设$f(vtx,n)$表示vtx为根的子树内,到vtx距离为n的点的个数。
设$g(vtx,n)$表示vtx内,如果在vtx再接上一个长度为u的链(目标为v)就能和v形成符合题目要求的三元组的点对个数。
转移的时候枚举子节点,先统计答案,然后合并状态。
对于边$(u,v)$,能贡献的答案为$g(u,x)\times f(v,x-1)+f(u,x)\times g(v,x+1)$。
然后合并进去,$g(u,x)+=f(u,x)\times f(v,x-1)+g(v,x+1)$以及$f(u,x)+=f(v,x-1)$。
注意要先在统计答案的同时合并状态,因为我们统计的是无序三元组。
复杂度$O(n^3)$,优化一下到$O(n^2)$。
考虑长链剖分优化。
每个点的状态直接从重子节点拖过来,然后其他节点的答案参考上面暴力合并进去。
#include <iostream>
#include <vector>
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;
const int_t LARGE = 2e5;
int_t depth[LARGE + 1], maxChd[LARGE + 1], maxDepth[LARGE + 1];
std::vector<int_t> graph[LARGE + 1];
void DFS1(int_t vtx, int_t from = -1) {
if (from == -1) depth[vtx] = 1;
maxDepth[vtx] = depth[vtx];
for (int_t to : graph[vtx]) {
if (to != from) {
depth[to] = depth[vtx] + 1;
DFS1(to, vtx);
maxDepth[vtx] = std::max(maxDepth[vtx], maxDepth[to]);
if (maxDepth[to] > maxDepth[maxChd[vtx]]) maxChd[vtx] = to;
}
}
}
// f(vtx,u)表示vtx子树内到vtx距离为u的点的个数
// g(vtx,u)表示vtx内,如果在vtx再接上一个长度为u的链(目标为v)就能和v形成符合题目要求的三元组的点对个数
void DFS2(int_t vtx, int_t from, int_t* f, int_t* g, int_t& result) {
//递归重子节点
if (maxChd[vtx]) {
DFS2(maxChd[vtx], vtx, f + 1, g - 1, result);
#ifdef DEBUG
cout << "at vtx " << vtx << " done heavys" << endl;
for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
cout << "f " << i << " = " << f[i] << endl;
for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
cout << "g " << i << " = " << g[i] << endl;
cout << endl;
#endif
}
f[0] += 1;
//重链贡献的答案
result += g[0];
//递归轻子节点
for (int_t to : graph[vtx]) {
if (to != from && to != maxChd[vtx]) {
const int_t size = maxDepth[to] - depth[vtx] + 1;
auto nextF = new int_t[size];
auto nextG = new int_t[size * 2];
std::fill(nextF, nextF + size, 0);
std::fill(nextG, nextG + 2 * size, 0);
nextG += size;
DFS2(to, vtx, nextF, nextG, result);
for (int_t i = 0; i < maxDepth[to] - depth[vtx]; i++) {
result += g[i + 1] * (nextF)[i];
if (i) result += (nextG)[i] * f[i - 1];
}
//合并答案
for (int_t i = 0; i < maxDepth[to] - depth[vtx]; i++) {
//合并g
g[i + 1] += f[i + 1] * (nextF)[i];
if (i) g[i - 1] += (nextG)[i];
f[i + 1] += (nextF)[i];
}
#ifdef DEBUG
cout << "at vtx " << vtx << " done chd " << to << endl;
for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
cout << "f " << i << " = " << f[i] << endl;
for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
cout << "g " << i << " = " << g[i] << endl;
cout << endl;
#endif
}
}
#ifdef DEBUG
cout << "at vtx " << vtx << endl;
for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
cout << "f " << i << " = " << f[i] << endl;
for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
cout << "g " << i << " = " << g[i] << endl;
cout << endl;
#endif
// result += g[0];
}
int n;
int main() {
scanf("%d", &n);
for (int i = 1; i <= n - 1; i++) {
int from, to;
scanf("%d%d", &from, &to);
graph[from].push_back(to);
graph[to].push_back(from);
}
DFS1(1);
#ifdef DEBUG
for (int_t i = 1; i <= n; i++) {
cout << "maxChd " << i << " = " << maxChd[i] << endl;
}
#endif
int_t result = 0;
int_t size = n + 1;
auto f = new int_t[size];
auto g = new int_t[size * 2];
std::fill(f, f + size, 0);
std::fill(g, g + size * 2, 0);
DFS2(1, -1, f, g + size, result);
cout << result << endl;
return 0;
}
发表回复