前回の記事の通り、データセットの充実により教師あり学習でも十分な性能を出せるようになった。このため、CNNがスタンダードとなっているネットワーク構造についても再考ができるかもしれない。
巨大データセットを用いて巨大なモデルを学習させるというのが現在の深層学習のトレンドの一つと言える。たとえば画像認識で話題になっているVision Transformerは、JFT-300Mという3億枚の画像からなる巨大データセットで事前学習することにより性能が出るものになっている。
『dlshogi with GCTの棋譜』は約225Mと、3億(300M)に迫るデータ量になっているため、CNNのように帰納バイアスが強い構造でなくとも性能を出せるようになっている可能性がある。今回は予備的な調査として、Vision Transformer、MLP-Mixerについてこのデータセットについて軽い学習を回して様子を見ると共に、推論時の速度を計測する。
ネットワークの実装
特にMLP-Mixerなど、細かいところは正確ではないかもしれない。
学習結果
ResNet, Transformer, MLP-Mixerについて、どれも10ブロック・256chとしてネットワークを生成した。
『dlshogi with GCTの棋譜』データセットで、バッチサイズ512、16万ステップ回した。合計でもデータセットをすべて使い切れていない(8192万データしか見ていない)軽い学習となる。MLP-Mixerについては間違えてバッチサイズ256にしてしまっていたのでさらに参考値程度に。
Policy損失ではTransformerがResNetに肉薄する程度。しかしValue損失ではどちらも大きくResNetに水を開けられている。まぁこの程度の学習量ならそうであるのが当然なはず。
学習時間
モデル名 | 学習環境 | 学習時間 |
---|---|---|
ResNet | 自宅の2080tiデスクトップ | 12時間07分 |
Transformer | Google Colab(V100) | 16時間13分 |
MLP-Mixer | 自宅の2080tiデスクトップ | 6時間32分 |
MLP-Mixerはバッチサイズを間違えていたので速い。それを除くと、学習時間としては大差はないか、とはいえTransformerはV100でこれなのでやや重いか。
推論速度
上で学習した16万ステップ時点でのパラメータを用いて、初期局面について10回10秒の推論を行ってNPSの平均値を計測した。ついでにモデルのパラメータ数も付記する。
(環境:自宅の2080tiデスクトップ、PyTorch1.10.1、 CUDA10.2、 TensorRT7.2.2.3、INT8推論)
モデル名 | 初期局面におけるNPS | パラメータ数 |
---|---|---|
ResNet | 30,632 | 12,351,310 |
Transformer | 4,771 | 8,081,230 |
MLP-Mixer | 8,788 | 3,064,136 |
NPSはCNNが圧倒的であり、その他のモデルはかなり低下してしまうことがわかった。
いくらかネットワークを切除して速度計測を行った結果、
- LayerNormがBatchNormよりも遅い
- TransformerやMLP-Mixerは1ブロックがResNetより大きめ
- TransformerやMLP-Mixerにおける1ブロック中のチャンネル方向の処理について「一端チャンネル数を増加させてからまた戻す」部分が重い
などの理由が明らかになった。
簡単に各項目についてもう少し説明を加えると
(1)BatchNormは推論時には各要素を決まった値で線形変換するだけなのでとても速い。LayerNormなどでは他の活性値を参照してどうのこうのするのでどうしても遅くなる
(2, 3)一つのブロックにおける処理は
- ResNet : 3x3Conv2層とSqueeze-and-Excitation
- Transformer : Self-AttensionとMLP2層(MLPの中間ではchが128の4倍)
- MLP-Mixer : MLP2層(token_mix)とMLP2層(channel_mix)(それぞれMLPの中間ではchがもとの2倍)
となっており、MLP-Mixerだと1ブロックでほぼ4層であり、ResNet2ブロック分の重さがあるようだった。またMLP2層の中間でチャンネル数を増やす部分でかなり重いようで、3x3Conv2層よりも時間がかかっていそうだった。
学習では2倍も差がついていなさそうなのに、推論時には6倍近くの差が出ていることは気になる。TensorRT化における相性があるのだろうか。
所感
MetaFormerとかの話を見ると、結局本質的なのはchannel_mixに相当する後半のMLP2層なのではという気がするし、最新論文系だとそこでの中間チャンネルをかなり大きめに取っている印象なので、結局速度はどうしようもなく落ちるのを受け入れる必要があるのかもしれない。それに見合うだけの精度向上を引き出せるかが焦点になりそうだが、探索が重要なゲームではなかなか厳しそうにも思える。
とりあえず16万ステップでは少なすぎるので、100倍とかそのレベルの量を回す必要がありそう。そうなるとまともな比較実験をするのは相当困難なので、どこかで見切りつけて博打をやることになるのだろう。