CQOI2011 动态逆序对

首先考虑如何转成一个偏序问题。

把样例的过程画出图来。

可以注意到一个时间点t产生的逆序对,在比他靠前的时间里仍然存在。

所以我们可以考虑单独计算每个时间点因为删掉元素而丢掉的逆序对数。

考虑搞出来n个三元组,每个形如(time,pos,val),表示第time个时间点把位于pos的元素val删掉。

如果有元素没有被删除那么他们的time顺次往下排即可。

一个三元组a会对另一个时间的三元组b产生1的贡献,当且仅当a.time>b.time(a比b晚删除,根据上面的图感性理解一下),然后(a.pos<b.pos且a.val>b.val)或者(a.pos>b.pos且a.val<b.val)。

因为我们统计的是每个时间点因为删除而消失的逆序对个数,求一个后缀和,然后输出前m个就是答案。

注意删除的是元素而不是下标!!

#include <assert.h>
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;

const int_t LARGE = 5e5;

struct Triple {
    int_t time, pos, val;
    int_t* target;
} datas[LARGE + 1];
int_t used = 0;
int_t results[LARGE + 1];
int_t perm[LARGE + 1];
int_t n, m;
int_t arr[LARGE + 1];

std::ostream& operator<<(std::ostream& os, const Triple& a) {
    os << "Triple{time=" << a.time << ",pos=" << a.pos << ",val=" << a.val
       << ",target=" << (a.target - results) << ",targetval=" << *a.target
       << "}";
    return os;
}
int_t lowbit(int_t x) { return x & -x; }
int_t query(int_t x) {
#ifdef DEBUG
    cout << "query " << x << " = ";
#endif
    int_t result = 0;
    while (x >= 1) {
        result += arr[x];
        assert(arr[x] >= 0);
        x -= lowbit(x);
    }
    return result;
}
void add(int_t pos, int_t val) {
    while (pos <= n) {
        arr[pos] += val;
        pos += lowbit(pos);
    }
}
void reset(int_t pos) {
    while (pos <= n) {
        arr[pos] = 0;
        pos += lowbit(pos);
    }
}
void process(Triple* left, Triple* right) {
    if (left >= right) return;
    auto mid = left + (right - left) / 2;
    process(left, mid);
    process(mid + 1, right);
#ifdef DEBUG

    cout << "processing left\n";
    for (auto ptr = left; ptr <= mid; ptr++) cout << *ptr << endl;
    cout << "processing right\n";
    for (auto ptr = mid + 1; ptr <= right; ptr++) cout << *ptr << endl;

#endif
    auto lptr = left, rptr = mid + 1;
    std::vector<Triple> result;
#ifdef DEBUG

    cout << "before merge ";
    for (int_t i = 1; i <= n; i++) cout << arr[i] << " ";
    cout << endl;
#endif
    //统计a.pos<b.pos且a.val>b.val的个数
    while (lptr <= mid && rptr <= right) {
#ifdef DEBUG
        cout << "lptr pos = " << lptr->pos << " rptr pos = " << rptr->pos
             << endl;
#endif
        if (lptr->pos <= rptr->pos) {
            add(lptr->val, 1);
            result.push_back(*lptr);
#ifdef DEBUG
            cout << "put " << *lptr << " to left" << endl;
#endif
            lptr++;

        } else {
            *rptr->target += query(n) - query(rptr->val);
#ifdef DEBUG
            cout << "put " << *rptr << " to right "
                 << "with val " << query(n) << " - " << query(rptr->val)
                 << endl;
#endif
            result.push_back(*rptr);
            rptr++;
        }
    }
    while (rptr <= right) {
        *rptr->target += query(n) - query(rptr->val);
#ifdef DEBUG
        cout << "put " << *rptr << " to right "
             << "with val " << query(n) << " - " << query(rptr->val) << endl;
        cout << "n = " << n << ",rptr val = " << rptr->val << endl;
#endif
        result.push_back(*rptr);
        rptr++;
    }
    while (lptr <= mid) result.push_back(*(lptr++));
    assert(result.size() == right - left + 1);
    for (auto p = left; p <= right; p++) reset(p->val);
#ifdef DEBUG
    cout << "sort ok0 " << endl;
    for (const auto& x : result) cout << x << endl;

#endif
    //统计a.pos>b.pos且a.val<b.val的个数
    lptr = mid, rptr = right;
    //倒着归并
    while (lptr >= left && rptr >= mid + 1) {
        if (lptr->pos > rptr->pos) {
            add(lptr->val, 1);
            lptr--;
        } else {
            *rptr->target += query(rptr->val - 1);
            rptr--;
        }
    }
    while (rptr >= mid + 1) {
        *rptr->target += query(rptr->val - 1);
        rptr--;
    }
    for (auto p = left; p <= right; p++) reset(p->val);

    std::copy(result.begin(), result.end(), left);
#ifdef DEBUG
    cout << "sort ok1 " << endl;
    for (auto p = left; p <= right; p++) cout << *p << endl;
    cout << endl;
#endif
}
int main() {
    scanf("%lld%lld", &n, &m);
    for (int_t i = 1; i <= n; i++) {
        scanf("%lld", &perm[i]);
    }
    static int_t inv[LARGE + 1];
    for (int_t i = 1; i <= n; i++) inv[perm[i]] = i;
    for (int_t i = 1; i <= m; i++) {
        //删的是元素,而不是下标!
        int_t val;
        scanf("%lld", &val);
        used++;
        datas[used] = Triple{used, inv[val], val, &results[i]};
        perm[inv[val]] = 0;
    }
    for (int_t i = 1; i <= n; i++) {
        if (perm[i]) {
            used++;
            datas[used] = Triple{used, i, perm[i], &results[used]};
        }
    }
    std::sort(datas + 1, datas + 1 + n,
              [](const Triple& a, const Triple& b) { return a.time > b.time; });
    process(datas + 1, datas + n);
    // for (int_t i = 1; i <= n; i++) cout << results[i] << endl;
    for (int_t i = n; i >= 2; i--) {
        assert(results[i] >= 0);
        results[i - 1] += results[i];
    }
    for (int_t i = 1; i <= m; i++) {
        printf("%lld\n", results[i]);
    }
    return 0;
}

 

评论

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注

这个站点使用 Akismet 来减少垃圾评论。了解你的评论数据如何被处理