Post Training Quantization(PTQ)の導入

要約

 PTQ(要するにINT8演算)の導入でR+30程度

実装

 ニューラルネットワークは基本的に浮動小数点演算(FP32)を用いている。今までは半精度浮動小数点演算(FP16)により高速化を行っていたが、Post Training Quantization(PTQ)という、FP32の範囲を絞ってINT8の演算に変換することで高速化する手法も存在する。TRTorchでもこれがサポートされているのでチュートリアルに従って試した。

 なかなか不安定なところもあったが、

  • Calibrationデータとして取り出すデータのミニバッチサイズを、コンパイルで指定する最適バッチサイズと同じにする
  • Calibrationデータの総数をミニバッチサイズの整数倍にする
  • CalibrationのアルゴリズムにはIInt8MinMaxCalibratorを使用する

あたりに気をつけることで上手くいくようになった感覚がある。ライブラリをそのまま使うだけなので実装自体の難易度は低く、下のように十数行でTorchScript形式のモデルをコンパイル・Calibrationできる。

強化学習で得た128chのモデルについて

損失計測

 floodgate2015年の棋譜を用いてCalibrationを行い、2019年の棋譜に検証損失の計算を行った。

 Calibrationに用いるデータ数(局面数)を変えながらINT8での計測を行った結果が以下となる。

データ数 Policy損失 Value損失
FP16(比較用) 1.8464 0.6429
64 1.8565 0.6450
128 1.8537 0.6435
256 1.8537 0.6442
512 1.8569 0.6437
1024 1.8583 0.6435
2048 1.8584 0.6438

 損失はいくらか悪化している。

 Calibrationに用いるデータ数が多ければ多いほど良いわけではなかった。後の256chでの結果も踏まえて、推論時のバッチサイズ(64)のちょうど2倍のデータ数(128)が最も良い値と見なすことにした。以降のNPSの測定ではそのデータ数でCalibrationを行っている。

NPSの測定

演算精度 初期局面 中盤の局面
fp16 35198 ± 606 30392 ± 1802
INT8 37191 ± 797 34403 ± 1003
(倍率) 1.057倍 1.131倍

 128chのモデルではもともと推論が比較的高速なので、INT8による推論にしたところで恩恵が大きくない。これだと損失の悪化の方が大きいのではないか。1.1倍程度ではまともなレート差になりえないので対局はスキップ。

教師あり学習(AobaZeroの棋譜)で得た256chのモデルについて

損失計測

データ数 Policy損失 Value損失
fp16(比較用) 1.8390 0.5804
64 1.8442 0.5855
128 1.8495 0.5829
256 1.8552 0.5846
512 1.8511 0.5849
1024 1.8501 0.5883
2048 1.8565 0.5896

 128chに比べて全体的に悪化幅が大きくなった。

NPSの測定

演算精度 初期局面 中盤の局面
fp16 18612 ± 188 16278 ± 474
INT8 26790 ± 481 24140 ± 836
(倍率) 1.439倍 1.482倍

 256chだとNPSの向上が大きい。NPS2倍でレート100とすると、1.4倍あれば+50程度は見込める。これなら測定できそうだと思ったため256chモデルについては対局も行った。

対局

 Miacisは1手0.5秒

 対戦相手の条件

  • 探索エンジン:やねうら王
  • 評価関数:Kristallweizen
  • NodesLimit:800000

 結果

演算精度 勝数 引分数 負数 勝率 相対Eloレート
FP16 382 186 432 47.5% -17.4
INT8 426 207 367 52.9% 20.5

 R+37.9となった。+50には届かなかったが、精度の悪化も含めるとこれでも伸びすぎかもしれないと思うくらいではある。