树上启发式合并模板

这个算法真是研究了好久才明白……众多博客真的不是面向小白的

模板拿cf600E举例子,问题是给一棵有根树(rt==1),每个节点有一个颜色值,树的节点数和颜色值的范围都是1e5。对于每个节点我们定义其子树中数量最多的那个颜色值为主颜色(可以有多个并列),现要给出每个节点的主颜色的值的和。

首先想到暴力做法,直接去挨个点算他们的子树的答案,统计子树中每种颜色的数量并维护当前主颜色的和。这种算法显然是$O(n^2)$的。同时我们注意到一个细节,每个点的访问计算中所用到的cnt数组,应该是一个公用的全局数组,因为节点和颜色值的上限都是1e5,我们没法建立一个cnt[maxn][maxcolor]的数组(如果能建就简单多了),所以每次在访问一个点后,需要清空该点的树的记录,因为后面马上要去访问它的兄弟节点。这时就需要再写一个dfs去清除记录,用dfs而不是memset去清除是因为可以精准限制访问次数。

那么优化点在哪里呢,我们发现,当你把所有的兄弟节点访问完后,接下来需要计算父亲的答案值,又需要把这些兄弟节点的贡献加起来,十分浪费时间。如果能把他们保存在一个大数组里就好了(但是空间受限不行),最多只能保存最后访问的那个兄弟,于是我们选择把重儿子放在最后,尽可能利用这一优势。

那么这样优化下来的复杂度是多少呢?oiwiki上有严格证明,是nlogn。大概是每个点的访问次数是它到根的轻边+1?而轻边最多logn条,因为轻树的size小于父树的1/2。

看代码吧

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<bits/stdc++.h>
using namespace std;

const int maxn=1e5+1;

#define ll long long
vector<int> vec[maxn];
int col[maxn],cnt[maxn],vis[maxn],n,top=0; //cnt是每种颜色的数量
ll sum[maxn],ans[maxn]; //sum[i]表示出现次数为i的颜色数量和,sum[top]是答案ans[u]

int son[maxn],siz[maxn];

void dfs1(int u,int fa){
siz[u]=1;
for(int i=0;i<vec[u].size();i++){
if(fa!=vec[u][i]){
dfs1(vec[u][i],u);
siz[u]+=siz[vec[u][i]];
if(siz[vec[u][i]]>siz[son[u]]){ //一开始默认siz[0]=0
son[u]=vec[u][i];
}
}
}
}

void cal(int u,int fa,int val){ //计算树u,sum[top]是答案
sum[cnt[col[u]]]-=col[u];
cnt[col[u]]+=val;
sum[cnt[col[u]]]+=col[u];

if(cnt[col[u]]>top||sum[top]==0){
top=cnt[col[u]];
}
for(int i=0;i<vec[u].size();i++){
if(vec[u][i]==fa||vis[vec[u][i]])
continue;
cal(vec[u][i],u,val);
}
}

void dfs(int u,int fa,bool keep){ //预处理得到u及其子树的答案
for(int i=0;i<vec[u].size();i++){
if(fa==vec[u][i]||vec[u][i]==son[u])
continue;
dfs(vec[u][i],u,0); //先处理轻儿子的树
}
if(son[u])
dfs(son[u],u,1),vis[son[u]]=1; //处理重儿子,保留痕迹

cal(u,fa,1); //这一步才真正是统计树u的答案,
//它在向下计算的时候不会再对已vis标记的重儿子进行dfs,优化在这里了
ans[u]=sum[top];
if(son[u])
vis[son[u]]=0; //已经利用过了此次便利,清除标记
if(!keep)
cal(u,fa,-1); //清除对数组的影响
}

int main(){
cin>>n;
for(int i=1;i<=n;i++){
scanf("%d",col+i);
}
for(int i=1,x,y;i<n;i++){
scanf("%d%d",&x,&y);
vec[x].push_back(y);
vec[y].push_back(x);
}
dfs1(1,0);
dfs(1,0,1); //填1是为了省下最后那次遍历
for(int i=1;i<=n;i++)
printf("%lld ",ans[i]);
puts("");
return 0;
}

再来一道cf375d

给出一棵树,m个询问(v,k)要求回答树v中有多少种颜色出现超过k次。

我们在cal函数中维护树u中各个出现次数的颜色数量,然后将询问预先存在数组中,询问同一棵树的放在一起,在dfs函数中执行完cal后去依次回答这些提问。

只贴上不同的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
void cal(int u,int fa,int val){
if(val==-1){
res[cnt[col[u]]]--;
cnt[col[u]]--;
}else{
cnt[col[u]]++;
res[cnt[col[u]]]++;
}

for(int i=0;i<vec[u].size();i++){
if(fa==vec[u][i]||vis[vec[u][i]])
continue;
cal(vec[u][i],u,val);
}
}

void dfs(int u,int fa,bool keep){
for(int i=0;i<vec[u].size();i++){
if(fa==vec[u][i]||vec[u][i]==son[u])
continue;
dfs(vec[u][i],u,0);
}
if(son[u])
dfs(son[u],u,1),vis[son[u]]=1;
cal(u,fa,1);
for(int i=0;i<q[u].size();i++){
int id=q[u][i].first;
int k=q[u][i].second;
ans[id]=res[k];
}
if(son[u])
vis[son[u]]=0;
if(!keep)
cal(u,fa,-1);
}

int main(){
cin>>n>>m;
for(int i=1;i<=n;i++)
cin>>col[i];
for(int i=1,x,y;i<n;i++){
cin>>x>>y;
vec[x].push_back(y);
vec[y].push_back(x);
}
for(int i=1,v,k;i<=m;i++){
cin>>v>>k;
q[v].push_back(P(i,k));
}
dfs1(1,0);
dfs(1,0,1);
for(int i=1;i<=m;i++)
cout<<ans[i]<<endl;
return 0;
}