こしあん
2019-02-18

TPUでも大きなバッチサイズに対して精度を出す

Pocket
LINEで送る


TPUでは大きなバッチサイズを適用することが訓練の高速化につながりますが、これは精度と引き換えになることがあります。大きなバッチサイズでも精度を出す方法を論文をもとに調べてみました。

背景

Qiitaに書いたGoogle Brainの論文「学習率を落とすな、バッチサイズを増やせ」を読むの続き。自分でも調べてみました。

実験

CIFAR-10で実験、10層のレイヤーのネットワークを作り以下の条件で調べる。オプティマイザーはモメンタム(特に断りなければ係数0.9)でGoogle ColabのTPUで調べました。すべて250エポック訓練させます。

  1. バッチサイズを128、初期学習率を0.1として、100、150、200エポックで学習率を1/5ずつ減衰(baseline)。
  2. バッチサイズを128、初期学習率を0.1として、100、150、200エポックでバッチサイズを5倍ずつ増やす。128→640→3200→16000となる(increase batch size
  3. バッチサイズを640、初期学習率を0.1、モメンタム係数を0.98として、100、150、200エポックで学習率を1/5ずつ減衰(increase momentum
  4. バッチサイズを640、初期学習率を0.5として、100、150、200エポックで学習率を1/5ずつ減衰

理論的には、ノイズスケールはすべて一緒で、

  • 1と2の学習曲線は一緒になるはず
  • 3は1,2と比べると、モメンタムの係数を増やしているので若干テスト精度が落ちるはず
  • 4は3との比較用で、仮に「初期の学習率」を上げた場合、精度の落ち方は3と比べてどのぐらいなのか

ということを確認していく。

コード

結果

縦軸はValidationのエラーレートで、横軸はエポック数です

考察

  • 1と2の学習曲線は一緒?→一緒、つまり学習率を下げることとバッチサイズを上げることは同じ
  • モメンタムの係数を上げた3場合は?→だいたい学習曲線は一緒に見えるが、やはりテスト精度は下がっている
  • モメンタム係数ではなく学習率を上げると?(4の場合)→テスト精度の落ち方がややマイルドになる。ただしこれは元の学習率によりけりなので、必ずしもこうなるとは限らない。

ほぼ論文の実験の通りの結果になりました。よりわかったことは、バッチサイズを上げる前提でいじる優先順位は、初期学習率>>モメンタム係数で、初期学習率を上げるとテスト精度が大きく下がってしまうケースではモメンタム係数を上げてみるというところではないでしょうか。

Related Posts

ColabのTPUでNASNet Largeを訓練しようとして失敗した話... ColabのTPUはとてもメモリ容量が大きく、計算が速いのでモデルのパラメーターを多くしてもそこまでメモリオーバーor遅くなりません。ただし、あまりにモデルが深すぎると訓練の初期設定で失敗することがあります。NASNet Largeを訓練しようとして発生しました。これを見ていきます。 CIFAR...
WarmupとData Augmentationのバッチサイズ別の精度低下について... 大きいバッチサイズで訓練する際は、バッチサイズの増加にともなう精度低下が深刻になります。この精度低下を抑制することはできるのですが、例えばData Augmentationのようなデータ増強・正則化による精度向上とは何が違うのでしょうか。それを調べてみました。 きっかけ この記事を書いたときに...
データのお気持ちを考えながらData Augmentationする... Data Augmentationの「なぜ?」に注目しながら、エラー分析をしてCIFAR-10の精度向上を目指します。その結果、オレオレAugmentationながら、Wide ResNetで97.3%という、Auto Augmentとほぼ同じ(-0.1%)精度を出すことができました。 (※すご...
Affinity LossをCIFAR-10で精度を求めてひたすら頑張った話... 不均衡データに対して有効性があると言われている損失関数「Affinity loss」をCIFAR-10で精度を出すためにひたすら頑張った、というひたすら泥臭い話。条件10個試したらやっと精度を出すためのコツみたいなのが見えてきました。 結論 長いので先に結論から。CIFAR-10をAffini...
PyTorchで行列(テンソル)積としてConv2dを使う... PyTorchではmatmulの挙動が特殊なので、思った通りにテンソル積が取れないことがあります。この記事では、基本的な畳み込み演算である「Conv2D」を使い、Numpyのドット積相当の演算を行うという方法を解説します。 はじめに PyTorchの変態コーディング技術です。多分。 画像のテ...
Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です