题目链接:Garden of Eden

题意:给定一颗n个节点的树,每个节点有一种颜色,颜色有k种,求树上有多少条路径包含这k种颜色,n<=50000,k<=10

思路:树上路径问题,用点分治求解,又由于k<=10,所以可以用二进制状态表示一条路径上包含的颜色集合,比如状态8转换成二进制为1000,那么就表示状态8表示的路径上含有第3种颜色(颜色标号从0开始)

那么考虑过重心向下的某一条路径,假设这条路径的二进制状态为d,设s表示含有所有颜色的集合,即s=(1<<k)-1,那么我们只需要找到二进制状态为s^d的集合的超集与d配对即可,所以需要求超集的和

1
2
3
4
5
6
7
8
9
// 超集和
for (int j = 0; j < k; j++)
for (int i = s; i >= 0; i--)
if (((1 << j) & i) == 0) tot[i] += tot[i | (1 << j)];

// 子集和
for (int j = 0; j < k; j++)
for (int i = 0; i <= s; i++)
if ((1 << j) & i) tot[i] ++ tot[i ^ (1 << j)];

求出超集后之后,按照点分治的一般步骤求解即可

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

using namespace std;

typedef long long ll;

const int N = 50010;
const int M = 1050;

struct node {
int to, nex;
};

node edge[2 * N];
int n, k, cnt, rt, sum, s, c, val[N], d[N];
int head[N], sz[N], son[N], vis[N], tot[M];
ll res;

inline void add_edge(int u, int v)
{
edge[++cnt].to = v;
edge[cnt].nex = head[u];
head[u] = cnt;
}

void dfs(int u, int fa)
{
sz[u] = 1;
son[u] = 0;
for (int i = head[u]; 0 != i; i = edge[i].nex) {
int v = edge[i].to;
if (v == fa || vis[v]) continue;
dfs(v, u);
sz[u] += sz[v];
son[u] = max(son[u], sz[v]);
}
son[u] = max(son[u], sum - sz[u]);
if (son[u] < son[rt]) rt = u;
}

void init()
{
memset(vis, 0, sizeof(vis));
memset(head, 0, sizeof(head));
cnt = 0;
res = 0;
s = (1 << k) - 1;
}

void deep(int u, int fa, int now)
{
d[++c] = now;
tot[now]++;
for (int i = head[u]; 0 != i; i = edge[i].nex) {
int v = edge[i].to;
if (vis[v] || v == fa) continue;
deep(v, u, now | val[v]);
}
}

ll calc(int u, int now)
{
c = 0;
memset(tot, 0, sizeof(tot));
deep(u, 0, now);
for (int j = 0; j < k; j++)
for (int i = s; i >= 0; i--)
if (((1 << j) & i) == 0) tot[i] += tot[i | (1 << j)];
ll r = 0;
for (int i = 1; i <= c; i++) r += tot[d[i] ^ s];
return r;
}

void solve(int u)
{
res += calc(u, val[u]);
vis[u] = 1;
for (int i = head[u]; 0 != i; i = edge[i].nex) {
int v = edge[i].to;
if (vis[v]) continue;
res -= calc(v, val[u] | val[v]);
sum = sz[v];
rt = 0;
dfs(v, -1);
solve(rt);
}
}

int main()
{
while (scanf("%d%d", &n, &k) != EOF) {
init();
for (int i = 1; i <= n; i++) {
int a;
scanf("%d", &a);
val[i] = 1 << (a - 1);
}
for (int i = 1; i <= n - 1; i++) {
int u, v;
scanf("%d%d", &u, &v);
add_edge(u, v);
add_edge(v, u);
}
rt = 0;
sum = son[0] = n;
dfs(1, -1);
solve(rt);
printf("%lld\n", res);
}
return 0;
}