Splay
早就听说SPLAY大名,可以做区间翻转操作
但是查到的都是平衡树相关的东西
因为这个是LCT的前置,所以终于要学了
基本功能
结构
一颗二叉搜索树,需要维护根节点编号与节点个数
每个节点维护父亲、左右儿子编号、节点权值、权值出现次数、子树大小
感觉跟AVL维护的东西差不多
基本操作
splay(伸展):将某个点旋转至根(其实可以旋转至任意父亲处)
push,remove
伸展
其中一个关键的地方在于旋转,应当使用双旋而不是单旋
单旋即将目标节点一直向上旋转
而双旋分六种情况,如下图(图源自OIWIKI)
前两种情况为父亲为根,做一次旋转即可
中间两种情况为,父亲方向与儿子方向相同,父亲先上旋转,儿子再上旋
底下两种情况为,父亲方向与儿子方向不同,儿子上旋两次
显然单旋弄一条链就直接卡没了,而双旋试一下就知道用链至少不会一直保持在最坏情况
严格的均摊复杂度可以由势能法证明是 O(logn)
的
具体实现起来,可以一次双旋,而不是做两次单旋,能省一些常数
分别画一下这六种情况的前后对应关系
图中红色标出的为断开的位置,观察一下最终连接,可以归纳出几个共通点(下述类型均指其为左儿子还是右儿子)
- 最顶端节点(c)断开处将连接首次旋转节点的异侧
- 中部的节点(后四种情况的x)连接底部节点(y)的断开处总是连接其儿子的异侧
- 最底部节点(后四种情况的y)与其父亲 (x) 类型相异侧,中间两种情况连接中部节点(x),后两种情况连接顶部节点(c)
同时前两种情况(父亲为根),可以合并到最后两种情况进行处理 - (后四种情况)最后剩下的一个断开为首次旋转节点 (z) 与 y 类型相异侧,中间两种情况连接顶部节点(c),后两种情况连接中部节点(x)
这样就是四处断开的位置及其所有可能的连接方式,其中2,4两种是后四种情况独有的,1,3是全部情况共用的
再记得更新父亲的连接,这里用函数fix做,顺便维护想维护的值例如子树大小
代码实现如下
//该函数在最顶端节点处调用
//i指定首层节点的类型,0左,1右
//j指定次层节点的类型,0左,1右,2不存在
void rot(int i, int j) {
int b = i ^ j; //确定此次旋转的类型,0同侧,1异侧,当j为2时恒不为0,视作异侧
//x为首层节点,y为次层节点,z为首次旋转的节点
Node *x = c[i], *y = j == 2 ? x : x->c[j], *z = b ? y : x;
if ((y->p = p)) p->c[type()] = y; //不要忘了更新父亲
c[i] = z->c[i ^ 1];//连接1
if (j < 2) {
x->c[j] = y->c[j ^ 1]; //连接2
z->c[j ^ 1] = b ? x : this; //连接4
}
y->c[i ^ 1] = b ? this : x; //连接3
fix();
x->fix();
y->fix();
if (p) p->fix();
}
void fix() {
if (c[0]) c[0]->p = this;
if (c[1]) c[1]->p = this;
}
剩下的就好说了,递归上旋直到为根即完成伸展操作
增加push
寻找目标位置,遇到相等就直接计数+1,遇到空位就新增节点
最后记得splay
删除remove
寻找目标位置,旋到根
若数量多于1,则计数器-1
否则删除根,将左子树的最右节点旋为左子树的根(先断开左子树与根)
再将右子树直接作为新根的右子树
左子树为空直接将右子树顶上来
其他操作
kth(给排名查数),rnk(查排名),求前驱、后继
后三项都可以先插入待查的数,旋到根,查完之后再删除,非常方便
代码(洛谷P6136 【模板】普通平衡树(数据加强版))
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <vector>
using namespace std;
const int MAX_N = 1E6 + 1E5 + 100;
const int INF = (1 << 30) + 100;
//===========================================io
int read() {
int ret = 0;
char c = getchar();
while ('0' > c || c > '9') c = getchar();
while ('0' <= c && c <= '9') {
ret *= 10;
ret += c - '0';
c = getchar();
}
return ret;
}
void write(int x) {
static char wbuf[8];
if (0 == x) putchar('0');
int sz = 0;
while (x) {
wbuf[sz++] = '0' + x % 10;
x /= 10;
}
while (sz) putchar(wbuf[--sz]);
}
//===========================================splay
class Node { // Splay
public:
Node *p = 0, *c[2];
bool flip = 0; //可选(区间翻转)
int val, cnt; //可选(视维护的值而定)
int sz; //可选(kth,rnk需要)
Node() {}
Node(int _val) {
c[0] = c[1] = 0;
fix();
val = _val;
cnt = 1;
sz = 1;
}
void fix() {
sz = cnt;
if (c[0]) {
c[0]->p = this;
sz += c[0]->sz;
}
if (c[1]) {
c[1]->p = this;
sz += c[1]->sz;
}
}
void pushFlip() {
if (!flip) return;
flip = 0;
swap(c[0], c[1]);
if (c[0]) c[0]->flip ^= 1;
if (c[1]) c[1]->flip ^= 1;
}
int type() { return p ? p->c[1] == this : -1; }
void rot(int i, int j) {
int b = i ^ j;
Node *x = c[i], *y = j == 2 ? x : x->c[j], *z = b ? y : x;
if ((y->p = p)) p->c[type()] = y;
c[i] = z->c[i ^ 1];
if (j < 2) {
x->c[j] = y->c[j ^ 1];
z->c[j ^ 1] = b ? x : this;
}
y->c[i ^ 1] = b ? this : x;
fix();
x->fix();
y->fix();
if (p) p->fix();
}
void splay() {
for (pushFlip(); p;) {
if (p->p) p->p->pushFlip();
p->pushFlip();
pushFlip();
int c1 = type(), c2 = p->type();
if (-1 == c2)
p->rot(c1, 2);
else
p->p->rot(c2, c1);
}
}
Node *first() {
Node *c = this;
while (c->c[1]) c = c->c[1];
return c;
}
};
class Splay {
private:
int n;
Node node[MAX_N];
Node *rt;
//=======================
public:
void push(int x) {
if (!n) {
node[n++] = Node(x);
rt = node;
} else {
Node *c = rt;
while (true) {
if (c->val == x) {
c->cnt += 1;
break;
}
int i = c->val < x;
if (c->c[i]) {
c = c->c[i];
} else {
c->c[i] = node + n;
node[n++] = Node(x);
c->c[i]->p = c;
c = c->c[i];
break;
}
}
c->splay();
rt = c;
}
}
//需确保有数可删
void remove(int x) {
Node *c = rt;
while (true) {
if (c->val == x) {
break;
}
c = c->c[c->val < x];
}
c->splay();
rt = c;
if (!--c->cnt) {
if (c->c[0]) {
rt = c->c[0]->first();
c->c[0]->p = 0;
rt->splay();
rt->c[1] = c->c[1];
rt->fix();
} else {
rt = c->c[1];
if (rt) {
rt->p = 0;
} else {
n = 0;
}
}
} else {
--c->sz;
}
}
int rnk(int x) {
push(x);
int ret = rt->c[0] ? rt->c[0]->sz : 0;
remove(x);
return ret + 1;
}
//需确保k不大于总数
int kth(int k) {
Node *c = rt;
while (true) {
int lsz = c->c[0] ? c->c[0]->sz : 0;
if (lsz + c->cnt < k) {
k -= c->cnt + lsz;
c = c->c[1];
} else if (lsz < k) {
c->splay();
rt = c;
return c->val;
} else {
c = c->c[0];
}
}
}
int lst(int x) {
push(x);
Node *c = rt->c[0];
while (c->c[1]) c = c->c[1];
remove(x);
return c->val;
}
int nxt(int x) {
push(x);
Node *c = rt->c[1];
while (c->c[0]) c = c->c[0];
remove(x);
return c->val;
}
};
//===========================================
Splay tr;
int n, m;
int main() {
n = read();
m = read();
while (n--) {
int x = read();
tr.push(x);
}
int ans = 0, lst = 0;
while (m--) {
int opt = read();
int x = read() ^ lst;
switch (opt) {
case 1:
tr.push(x);
break;
case 2:
tr.remove(x);
break;
case 3:
ans ^= lst = tr.rnk(x);
break;
case 4:
ans ^= lst = tr.kth(x);
break;
case 5:
ans ^= lst = tr.lst(x);
break;
case 6:
ans ^= lst = tr.nxt(x);
break;
}
}
write(ans);
putchar('\n');
return 0;
}
区间翻转
还是基于splay操作,对于翻转 [l,r]
先把 l-1 splay至根
再把 r+1 splay至根的右儿子
翻转根的右儿子的左儿子即可
类似lazytag的做法
代码(洛谷P3391 【模板】文艺平衡树)
#include <iostream>
using namespace std;
const int MAX_N = 1E5 + 100;
int n, m;
class Node {
public:
int val, sz;
bool flip;
Node *p, *c[2];
Node() {
flip = 0;
p = c[0] = c[1] = 0;
}
void pushFlip() {
if (!flip) return;
flip = 0;
swap(c[0], c[1]);
if (c[0]) c[0]->flip ^= 1;
if (c[1]) c[1]->flip ^= 1;
}
void fix() {
sz = 1;
if (c[0]) {
sz += c[0]->sz;
c[0]->p = this;
}
if (c[1]) {
sz += c[1]->sz;
c[1]->p = this;
}
}
int type() { return p ? this == p->c[1] : 2; }
void rot(int i, int j) {
int b = i ^ j;
Node *x = c[i], *y = 2 & j ? x : x->c[j], *z = b ? y : x;
if (y->p = p) p->c[type()] = y;
c[i] = z->c[i ^ 1];
if (j < 2) {
x->c[j] = y->c[j ^ 1];
z->c[j ^ 1] = b ? x : this;
}
y->c[i ^ 1] = b ? this : x;
fix();
x->fix();
y->fix();
if (p) p->fix();
}
void splay() {
for (pushFlip(); p;) {
if (p->p) p->p->pushFlip();
p->pushFlip();
pushFlip();
if (2 & p->type())
p->rot(type(), 2);
else
p->p->rot(p->type(), type());
}
}
};
class Splay {
private:
int n;
Node node[MAX_N];
Node *rt;
void print(Node *c) {
c->pushFlip();
if (c->c[0]) print(c->c[0]);
if (1 <= c->val && c->val <= n) cout << c->val << ' ';
if (c->c[1]) print(c->c[1]);
}
public:
void init(int _n) {
n = _n;
rt = node;
node[0].val = 0;
node[0].sz = n + 2;
for (int i = 1; i <= n + 1; ++i) {
node[i].val = i;
node[i].p = node + i - 1;
node[i - 1].c[1] = node + i;
node[i].sz = n + 2 - i;
}
}
Node *find(int k) {
k += 1;
Node *c = rt;
while (true) {
c->pushFlip();
int lsz = c->c[0] ? c->c[0]->sz : 0;
if (lsz + 1 < k) {
k -= 1 + lsz;
c = c->c[1];
} else if (lsz < k) {
return c;
} else {
c = c->c[0];
}
}
}
void flip(int l, int r) {
Node *ln = find(l - 1);
ln->splay();
rt = ln;
Node *rn = find(r + 1);
ln->c[1]->p = 0;
rn->splay();
rn->p = ln;
ln->c[1] = rn;
rn->c[0]->flip ^= 1;
}
void print() { print(rt); }
};
Splay tr;
int main() {
ios::sync_with_stdio(false);
cin >> n >> m;
tr.init(n);
while (m--) {
int l, r;
cin >> l >> r;
tr.flip(l, r);
}
tr.print();
return 0;
}