こしあん
2019-02-18

TPUでも大きなバッチサイズに対して精度を出す


TPUでは大きなバッチサイズを適用することが訓練の高速化につながりますが、これは精度と引き換えになることがあります。大きなバッチサイズでも精度を出す方法を論文をもとに調べてみました。

背景

Qiitaに書いたGoogle Brainの論文「学習率を落とすな、バッチサイズを増やせ」を読むの続き。自分でも調べてみました。

実験

CIFAR-10で実験、10層のレイヤーのネットワークを作り以下の条件で調べる。オプティマイザーはモメンタム(特に断りなければ係数0.9)でGoogle ColabのTPUで調べました。すべて250エポック訓練させます。

  1. バッチサイズを128、初期学習率を0.1として、100、150、200エポックで学習率を1/5ずつ減衰(baseline)。
  2. バッチサイズを128、初期学習率を0.1として、100、150、200エポックでバッチサイズを5倍ずつ増やす。128→640→3200→16000となる(increase batch size
  3. バッチサイズを640、初期学習率を0.1、モメンタム係数を0.98として、100、150、200エポックで学習率を1/5ずつ減衰(increase momentum
  4. バッチサイズを640、初期学習率を0.5として、100、150、200エポックで学習率を1/5ずつ減衰

理論的には、ノイズスケールはすべて一緒で、

  • 1と2の学習曲線は一緒になるはず
  • 3は1,2と比べると、モメンタムの係数を増やしているので若干テスト精度が落ちるはず
  • 4は3との比較用で、仮に「初期の学習率」を上げた場合、精度の落ち方は3と比べてどのぐらいなのか

ということを確認していく。

コード

結果

縦軸はValidationのエラーレートで、横軸はエポック数です

考察

  • 1と2の学習曲線は一緒?→一緒、つまり学習率を下げることとバッチサイズを上げることは同じ
  • モメンタムの係数を上げた3場合は?→だいたい学習曲線は一緒に見えるが、やはりテスト精度は下がっている
  • モメンタム係数ではなく学習率を上げると?(4の場合)→テスト精度の落ち方がややマイルドになる。ただしこれは元の学習率によりけりなので、必ずしもこうなるとは限らない。

ほぼ論文の実験の通りの結果になりました。よりわかったことは、バッチサイズを上げる前提でいじる優先順位は、初期学習率>>モメンタム係数で、初期学習率を上げるとテスト精度が大きく下がってしまうケースではモメンタム係数を上げてみるというところではないでしょうか。

Related Posts

Kerasでモデルのsummaryをテキストとして保存する方法... Kerasで「plot_modelを使えばモデルの可視化ができるが、GraphViz入れないといけなかったり、セットアップが面倒くさい!model.summary()のテキストをファイル保存で十分だ!」という場合に使えるテクニックです。 summary()のprint_fn引数を使う sum...
Affinity LossをCIFAR-10で精度を求めてひたすら頑張った話... 不均衡データに対して有効性があると言われている損失関数「Affinity loss」をCIFAR-10で精度を出すためにひたすら頑張った、というひたすら泥臭い話。条件10個試したらやっと精度を出すためのコツみたいなのが見えてきました。 結論 長いので先に結論から。CIFAR-10をAffini...
KerasでSTL-10を扱う方法 スタンフォード大学が公開している画像データセットに「STL-10」というのがあります。これはCIFAR-10に似た形式ながら、高画質かつ教師なし学習にも応用できる便利なデータです。PyTorchではtorchvisionを使うと簡単に読み込めるのですが、Kerasではデフォルトで容易されていないの...
Kerasに組み込まれているInceptionV3の実装... InceptionV3のsummary __________________________________________________________________________________________________ Layer (type) ...
BrestCancerデータセットをCNNで分類する 構造化データを畳み込みニューラルネットワーク(CNN)で分析することを考えます。BrestCancerデータセットはScikit-learnに用意されている、乳がんが良性か悪性かの2種類を分類する典型的な構造化データです。サンプル数569、データの次元30の典型的な構造化データです。 なぜ畳み込...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です