Recurrent World Models Facilitate Policy Evolutionを読んだ

出典

 David Ha and Jürgen Schmidhuber, “Recurrent World Models Facilitate Policy Evolution,” Advances in Neural Information Processing Systems 31, 2018.

 arXiv:https://arxiv.org/abs/1809.01999

 World Models:https://arxiv.org/abs/1803.10122

概要

 環境モデルをVAEとRNNを用いて学習

詳細

提案システムの全体像

 入力画像をVAEで表現に変換し、それをRNNに入力していく。制御部分(C)は表現zとRNNの隠れ状態hを入力とし、行動を出力する。

VAE

  • 入力画像(観測情報)o_tから状態表現z_tを獲得。

MDN-RNN

  • 状態表現z_tと行動a_tを受け取って次状態の表現z_{t + 1}を混合ガウス分布として予測するRNN
    • この仕組みは既存の手法

Controller

  • 状態表現z_tとRNNの隠れ状態h_tを連結したものを入力
  • 線形関数で方策決定a_t = W_c[ z_t h_t ] + b_c
  • パラメータ数が少ない線形関数なので勾配法によらない最適化が可能

実験

実験1:Car Racing

  • 真上視点の画像を入力に車を運転するタスク
  • 学習手順
    1. ランダムな行動で10,000回試行しデータを収集
    2. Vを学習:収集したデータに出てきた画像を再構成
    3. Mを学習:学習したVと収集データからP(z_{t + 1}|a_t, z_t, h_t)を学習
    4. Cを学習:学習したV,Mから入力を得て、実際の環境から得られる報酬を最大化

  • M(から得られる隠れ状態h)を使わないと低性能(下から2,3段目)
  • VとMを両方使うことで高性能(最下段)

実験2:Viz Doom

  • 敵モンスターが撃ってくる火の玉を左右移動で避けるタスク
  • 学習した環境モデルのみを用いて学習
    • Mモデルに「プレイヤーが玉に当たったかどうか:d」を予測する出力を追加
  • 学習手順
    1. ランダムな行動で10,000回試行しデータを収集
    2. Vを学習:収集したデータに出てきた画像を再構成
    3. Mを学習:学習したVと収集データからP(z_{t + 1},d_t |a_t, z_t, h_t)を学習
    4. Cを学習:学習したV,Mを用いた仮想環境で生存時間を最大化するように学習

  • 仮想環境のランダム性\tauを上げると遷移が不安定になり難易度上昇
    • 不完全な環境モデルへの過学習を抑制→学習中のスコアは落ちるが実スコアは向上

Discussion

  • 環境モデルを学習できることはコストの観点から有用
    • シミュレーション上で学習した結果を現実世界に転用する研究(4, 33)を補完
  • VAEは一般的な(再構成に有用な)特徴を獲得
    • タスクを解く上で有用な特徴量を取れるとは限らない
    • 人間はタスクに関連した特徴を学ぶ(65)
  • 上達しないと観測できない状態があると今回の手法では不十分
    • 方策の学習と環境モデルの学習を同時に行う必要性
      • 環境モデルの学習に有用な情報を得る制御(83)
      • 好奇心、内的動機づけを用いた探索(61, 64, 77, 80, 81)
      • 情報を探す手法(23, 86)
      • 新しい探索を奨励する手法(47)
  • VAE&MDN-RNNの表現力の限界
    • Catastrophic Forgettingの問題(16, 43, 69)
    • より大きいモデル(27, 89, 93, 97, 98)や記憶モジュール(19, 107)を使うことによる表現力向上の可能性
  • 単純なRNNだと1ステップごとのシミュレーション以外不可能
    • 時空間の細部を無視した抽象的思考や階層的計画が重要
      • RNNをサブルーチンに持つCを用いる研究(83)
      • CとMを一つのネットワークに融合する研究(84)
      • PowerPlay(82, 91)
      • Behavioural Replay(79)

所感

  • モデルのみを用いた学習で、モデルに過剰適合したCheatingが起こるという話(4.2節)が面白かった
  • Discussion部分が興味深く、ここで挙げられていた論文は一通り目を通したいところ
  • 結局Cへの入力に使っているのは状態表現とRNNの隠れ状態だけなので先読みを明示的に行う方法も考えたい
    • 次状態予測部分をRNNでモデル化しなければならない必然性があるのかどうか