Plain BSTs can degenerate into a chain when keys arrive in order, making operations O(n)
. Balanced BSTs rearrange the tree (rotations + metadata) so height stays logarithmic—giving O(log n)
search, insert, delete regardless of input order.
For every node with key x
: all keys in left
are < x
, all keys in right
are > x
. This is the property that makes inorder traversal sorted.
BF = height(left) − height(right)
is in {−1, 0, +1}
.Treat every missing child as a special NIL leaf (black, no key).
Height bound | Search | Insert | Delete | Notes | |
---|---|---|---|---|---|
AVL | ≤ ~1.44·log₂(n) | O(log n) | O(log n), ≤2 rotations | O(log n), may cascade | Best lookups; more rebalancing |
RBT | ≤ 2·log₂(n) | O(log n) | O(log n), ≤2 rotations + recolors | O(log n), ≤3 rotations + recolors | Great all-rounder; standard library choice |
A rotation is a local pointer swap that changes parent/child roles while preserving the inorder order.
BF
is 2 or −2, rotate using LL/LR/RR/RL rules.z
as red; update sizes.z
) is red:
z ← grandparent
.O(h)
).size = 1 + size(left) + size(right)
at each node.
k
to size(left)+1
to go left/right in O(h)
.upd()
on the affected nodes (height/size).null
as black in helpers or use a shared black NIL
sentinel.subtree_size
)height
& size
class Node: __slots__ = ("key","left","right","h","size") def __init__(self, k): self.key=k; self.left=self.right=None; self.h=1; self.size=1 def height(u): return u.h if u else 0 def sz(u): return u.size if u else 0 def upd(u): u.h = 1 + max(height(u.left), height(u.right)) u.size = 1 + sz(u.left) + sz(u.right) def rotL(x): y = x.right; T2 = y.left y.left = x; x.right = T2 upd(x); upd(y); return y def rotR(y): x = y.left; T2 = x.right x.right = y; y.left = T2 upd(y); upd(x); return x def balance(u): return height(u.left) - height(u.right) def insert(u, x): if not u: return Node(x) if x == u.key: return u if x < u.key: u.left = insert(u.left, x) else: u.right = insert(u.right, x) upd(u) b = balance(u) if b > 1 and x < u.left.key: return rotR(u) # LL if b <-1 and x > u.right.key: return rotL(u) # RR if b > 1 and x > u.left.key: u.left = rotL(u.left); return rotR(u) # LR if b <-1 and x < u.right.key: u.right= rotR(u.right); return rotL(u) # RL return u def minNode(u): while u.left: u=u.left return u def delete(u, x): if not u: return None if x < u.key: u.left = delete(u.left, x) elif x > u.key: u.right = delete(u.right, x) else: if not u.left: return u.right if not u.right: return u.left s = minNode(u.right); u.key = s.key u.right = delete(u.right, s.key) upd(u) b = balance(u) if b > 1: if balance(u.left) < 0: u.left = rotL(u.left) return rotR(u) if b < -1: if balance(u.right) > 0: u.right = rotR(u.right) return rotL(u) return u def kth(u, k): if not u or k<1 or k>sz(u): return None left = sz(u.left) if k == left+1: return u if k < left+1: return kth(u.left, k) return kth(u.right, k-left-1)
# Colors: 'R'/'B'. Treat None as black in helpers. class Node: __slots__=("key","left","right","parent","color","size") def __init__(self,k,c='R',p=None): self.key=k; self.left=self.right=None; self.parent=p; self.color=c; self.size=1 def color(u): return 'B' if u is None else u.color def sz(u): return 0 if u is None else u.size def upd(u): if u: u.size = 1 + sz(u.left) + sz(u.right) def rotL(root, x): y=x.right; x.right=y.left; if y.left: y.left.parent=x y.parent=x.parent if x.parent is None: root=y elif x==x.parent.left: x.parent.left=y else: x.parent.right=y y.left=x; x.parent=y upd(x); upd(y); return root def rotR(root, y): x=y.left; y.left=x.right if x.right: x.right.parent=y x.parent=y.parent if y.parent is None: root=x elif y==y.parent.left: y.parent.left=x else: y.parent.right=x x.right=y; y.parent=x upd(y); upd(x); return root def insert(root, k): # BST insert p=None; u=root while u: p=u; u.size+=1; u = u.left if k<u.key else u.right u=Node(k,'R',p) if p is None: root=u elif k<p.key: p.left=u else: p.right=u # Fix-up while u.parent and u.parent.color=='R': g = u.parent.parent if u.parent==g.left: y = g.right if color(y)=='R': # case 1 u.parent.color='B'; y.color='B'; g.color='R'; u=g else: if u==u.parent.right: u=u.parent; root=rotL(root,u) u.parent.color='B'; g.color='R'; root=rotR(root,g) else: y = g.left if color(y)=='R': u.parent.color='B'; y.color='B'; g.color='R'; u=g else: if u==u.parent.left: u=u.parent; root=rotR(root,u) u.parent.color='B'; g.color='R'; root=rotL(root,g) root.color='B'; return root def transplant(root,u,v): if u.parent is None: root=v elif u==u.parent.left: u.parent.left=v else: u.parent.right=v if v: v.parent=u.parent return root def minimum(u): while u.left: u=u.left return u def delete(root, k): # find u=root while u and u.key!=k: u.size-=1 u = u.left if k<u.key else u.right if not u: return root orig = u.color # standard RB delete if u.left is None: x = u.right root = transplant(root,u,u.right) elif u.right is None: x = u.left root = transplant(root,u,u.left) else: y=minimum(u.right); orig=y.color; x=y.right # update sizes on path from y upwards t=y.parent while t and t!=u: t.size-=1; t=t.parent if y.parent!=u: root = transplant(root,y,y.right) y.right=u.right; y.right.parent=y root = transplant(root,u,y) y.left=u.left; y.left.parent=y y.color=u.color; upd(y) # fix sizes upward from parent of replacement t=(x.parent if x else None) while t: upd(t); t=t.parent # fix-up colors if orig=='B': while (x!=root) and (color(x)=='B'): p=x.parent if x==p.left: w=p.right if color(w)=='R': w.color='B'; p.color='R'; root=rotL(root,p); w=p.right if color(w.left)=='B' and color(w.right)=='B': w.color='R'; x=p else: if color(w.right)=='B': if w.left: w.left.color='B' w.color='R'; root=rotR(root,w); w=p.right w.color=p.color; p.color='B'; if w.right: w.right.color='B'; root=rotL(root,p); x=root else: w=p.left if color(w)=='R': w.color='B'; p.color='R'; root=rotR(root,p); w=p.left if color(w.left)=='B' and color(w.right)=='B': w.color='R'; x=p else: if color(w.left)=='B': if w.right: w.right.color='B' w.color='R'; root=rotL(root,w); w=p.left w.color=p.color; p.color='B'; if w.left: w.left.color='B'; root=rotR(root,p); x=root if x: x.color='B' return root def kth(u,k): while u: left = sz(u.left) if k == left+1: return u if k < left+1: u = u.left else: k -= left+1; u = u.right return None
sortedcontainers
(not a tree)# pip install sortedcontainers from sortedcontainers import SortedList S = SortedList() S.add(10); S.discard(7) x = 25 in S # O(log n) kth = S[k-1] # O(1) # Note: this is a B-tree-like structure under the hood, not an AVL/RBT.
class Avl { static class Node { int key,h=1,size=1; Node l,r; Node(int k){key=k;} } static int H(Node u){return u==null?0:u.h;} static int S(Node u){return u==null?0:u.size;} static void upd(Node u){ u.h=1+Math.max(H(u.l),H(u.r)); u.size=1+S(u.l)+S(u.r); } static Node rotL(Node x){ Node y=x.r, t=y.l; y.l=x; x.r=t; upd(x); upd(y); return y; } static Node rotR(Node y){ Node x=y.l, t=x.r; x.r=y; y.l=t; upd(y); upd(x); return x; } static int bal(Node u){return H(u.l)-H(u.r);} static Node insert(Node u,int k){ if(u==null) return new Node(k); if(k==u.key) return u; if(k<u.key) u.l=insert(u.l,k); else u.r=insert(u.r,k); upd(u); int b=bal(u); if(b>1 && k<u.l.key) return rotR(u); if(b<-1 && k>u.r.key) return rotL(u); if(b>1 && k>u.l.key){ u.l=rotL(u.l); return rotR(u); } if(b<-1 && k<u.r.key){ u.r=rotR(u.r); return rotL(u); } return u; } static Node min(Node u){ while(u.l!=null) u=u.l; return u; } static Node delete(Node u,int k){ if(u==null) return null; if(k<u.key) u.l=delete(u.l,k); else if(k>u.key) u.r=delete(u.r,k); else{ if(u.l==null) return u.r; if(u.r==null) return u.l; Node s=min(u.r); u.key=s.key; u.r=delete(u.r,s.key); } upd(u); int b=bal(u); if(b>1){ if(bal(u.l)<0) u.l=rotL(u.l); return rotR(u); } if(b<-1){ if(bal(u.r)>0) u.r=rotR(u.r); return rotL(u); } return u; } static Node kth(Node u,int k){ while(u!=null){ int left=S(u.l); if(k==left+1) return u; if(k<=left) u=u.l; else {k-=left+1; u=u.r;} } return null; } }
class Rb { static final boolean R=true,B=false; static class Node{ int key,size=1; boolean c=R; Node l,r,p; Node(int k){key=k;} } static int S(Node u){return u==null?0:u.size;} static void upd(Node u){ if(u!=null) u.size=1+S(u.l)+S(u.r); } static Node rotL(Node root, Node x){ Node y=x.r; x.r=y.l; if(y.l!=null) y.l.p=x; y.p=x.p; if(x.p==null) root=y; else if(x==x.p.l) x.p.l=y; else x.p.r=y; y.l=x; x.p=y; upd(x); upd(y); return root; } static Node rotR(Node root, Node y){ Node x=y.l; y.l=x.r; if(x.r!=null) x.r.p=y; x.p=y.p; if(y.p==null) root=x; else if(y==y.p.l) y.p.l=x; else y.p.r=x; x.r=y; y.p=x; upd(y); upd(x); return root; } static boolean color(Node u){ return u==null?B:u.c; } static Node insert(Node root,int k){ Node p=null,u=root; while(u!=null){ p=u; u.size++; u=(k<u.key)?u.l:u.r; } u=new Node(k); u.p=p; u.c=R; if(p==null) root=u; else if(k<p.key) p.l=u; else p.r=u; while(u.p!=null && u.p.c==R){ Node g=u.p.p; if(u.p==g.l){ Node y=g.r; if(color(y)==R){ u.p.c=B; y.c=B; g.c=R; u=g; } else{ if(u==u.p.r){ u=u.p; root=rotL(root,u); } u.p.c=B; g.c=R; root=rotR(root,g); } }else{ Node y=g.l; if(color(y)==R){ u.p.c=B; y.c=B; g.c=R; u=g; } else{ if(u==u.p.l){ u=u.p; root=rotR(root,u); } u.p.c=B; g.c=R; root=rotL(root,g); } } } root.c=B; return root; } static Node transplant(Node root, Node u, Node v){ if(u.p==null) root=v; else if(u==u.p.l) u.p.l=v; else u.p.r=v; if(v!=null) v.p=u.p; return root; } static Node min(Node u){ while(u.l!=null) u=u.l; return u; } static Node delete(Node root,int k){ Node u=root; while(u!=null && u.key!=k){ u.size--; u=(k<u.key)?u.l:u.r; } if(u==null) return root; boolean orig = u.c; Node x; if(u.l==null){ x=u.r; root=transplant(root,u,u.r); } else if(u.r==null){ x=u.l; root=transplant(root,u,u.l); } else{ Node y=min(u.r); orig=y.c; x=y.r; for(Node t=y.p; t!=u; t=t.p) t.size--; if(y.p!=u){ root=transplant(root,y,y.r); y.r=u.r; y.r.p=y; } root=transplant(root,u,y); y.l=u.l; y.l.p=y; y.c=u.c; upd(y); } for(Node t=(x==null?null:x.p); t!=null; t=t.p) upd(t); if(orig==B){ while(x!=root && color(x)==B){ Node p=(x==null)?(x=null) : x.p; if(p==null) break; if(x==p.l){ Node w=p.r; if(color(w)==R){ w.c=B; p.c=R; root=rotL(root,p); w=p.r; } if(color(w.l)==B && color(w.r)==B){ w.c=R; x=p; } else{ if(color(w.r)==B){ if(w.l!=null) w.l.c=B; w.c=R; root=rotR(root,w); w=p.r; } w.c=p.c; p.c=B; if(w.r!=null) w.r.c=B; root=rotL(root,p); x=root; } }else{ Node w=p.l; if(color(w)==R){ w.c=B; p.c=R; root=rotR(root,p); w=p.l; } if(color(w.l)==B && color(w.r)==B){ w.c=R; x=p; } else{ if(color(w.l)==B){ if(w.r!=null) w.r.c=B; w.c=R; root=rotL(root,w); w=p.l; } w.c=p.c; p.c=B; if(w.l!=null) w.l.c=B; root=rotR(root,p); x=root; } } } if(x!=null) x.c=B; } return root; } }
TreeSet
/TreeMap
(red–black)import java.util.*; TreeSet<Integer> set = new TreeSet<>(); set.add(10); set.remove(7); boolean has = set.contains(25); Integer lo = set.first(), hi = set.last(); NavigableSet<Integer> sub = set.subSet(10,true,20,true); // range view // Note: Java's TreeSet/TreeMap are Red–Black trees; they don't expose k-th directly.
struct Node{ int key,h=1,size=1; Node *l=nullptr,*r=nullptr; Node(int k):key(k){} }; int H(Node* u){return u?u->h:0;} int S(Node* u){return u?u->size:0;} void upd(Node* u){ if(u){ u->h=1+std::max(H(u->l),H(u->r)); u->size=1+S(u->l)+S(u->r);} } Node* rotL(Node* x){ Node* y=x->r; Node* t=y->l; y->l=x; x->r=t; upd(x); upd(y); return y; } Node* rotR(Node* y){ Node* x=y->l; Node* t=x->r; x->r=y; y->l=t; upd(y); upd(x); return x; } int bal(Node* u){ return H(u->l)-H(u->r); } Node* insert(Node* u,int k){ if(!u) return new Node(k); if(k==u->key) return u; if(k<u->key) u->l=insert(u->l,k); else u->r=insert(u->r,k); upd(u); int b=bal(u); if(b>1 && k<u->l->key) return rotR(u); if(b<-1 && k>u->r->key) return rotL(u); if(b>1 && k>u->l->key){ u->l=rotL(u->l); return rotR(u); } if(b<-1 && k<u->r->key){ u->r=rotR(u->r); return rotL(u); } return u; } Node* minNode(Node* u){ while(u->l) u=u->l; return u; } Node* erase(Node* u,int k){ if(!u) return nullptr; if(k<u->key) u->l=erase(u->l,k); else if(k>u->key) u->r=erase(u->r,k); else{ if(!u->l){ Node* r=u->r; delete u; return r; } if(!u->r){ Node* l=u->l; delete u; return l; } Node* s=minNode(u->r); u->key=s->key; u->r=erase(u->r,s->key); } upd(u); int b=bal(u); if(b>1){ if(bal(u->l)<0) u->l=rotL(u->l); return rotR(u); } if(b<-1){ if(bal(u->r)>0) u->r=rotR(u->r); return rotL(u); } return u; } Node* kth(Node* u,int k){ while(u){ int left=S(u->l); if(k==left+1) return u; if(k<=left) u=u->l; else { k-=left+1; u=u->r; } } return nullptr; }
std::set
/std::map
(RBT) + PBDS order-statistics#include <set> #include <map> std::set<int> S; S.insert(10); S.erase(7); bool has = S.count(25); int mn = *S.begin(), mx = *S.rbegin(); // GNU PBDS (non-standard but handy in contests): #include <ext/pb_ds/assoc_container.hpp> using namespace __gnu_pbds; tree<int, null_type, std::less<int>, rb_tree_tag, tree_order_statistics_node_update> T; T.insert(10); T.insert(20); int kth = *T.find_by_order(0); // 0-indexed k-th int rank = T.order_of_key(15); // #elements < 15
subtree_size
and compare to inorder stepping.