Mamba-2読む Section 2~5

 ※ このブログ記事筆者の理解・説明には誤りが含まれている可能性があります

 以下の論文を読む。今回はSection 3から5あたりの、SSMとAttentionの双対性についての記述を読解する。

特に断りのない限り、式番号や画像番号は上記論文に合わせる。

概要

 状態空間モデル(State Space Model : SSM)の中の一種と、線形Attentionの中の一種が、Semiseparable行列による系列変換という形で結びつくことを示す。

 状態空間モデルは、系列長  T に比例する計算量で系列変換を実行できるので、見た目上は  T ^ 2 に比例するAttention系よりも高速だが、愚直にやると並列性が低いので特に学習時の系列処理をGPUなどで計算する場合に遅い。この問題に対して、

  • 状態空間モデルのうち行列  A, B, C が入力によらないタイプのものであるLinear Time Invariance (LTI)では畳み込みにより効率的に計算できる
  • Mamba-1では、並列スキャンという方法によって部分的に並列化する
  • Mamba-2では、SSMによる変換を行列積で書けるように式変形を行う

ということになる。多少計算量的に不利になったとしても、行列積という並列性が高くアクセラレータの親和性の高い方法で書けると実践的には速くなることもある。ここで、行列積として書いたものが線形Attentionの特殊な形と同一視できることがわかるので、SSMの線形形式とAttentionの形式を双対として捉えている。

[Section 3] 状態空間モデルの変形

 離散化された状態空間モデルは以下のように定式化できる。状態遷移行列  A \in \mathbb{R} ^ \mathrm{(T, N, N)} , 入力行列  B \in \mathbb{R} ^ \mathrm{(T, N)} , 出力行列  C \in \mathbb{R} ^ \mathrm{(T, N)} を使い、隠れ状態  h \in \mathbb{R} ^ \mathrm{(T, N)} を通して  x \in \mathbb{R} ^ \mathrm{T} \mapsto y \in \mathbb{R} ^ \mathrm{T} を行う。

 \displaystyle
\begin{align}
h _ t &= A _ t h _ {t - 1} + B _ t x _ t \tag{2a}\\
y _ t &= C _ t ^ \top h _ t \tag{2b}
\end{align}

それぞれ添字は最初の時間軸での値を取る。

 これを時間方向に展開して行列積で書けるようにしたい。  h _ 0 = B _ 0 x _ 0 から展開すると

 \displaystyle
\begin{align}
h _ t &= A _ t \dots A _ 1 B _ 0 x _ 0 + A _ t \dots A _ 2 B _ 1 x _ 1 + \cdots A _ t A _ {t - 1} B _ {t - 2} x _ {t - 2} + A _ t B _ {t - 1} x _ {t - 1} + B _ t x _ t \\
&= \sum _ {s = 0} ^ {t} A _ {t:s} ^ \times B _ s x _ s
\end{align}

となる。ここで  A _ {t:s} ^ \times = A _ t A _ {t - 1} \dots A _ {s + 1} である。行列  C _ t をかけることも考慮すると、

 \displaystyle
\begin{align}
y _ t &= \sum _ {s = 0} ^ t C _ t ^ \top A _ {t:s} ^ \times B _ s x _ s \\
y &= \mathrm{SSM}(A, B, C)(x) = M x \tag{3}\\
M _ {ji} &:= C _ j ^ \top A _ j \cdots A _ {i + 1} B _ i
\end{align}

であるので、行列  M による行列積として表現できることがわかった。そして一般に、行列が  M _ {ji} := C _  j ^ \top A _ j \cdots A _ {i + 1} B _ i として書けるとき、これはSequentially semiseparable(SSS)表現として知られたものであり、画像のようなものである。

Figure 2

 いくつか重要な事実を挙げると、

定義3.1 下三角行列  M は下三角部分に含まれるすべての部分行列の階数が最大で  N である場合に N-Semiseparable行列という。  N をSemiseparable行列の階数と呼ぶ。

補題3.3 N-SSS行列は N-Semiseparableである。

命題3.4 全てのN-Semiseparable行列はN-SSS表現を持つ。

定義3.5 状態サイズが  N である状態空間モデル  y = \mathrm{SSM}(A, B, C)(x) はN-SSS行列による積  y = \mathrm{SSS}(A, B, C) \cdot x と書ける。

 そしてここからSection 5の内容にも部分的に踏み込むが、状態遷移行列  A  _ t に対して強い制約を入れることを考える。  A _ t = a _ t I と、単位行列スカラー倍したものまで単純化すると  M

 \displaystyle
\begin{align}
M _ {ji} = A _ {j:i} \cdot (C _ j ^ \top B _ i)
\end{align}

と書けるので、 A について

 \displaystyle
\begin{align}
L = \mathrm{1SS}(a _ {0:T}) := \begin{bmatrix}
1 & & & & &  \\
a_1 & 1 & & & & \\
a_2 a_1 & a_2 & 1 & & & \\
\vdots & \vdots & \vdots & \ddots & & \\
a_{T-1} \cdots a_1 & a_{T-1} \cdots a_2 & \cdots & a_{T-1} & 1
\end{bmatrix} \tag{6}
\end{align}

(説明の都合上、元の論文の式(6)から変数名を置き換えた)

を考えると

 \displaystyle
\begin{align}
M &= L \circ (C B ^ \top)
\end{align}

であり、つまり系列変換全体として見ると

 \displaystyle
\begin{align}
y = (L \circ (C B ^ \top)) \cdot x \tag{*}
\end{align}

である。

[Section 4] Attentionの変形

Source系列の長さを  \mathrm{S} 、 Target系列の長さを  \mathrm{T} 、特徴次元を  \mathrm{N} 、ヘッド次元  \mathrm{P} としたとき、基本的なAttentionの定式化は以下のようになる。

 \displaystyle
\begin{align}
Q &= \mathrm{input} \quad & \mathrm{(T, N)} \\
K &= \mathrm{input} \quad & \mathrm{(S, N)} \\
V &= \mathrm{input} \quad & \mathrm{(S, P)} \tag{9} \\
G &= Q K ^ \top \quad & \mathrm{(T, S)} \\
M &= f(G) \quad & \mathrm{(T, S)} \\
Y &= M V \quad & \mathrm{(T, P)}
\end{align}

一般的に  f はソフトマックス関数となる。その場合、  G の各要素に指数関数を適用し、  \mathrm{S} 軸について正規化することになる。正規化は全てが1であるベクトルを追加で考えるとわりと自明なのでいったん無視できる。また指数関数を適用するところは、カーネル変換として捉えられる。つまり  \exp (Q K ^ \top) = \psi(Q) \psi(K) ^ \top となる特徴マップ  \psi が存在する。先に  Q, K に特徴マップを適用したものを  Q, K と思い直すことで、関数  f の適用も無視できる。カーネルとしていろいろなものを考えられるという点もまた発展的な話としてある。

以上の議論を踏まえるとAttentionの定式化が

 \displaystyle
\begin{align}
y = (Q K ^ \top) \cdot V
\end{align}

となる。ここで、特にSource系列=Target系列のときなどは、Attentionに因果マスク  L を導入することが自然となる。これは

 \displaystyle
\begin{align}
y = (L \circ (Q K ^ \top)) \cdot V \tag{10}
\end{align}

ということになる。ここでマスク  L は一般的には要素が1の下三角行列になるが、そうでないものを考えることもできる。

Figure 3

[Section 5] Structured State-Space Duality (SSD)

式(10)において L として1-Semiseparable行列を考えとき、これを式(*)を見比べると同じになっているため、つまり「  A単位行列スカラー倍であるという制約を入れたSSM」と「マスクとして1-Semiseparable行列を考えたAttention」は同一視できる。

Figure 4