PonderNet: Learning to Ponderを読んだ際のメモ

出典

Abst

 標準的なニューラルネットワークにおいて、計算量は入力のサイズに伴って大きくなるが、学習する問題の複雑さに対応して大きくなるわけではない。この限界を打ち破るために、PonderNetを提案する。このアルゴリズムは問題の複雑さに応じて計算量を適応させることを学習する。PonderNetは、学習時の予測精度、計算コスト、汎化という3点についてちょうど良い妥協点を達成するために、エンドツーエンドで計算ステップ数を学習する。複雑な合成問題において、PonderNetは既存の計算量を適応的に決定する手法に比べて大幅に性能を向上し、典型的なニューラルネットワークが失敗するような外挿テストに成功する。また、Pondernetは現実世界の質問応答データセットにおいて現在のSOTAと同様の性能をより少ない計算量で達成する。ニューラルネットワークの推論能力を検証するための複雑なタスクにおいてもSOTAと同等の性能を示す。

手法

 一般的な教師あり学習を考える。常に同じ入力が与えられるものとして、再帰的なネットワークについて、出力を一つ増やして停止確率を出すようにする。

  •  \hat{y} _ n, h _ {n + 1}, \lambda _ n = s(x, h _ n)
    •  \hat{y} _ n : ネットワークの出力
    •  h _ {n + 1} : 隠れ状態
    •  \lambda _ n : ネットワークの推論を停止する確率
    •  x : 入力

 ベルヌーイ確率変数 \Lambda _ nを定義し、継続( \Lambda _ n = 0)と停止( \Lambda _ n = 1)の2状態についてのマルコフ過程を考える。遷移確率を

 P(\Lambda _ n = 1 | \Lambda _ {n - 1} = 0) = \lambda _ n \; \forall 1 \le n \le N

とする。ステップnまでずっと停止せず、nで初めて停止状態に入る確率は

 p _ n = \lambda _ n \prod _ {j = 1} ^ {n - 1} (1 - \lambda _ j)

である。PonderNetの出力として、上記の確率で当該ステップの出力を選択する。

細かい部分について

 最大ステップ数 N \to \inftyであれば上記の説明で良いのだが、実際は Nが有限であるため、確率 p _ nの合計が1になるように修正する必要がある。修正方法は2案あり、

  • 合計が1になるように全体を正規化する
  • 1に足りない停止確率を最後のステップに割り当てる

 最大ステップ数 Nの決め方は

  • 検証時 : 計算時間等の制約に基づいて定数として決定する
  • 学習時 : 停止の最小累積確率を決める
    • つまり \sum _ {j = 1} ^ n p _ j \gt 1 - \epsilonとなったら終わり(実験では \epsilon = 0.05)

損失

 損失は再構成損失 L _ {Rec}正則化 L _ {Reg}から成る。

 L = L _ {Rec}  + \beta L_{Reg}

 L _ {Rec} = \sum _ {n = 1} ^ N p _ n \mathcal{L} (y, \hat{y} _ n)

 L _ {Reg} = KL(p _ n || p _ G(\lambda _ p))

  •  \mathcal{L}は自乗誤差なり交差エントロピーなり、タスクに応じた損失関数
  •  \lambda _ pは停止ポリシーを示す事前分布(幾何分布)を決定するハイパーパラメータ
    • 幾何分布とは? Wikipedia
    • ベルヌーイ試行を繰り返して初めて成功させるまでの試行回数 X の分布
    • 例えば、サイコロの1の目が出るまで繰り返し投げるとする。p = 1/6 の幾何分布に従うといい、それの台は {1, 2, 3, …} である。
  •  L _ {Reg}を導入する目的
    • ネットワークを事前分布の確率1 / \lambda _ pにバイアスする
    • 可能な全てのステップ数に0ではない確率を与えることを促進する
    • (所感 : ここはちょっとすぐにはピンとこなかった。この項が無い場合にはどういうことが起こるのだろうか?)

実験

パリティタスク

  • 提案手法はACTという既存手法に比べて性能が高く、また学習時と少し異なるような設定における外挿にも強かった。
  • PonderNetが学習に失敗したのは \lambda _ p = 0.9としたとき
    •  1 / 0.9が1に近く、つまり平均ステップ数が1になるよう正則化がかけられると失敗すると考えられる
  •  \lambda _ p = 0.1としたときは平均10ステップにせよという正則化がかけられることになるわけだが、学習すると3ステップ程度で良いという平均思考ステップ数に落ち着くようになった

 その他bAbIタスク、Paired associative inferenceについてもUniversal Transformerと同等、あるいはより良い性能を達成した。

所感

  • シンプルなやり方で順当に結果を出しているという印象で面白かった
  • 正則化の効力についてはすぐには理解しきれないところもあったが、これは実験を重ねることで肌感覚的にわかってきそうなものにも思える
  • 実験でRNNやUniversal Transformerをベースとしているように、重み共有というか、再帰的なネットワークを基調としていて、やっぱり自分としては今後そこが流行りになっていくんじゃないか(もう流行ってる?)という気持ちはある
    • そういう重み共有モデルにおいてループ回数を動的に決めたいというモチベーションはあり、そのための手法として有力そう