AtCoder Regular Contest 008 C - THE☆たこ焼き祭り2012

問題

 あなたと参加者を足してN人の人がおり、あなたは持っているN個のたこ焼きを参加者に投げて配る必要がある。参加者もたこ焼きを投げることができ、参加者を経由して他の参加者にたこ焼きを配ることが可能である。各人は2次元座標上に散らばっており、たこ焼きを投げる速度の上限と受け取る速度の上限が決まっている。たこ焼きを全員に配りきるために必要な時間の最小値を求めよ。

解法

 i番目の人がj番目の人にたこ焼きを直接投げつける場合の最小コストは$$\frac{iからjへの距離}{\min\{iが投げる速度の上限,jが受け取る速度の上限\}}$$である。このような重み付きのエッジを持つグラフとして考えて、0番目の人からダイクストラ法を使えば各人への最短経路が求まる。1秒に1個しか投げられないという制約から、この最短経路をソートして小さいほうから N - 1, N - 2, \dots, 1を足したものの最小値が答えとなる。

反省

 22分36秒でAC。最初は1秒に1個しか投げられないという制約や、参加者を経由できるという点に戸惑ってすごく複雑な答えになるのではないかといろいろ考えていたが、13分ほどで結局ダイクストラ法やるだけというところに気づけた。気づいてからは特に詰まることもなく……とは言ってもやはり実装がちょっと遅いか。しかしダイクストラ法の書き方は固まってきたのでこれ以上速くできる気はしない。

 解法としては最短経路で投げる限り同じ人が同時に2個たこ焼きを持つことにはならないという点が重要なんだろうけど、感覚的にしか言えなくてはっきり証明しろと言われたら困ってしまう気がする。おそらくそのような場合がありうると最短性に矛盾ということになるんだな。

コード

#include"bits/stdc++.h"
using namespace std;
using ll = long long;

int main() {
    ll N;
    cin >> N;

    vector<double> x(N), y(N), t(N), r(N);
    for (ll i = 0; i < N; i++) {
        cin >> x[i] >> y[i] >> t[i] >> r[i];
    }

    //iからjに投げるときにかかる時間
    vector<vector<double>> connect(N, vector<double>(N));
    for (ll i = 0; i < N; i++) {
        for (ll j = 0; j < N; j++) {
            double distance = sqrt(pow(x[i] - x[j], 2) + pow(y[i] - y[j], 2));
            double speed = min(t[i], r[j]);
            connect[i][j] = distance / speed;
        }
    }

    //ダイクストラ法
    vector<double> cost(N, INT_MAX);
    struct Element {
        double cost;
        ll curr;
        bool operator<(const Element& rhs) const {
            return cost > rhs.cost;
        }
    };

    priority_queue<Element> pq;

    cost[0] = 0;
    pq.push({ 0.0, 0 });
    while (!pq.empty()) {
        auto t = pq.top();
        pq.pop();

        for (ll i = 0; i < N; i++) {
            if (i == t.curr) {
                continue;
            }

            double new_cost = t.cost + connect[t.curr][i];
            if (new_cost < cost[i]) {
                cost[i] = new_cost;
                pq.push({ new_cost, i });
            }
        }
    }
    
    sort(cost.begin(), cost.end());
    for (ll i = 1; i < N - 1; i++) {
        cost[i] += N - 1 - i;
    }

    printf("%.10f\n", *max_element(cost.begin(), cost.end()));
}