時系列モデルが木構造を学習できることの検証

 前回の考察では、時系列モデルが暗黙のうちに木構造を学習できるので木の遷移履歴を時系列展開しても良いという仮説を立てた。この仮説を多少なりとも検証するため、今回は木に関する簡単なタスクを考えて、それが学習可能かどうかを実験により確かめた。

実験

 タスク概要:「木についての深さ優先探索での訪問順が与えられるので、幅優先探索での訪問順を出力せよ」

具体的なデータ生成手順

  1.  N頂点の木をランダムに生成する
  2. 生成した木についてランダムに根を決定
  3. 根から深さ優先探索を行い、訪問順を記録(これを時系列モデルへの入力とする)
  4. 根から幅優先探索を行い、訪問順を記録(これを正答とする)

f:id:tokumini:20200708111225p:plain

細かい点

  • DFSでは戻る際の訪問も記録しているので、系列から完全に木を復元可能(だと思う)
  • BFSにおける同じ深さでの優先度はDFSで先に探索された順と同じとする
  • ノードを入力する際はOnehotベクトルの形式にして入力
  •  N頂点の木から得られるDFSの系列(入力)は 2N - 1

実装

実験結果

頂点数 N = 6

f:id:tokumini:20200708111650p:plain

 単純なLSTMおよび以前実装したDNCのいずれにおいても50000データほど学習させると正答率(BFS順を完全一致で出力できた割合)が100%になった。 N = 6という小さい木ではあるが、訪問順を詳細に記録しておけば木構造を認識することはできていそうである。

頂点数 N = 20, 50

 LSTMについて N = 20, 50として実験を行った。

f:id:tokumini:20200710135415p:plain

 20ノード, 50ノードのものはバッチ学習として1ステップ32データにしているので、データ数としては5 * 32 = 160万データとなる。それだけ学習させると20ノードのとき正答率93.7%、50ノードのとき正答率8.0%となった。

頂点数 N = 50 その2

 50頂点くらいはなんとか学習してほしかったので、GPUを使うようにプログラムを変更してバッチサイズを256、学習ステップ数を20万(計データ数は5億程度)にして再度実験を行った。

f:id:tokumini:20200711101243p:plain

 最初は隠れ層の次元数512(青線)でやったところ正答率60%程度で頭打ちになり最後は突然nanになって崩壊してしまった。1024(赤線)でやり直したが、学習は遅いうえ損失も(nanではないが)爆発して正答率が崩壊するのは同様だった。LSTMで記憶容量を超えるとこういう感じで崩壊してしまうものなのだろうか。

雑記

 特に工夫していない1層のLSTMではそこまで長い系列(大きい木)は扱えないかもしれない。DNCをバッチ学習に対応させて可能性があるかどうか、あるいは違う系列モデルを実装するか。 N = 50だと入力系列の長さは 99であり、自然言語処理で考えて99単語からなる文となると長めではあるか。

 一方で30頂点くらいでもちゃんと探索できれば探索なしよりは強くなるんじゃないかという気もする。この記事でやっている実験は実際のゲーム木探索での状況とはやや異なるので、これの性能を追い求めていてもなぁとも思うところ。MCTSnetも実装するべきことを考えると結構実装が重たいので早めに実際の利用環境で試してみたいという気持ちはある。どちらの方針を取るにしてもとにかく実装を進めねば。