Plain BSTs can degenerate into a chain when keys arrive in order, making operations O(n). Balanced BSTs rearrange the tree (rotations + metadata) so height stays logarithmic—giving O(log n) search, insert, delete regardless of input order.
For every node with key x: all keys in left are < x, all keys in right are > x. This is the property that makes inorder traversal sorted.
BF = height(left) − height(right) is in {−1, 0, +1}.Treat every missing child as a special NIL leaf (black, no key).
| Height bound | Search | Insert | Delete | Notes | |
|---|---|---|---|---|---|
| AVL | ≤ ~1.44·log₂(n) | O(log n) | O(log n), ≤2 rotations | O(log n), may cascade | Best lookups; more rebalancing |
| RBT | ≤ 2·log₂(n) | O(log n) | O(log n), ≤2 rotations + recolors | O(log n), ≤3 rotations + recolors | Great all-rounder; standard library choice |
A rotation is a local pointer swap that changes parent/child roles while preserving the inorder order.
BF is 2 or −2, rotate using LL/LR/RR/RL rules.z as red; update sizes.z) is red:
z ← grandparent.O(h)).size = 1 + size(left) + size(right) at each node.
k to size(left)+1 to go left/right in O(h).upd() on the affected nodes (height/size).null as black in helpers or use a shared black NIL sentinel.subtree_size)height & size
class Node:
__slots__ = ("key","left","right","h","size")
def __init__(self, k): self.key=k; self.left=self.right=None; self.h=1; self.size=1
def height(u): return u.h if u else 0
def sz(u): return u.size if u else 0
def upd(u):
u.h = 1 + max(height(u.left), height(u.right))
u.size = 1 + sz(u.left) + sz(u.right)
def rotL(x):
y = x.right; T2 = y.left
y.left = x; x.right = T2
upd(x); upd(y); return y
def rotR(y):
x = y.left; T2 = x.right
x.right = y; y.left = T2
upd(y); upd(x); return x
def balance(u):
return height(u.left) - height(u.right)
def insert(u, x):
if not u: return Node(x)
if x == u.key: return u
if x < u.key: u.left = insert(u.left, x)
else: u.right = insert(u.right, x)
upd(u)
b = balance(u)
if b > 1 and x < u.left.key: return rotR(u) # LL
if b <-1 and x > u.right.key: return rotL(u) # RR
if b > 1 and x > u.left.key: u.left = rotL(u.left); return rotR(u) # LR
if b <-1 and x < u.right.key: u.right= rotR(u.right); return rotL(u) # RL
return u
def minNode(u):
while u.left: u=u.left
return u
def delete(u, x):
if not u: return None
if x < u.key: u.left = delete(u.left, x)
elif x > u.key: u.right = delete(u.right, x)
else:
if not u.left: return u.right
if not u.right: return u.left
s = minNode(u.right); u.key = s.key
u.right = delete(u.right, s.key)
upd(u)
b = balance(u)
if b > 1:
if balance(u.left) < 0: u.left = rotL(u.left)
return rotR(u)
if b < -1:
if balance(u.right) > 0: u.right = rotR(u.right)
return rotL(u)
return u
def kth(u, k):
if not u or k<1 or k>sz(u): return None
left = sz(u.left)
if k == left+1: return u
if k < left+1: return kth(u.left, k)
return kth(u.right, k-left-1)
# Colors: 'R'/'B'. Treat None as black in helpers.
class Node:
__slots__=("key","left","right","parent","color","size")
def __init__(self,k,c='R',p=None): self.key=k; self.left=self.right=None; self.parent=p; self.color=c; self.size=1
def color(u): return 'B' if u is None else u.color
def sz(u): return 0 if u is None else u.size
def upd(u):
if u: u.size = 1 + sz(u.left) + sz(u.right)
def rotL(root, x):
y=x.right; x.right=y.left;
if y.left: y.left.parent=x
y.parent=x.parent
if x.parent is None: root=y
elif x==x.parent.left: x.parent.left=y
else: x.parent.right=y
y.left=x; x.parent=y
upd(x); upd(y); return root
def rotR(root, y):
x=y.left; y.left=x.right
if x.right: x.right.parent=y
x.parent=y.parent
if y.parent is None: root=x
elif y==y.parent.left: y.parent.left=x
else: y.parent.right=x
x.right=y; y.parent=x
upd(y); upd(x); return root
def insert(root, k):
# BST insert
p=None; u=root
while u: p=u; u.size+=1; u = u.left if k<u.key else u.right
u=Node(k,'R',p)
if p is None: root=u
elif k<p.key: p.left=u
else: p.right=u
# Fix-up
while u.parent and u.parent.color=='R':
g = u.parent.parent
if u.parent==g.left:
y = g.right
if color(y)=='R': # case 1
u.parent.color='B'; y.color='B'; g.color='R'; u=g
else:
if u==u.parent.right: u=u.parent; root=rotL(root,u)
u.parent.color='B'; g.color='R'; root=rotR(root,g)
else:
y = g.left
if color(y)=='R':
u.parent.color='B'; y.color='B'; g.color='R'; u=g
else:
if u==u.parent.left: u=u.parent; root=rotR(root,u)
u.parent.color='B'; g.color='R'; root=rotL(root,g)
root.color='B'; return root
def transplant(root,u,v):
if u.parent is None: root=v
elif u==u.parent.left: u.parent.left=v
else: u.parent.right=v
if v: v.parent=u.parent
return root
def minimum(u):
while u.left: u=u.left
return u
def delete(root, k):
# find
u=root
while u and u.key!=k:
u.size-=1
u = u.left if k<u.key else u.right
if not u: return root
orig = u.color
# standard RB delete
if u.left is None:
x = u.right
root = transplant(root,u,u.right)
elif u.right is None:
x = u.left
root = transplant(root,u,u.left)
else:
y=minimum(u.right); orig=y.color; x=y.right
# update sizes on path from y upwards
t=y.parent
while t and t!=u: t.size-=1; t=t.parent
if y.parent!=u:
root = transplant(root,y,y.right)
y.right=u.right; y.right.parent=y
root = transplant(root,u,y)
y.left=u.left; y.left.parent=y
y.color=u.color; upd(y)
# fix sizes upward from parent of replacement
t=(x.parent if x else None)
while t: upd(t); t=t.parent
# fix-up colors
if orig=='B':
while (x!=root) and (color(x)=='B'):
p=x.parent
if x==p.left:
w=p.right
if color(w)=='R': w.color='B'; p.color='R'; root=rotL(root,p); w=p.right
if color(w.left)=='B' and color(w.right)=='B':
w.color='R'; x=p
else:
if color(w.right)=='B':
if w.left: w.left.color='B'
w.color='R'; root=rotR(root,w); w=p.right
w.color=p.color; p.color='B'; if w.right: w.right.color='B'; root=rotL(root,p); x=root
else:
w=p.left
if color(w)=='R': w.color='B'; p.color='R'; root=rotR(root,p); w=p.left
if color(w.left)=='B' and color(w.right)=='B':
w.color='R'; x=p
else:
if color(w.left)=='B':
if w.right: w.right.color='B'
w.color='R'; root=rotL(root,w); w=p.left
w.color=p.color; p.color='B'; if w.left: w.left.color='B'; root=rotR(root,p); x=root
if x: x.color='B'
return root
def kth(u,k):
while u:
left = sz(u.left)
if k == left+1: return u
if k < left+1: u = u.left
else: k -= left+1; u = u.right
return None
sortedcontainers (not a tree)# pip install sortedcontainers from sortedcontainers import SortedList S = SortedList() S.add(10); S.discard(7) x = 25 in S # O(log n) kth = S[k-1] # O(1) # Note: this is a B-tree-like structure under the hood, not an AVL/RBT.
class Avl {
static class Node { int key,h=1,size=1; Node l,r; Node(int k){key=k;} }
static int H(Node u){return u==null?0:u.h;} static int S(Node u){return u==null?0:u.size;}
static void upd(Node u){ u.h=1+Math.max(H(u.l),H(u.r)); u.size=1+S(u.l)+S(u.r); }
static Node rotL(Node x){ Node y=x.r, t=y.l; y.l=x; x.r=t; upd(x); upd(y); return y; }
static Node rotR(Node y){ Node x=y.l, t=x.r; x.r=y; y.l=t; upd(y); upd(x); return x; }
static int bal(Node u){return H(u.l)-H(u.r);}
static Node insert(Node u,int k){
if(u==null) return new Node(k);
if(k==u.key) return u;
if(k<u.key) u.l=insert(u.l,k); else u.r=insert(u.r,k);
upd(u); int b=bal(u);
if(b>1 && k<u.l.key) return rotR(u);
if(b<-1 && k>u.r.key) return rotL(u);
if(b>1 && k>u.l.key){ u.l=rotL(u.l); return rotR(u); }
if(b<-1 && k<u.r.key){ u.r=rotR(u.r); return rotL(u); }
return u;
}
static Node min(Node u){ while(u.l!=null) u=u.l; return u; }
static Node delete(Node u,int k){
if(u==null) return null;
if(k<u.key) u.l=delete(u.l,k);
else if(k>u.key) u.r=delete(u.r,k);
else{
if(u.l==null) return u.r;
if(u.r==null) return u.l;
Node s=min(u.r); u.key=s.key; u.r=delete(u.r,s.key);
}
upd(u); int b=bal(u);
if(b>1){ if(bal(u.l)<0) u.l=rotL(u.l); return rotR(u); }
if(b<-1){ if(bal(u.r)>0) u.r=rotR(u.r); return rotL(u); }
return u;
}
static Node kth(Node u,int k){
while(u!=null){
int left=S(u.l);
if(k==left+1) return u;
if(k<=left) u=u.l; else {k-=left+1; u=u.r;}
} return null;
}
}
class Rb {
static final boolean R=true,B=false;
static class Node{ int key,size=1; boolean c=R; Node l,r,p; Node(int k){key=k;} }
static int S(Node u){return u==null?0:u.size;}
static void upd(Node u){ if(u!=null) u.size=1+S(u.l)+S(u.r); }
static Node rotL(Node root, Node x){ Node y=x.r; x.r=y.l; if(y.l!=null) y.l.p=x;
y.p=x.p; if(x.p==null) root=y; else if(x==x.p.l) x.p.l=y; else x.p.r=y;
y.l=x; x.p=y; upd(x); upd(y); return root; }
static Node rotR(Node root, Node y){ Node x=y.l; y.l=x.r; if(x.r!=null) x.r.p=y;
x.p=y.p; if(y.p==null) root=x; else if(y==y.p.l) y.p.l=x; else y.p.r=x;
x.r=y; y.p=x; upd(y); upd(x); return root; }
static boolean color(Node u){ return u==null?B:u.c; }
static Node insert(Node root,int k){
Node p=null,u=root;
while(u!=null){ p=u; u.size++; u=(k<u.key)?u.l:u.r; }
u=new Node(k); u.p=p; u.c=R;
if(p==null) root=u; else if(k<p.key) p.l=u; else p.r=u;
while(u.p!=null && u.p.c==R){
Node g=u.p.p;
if(u.p==g.l){
Node y=g.r;
if(color(y)==R){ u.p.c=B; y.c=B; g.c=R; u=g; }
else{
if(u==u.p.r){ u=u.p; root=rotL(root,u); }
u.p.c=B; g.c=R; root=rotR(root,g);
}
}else{
Node y=g.l;
if(color(y)==R){ u.p.c=B; y.c=B; g.c=R; u=g; }
else{
if(u==u.p.l){ u=u.p; root=rotR(root,u); }
u.p.c=B; g.c=R; root=rotL(root,g);
}
}
}
root.c=B; return root;
}
static Node transplant(Node root, Node u, Node v){
if(u.p==null) root=v;
else if(u==u.p.l) u.p.l=v; else u.p.r=v;
if(v!=null) v.p=u.p; return root;
}
static Node min(Node u){ while(u.l!=null) u=u.l; return u; }
static Node delete(Node root,int k){
Node u=root;
while(u!=null && u.key!=k){ u.size--; u=(k<u.key)?u.l:u.r; }
if(u==null) return root;
boolean orig = u.c; Node x;
if(u.l==null){ x=u.r; root=transplant(root,u,u.r); }
else if(u.r==null){ x=u.l; root=transplant(root,u,u.l); }
else{
Node y=min(u.r); orig=y.c; x=y.r;
for(Node t=y.p; t!=u; t=t.p) t.size--;
if(y.p!=u){ root=transplant(root,y,y.r); y.r=u.r; y.r.p=y; }
root=transplant(root,u,y); y.l=u.l; y.l.p=y; y.c=u.c; upd(y);
}
for(Node t=(x==null?null:x.p); t!=null; t=t.p) upd(t);
if(orig==B){
while(x!=root && color(x)==B){
Node p=(x==null)?(x=null) : x.p; if(p==null) break;
if(x==p.l){
Node w=p.r;
if(color(w)==R){ w.c=B; p.c=R; root=rotL(root,p); w=p.r; }
if(color(w.l)==B && color(w.r)==B){ w.c=R; x=p; }
else{
if(color(w.r)==B){ if(w.l!=null) w.l.c=B; w.c=R; root=rotR(root,w); w=p.r; }
w.c=p.c; p.c=B; if(w.r!=null) w.r.c=B; root=rotL(root,p); x=root;
}
}else{
Node w=p.l;
if(color(w)==R){ w.c=B; p.c=R; root=rotR(root,p); w=p.l; }
if(color(w.l)==B && color(w.r)==B){ w.c=R; x=p; }
else{
if(color(w.l)==B){ if(w.r!=null) w.r.c=B; w.c=R; root=rotL(root,w); w=p.l; }
w.c=p.c; p.c=B; if(w.l!=null) w.l.c=B; root=rotR(root,p); x=root;
}
}
}
if(x!=null) x.c=B;
}
return root;
}
}
TreeSet/TreeMap (red–black)import java.util.*; TreeSet<Integer> set = new TreeSet<>(); set.add(10); set.remove(7); boolean has = set.contains(25); Integer lo = set.first(), hi = set.last(); NavigableSet<Integer> sub = set.subSet(10,true,20,true); // range view // Note: Java's TreeSet/TreeMap are Red–Black trees; they don't expose k-th directly.
struct Node{ int key,h=1,size=1; Node *l=nullptr,*r=nullptr; Node(int k):key(k){} };
int H(Node* u){return u?u->h:0;} int S(Node* u){return u?u->size:0;}
void upd(Node* u){ if(u){ u->h=1+std::max(H(u->l),H(u->r)); u->size=1+S(u->l)+S(u->r);} }
Node* rotL(Node* x){ Node* y=x->r; Node* t=y->l; y->l=x; x->r=t; upd(x); upd(y); return y; }
Node* rotR(Node* y){ Node* x=y->l; Node* t=x->r; x->r=y; y->l=t; upd(y); upd(x); return x; }
int bal(Node* u){ return H(u->l)-H(u->r); }
Node* insert(Node* u,int k){
if(!u) return new Node(k);
if(k==u->key) return u;
if(k<u->key) u->l=insert(u->l,k); else u->r=insert(u->r,k);
upd(u); int b=bal(u);
if(b>1 && k<u->l->key) return rotR(u);
if(b<-1 && k>u->r->key) return rotL(u);
if(b>1 && k>u->l->key){ u->l=rotL(u->l); return rotR(u); }
if(b<-1 && k<u->r->key){ u->r=rotR(u->r); return rotL(u); }
return u;
}
Node* minNode(Node* u){ while(u->l) u=u->l; return u; }
Node* erase(Node* u,int k){
if(!u) return nullptr;
if(k<u->key) u->l=erase(u->l,k);
else if(k>u->key) u->r=erase(u->r,k);
else{
if(!u->l){ Node* r=u->r; delete u; return r; }
if(!u->r){ Node* l=u->l; delete u; return l; }
Node* s=minNode(u->r); u->key=s->key; u->r=erase(u->r,s->key);
}
upd(u); int b=bal(u);
if(b>1){ if(bal(u->l)<0) u->l=rotL(u->l); return rotR(u); }
if(b<-1){ if(bal(u->r)>0) u->r=rotR(u->r); return rotL(u); }
return u;
}
Node* kth(Node* u,int k){
while(u){ int left=S(u->l);
if(k==left+1) return u;
if(k<=left) u=u->l; else { k-=left+1; u=u->r; } }
return nullptr;
}
std::set/std::map (RBT) + PBDS order-statistics#include <set> #include <map> std::set<int> S; S.insert(10); S.erase(7); bool has = S.count(25); int mn = *S.begin(), mx = *S.rbegin(); // GNU PBDS (non-standard but handy in contests): #include <ext/pb_ds/assoc_container.hpp> using namespace __gnu_pbds; tree<int, null_type, std::less<int>, rb_tree_tag, tree_order_statistics_node_update> T; T.insert(10); T.insert(20); int kth = *T.find_by_order(0); // 0-indexed k-th int rank = T.order_of_key(15); // #elements < 15
subtree_size and compare to inorder stepping.