MineRL 世界モデル学習 その6 バッチサイズ変更

 前回はMambaを用いて時系列入力をしたときに上手く学習できることを確認した。

 今回は、今後ストリーム学習(データをバッファに溜めてランダムサンプルするのではなく、その場ですぐ学習して捨てること)をするにあたって、これは (1)バッチサイズが1である、 かつ (2)ランダムシャッフルなし ということでもあるので、まずオフライン学習としてその条件でまともに学習が進むのかどうかを確認する。

実験1 バッチサイズを1まで下げて大丈夫か

 前回までと同様の設定でバッチサイズを64から16, 4, 1と減らしていく。

 ステップ数や学習サイズは変えずに学習を実行した結果が以下となる。

 バッチサイズを小さくすると学習損失が不安定になるし、最終的な損失としても高いままであることがわかる。

 しかし、バッチサイズ1でも学習が崩壊するということはなさそうだったので、なんとかやれるのかもしれない。バッチサイズ1での出力結果は以下の通り。

GT 予測

 良くはないし、2行目1列目で、インベントリは開いていないのにそういう画面を予測してしまっているようなところもあるが、とりあえずは許容範囲内ということでOKとする。

 一応、学習に使ったデータ数や、学習時間で比較すると以下のようになる。

学習に使ったデータ数 学習時間
データ数で見るとバッチサイズを小さくする弊害はそこまでなさそう 時間効率としてはやはりバッチサイズ小さいことは得にはならない。得に1だと良くない

実験2 ランダムサンプルなしで学習できるか

 バッチサイズ1の条件で、train_loader

    train_loader = DataLoader(
        dataset,
        batch_size=int(args.batch_size),
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True,
    )

shuffle=Falseとして学習できるかどうかを確認した。結果は以下の通り。

 学習している状況が大きく変わるたびに損失が波打ってしまっているが、やはり学習が根本的にできないということではなさそうだった。

 予測結果は以下の通り。

GT 予測

 ちょうど最後に学習したのが海付近の環境だったのか、出力には水っぽいものが多く出ている。予測として出力させるときには条件付けとしてこの予測画像の前15フレームほどが入っているはずだが、その文脈を適切に捉えられてはいないようだ。

実装

 最後に、ここまでは毎回データセットから固定長16のデータを取り出して、毎回ゼロの状態からMambaを含めた予測ネットワークに入力していたが、shuffle=Falseにした状態でMambaをステップモード(毎回1つのデータが入力されるモード)で動かすことで、ほぼストリーム学習に近い設定になる。これを実現するためには実装の修正が必要なのでそれを行った。

 まず、Mambaに allocate_inference_cache という関数があるので、それを使って状態を初期化する。

    def allocate_inference_cache(self, batch_size, dtype=None):
        return self.mamba.allocate_inference_cache(batch_size, dtype=dtype)

 この状態を使って step を繰り返して、ずっと状態を連続させる形で学習させた。大まかには

  • 時刻 t開始
    • 画像 t、行動 tが得られる
    • ノイズを作り、時系列モデルの状態で条件付けして画像 tを予測するように生成する
    • Flowを比較して損失を計算し、誤差逆伝播してネットワークの学習を1ステップ進める
    • 画像 tと行動 tを用いて時系列モデルの状態を進める
    •  t \leftarrow t + 1

 具体的な実装はhttps://github.com/SakodaShintaro/minerl_practice/blob/main/python/train_diffusion/train_stream.py

 この実装をもとに学習はできたものの、先の教師あり学習(バッチサイズ1、ランダムシャッフルなし)よりも損失が大きい状態で止まってしまっている。

 Mambaの状態が、教師あり学習であれば、結局固定長のシーケンスで学習されるが、今回の step を使う学習ではずっと過去の状態を引きずることになるので、そこで差分があるとは思われる。

 同じタイムスタンプのGTと予測を並べると以下のようになる。

 全然良い予測にはなっていなさそう。なんらかのバグがあるような気がするので、次回はそこの調査から。