前回はParallel Scanの逆伝播と状態空間モデルの離散化について確認した。今回は残りの細かい部分として
- Mambaの高速化工夫
- 実はBはオイラー法
- Gated RNNとの関係
について確認する。
訂正
前回の記事で
Mambaの論文ではメモリ使用量を抑えるために、値を保持しておくのではなく逆伝播時に再計算をすると書かれている。mamba.pyでもそのようにされている
と書いていたが、論文を読み直していて気づいたこととしてMambaでのメモリ消費量(およびデータ転送を抑えることで計算自体の高速化)は、この部分ではなかったようだ。そもそもの離散化なども含めた計算パートをGPUのSRAMで上手く行うとのことである。
そもそもMambaの高速動作に寄与する工夫は大きく3つあると記述されているので、改めてそれを整理する。
(1)Parallel Scan
前回などでこれまで書いてきた通りなので省略
(2)カーネル融合
まず、SSMのScanの計算量は
- : バッチサイズ
- : 系列長
- : SSMに入ってくる次元
- : SSM内部での状態の次元
としたときに FLOPsとなる。ただし、実践的な実行ボトルネックはメモリIOであるため、そこを減らすことが重要になる。
Mambaの式をそのまま実装すると
- まず離散化されたをHBM (High-Bandwidth Memory:一般的にGPUメモリと言われるもの) で計算する。のサイズはである。
- Parallel ScanもHBMで計算される。この操作後のデータのサイズもである
- をかけての出力を得る
という手順になる。しかし、これだと手順中にHBMでのメモリIOが発生する。
なのでこれを改良し
- バイトのデータをメモリHBMからSRAM (Static Random-Access Memory) に転送する
- SRAM上でを計算する
- SRAM上でParallel Scanを実行する
- をかけての出力を得た後、HBMに結果を戻す
とする。系列長が長すぎてSRAM上に収まらない場合、シーケンスをチャンクに分割して各チャンクで実行する。中間状態があれば処理を上手く継続できる。
(3)再計算
forwardパスをカーネル融合で実装していても、逆伝播の際に順伝播の値を取っておいて利用するならば結局の値を転送することになる。これはやはりメモリIOが多くなって遅いので、逆伝播計算も再計算することが工夫となる。結局SSMの入出力部分・パラメータのshapeを考えるとであるため、その量だけHBMとSRAM間でやり取りする方が高速になる。スキャン操作部分だけでなく、SSMブロック全体で再計算した方が良いところは再計算する。
はBはオイラー法
前回、状態空間モデルのZOHによる離散化を考えたのは良いが、の実装においては違う式が利用されている。
これはオイラー法ということらしい。多少調べた理解だと、サンプリング時間が十分に短いならばオイラー法で良い近似になるのでZOHを真面目に考える必要がないということだろうか。次の項で見るように、Mambaではある意味でサンプリング時間を動的に調節できるようになるということでもあるので、それが十分小になっていれば良い? しかし言語モデルのトークン系列としては1トークンが持つ意味の量というものが決まってしまうので、そこがどうなんだろうか。
Gated RNNとの関係
Mambaで重要だとして導入された選択機構は、古典的なRNNのGating機構と関係がある(3.5節)。特に、SSMのはRNN Gating機構の一般化になっている。
定理1 とする。このとき、SSMの定式化は
と同じになる(これがMambaの中においてをこう定めた理由にもなっている)。
証明 前提より、状態空間モデルの式は
となる。離散ステップは
と計算される。ここで関数は
である。なのでZOHによる離散化まで考慮すると、をシグモイド関数として
であり、
となる。これはつまり
ということである。
つまり、Mambaの動作解釈として、Gateによって入力をどの程度受け入れるか、ということを制御できるということになる。Selective CopyタスクやInduction Headsタスクなどをやれるのがそういう機能だと見なせる。ただし、結局言語モデルとしての高級な機能にどこまで寄与するものなのだかは、正直よくわからない。
所感
しばらくMambaの論文等を追ってみて、ある程度はわかってきた気がする。これが強いアーキテクチャとして有力になっていくのかどうかはわからないが、結局RNN的に推論時とても軽量である性質はエッジデバイスで動かすことなどを考えると必要になってくるのではないか。
個人的には、さらに階層的な時間ステップを導入していくことに興味がある。言語トークンのステップに縛られない、あるいはセンサデータの入力周波数に縛られないような動作が重要になると信じたいものだ。