Sharpness-Aware Minimizationの検証

 Sharpness-Aware Minimizationという手法が提案されています。

 詳しい人が説明してくれています(僕もこれで知りました)。

 上記事の

パラメータ周辺での最大の損失を求めて、それが下がる方向でパラメータを更新する

というのが基本的なコンセプトでしょう。

 非公式ですが、再現実装も公開されています。

 それほど複雑でもなかったのでMiacisでも実装してみました。

 以下AobaZeroの棋譜を用いて実験して結果を掲載します。

Policy損失

f:id:tokumini:20210129100227p:plain

 最終的な検証損失

手法 時間(hhh:mm:ss)
通常のSGD 1.950
SAM(SGD) 1.864

 はっきりと改善が見られました。1.8台というのは強化学習を最後まで回しきってなんとか見られるかどうかという値であり、1.86というのはとても小さいという印象です。

 train損失は通常のSGDより悪い値となっていますが、SAMの方では2回目の損失、つまり近傍内で悪化する方向に移動してからの損失を表示しているので妥当なのかなと思います。

Value損失

f:id:tokumini:20210129100250p:plain

 最終的な検証損失

手法 時間(hhh:mm:ss)
通常のSGD 0.6453
SAM(SGD) 0.6520

 Value損失は途中までは良かったんですが、最終的な値は通常のSGDよりも悪化してしまいました。

検証対局

Miacis time = 250msec, YaneuraOu time = 250msec, YaneuraOu Threads = 4,NodesLimit=400000
手法 勝数 引分数 負数 勝率 相対レート
通常のSGD 488 117 395 54.6% 32.4
SAM(SGD) 411 109 480 46.6% -24.0

 残念ながら対局では性能が上がってないという結果になりました。個人的な印象ですが、他の実験を見ていてもPolicy損失よりValue損失の方が重要だという印象があります。

学習時間

 SAMは1回損失・勾配を計算して周囲の最も悪いところに移動した後、もう一度損失・勾配を計算するので単純に考えると2倍の時間がかかります。実際の学習時間は

手法 時間(hhh:mm:ss)
通常のSGD 080:44:58
SAM(SGD) 136:37:22

となり、2倍ではないにしろそこそこの増加はありました。しかし強化学習では「データ生成時間 >> 学習時間」なのでこの程度の増加は許容できるでしょう。

 一応Policy損失の方で良さそうな雰囲気は出ているので強化学習でも試してみようと思います。