Mambaが気になったので触ってみることにした。大規模な実験を回すのは大変なので、元論文で使われていた人工的なタスクのInduction Headsをやる。
Mamba
とりあえず動かしてみることを優先してMamba自体の理解は後回しにする。
概要としては以下の資料がわかりやすい。
細かい内容については
などがある。実装で参考にしたのは
- GitHub - state-spaces/mamba : 公式
- GitHub - johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch.
- GitHub - alxndrTL/mamba.py: A simple and efficient Mamba implementation in PyTorch and MLX. : 今回実験に利用
- GitHub - hrbigelow/mamba-recall: Experiments with Mamba SSM models
など。
Parallel Scan
Mambaの内容を軽く追っているときに、単純な累積和が2分木を利用したParallel Scanで実行できることはわかるが、実際には
を計算することになる。これもできるのか一瞬迷うところはあった。根本的には、Parallel Scanは結合則が成り立つものに対してはできるということなのだと思う。全く厳密ではないが、直感的にはもう1ステップ展開して
まで見ると、上手くデータを保持すればできそうということになる。
たとえば4系列で処理の動きを考えてみると
となる。これはコード的にはmamba.pyのpscan.pyを読むとわかりやすい。
https://github.com/alxndrTL/mamba.py/blob/a21c448e0f46a011d8edf792f8349c952e87d620/pscan.py#L37-L92
Induction Head
論文のTable 2(Table? Figureでは?)のMambaの線を再現する。
ハイパーパラメータは論文の通り。比較手法はなしでMambaだけがちゃんと動くことを確認した。
学習の損失推移は以下の通り。途中で"気づく"タイミングが発生してそこから損失が一気に減る。
学習系列長は256で固定して、学習終了後、テストとして32から2倍ずつした系列長で汎化を確認した。各長さ100サンプルで簡単に検証したところ、系列長16384までは正答率100%になった。
つまり系列長に対して行動を丸暗記しているのではなく、規定のトークンが来たときにその次のトークンを記憶しておくということがちゃんとできていると考えられる。
系列長をさらに2倍にするとGPUのメモリが足りなかったため、自己回帰モードでの推論ができることも確認した。系列長262144まで正答率100%であり、524288で82%、1048576で49%という結果になった。
速度については、Parallel Scanでの実行も自己回帰も、系列全体として考えるとであり、自己回帰の方がとても定数倍が大きく遅い。たとえば系列長1048576をテストするのには2時間以上かかっている。(以下は両対数グラフ)
系列がすでに定まっている状態で推論させるなら当然Parallel Scanで良いが、逐次的な生成が必要になる場面ではやはり各ステップで時間・空間計算量がであるというのは魅力的だ。