大量的细节+自己不熟悉AC自动机。
- 基于DFS的记忆化搜索是错的,转移顺序有问题。
- 走到了一个点,即代表这个点在失配树上到根的所有点都匹配到了,所以要把他们加起来。
- DP各种东西好好想想,别眼瞎。
- 分清楚这步转移时从哪到哪的
原题的式子取个log,变成了经典的取若干个东西使得他们的平均值最大的问题,01分数规划即可。
给每个串加上一个附加权,然后AC自动机跑DP,最大化权值即可。
#include <assert.h>
#include <cmath>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <queue>
#include <string>
#include <unordered_map>
#include <utility>
using int_t = long long int;
using real_t = double;
using pair_t = std::pair<int_t, int_t>;
const real_t EPS = 1e-7;
const int_t INF = 0x7fffffff;
using std::cin;
using std::cout;
using std::endl;
const int_t LARGE = 2000;
struct Node* nodes[LARGE + 1];
int_t used = 0;
struct Node {
real_t val = 0;
Node* link = nullptr;
Node* chds[10];
int_t count = 0;
int_t id = 0;
int_t idx = 0;
Node*& access(char chr) { return chds[chr - '0']; }
Node() {
idx = ++used;
nodes[idx] = this;
memset(chds, 0, sizeof(chds));
}
};
Node* insert(Node* root, const char* str, real_t val) {
for (auto p = str; *p != '\0'; p++) {
if (root->access(*p) == nullptr) {
root->access(*p) = new Node;
}
root = root->access(*p);
}
root->val += val;
root->count += 1;
return root;
}
void BFS(Node* root) {
std::queue<Node*> queue;
for (Node*& to : root->chds)
if (to) {
to->link = root;
queue.push(to);
} else {
to = root;
}
while (queue.empty() == false) {
Node* front = queue.front();
queue.pop();
front->count += front->link->count;
front->val += front->link->val;
for (auto chr = '0'; chr <= '9'; chr++) {
Node*& to = front->access(chr);
if (to == nullptr) {
to = front->link->access(chr);
} else {
Node* parent = front->link;
while (parent && parent->access(chr) == nullptr)
parent = parent->link;
if (parent == nullptr)
to->link = root;
else
to->link = parent->access(chr);
queue.push(to);
}
}
}
}
char buf[LARGE + 1];
int_t n, m;
Node* root = new Node;
//已匹配长度为n,在第m个点的最优解
real_t dp[LARGE + 1][LARGE + 1];
pair_t from[LARGE + 1][LARGE + 1];
char string[LARGE + 1];
real_t check(real_t x) {
for (int_t i = 0; i <= n; i++) {
for (int_t j = 1; j <= used; j++) {
dp[i][j] = -INF;
}
}
dp[0][1] = 0;
for (int_t i = 0; i < n; i++) {
for (int_t j = 1; j <= used; j++) {
if (dp[i][j] == -INF) continue;
//枚举出边
if (buf[i] != '.') {
auto chr = buf[i];
Node* node = nodes[j]->access(chr);
if (dp[i + 1][node->idx] <
dp[i][j] + node->val - x * node->count) {
dp[i + 1][node->idx] =
dp[i][j] + node->val - x * node->count;
from[i + 1][node->idx] = pair_t(j, chr);
}
} else {
for (auto chr = '0'; chr <= '9'; chr++) {
Node* node = nodes[j]->access(chr);
if (dp[i + 1][node->idx] <
dp[i][j] + node->val - x * node->count) {
dp[i + 1][node->idx] =
dp[i][j] + node->val - x * node->count;
from[i + 1][node->idx] = pair_t(j, chr);
}
}
}
}
}
int_t max = -1, val = -INF;
for (int_t i = 1; i <= used; i++)
if (dp[n][i] > val) {
val = dp[n][i];
max = i;
}
int_t pos = max;
for (int_t i = n; i >= 1; i--) {
string[i - 1] = from[i][pos].second;
pos = from[i][pos].first;
}
return dp[n][max];
}
int main() {
cout.setf(std::ios::fixed);
cout << std::setprecision(5);
scanf("%lld%lld%s", &n, &m, buf);
for (int_t i = 1; i <= m; i++) {
static char buf[LARGE + 1];
int_t val;
scanf("%s%lld", buf, &val);
insert(root, buf, log(val));
}
BFS(root);
real_t left = 0, right = 25;
real_t mid;
real_t result;
while ((right - left) > EPS) {
mid = (left + right) / 2;
real_t checkval = check(mid);
if (checkval > 0) {
result = left = mid;
} else {
right = mid;
}
}
check(result);
printf("%s", string);
return 0;
}
发表回复