首先考虑如何转成一个偏序问题。
把样例的过程画出图来。

可以注意到一个时间点t产生的逆序对,在比他靠前的时间里仍然存在。
所以我们可以考虑单独计算每个时间点因为删掉元素而丢掉的逆序对数。
考虑搞出来n个三元组,每个形如(time,pos,val),表示第time个时间点把位于pos的元素val删掉。
如果有元素没有被删除那么他们的time顺次往下排即可。
一个三元组a会对另一个时间的三元组b产生1的贡献,当且仅当a.time>b.time(a比b晚删除,根据上面的图感性理解一下),然后(a.pos<b.pos且a.val>b.val)或者(a.pos>b.pos且a.val<b.val)。
因为我们统计的是每个时间点因为删除而消失的逆序对个数,求一个后缀和,然后输出前m个就是答案。
注意删除的是元素而不是下标!!
#include <assert.h>
#include <algorithm>
#include <cstdio>
#include <iostream>
#include <vector>
using int_t = long long int;
using std::cin;
using std::cout;
using std::endl;
const int_t LARGE = 5e5;
struct Triple {
int_t time, pos, val;
int_t* target;
} datas[LARGE + 1];
int_t used = 0;
int_t results[LARGE + 1];
int_t perm[LARGE + 1];
int_t n, m;
int_t arr[LARGE + 1];
std::ostream& operator<<(std::ostream& os, const Triple& a) {
os << "Triple{time=" << a.time << ",pos=" << a.pos << ",val=" << a.val
<< ",target=" << (a.target - results) << ",targetval=" << *a.target
<< "}";
return os;
}
int_t lowbit(int_t x) { return x & -x; }
int_t query(int_t x) {
#ifdef DEBUG
cout << "query " << x << " = ";
#endif
int_t result = 0;
while (x >= 1) {
result += arr[x];
assert(arr[x] >= 0);
x -= lowbit(x);
}
return result;
}
void add(int_t pos, int_t val) {
while (pos <= n) {
arr[pos] += val;
pos += lowbit(pos);
}
}
void reset(int_t pos) {
while (pos <= n) {
arr[pos] = 0;
pos += lowbit(pos);
}
}
void process(Triple* left, Triple* right) {
if (left >= right) return;
auto mid = left + (right - left) / 2;
process(left, mid);
process(mid + 1, right);
#ifdef DEBUG
cout << "processing left\n";
for (auto ptr = left; ptr <= mid; ptr++) cout << *ptr << endl;
cout << "processing right\n";
for (auto ptr = mid + 1; ptr <= right; ptr++) cout << *ptr << endl;
#endif
auto lptr = left, rptr = mid + 1;
std::vector<Triple> result;
#ifdef DEBUG
cout << "before merge ";
for (int_t i = 1; i <= n; i++) cout << arr[i] << " ";
cout << endl;
#endif
//统计a.pos<b.pos且a.val>b.val的个数
while (lptr <= mid && rptr <= right) {
#ifdef DEBUG
cout << "lptr pos = " << lptr->pos << " rptr pos = " << rptr->pos
<< endl;
#endif
if (lptr->pos <= rptr->pos) {
add(lptr->val, 1);
result.push_back(*lptr);
#ifdef DEBUG
cout << "put " << *lptr << " to left" << endl;
#endif
lptr++;
} else {
*rptr->target += query(n) - query(rptr->val);
#ifdef DEBUG
cout << "put " << *rptr << " to right "
<< "with val " << query(n) << " - " << query(rptr->val)
<< endl;
#endif
result.push_back(*rptr);
rptr++;
}
}
while (rptr <= right) {
*rptr->target += query(n) - query(rptr->val);
#ifdef DEBUG
cout << "put " << *rptr << " to right "
<< "with val " << query(n) << " - " << query(rptr->val) << endl;
cout << "n = " << n << ",rptr val = " << rptr->val << endl;
#endif
result.push_back(*rptr);
rptr++;
}
while (lptr <= mid) result.push_back(*(lptr++));
assert(result.size() == right - left + 1);
for (auto p = left; p <= right; p++) reset(p->val);
#ifdef DEBUG
cout << "sort ok0 " << endl;
for (const auto& x : result) cout << x << endl;
#endif
//统计a.pos>b.pos且a.val<b.val的个数
lptr = mid, rptr = right;
//倒着归并
while (lptr >= left && rptr >= mid + 1) {
if (lptr->pos > rptr->pos) {
add(lptr->val, 1);
lptr--;
} else {
*rptr->target += query(rptr->val - 1);
rptr--;
}
}
while (rptr >= mid + 1) {
*rptr->target += query(rptr->val - 1);
rptr--;
}
for (auto p = left; p <= right; p++) reset(p->val);
std::copy(result.begin(), result.end(), left);
#ifdef DEBUG
cout << "sort ok1 " << endl;
for (auto p = left; p <= right; p++) cout << *p << endl;
cout << endl;
#endif
}
int main() {
scanf("%lld%lld", &n, &m);
for (int_t i = 1; i <= n; i++) {
scanf("%lld", &perm[i]);
}
static int_t inv[LARGE + 1];
for (int_t i = 1; i <= n; i++) inv[perm[i]] = i;
for (int_t i = 1; i <= m; i++) {
//删的是元素,而不是下标!
int_t val;
scanf("%lld", &val);
used++;
datas[used] = Triple{used, inv[val], val, &results[i]};
perm[inv[val]] = 0;
}
for (int_t i = 1; i <= n; i++) {
if (perm[i]) {
used++;
datas[used] = Triple{used, i, perm[i], &results[used]};
}
}
std::sort(datas + 1, datas + 1 + n,
[](const Triple& a, const Triple& b) { return a.time > b.time; });
process(datas + 1, datas + n);
// for (int_t i = 1; i <= n; i++) cout << results[i] << endl;
for (int_t i = n; i >= 2; i--) {
assert(results[i] >= 0);
results[i - 1] += results[i];
}
for (int_t i = 1; i <= m; i++) {
printf("%lld\n", results[i]);
}
return 0;
}
发表回复