Transformer系世界モデル手法IRISとTWMの比較

 以下の2つの論文を比較する。

 共通点としてどちらも

となっている。

 提案手法に沿って、呼称は以下のようにする

手法

IRIS

IRIS Figure 1

TWM

TWM Figure1

 まず手法の概要として共通点は

  • 観測を離散AutoEncoderで複数の離散トークンにする
  • 観測の離散トークンと行動などをTransformerに入力する

大雑把には似ている。以下違いを見ていく。

Transformerへの入力

  • IRISの入力は「観測の離散トークン + 行動」
  • TWMの入力は「観測の離散トークン + 行動 + 報酬」

Transformerの使い方

  • IRISは自己回帰的に次状態の離散トークン・報酬・終了判定を予測する
  • TWMは行動に相当する位置のトークンの変換結果を hとして、そこに情報が集約されていると考え、 hから各種次の予測を行う

離散化Encoder

要素 IRIS TWM
入力サイズ 64x64x3(スタックしているかは不明。たぶんしているのでは) 64x64x4 (直近4フレームをスタック・グレースケール化)
離散化トークンのクラス数 512 32
1観測を何個のトークンにするか 16 32
損失の工夫 Perceptual Lossあり consistency lossというものを提案・導入

 離散化クラス数などはやや異なるが、概ね一致。

Transformer

要素 IRIS TWM
ベース実装 minGPT Transformer-XL
次元 256 256
層数 10 10
head数 4 4
入力する最大タイムステップ数 20 16

 Transformerパラメータはかなり近い。

方策の入力

  • IRIS:復元した観測
  • TWM:観測の表現(離散トークン)

 TWMの方で良し悪しについては記述がある。

  • 観測そのまま or 復元した観測:画像での学習なので方策が安定しやすい
  • 観測の表現:遷移予測の学習によっても変更が加えられていくので方策モデルがその変化にも追従しなければいけない。エントロピー正則化と一貫性損失を入れると安定する
  • 観測の表現 + 遷移の履歴(Dreamer形式):推論時に遷移モデルも実行するコストが発生する

 IRISの方策はさも各ステップ独立で推論できそうだが、Appendixを見るとLSTMを使っているとあり、過去系列も見ていそう。

実装

実験にかかる時間

  • IRIS
    • 実験全体は8基のA100で実行
    • 1枚のGPUで2つのAtari環境を実行して、おおよそ7日かかる。つまり1環境あたりだと3.5日
    • おそらく上記が5シードまとめての時間なので、1環境1シードだと0.7日 = 16.8時間
  • TWM(おそらく1環境・1シードあたりについての記述)
    • A100を1台だと約10時間
    • 3090を1台だと約12時間
    • Transformer-XLのメモリ機構を使わないVanilla Transformerを使うと約1.5倍かかる

 IRISの方が若干遅いのは、(1) TransformerがVanillaのもの (2)Actor-CriticにLSTMセルが使用されている というところがありそう。とはいえそこまで大差ではない。

評価

 どちらもDeep Reinforcement Learning at the Edge of the Statistical Precipiceに従っている。(おそらくAtari 100kベンチマークについての評価用OSSツールが利用されているのだと思われる。良さそう)

 比較手法もだいたい同じ。

IRIS Figure 5

TWM Figure 3

 どちらもMedianは幅広い。MeanとInterquantile Meanでは良い。

 各ゲームでのスコア表をマージすると以下のようになった。色付けは各行ごとに緑が最小、白が最大。

 (スプレッドシートhttps://docs.google.com/spreadsheets/d/1B1gMkAqUhcmGdsDui7_1lHjne3E3sOvXSg_4fIovNw4/edit?usp=sharing

 得意/苦手なゲームの傾向もやや似ていそう? ちょっとよくわからないか。EfficientZeroが強い。

その他・実験考察など

  • IRIS
    • Frostbite と Krullというステージ遷移がある場合に苦労する
    • フレームあたりの観測離散化トークン数を64に増やすと改善するゲームもある(もちろんその分世界モデルが重くなる)
    • 学習量を10Mに伸ばすと性能も伸びる
  • TWM
    • Actorのエントロピー正則化についても閾値を導入して、下回った場合のみ有効になるようにする
    • 過去のデータを何度もサンプリングして学習に利用しすぎないようにBalanced Dataset Samplingを行う
    • Attentionのヒートマップを見ると、報酬が1になるようなタイミングに重みがあったりするので、そういうタイミングへの学習が行われていそう
    • 世界モデルの学習系列を16から4に減らすと明確に性能が落ちる

所感

  • 細かい差異はともかく、Transformerに入れて世界モデルとして使うとこれくらいできますというのが確度の高い情報として提示された印象
  • それぞれの手法について、予測部分で行動まで出してそれをPolicyにするというのは不可能なんだろうか