AtCoder Regular Contest 058 E - 和風いろはちゃん / Iroha and Haiku

問題

 長さがNであり、各要素が1から10までの整数である数列a_0, a_1, \dots, a_{N - 1}について、連続する部分列で和がX, Y, Zとなるものがこの順で連続しているときXYZを含むとする。XYZを含むものが何通りか求めよ。

解法

 ほぼ解説の通り。

 XYZを含まないものの数を考える。左から一つずつ数列の値を決定していくDPをしていくとき、必要な状態は直前の数列のいくつかである。具体的には和がX + Y + Zまでとなるような直前の数列を保持していれば今考えているところの数字を決定できる。

 これは数字を1 \to "1",2\to "10",3 \to "100", 4 \to "1000"と符号化し、直前の数列とはこれを並べたものとして表現することで状態数を減らすことができる。直前の数列の和は符号化後の数字の長さとなるため、状態数はO(2^{X + Y + Z})となる。あとはこれをN個分、各数字が1から10まで回していくだけなので全体の計算量はO(10N2^{X + Y + Z})

 DPの遷移はたとえば直前の数列を符号化したものが sであったとすると

今回の数字 遷移後の数列を符号化したもの
1 s1
2 s10
3 s100
4 s1000

 となるので、今回の数字をiとすると結局siだけ左にシフトしたものと1i - 1だけ左にシフトしたもののorで書ける。

 「遷移した数列がXYZを含んでいる⇔符号化したものについてX + Y + Z個目,Y + Z個目,Z個目のビットが立っている」なのであらかじめそのようなビットを立てたものを用意しておくとXYZを含んでいるかの判定も簡単に行える。

反省

 数日前1時間10分ほど考えたが解けなかったので今日はすぐ解説を見たが、それでも1時間近くかかった。

 解いていたときは、X, Y, Zの長さを全探索して含まれないところは自由に取れるというので解けるんじゃないかと勘違いしてサンプルが合わずずっと悩んでいた。問題を解いた人の記事を探してみると同じような誤解をしていた人もちらほら見かけるのでちょっと安心するなど。

 解説はぱっと見では何を言っているのかわからないし、実装もどうするのかよくわからなかったけれど、creep04の実装を参考になんとか理解してみるとすごく賢い解法で良い問題に思えた。こういう問題を解けるようになりたい。

コード

#include"bits/stdc++.h"
using namespace std;
using ll = int64_t;

int main() {
    ll N, X, Y, Z;
    cin >> N >> X >> Y >> Z;

    constexpr ll MOD = (ll)1e9 + 7;
    const ll MAX = 1LL << (X + Y + Z - 1);

    //XYZを含む⇔符号化するとX + Y + Z - 1個目, Y + Z - 1個目, Z - 1個目のビットが立っている
    const ll NG_BITS = (1LL << (X + Y + Z - 1)) | (1LL << (Y + Z - 1)) | (1LL << (Z - 1));

    //dp[i][j] := i番目まで見て,直前の数列を符号化したものがjであるときのXYZとなっていない数
    vector<vector<ll>> dp(N + 1, vector<ll>(MAX, 0));
    dp[0][0] = 1;

    for (ll i = 0; i < N; i++) {
        for (ll j = 1; j <= 10; j++) {
            for (ll k = 0; k < MAX; k++) {
                //k:前回までのbit列
                //遷移:kをjだけ左にシフトして、そこに先頭だけ1を立てたものを入れる
                ll t = (k << j) | (1LL << (j - 1));

                //XYZとなっているかの判定
                if ((NG_BITS & t) != NG_BITS) {
                    //上の方は関係ないのでマスクする
                    t &= (MAX - 1);

                    dp[i + 1][t] += dp[i][k];
                    dp[i + 1][t] %= MOD;
                }
            }
        }
    }

    //全ての数からXYZとならない数を引く
    ll ans = 1;
    for (ll i = 0; i < N; i++) {
        ans *= 10;
        ans %= MOD;
    }

    for (ll i = 0; i < MAX; i++) {
        ans += MOD - dp[N][i];
        ans %= MOD;
    }
    cout << ans << endl;
}