1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
| #include <iostream> #include <algorithm> #include <cstring> #include <cstdio>
using namespace std;
typedef long long ll;
const int N = 50010; const int M = 1050;
struct node { int to, nex; };
node edge[2 * N]; int n, k, cnt, rt, sum, s, c, val[N], d[N]; int head[N], sz[N], son[N], vis[N], tot[M]; ll res;
inline void add_edge(int u, int v) { edge[++cnt].to = v; edge[cnt].nex = head[u]; head[u] = cnt; }
void dfs(int u, int fa) { sz[u] = 1; son[u] = 0; for (int i = head[u]; 0 != i; i = edge[i].nex) { int v = edge[i].to; if (v == fa || vis[v]) continue; dfs(v, u); sz[u] += sz[v]; son[u] = max(son[u], sz[v]); } son[u] = max(son[u], sum - sz[u]); if (son[u] < son[rt]) rt = u; }
void init() { memset(vis, 0, sizeof(vis)); memset(head, 0, sizeof(head)); cnt = 0; res = 0; s = (1 << k) - 1; }
void deep(int u, int fa, int now) { d[++c] = now; tot[now]++; for (int i = head[u]; 0 != i; i = edge[i].nex) { int v = edge[i].to; if (vis[v] || v == fa) continue; deep(v, u, now | val[v]); } }
ll calc(int u, int now) { c = 0; memset(tot, 0, sizeof(tot)); deep(u, 0, now); for (int j = 0; j < k; j++) for (int i = s; i >= 0; i--) if (((1 << j) & i) == 0) tot[i] += tot[i | (1 << j)]; ll r = 0; for (int i = 1; i <= c; i++) r += tot[d[i] ^ s]; return r; }
void solve(int u) { res += calc(u, val[u]); vis[u] = 1; for (int i = head[u]; 0 != i; i = edge[i].nex) { int v = edge[i].to; if (vis[v]) continue; res -= calc(v, val[u] | val[v]); sum = sz[v]; rt = 0; dfs(v, -1); solve(rt); } }
int main() { while (scanf("%d%d", &n, &k) != EOF) { init(); for (int i = 1; i <= n; i++) { int a; scanf("%d", &a); val[i] = 1 << (a - 1); } for (int i = 1; i <= n - 1; i++) { int u, v; scanf("%d%d", &u, &v); add_edge(u, v); add_edge(v, u); } rt = 0; sum = son[0] = n; dfs(1, -1); solve(rt); printf("%lld\n", res); } return 0; }
|