Direct Preference Optimizationという手法があるらしいです。
LLM訓練の最終段ではRLHF(Reinforcement Learning from Human Feedback)として、人手で良し悪しを評価してあるデータセットを使って(1)報酬モデルの学習 (2)それを用いた強化学習 ということが行われたりします。冒頭の論文では、この2ステップは実は適切な損失関数を使った最尤法での学習だけで行えると主張しているようです。報酬関数の推定をすっ飛ばせるという話が面白そうだったので詳しく読んでみます。
前提:そもそもRLHFを確認
まずRLHFとしてはFine-tuning language models from human preferencesに則るそうです。これだと次の3段階だとしているようです。
- SFT(Supervised Fine-Tuning)
- Preference Sampling and Reward Learning
- Reinforcement-Learning Optimization
(1) Supervised Fine-Tuning
事前学習されたLanguage Modelから、更に関心のあるタスクでFine-Tuningを行うステップとなります。単純に最尤法で学習して、を得ます。
(2) Reward Modelling
SFTのモデルを使って、様々な入力に対して2種類の回答を得ます。この2つについて、良い方を、悪い方をとします。またこの良い悪いの関係をと表現します。この嗜好は潜在的な報酬関数によって生成されると仮定されますが、この報酬関数は未知のものです。
こういった嗜好のモデル化には様々な手法があるらしいですが、ここではその中からBradley-Terryモデル[5]を使うこととします。Bradley-Terryモデルでは、人間の嗜好分布が次のように書けるとしています。
要するに報酬にexpをかけたものの割合で考えればいいらしいです。
ここで、我々の手元にはからサンプリングされた固定データセットがあるとします。報酬関数を適当なパラメータを持つ関数とモデル化すると、最尤法でパラメータを推定することができます。問題を二値分類として仮定して、負の対数尤度
を最小化すればいいです。(式からへの変形は、分母分子をで割ればすぐです)。は(標準)ロジスティック関数です。
(3) RL Fine-Tuning Phase
学習した報酬関数を使って言語モデル自体を訓練します。もとの方策から外れすぎないように正則化を係数で入れて
を考えてPPOとかで最適化します。
本題
報酬関数から得られる最適方策を解析的に真面目に考えるといろいろ導けるらしいです。
まず、先行研究[25, 26]に従うと、先の式に対する最適解は次のようになることが簡単に示せるとのことです。
ここでは分配関数
です。
付録A.1に証明が書いてあるので先にそれを見ましょう。
まず、式を変形していきます。
(最後で急にのlogが外れて符号が反転するのはどうしてでしょう。しばらく考えたけどわかりませんでした)
結局はによらない関数であるため、argminを考えるときには関係ありません。そしてKLダイバージェンスが最小値になるのはまさにそれらが一致しているときなので、これで証明されました。
さて、最適な方策
がわかったので、これをについて整理してみます。
となります。真の報酬を考えたときにもこの式が成立します。そして、式で表現したように、Bradley-Terryモデルの下では最適化したい負の対数尤度が報酬の差だけに依存するので、分配関数の部分が打ち消されます。(ここからの式変形がわざわざ付録A.2で書かれていますが、あまりにも自明なのでなんでこれがあるのかがよくわかりません)。単純に式に入れ込んで
を損失関数として学習させればいいそうです。Pythonでのコード的に書くと
import torch.nn.functional as F def dpo_loss(pi_logps, ref_logps, yw_idxs, yl_idxs, beta): """ pi_logps: policy logprobs, shape (B,) ref_logps: reference model logprobs, shape (B,) yw_idxs: preferred completion indices in [0, B-1], shape (T,) yl_idxs: dispreferred completion indices in [0, B-1], shape (T,) beta: temperature controlling strength of KL penalty Each pair of (yw_idxs[i], yl_idxs[i]) represents the indices of a single preference pair. """ pi_yw_logps, pi_yl_logps = pi_logps[yw_idxs], pi_logps[yl_idxs] ref_yw_logps, ref_yl_logps = ref_logps[yw_idxs], ref_logps[yl_idxs] pi_logratios = pi_yw_logps - pi_yl_logps ref_logratios = ref_yw_logps - ref_yl_logps losses = -F.logsigmoid(beta * (pi_logratios - ref_logratios)) rewards = beta * (pi_logps - ref_logps).detach() return losses, rewards
とのことです。