前回はランダムな動作をするエージェントとVAEを動かせることを確認した。18000フレームの行動-状態(画像)のペアを得ることができるようになっている。これを何度か繰り返すことでデータセットを作れる。
これに対して、まず状態と行動から次状態を予測するモデル(いわゆる世界モデル)を作ってみる。強化学習をやるにしても行動の影響がわかった状態でやる方が良いのではないか、という予測に基づく。
Minecraftのような複雑な環境では、同じ状況で同じ行動を取っても次になにが起こるかを完全に予測することは難しいと思われるため、世界モデルには確率的な遷移予測ができてほしい。確率的生成モデルの一つである拡散モデルを使ってみることにする。単純に最近拡散モデルに興味があるので使ってみたいというだけのことでもある。
学習設定
18000フレームx3セットのデータから、連続4フレームずつ抜き出してきて、「状態1~3 + 行動1~3を条件として、状態4についてノイズから逆拡散過程をたどるように学習させる」
状態(画像)はすべて学習済みVAEで1/8サイズにエンコードし、その潜在変数の中で拡散過程/逆拡散過程を行う。行動は24次元のものを潜在変数の次元を同じサイズにMLPで変換し、
画像は256x256サイズにリサイズし、バッチサイズは8にした。
結果
損失の推移をプロットすると以下のようになった。
損失はあるとき急に下がり始めた。
あるデータセットの先頭16フレームを描画する。
- 正解画像
- 予測画像
なんらかのゲーム画面っぽいものは学習できている。インベントリを開く(閉じる)という行動が大きく画面に変化を与えるので、まず大まかに画面の傾向が2種類あるということは大まかに学習できていそう。
一方、現状だとインベントリを開いているタイミングが正解画像と異なっているので、インベントリを開く(閉じる)という動作と関連付けて画面の変化を予測できるわけではない。行動のネットワークへの入力の仕方が良くないのかもしれない。
次回へ向けて
行動の入力の仕方を工夫した方が良さそうだ。また、もっと学習の試行錯誤を高速でやっていくために、画像サイズを小さくしてバッチサイズを上げていった方がやりやすいかもしれない。
またインベントリを3フレーム以上開いてから閉じると、今の仕組みでは開く前の状況が情報としてネットワークに入力できないため、本来はより長く入力を入れていく必要があると思われる。
その他
今回は学習済みEncoderをそのまま利用しているが、最終的にはこれも共同学習したい。やり方についてちょっと悩んでいたが、よく考えると拡散モデルの条件として与えているものにも勾配は通るので、ちゃんとやればそのまま学習できるかもしれない。