最近の世界モデル系手法ではVQ-VAEが当たり前のように使われているので試してみたいが、生のVQ-VAEだとcommitment lossとかentropy lossとか、様々な工夫を入れなければいけないことが気になり、そういう工夫が要らないと主張されていて内容もシンプルなFSQを試してみた。
公式実装がjaxで存在しているのでこれを利用した。
タスクとしては単純な画像再構成で、CIFAR-10あたりだとちょっとトイ・タスクすぎるので、STL-10データを使った。
STL-10データの特徴は冒頭に書いてある通り
- 10 クラス: airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck
- 画像サイズ 96x96 pixels
- 各クラス train500枚、test800枚の画像がある
となる。unlabeledなデータもあるが、今回はとりあえずtrain 500 x 10クラスの5000枚で学習を試した。
VQ-VAEを使う上で、Encoder-Decoder部分のCNNも、あまり適当なものではなく実績のあるものを使いたかったため、MaskGITの実装を利用した。(両方Google Researchの研究なので)こちらもjaxでありFSQと繋げるのは簡単。
実装
FSQをlevel3で要素数10としたので、コードブックサイズとしてはと結構大きくなっており、VQ-VAEの方は単純にコードブックサイズ1024としたので、フェアではないかもしれない。FSQのlevel2とするとnanが発生して上手く学習が進まなかった。
結果
比較実験
まずFSQと補助損失がなにも入っていないvanilla VQ-VAEと比較した。
再構成損失が明らかにFSQの方が良く、損失の推移も安定的だった。
FSQの長時間実験
FSQを上記比較実験から10倍の8時間回してどの程度再構成ができるかを確認した。10倍回すとかなり損失も下がる。
学習データにあるもの
データID | 結果 |
---|---|
0 | |
1 | |
2 | |
3 | |
4 |
学習データ内にあるものはまぁまぁ再現ができているといえばできているかもしれないが、No.2のチーターの画像はちょっと細かくてよくわからない感じになっている。
テストデータ
データID | 結果 |
---|---|
0 | |
1 | |
2 | |
3 | |
4 |
テストデータにあるものだとまだ崩れがちである。学習データも学習量も足りていなさそうな中ではよくやっているか。
今後
- BBFの次状態予測部分にFSQを導入してみる
余談
FSQのLevel2だと学習が上手くいかないのは、元のコードがバグっているからのような気がした。自前で実装してみると学習は進んだ。しかし、損失の減りはLevel3(10次元)より悪かった。
前述の通り、コードブックサイズがとではかなり違うので仕方がないことではありそう。