Bigger, Better, Fasterのコードを動かす

 コードが公開されているので動かしてみる。

 venvで行ったので、おおよその手順は

git clone https://github.com/google-research/google-research
cd bigger_better_faster

python3 -m venv .env
source .env/bin/activate

pip3 install -r requirements.txt

export PYTHONPATH="$(readlink -f ../):$PYTHONPATH"

python3 -m bbf.train \
    --agent=BBF \
    --gin_files=bbf/configs/BBF.gin \
    --base_dir=./bbf_result \
    --run_number=1

という感じ。しかし途中途中でいろいろ直さないと動かなかった。

(1)dopamine_rlで落ちる

 BBFは内部的にdopamine_rlというライブラリのNoisyNetworkを読み込んでいるところがあるが、その読み込み時に

AttributeError: module 'jax.interpreters.xla' has no attribute 'DeviceArray'

で落ちる。これは最新のGitHubコードだと直っているのでこれと同じ修正をローカルのものに入れることで回避できる。

[JAX] Replace uses of jax.interpreters.xla.DeviceArray with jax.Array. · google/dopamine@a9a8fc0 · GitHub

 試行錯誤中だと上手くいかなかったが、単純にpipでdopamineライブラリのバージョンを上げても回避できるかもしれない。

(2)spr_agent.pyで落ちる

 dopamine_rlを直したら次はspr_agent.pyのところで落ちた。

ValueError: Custom node type mismatch: expected type: <class 'flax.core.frozen_dict.FrozenDict'>, value {Pythonの辞書形式}

というエラーが出る。要するにdictとFrozenDictの違いで型があってないということらしいので変換を入れると良い。

 その他、get_default_device_assignmentというやつも削除されたせいかエラーが出る。

Remove use of get_default_device_assignment(). · google/jax@3ce5cb6 · GitHub

 なんだかよくわからないが、これらを勘で次のように直した。(コピペ後に微修正したのでパッチファイルとしては機能しないと思われる。該当行を手動で修正する)

diff --git a/bigger_better_faster/bbf/agents/spr_agent.py b/bigger_better_faster/bbf/agents/spr_agent.py
index 25306dcb6..56c696091 100755
--- a/bigger_better_faster/bbf/agents/spr_agent.py
+++ b/bigger_better_faster/bbf/agents/spr_agent.py
@@ -45,9 +45,7 @@ def _pmap_device_order():
   if jax.process_count() == 1:
     return [
         d
-        for d in xb.get_backend().get_default_device_assignment(
-            jax.device_count()
-        )
+        for d in xb.get_backend().local_devices()
         if d.process_index == jax.process_index()
     ]
   else:
@@ -177,6 +175,13 @@ def interpolate_weights(
   if keys is None:
     keys = old_params.keys()
   for k in keys:
+    old_params = FrozenDict(old_params)
+    new_params = FrozenDict(new_params)
     combined_params[k] = jax.tree_util.tree_map(combination, old_params[k],
                                                 new_params[k])
   for k, v in old_params.items():
@@ -1309,6 +1314,10 @@ class BBFAgent(dqn_agent.JaxDQNAgent):
         optax.masked(optimizer, self.head_mask),
     )
 
+    self.online_params = FrozenDict(self.online_params)
     self.optimizer_state = self.optimizer.init(self.online_params)
     self.target_network_params = copy.deepcopy(self.online_params)
     self.random_params = copy.deepcopy(self.online_params)

(3)JaxがGPUを使っていない

 実験が動き出したようなログが流れ続けるが、nvidia-smiで見てもプロセスが乗っておらず、GPU利用率も低かった。

python3 -c "import jax; print(jax.devices())"

で確認したところ

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
[CpuDevice(id=0)]

 と出た。公式のインストール手順https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-via-pip-easier:titileを見て、自分の環境はCUDA-12が入っているので

pip3 install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

とした。これでpython3 -c "import jax; print(jax.devices())"の表示は

[cuda(id=0), cuda(id=1)]

となった。(2枚挿さっているPCなので2枚分ちゃんと出た)

動作結果

 設定はデフォルトのままで実行してみた。configをよく見ると、BBFAgent.replay_ratio = 64となっている(https://github.com/google-research/google-research/blob/4da7251308decf0a61807c09a8f4c087cbd06310/bigger_better_faster/bbf/configs/BBF.gin#L33)。

 ChopperCommandという、他の手法だとあまりスコアが伸びないゲームでHuman Normalized Scoreが1を超える結果を得られた。

 実験時間も3090GPUを1枚だけ使って2.6hだったため、かなり短い。