POI2014 Hotels

设$f(vtx,n)$表示vtx为根的子树内,到vtx距离为n的点的个数。

设$g(vtx,n)$表示vtx内,如果在vtx再接上一个长度为u的链(目标为v)就能和v形成符合题目要求的三元组的点对个数。

转移的时候枚举子节点,先统计答案,然后合并状态。

对于边$(u,v)$,能贡献的答案为$g(u,x)\times f(v,x-1)+f(u,x)\times g(v,x+1)$。

然后合并进去,$g(u,x)+=f(u,x)\times f(v,x-1)+g(v,x+1)$以及$f(u,x)+=f(v,x-1)$。

注意要先在统计答案的同时合并状态,因为我们统计的是无序三元组。

复杂度$O(n^3)$,优化一下到$O(n^2)$。

考虑长链剖分优化。

每个点的状态直接从重子节点拖过来,然后其他节点的答案参考上面暴力合并进去。

#include <iostream>
#include <vector>
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;

const int_t LARGE = 2e5;
int_t depth[LARGE + 1], maxChd[LARGE + 1], maxDepth[LARGE + 1];
std::vector<int_t> graph[LARGE + 1];

void DFS1(int_t vtx, int_t from = -1) {
    if (from == -1) depth[vtx] = 1;
    maxDepth[vtx] = depth[vtx];
    for (int_t to : graph[vtx]) {
        if (to != from) {
            depth[to] = depth[vtx] + 1;
            DFS1(to, vtx);
            maxDepth[vtx] = std::max(maxDepth[vtx], maxDepth[to]);
            if (maxDepth[to] > maxDepth[maxChd[vtx]]) maxChd[vtx] = to;
        }
    }
}
// f(vtx,u)表示vtx子树内到vtx距离为u的点的个数
// g(vtx,u)表示vtx内,如果在vtx再接上一个长度为u的链(目标为v)就能和v形成符合题目要求的三元组的点对个数
void DFS2(int_t vtx, int_t from, int_t* f, int_t* g, int_t& result) {
    //递归重子节点
    if (maxChd[vtx]) {
        DFS2(maxChd[vtx], vtx, f + 1, g - 1, result);
#ifdef DEBUG
        cout << "at vtx " << vtx << " done heavys" << endl;
        for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
            cout << "f " << i << " = " << f[i] << endl;
        for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
            cout << "g " << i << " = " << g[i] << endl;
        cout << endl;
#endif
    }
    f[0] += 1;
    //重链贡献的答案
    result += g[0];
    //递归轻子节点
    for (int_t to : graph[vtx]) {
        if (to != from && to != maxChd[vtx]) {
            const int_t size = maxDepth[to] - depth[vtx] + 1;
            auto nextF = new int_t[size];
            auto nextG = new int_t[size * 2];
            std::fill(nextF, nextF + size, 0);
            std::fill(nextG, nextG + 2 * size, 0);
            nextG += size;
            DFS2(to, vtx, nextF, nextG, result);
            for (int_t i = 0; i < maxDepth[to] - depth[vtx]; i++) {
                result += g[i + 1] * (nextF)[i];
                if (i) result += (nextG)[i] * f[i - 1];
            }
            //合并答案
            for (int_t i = 0; i < maxDepth[to] - depth[vtx]; i++) {
                //合并g
                g[i + 1] += f[i + 1] * (nextF)[i];
                if (i) g[i - 1] += (nextG)[i];
                f[i + 1] += (nextF)[i];
            }
#ifdef DEBUG
            cout << "at vtx " << vtx << " done chd " << to << endl;
            for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
                cout << "f " << i << " = " << f[i] << endl;
            for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
                cout << "g " << i << " = " << g[i] << endl;
            cout << endl;
#endif
        }
    }
#ifdef DEBUG
    cout << "at vtx " << vtx << endl;
    for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
        cout << "f " << i << " = " << f[i] << endl;
    for (int_t i = 0; i <= maxDepth[vtx] - depth[vtx]; i++)
        cout << "g " << i << " = " << g[i] << endl;
    cout << endl;
#endif
    // result += g[0];
}
int n;
int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n - 1; i++) {
        int from, to;
        scanf("%d%d", &from, &to);
        graph[from].push_back(to);
        graph[to].push_back(from);
    }
    DFS1(1);
#ifdef DEBUG
    for (int_t i = 1; i <= n; i++) {
        cout << "maxChd " << i << " = " << maxChd[i] << endl;
    }
#endif
    int_t result = 0;
    int_t size = n + 1;
    auto f = new int_t[size];
    auto g = new int_t[size * 2];
    std::fill(f, f + size, 0);
    std::fill(g, g + size * 2, 0);
    DFS2(1, -1, f, g + size, result);
    cout << result << endl;
    return 0;
}

 

 

 

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理