Mamba探訪(3)

 前回はParallel Scanの逆伝播と状態空間モデルの離散化について確認した。今回は残りの細かい部分として

  • Mambaの高速化工夫
  • 実はBはオイラー
  • Gated RNNとの関係

について確認する。

訂正

 前回の記事で

Mambaの論文ではメモリ使用量を抑えるために、値を保持しておくのではなく逆伝播時に再計算をすると書かれている。mamba.pyでもそのようにされている

と書いていたが、論文を読み直していて気づいたこととしてMambaでのメモリ消費量(およびデータ転送を抑えることで計算自体の高速化)は、この部分ではなかったようだ。そもそもの離散化なども含めた計算パートをGPUSRAMで上手く行うとのことである。

 そもそもMambaの高速動作に寄与する工夫は大きく3つあると記述されているので、改めてそれを整理する。

(1)Parallel Scan

 前回などでこれまで書いてきた通りなので省略

(2)カーネル融合

 まず、SSMのScanの計算量は

  •  B : バッチサイズ
  •  L : 系列長
  •  D : SSMに入ってくる次元
  •  N : SSM内部での状態の次元

としたときに O(BLDN) FLOPsとなる。ただし、実践的な実行ボトルネックはメモリIOであるため、そこを減らすことが重要になる。

 Mambaの式をそのまま実装すると

  1. まず離散化された \bar{A}, \bar{B}をHBM (High-Bandwidth Memory:一般的にGPUメモリと言われるもの) で計算する。 \bar{A}, \bar{B}のサイズは (B, L, D, N)である。
  2. Parallel ScanもHBMで計算される。この操作後のデータのサイズも (B, L, D, N)である
  3.  Cをかけて (B, L, D)の出力を得る

という手順になる。しかし、これだと手順中にHBMで O(BLDN)のメモリIOが発生する。

 なのでこれを改良し

  1.  O(BLD+DN)バイトのデータ (\Delta, A, B, C)をメモリHBMからSRAM (Static Random-Access Memory) に転送する
  2. SRAM上で \bar{A}, \bar{B}を計算する
  3. SRAM上でParallel Scanを実行する
  4.  Cをかけて (B, L, D)の出力を得た後、HBMに結果を戻す

とする。系列長が長すぎてSRAM上に収まらない場合、シーケンスをチャンクに分割して各チャンクで実行する。中間状態があれば処理を上手く継続できる。

(3)再計算

 forwardパスをカーネル融合で実装していても、逆伝播の際に順伝播の値を取っておいて利用するならば結局 (B, L, D, N)の値を転送することになる。これはやはりメモリIOが多くなって遅いので、逆伝播計算も再計算することが工夫となる。結局SSMの入出力部分・パラメータ \Delta, A, B, Cのshapeを考えると O(BLN + DN)であるため、その量だけHBMとSRAM間でやり取りする方が高速になる。スキャン操作部分だけでなく、SSMブロック全体で再計算した方が良いところは再計算する。

はBはオイラー

 前回、状態空間モデルのZOHによる離散化を考えたのは良いが、 Bの実装においては違う式が利用されている。

 これはオイラー法ということらしい。多少調べた理解だと、サンプリング時間が十分に短いならばオイラー法で良い近似になるのでZOHを真面目に考える必要がないということだろうか。次の項で見るように、Mambaではある意味でサンプリング時間を動的に調節できるようになるということでもあるので、それが十分小になっていれば良い? しかし言語モデルトークン系列としては1トークンが持つ意味の量というものが決まってしまうので、そこがどうなんだろうか。

Gated RNNとの関係

 Mambaで重要だとして導入された選択機構は、古典的なRNNのGating機構と関係がある(3.5節)。特に、SSMの \DeltaはRNN Gating機構の一般化になっている。

 定理1  N = 1, A = -1, B = 1, s _ \Delta = \mathrm{Linear}(x), \tau _ \Delta = \mathrm{softplus}とする。このとき、SSMの定式化は

 
\begin{cases}
g _ t = \sigma ( \mathrm{Linear}(x _ t)) \\
h _ t = (1 - g _ t) h _ {t - 1} + g _ t x _ t
\end{cases}

と同じになる(これがMambaの中において s _ \Delta, \tau _ \Deltaをこう定めた理由にもなっている)。

 証明 前提より、状態空間モデルの式は

 
\begin{align}
\dot{h}(t) = -h(t) + x(t)
\end{align}

となる。離散ステップは

 
\begin{align}
\Delta _ t = \mathrm{softplus}(\mathrm{Linear}(x _ t))
\end{align}

と計算される。ここで \mathrm{softplus}関数は

 
\begin{align}
\mathrm{softplus}(x) = \log (1 + e ^ x)
\end{align}

である。なのでZOHによる離散化まで考慮すると、 \sigmaシグモイド関数として

 
\begin{align}
\bar{A} _ t  &= \exp (\Delta A) \\
&= \exp (\log(1 + \exp (\mathrm{Linear}(x _ t)) (-1)) \\
&= \frac{1}{1 + \exp (\mathrm{Linear}(x _ t))} \\
&= \sigma (-\mathrm{Linear}(x _ t)) \\
&= 1 - \sigma (\mathrm{Linear}(x _ t))
\end{align}

であり、

 
\begin{align}
\bar{B} _ t &= (\Delta A) ^ {-1} (\exp (\Delta A) - I) \Delta B \\
&= (-\Delta) ^ {-1} (\exp (\Delta A) - 1) \Delta 1 \\
&= 1 - \exp (\Delta A) \\
&= 1 - \bar{A} \\
&= \sigma (\mathrm{Linear}(x _ t))
\end{align}

となる。これはつまり

 
\begin{cases}
g _ t = \sigma ( \mathrm{Linear}(x _ t)) \\
h _ t = (1 - g _ t) h _ {t - 1} + g _ t x _ t
\end{cases}

ということである。

 つまり、Mambaの動作解釈として、Gateによって入力をどの程度受け入れるか、ということを制御できるということになる。Selective CopyタスクやInduction Headsタスクなどをやれるのがそういう機能だと見なせる。ただし、結局言語モデルとしての高級な機能にどこまで寄与するものなのだかは、正直よくわからない。

所感

 しばらくMambaの論文等を追ってみて、ある程度はわかってきた気がする。これが強いアーキテクチャとして有力になっていくのかどうかはわからないが、結局RNN的に推論時とても軽量である性質はエッジデバイスで動かすことなどを考えると必要になってくるのではないか。

 個人的には、さらに階層的な時間ステップを導入していくことに興味がある。言語トークンのステップに縛られない、あるいはセンサデータの入力周波数に縛られないような動作が重要になると信じたいものだ。