画像to画像のネットワークに二桁の足し算を学習させる

 手頃な大きさのニューラルネットワークが、画像を入力とし画像を出力する形で二桁の足し算を解くことができるのかどうかを検証した。

実験方法

データ形式

 入力画像に「A+B=」という式を画像化したもの、教師画像として「A+B=C」の正しい式を画像化したものを与え、入力から教師を予測できるように学習する。たとえば以下のようなデータペアが与えられる。

f:id:tokumini:20211128183016p:plainf:id:tokumini:20211128183026p:plain
左:入力画像 右:教師画像

 詳細には、入力・教師画像は

  • サイズは256×256
  • 白黒の1ch画像
  • ネットワーク入力時、値は[0, 1]

である。

データの作り方

 0~99までの2つの値の足し算の式を全部生成し、計100 * 100 = 10000通りのデータを作り、ランダムに選んだ9000個を学習用データ、残り1000個をvalidation用データとする。

学習モデル

 これを適当なネットワークで処理する。AutoEncoderなりなんなり、入力・出力が画像で行えるものならなんでも良かったのだが、今回は実装を手抜きするために以下のセマンティクスセグメンテーション向けのライブラリを利用した。

 入力・出力をともに1chとして、出力の最後にはsigmoid関数をかけて出力を[0, 1]にする。出力と教師画像の二乗誤差を損失として勾配法で学習を行う。

実装全体

結果

 1周期Cosine annealingを用いて50エポックくらい回すと検証損失は収束している見た目になった。

f:id:tokumini:20211128182929p:plain

 最終エポックまで終わったパラメータについて、検証データからランダムに18データ選んでネットワーク出力を見てみると以下のようになった。

f:id:tokumini:20211128183223p:plain f:id:tokumini:20211128183230p:plain f:id:tokumini:20211128183254p:plain f:id:tokumini:20211128183259p:plain f:id:tokumini:20211128183307p:plain f:id:tokumini:20211128183312p:plain f:id:tokumini:20211128183318p:plain f:id:tokumini:20211128183330p:plain f:id:tokumini:20211128183338p:plain f:id:tokumini:20211128183345p:plain f:id:tokumini:20211128183401p:plain f:id:tokumini:20211128183407p:plain f:id:tokumini:20211128183415p:plain f:id:tokumini:20211128183423p:plain f:id:tokumini:20211128183430p:plain f:id:tokumini:20211128183639p:plain f:id:tokumini:20211128183647p:plain f:id:tokumini:20211128183654p:plain

 最後の画像など、少し数字が滲んでいて怪しいところはあるが、基本的に多くの場合で計算自体は合っている(と思う、目視で確認しているのでもしかしたら漏れがあるかもしれない)。

所感

 こういうことができそうだなぁと思ったが上手く論文を見つけられなかったので自分で試してみた。想定していたことがそれなりにできていそうで、こういう「これぐらいの大きさのニューラルネットワークでこれくらいのことができそう」という感覚はできるだけ持っておきたいところ。

 整数の足し算など、おそらくTransformerで大量に学習させた言語モデルに突っ込めばできるものだとは思うが、数字をトークン化して入力している事自体がある程度のヒントであるとも思える。また、ごちゃごちゃしてtex形式だと大変なことになる数式や、画像(グラフなど)と文章がセットになっている文書などを読む場合とかも考えると、文章自体を画像として捉える手法の重要性もそれなりにあるのかなと想像しているところではある。

 今後、タスクの難易度を上げる方法として

  • 引き算、掛け算を追加する(割り算は整数に閉じないからどうだろう?)
  • 値の上限をもっと増やす
  • 項の数を増やす

などがパッと思い浮かぶ。

 タスクの難易度上昇なのかは明確にはわからないが、他にもデータ拡張として

  • 数式表示位置をランダムに少しズラす
  • 数式のフォント、フォントサイズをランダムに変える

といったことをすることも考えられる。

 精度向上のためには単純にもっと巨大モデルを用いても良いので、そのあたりでどこまでできるものなのかは少し検証してみたい。