AtCoder Regular Contest 101 D - Median of Medians

問題

 長さNの数列aについて(l, r)(1 \le l \le r \le N)で切り取られる部分列a_l, a_{l + 1}, \dots, a_rの中央値をm_{l,r}とする。全ての(l,r)についてm_{l,r}を並べた数列mについて中央値を求めよ。

解法

 解説そのまま。

 長さがMである数列bにおける中央値は「bの中にそれより大きいものが\left\lceil \frac{M}{2} \right\rceil個ある整数の中で最大のもの」である。つまりbの中にある数xより大きい整数が何個あるか求められれば二分探索できる。

 mの中にx以上の整数が何個入るか? という問題を考えるとaの各要素について必要な情報はx以上であるかどうかだけ。よってaの各要素をx以上なら1、そうでないなら-1と変換した数列bを考える。これの累積和を取った数列をSとして、S_{l} \le S_{r}である数を求めればよい。これは2(\frac{N(N + 1)}{2}  - (転倒数))であり、転倒数はO(N\log{N})で求まるので間に合う。

反省

 コンテスト終わったすぐ後に解説を見て、中央値は二分探索で求められるんだなーということまでは覚えていたがそこからが思い出せずまた解説を見ることに。この問題の変換はかなり難しいと思う。

 とりあえず作っておいたInversionCountのライブラリを使う機会が来た。マージソート的な実装でやっているんだけど、他の人の提出を見るとBITでやっているほうが多いかもしれない。確かに実行時間1155 msってのはちょっと不安だな。

コード

#include"bits/stdc++.h"
using namespace std;
using ll = int64_t;

class InversionCount {
public:
    InversionCount(vector<ll> target) : target_(target) {}
    ll count(ll left = 0, ll right = -1) {
        if (right < 0) {
            //rはtarget_.size()で初期化
            right = target_.size();
        }
        if (right <= left + 1) {
            return 0;
        }
        ll mid = (left + right) / 2;
        ll result = 0;
        //左半分を数える
        result += count(left, mid);

        //右半分を数える
        result += count(mid, right);

        //左右またぐ数を数える
        result += merge(left, mid, right);

        return result;
    }
private:
    ll merge(ll left, ll mid, ll right) {
        vector<ll> l, r;
        for (ll i = left; i < mid; i++) {
            l.push_back(target_[i]);
        }
        for (ll i = mid; i < right; i++) {
            r.push_back(target_[i]);
        }
        //番兵
        l.push_back(LLONG_MAX);
        r.push_back(LLONG_MAX);

        ll left_index = 0;
        ll right_index = 0;
        ll result = 0;
        for (ll i = left; i < right; i++) {
            if (l[left_index] <= r[right_index]) {
                target_[i] = l[left_index];
                left_index++;
            } else {
                target_[i] = r[right_index];
                right_index++;
                result += ((mid - left) - left_index);
            }
        }
        return result;
    }

    vector<ll> target_;
};

int main() {
    ll N;
    cin >> N;
    vector<ll> a(N);
    for (ll i = 0; i < N; i++) {
        cin >> a[i];
    }

    auto medianNum = [&](ll x) {
        vector<ll> b(N + 1, 0);
        for (ll i = 0; i < N; i++) {
            b[i + 1] = (a[i] >= x ? 1 : -1) + b[i];
        }
        InversionCount ic(b);
        return N * (N + 1) / 2 - ic.count();
    };

    ll ok = 0, ng = INT_MAX;
    while (ok + 1 != ng) {
        ll mid = (ok + ng) / 2;
        ll num = medianNum(mid);
        (2 * num >= N * (N + 1) / 2 ? ok = mid : ng = mid);
    }
    cout << ok << endl;
}