平均化MCTSの実装変更

 AlphaZeroで用いられているようなモンテカルロ木探索においては、状態sで行動aを取った回数N(s, a)とその行動以下の部分木から得られた報酬の総和W(s,a)を保存しておくことで、必要な際に行動価値をQ(s, a) = \frac{W(s, a)}{N(s, a)}として求める。

 しかしこの書き方ではSarsa-UCT(λ)の実装をしようとしたときに大きな変更が必要になってしまう。よってまずこの報酬の平均を用いるMCTSQ(s, a)の値を常に保持する形で実装し直し、それが今までのものと同じような出力を出すことを確認してからSarsa-UCT(\lambda)の実装に移りたい。

 今まで状態sにおいて行動an回選択し、報酬r_1, r_2, \dots, r_nが得られていたとする。このとき平均を取るならばQ_n(s, a)

$$ \begin{align} Q_n(s, a) = \frac{\sum_{i = 1}^{n} r_i}{n} \end{align} $$

 である。ここでn+1回目の行動選択が発生し、そのときの報酬が r_{n + 1}だったとするとQ_{n + 1}(s, a)

$$ \begin{align} Q_{n + 1}(s, a) &= \frac{\sum_{i = 1}^{n + 1} r_i}{n + 1} \\ &= \frac{n}{n+1} Q_n(s, a) + \frac{1}{n + 1}r_{n + 1} \\ &= Q_n(s, a) + \frac{1}{n + 1} \left(r_{n+1} - Q_n(s, a) \right) \end{align} $$

 となるので、Q_n(s, a)から学習率(更新ステップ幅)\frac{1}{n + 1}r_{n + 1}を目標値として更新を行ったのだと解釈できる。

 以上の内容はSutton & Barto『強化学習』のp.39にも「(行動価値の)漸進的手法による実装」として同様の式変形が記載されている。

 この結果からバックアップの実装を次のようにした。

void Searcher::backup(std::stack<int32_t>& indices,
                      std::stack<int32_t>& actions) {
    //リーフノードのvalueで初期化
    auto value = hash_table_[indices.top()].value;
    indices.pop();

    static constexpr float LAMBDA = 1.0;

    //バックアップ
    while (!actions.empty()) {
        //置換表におけるindexとそのノードの合法手のうち何番目の指し手かを取得
        auto index = indices.top();
        indices.pop();
        auto action = actions.top();
        actions.pop();

        //手番が変わるので反転
        value = -value;

        //探索回数の更新
        hash_table_[index].N[action]++;
        hash_table_[index].sum_N++;

        //現在のQを退避
        auto curr_q = hash_table_[index].Q[action];

        //Qを更新
        float alpha = 1.0 / hash_table_[index].N[action];
        hash_table_[index].Q[action] += alpha * (value - curr_q);

        //上のノードへバックアップするvalue値を更新
        value = LAMBDA * value + (1.0f - LAMBDA) * curr_q;
        }
    }
}

 示したソースコード中のように\lambda = 1としていればこれは今まで通りの報酬の平均を用いる手法と一致する。これはTD(\lambda)が\lambda = 1において(強化学習における)モンテカルロ法と一致することと同等である。

 この実装に変更し、同じ評価パラメータを用いて元のMCTS実装と500局ほど対局させてみたところ、勝率が37%になってしまった。出力を確認してみたところ、並列化しない場合は元のMCTSと同じ値が出力されていたが、並列化した際に異なる値になっていた。バーチャルロスの影響で学習率が正しく計算できていないようだ。無視できないほどの弱体化なのでバーチャルロスの更新タイミングを考え直すことにする。