LibTorchにおける半精度浮動小数点演算

記事の要約

 LibTorchを使って半精度浮動小数点演算(FP16)を行うことで探索は速くなったが、学習は上手くいかなかった。どうもBatch Normalizationの部分はFP32で計算しなければならないようだ。

LibTorchによる半精度浮動小数点演算

 深層学習では厳密な精度よりも計算速度が求められる場合が多い。特に近頃のGPUTensorコアなるものによって半精度(16bit)の浮動小数点演算を高速に行うことができるようで、FP16で計算を行うプログラムを書くと高速になるとのことである。

 LibTorchではモデルをGPUに転送する際にtorch::kHalfを指定することで半精度浮動小数点演算ができるようだ。自作の将棋ソフト『Miacis』のneural_network.cppから、モデルをGPUへ転送する部分のコードを示す。以下示すコードはUSE_HALF_FLOATシンボルを立てている場合に半精度浮動小数点演算を行うようになっている。

    device_ = torch::Device(torch::kCUDA, gpu_id);
#ifdef USE_HALF_FLOAT
    to(device_, torch::kHalf);
#else
    to(device_);
#endif

 入力特徴量もGPUへ転送する際に同様の指定を行う。

std::pair<torch::Tensor, torch::Tensor> NeuralNetworkImpl::forward(const std::vector<float>& inputs) {
#ifdef USE_HALF_FLOAT
    torch::Tensor x = torch::tensor(inputs).to(device_, torch::kHalf);
#else
    torch::Tensor x = torch::tensor(inputs).to(device_);
#endif
    //以下ネットワークによる処理

    return { policy, value };
}

 ネットワークの出力をCPUで受け取る場合はtorch::Half型を用いる。たとえばbatch_size分のvaluestd::vector<float>で受け取るコードは次のようになる。

    auto y = forward(inputs);

    //CPUに持ってくる
    torch::Tensor value = y.second.cpu();

    //float型のstd::vectorで受け取る
    std::vector<float> values(batch_size);
#ifdef USE_HALF_FLOAT
    std::copy(value.data<torch::Half>(), value.data<torch::Half>() + batch_size, values.begin());
#else
    std::copy(value.data<float>(), value.data<float>() + batch_size, values.begin());
#endif

実験

 上記の改良により本当に速くなるか、精度の問題はないかを検証した。

探索速度

 初期局面を1GPU(2080ti)、2スレッド、512バッチで探索した。GPU処理部分を重くして差がわかりやすくなるようにResidualブロックにおけるCNNのチャネル数を256にして実験を行った。探索速度は次のようになった。

FP32 FP16
7322 NPS 11284 NPS

 FP16にすることで1.5倍ほどのNPSになった。

教師あり学習

 RTX2080tiを用いてfloodgate2016年の棋譜をもとに教師あり学習を行った。損失の推移を次に示す。

f:id:tokumini:20190509202901p:plainf:id:tokumini:20190509202910p:plain
左:Policy損失 右:Value損失

 FP16では上手く学習できていない。

 少し調べてみるとBatch Normalizationは32bitで計算しければいけないとの話があった。指数移動平均のところが原因と言われていたりもするが、どこまで信用できるものかはわからない。Qiitaにも同内容の記事があったので少なくともBatch NormalizationはFP32で計算しなければならないという説の信憑性は高そうだ。山岡さんのブログでもそのようにしていた。やはりBatch Normalizationの部分が問題であるようだ。

結論

 FP16によって探索は速くなったが学習での精度が低下してしまった。今後はモデルの一部だけ(Batch Normalization以外)をFP16で計算する方法などを探っていきたい。