CarRacing-v3環境で強化学習(その6)DACER2もどきの実装

 \newcommand{\bs}{\boldsymbol}

Soft Actor Criticでの学習が安定してきたので、SACにおける方策を拡散モデルで表現することを試します。(拡散モデルといいつつ、実際のところはFlow Matching的な表現にしています)

拡散モデル/Flow Matchingでの方策は、行動のサンプリングは再パラメータトリックでできますが、ある行動  \bs{a} に対して  \log \pi(\bs{a}|\bs{s}) を計算するというのは(素朴には)できず、エントロピー正則化が上手く実現できません。

拡散モデル方策にエントロピー正則化的な手法を持ち込む方法はたくさんありますが、今回は Enhanced DACER Algorithm with High Diffusion Efficiency という論文の手法から部分的にアイデアを借用してみます。

DACER2では、方策を2種類の損失を使って学習します。

  • (1)メインの損失 : 行動価値最大化
  • (2)補助損失 : 行動価値の勾配マッチング

(1)の方は再パラメータトリックを使う方策でよくあるように

 \displaystyle
\begin{align}
\mathcal{L} _ q(\theta) = \mathbb{E} \lbrack {-Q _ \phi(\bs{s}, \bs{a})} \rbrack
\end{align}

を損失として  \bs{a} に勾配を流して学習するものです。

(2)の方は、拡散モデルがランジュバンモンテカルロ法からのサンプリングと考えたときに、確率分布のスコアを基に逆拡散を行っていくことを表現したものです。エントロピー最大化強化学習の設定では、行動価値関数の符号を逆にした  -Q をエネルギー関数としたボルツマン分布が方策に近しいものになると考えられるため、これを補助として採用します。

実践的には、DACER2では  \nabla _ \bs{a} Q をそのまま使うのではなく、方向のみを使うようにノルムで割って正規化したり、拡散タイムステップに応じて大きさを変えられるように重みを付けています。スコアを表現するネットワークを  S _ \theta として

 \displaystyle
\begin{align}
w(t) &= \exp(c \cdot t + d) \\
\nabla _ {\bs{a} _ t} Q _ \text{norm}(\bs{s}, \bs{a} _ t) &= \frac{\nabla _ {\bs{a} _ t} Q(\bs{s}, \bs{a} _ t)}{||\nabla _ {\bs{a} _ t} Q(\bs{s}, \bs{a} _ t)|| + \epsilon} \\
\mathcal{L} _ g(\theta) &= \mathbb{E} \left\lbrack ||w(t) \nabla _ {\bs{a} _ t} Q _ \text{norm}(\bs{s}, \bs{a} _ t) - S _ \theta (\bs{s}, \bs{a} _ t, t)|| \right\rbrack
\end{align}

(2)の詳細

補助損失の背景を整理します。

パラメータ  \theta を持つエネルギー関数  f _ \theta (\bs{x}):\mathbb{R} ^ d \to \mathbb{R} を使って、非正規化確率密度を  \gamma _ \theta(\bs{x}) = \exp (-f _ \theta(\bs{x})) と表した確率モデルをエネルギーベースモデルと呼びます。

 \displaystyle
\begin{align}
q _ \theta (\bs{x}) &= \exp(-f _ \theta(\bs{x})) / Z(\theta) \\
Z(\theta) &= \int \exp(-f _ \theta(\bs{y})) d \bs{y}
\end{align}

また、対数尤度  \log p(\bs{x}) の入力  \bs{x} についての勾配をスコア  \bs{s}(\bs{x}) と呼びます。

 \displaystyle
\begin{align}
\bs{s}(\bs{x}) := \nabla _ \bs{x} \log p(\bs{x}) : \mathbb{R} ^ d \to \mathbb{R} ^ d
\end{align}

エネルギーベースモデルである場合、

 \displaystyle
\begin{align}
\bs{s}(\bs{x}) &= \nabla _ \bs{x} \log p(\bs{x}) \\
&= \nabla _ \bs{x} \log \left(\exp(-f _ \theta(\bs{x})) / Z(\theta)\right) \\
&= -\nabla _ \bs{x} f _ \theta(\bs{x}) - \underbrace{\nabla _ \bs{x} Z(\theta)} _ {=0} \\
&= -\nabla _ \bs{x} f _ \theta(\bs{x})
\end{align}

であり、エネルギー関数についての負の勾配と一致します。

拡散モデルでは条件付き分布  p(\bs{x} _ t|\bs{x} _ 1)ガウス分布で表されるのでした。  \alpha _ t, \sigma _ t をノイズスケジュールを示すパラメータとして

 \displaystyle
\begin{align}
p(\bs{x} _ t|\bs{x} _ 1) = \mathcal{N}(\alpha _ t \bs{x} _ 1, \sigma _ t ^ 2 \bs{I})
\end{align}

したがって、このスコアは

 \displaystyle
\begin{align}
\nabla _ {\bs{x} _ t} p(\bs{x} _ t|\bs{x} _ 1) &= \nabla _ {\bs{x} _ t} \mathcal{N}(\alpha _ t \bs{x} _ 1, \sigma _ t ^ 2 \bs{I}) \\
&= -\frac{\bs{x} _ t - \alpha _ t \bs{x} _ 1}{\sigma _ t ^ 2} \\
&= -\frac{\bs{\epsilon}}{\sigma _ t}
\end{align}

となります。

Flow Matchingの場合

(この節の式変形が全て正しい自信はあまりないです。得に符号の入れ替えなどが細かくあってどこかで逆になっていそうな……)

今回は拡散モデルといいつつFlow Matching的な定式化で実装しているので、単にスコアを教師とするのではなく、そこから変形が必要になると思われます。

https://d2jud02ci9yv69.cloudfront.net/2025-04-28-diffusion-flow-173/blog/diffusion-flow/

のブログでまとめられている通り、ノイズ予測、データ予測、速度予測、フローマッチング予測については以下のような関係があります。

フローマッチング予測について整理すると

 \displaystyle
\begin{align}
\bs{u} &= \bs{\epsilon} - \bs{x} \\
&= \bs{\epsilon} - \frac{\bs{x} _ t - \sigma _ t \bs{\epsilon}}{\alpha _ t} \\
&= \left(1 + \frac{\sigma _ t}{\alpha _ t}\right) \bs{\epsilon} - \frac{1}{\alpha _ t} \bs{x} _ t
\end{align}

となります。これに、先のスコアとノイズの関係から  \bs{\epsilon} = -\sigma _ t \bs{s} _ t を代入して

 \displaystyle
\begin{align}
\bs{u} &= - \sigma _ t\left(1 + \frac{\sigma _ t}{\alpha _ t}\right) \bs{s} _ t - \frac{1}{\alpha _ t} \bs{x} _ t
\end{align}

となります。

上記のブログでは時刻0でデータ分布、時刻1でノイズ分布という設定でデータ分布からノイズ分布に進んでいく拡散過程での定式化を行っていますが、今回はFlow Matchingで、まず逆拡散過程と同じ方向になるので、先の式が符号逆になります。

 \displaystyle
\begin{align}
\bs{u} &= \sigma _ t\left(1 + \frac{\sigma _ t}{\alpha _ t}\right) \bs{s} _ t + \frac{1}{\alpha _ t} \bs{x} _ t
\end{align}

また、フローマッチングのスケジューリングから  \alpha _ t = t, \sigma _ t = 1 - t となるので、

 \displaystyle
\begin{align}
\bs{u} &= (1-t) \left(1 + \frac{1 - t}{t}\right) \bs{s} + \frac{1}{t} \bs{x} _ t \\
&= \frac{1-t}{t} \bs{s} _ t + \frac{1}{t} \bs{x} _ t
\end{align}

となります。

最後に今回は行動価値  Q の符号反転したものをエネルギー関数として考えているので(エネルギーは低いほど高い確率になるので、行動価値が高いほど方策の確率を上げたいとすると符号は逆にします)

 \displaystyle
\begin{align}
\bs{s} _ t = \nabla _ {\bs{a} _ t} Q _ \theta(\text{state}, \bs{a} _ t)
\end{align}

を代入して

 \displaystyle
\begin{align}
\bs{u} = \frac{1-t}{t} \nabla _ {\bs{a} _ t} Q _ \theta(\text{state}, \bs{a} _ t) + \frac{1}{t} \bs{a} _ t
\end{align}

これを基に、ノルムの正規化と時間での重みづけをしたものをターゲットとして、二乗誤差を補助損失として加えて学習を行います。

結果

 補助損失がない場合だと、学習の序盤で完全におかしな方策になってしまい、以降まともな報酬が得られないため復帰することもない、ということが多々ありました。補助損失を付与するとそういうことがなくなります。

 ただし、実装の正しさに自信はなく、特に符号が怪しいとは思っているので試しに勾配の符号を反転してみても、学習がおかしくならないどころかむしろちょっと性能良かったりもします。

 序盤で方策が潰れないだけの適当なノイズとして作用しているだけなのでは? という疑いも生じてくるところですが、とりあえず100kステップ以内にスコア900超えを達成できたため、性能としてはこれで満足という感じもします。

 1ステップあたりの更新回数も上げたため、PPOよりサンプル効率がハッキリ良くなっています。

 バッチサイズも16くらいまで下げてもなんとか学習できることはあり、今後ネットワークを巨大化していったときにも誤魔化せる可能性が出てきたので、拡散モデルベースの方策を用いたSACがある程度使えるようになってきたという感触です。