前回
前回は『Which are you?』というトイタスクを考え、LSTMを使って方策勾配法を用いることで10回中7回でそれなりな正答率になった。
今回は時系列モデルをLSTMからTransformerへと変更した。
実装
Transformerで時系列を扱う場合、過去の入力情報が必要になるので、適当にstd::vector
で確保しておく。LSTMのときに状態リセットAPIを用意していたので、それが呼ばれたタイミングでstd::vector
を空にする。
学習時にエピソード全体を一気に推論することがあり、そのときには未来情報を見ないようにmaskを作る必要がある。これは以下のような形で作れるというのは学びだった。
torch::Tensor mask = torch::ones({seq_len, seq_len}).to(device); mask = torch::triu(mask); mask = mask.transpose(0, 1);
PyTorchではmaskがbool値である場合は「Trueであるところが計算対象外」となるが、float値である場合は(0,1でなく中間の値も使えるように)maskの値がAttentionに直接かかるので、「1であるところが計算対象、0であるところが計算対象外」となり、意味が逆転するという罠があることも初めて知った。
以下が推論部の実装。seq_lenを一度に入力される場合(学習時)でも、seq_lenを1個ずつ切り出して入力される場合(実行時)でも、上手く動くようになっている。Valueを出力する機構も準備しているが、今回はValueは損失計算などに使用していない。
std::tuple<torch::Tensor, torch::Tensor> Network::forward(torch::Tensor x) { // xのshapeは(seq_len, batch, input_size) // 出力shape policy:(seq_len, batch, POLICY_DIM) value(seq_len, batch, 1) const int64_t seq_len = x.size(0); const torch::Device& device = transformer_->parameters().front().device(); x = x.view({-1, 1, input_size_}); // 過去の情報と結合 input_history_.push_back(x); x = torch::cat(input_history_, 0); x = x.to(device); // 1層目 x = first_layer_(x); // maskを作る torch::Tensor mask = torch::ones({seq_len, seq_len}).to(device); mask = torch::triu(mask); mask = mask.transpose(0, 1); torch::Tensor output = transformer_->forward(x, mask); const int64_t output_len = output.size(0); output = output.slice(0, output_len - seq_len, output_len); torch::Tensor policy = policy_head_->forward(output); torch::Tensor value = value_head_->forward(output); return std::make_tuple(policy, value); } void Network::resetState() { input_history_.clear(); }
結果
10回中10回で正答率は80%を超えるようになり、ちゃんと学習できていそうなことがわかった。
1回はほぼ100%に張り付くような形になったが、他では85%~95%あたりで止まってしまっている。これは、1回目の行動がランダムエージェントと被り(確率25%)、そこで行動を続けずに二択を回答すると確率50%で外れて、合わせて12.5%で外すような形になっているのではないかと思われる。
所感
Transformerが強いというよりも、過去の情報をCPUメモリに直接保持する形だと解きやすいのは当然な気がする。計算量が間に合うなら別にTransformerの構造でなくとも全結合ネットワークに放り込んでもこれくらいできるのではないかという気はするが、まぁ計算量の抑え方というのも重要なところではあるか。(今回のタスクでは系列長が2とかになっているはずなので正直なんでも解けるとは思うが)。
とはいえ入力系列が長くなっていくとこれではメモリが足りなくなったりといったことは起こるはずなので、そういうときには先日読んだような手法を使っていくことになるのだろう。