背景
唐突にLLMでMulti-Armed Banditを解きたくなった。
全く読んでいないけれど
という論文もある。
ブラウザでのChatGPT4で数回試した感じだとそこそこ理屈立ててやってくれたので、ある程度やれるのではないかという期待。
実装
正規分布でスコアを返す3台のスロットマシンでのMulti-Armed Banditを実装した。
という設定で、2台目が一番良い状況になっている。
これはEpsilon Greedyとかでやるとある程度ちゃんと2台目に多く割り当てられる。
これをLLMに解かせる。OpenAIのAPIを使って、テキストで状況を伝えてどのスロットマシンを選択するかを決定してもらう。引いた結果もテキストで伝える。指定はかなり雑に、最終行で「[1]」という感じで出力させる。フォーマットが合わなかったらその旨を伝えて再度出力させるようにする。
そのあたりの実装は以下のような感じ。
結果・所感
結果は、API呼び出しがわりと時間かかり、また1試行でスロットを引く数(ここでは100)以上呼び出すことになるので、謎の通信エラーもあってまともに終了までやりきれることがなかった。
結局、APIでは追加学習も難しいので、これで押し切るというのは無理そう。ローカルで推論できる適当なモデルを使ってなんとかする方針を探りたい。追加学習する方法あるんだろうか、特になにも思いついてはいない。
また、プロンプトの入れ方によって結構ちゃんと出力されるかどうかが変わってきそうな雰囲気を感じる。プロンプトの良し悪しとか気休めでしょくらいに思っていたが、意外と影響大きいか。
後は、過去試行の情報をどう記憶に留めておくかという問題もある。今は100回引く分の対話全てを入れるとトークン数超過になるので、なにか対処は必要になる。要約させるなどもありえるけど、どこまで手を入れるか。平均・標準偏差の入力とかしたらそりゃ短期的に性能は良くなりそうではあるが、それが目指しているものかというと。