分类: 字符串

  • 洛谷5357 AC自动机模板 二次加强 复习

    简单复习了以下AC自动机。

    首先注意,$$n$$个字符串构建的AC自动机至多可能有总串长+1个节点(根节点不隶属于任何一个字符串)

    由fail指针构成的树叫做后缀链接树,每个点的父节点为这个点所表示的字符串,在整个ACAM里能匹配到的公共后缀最长的字符串。(与SAM里的后缀连接好像是一样的….)

    构建后缀链接,并且把Trie树补成Trie图(即将每个点的不存在的子节点指向下一个待匹配位置的操作)通过BFS进行。

    初始时: 根节点的子节点的fail指针全都指向根,根节点的不存在的子节点构造关于根的自环。

    接下来根节点的存在的子节点入队,开始BFS。

    从队列里取出来队首元素V,按照如下方法操作:

    遍历其子节点,设当前遍历到的子节点通过字符chr获得,如果这个子节点存在,则其后缀链接指向V的后缀链接指向的节点的chr子节点,并将该节点入队。

    如果该子节点不存在,则将其替换为V的后缀链接的chr子节点。

     

    如果遇到一个所有字符串都没出现的字符怎么办?这时我们肯定会走到根,然后沿着根节点的子节点走自环,成功滚到下一个字符。

     

    匹配时直接沿着Trie图走即可(我们在BFS中已经把所有点的子节点补完了),每走到一个点,则说明了我们匹配到了根到这个点的边所构成的字符串。

     

    然后怎么统计每个子串出现的次数呢?

    首先我们在ACAM上的每个点标记上以这个点结尾的字符串,可以知道这种标记只会有总字符串个数个。

    然后我们构建出后缀链接树,我们可以知道,当我们的字符串走到某个点的时候,就意味着匹配到了这个点在后缀链接树上到树根经历的所有字符串,然后我们打个到根的差分标记,最后DFS统计下标记即可求出来每个字符串访问了几遍。

    #include <algorithm>
    #include <cstdlib>
    #include <cstring>
    #include <iostream>
    #include <queue>
    #include <string>
    using int_t = long long int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    struct Node {
        Node* chd[26];
        Node* fail = nullptr;
        // std::vector<int_t> strs;
        std::vector<int_t> strs;
        int_t mark = 0;
        Node*& access(char x) { return chd[x - 'a']; }
        Node() { memset(chd, 0, sizeof(chd)); }
        int_t id;
    };
    std::vector<int_t> graph[int_t(2e5) + 10];
    
    Node* mapping[int_t(2e5) + 10];
    Node* root = new Node;
    int_t used = 1;
    void insert(const char* str, int_t id) {
        Node* curr = root;
        for (auto ptr = str; *ptr; ptr++) {
            auto& chd = curr->access(*ptr);
            if (chd == nullptr) {
                chd = new Node;
                chd->id = ++used;
                mapping[used] = chd;
            }
            curr = chd;
        }
        curr->strs.push_back(id);
    }
    void buildFail() {
        std::queue<Node*> queue;
        // 安排根节点的子节点
        for (auto& ref : root->chd) {
            if (ref) {
                ref->fail = root;
                queue.push(ref);
                graph[1].push_back(ref->id);
            } else {
                ref = root;
            }
        }
        while (!queue.empty()) {
            Node* front = queue.front();
            queue.pop();
            for (char chr = 'a'; chr <= 'z'; chr++) {
                auto& ptr = front->access(chr);
                if (ptr) {
                    queue.push(ptr);
                    ptr->fail = front->fail->access(chr);
                    graph[ptr->fail->id].push_back(ptr->id);
    #ifdef DEBUG
                    cout << "edge " << ptr->fail->id << " to " << ptr->id << endl;
    #endif
                } else {
                    ptr = front->fail->access(chr);
                }
            }
        }
    }
    int_t counts[int_t(2e5 + 1)];
    void DFS(int_t vtx) {
        for (auto chd : graph[vtx]) {
            DFS(chd);
            mapping[vtx]->mark += mapping[chd]->mark;
        }
        Node* curr = mapping[vtx];
        for (auto x : curr->strs) {
            counts[x] += curr->mark;
        }
    }
    int main() {
        std::ios::sync_with_stdio(false);
        root->id = 1;
        mapping[1] = root;
        int_t n;
        cin >> n;
        static char buf[int_t(3e6 + 10)];
        for (int_t i = 1; i <= n; i++) {
            cin >> buf;
            insert(buf, i);
        }
        buildFail();
        cin >> buf;
        auto curr = root;
    
        for (auto ptr = buf; *ptr; ptr++) {
            auto chd = curr->access(*ptr);
            chd->mark += 1;
            curr = chd;
        }
        DFS(1);
        for (int_t i = 1; i <= n; i++)
            cout << counts[i] << endl;
        return 0;
    }

     

  • 洛谷1368 工艺

    题意:给定一个序列,求其所有旋转同构序列中字典序最小的序列。

    将原序列复制一份然后拼到原序列后面,例如

    10 9 8 7 6 5 4 3 2 1变为10 9 8 7 6 5 4 3 2 1 10 9 8 7 6 5 4 3 2 1

    然后对拼接后的串建立后缀自动机:

    然后在后缀自动机上走n步,n是原串长,每次选当前节点出边中字典序最小的边走,所经过的序列就是符合要求的序列。

    以下代码会MLE,因为我用了指针,换成数组版本即可。

    #include <iostream>
    #include <map>
    #include <vector>
    #include <algorithm>
    using std::cin;
    using std::cout;
    using std::endl;
    using int_t = int;
    struct Node
    {
        //这个点所表示的最长长度
        int_t max;
        std::map<int_t, Node *> chds;
        Node *link = nullptr;
        Node *&access(int_t x)
        {
            return chds[x];
        }
        Node(int max)
        {
            this->max = max;
        }
        Node *clone()
        {
            return new Node(*this);
        }
    } __attribute__((aligned(1)));
    
    Node *root = new Node(0);
    Node *last = root;
    
    void append(char chr)
    {
        Node *next = new Node(last->max + 1);
        Node *curr = last;
        while (curr && curr->access(chr) == nullptr)
        {
            curr->access(chr) = next;
            curr = curr->link;
        }
        if (curr == nullptr)
        {
            next->link = root;
        }
        else if (curr->access(chr)->max == curr->max + 1)
        {
            next->link = curr->access(chr);
        }
        else
        {
            Node *newNode = curr->access(chr)->clone();
            Node *oldNode = curr->access(chr);
            newNode->max = curr->max + 1;
            oldNode->link = next->link = newNode;
            while (curr && curr->access(chr) == oldNode)
            {
                curr->access(chr) = newNode;
                curr = curr->link;
            }
            // oldNode->link = newNode;
        }
        last = next;
    }
    
    int main()
    {
        std::vector<int> seq;
        int n;
        cin >> n;
        for (int i = 1; i <= n; i++)
        {
            int x;
            cin >> x;
            seq.push_back(x);
        }
        for (int x : seq)
            append(x);
        for (int x : seq)
            append(x);
        Node *curr = root;
        for (int i = 1; i <= n; i++)
        {
            auto next = *curr->chds.begin();
            cout << next.first << " ";
            curr = next.second;
        }
        return 0;
    }

     

     

  • AC自动机模板

    我在这里卡了一个周了,终于写出来了qwq

     当走到一个节点时,一定要沿着失配路径走,一直走到根节点,因为失配路径所经过的节点已定能匹配到!

    /*
     * To change this license header, choose License Headers in Project Properties.
     * To change this template file, choose Tools | Templates
     * and open the template in the editor.
     */
    
    /* 
     * File:   main.cpp
     * Author: Ytong
     *
     * Created on 2017年12月3日, 下午7:11
     */
    #pra\
    gma GCC optimize("O3")
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    
    #include 
    #include 
    #include 
    #include 
    using int_t = long long int;
    using std::string;
    using std::cout;
    using std::endl;
    using std::strlen;
    using std::cin;
    using std::queue;
    using std::vector;
    using std::ifstream;
    using std::max;
    
    enum MemoryType {
        preAllocation, dynamicAllocation
    };
    
    struct Node {
        Node* children[26];
        int_t id = 0;
        Node* fail = nullptr;
        //标明这个节点是如何分配内存的
        MemoryType type = MemoryType::preAllocation;
    //当前节点的某个子节点是否在构建失配路径前就已经存在
        bool raw[26];
    
        inline Node* &getChild(char chr) {
            return children[chr - 'a'];
        }
    };
    int_t used;
    //预分配一块内存,用来加快程序分配内存的速度
    char memPool[sizeof (Node)*20000 + 1];
    const int_t size = sizeof (memPool) / sizeof (Node);
    Node* root = nullptr;
    char buff[1000000 + 10];
    char patterns[200][100];
    int_t result[200];
    //获取一个新的Node
    Node* nextNewNode() {
        if (used < size) //预分配的内存还有空位的情况下 直接使用定位new在预分配的内存中分一个Node return new(memPool + (used++) * sizeof (Node)) Node; else { //否则就用new在堆里分配内存 Node* x = new Node; memset(x->children, 0, sizeof (x->children));
            memset(x->raw, 1, sizeof (x->raw));
            x->type = MemoryType::dynamicAllocation;
            return x;
        }
    }
    //将字符串插入到Trie
    inline void insert(const char * str, int_t id) {
        int_t length = strlen(str);
        Node* ptr = root;
        for (int_t i = 0; i < length; i++) {
            //    cout << i << endl;
            //    if (i % 10 == 0) cout << i << endl; if (ptr->getChild(str[i]) == nullptr) {
                ptr->getChild(str[i]) = nextNewNode();
            }
            ptr = ptr->getChild(str[i]);
        }
        ptr->id = id;
    
    }
    //构建失配路径
    inline void buildUpFailPtrs() {
        queue<Node* > qu;
        for (char chr = 'a'; chr <= 'z'; chr++) { //根节点所有子节点的失配路径指向根 if (root->getChild(chr)) {
                root->getChild(chr)->fail = root;
                qu.push(root->getChild(chr));
            }
        }
        while (qu.empty() == false) {
            Node* front = qu.front();
            qu.pop();
            for (char chr = 'a'; chr <= 'z'; chr++) { Node* & chd = front->getChild(chr);
                if (chd == nullptr) {
                    //如果一个节点不存在 那么就补上这条路径 可降低匹配函数中的代码量
                    chd = front->fail->getChild(chr);
                    front->raw[chr - 'a'] = false;
                } else {
                    Node* fail = front->fail;
                    while (fail && (fail->getChild(chr) == nullptr)) {
                        fail = fail->fail;
                    }
                    if (fail == nullptr) fail = root;
                    chd->fail = fail->getChild(chr);
                    if (chd->fail == nullptr) chd->fail = root;
                    qu.push(chd);
                }
            }
        }
    }
    
    inline void match(char * str) {
        memset(result, 0, sizeof (result));
        int_t length = strlen(str);
        Node* ptr = root;
        for (int_t i = 0; i < length; i++) {
            //  if (i % 10 == 0) cout << i << endl; ptr = ptr->getChild(str[i]);
            //        cout << "moving to " << ptr << " with str " << str[i] << endl; if (ptr == nullptr) { ptr = root; continue; } Node* temp = ptr; //重要!!!一定要沿着当前节点的失配路径已知往回走,因为这些失配路径上的字符串一定匹配到了!! while (temp) { // if(temp->id==0) break;
                if (temp->id) {
                    //    cout << "found " << patterns[temp->id] << " at pos " << i;
                    //       cout << "result " << temp->id << " from " << result[temp->id] << " to " << result[temp->id] + 1 << endl; result[temp->id]++;
                }
    
                temp = temp->fail;
            }
    
        }
    
    }
    //删掉树
    void remove(Node* r) {
        if (r == nullptr) return;
        for (char chr = 'a'; chr <= 'z'; chr++) { if (r->raw[chr - 'a'])
                remove(r->getChild(chr));
        }
        //只有某个节点是使用堆分配的内存时才能delete
        //定位new不能用delete
        if (r->type == MemoryType::dynamicAllocation)
            delete r;
    }
    
    int main() {
        int_t n;
        while (true) {
            remove(root);
            memset(memPool, 0, sizeof (memPool));
            used = 0;
            cin>>n;
            if (n == 0) break;
    
            root = nextNewNode();
            for (int_t i = 1; i <= n; i++) { cin >> patterns[i];
                insert(patterns[i], i);
            }
            buildUpFailPtrs();
            cin>>buff;
            match(buff);
            int_t maxTime = 0;
            for (int_t i = 1; i <= n; i++) {
                maxTime = max(maxTime, result[i]);
            }
            cout << maxTime << endl;
            for (int_t i = 1; i <= n; i++) {
                if (result[i] == maxTime) {
                    cout << patterns[i] << endl;
                }
            }
        }
        return 0;
    }
    
  • KMP字符串匹配算法

    KMP字符串匹配算法可以理解为一种经过优化的暴力匹配。

     

    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    using namespace std;
    using int_t = long long int;
    
    int_t nxt[1000000 + 1];
    
    void getNext(string pattern) {
        nxt[0] = -1;
        int_t k = -1;
        int_t j = 0;
        while (j < pattern.length()) {
            if (k == -1 || pattern[j] == pattern[k]) {
                if (pattern[++j] == pattern[++k]) {
                    nxt[j] = nxt[k];
                } else {
                    nxt[j] = k;
                }
            } else {
                k = nxt[k];
            }
        }
    }
    
    int main() {
        srand(time(0));
        string pattern, source;
        memset(nxt, 0, sizeof (nxt));
        int_t x1, x2;
        cin >> x1>>x2;
        for (int_t i = 0; i < x1; i++) {
            source = source + char(rand() % 5 + 'A');
        }
        for (int_t i = 0; i < x2; i++) {
            pattern = pattern + char(rand() % 5 + 'A');
        }
        cout << source << endl << pattern << endl;
        int_t i = 0, j = 0;
        getNext(pattern);
        while (i < source.length() && j < (signed long long) pattern.length()) {
            //cout << "x:i=" << i << " j=" << j << endl;
            if (j == -1 || source[i] == pattern[j]) {
                //cout << "i=" << i << " j=" << j << " ok" << endl;
                i++;
                j++;
            } else {
                //  cout << "j=" << j << " to " << nxt[j] << endl;
                j = nxt[j];
            }
            if (j == pattern.length()) {
                cout << i - j + 1 << endl;
                i = i - j + 2;
                j = 0;
            }
            // cout << "y:i=" << i << " j=" << j << endl;
            ///  cout << (i < source.length()) << " " << (j < (signed long long )pattern.length()) << endl;
        }
        for (int_t i = 0; i < pattern.length(); i++) cout << nxt[i] << " ";
        cout << endl;
        main();
        return 0;
    }
    
  • Trie树

    Trie树是一种用来保存字符串集合的树。

    Trie树可在$$O(n)$$的时间内将一个字符串插入该集合中,也可以在$$O(n)$$的时间内确定一个字符串是否在集合中。

    这是一颗Trie树(图自https://baike.baidu.com/pic/%E5%AD%97%E5%85%B8%E6%A0%91/9825209/0/9252ae7e96e893610dd7da67?fr=lemma&ct=single#aid=0&pic=9252ae7e96e893610dd7da67)

    每i层的节点存储字符串中的第i个字符。

    代码实现:

     

     

    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    #include 
    using namespace std;
    using int_t = unsigned long long int;
    
    struct TrieNode {
        TrieNode* chd[26];
        int_t value = 0;
    
        TrieNode() {
            memset(chd, 0, sizeof (chd));
        }
    };
    TrieNode* root = new TrieNode;
    
    void insert(string str, int_t v) {
        TrieNode * ptr = root;
        for (int_t i = 0; i < str.length(); i++) {
            if (ptr->chd[str[i] - 'a'] == nullptr) {
                ptr->chd[str[i] - 'a'] = new TrieNode;
            }
            ptr = ptr->chd[str[i] - 'a'];
        }
        ptr->value = v;
    
    }
    vector vec;
    
    void map(TrieNode * v) {
        if (v->value) {
            for (char chr : vec) cout << chr;
            cout << endl;
        }
        for (int_t i = 0; i <= 'z' - 'a'; i++) {
            if (v->chd[i]) {
                vec.push_back(i + 'a');
                map(v->chd[i]);
                vec.pop_back();
            }
        }
    }
    
    bool query(string str) {
        TrieNode * ptr = root;
        for (int_t i = 0; i < str.length(); i++) {
            if (ptr->chd[str[i] - 'a'] == nullptr) {
                return false;
            }
            ptr = ptr->chd[str[i] - 'a'];
        }
        if (ptr->value == 0) return false;
        return true;
    
    }
    
    int main() {
        while (true) {
            string opt;
            cin>>opt;
            if (opt == "exit") break;
            else if (opt == "insert") {
                string str;
                cin>>str;
                insert(str, 1);
            } else if (opt == "query") {
                string str;
                cin>>str;
                cout << query(str) << endl;
            } else if (opt == "map") {
                map(root);
            }
        }
    
        return 0;
    
    }