Direct Preference Optimizationを読む(その1)

 Direct Preference Optimizationという手法があるらしいです。

 LLM訓練の最終段ではRLHF(Reinforcement Learning from Human Feedback)として、人手で良し悪しを評価してあるデータセットを使って(1)報酬モデルの学習 (2)それを用いた強化学習 ということが行われたりします。冒頭の論文では、この2ステップは実は適切な損失関数を使った最尤法での学習だけで行えると主張しているようです。報酬関数の推定をすっ飛ばせるという話が面白そうだったので詳しく読んでみます。

前提:そもそもRLHFを確認

 まずRLHFとしてはFine-tuning language models from human preferencesに則るそうです。これだと次の3段階だとしているようです。

  1. SFT(Supervised Fine-Tuning)
  2. Preference Sampling and Reward Learning
  3. Reinforcement-Learning Optimization

(1) Supervised Fine-Tuning

 事前学習されたLanguage Modelから、更に関心のあるタスクでFine-Tuningを行うステップとなります。単純に最尤法で学習して、 \pi ^ {\mathrm{SFT}}を得ます。

(2) Reward Modelling

 SFTのモデルを使って、様々な入力xに対して2種類の回答 y _ 1, y _ 2を得ます。この2つについて、良い方を y _ w、悪い方を y _ lとします。またこの良い悪いの関係を y _ w \succ y _ lと表現します。この嗜好は潜在的な報酬関数 r ^ * ( y, x )によって生成されると仮定されますが、この報酬関数は未知のものです。

 こういった嗜好のモデル化には様々な手法があるらしいですが、ここではその中からBradley-Terryモデル[5]を使うこととします。Bradley-Terryモデルでは、人間の嗜好分布 p ^ *が次のように書けるとしています。

 
p ^ * (y _ 1 \succ y _ 2 | x) = \frac{\exp(r ^ * ( x, y _ 1 ))} {\exp(r ^ * ( x, y _ 1 )) + \exp(r ^ * ( x, y _ 2 ))}
\qquad(1)

要するに報酬にexpをかけたものの割合で考えればいいらしいです。

 ここで、我々の手元には p ^ *からサンプリングされた固定データセット \mathcal{D} = \lbrace x ^ {(i)}, y ^ {(i)} _ w, y ^ {(i)} _ l \rbrace ^ N _ {i = 1}があるとします。報酬関数 r _ \phi (x, y)を適当なパラメータを持つ関数とモデル化すると、最尤法でパラメータを推定することができます。問題を二値分類として仮定して、負の対数尤度

 
\mathcal{L} _ R (r _ \phi, \mathcal{D}) = - \mathbb{E} _ {(x, y_w, y_l) \sim \mathcal{D}} \lbrack \log \sigma (r _ \phi(x, y _ w) - r _ \phi (x, y _ l)) \rbrack
\qquad(2)

を最小化すればいいです。(式 (1)から(2)への変形は、分母分子を \exp(r ^ * ( x, y _ 1 ))で割ればすぐです)。 \sigmaは(標準)ロジスティック関数です。

(3) RL Fine-Tuning Phase

 学習した報酬関数を使って言語モデル自体を訓練します。もとの方策 \pi _ \mathrm{ref} = \pi ^ \mathrm{SFT}から外れすぎないように正則化を係数\betaで入れて

 
\max _ {\pi _ \theta} \mathbb{E} _ {x \sim \mathcal{D}, y \sim \pi(y | x)} \lbrack r _ \phi (x, y) \rbrack
 - \beta \mathbb{D} _ {\mathrm{KL}} \lbrack \pi _ \theta ( y | x ) || \pi _ \mathrm{ref} (y | x) \rbrack
\qquad(3)

とします。これは微分不可能なので、強化学習として

 
r(x, y) = r _ \phi(x, y) - \beta (\log \pi _ \theta (y | x) - \log \pi _ \mathrm{ref} (y | x))

を考えてPPOとかで最適化します。

本題

 報酬関数から得られる最適方策を解析的に真面目に考えるといろいろ導けるらしいです。

 まず、先行研究[25, 26]に従うと、先の式(3)に対する最適解は次のようになることが簡単に示せるとのことです。

 
\pi _ r ( y|x) = \frac{1}{Z(x)} \pi _ \mathrm{ref}(y|x) \exp\left(\frac{1}{\beta} r(x, y) \right)
\qquad(4)

 ここで Z(x)は分配関数

 
Z(x) = \sum _ y \pi _ \mathrm{ref} \exp \left( \frac{1}{\beta} r(x, y) \right)

です。

 付録A.1に証明が書いてあるので先にそれを見ましょう。

 まず、式 (3)を変形していきます。

 
\max _ {\pi _ \theta} \mathbb{E} _ {x \sim \mathcal{D}, y \sim \pi(y | x)} \lbrack r _ \phi (x, y) \rbrack - \beta \mathbb{D} _ {\mathrm{KL}} \lbrack \pi _ \theta ( y | x ) || \pi _ \mathrm{ref} (y | x) \rbrack \\
= \max _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \mathbb{E} _ {y \sim \pi(y|x)} \left\lbrack r(x, y) - \beta \log \frac{\pi(y|x)}{\pi _ \mathrm{ref} (y|x)} \right\rbrack \\
= \min _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \mathbb{E} _ {y \sim \pi(y|x)} \left\lbrack \log \frac{\pi(y|x)}{\pi _ \mathrm{ref} (y|x)} - \frac{1}{\beta} r(x, y) \right\rbrack \\
= \min _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \mathbb{E} _ {y \sim \pi(y|x)} \left\lbrack \log \frac{\pi(y|x)}{\pi _ \mathrm{ref} (y|x)} - \log \exp \left( \frac{1}{\beta} r(x, y) \right) \right\rbrack \\
= \min _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \mathbb{E} _ {y \sim \pi(y|x)} \left\lbrack \log \frac{\pi(y|x)}{\pi _ \mathrm{ref} (y|x) \exp \left( \frac{1}{\beta} r(x, y) \right) } \right\rbrack \\
= \min _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \mathbb{E} _ {y \sim \pi(y|x)} \left\lbrack \log \frac{\pi(y|x)}{\frac{1}{Z(x)} \pi _ \mathrm{ref} (y|x) \exp \left( \frac{1}{\beta} r(x, y) \right) } \frac{1}{Z(x)} \right\rbrack \\
= \min _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \mathbb{E} _ {y \sim \pi(y|x)} \left\lbrack \log \frac{\pi(y|x)}{\frac{1}{Z(x)} \pi _ \mathrm{ref} (y|x) \exp \left( \frac{1}{\beta} r(x, y) \right) } - \log Z(x) \right\rbrack \\
= \min _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \mathbb{E} _ {y \sim \pi(y|x)} \left\lbrack \log \frac{\pi(y|x)}{\pi _ r (y|x)} - \log Z(x) \right\rbrack \\
= \min _ {\pi} \mathbb{E} _ {x \sim \mathcal{D}} \left\lbrack \mathbb{D} _ \mathrm{KL}(\pi (y|x) || \pi _ r (y|x)) + Z(x) \right\rbrack
\qquad(14)

 (最後で急に Z(x)のlogが外れて符号が反転するのはどうしてでしょう。しばらく考えたけどわかりませんでした)

 結局 Z(x) \piによらない関数であるため、argminを考えるときには関係ありません。そしてKLダイバージェンスが最小値になるのはまさにそれらが一致しているときなので、これで証明されました。

 さて、最適な方策

 
\pi _ r ( y|x) = \frac{1}{Z(x)} \pi _ \mathrm{ref}(y|x) \exp\left(\frac{1}{\beta} r(x, y) \right)
\qquad(4)

がわかったので、これを r(x, y)について整理してみます。

 
r(x, y) = \beta \log \frac{\pi _ r (y|x)}{\pi _ \mathrm{ref}(y|x)} + \beta \log Z(x)
\qquad(5)

となります。真の報酬 r ^ *を考えたときにもこの式が成立します。そして、式 (2)で表現したように、Bradley-Terryモデルの下では最適化したい負の対数尤度が報酬の差だけに依存するので、分配関数の部分が打ち消されます。(ここからの式変形がわざわざ付録A.2で書かれていますが、あまりにも自明なのでなんでこれがあるのかがよくわかりません)。単純に式 (2)に入れ込んで

 
\mathcal{L} _ \mathrm{DPO} (\pi _ \theta; \pi _ \mathrm{ref})= - \mathbb{E} _ {(x, y_w, y_l) \sim \mathcal{D}} \left\lbrack \log \sigma \left( \beta \log \frac{\pi _ \theta (y _ w|x)}{\pi _ \mathrm{ref}(y _ w|x)} - \beta \log \frac{\pi _ \theta (y _ l|x)}{\pi _ \mathrm{ref}(y _ l|x)} \right) \right\rbrack
\qquad(7)

を損失関数として学習させればいいそうです。Pythonでのコード的に書くと

import torch.nn.functional as F
def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta):
    """
    pi_logps: policy logprobs, shape (B,)
    ref_logps: reference model logprobs, shape (B,)
    yw_idxs: preferred completion indices in [0, B-1], shape (T,)
    yl_idxs: dispreferred completion indices in [0, B-1], shape (T,)
    beta: temperature controlling strength of KL penalty
    Each pair of (yw_idxs[i], yl_idxs[i]) represents the
    indices of a single preference pair.
    """
    pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs]
    ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs]

    pi_logratios = pi_yw_logps - pi_yl_logps
    ref_logratios = ref_yw_logps - ref_yl_logps

    losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios))
    rewards = beta * (pi_logps - ref_logps).detach()
    return losses, rewards

とのことです。