标签: 树状数组

  • NOI2011 阿狸的打字机

    询问离线,按照y分类,构建出Fail树,询问(x,y)等价于询问,有多少个Trie树上从y到根节点的路径上经过的点,出现在Fail树上x的子树内。

    #include <cstring>
    #include <iostream>
    #include <queue>
    #include <string>
    #include <vector>
    using int_t = long long int;
    using std::cin;
    using std::cout;
    using std::endl;
    const int_t LARGE = 1e5;
    int_t vtxID;
    struct Node {
        Node* trieChd[26];
        Node* link = nullptr;
        Node* parent = nullptr;
        int_t id;
        std::vector<int_t> strID;
        Node*& access(char chr) { return trieChd[chr - 'a']; }
        Node() {
            id = ++vtxID;
            memset(trieChd, 0, sizeof(trieChd));
        }
    };
    std::vector<int_t> failTree[LARGE + 1];
    int_t arr[LARGE + 1];
    Node* root = new Node();
    Node* strs[LARGE + 1];
    int_t DFSN[LARGE + 1];
    int_t result[LARGE + 1];
    int_t size[LARGE + 1];
    int_t m;
    struct Query {
        int_t y, x;
        int_t id;
    };
    std::vector<Query> querys[LARGE + 1];
    void DFS(int_t vtx) {
        DFSN[vtx] = ++DFSN[0];
        size[vtx] = 1;
        for (int_t to : failTree[vtx]) {
            DFS(to);
            size[vtx] += size[to];
        }
    }
    int_t lowbit(int_t x) { return x & (-x); }
    void add(int_t pos, int_t val) {
        while (pos <= vtxID) {
            arr[pos] += val;
            pos += lowbit(pos);
        }
    }
    int_t query(int_t pos) {
        int_t result = 0;
        while (pos >= 1) {
            result += arr[pos];
            pos -= lowbit(pos);
        }
        return result;
    }
    int_t query(int_t left, int_t right) { return query(right) - query(left - 1); }
    void DFS2(Node* node) {
        add(DFSN[node->id], 1);
        if (node->strID.empty() == false) {
            for (int_t id : node->strID) {
                for (const auto& query : querys[id]) {
                    result[query.id] = ::query(
                        DFSN[strs[query.x]->id],
                        DFSN[strs[query.x]->id] - 1 + size[strs[query.x]->id]);
                }
            }
        }
        for (char chr = 'a'; chr <= 'z'; chr++) {
            if (node->access(chr)) {
                Node* to = node->access(chr);
                DFS2(to);
            }
        }
        add(DFSN[node->id], -1);
    }
    int main() {
    #ifdef DEBUG
        freopen("qwq.txt", "r", stdin);
    #endif
        static char buf[LARGE + 10];
        int_t used = 0;
        scanf("%s", buf);
        Node* node = root;
        for (char* ptr = buf; *ptr != '\0'; ptr++) {
            if (*ptr == 'B') {
                node = node->parent;
    #ifdef DEBUG
                cout << "back to " << node->id << endl;
    #endif
            } else if (*ptr == 'P') {
                node->strID.push_back(++used);
                strs[used] = node;
    #ifdef DEBUG
                cout << "print at " << node->id << " strid = " << used << endl;
    #endif
            } else {
                if (node->access(*ptr) == nullptr) {
                    Node*& next = node->access(*ptr);
                    next = new Node;
                    next->parent = node;
                }
                node = node->access(*ptr);
    #ifdef DEBUG
                cout << "walk with " << *ptr << " to " << node->id << endl;
    #endif
            }
        }
        std::queue<Node*> queue;
        for (char chr = 'a'; chr <= 'z'; chr++)
            if (root->access(chr)) {
                root->access(chr)->link = root;
                queue.push(root->access(chr));
                failTree[root->id].push_back(root->access(chr)->id);
            }
        while (queue.empty() == false) {
            Node* front = queue.front();
            queue.pop();
            for (char chr = 'a'; chr <= 'z'; chr++) {
                if (front->access(chr)) {
                    Node*& link = front->access(chr)->link;
                    Node* curr = front->link;
                    while (curr && curr->access(chr) == nullptr) curr = curr->link;
                    if (curr == nullptr)
                        link = root;
                    else
                        link = curr->access(chr);
                    failTree[link->id].push_back(front->access(chr)->id);
    #ifdef DEBUG
                    cout << "fail " << front->access(chr)->id << " = " << link->id
                         << endl;
    #endif
                    queue.push(front->access(chr));
                }
            }
        }
        DFS(1);
        scanf("%lld", &m);
        for (int_t i = 1; i <= m; i++) {
            int_t x, y;
            scanf("%lld%lld", &x, &y);
            querys[y].push_back(Query{y, x, i});
        }
        DFS2(root);
        for (int_t i = 1; i <= m; i++) {
            printf("%d\n", (int)result[i]);
        }
        return 0;
    }

     

  • Luogu4113 HEOI2012 采花

    题意:

    求一段区间内,出现次数至少为2的数的个数。

    可以考虑转化成像HH的项链一样的模型。

    扫描线+树状数组。

    考虑左端弹出一个数会造成什么影响,设弹出的数为x,那么一直到x的下一次出现位置的这个区间不受任何影响(因为x只出现了0次或一次,不计入答案),然而对于x的下一次出现位置,到下下一次出现位置之间的答案会减一,因为x本来在这些地方出现了两次,但是我们刚刚把他删了,所以就剩下一次了,对于更靠右的地方,答案不变,因为那些地方x至少出现了三次。

    #include <algorithm>
    #include <cstdio>
    #include <iostream>
    
    using int_t = int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 2e6 + 10;
    
    struct Operation {
        int_t left, right;
        int_t result;
        int_t id;
        bool operator<(const Operation& x) const { return left < x.left; }
    } opts[LARGE + 1];
    int_t n, c, m;
    int_t arr[LARGE + 1];
    int_t next[LARGE + 1];
    int_t prev[LARGE + 1];
    //颜色k最后一次出现的位置
    int_t last[LARGE + 1];
    int_t colors[LARGE + 1];
    inline int_t lowbit(int_t x) { return x & (-x); }
    void add(int_t pos, int_t val) {
        while (pos <= n) {
            arr[pos] += val;
            pos += lowbit(pos);
        }
    }
    int_t query(int_t pos) {
        int_t result = 0;
        while (pos >= 1) {
            result += arr[pos];
            pos -= lowbit(pos);
        }
        return result;
    }
    int main() {
        scanf("%d%d%d", &n, &c, &m);
        for (int_t i = 1; i <= n; i++) scanf("%d", &colors[i]);
        for (int_t i = 1; i <= m; i++) {
            auto& thiz = opts[i];
            scanf("%d%d", &thiz.left, &thiz.right);
            thiz.id = i;
        }
        std::sort(opts + 1, opts + 1 + m);
        for (int_t i = 1; i <= n; i++) {
            int_t color = colors[i];
            if (last[color] == 0) {
                last[color] = i;
                prev[color] = -1;
            } else {
                next[last[color]] = i;
                prev[i] = last[color];
                last[color] = i;
            }
        }
        static int_t count[LARGE + 1];
        int_t curr = 0;
        for (int_t i = 1; i <= n; i++) {
            if (next[i] == 0) next[i] = n + 1;
            count[colors[i]] += 1;
            if (count[colors[i]] == 2) {
    #ifdef DEBUG
                cout << "color " << i << " = " << colors[i]
                     << " count = " << count[colors[i]] << endl;
    #endif
                curr += 1;
            }
            add(i, curr);
            add(i + 1, -curr);
        }
        int_t pos = 1;
        for (int_t i = 1; i <= m; i++) {
            auto& curr = opts[i];
            while (pos < curr.left) {
                int_t next0 = next[pos];
                if (next0 <= n) {
                    int_t next1 = next[next0];
                    add(next0, -1);
                    add(next1, 1);
                }
                pos++;
            }
            curr.result = query(curr.right);
        }
        std::sort(opts + 1, opts + 1 + m,
                  [](const Operation& a, const Operation& b) -> bool {
                      return a.id < b.id;
                  });
        for (int_t i = 1; i <= m; i++) {
            printf("%d\n", opts[i].result);
        }
        return 0;
    }

     

  • SDOI2009 HH的项链

    区间内不同的数的个数。

    考虑离线做法。

    预处理出区间$[1,i],i\in [1,n]$这$n$个区间的答案,这个显然是可以$O(n)$预处理出来的.

    然后考虑一下,现在我们知道了区间$[1,i]$的答案,如何转移到区间$[2,i]$的答案.

    从区间$[1,i]$转移到$[2,i]$的过程中,删除了第一个数,这会造成什么影响呢?

    从第一个数,到他下一次出现的位置,这之间所有位置都失去了1这个数,所以他们的答案都应该减1

    这就变成了区间减法和单点查询。

    把询问按照左端点排序,然后线段树或者树状数组维护即可。

    如果要在线的话,就直接依次删掉$1-n$位置的数然后用主席树可持久化即可。

    离线版本:

    // luogu-judger-enable-o2
    #include <iostream>
    #include <algorithm>
    #include <vector>
    #include <queue>
    #include <deque>
    using int_t = int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 500001;
    
    struct Query
    {
        int_t left;
        int_t right;
        int_t result;
        int_t id;
    } querys[LARGE + 1];
    int_t n, m;
    int_t colors[LARGE + 1];
    int_t arr[LARGE + 1];
    int_t prefix[LARGE + 1];
    std::vector<int_t> numbers;
    int_t prev[LARGE + 1];
    int_t next[LARGE + 1];
    int_t lowbit(int_t x)
    {
        return x & -x;
    }
    int_t query(int_t pos)
    {
        int_t result = 0;
        while (pos >= 1)
        {
            result += arr[pos];
            pos -= lowbit(pos);
        }
        return result;
    }
    void modify(int_t pos, int_t x)
    {
        while (pos <= n)
        {
            arr[pos] += x;
            pos += lowbit(pos);
        }
    }
    void modify(int_t left, int_t right, int_t x)
    {
        modify(left, x);
        modify(right + 1, -x);
    }
    int main()
    {
        scanf("%d", &n);
        for (int_t i = 1; i <= n; i++)
        {
            scanf("%d", &colors[i]);
            numbers.push_back(colors[i]);
        }
        std::sort(numbers.begin(), numbers.end());
        numbers.resize(std::unique(numbers.begin(), numbers.end()) - numbers.begin());
        for (int_t i = 1; i <= n; i++)
        {
            colors[i] = std::lower_bound(numbers.begin(), numbers.end(), colors[i]) - numbers.begin() + 1;
            prefix[i] = prefix[i - 1];
            if (prev[colors[i]] == 0)
            {
                prefix[i] += 1;
                prev[colors[i]] = i;
            }
            else
            {
                next[prev[colors[i]]] = i;
                prev[colors[i]] = i;
            }
            modify(i, i, prefix[i]);
        }
        for (int_t i = 1; i <= n; i++)
            if (next[i] == 0)
                next[i] = -1;
        scanf("%d", &m);
        for (int_t i = 1; i <= m; i++)
        {
            querys[i].id = i;
            scanf("%d%d", &querys[i].left, &querys[i].right);
        }
        std::sort(querys + 1, querys + 1 + m, [](const Query &a, const Query &b) -> bool { return a.left < b.left; });
        int_t ptr = 1;
        for (int_t i = 1; i <= m; i++)
        {
            auto &query = querys[i];
            while (ptr < query.left)
            {
                if (next[ptr] != -1)
                {
                    modify(ptr, next[ptr] - 1, -1);
                }
                else
                {
                    modify(ptr, n, -1);
                }
                ptr++;
            }
            query.result = ::query(query.right);
        }
        std::sort(querys + 1, querys + 1 + m, [](const Query &a, const Query &b) -> bool { return a.id < b.id; });
        for (int_t i = 1; i <= m; i++)
            printf("%d\n", querys[i].result);
        return 0;
    }

    在线版本:

    (洛谷MLE两个点)

    #include <iostream>
    #include <algorithm>
    #include <vector>
    #include <queue>
    #include <deque>
    #include <assert.h>
    using int_t = int;
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 500001;
    
    struct Query
    {
        int_t left;
        int_t right;
        int_t result;
        int_t id;
    } querys[LARGE + 1];
    int_t n, m;
    int_t colors[LARGE + 1];
    int_t prefix[LARGE + 1];
    std::vector<int_t> numbers;
    int_t prev[LARGE + 1];
    int_t next[LARGE + 1];
    struct Node
    {
        int_t begin;
        int_t end;
        int_t value;
        int_t mark;
        Node *left = nullptr;
        Node *right = nullptr;
        Node *clone()
        {
            Node *x = new Node(begin, end);
            *x = *this;
            return x;
        }
        Node(int_t begin, int_t end)
        {
            this->begin = begin;
            this->end = end;
            value = mark = 0;
        }
        int_t query(int_t begin, int_t end, int_t sum = 0)
        {
            int_t result = 0;
            if (this->begin == begin && this->end == end)
                result = value + (end - begin + 1) * sum;
            else
            {
                int_t mid = (this->begin + this->end) / 2;
    
                if (end <= mid)
                    result = left->query(begin, end, sum + mark);
                else if (begin > mid)
                    result = right->query(begin, end, sum + mark);
                else
                    result = left->query(begin, mid, sum + mark) + right->query(mid + 1, end, sum + mark);
            }
            return result;
        }
        void maintain()
        {
            value = left->value + right->value;
        }
        static Node *build(int_t begin, int_t end, int_t *arr)
        {
            Node *node = new Node(begin, end);
            if (begin != end)
            {
                int_t mid = (begin + end) / 2;
                node->left = build(begin, mid, arr);
                node->right = build(mid + 1, end, arr);
                node->maintain();
            }
            else
            {
                node->value = arr[begin];
            }
            return node;
        }
    };
    Node *buildNext(Node *prev, int_t begin, int_t end, int_t value)
    {
        Node *node = prev->clone();
        node->value += (end - begin + 1) * value;
        if (node->begin == begin && node->end == end)
        {
            node->mark += value;
        }
        else
        {
    
            int_t mid = (prev->begin + prev->end) / 2;
            if (end <= mid)
            {
                node->left = buildNext(prev->left, begin, end, value);
            }
            else if (begin > mid)
            {
                node->right = buildNext(prev->right, begin, end, value);
            }
            else
            {
                node->left = buildNext(prev->left, begin, mid, value);
                node->right = buildNext(prev->right, mid + 1, end, value);
            }
        }
    
        return node;
    }
    int main()
    {
        scanf("%d", &n);
        for (int_t i = 1; i <= n; i++)
        {
            scanf("%d", &colors[i]);
            numbers.push_back(colors[i]);
        }
        std::sort(numbers.begin(), numbers.end());
        numbers.resize(std::unique(numbers.begin(), numbers.end()) - numbers.begin());
        for (int_t i = 1; i <= n; i++)
        {
            colors[i] = std::lower_bound(numbers.begin(), numbers.end(), colors[i]) - numbers.begin() + 1;
            prefix[i] = prefix[i - 1];
            if (prev[colors[i]] == 0)
            {
                prefix[i] += 1;
                prev[colors[i]] = i;
            }
            else
            {
                next[prev[colors[i]]] = i;
                prev[colors[i]] = i;
            }
        }
        static Node *roots[LARGE + 1] = {Node::build(1, n, prefix)};
        for (int_t i = 1; i <= n; i++)
        {
            //处理删掉第i个数后的结果
            if (next[i] == 0)
            {
                roots[i] = buildNext(roots[i - 1], i, n, -1);
            }
            else
            {
                roots[i] = buildNext(roots[i - 1], i, next[i] - 1, -1);
            }
        }
        scanf("%d", &m);
        for (int_t i = 1; i <= m; i++)
        {
            int_t left, right;
            scanf("%d%d", &left, &right);
            printf("%d\n", roots[left - 1]->query(right, right));
        }
        return 0;
    }

     

  • 树状数组实现区间操作

    使用树状数组维护区间和,并支持区间加。

    这玩意除了常数比线段树小以外没有任何优点。

    $$ \text{仍然将原序列差分。} \\ \text{设差分后的序列为}d_1,d_2….d_N \\ \text{序列中的第}n\text{项则为}a_n=\sum_{1\le i\le n}{d_i} \\ \text{则序列前}n\text{项的和为} \\ s_n=\sum_{1\le i\le n}{\sum_{1\le j\le i}{d_j}}=\sum_{1\le i\le n}{d_i\sum_{i\le j\le n}{1}=\sum_{1\le i\le n}{di\left( n-i+1 \right)}} \\ \text{整理一下} \\ s_n=\left( n+1 \right) \times \sum_{1\le i\le n}{d_i}-\sum_{1\le i\le n}{i\times d_i} \\ \text{然后分别用两个树状数组维护差分序列}d_i\text{和}i\times d_i\text{即可} \\ $$

    #include <iostream>
    
    using int_t = long long int;
    
    using std::cin;
    using std::cout;
    using std::endl;
    
    const int_t LARGE = 1000001;
    
    int_t diff[LARGE + 1];
    int_t diff2[LARGE + 1];
    int n, m;
    inline int_t lowbit(int_t x)
    {
        return x & -x;
    }
    
    void addpos(int_t *arr, int_t pos, int_t n, int_t value)
    {
        while (pos <= n)
        {
            arr[pos] += value;
            pos += lowbit(pos);
        }
    }
    
    int_t querypos(int_t *arr, int_t pos)
    {
        int_t result = 0;
        while (pos >= 1)
        {
            result += arr[pos];
            pos -= lowbit(pos);
        }
        return result;
    }
    
    void add(int_t left, int_t right, int_t val)
    {
        addpos(diff, left, n, val);
        addpos(diff, right + 1, n, -val);
    
        addpos(diff2, left, n, left * val);
        addpos(diff2, right + 1, n, (right + 1) * -val);
    }
    int_t query(int_t pos)
    {
        return ((pos + 1) * querypos(diff, pos) - querypos(diff2, pos));
    }
    int main()
    {
        scanf("%d%d", &n, &m);
        for (int_t i = 1; i <= n; i++)
        {
            int x;
            scanf("%d", &x);
            add(i, i, x);
        }
        for (int_t i = 1; i <= m; i++)
        {
            int opt, left, right, value;
            scanf("%d%d%d", &opt, &left, &right);
            switch (opt)
            {
            case 1:
                scanf("%d", &value);
                add(left, right, value);
                break;
            case 2:
                printf("%lld\n", query(right) - query(left - 1));
            }
        }
        return 0;
    }