前言
Treap 这个词是由 Tree 和 Heap 组合形成的,可以看出 Treap 是查找树和堆的结合,因此中文叫树堆。
和其他平衡树一样,Treap 的中序遍历值单调不减;而根据堆的性质,每个结点的权小于两个子结点的权。
Treap 分为有旋和无旋两种,而无旋 Treap又叫 FHQ-Treap,主要通过分裂(split)和合并(merge)实现维护操作。
操作
1. 分裂(split)
分裂操作是将一个树分成 \(x,y\) 两个树。\(x\) 中每一个结点的值都小于 \(k\),而 \(y\) 中每一个结点的值都大于等于 \(k\)。复杂度 \(O(logn)\)
举个例子:
(此图
盗自出自某dalao blog)代码:
void split(int p,int _val,int &x,int &y) { if(!p) x=y=0; else if(t[p].val<=_val) { split(t[p].r,_val,t[p].r,y); pushup(p); x=p; } else { split(t[p].l,_val,x,t[p].l); pushup(p); y=p; } }
\(2.\) 合并 (merge)
合并是将 \(x,y\) 两棵树合并为一棵树 复杂度 \(O(logn)\)
(此图
盗自出自另一位dalao blog)代码:
int merge(int x,int y) { if(!x||!y) return x+y; if(t[x].key<=t[y].key) { t[x].r=merge(t[x].r,y); pushup(x); return x; } else { t[y].l=merge(x,t[y].l); pushup(y); return y; } }
\(3.\) 插入
先将申请一个新的结点,作为一棵树 \(y\);并将原来的树分裂成 \(x,z\) 两棵树。
然后依次合并 \(x,y,z\),就完成了。复杂度 \(O(logn)\)代码:
inline void insert(int _val) { int x,y; split(root,_val,x,y); t[++cnt].init(_val); root=merge(x,merge(cnt,y)); }
\(4.\) 删除
删除比较巧妙,先将树分裂成 \(x,y,z\) 三棵树;其中 \(x\) 的每个结点的值均小于 \(k\),\(y\) 的每个结点的值均为 \(k\),\(z\) 的每个结点的值均大于 \(k\) 。
直接合并 \(y\) 的左右两棵子树,根节点就被删除掉了。最后,依次合并 \(x,y,z\)。
复杂度 \(O(logn)\)
代码:
inline void del(int _val) { int x,y,z; split(root,_val,x,y); split(x,_val-1,x,z); z=merge(t[z].l,t[z].r); root=merge(x,merge(z,y)); }
\(5.\) 查询排名
直接分裂,小于 \(k\) 的树的大小加一即为排名。
复杂度 \(O(logn)\)
int getrank(int p,int _val) { int x,y; split(root,_val-1,x,y); int res=t[x].siz+1; root=merge(x,y); return res; }
\(6.\) 排名为 \(x\) 的数
这个操作是查询第 \(x\) 大,要按照普通的查询方法来搞
详见代码:
int getvalue(int p,int k) { if(t[t[p].l].siz+1==k) return t[p].val; if(t[t[p].l].siz+1<k) return getvalue(t[p].r,k-t[t[p].l].siz-1); else return getvalue(t[p].l,k); }
\(7.\) 前驱
所以直接查找小于 \(x\) 的数里最大的。
代码:
inline int prev(int _val) { int x,y; split(root,_val-1,x,y); int tmp=x; while(t[tmp].r) tmp=t[tmp].r; root=merge(x,y); return t[tmp].val; }
\(8.\) 后继
于前驱同理,找大于等于 \(x\) 里最小的
代码:
inline int nex(int _val) { int x,y; split(root,_val,x,y); int tmp=y; while(t[tmp].l) tmp=t[tmp].l; root=merge(x,y); return t[tmp].val; }
P3369 【模板】普通平衡树
#include<iostream>
#include<cstdio>
#include<ctime>
using namespace std;
namespace fastio
{
#define te template<typename T>
#define tem template<typename T,typename ...Args>
te inline static void read(T &x){x=0;int f=1;char ch=getchar();while(ch<'0'||ch>'9'){if(ch=='-') f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}if(f==-1) x=-x;}
tem inline static void read(T& x,Args& ...args){read(x);read(args...);}
te inline static void write(char c,T x){T p=x;if(!p) putchar('0');if(p<0){putchar('-');p=-p;}int cnt[105],tot=0;while(p){cnt[++tot]=p%10;p/=10;}for(int i=tot;i>=1;i--){putchar(cnt[i]+'0');}putchar(c);}
tem inline static void write(const char c,T x,Args ...args){write(c,x);write(c,args...);}
}using namespace fastio;
const int N=1e5+10;
int n,m;
int op,x;
int root,cnt;
struct node{
int l,r,key,val,siz;
inline void init(int _val){
val=_val;
siz=1;
key=rand();
}
}t[N];
inline void pushup(int p){t[p].siz=t[t[p].l].siz+t[t[p].r].siz+1;}
void split(int p,int _val,int &x,int &y){
if(!p) x=y=0;
else if(t[p].val<=_val){
split(t[p].r,_val,t[p].r,y);
pushup(p);x=p;
}else{
split(t[p].l,_val,x,t[p].l);
pushup(p);y=p;
}
}
int merge(int x,int y){
if(!x||!y) return x+y;
if(t[x].key<=t[y].key){
t[x].r=merge(t[x].r,y);
pushup(x);return x;
}else{
t[y].l=merge(x,t[y].l);
pushup(y);return y;
}
}
inline void insert(int _val){
int x,y;
split(root,_val,x,y);
t[++cnt].init(_val);
root=merge(x,merge(cnt,y));
}
inline void del(int _val){
int x,y,z;
split(root,_val,x,y);
split(x,_val-1,x,z);
z=merge(t[z].l,t[z].r);
root=merge(x,merge(z,y));
}
int getrank(int p,int _val){
int x,y;
split(root,_val-1,x,y);
int res=t[x].siz+1;
root=merge(x,y);return res;
}
int getvalue(int p,int k){
if(t[t[p].l].siz+1==k) return t[p].val;
if(t[t[p].l].siz+1<k) return getvalue(t[p].r,k-t[t[p].l].siz-1);
else return getvalue(t[p].l,k);
}
inline int prev(int _val){
int x,y;
split(root,_val-1,x,y);
int tmp=x;
while(t[tmp].r) tmp=t[tmp].r;
root=merge(x,y);
return t[tmp].val;
}
inline int nex(int _val){
int x,y;
split(root,_val,x,y);
int tmp=y;
while(t[tmp].l) tmp=t[tmp].l;
root=merge(x,y);
return t[tmp].val;
}
int main(){
srand(time(0));
read(n);
while(n--){
read(op,x);
if(op==1) insert(x);
else if(op==2) del(x);
else if(op==3) write('\n',getrank(root,x));
else if(op==4) write('\n',getvalue(root,x));
else if(op==5) write('\n',prev(x));
else if(op==6) write('\n',nex(x));
}
return 0;
}