复习长链剖分。
首先划分长链,走到每个长链顶就给这个长链开个数组拿来记答案,数组的长度(从0开始算)是整个长链上点的个数,下标为k的位置所表示的意义是到长链顶距离为k的点的个数,沿着重链走的时候,直接指针偏移一位就行了(毕竟状态在深度上连续)。
然后DFS,走到每一个点,先给这个点的DP数组指针(设为arr),下标为0的位置加个1,表示这个点自己的答案。然后如果这个点有重子节点,就先沿着重子节点走,走的时候DP数组直接偏移上1,然后这个点的答案先记录为重子节点的答案+1(我们要考虑所有子节点的答案,一会拿这玩意来更新)。
对于非重子节点的点,开个新的DP数组,然后直接传下去往下递归。
非重子节点返回后尝试更新答案。枚举非重子节点DP数组上的每一个状态,把他的值加到当前节点的DP值上(即合并子链状态,把状态合并到重链上),然后顺带更新一波结果(如果某个深度的值出现的比当前的多,就直接更新,如果相当就比一下深度),合并晚状态后直接把非重子节点的内存回收掉即可。
显然复杂度是$\theta(n)$的,每条重链会有一个DP数组,这个DP数组的长度是重链的长度,每条重链的答案只会被合并一次,而所有重链的长度之和是$n$。
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>
using int_t = long long int;
const int_t LARGE = 1e6;
using std::cin;
using std::cout;
using std::endl;
int_t results[LARGE + 1];
int_t depth[LARGE + 1];
int_t maxDepth[LARGE + 1];
int_t maxChd[LARGE + 1];
std::vector<int_t> graph[LARGE + 1];
int_t n;
void DFS1(int_t vtx, int_t from = -1) {
maxChd[vtx] = -1;
maxDepth[vtx] = depth[vtx];
for (auto to : graph[vtx])
if (to != from) {
depth[to] = depth[vtx] + 1;
DFS1(to, vtx);
if (maxChd[vtx] == -1 || maxDepth[to] > maxDepth[maxChd[vtx]]) {
maxChd[vtx] = to;
}
maxDepth[vtx] = std::max(maxDepth[vtx], maxDepth[to]);
}
}
// arr 距离为x的点的个数
void DFS2(int_t vtx, int_t from = -1, int_t* arr = nullptr) {
arr[0]++;
auto& result = results[vtx];
if (maxChd[vtx] != -1) {
DFS2(maxChd[vtx], vtx, arr + 1);
result = results[maxChd[vtx]] + 1; //答案初始值
}
for (auto to : graph[vtx])
if (to != from && to != maxChd[vtx]) {
int_t arrlen = maxDepth[to] - depth[vtx] + 1;
int_t* next = new int_t[arrlen];
std::fill(next, next + arrlen, 0);
// next[0] = 1;
DFS2(to, vtx, next + 1);
for (int_t i = 1; i < arrlen; i++) {
arr[i] += next[i];
if (arr[i] > arr[result])
result = i;
else if (arr[i] == arr[result] && i < result)
result = i;
}
delete[] next;
}
if (arr[result] == 1)
result = 0; //最小的距离
}
int main() {
std::ios::sync_with_stdio(false);
cin >> n;
for (int_t i = 1; i <= n - 1; i++) {
int_t u, v;
cin >> u >> v;
graph[u].push_back(v);
graph[v].push_back(u);
}
depth[1] = 1;
DFS1(1);
#ifdef DEBUG
for (int_t i = 1; i <= n; i++) {
cout << "maxchd " << i << " = " << maxChd[i] << endl;
}
#endif
static int_t arr[LARGE + 1];
DFS2(1, -1, arr);
for (int_t i = 1; i <= n; i++) {
cout << results[i] << endl;
}
return 0;
}
发表回复