Bigger, Better, Fasterのコードを動かす

 コードが公開されているので動かしてみる。

 venvで行ったので、おおよその手順は

git clone https://github.com/google-research/google-research
cd bigger_better_faster

python3 -m venv .env
source .env/bin/activate

pip3 install -r requirements.txt

export PYTHONPATH="$(readlink -f ../):$PYTHONPATH"

python3 -m bbf.train \
    --agent=BBF \
    --gin_files=bbf/configs/BBF.gin \
    --base_dir=./bbf_result \
    --run_number=1

という感じ。しかし途中途中でいろいろ直さないと動かなかった。

(1)dopamine_rlで落ちる

 BBFは内部的にdopamine_rlというライブラリのNoisyNetworkを読み込んでいるところがあるが、その読み込み時に

AttributeError: module 'jax.interpreters.xla' has no attribute 'DeviceArray'

で落ちる。これは最新のGitHubコードだと直っているのでこれと同じ修正をローカルのものに入れることで回避できる。

[JAX] Replace uses of jax.interpreters.xla.DeviceArray with jax.Array. · google/dopamine@a9a8fc0 · GitHub

 試行錯誤中だと上手くいかなかったが、単純にpipでdopamineライブラリのバージョンを上げても回避できるかもしれない。

(2)spr_agent.pyで落ちる

 dopamine_rlを直したら次はspr_agent.pyのところで落ちた。

ValueError: Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value {Pythonの辞書形式}

というエラーが出る。要するにdictとFrozenDictの違いで型があってないということらしいので変換を入れると良い。

 その他、get_default_device_assignmentというやつも削除されたせいかエラーが出る。

Remove use of get_default_device_assignment(). · google/jax@3ce5cb6 · GitHub

 なんだかよくわからないが、これらを勘で次のように直した。(コピペ後に微修正したのでパッチファイルとしては機能しないと思われる。該当行を手動で修正する)

diff --git a/bigger_better_faster/bbf/agents/spr_agent.py b/bigger_better_faster/bbf/agents/spr_agent.py
index 25306dcb6..56c696091 100755
--- a/bigger_better_faster/bbf/agents/spr_agent.py
+++ b/bigger_better_faster/bbf/agents/spr_agent.py
@@ -45,9 +45,7 @@ def _pmap_device_order():
   if jax.process_count() == 1:
     return [
         d
-        for d in xb.get_backend().get_default_device_assignment(
-            jax.device_count()
-        )
+        for d in xb.get_backend().local_devices()
         if d.process_index == jax.process_index()
     ]
   else:
@@ -177,6 +175,13 @@ def interpolate_weights(
   if keys is None:
     keys = old_params.keys()
   for k in keys:
+    old_params = FrozenDict(old_params)
+    new_params = FrozenDict(new_params)
     combined_params[k] = jax.tree_util.tree_map(combination, old_params[k],
                                                 new_params[k])
   for k, v in old_params.items():
@@ -1309,6 +1314,10 @@ class BBFAgent(dqn_agent.JaxDQNAgent):
         optax.masked(optimizer, self.head_mask),
     )
 
+    self.online_params = FrozenDict(self.online_params)
     self.optimizer_state = self.optimizer.init(self.online_params)
     self.target_network_params = copy.deepcopy(self.online_params)
     self.random_params = copy.deepcopy(self.online_params)

(3)JaxがGPUを使っていない

 実験が動き出したようなログが流れ続けるが、nvidia-smiで見てもプロセスが乗っておらず、GPU利用率も低かった。

python3 -c "import jax; print(jax.devices())"

で確認したところ

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

 と出た。公式のインストール手順https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-via-pip-easier:titileを見て、自分の環境はCUDA-12が入っているので

pip3 install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

とした。これでpython3 -c "import jax; print(jax.devices())"の表示は

[cuda(id=0), cuda(id=1)]

となった。(2枚挿さっているPCなので2枚分ちゃんと出た)

動作結果

 設定はデフォルトのままで実行してみた。configをよく見ると、BBFAgent.replay_ratio = 64となっている(https://github.com/google-research/google-research/blob/4da7251308decf0a61807c09a8f4c087cbd06310/bigger_better_faster/bbf/configs/BBF.gin#L33)。

 ChopperCommandという、他の手法だとあまりスコアが伸びないゲームでHuman Normalized Scoreが1を超える結果を得られた。

 実験時間も3090GPUを1枚だけ使って2.6hだったため、かなり短い。

Transformer系世界モデル手法IRISとTWMの比較

 以下の2つの論文を比較する。

 共通点としてどちらも

となっている。

 提案手法に沿って、呼称は以下のようにする

手法

IRIS

IRIS Figure 1

TWM

TWM Figure1

 まず手法の概要として共通点は

  • 観測を離散AutoEncoderで複数の離散トークンにする
  • 観測の離散トークンと行動などをTransformerに入力する

大雑把には似ている。以下違いを見ていく。

Transformerへの入力

  • IRISの入力は「観測の離散トークン + 行動」
  • TWMの入力は「観測の離散トークン + 行動 + 報酬」

Transformerの使い方

  • IRISは自己回帰的に次状態の離散トークン・報酬・終了判定を予測する
  • TWMは行動に相当する位置のトークンの変換結果を hとして、そこに情報が集約されていると考え、 hから各種次の予測を行う

離散化Encoder

要素 IRIS TWM
入力サイズ 64x64x3(スタックしているかは不明。たぶんしているのでは) 64x64x4 (直近4フレームをスタック・グレースケール化)
離散化トークンのクラス数 512 32
1観測を何個のトークンにするか 16 32
損失の工夫 Perceptual Lossあり consistency lossというものを提案・導入

 離散化クラス数などはやや異なるが、概ね一致。

Transformer

要素 IRIS TWM
ベース実装 minGPT Transformer-XL
次元 256 256
層数 10 10
head数 4 4
入力する最大タイムステップ数 20 16

 Transformerパラメータはかなり近い。

方策の入力

  • IRIS:復元した観測
  • TWM:観測の表現(離散トークン)

 TWMの方で良し悪しについては記述がある。

  • 観測そのまま or 復元した観測:画像での学習なので方策が安定しやすい
  • 観測の表現:遷移予測の学習によっても変更が加えられていくので方策モデルがその変化にも追従しなければいけない。エントロピー正則化と一貫性損失を入れると安定する
  • 観測の表現 + 遷移の履歴(Dreamer形式):推論時に遷移モデルも実行するコストが発生する

 IRISの方策はさも各ステップ独立で推論できそうだが、Appendixを見るとLSTMを使っているとあり、過去系列も見ていそう。

実装

実験にかかる時間

  • IRIS
    • 実験全体は8基のA100で実行
    • 1枚のGPUで2つのAtari環境を実行して、おおよそ7日かかる。つまり1環境あたりだと3.5日
    • おそらく上記が5シードまとめての時間なので、1環境1シードだと0.7日 = 16.8時間
  • TWM(おそらく1環境・1シードあたりについての記述)
    • A100を1台だと約10時間
    • 3090を1台だと約12時間
    • Transformer-XLのメモリ機構を使わないVanilla Transformerを使うと約1.5倍かかる

 IRISの方が若干遅いのは、(1) TransformerがVanillaのもの (2)Actor-CriticにLSTMセルが使用されている というところがありそう。とはいえそこまで大差ではない。

評価

 どちらもDeep Reinforcement Learning at the Edge of the Statistical Precipiceに従っている。(おそらくAtari 100kベンチマークについての評価用OSSツールが利用されているのだと思われる。良さそう)

 比較手法もだいたい同じ。

IRIS Figure 5

TWM Figure 3

 どちらもMedianは幅広い。MeanとInterquantile Meanでは良い。

 各ゲームでのスコア表をマージすると以下のようになった。色付けは各行ごとに緑が最小、白が最大。

 (スプレッドシートhttps://docs.google.com/spreadsheets/d/1B1gMkAqUhcmGdsDui7_1lHjne3E3sOvXSg_4fIovNw4/edit?usp=sharing

 得意/苦手なゲームの傾向もやや似ていそう? ちょっとよくわからないか。EfficientZeroが強い。

その他・実験考察など

  • IRIS
    • Frostbite と Krullというステージ遷移がある場合に苦労する
    • フレームあたりの観測離散化トークン数を64に増やすと改善するゲームもある(もちろんその分世界モデルが重くなる)
    • 学習量を10Mに伸ばすと性能も伸びる
  • TWM
    • Actorのエントロピー正則化についても閾値を導入して、下回った場合のみ有効になるようにする
    • 過去のデータを何度もサンプリングして学習に利用しすぎないようにBalanced Dataset Samplingを行う
    • Attentionのヒートマップを見ると、報酬が1になるようなタイミングに重みがあったりするので、そういうタイミングへの学習が行われていそう
    • 世界モデルの学習系列を16から4に減らすと明確に性能が落ちる

所感

  • 細かい差異はともかく、Transformerに入れて世界モデルとして使うとこれくらいできますというのが確度の高い情報として提示された印象
  • それぞれの手法について、予測部分で行動まで出してそれをPolicyにするというのは不可能なんだろうか

強化学習における自分の興味範囲

 強化学習と一口に言ってもその範囲はとても広いので、自分はどこに興味があるのかを絞って考えたい。その点について改めて整理する。手法や工夫ではなく問題設定の方に着目する。

 まず、究極的な目標としては「実世界で動作できる知性を実現したい」となる。しかし、現実的に役に立つものを作りたいというよりも、知性を機械的に再現する手法を知りたいという気持ちの方が強い。

 現実的に有用なものを目指す場合、実世界と相互作用するボディの作り込みや、MLOpsのようなデータ収集と学習の仕組み作りを整備することが近道になりそうで、それは個人でやるには難しく、有益なことはどうせ企業がやるだろうとも思う。個人の趣味としてやるなら、実用性よりも納得を優先したい。

 生物は「他のマシンで学習したパラメータを脳内にコピーしてくる」ということはできず、個体として振る舞う独立したこの脳だけで学習する必要があると思われる。それを完全再現することはとても難しいだろうが、「この技術が洗練されていけば実現できそうだな」と納得できるようなものがわかると良い。

 環境準備の手軽さを鑑みて、題材自体はAtari等々のようなゲームとするのが現実的だろう。ゲームではやりつつも、先を見据える意味で以下のような4つの要素を念頭に置きたくなる。

オンライン強化学習

 データを予め収集しておいた中で学習するオフライン設定ではなく、探索のことも考える必要があるオンラインの強化学習に興味がある。探索以外にも、環境と相互作用をしていて今起こった失敗を(理屈上の最短で)いつ学習で取り込めるかという問題意識や、学習中も同じFPSで動き続けられるかといった話題もありそうだが、とりあえずはオンラインで学習できていれば良いとする。

非分散学習

 強化学習手法の中には、環境を並列でたくさん生成して相互作用の効率を上げるものがあるが、最終的な目標が現実世界との相互作用ということになると現実世界は一つしかないのでそのような手法には頼れなくなる。もちろん、シミュレータで並列的に事前学習してから現実世界に持ち込むというのは有力な手段なのだろうが、興味の範囲からは外れる。環境のモデル自体を学習して、それを内部的に並列に動かして学習するのは良い。最終的に相互作用する真の環境が一つだけであることが重要になる。

非エピソディック

 現実世界での振る舞いを考えると、明確なエピソード区切りはないものと考えた方が自然に思える。擬似的な再現としては、ゲーム系の環境として1回終わったらすぐ次が始まると想定するだけで良い。ただ、連続的にクリアできるような環境だと、報酬の割引率を適切に設定しないと報酬和自体がかなり大きくなっていってしまい学習に悪影響が出そうだ。このあたりの報酬の大きさの不安定性にも興味がある。

非定常環境

 結局、オンライン学習が求められるのは環境が非定常だからという側面が大きそうなので、評価する上ではそのような設定を準備するのが良さそう。この場合、環境が切り替わっても得られる報酬が落ちにくい、あるいはすぐ高い報酬に行けることが重要になるので、単純に最高スコアが高いことだけでなくサンプル効率が良いものが重要になる。


 上記を念頭にやっていけたらなという所存。

週記 20231225~20231231

 今週は実装を進めようとしていたが、思ったようには進まなかった。

 目標としている変更は、DQNをベースとして

  • ネットワークを過去系列を入力に含むTransformerに変える
  • 上に伴って、ReplayBufferも系列として情報をサンプリングできるものに変える
  • そうするとReplayBufferで管理するもの、返り値が変わる
  • 学習が終わってテストするときも、毎回の状態をReplayBuffer(的ななにか)に溜めておく必要がある

となることが実装を進めてみてわかってきた。最初はStableBaselines3のDQNをベースにして実装していこうとしていたが、これの実装がかなり継承を駆使した作りになっており、ReplayBufferやテスト時の動作を変更しようとしたときに結構つらいことになってしまった。ちゃんと理解していれば修正するべき点は意外と少ないのかもしれないが、そもそもどこを理解するべきかが難しい。

 なのでcleanrlのdqnから出発するように方針転換した。こちらの方が見通しがよく、またそもそものベースDQNの収束も早かった(ハイパーパラメータのちょっとした違い?)ので一石二鳥だった。

 現状はまだこのdqnでどうやればネットワークとかReplayBufferをいじれるかという見通しを立てたところで止まっている。

 一応先週の目標は「見通しを立てる」ところまでではあったので想定通りの進み具合ではあるが、まぁ実際に実装までできないものだなという学びでもある。

その他

 Sutton & Bartoの強化学習(第2版)の翻訳されたものをようやく読み始めている。非エピソディック設定とか、従来的なQ学習(から派生するDQN)に触って、読む準備が多少できてきたか。

 しかし内容はかなり難しい。テーブルベースのところはなんとか、くらいではあるけど、関数近似が入ったところから完全に置き去りにされている。

 最後の章で課題として挙げられていたところは

  • 一つ目はオンラインでの学習。これはやっていきたい。深層学習がオンラインに向いてないと言及されているが、もっと良い手法があるとも感じられていないのでとりあえずこれでやるしかなさそう
  • 二つ目は汎用性を見据えた表現学習。これについても深層学習でについて悲観的な見方をされているが、なんとかなるんじゃないかな
  • 三つ目はプランニングのためのモデル(モデル自体の学習も含む)。明らかに難しい話だが、モデルの学習も現状は深層学習以外道が見えていない
  • 四つ目はタスク選択の自動化。これは初めて聞いた話でまだピンと来ていない
  • 五つ目は好奇心の導入。これは環境モデルの学習と近いところで行われるのではないかという予想。環境をできるだけ少ない労力で説明できるようになることを目指すのが、好奇心の学習と関連するのではないか
  • 六つ目は安全性。これに関しては粛々とやりましょう

という印象。

 とりあえず一周目としては丁寧に数式を追いかけるというよりも書かれているトピックを把握する程度の雑な読み方をしたので、ここから気になったところをまた数式手書きして辿ってみたい。

週記 20231218~20231224

読んだ本

  • 今井 むつみ,秋田 喜美『言語の本質-ことばはどう生まれ、進化したか』

 言語学には特に詳しくないし思想も持っていないのでオノマトペについて語られるあれこれは素直にそう思える。個人的に興味を惹かれたのはアブダクション推論部分についての仮説で、ちょうどLLMは同じ知識を逆転した形式で問うと精度が落ちるという話とリンクするのではないか。外れているかもしれなくとも勝手に外挿して当てはめてみようとした方が学習効率が高そう。でもそれをどうやって機械で実現する?

書いたもの

 補足としては、「対策3:入力系列を圧縮する」というものもある。たとえば将棋の盤面を9x9 = 81トークンとしてずっと扱うのは流石に無駄が多い。もっと少ないトークン数で局面を十全に表すことができそう。おそらく、ここでのトークンは直接Policy, Valueを計算できるほど抽象化された分散状態じゃなくて良いので、局面を一意に指定・復元できる情報があれば良い(?)

やっていること

 強化学習の実装周りは、強化学習が想像以上に難しくて撤退戦という趣になっている。とにかくちゃんと学習できることを確認できる地点まで戻るために、グリッド世界をGym形式環境として書き直し、Stable Baseline3の既存実装を使って学習ができるかどうか、というところまで戻ってきた。強化学習をやるならGymから逃げられない!

 一応学習はできるらしい。上手くいくPPOとDQNのデフォルトパラメータで10回回して、安定性も問題なさそうなことを確認した。

 グラフ化している指標としては、単に収益ではなく最適行動を取る確率にしている。ターゲット位置が離れた位置に生成されるとどうやっても収益1にすることは不可能なので、細かい最善を考えるのが面倒くさい。最適な行動とは、今がターゲット位置ならクリック行動だし、今がターゲット位置でないならターゲット位置に近づくものとなる(相対位置関係によっては2方向が正解になることがあり得る)。これだと正しく100%が上限となる。

 DQNではかなり安定した収束を見せているのだが、それでも理想的な行動を取る確率が100%にはなりきらず96%ほどで止まってしまうことが気になる。どういう状態で失敗しているのか、確認してみなければいけない。まだライブラリの操作に慣れていないので少し調べながら進める必要があるのが面倒だ。

 とはいえDQNの方は収束した後に崩壊することも少なさそうなので、グリッド離散行動に適しているのは性質からいってもこちらだろうし、とりあえずはDQNの改造という方針で考えてみたいところ。

来週の目標

 Stable Baseline3の互換形式で、エピソードをまたいだ系列を入力とするモデルを実装する目処をつける。

断想:系列入力ベースの強化学習

 最近は状態や報酬などを系列データとして扱う強化学習に興味が出ている。端的に言えばDecision Transformer1 のことになる。

 特に、エピソードをまたいだ(across-episodicな)長い系列を入れることに可能性を感じる。着目点は違うが、やっていることとしてはAlgorithm Distillation2に近い。個人的に期待するacross-episodicの良い点としては、体験した成功例・失敗例をそのまま有効活用してサンプル効率を高めることにある。明示的な体験の参照ができない状況では、ニューラルネットのパラメータに勾配法で知識が反映されるまで例を活かすことができないが、それはどうしても遅くなってしまうのではないかというところが気になっている。

 改善されると嬉しい別の例をコンピュータ将棋で挙げると、(ランダム性を入れずパラメータ更新もない状況の想定でも)全く同じ棋譜で全く同じ負け方をすることを回避できるようになることだと思う。完全に同じ失敗を繰り返すのはあまりに知的でない振る舞いで、見ていてとても気分が悪い。これをなんとかしたい。

 体験を外部情報の一種として捉えると、LLMにおける検索の利用(Retrieval Augmented Generation)と類似した側面もあるのかもしれない。知識をニューラルネットのパラメータへ反映させるには学習量が必要そうなので、プロンプトの参照情報として与えてしまうという発想になる。

課題

 この方針の課題は、今のニューラルネット(Transformer)では入れられる系列長が短いことだと思う。GPT-3.5を例にすると16Kが上限となっている。将棋でどの程度過去の棋譜を入れられるか概算すると、

1局だいたい200手と考えて、片方の手番だけを入れるとして1局あたり100局面で、各局面は9x9トークンになり、たとえば10局分入れようとすると81000トーク

となる。単純にやるとまともに入れられないことがわかる。それに、コンピュータ将棋では推論速度も重要になるので、なんとか動くけど遅いという状況ではメリットがほとんどなさそうだ。できるだけ速度は落とさないままに過去の情報を活用できる必要がある。

対策1: 系列長に対して線形な計算量のアーキテクチャを使う

 もちろんTransformerの、系列長に対して2乗の計算量が必要になってしまう性質は多くの人が改善したいと考えており、様々に研究がされている。最近だとMamba3であるとか、Based4であるといったものが有力なのではないかと囁かれているらしい。

 論文などを多少見てはみたが、状態空間モデルをいろいろこねくり回して上手いことやる手法は数式・概念とかが難しく、さらっと読んですぐに理解できるというものではなかった。ここに関しては、多少興味はあるが自分でめちゃくちゃ試行錯誤していくほどのモチベーションはないし、多くの人が取り組んでいくだろうから、その成果だけを後々享受させてもらうという姿勢で良い気もしている。

対策2: 入れる系列の選び方を工夫する

 先程は単純に1局から100局面を全て入力するという考えで系列長を試算したが、それはあまりにも愚直なやり方ではある。RAGのように、検索概念を導入して、現局面と類似する過去の局面を探して入力に追加するというのも有望だろう。

 また、直感的にヒント情報として重要なのは、良い方向であれ悪い方向であれ予測を大きく外した局面だと思われるので、過去の体験に対してValueの予測ハズレ具合で重み付け(優先順序付け)をするのも可能性があるのではないかと感じる。これは経験再生における優先順位付け(Prioritized Experience Replay5)の発想そのままではある。

 いずれにしても、少量の効果的な状態だけ上手く系列の一部に取り入れることで、計算効率をそこまで下げずに精度に寄与できると面白いと思う。とりあえずはこちらの方向性で考えてみたいか、という気分。

週記 20231211~20231217

 今週はDecision Transformerの実装をしていたが、あまり上手くいっていない。

 題材としては先週と同じで丸をクリックさせるタスクをやっており、ランダムエージェントで動かした100MステップのデータからDecision Transformerを学習させて、Returnに応じた方策が学習できている事自体は確認できた。

 しかし、動かすときにReturn1を目標として動作させると、Return1を得られる場合の行動しか取らないようになってしまう。つまり、カーソルが丸の中に入っていないのにクリックを連打している。確かに、最終的にはクリックで報酬が入るのでクリック動作自体が強くReturn1と結びついてしまうのはそうなのかもしれない。

 どうもTransformerが画像情報を上手く拾えていない気がする(のでReturnだけに依存した振る舞いをしている)。系列長を小さくするためにだいぶ大きいパッチサイズにしているのが良くないのかもしれないし、画像部分の位置エンコーディングとして学習可能なものではなく三角関数ベースでのx,y位置を指定するものの方が良いのかもしれない。あるいはCNNで埋め込みにいくか。このあたり結局どうした方が良いのかというのが難しい。

 簡単な題材から始めているつもりだったが、これでも全然簡単でなかったということがわかってきたので、さらに簡単にするためGUIから離れてグリッド世界でのActor-Criticからやり直している。

 3x3のグリッド世界で、これまでと同じように上下左右移動 + クリック動作で、自己位置とターゲット位置が重なっているときにクリックした場合だけ報酬を与えるということにする。3x3ならテーブルに乗っけられるのでそれでActor-Criticを書いて、ひとまずこれは1ステップのオンライン学習で上手くできることを確認した。

 ここからまた(1)CNNベースニューラルネット、(2)Transformer で、それぞれどうなるのかというのを見ていくことになるだろうか。グリッドを適宜広げていくと簡単にスパース報酬タスクになるので、この状態でいろいろ試行錯誤することになりそう。

競技プログラミング

 本当は日曜のコンテストまで終わってから書きたいところだったが、今日はAGCだったので原理的に無理だった。

  • ABC333 : パフォーマンス2190 レート1770(+58)

 前回に続いての成功で戻ってきた。これくらいが今の適正値なのではないか。もう少し下かな。

 FのDPがわりと早く解けているかもしれず、自分はひょっとしてDP得意な部類に入るのだろうか? まぁ競技プログラミング歴で言ったら長い方ではあると思うので、典型DPというのはまだ戦える分野なのかもしれない。

週記 20231204~20231210

 今週からGUI操作のプログラミングを始めている。

今週やったこと

 結局、機械にGUIを直接いじってもらうのがわかりやすいなという考えになって、GUIを操作させるプログラムを書いている。

 当面の目標としては「スクリーンショットを入力、マウス操作を出力としたできるだけEnd-to-Endなニューラルネットワークで将棋所を直接操作して指し、強化学習をしてランダム指しエージェントに勝つ」になるかと考えている。これを向こう1,2年くらいで達成できれば。

 GUI操作のプログラム自体は、UbuntuC++ならxlibを使えばそこまで大変でもなかった。ただ、学習がとことん大変に思えるのでまだまだ将棋をやる段階ではないと思い、まずはSiv3Dでマウス操作を必要とする簡単なタスクを実装して、それを強化学習で解かせるところから始めている。

 【タスク】ランダム位置に出現する青い丸の中をクリックすると報酬が得られる。1回青い丸をクリックすると別のランダム位置に飛ぶ。(なんかFPSのエイム練習みたいだ。最初はもっと小さい丸にするつもりだったが、あまりにもランダムエージェントが報酬を得られないので丸がどんどん大きくなってきた……)

 ちゃんと方策勾配ベースの強化学習を実装したのは初めてなので、わりと苦労している。今はなんとかギリギリ学習できたような気配があるという段階。

 方策勾配だけでなく、今回エピソードという区切りを設けていないので非エピソディックな設定というのも初めてになっている。報酬に割引が必要というのをようやく実感した。将棋みたいなエピソード区切りが明確なゲームだと割引を考える必然性はあまりないからなぁ。

 今は離散行動で学習させているが、本当はカーソル移動を連続値にしたい。その他やりたいことはポツポツと。


 Pythonが全然好きになれなくて、こういうプログラムでもC++を担ぎ出してくることになるのは結局良くないのではないかという気はしている。今後「なんらかの学習済みパラメータを元にして学習を始める」ということをやりたくなったときにかなり面倒くさそう。このあたりどうするか。

競技プログラミング

  • ARC:パフォーマンス1356 レート1680(-31)
    • とうとう1600台まで落ちてきた。直近6コンテスト連続でマイナス!
  • ABC:unratedにならなければ久しぶりにプラスになりそう。E問題で詰まったときわりと早めにF問題読みにいったのが良かった
  • AHC:理由は自分でもよくわからないが、全くやる気が出なかった。

その他

 少しずつ、休日に自分のやりたいプログラミングができる状態になってきている。強化学習の勘所をもっと理解したいと思いつつ、でもやっぱり強化学習って本当に効率良いのかなという疑いも強い。まぁいろいろ見つつですか。

週記 20231127~20231203

 あまり好ましくない業務が今週にまで食い込んでいたが、なんとかやりきったと思う。いや、若干不穏なところはいくらかあるが……。1年前から考えると状況は良くなっているのに、それでもこんなもんかという気持ちにはなってしまう。

DPO

 週の特に前半でDPOの論文をわりと時間かけて読んでいた。

 別にRLHFなんてやってみたことはないが、面倒くさそうではあり、その大変そうな工程を簡略化できるなら嬉しそう。それに、これ場合によっては一般の強化学習にも影響してくるのではないかと思った。報酬を定義より良い行動/悪い行動の順序付けの方が簡単、という状況はいくらかありそうだ。深い探索と浅い探索(あるいは探索なし)もそういう関係にならないかとか考えるけれど、ボードゲームなら報酬がわかりやすいからこんなことをする必要性はなさそうでもある。一般に、報酬を複雑に(多くの場合、学習を進みやすくさせるため細かく与えようとするから複雑になるのだと思う)するのではなく、疎な報酬から学習できるように頑張るべきではあるんだろうな。

JARVIS-1

 あとはJARVIS-1も若干面白そうではあり、勾配法のパラメータ学習なくても記憶部分に成功例を溜めていけばいくらかの学習(?)みたいな振る舞いにはなると。雑な目で見たら検索でLLMの性能を上げたい話と近いのだろうか(検索対象が外部データじゃなくて過去の体験になるだけ)。まぁ確かに何でもかんでもパラメータに反映させる必要もないのかなという気はする。両方が上手い感じに結びつくと嬉しいことになりそうではある。

 あとはシンボル操作みたいなところがな。論理性とかをどうやって実現すればいいのか。きっとどこかの頭いい人が良い方法を考えてくれると思うので、それを見てちゃんとおぉーとなれるようでいたい。

競技プログラミング

  • AtCoder Beginner Contest 331終了後:レート1711(-12)

 ローリングハッシュをセグメント木に乗せるゲームをあれだけの人が解けるのはすごいことだと思います。それでマイナスになるならそれはもう仕方ないという気分。

その他

 相変わらずやることに迷いがちで、いろいろなことを摘み食いしてはすぐ飽きて放り出す感じになっている。なにか数年単位でやることを決めたいと思って、もう2,3年経っているのではないか。まぁまぁこれが平常運転。

Direct Preference Optimizationを読む(その2)

 その1

でDPOの損失関数

 
\mathcal{L} _ \mathrm{DPO} (\pi _ \theta; \pi _ \mathrm{ref})= - \mathbb{E} _ {(x, y_w, y_l) \sim \mathcal{D}} \left\lbrack \log \sigma \left( \beta \log \frac{\pi _ \theta (y _ w|x)}{\pi _ \mathrm{ref}(y _ w|x)} - \beta \log \frac{\pi _ \theta (y _ l|x)}{\pi _ \mathrm{ref}(y _ l|x)} \right) \right\rbrack
\quad(7)

が導出できたので、この関数の性質を分析してみます。

勾配がどうなっているか

 まず微分してみます。整理するために

 \hat{r} _ \theta (x, y) = \beta \log \frac{\pi _ \theta (y _ w|x)}{\pi _ \mathrm{ref}(y _ w|x)}

と置きます。後にこれは暗黙の報酬モデルであることが明らかになりますが、それはさておいて、まずは勾配を求めます。 f = \hat{r} _ \theta(x, y _ w) - \hat{r} _ \theta(x, y _ l)とすると、損失関数の中身は

 
h = \log \sigma(f)

という形式になっているので、連鎖律で

 
\frac{\partial h}{\partial \theta} = f' \sigma(f) (1 - \sigma(f)) \frac{1}{\sigma(f)}

となり、ロジスティック関数の性質から 1 - \sigma(f) = \sigma(-f)なので、結局勾配は f' \sigma(-f)となります。この1つ目は

 
f' _ \theta = \nabla _ \theta \log \pi _ \theta (y _ w | x) - \nabla _ \theta \log \pi _ \theta (y _ l | x)

であり、単純に良いものを増加させ、悪いものを低下させるように働きます。2つ目の

 
\sigma(-f) = \sigma(-\hat{r} _ \theta(x, y _ w) + \hat{r} _ \theta(x, y _ l))

がその係数であり、つまり報酬によって上手いこと重み付けがなされると解釈できます。

DPOのさらなる理論解析

 まず次のことを定めます。

定義1. 2つの報酬関数 r(x, y) r'(x, y)が、ある関数 fに対して r(x, y) - r'(x, y) = f(x)となる場合、その2つの報酬関数は等価であるという。

 これは報酬関数の集合をクラスに分割する同値関係を定めていることになります。

 この同値関係について、次の補題2つがあります。

補題1. 同じクラスにある2つの報酬関数は、Plackett-Luceモデル(特にBradley-Terryモデル)の下で同じ嗜好分布を導く
補題2. 同じクラスにある2つの報酬関数は、制約付きRL問題の下で同じ最適方策を導く

 嗜好分布とは回答の良し悪し y _ w, y _ lの組についての分布で、補題1はunder-specification問題としてPlackett-Luceモデル族については知られている話のようです。報酬が一意に上手く定まらないわけですが、逆に補題2から、クラスさえ決まってしまえばどの報酬関数でも問題ないことがわかります。

補題1の証明】

 一般的なPlackett-Luceモデル(Bradley-Terryモデルはこの内 K = 2の特殊なケース)を考えます。特定の報酬関数 r'(x, y) = r(x, y) + f(x)によって誘導されるランキング上の確率分布を p _ rとします。手順はほぼ自明で、expなので和が掛け算としてくくり出せて約分できるので同じということになります。

 
p _ r' (\tau | y _ 1, \dots, y _ K, x) = \Pi _ {k = 1} ^ {K} \frac{ \exp(r'(x, y _ {\tau(k)})) }{ \sum _ {j = k}  ^ {K} \exp (r'(x, y _ {\tau(j)} ))} \\
= \Pi _ {k = 1} ^ {K} \frac{ \exp(r(x, y _ {\tau(k)}) + f(x)) }{ \sum _ {j = k}  ^ {K} \exp (r(x, y _ {\tau(j)} ) + f(x))} \\
= \Pi _ {k = 1} ^ {K} \frac{ \exp(r(x, y _ {\tau(k)})) }{ \sum _ {j = k}  ^ {K} \exp (r(x, y _ {\tau(j)} ))} \\
= p _ r (\tau | y _ 1, \dots, y _ K, x)

 補題2の証明もほぼ同様なので省略します。最適方策が \pi _ \mathrm{ref}(y | x) \exp \left( \frac{1}{\beta} r(x, y) \right)に比例する形でかけるのでexpが打ち消し合います。

 そして、以下の定理が成り立ちます。

定理1. 適当な仮定の下で、Plackett-Luce(特にBradley-Terry)モデルと一致するすべての報酬クラスは、再パラメータ化によって表現できる。この再パラメータ化は、ある方策 \pi (y | x)と与えられた参照方策 \pi _ \mathrm{ref}(y | x) を用いて、 r(x, y) = \beta \log\left(\frac{\pi(y|x)}{\pi _ \mathrm{ref}(y|x)}\right)の形をする。

(ちょっと書き方がわかりにくいのですが、要するに報酬をクラスとしてしか指定できていなかったところから、そのうちの一つに定まるという意味合いなのだと思います)

 証明の概要として、最適方策 \pi _ r (x, y)を導く報酬のクラスに属する任意の関数 r(x, y)に対して、以下のような写像 fを考えます。

 
f(r; \pi _ \mathrm{ref}, \beta) = r(x, y) - \beta \log \sum _ y \pi _ \mathrm{ref} (y|x) \exp \left( \frac{1}{\beta} r(x, y) \right)

これは要するに r(x, y) \pi _ rの分配関数で正規化するような操作になっています。また、第二項は xについてだけの関数なので、定義1より報酬関数としてのクラスは変わらないことがわかります。ここで、その1で求めた報酬関数についての等式(5)(これはどの報酬関数についても成り立ちます)

 
r(x, y) = \beta \log \frac{\pi _ r (y|x)}{\pi _ \mathrm{ref}(y|x)} + \beta \log Z(x)
\qquad(5)

を代入すると、きれいに消えて \beta \log \frac{\pi _ r (y|x)}{\pi _ \mathrm{ref}(y|x)}だけ残ります。つまり、どの報酬関数もこの写像 fによって一点に潰れるということなのではないかと思います。

 別の見方をすると、分配関数

 
Z(x) = \sum _ y \pi _ \mathrm{ref} \exp \left( \frac{1}{\beta} r(x, y) \right)

について、定理1を代入すると、

 
Z(x) = \sum _ y \pi (y|x) = 1

であることが直ちにわかるため、分配関数が1になる、という形で制約づけていると見なすこともできます。