MCTSnetの損失計算部

 MCTSnetの解説は他にもある

ので、そちらも参照されたし。この記事では損失計算部分にだけ注目して記述する。

 arXiv版OpenReview版は式番号が異なるので注意。OpenReview(ICLR2018)で一回Rejectになって、ICML2018に通っていて、arXivの最新版はそのICMLにある版のと近そうなので、とりあえずこの記事では基本的にarXiv版に準拠する。

3.5 MCTSnetの訓練

 MCTSnetの読み出しネットワークは、最終的に探索全体を考慮した行動決定および評価を出力する。 この最終出力は、基本的に値ベースの学習や方策勾配法などの損失関数に従って(つまり強化学習によって)訓練できるが、簡単のため以下では各状態で教師行動が決まっているもとで教師あり学習でMCTSnetを学習することを考える。

 記号の定義

  •  z _m :  m回目のシミュレーションで確率的にサンプリングされた行動系列
  •  z_{\le m} :  m回目までのシミュレーションで確率的にサンプリングされた行動系列の系列
  •  \boldsymbol{\mathrm{z}} : 確率変数としての z_{\le m}
  •  M : 総シミュレーション回数

 基本的には M回シミュレーションを行った後の最終的な出力方策 p _\theta(a|s,  \boldsymbol{\mathrm{z}})について、教師行動 a ^ *について交差エントロピーを計算する。


\displaystyle
l(s, a ^ * ) = \mathbb{E} _ {\boldsymbol{\mathrm{z}} \sim \pi(\boldsymbol{\mathrm{z}}|s)} [ -\log p _\theta(a ^ *|s,  \boldsymbol{\mathrm{z}}) ] \tag{8}

 上式は周辺分布についての対数尤度の下限としても解釈できる(※ここはよくわからない)。

 そして現実的に \boldsymbol{\mathrm{z}}全てについて計算できるわけはないので、一つのサンプルについて勾配の推定を行う。(Schulman et al., 2015が引用されているのでちらっと見てみたがよくわからなかった)


\displaystyle
\nabla _ \theta l(s, a ^ * ) = -\mathbb{E} _ {z} [\nabla _ \theta \log p _\theta(a ^ *|s,  \boldsymbol{\mathrm{z}}) 
+ (\nabla \log \pi(\boldsymbol{\mathrm{z}}|s; \theta _ s)) \log p _\theta(a ^ *|s,  \boldsymbol{\mathrm{z}}) ]
\tag{9}

 (この式中の \boldsymbol{\mathrm{z}} zの間違いである気もするが、とりあえず論文通りの記述にしておく)

 最初の項は単純に上の損失をそのまま微分したもの。

 2番目の項はシミュレーション分布に関する勾配に対応し、REINFORCEアルゴリズムあるいはActor-Criticのようなやり方によって学習する。この項において \log p _\theta(a ^ *|s,  \boldsymbol{\mathrm{z}})が報酬信号の役割を果たす。これが大きくなるようなシミュレーション系列は良いシミュレーション系列だったと見なすことができるため。

 (式としては記述がないが?)シミュレーション方策 \pi(a|s; \theta _ s)に負のエントロピー正則化項を追加する。

3.6 信頼割り当ての工夫

 REINFORCE勾配は不偏だが分散が大きい。最終的な行動 a ^ *を決定するまでに M回のシミュレーションが寄与していて、各シミュレーション中でもどのノードを選択していくかという問題があるため計 O(M \log M)から O(M ^ 2)ほどの行動決定があり、信頼の割り当てが難しい。これに対処するため一つの逐次的決定問題のサンプルからBias-Varianceのトレードオフを調整できる手法を提案する。

  M回のシミュレーションを行うとして、 m = 1, \dots, Mのいつでも最終的な行動決定分布 p _\theta(a|s, z_{\le m})を考え、損失 lを計算することはできる。 m回時点での損失を l _ m = l(p _ \theta(a|s, z _ {\le m}))と定義する。

 最終目標は -l _ Mの最大化である。 l _ 0 = 0としたとき、この最終的な損失は次のように変形できる。


\displaystyle
-l _ M = -(l _ M - l _ 0) = \sum _ {m = 1} ^ {M} -(l _ m - l _ {m - 1})

  \bar{r} _ m = -(l _ m - l _ {m - 1})と置くと、これは m回目のシミュレーションで損失が減少した量であると見なせる。ここで、 m回目以降のこの量の累積和を R _ m = \sum _ {m' \ge m} \bar{r} _ {m'}と置くと、最終的に求めたい量は - l _ M = R _ 1である。

 式(9)のREINFORCE項について -(\nabla \log \pi(\boldsymbol{\mathrm{z}}|s; \theta _ s)) \log p _\theta(a ^ *|s,  \boldsymbol{\mathrm{z}}) = Aと置くと、


\displaystyle
A = \sum _ m (\nabla \log \pi(z _ {m}|s; \theta _ s)) R _ 1
\tag{10}

 (次の式以降そうなるのだが、ここも多分 \pi(z _ {m}|s; \theta _ s)ではなく \pi(z _ {m}|s, z _ {\lt m}; \theta _ s)だと思われる。まぁ自明なものとして省略しているのか)

 ここで z _ mにおける確率変数は m以降の部分にしか影響を及ぼさないことから R _ 1の部分は R _ mに置き換えることができ、


\displaystyle
A = \sum _ m (\nabla \log \pi(z _ {m}|s, z _ {\lt m}; \theta _ s)) R _ m \\
 = -\sum _ m (\nabla \log \pi(z _ {m}|s, z _ {\lt m}; \theta _ s)) (l _ M - l _ {m - 1}) \tag{11, 12}

とすることができる。つまり、 m回目のシミュレーション終了時点では、最終損失をそのまま強化信号として使うのではなく、最終損失と m回目時点での差分を使用する。これによってバイアスは生じないまま分散は小さくなる。

 さらに割引率 \gammaを導入してバイアスと分散のトレードオフを行う。 m回目のシミュレーションは m回目に続く近いところで多く貢献していると考え、遠い未来についての貢献は多少考慮する量を減らすという意味合いがある。 R ^ \gamma _ m = \sum _ {m' \ge m} \gamma ^ {m' - m} r _ {m'}と置いて、MCTSnetの最終的な勾配計算は


\displaystyle
\nabla _ \theta l(s, a ^ * ) = -\mathbb{E} _ {z} [\nabla _ \theta \log p _\theta(a ^ *|s,  z) 
+ \sum _ m \nabla \log \pi(z _ m | s; \theta _ s) R ^ \gamma _ m ] \tag{13}