AtCoder Regular Contest 043 C - 転倒距離

問題

 2つの順列について順序が入れ替わっている数字の組の個数を転倒距離と呼ぶこととする。サイズN(\le 10^5)の順列A, Bが与えられたとき、AともBとも転倒距離が等しい順列があるか判定し、ある場合には1つ挙げよ。

解法

 まず適当な数字の置き換えをしてA1, 2, \dots, Nという配列になるようにする。これによりABA'( = 1, 2, \dots, N)B'になったとすると、ABの転倒距離はB'の転倒数に等しくなる。ある順列の転倒数はマージソートのようなやり方でO(Nlog{N})で求めることができる。( tmaehara氏のスライドなどを参照のこと)

 B'の転倒数が奇数であるとき、AともBとも転倒距離が等しい順列は存在しない。(証明は解説スライドを参照のこと)

 B'の転倒数が偶数であるとき、バブルソートをシミュレーションすればちょうど転倒距離が等しいものを求めることができるが、愚直に行うとO(N^2)かかるため、セグメント木などを使ってO(N\log{N})でシミュレーションできるように高速化することでこの問題が解ける。

反省

 1時間52分(4WA)でAC。転倒数に帰着させて奇数の場合は構築不可というところまでは(感覚的に)わかったが、まず転倒数を高速に求める方法がわからなかったので30分ほど過ぎたあたりで諦めて検索をしてO(N\log{N})で求められることを知った。なんとかそれを30分ほどかけて実装したが、偶数のときに実際に構築する方法がわからず結局解説スライドを見た。

 転倒数はバブルソートと深い関係があるというのはちょっとだけ知っていたが、シミュレーションすればいいという発想に至っていなかった。さらにそれをセグメント木などを使って高速化するということで、コンテストで出されたらなかなか解けそうにない。

 Aを変換したり、最後に答えを変換しなおしたりというのを勘でやってなんか通ってしまったのであまり問題(解法)をしっかり理解できたという気がしない。

 とりあえずセグメント木に加えて転倒数も自作ライブラリに突っ込んでおいた。あまりマージソートを理解していないままなので良いこととは思えないが……。

コード

#include"bits/stdc++.h"
using namespace std;
using ll = int64_t;

class InversionCount {
public:
    InversionCount(vector<ll> target) : target_(target) {}
    ll count(ll left = 0, ll right = -1) {
        if (right < 0) {
            //rはtarget_.size()で初期化
            right = target_.size();
        }
        if (right <= left + 1) {
            return 0;
        }
        ll mid = (left + right) / 2;
        ll result = 0;
        //左半分を数える
        result += count(left, mid);

        //右半分を数える
        result += count(mid, right);

        //左右またぐ数を数える
        result += merge(left, mid, right);

        return result;
    }
private:
    ll merge(ll left, ll mid, ll right) {
        vector<ll> l, r;
        for (ll i = left; i < mid; i++) {
            l.push_back(target_[i]);
        }
        for (ll i = mid; i < right; i++) {
            r.push_back(target_[i]);
        }
        //番兵
        l.push_back(LLONG_MAX);
        r.push_back(LLONG_MAX);

        ll left_index = 0;
        ll right_index = 0;
        ll result = 0;
        for (ll i = left; i < right; i++) {
            if (l[left_index] <= r[right_index]) {
                target_[i] = l[left_index];
                left_index++;
            } else {
                target_[i] = r[right_index];
                right_index++;
                result += ((mid - left) - left_index);
            }
        }
        return result;
    }

    vector<ll> target_;
};

//1点更新,区間和
class SegmentTree {
public:
    SegmentTree(ll n, ll value) {
        n_ = (ll)pow(2, ceil(log2(n)));
        nodes_.resize(2 * n_ - 1, value);
    }
    void update(ll x, ll v) {
        nodes_[x + n_ - 1] = v;
        for (ll i = (x + n_ - 2) / 2; i > 0; i = (i - 1) / 2) {
            nodes_[i] = nodes_[2 * i + 1] + nodes_[2 * i + 2];
        }
    }

    ll getSum(ll a, ll b, ll k = 0, ll l = 0, ll r = -1) {
        if (r < 0) {
            r = n_;
        }
        if (r <= a || b <= l) {
            return 0;
        }
        if (a <= l && r <= b) {
            return nodes_[k];
        }
        ll left  = getSum(a, b, 2 * k + 1, l, (l + r) / 2);
        ll right = getSum(a, b, 2 * k + 2, (l + r) / 2, r);
        return left + right;
    }

private:
    //2のべき乗
    ll n_;
    vector<ll> nodes_;
};

int main() {
    ll N;
    cin >> N;
    vector<ll> A(N), B(N);
    vector<ll> pos(N + 1, 0), B2(N);
    for (ll i = 0; i < N; i++) {
        cin >> A[i];
        pos[A[i]] = i + 1;
    }
    for (ll i = 0; i < N; i++) {
        cin >> B[i];
        B2[i] = pos[B[i]];
    }

    InversionCount ic(B2);
    ll inversion_count = ic.count();

    if (inversion_count % 2 == 1) {
        cout << -1 << endl;
    } else {
        inversion_count /= 2;
        SegmentTree st(N + 1, 0);
        vector<ll> ans;
        for (ll i = 0; i < N; i++) {
            ll sum = st.getSum(B2[i] + 1, N + 1);
            st.update(B2[i], 1);

            //移動量
            ll v = min(sum, inversion_count);

            inversion_count -= v;
            ans.insert(ans.end() - v, B2[i]);
        }
        assert(inversion_count == 0);

        for (ll i = 0; i < N; i++) {
            cout << A[ans[i] - 1] << " \n"[i == N - 1];
        }
    }
}