AtCoder Regular Contest 018 C - 席替え

問題

 NM列の長方形状に並べられた机に対して生徒が一人ずつ座っている。i行目j列目にいる生徒の成績をG[i][j]としたとき、a \gt cならばG[a][b] \ge G[c][d]となるようにしたい。そのような席替えのうち、生徒の移動距離(マンハッタン距離)が最小になるものについてその最小値を求めよ。

解法

 まず各成績の人が何行目に来なければならないかを計算する。そして全員をその行に移動させてから、列が被っていれば被らないように横に振り分けることで目的の並び方が達成される。

 前半の操作について、 apで割り切れないときp \ge N \times Mより同じ成績の人は発生しない。したがって出現した成績を一つのvectorに押し込んでソートしてそのインデックスをMで割れば、ある成績の人が何行目へ行かなければならないかが求まる。

 後半の操作は各行について、列を順番に見ていってj列目までにはj人いないといけないなので、累積和とのずれを計算するとちょうと移動させる量が求まる。

  apで割り切れるときは同じ成績の人が発生しうる。xpより小さければ全員同じ成績なので答えは0。大きければ先頭の人だけ成績が大きいので、その人を一番後ろの人と交換するため 2(N - 1)が答えとなる。

反省

 1時間12分(6WA)でAC。40分ほどでコーナーケース以外の方針はわかったが、最後の10ケースが間違ってしまう理由がわからずかなり苦しんだ。てっきり成績は一意に定まると思い込んでいたが、 apで割り切れる場合があったのを失念していた。そして割り切れる場合も全て答えは0となるわけではないというところにも引っかかった。

 pで割ったものが成績となると明示的に問題文には書かれていないが、乱数生成といったらそれが常識なんだろうか。一番最初だけmodをとらないというのも罠で、かなりひっかけポイントが多い問題だったように感じる。

 後半の各行に対する操作は、厳密な証明はできないが感覚的にそんな感じで求まるんじゃないかと試してみたら答えが合っていたというものだった。これも典型問題だろうか。解説等を見ると行についても再度ソートしてどこへ行くかを計算すればよいだけだった。もっと簡単に考えなければ。

コード

#include"bits/stdc++.h"
using namespace std;
using ll = int64_t;

int main() {
    ll N, M;
    cin >> N >> M;
    ll x, a, p;
    cin >> x >> a >> p;

    vector<vector<ll>> G(N, vector<ll>(M));
    vector<ll> num;
    for (ll i = 0; i < N; i++) {
        for (ll j = 0; j < M; j++) {
            if (i == 0 && j == 0) {
                G[i][j] = x;
            } else {
                G[i][j] = (j == 0 ? (G[i - 1][M - 1] + a) % p : (G[i][j - 1] + a) % p);
            }
            num.push_back(G[i][j]);
        }
    }

    if (a % p == 0) {
        if (x > p) {
            cout << (N - 1) * 2 << endl;
        } else {
            cout << 0 << endl;
        }
        return 0;
    }

    sort(num.begin(), num.end());
    map<ll, ll> raw;
    for (ll i = 0; i < num.size(); i++) {
        raw[num[i]] = i / M;
    }

    ll ans = 0;

    vector<vector<ll>> after_num(N, vector<ll>(M, 0));
    for (ll i = 0; i < N; i++) {
        for (ll j = 0; j < M; j++) {
            ans += abs(i - raw[G[i][j]]);
            after_num[raw[G[i][j]]][j]++;
        }
    }

    for (ll i = 0; i < N; i++) {
        assert(accumulate(after_num[i].begin(), after_num[i].end(), (ll)0) <= M);
        ll sum = 0;
        for (ll j = 0; j < M; j++) {
            ans += abs(sum - j);
            sum += after_num[i][j];
        }
    }

    cout << ans << endl;
}