0%

树链剖分

。。。。

有点懒;

需要先理解几个概念:

  1. LCA

  2. 线段树(熟练,要不代码能调一天)

  3. 图论的基本知识(dfs序的性质)

这大概就好了;

定义

  1.重儿子:一个点所连点树size最大的,这个son被称为这个点的重儿子;

  2.轻儿子:一个点所连点除重儿子以外的都是轻儿子;

  3.重链:从一个轻儿子或根节点开始沿重儿子走所成的链;

步骤

  在代码里,结合代码更清晰。。。(其实是太懒了)

 有重点需要注意的东西在code中有提到,仔细看。。。。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include<bits/stdc++.h>
#define maxn 100007
#define le(x) x<<1
#define re(x) x<<1|1
using namespace std;
int n,m,root,mod,a[maxn],head[maxn],fa[maxn],son[maxn],cnt,tag[maxn<<2];
//a:原始点值,fa:父亲节点,son:重儿子,tag:懒标记
int top[maxn],sz[maxn],id[maxn],dep[maxn],w[maxn],cent,tr[maxn<<2];
//top:所在重链的头结点,sz:子树大小,id:dfs序,dep:深度
//w:dfs序所对应的值(建线段树),tr:线段树
struct node{
int next,to;
}edge[maxn<<2];

template<typename type_of_scan>
inline void scan(type_of_scan &x){
type_of_scan f=1;x=0;char s=getchar();
while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar();
while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar();
x*=f;
}

inline void add(int u,int v){
edge[++cent]=(node){head[u],v};head[u]=cent;
}
//-----------------------------------------------------线段树红色预警
void push_up(int p){
tr[p]=tr[le(p)]+tr[re(p)];
tr[p]%=mod;
}

void build(int l,int r,int p){
if(l==r){
tr[p]=w[l];
return ;
}
int mid=(l+r)>>1;
build(l,mid,le(p));
build(mid+1,r,re(p));
push_up(p);
}

void push_down(int l,int r,int p,int k){
int mid=l+r>>1;
tr[le(p)]+=k*(mid-l+1),tr[re(p)]+=k*(r-mid);
tr[le(p)]%=mod,tr[re(p)]%=mod;
tag[le(p)]+=k,tag[re(p)]+=k;
tag[le(p)]%=mod,tag[re(p)]%=mod;
}

void r_add(int nl,int nr,int l,int r,int p,int k){
if(nl<=l&&nr>=r){
tr[p]+=k*(r-l+1);tag[p]+=k;
tr[p]%=mod,tag[p]%=mod;
return ;
}
push_down(l,r,p,tag[p]),tag[p]=0;
int mid=(l+r)>>1;
if(nl<=mid) r_add(nl,nr,l,mid,le(p),k);
if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k);
push_up(p);
}

int r_query(int nl,int nr,int l,int r,int p){
int ans=0;
if(nl<=l&&nr>=r) return tr[p];
push_down(l,r,p,tag[p]),tag[p]=0;
int mid=l+r>>1;
if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p)),ans%=mod;
if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p)),ans%=mod;
push_up(p);
return ans;
}

//-----------------------------------------------------线段树结束
//-----------------------------------------------------开始预处理

void dfs1(int x){
sz[x]=1;//sz初始化
int max_part=-1;//max_part更新寻找重儿子
for(int i=head[x];i;i=edge[i].next){
int y=edge[i].to;
if(y==fa[x]) continue;
fa[y]=x,dep[y]+=dep[x]+1;//更新子节点,准备开始继续dfs1
dfs1(y);sz[x]+=sz[y];//更新自身的sz数组
if(max_part<sz[y]) son[x]=y,max_part=sz[y];//更新重儿子
}
}
/*dfs1功能介绍
1.更新fa数组;
2.更新dep数组;
3.更新sz数组;
4.更新son数组;
*/

void dfs2(int x,int t){
id[x]=++cnt,w[cnt]=a[x],top[x]=t;//更新dfs序,dfs序所对的值,重链头节点
if(!son[x]) return ;
dfs2(son[x],t);
for(int i=head[x];i;i=edge[i].next){
int y=edge[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
/*dfs2功能介绍
1.更新id数组;
2.更新w数组;
3.更新top数组
*/

//------------------------------------------------预处理结束
//------------------------------------------------开始主要操作

//其实没有说的这么简单,这里重点是理解重链之间的跳跃方式,线段树的优化
//一个性质:重链上的dfs序是连续的,dfs1在dfs2前的原因就在此

int road_query(int x,int y){
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下面往上跳
ans+=r_query(id[top[x]],id[x],1,n,1);//更新重链
ans%=mod;
x=fa[top[x]];//跳到重链头的fa
}
if(dep[x]>dep[y]) swap(x,y);
ans+=r_query(id[x],id[y],1,n,1);//已经在同一条重链上,直接加
return ans%mod;
}

int tree_query(int x){
return r_query(id[x],id[x]+sz[x]-1,1,n,1)%mod;
}//一个性质:在同一颗子树上的dfs序是连续的

void road_add(int x,int y,int k){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
r_add(id[top[x]],id[x],1,n,1,k);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
r_add(id[x],id[y],1,n,1,k);
return ;
}//类比

void tree_add(int x,int k){
r_add(id[x],id[x]+sz[x]-1,1,n,1,k);
return ;
}//相同的性质

//-----------------------------------------------树链剖分

int main(){
scan(n),scan(m),scan(root),scan(mod);
for(int i=1;i<=n;i++) scan(a[i]);
for(int i=1,u,v;i<=n-1;i++)
scan(u),scan(v),add(u,v),add(v,u);
dfs1(root),dfs2(root,root),build(1,n,1);
for(int i=1;i<=m;i++){
int type,x,y,z;
scan(type);
if(type==1) scan(x),scan(y),scan(z),
road_add(x,y,z);
else if(type==2) scan(x),scan(y),
printf("%d\n",road_query(x,y));
else if(type==3) scan(x),scan(z),
tree_add(x,z);
else if(type==4) scan(x),
printf("%d\n",tree_query(x));
}
return 0;
}

好了,可以开始调代码了

拓展:

  树链剖分,作为一个优秀的暴力结构,以O(n logn logn)的时间复杂度完成路径查询,在子树查询做到了nlogn级别,所以不得不说其优秀;

  但是,它的作用远不及此:

  1.LCA查询:

    与倍增相同,树链剖分可以用logn的时间复杂度完成LCA查询(跳跃性好像更优),而他的初始化是两遍dfs O(n),理论上更优。

    可以猜测,LCA依旧运用重链跳法,然后比较即可,这里给出示范代码

1
2
3
4
5
6
7
int Lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return dep[x]>dep[y]?y:x;
}//只要看懂树链剖分的基本操作,这个很简单

    可以看到,其实代码很短。。。

  2.换根操作:

    设现在的根是root,我们可以发现,换根对于路径上的操作并没有影响,但是子树操作就会影响了,所以我们分类讨论

      设u为我们要查的子树的根节点

      (1)如果root=u,那么子树即为整棵树;

      (2)设 lca 为root和u的LCA,这里可以用上面所讲的树链剖分做,如果lca!=u,那么root并不是u的子节点,所以对于查询并不影响,常规操作即可

      (3)如果lca=u,那么u节点的子树就是整颗树减去u-root这个路径上与u相挨的节点v的子树即可,这里给出logn求点v的方法

1
2
3
4
5
6
7
8
9
10
//前提条件:要求的节点相挨的节点u,必须是root的LCA 
int find(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳
if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了
x=fa[top[x]];//跳
}
if(dep[x]<dep[y]) swap(x,y);//让y最浅
return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的
}

    整个操作的代码层次感我写的还是比较清楚了

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
void tree_add(int x,int k){
if(root==x) r_add(1,n,1,n,1,k);//CASE 1
else{
int lca=Lca(x,root);
if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2
else{
int dson=find(x,root);
r_add(1,n,1,n,1,k);
r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k);
}//CASE 3
}
return ;
}

ll tree_query(int x){
if(root==x) return r_query(1,n,1,n,1);//CASE 1
else{
int lca=Lca(x,root);
if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2
else{
int dson=find(x,root);
return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1);
}//CASE 3
}
}

推荐评测网站LOJ 。。。(因为洛谷没有换根操作)

AC代码附上

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#include<bits/stdc++.h>
#define maxn 100007
#define ol putchar('\n')
#define le(x) x<<1
#define re(x) x<<1|1
#define ll long long
using namespace std;
int n,m,head[maxn],cent,dep[maxn],son[maxn],fa[maxn],vis[maxn];
int top[maxn],a[maxn],id[maxn],w[maxn],sz[maxn],cnt,ij,root;
ll tr[maxn<<3],tag[maxn<<3];
struct node{
int next,to;
}edge[maxn<<3];

template<typename type_of_scan>
inline void scan(type_of_scan &x){
type_of_scan f=1;x=0;char s=getchar();
while(s<'0'||s>'9') f=s=='-'?-1:1,s=getchar();
while(s>='0'&&s<='9') x=(x<<3)+(x<<1)+s-'0',s=getchar();
x*=f;
}
template<typename type_of_print>
inline void print(type_of_print x){
if(x<0) putchar('-'),x=-x;
if(x>9) print(x/10);
putchar(x%10+'0');
}

inline void add(int u,int v){
edge[++cent]=(node){head[u],v};head[u]=cent;
}

void push_up(int p){
tr[p]=tr[le(p)]+tr[re(p)];
}

void push_down(int l,int r,int p,ll k){
int mid=l+r>>1;
tr[le(p)]+=1ll*(mid-l+1)*k,
tr[re(p)]+=1ll*(r-mid)*k,
tag[le(p)]+=k,tag[re(p)]+=k;
}

void build(int l,int r,int p){
if(l==r){
tr[p]=w[l];
return ;
}
int mid=l+r>>1;
build(l,mid,le(p));
build(mid+1,r,re(p));
push_up(p);
}

void r_add(int nl,int nr,int l,int r,int p,int k){
if(nl<=l&&nr>=r){
tr[p]+=1ll*(r-l+1)*k;
tag[p]+=1ll*k;
return ;
}
push_down(l,r,p,tag[p]),tag[p]=0;
int mid=l+r>>1;
if(nl<=mid) r_add(nl,nr,l,mid,le(p),k);
if(nr>mid) r_add(nl,nr,mid+1,r,re(p),k);
push_up(p);
}

ll r_query(int nl,int nr,int l,int r,int p){
ll ans=0;
if(nl<=l&&nr>=r) return tr[p];
push_down(l,r,p,tag[p]),tag[p]=0;
int mid=l+r>>1;
if(nl<=mid) ans+=r_query(nl,nr,l,mid,le(p));
if(nr>mid) ans+=r_query(nl,nr,mid+1,r,re(p));
return ans;
}

void dfs1(int x){
sz[x]=1;int max_part=-1;vis[x]++;
for(int i=head[x];i;i=edge[i].next){
int y=edge[i].to;
if(y==fa[x]) continue;
fa[y]=x;dep[y]=dep[x]+1;
dfs1(y);sz[x]+=sz[y];
if(max_part<sz[y]) son[x]=y,max_part=sz[y];
}
}

void dfs2(int x,int t){
id[x]=++cnt;w[cnt]=a[x];top[x]=t;
if(!son[x]) return ;
dfs2(son[x],t);
for(int i=head[x];i;i=edge[i].next){
int y=edge[i].to;
if(y==son[x]||fa[x]==y) continue;
dfs2(y,y);
}
}

int Lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
return dep[x]>dep[y]?y:x;
}//只要看懂树链剖分的基本操作,这个很简单


//前提条件:要求的节点相挨的节点u,必须是root的LCA
int find(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);//从最下往上跳
if(fa[top[x]]==y) return top[x];//如果y是x所在重链top的父亲节点,那么就可以返回了
x=fa[top[x]];//跳
}
if(dep[x]<dep[y]) swap(x,y);//让y最浅
return son[y];// 因为在一条重链上,那么重儿子一定是路径上与要求节点相挨的
}

void tree_add(int x,int k){
if(root==x) r_add(1,n,1,n,1,k);//CASE 1
else{
int lca=Lca(x,root);
if(lca!=x) r_add(id[x],id[x]+sz[x]-1,1,n,1,k);//CASE 2
else{
int dson=find(x,root);
r_add(1,n,1,n,1,k);
r_add(id[dson],id[dson]+sz[dson]-1,1,n,1,-k);
}//CASE 3
}
return ;
}

ll tree_query(int x){
if(root==x) return r_query(1,n,1,n,1);//CASE 1
else{
int lca=Lca(x,root);
if(lca!=x) return r_query(id[x],id[x]+sz[x]-1,1,n,1);//CASE 2
else{
int dson=find(x,root);
return r_query(1,n,1,n,1)-r_query(id[dson],id[dson]+sz[dson]-1,1,n,1);
}//CASE 3
}
}

void road_add(int x,int y,ll k){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
r_add(id[top[x]],id[x],1,n,1,k);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
r_add(id[x],id[y],1,n,1,k);
return ;
}

ll road_query(int x,int y){
ll ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
ans+=r_query(id[top[x]],id[x],1,n,1);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
ans+=r_query(id[x],id[y],1,n,1);
return ans;
}

int main(){
// freopen("cin.in","r",stdin);
// freopen("co.out","w",stdout);
scan(n);
for(int i=1;i<=n;i++) scan(a[i]);
for(int i=2,v;i<=n;i++) scan(v),add(i,v),add(v,i);
dfs1(1),dfs2(1,1),build(1,n,1);root=1;
scan(m);
for(int i=1;i<=m;i++){
int type,x,y,z;
scan(type),scan(x);
if(type==1) root=x;
else if(type==2) scan(y),scan(z),road_add(x,y,z);
else if(type==3) scan(z),tree_add(x,z);
else if(type==4) scan(y),printf("%lld\n",road_query(x,y));
else if(type==5) printf("%lld\n",tree_query(x));
}
return 0;
}
本站访问次数: