太好了孩子们,主播学了树链剖分薄纱绿蓝紫题
2025/9/11 一个 911 的日子
前几天刚写了《雨》,没想到第二天便是阳光明媚,甚至是暴晒,我们初三去跑操还在阳光最大的地方暴晒了 1 小时,给我大帅哥晒成包青天,白月光晒成黑魔仙了。
于是,有了
$$\ce{6CO2 + 6H2O ->[{\text{光}}][{\text{叶绿体}}] C6H12O6 + 6O2}$$
把光能转化成学习的动力了——总算不是打游戏。

重链剖分

因为主播只会这个。

引入

我们学过很多维护序列的数据结构,如 st 表、树状数组、线段树……
但树能用的好像没什么,那我们怎么把他们用到树上呢?(绝对不是我想线段树了,大笨蛋)
把树分成链就好了。但若是直接用 DFS 序来剖分链,链的数量可能会很大,最后又成了暴力了。
于是有了重链剖分

原理

定义重孩子:子树最大的孩子。
重链:除了第一个点其他都是重孩子的链。
轻边:不属于任何重链的边,一定是轻孩子和父节点的边。
重链和轻边的总数不会超过 $2\log n$

证明

设节点为 $u$,重儿子为 $w$,一个轻儿子为 $v$,则有:
$$sz_w \geq sz_v$$
$$sz_u \geq sz_w+sz_v \geq 2sz_v$$
也就是说,每经过一条轻边,节点数就至少会翻倍!
哇塞,可以把树变成 $O(\log n)$ 的链了!
我们直接两次 DFS,先求出 sz、fa、dep 和重孩子,再优先遍历重孩子,求出 dfn、top(链头,非重孩子的 top 是它本身)
然后再沿用倍增求 LCA 的思想即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
int sz[N],dep[N],wc[N],dfn[N],cnt,seq[N],top[N],fa[N];
void dfs1(int u,int from){
dep[u]=dep[from]+1,sz[u]=1,fa[u]=from;
for(int v:g[u]){
if(v==from) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[wc[u]]) wc[u]=v;
}
}
void dfs2(int u,int Top){
dfn[u]=++cnt,seq[cnt]=u;
top[u]=Top;
if(wc[u]!=0){
dfs2(wc[u],Top);
for(int v:g[u])
if(v!=fa[u]&&v!=wc[u]) dfs2(v,v);
}
}
void upd(int x,int y){//你想干什么都可以哦
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
//对dfn[top[x]]~dfn[x]序列操作
x=fa[top[x]];
}
//min(dfn[x],dfn[y])~max(dfn[x],dfn[y])序列操作
}

搭配线段树食用最佳!

P3384 【模板】重链剖分/树链剖分

蓝题模板
经典的树剖+线段树,后面的题基本都是。

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#include<bits/stdc++.h>
#define int long long
#define ls o<<1
#define rs o<<1|1
using namespace std;
const int N=2e5+5;
int n,m,r,p,a[N];
int sz[N],dep[N],wc[N],dfn[N],cnt,seq[N],top[N],fa[N];
vector<int> g[N];
void dfs1(int u,int from){
dep[u]=dep[from]+1,sz[u]=1,fa[u]=from;
for(int v:g[u]){
if(v==from) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[wc[u]]) wc[u]=v;
}
}
void dfs2(int u,int Top){
dfn[u]=++cnt,seq[cnt]=u;
top[u]=Top;
if(wc[u]!=0){
dfs2(wc[u],Top);
for(int v:g[u])
if(v!=fa[u]&&v!=wc[u]) dfs2(v,v);
}
}
struct node{
int sum,add;
}tr[N<<2];
void pushup(int o){
tr[o].sum=(tr[ls].sum+tr[rs].sum)%p;
}
void tag(int o,int l,int r,int add){
tr[o].add=(tr[o].add+add)%p;
tr[o].sum=(tr[o].sum+(r-l+1)*add%p)%p;
}
void pushdown(int o,int l,int r){//更新子节点
int mid=l+r>>1;
tag(ls,l,mid,tr[o].add);
tag(rs,mid+1,r,tr[o].add);
tr[o].add=0;
}
void build(int o,int l,int r){
if(l==r){
tr[o].sum=a[seq[l]];//初始化
return;
}
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(o);
}
void update(int o,int l,int r,int L,int R,int x){
if(l>=L&&r<=R){
tag(o,l,r,x);
return;
}
pushdown(o,l,r);
int mid=l+r>>1;
if(L<=mid) update(ls,l,mid,L,R,x);
if(R>mid) update(rs,mid+1,r,L,R,x);
pushup(o);
}
void upd(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],z);
x=fa[top[x]];
}
update(1,1,n,min(dfn[x],dfn[y]),max(dfn[x],dfn[y]),z);
}
int query(int o,int l,int r,int L,int R){
if(l>=L&&r<=R) return tr[o].sum;
pushdown(o,l,r);
int mid=l+r>>1,res=0;
if(L<=mid) res=(res+query(ls,l,mid,L,R))%p;//合并区间值
if(R>mid) res=(res+query(rs,mid+1,r,L,R))%p;
return res;
}
int qry(int x,int y){
int res=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res=(res+query(1,1,n,dfn[top[x]],dfn[x]))%p;
x=fa[top[x]];
}
return (res+query(1,1,n,min(dfn[x],dfn[y]),max(dfn[x],dfn[y])))%p;
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0);
cin>>n>>m>>r>>p;
for(int i=1;i<=n;i++) cin>>a[i],a[i]%=p;
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(r,0);
dfs2(r,0);
build(1,1,n);
int op,x,y,z;
while(m--){
cin>>op;
if(op==1) cin>>x>>y>>z,upd(x,y,z);
if(op==2) cin>>x>>y,cout<<qry(x,y)<<'\n';
if(op==3) cin>>x>>z,update(1,1,n,dfn[x],dfn[x]+sz[x]-1,z);
if(op==4) cin>>x,cout<<query(1,1,n,dfn[x],dfn[x]+sz[x]-1)<<'\n';
}
return 0;
}

P3038 [USACO11DEC] Grass Planting G

边权转点权

讲解

令子节点的点权为与父亲连边的边权,然后 LCA 的点权不用算进去即可。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#include<bits/stdc++.h>
#define int long long
#define ls o<<1
#define rs o<<1|1
using namespace std;
const int N=2e5+5;
int n,m,r,a[N];
int sz[N],dep[N],wc[N],dfn[N],cnt,seq[N],top[N],fa[N];
vector<int> g[N];
void dfs1(int u,int from){
dep[u]=dep[from]+1,sz[u]=1,fa[u]=from;
for(int v:g[u]){
if(v==from) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[wc[u]]) wc[u]=v;
}
}
void dfs2(int u,int Top){
dfn[u]=++cnt,seq[cnt]=u;
top[u]=Top;
if(wc[u]!=0){
dfs2(wc[u],Top);
for(int v:g[u])
if(v!=fa[u]&&v!=wc[u]) dfs2(v,v);
}
}
struct node{
int sum,add;
}tr[N<<2];
void pushup(int o){
tr[o].sum=tr[ls].sum+tr[rs].sum;
}
void tag(int o,int l,int r,int add){
tr[o].add+=add;
tr[o].sum+=(r-l+1)*add;
}
void pushdown(int o,int l,int r){//更新子节点
int mid=l+r>>1;
tag(ls,l,mid,tr[o].add);
tag(rs,mid+1,r,tr[o].add);
tr[o].add=0;
}
void build(int o,int l,int r){
if(l==r){
tr[o].sum=a[seq[l]];//初始化
return;
}
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(o);
}
void update(int o,int l,int r,int L,int R,int x){
if(l>=L&&r<=R){
tag(o,l,r,x);
return;
}
pushdown(o,l,r);
int mid=l+r>>1;
if(L<=mid) update(ls,l,mid,L,R,x);
if(R>mid) update(rs,mid+1,r,L,R,x);
pushup(o);
}
void upd(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],z);
x=fa[top[x]];
}
if(x!=y) update(1,1,n,min(dfn[x],dfn[y])+1,max(dfn[x],dfn[y]),z);
}
int query(int o,int l,int r,int L,int R){
if(l>=L&&r<=R) return tr[o].sum;
pushdown(o,l,r);
int mid=l+r>>1,res=0;
if(L<=mid) res+=query(ls,l,mid,L,R);//合并区间值
if(R>mid) res+=query(rs,mid+1,r,L,R);
return res;
}
int qry(int x,int y){
int res=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res+=query(1,1,n,dfn[top[x]],dfn[x]);
x=fa[top[x]];
}
return res+query(1,1,n,min(dfn[x],dfn[y])+1,max(dfn[x],dfn[y]));
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0);
cin>>n>>m;
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1,0);
dfs2(1,0);
build(1,1,n);
char op;
int u,v;
while(m--){
cin>>op>>u>>v;
if(op=='P') upd(u,v,1);
else cout<<qry(u,v)<<'\n';
}
return 0;
}

P4092 [HEOI2016/TJOI2016] 树 && P4116 Qtree3

简单1
简单2

讲解

两道题类似,都可以直接用线段树维护,但两者都有更简单的写法——前者并查集,后者 set。

线段树写法(题1的,题2基本双倍经验)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include<bits/stdc++.h>
#define int long long
#define ls o<<1
#define rs o<<1|1
using namespace std;
const int N=2e5+5;
int n,m,r,p,a[N];
int sz[N],dep[N],wc[N],dfn[N],cnt,seq[N],top[N],fa[N];
vector<int> g[N];
void dfs1(int u,int from){
dep[u]=dep[from]+1,sz[u]=1,fa[u]=from;
for(int v:g[u]){
if(v==from) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[wc[u]]) wc[u]=v;
}
}
void dfs2(int u,int Top){
dfn[u]=++cnt,seq[cnt]=u;
top[u]=Top;
if(wc[u]!=0){
dfs2(wc[u],Top);
for(int v:g[u])
if(v!=fa[u]&&v!=wc[u]) dfs2(v,v);
}
}
int tr[N<<2];
void pushup(int o){
tr[o]=max(tr[ls],tr[rs]);
}
void tag(int o,int l,int r,int pos){
if(l<=pos&&pos<=r) tr[o]=max(tr[o],pos);
}
void build(int o,int l,int r){
if(l==r){
tr[o]=0;//初始化
return;
}
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(o);
}
void update(int o,int l,int r,int x){
if(l==r){
tr[o]=x;
return;
}
int mid=l+r>>1;
if(x<=mid) update(ls,l,mid,x);
else update(rs,mid+1,r,x);
pushup(o);
}

int query(int o,int l,int r,int L,int R){
if(l>=L&&r<=R) return tr[o];
int mid=l+r>>1,res=-1;
if(L<=mid) res=max(res,query(ls,l,mid,L,R));//合并区间值
if(R>mid) res=max(res,query(rs,mid+1,r,L,R));
return res;
}
int qry(int x,int y){
int res=-1;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res=max(res,query(1,1,n,dfn[top[x]],dfn[x]));
x=fa[top[x]];
}
return seq[max(res,query(1,1,n,min(dfn[x],dfn[y]),max(dfn[x],dfn[y])))];
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0);
cin>>n>>m;
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1,0);
dfs2(1,0);
build(1,1,n);
update(1,1,n,dfn[1]);
char op;
int x,y,z;
while(m--){
cin>>op>>x;
if(op=='Q') cout<<qry(x,1)<<'\n';
else update(1,1,n,dfn[x]);
}
return 0;
}

P3976 [TJOI2015] 旅游

紫题
其实没有比前面蓝题难多少。线段树维护一下前-后和后-前的最大值,在模板基础上判断一下链的方向即可。
注意我用的重载运算符,加法要注意顺序(写 merge 函数可能会好分辨一点)

代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
#include<bits/stdc++.h>
#define int long long
#define ls o<<1
#define rs o<<1|1
using namespace std;
const int N=2e5+5,inf=1e9;
int n,m,a[N];
int sz[N],dep[N],wc[N],dfn[N],cnt,seq[N],top[N],fa[N];
vector<int> g[N];
void dfs1(int u,int from){
dep[u]=dep[from]+1,sz[u]=1,fa[u]=from;
for(int v:g[u]){
if(v==from) continue;
dfs1(v,u);
sz[u]+=sz[v];
if(sz[v]>sz[wc[u]]) wc[u]=v;
}
}
void dfs2(int u,int Top){
dfn[u]=++cnt,seq[cnt]=u;
top[u]=Top;
if(wc[u]!=0){
dfs2(wc[u],Top);
for(int v:g[u])
if(v!=fa[u]&&v!=wc[u]) dfs2(v,v);
}
}
struct node{
int qh,hq,mx,mn,add;
node operator +(const node&w)const{//注意这里加法有顺序!
return {max({qh,w.qh,mx-w.mn}),max({hq,w.hq,w.mx-mn}),max(mx,w.mx),min(mn,w.mn),0};
}
}tr[N<<2];
void pushup(int o){
tr[o]=tr[ls]+tr[rs];
}
void tag(int o,int l,int r,int add){
tr[o].add+=add,tr[o].mx+=add,tr[o].mn+=add;
}
void pushdown(int o,int l,int r){//更新子节点
int mid=l+r>>1;
tag(ls,l,mid,tr[o].add);
tag(rs,mid+1,r,tr[o].add);
tr[o].add=0;
}
void build(int o,int l,int r){
if(l==r){
tr[o]={0,0,a[seq[l]],a[seq[l]],0};//初始化
return;
}
int mid=l+r>>1;
build(ls,l,mid);
build(rs,mid+1,r);
pushup(o);
}
void update(int o,int l,int r,int L,int R,int x){
if(l>=L&&r<=R){
tag(o,l,r,x);
return;
}
pushdown(o,l,r);
int mid=l+r>>1;
if(L<=mid) update(ls,l,mid,L,R,x);
if(R>mid) update(rs,mid+1,r,L,R,x);
pushup(o);
}
node query(int o,int l,int r,int L,int R){
if(l>=L&&r<=R) return tr[o];
pushdown(o,l,r);
int mid=l+r>>1;
node res={0,0,0,inf,0};
if(L<=mid) res=res+query(ls,l,mid,L,R);//合并区间值
if(R>mid) res=res+query(rs,mid+1,r,L,R);
return res;
}
void upd(int x,int y,int z){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
update(1,1,n,dfn[top[x]],dfn[x],z);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
update(1,1,n,dfn[x],dfn[y],z);
}
int qry(int x,int y){
node l,r;//u和v的
l=r={0,0,0,inf,0};
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]){
r=query(1,1,n,dfn[top[y]],dfn[y])+r;
y=fa[top[y]];
}
else{
l=query(1,1,n,dfn[top[x]],dfn[x])+l;
x=fa[top[x]];
}
}
if(dep[x]>dep[y]) l=query(1,1,n,dfn[y],dfn[x])+l;
else r=query(1,1,n,dfn[x],dfn[y])+r;
swap(l.qh,l.hq);
return (l+r).hq;
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0);
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
dfs1(1,0);
dfs2(1,0);
build(1,1,n);
cin>>m;
while(m--){
int x,y,z;
cin>>x>>y>>z;
cout<<qry(x,y)<<'\n';
upd(x,y,z);
}
return 0;
}