BJOI2019 奥术神杖

大量的细节+自己不熟悉AC自动机。

  1. 基于DFS的记忆化搜索是错的,转移顺序有问题。
  2. 走到了一个点,即代表这个点在失配树上到根的所有点都匹配到了,所以要把他们加起来。
  3. DP各种东西好好想想,别眼瞎。
  4. 分清楚这步转移时从哪到哪的

原题的式子取个log,变成了经典的取若干个东西使得他们的平均值最大的问题,01分数规划即可。

给每个串加上一个附加权,然后AC自动机跑DP,最大化权值即可。

#include <assert.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <queue>
#include <string>
#include <unordered_map>
#include <utility>
using int_t = long long int;
using real_t = double;
using pair_t = std::pair<int_t, int_t>;
const real_t EPS = 1e-7;
const int_t INF = 0x7fffffff;
using std::cin;
using std::cout;
using std::endl;
const int_t LARGE = 2000;
struct Node* nodes[LARGE + 1];
int_t used = 0;
struct Node {
    real_t val = 0;
    Node* link = nullptr;
    Node* chds[10];
    int_t count = 0;
    int_t id = 0;
    int_t idx = 0;

    Node*& access(char chr) { return chds[chr - '0']; }
    Node() {
        idx = ++used;
        nodes[idx] = this;
        memset(chds, 0, sizeof(chds));
    }
};
Node* insert(Node* root, const char* str, real_t val) {
    for (auto p = str; *p != '\0'; p++) {
        if (root->access(*p) == nullptr) {
            root->access(*p) = new Node;
        }
        root = root->access(*p);
    }
    root->val += val;
    root->count += 1;
    return root;
}
void BFS(Node* root) {
    std::queue<Node*> queue;
    for (Node*& to : root->chds)
        if (to) {
            to->link = root;
            queue.push(to);
        } else {
            to = root;
        }
    while (queue.empty() == false) {
        Node* front = queue.front();
        queue.pop();
        front->count += front->link->count;
        front->val += front->link->val;
        for (auto chr = '0'; chr <= '9'; chr++) {
            Node*& to = front->access(chr);
            if (to == nullptr) {
                to = front->link->access(chr);
            } else {
                Node* parent = front->link;
                while (parent && parent->access(chr) == nullptr)
                    parent = parent->link;
                if (parent == nullptr)
                    to->link = root;
                else
                    to->link = parent->access(chr);
                queue.push(to);
            }
        }
    }
}

char buf[LARGE + 1];
int_t n, m;

Node* root = new Node;
//已匹配长度为n,在第m个点的最优解
real_t dp[LARGE + 1][LARGE + 1];
pair_t from[LARGE + 1][LARGE + 1];
char string[LARGE + 1];
real_t check(real_t x) {

    for (int_t i = 0; i <= n; i++) {
        for (int_t j = 1; j <= used; j++) {
            dp[i][j] = -INF;
        }
    }
    dp[0][1] = 0;
    for (int_t i = 0; i < n; i++) {
        for (int_t j = 1; j <= used; j++) {
            if (dp[i][j] == -INF) continue;
            //枚举出边
            if (buf[i] != '.') {
                auto chr = buf[i];
                Node* node = nodes[j]->access(chr);
                if (dp[i + 1][node->idx] <
                    dp[i][j] + node->val - x * node->count) {
                    dp[i + 1][node->idx] =
                        dp[i][j] + node->val - x * node->count;
                    from[i + 1][node->idx] = pair_t(j, chr);
                }

            } else {
                for (auto chr = '0'; chr <= '9'; chr++) {
                    Node* node = nodes[j]->access(chr);
                    if (dp[i + 1][node->idx] <
                        dp[i][j] + node->val - x * node->count) {
                        dp[i + 1][node->idx] =
                            dp[i][j] + node->val - x * node->count;
                        from[i + 1][node->idx] = pair_t(j, chr);
                    }
                }
            }
        }
    }
    int_t max = -1, val = -INF;
    for (int_t i = 1; i <= used; i++)
        if (dp[n][i] > val) {
            val = dp[n][i];
            max = i;
        }
    int_t pos = max;
    for (int_t i = n; i >= 1; i--) {
        string[i - 1] = from[i][pos].second;
        pos = from[i][pos].first;
    }
    return dp[n][max];
}
int main() {
    cout.setf(std::ios::fixed);
    cout << std::setprecision(5);
    scanf("%lld%lld%s", &n, &m, buf);
    for (int_t i = 1; i <= m; i++) {
        static char buf[LARGE + 1];
        int_t val;
        scanf("%s%lld", buf, &val);
        insert(root, buf, log(val));
    }
    BFS(root);
    real_t left = 0, right = 25;
    real_t mid;
    real_t result;
    while ((right - left) > EPS) {
        mid = (left + right) / 2;
        real_t checkval = check(mid);
        if (checkval > 0) {
            result = left = mid;
        } else {
            right = mid;
        }
    }
    check(result);
    printf("%s", string);
    return 0;
}

 

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理