要約
PTQ(要するにINT8演算)の導入でR+30程度
実装
ニューラルネットワークは基本的に浮動小数点演算(FP32)を用いている。今までは半精度浮動小数点演算(FP16)により高速化を行っていたが、Post Training Quantization(PTQ)という、FP32の範囲を絞ってINT8の演算に変換することで高速化する手法も存在する。TRTorchでもこれがサポートされているのでチュートリアルに従って試した。
なかなか不安定なところもあったが、
- Calibrationデータとして取り出すデータのミニバッチサイズを、コンパイルで指定する最適バッチサイズと同じにする
- Calibrationデータの総数をミニバッチサイズの整数倍にする
- Calibrationのアルゴリズムには
IInt8MinMaxCalibrator
を使用する
あたりに気をつけることで上手くいくようになった感覚がある。ライブラリをそのまま使うだけなので実装自体の難易度は低く、下のように十数行でTorchScript形式のモデルをコンパイル・Calibrationできる。
強化学習で得た128chのモデルについて
損失計測
floodgate2015年の棋譜を用いてCalibrationを行い、2019年の棋譜に検証損失の計算を行った。
Calibrationに用いるデータ数(局面数)を変えながらINT8での計測を行った結果が以下となる。
データ数 | Policy損失 | Value損失 |
---|---|---|
FP16(比較用) | 1.8464 | 0.6429 |
64 | 1.8565 | 0.6450 |
128 | 1.8537 | 0.6435 |
256 | 1.8537 | 0.6442 |
512 | 1.8569 | 0.6437 |
1024 | 1.8583 | 0.6435 |
2048 | 1.8584 | 0.6438 |
損失はいくらか悪化している。
Calibrationに用いるデータ数が多ければ多いほど良いわけではなかった。後の256chでの結果も踏まえて、推論時のバッチサイズ(64)のちょうど2倍のデータ数(128)が最も良い値と見なすことにした。以降のNPSの測定ではそのデータ数でCalibrationを行っている。
NPSの測定
演算精度 | 初期局面 | 中盤の局面 |
---|---|---|
fp16 | 35198 ± 606 | 30392 ± 1802 |
INT8 | 37191 ± 797 | 34403 ± 1003 |
(倍率) | 1.057倍 | 1.131倍 |
128chのモデルではもともと推論が比較的高速なので、INT8による推論にしたところで恩恵が大きくない。これだと損失の悪化の方が大きいのではないか。1.1倍程度ではまともなレート差になりえないので対局はスキップ。
教師あり学習(AobaZeroの棋譜)で得た256chのモデルについて
損失計測
データ数 | Policy損失 | Value損失 |
---|---|---|
fp16(比較用) | 1.8390 | 0.5804 |
64 | 1.8442 | 0.5855 |
128 | 1.8495 | 0.5829 |
256 | 1.8552 | 0.5846 |
512 | 1.8511 | 0.5849 |
1024 | 1.8501 | 0.5883 |
2048 | 1.8565 | 0.5896 |
128chに比べて全体的に悪化幅が大きくなった。
NPSの測定
演算精度 | 初期局面 | 中盤の局面 |
---|---|---|
fp16 | 18612 ± 188 | 16278 ± 474 |
INT8 | 26790 ± 481 | 24140 ± 836 |
(倍率) | 1.439倍 | 1.482倍 |
256chだとNPSの向上が大きい。NPS2倍でレート100とすると、1.4倍あれば+50程度は見込める。これなら測定できそうだと思ったため256chモデルについては対局も行った。
対局
Miacisは1手0.5秒
対戦相手の条件
- 探索エンジン:やねうら王
- 評価関数:Kristallweizen
- NodesLimit:800000
結果
演算精度 | 勝数 | 引分数 | 負数 | 勝率 | 相対Eloレート |
---|---|---|---|---|---|
FP16 | 382 | 186 | 432 | 47.5% | -17.4 |
INT8 | 426 | 207 | 367 | 52.9% | 20.5 |
R+37.9となった。+50には届かなかったが、精度の悪化も含めるとこれでも伸びすぎかもしれないと思うくらいではある。