Affinity LossをCIFAR-10で精度を求めてひたすら頑張った話
不均衡データに対して有効性があると言われている損失関数「Affinity loss」をCIFAR-10で精度を出すためにひたすら頑張った、というひたすら泥臭い話。条件10個試したらやっと精度を出すためのコツみたいなのが見えてきました。
目次
結論
長いので先に結論から。CIFAR-10をAffinity lossで結果をよくしたい場合、
- Data Augmentationすると効果がブーストされ、確認しやすい
- F1スコアのクラス間の集計方法(Micro / Macro)を比較検討しても良い。Macroだと結果的に不均衡側のクラスに高い比重がかかってしまい、それが見た目の数字を悪くしていることもある。
- 出力層の直前のレイヤー数が多ければ、Affinity Lossの直前のBatch Normalizationを切っても良い
- SoftmaxとAffinity Lossを同時に訓練すると若干安定性が良くなるが、効果としては地味
- ネットワークのアーキテクチャの変更はハイパーパラメータの再チューニングが必要で、上の2つに比べると結果を保証しにくい。ただし、大きいネットワークのほうが中心の数(m)を増やしたときの効果は大きそう
背景
私が書いたQiitaのこちらの記事からの続きになります。Qiitaの記事だと、「MNISTはうまく行ったけどCIFAR-10だと微妙だったよね」という結論になりました。
ざっくりいうと、Affinity lossとは「不均衡データに対して有効性のある損失関数」です。元ネタの論文と自分の実装はこちらになります。
元の論文:Munawar Hayat, Salman Khan, Waqas Zamir, Jianbing Shen, Ling Shao. Max-margin Class Imbalanced Learning with Gaussian Affinity. 2019. https://arxiv.org/abs/1901.07711
自分の実装:https://github.com/koshian2/affinity-loss
試した条件
以下の条件でひたすら試しました。特に断りがなければ、ネットワークは10層のレイヤーのCNNを使っています。ハイパーパラメータはOptunaを使い、$m=1,5 \sigma=90, \lambda=0.43$で試しています。このハイパーパラメータは10層のレイヤーのCNNを使ってチューニングしています。
10層レイヤーの構成は以下の通りです。
レイヤー | チャンネル数 | カーネル/Stride | 繰り返し |
---|---|---|---|
Conv | 64 | 3 | 3 |
AvgPool | – | 2 | 1 |
Conv | 128 | 3 | 3 |
AvgPool | – | 2 | 1 |
Conv | 256 | 3 | 3 |
Global AvgPool | – | – | 1 |
Batch Norm | – | – | 1 |
Affinity loss | 10 | – | 1 |
他の条件は以前と同じです。
- バッチサイズは640
- Weight Decayはなし
- MNISTの場合と同様に、テストデータを基準に1クラスあたりのサンプル数を500、200、100、50、20、10と変化させる。訓練データのサンプル数はテストデータの各5倍。
- TensorFlow/KerasのTPUで訓練する
- すべてのケースについて5回試行した
- Macro F1スコアで比較
試行ごとに変えた条件は次の通りです。
- 長く訓練してみる(100epoch→250epoch)
- ネットワークをMobileNetに変更する
- ネットワークをResNet50に変更する
- ネットワークを10層から3層に変更する
- バッチサイズを640→128に変更する(学習率も適切に合わせる)
- Affinity Lossの直前のBatchNormalizationを削る
- SoftmaxとAffinity lossを混合して訓練する
- Data Augmentationを加える
- 評価尺度をMacro F1からMicro F1にする
- Micro F1で評価し、Data Augmentationを加える
表の見方
ヘッダーの意味は以下の通りです。
- Affinity m=1/5 : Affinity lossで訓練したときの評価。行の要素はテストデータにおける1クラスあたりの不均衡クラスのサンプル数。例えば500のときは、半分の均衡クラスは1クラス1000サンプル、残りの半分の不均衡クラスは1クラス500サンプルとなる。この値を500~10まで変化させる。最小値の10のときは1000:10、つまり100倍の不均衡が発生する。訓練データは総サンプル数の比にしたがって5倍の値となる。例えばテストデータにおいて、1000:500なら、訓練データにおける1クラスのサンプル数は5000:2500となる。
- Acc Median:5回試行したときのサンプル単位の精度(Accuracy)の中央値
- Acc Max:5回試行したときのサンプル単位の精度(Accuracy)の最大値
- Diff Acc Median:Softmax、AffinityLossを各5回試行したときの、精度(Accuracy)の中央値同士の差。Affinity Loss-Softmaxで計算し、この値が高いほどAffinity Lossの効きが強いことになる。
- Diff Acc Max:Softmax、AffinityLossを各5回試行したときの、精度(Accuracy)の最大値同士の差。Affinity Loss-Softmaxで計算し、この値が高いほどAffinity Lossの効きが強いことになる。
- F1 Median:5回試行したときのF1スコアの中央値。特に断りがなければMacro F1スコアでクラス単位のF1スコアの平均値を評価する。
- F1 Max:5回試行したときのF1スコアの最大値。特に断りがなければMacro F1スコアでクラス単位のF1スコアの平均値を評価する。
- Diff F1 Median:Softmax、AffinityLossを各5回試行したときの、F1スコアの中央値同士の差。特に断りがなければいずれもMacro F1スコア。Affinity Loss-Softmaxで計算し、この値が高いほどAffinity Lossの効きが強いことになる。
- Diff F1 Max:Softmax、AffinityLossを各5回試行したときの、F1スコアの最大値同士の差。特に断りがなければいずれもMacro F1スコア。Affinity Loss-Softmaxで計算し、この値が高いほどAffinity Lossの効きが強いことになる。
ちょっと評価項目が多すぎるので、最後に「m=1のDiff Acc Median、m=5のDiff Acc Median、m=1のDiff F1 Median、m=5のDiff F1 Median」(それぞれの行ごとの和)をまとめて表記します。
ちなみにSoftmaxとの比較は各条件で再度Softmaxで訓練させ、その値をもとに計算しています。
1.長く訓練してみる(100epoch→250epoch)
100epoch→250epochに変更してみます。100epochの場合は、50エポック、80エポックで学習率をそれぞれ1/5にしました。しかし、250epochの場合は、100、150、200エポックでそれぞれ学習率を1/5にします。初期学習率はすべて5e-3です。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8751 | 0.8778 | 0.30% | 0.16% | 0.8679 | 0.8724 | 0.19% | 0.21% |
200 | 0.871 | 0.8762 | 0.35% | 0.84% | 0.8214 | 0.8234 | -0.44% | -0.70% |
100 | 0.8703 | 0.8759 | 0.29% | 0.13% | 0.7654 | 0.774 | -1.49% | -1.22% |
50 | 0.8896 | 0.8953 | 1.69% | 2.07% | 0.7275 | 0.7474 | 2.58% | 2.88% |
20 | 0.8927 | 0.8964 | 0.38% | 0.41% | 0.638 | 0.6449 | -1.88% | -3.57% |
10 | 0.8997 | 0.9041 | 0.59% | 0.51% | 0.5751 | 0.6096 | -0.16% | 1.86% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8717 | 0.878 | -0.04% | 0.18% | 0.8655 | 0.8728 | -0.05% | 0.25% |
200 | 0.8706 | 0.876 | 0.31% | 0.82% | 0.8165 | 0.8272 | -0.93% | -0.32% |
100 | 0.8723 | 0.8764 | 0.49% | 0.18% | 0.7687 | 0.7774 | -1.16% | -0.88% |
50 | 0.8866 | 0.8939 | 1.39% | 1.93% | 0.7131 | 0.7159 | 1.14% | -0.27% |
20 | 0.8921 | 0.8978 | 0.32% | 0.55% | 0.6158 | 0.6235 | -4.10% | -5.71% |
10 | 0.895 | 0.9004 | 0.12% | 0.14% | 0.5521 | 0.576 | -2.46% | -1.50% |
3.60% / 2.59% / -1.20% / -7.56%
(m=1のDiff Acc Median / m=5のDiff Acc Median / m=1のDiff F1 Median / m=5のDiff F1 Medianのそれぞれの行ごとの和)
サンプル単位の精度では全般的に良くても、クラス単位のF1スコアでは悪いという結果になりました。これは以前の結果から変わっていません。
2.ネットワークをMobileNetに変更する
Affinity lossは画像埋め込みを考えているから単純すぎるネットワークでは良くならないのでは?と考え、ネットワークをMobileNetに変更してみました。ただしハイパーパラメータの再チューニングはしていません。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8329 | 0.8372 | -0.67% | -0.48% | 0.8232 | 0.8288 | -0.82% | -0.47% |
200 | 0.824 | 0.8303 | -0.15% | -0.04% | 0.759 | 0.77 | -0.86% | -0.56% |
100 | 0.8309 | 0.8404 | 0.21% | -0.09% | 0.6993 | 0.7198 | -0.57% | -0.65% |
50 | 0.8462 | 0.8533 | 0.53% | 0.90% | 0.6403 | 0.6568 | -1.63% | -1.96% |
20 | 0.8491 | 0.8559 | -0.39% | -0.22% | 0.5105 | 0.5451 | -6.22% | -4.02% |
10 | 0.862 | 0.8641 | 0.13% | -0.20% | 0.4777 | 0.5141 | -6.08% | -5.08% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8405 | 0.8457 | 0.09% | 0.37% | 0.831 | 0.8398 | -0.04% | 0.63% |
200 | 0.8214 | 0.8295 | -0.41% | -0.12% | 0.7668 | 0.7792 | -0.08% | 0.36% |
100 | 0.8253 | 0.8311 | -0.35% | -1.02% | 0.6964 | 0.7173 | -0.86% | -0.90% |
50 | 0.8528 | 0.8627 | 1.19% | 1.84% | 0.6482 | 0.6651 | -0.84% | -1.13% |
20 | 0.8476 | 0.8582 | -0.54% | 0.01% | 0.5303 | 0.5408 | -4.24% | -4.45% |
10 | 0.856 | 0.8607 | -0.47% | -0.54% | 0.4657 | 0.5123 | -7.28% | -5.26% |
-0.34% / -0.49% / -16.18% / -13.34%
かえって悪くなってしまいました。MobileNetのハイパラの最適解はおそらく別の場所にあるからまあ当然ですよね。
3.ネットワークをResNet50に変更する
もっと深くしてResNet50にしてみます。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.6761 | 0.7119 | -5.40% | -3.49% | 0.6164 | 0.696 | -10.30% | -4.07% |
200 | 0.5427 | 0.5577 | -15.90% | -15.31% | 0.2716 | 0.2937 | -36.11% | -35.01% |
100 | 0.5328 | 0.5826 | -20.32% | -16.92% | 0.2551 | 0.2747 | -35.31% | -33.98% |
50 | 0.5894 | 0.629 | -14.23% | -12.37% | 0.2851 | 0.294 | -23.20% | -23.90% |
20 | 0.4755 | 0.523 | -24.80% | -27.12% | 0.1825 | 0.2342 | -25.35% | -29.36% |
10 | 0.5772 | 0.645 | -10.78% | -13.11% | 0.2663 | 0.295 | -13.18% | -16.96% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.7321 | 0.7483 | 0.20% | 0.15% | 0.7204 | 0.7387 | 0.10% | 0.20% |
200 | 0.729 | 0.7446 | 2.73% | 3.38% | 0.6503 | 0.6777 | 1.76% | 3.39% |
100 | 0.7522 | 0.7573 | 1.62% | 0.55% | 0.6072 | 0.622 | -0.10% | 0.75% |
50 | 0.7738 | 0.7905 | 4.21% | 3.78% | 0.5565 | 0.5664 | 3.94% | 3.34% |
20 | 0.7896 | 0.8139 | 6.61% | 1.97% | 0.454 | 0.4788 | 1.80% | -4.90% |
10 | 0.7841 | 0.8101 | 9.91% | 3.40% | 0.4033 | 0.4146 | 0.52% | -5.00% |
-91.43% / 25.28% / -143.45% / 8.02%
なんかものすごい極端な結果になってしまいました。ResNet50だとそもそもSoftmaxで精度が出ていないですね。もうちょっと長く訓練したり、学習率やバッチサイズのチューニングをしないといけないと思います。
ただこの結果だけで見ると、m=1では超絶悪いものの、m=5では相対的に良い(Softmaxがかなり悪いから絶対的にいいわけではない)ということになりました。ResNet50みたいに大きなネットワークではm(中心の数)を増やすことによるメリットが大きいのかもしれません。
4. ネットワークを10層から3層に変更する
今まではネットワークを大きく・深くしていましたが、今度はネットワークを小さく・浅くしてみます。どのように変わるのでしょうか?
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8244 | 0.8321 | -1.02% | -0.57% | 0.8157 | 0.8235 | -1.15% | -0.63% |
200 | 0.8214 | 0.8311 | -0.79% | -0.12% | 0.7526 | 0.7644 | -2.30% | -1.66% |
100 | 0.83 | 0.8366 | -0.97% | -0.36% | 0.6807 | 0.6921 | -5.27% | -4.96% |
50 | 0.8476 | 0.8533 | -0.11% | -0.21% | 0.601 | 0.611 | -6.29% | -7.38% |
20 | 0.8578 | 0.861 | 0.01% | -0.01% | 0.4597 | 0.4648 | -12.84% | -15.05% |
10 | 0.8679 | 0.8706 | -0.02% | 0.04% | 0.4365 | 0.4375 | -14.23% | -14.38% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8255 | 0.8273 | -0.91% | -1.05% | 0.8171 | 0.8186 | -1.01% | -1.12% |
200 | 0.8252 | 0.8269 | -0.41% | -0.54% | 0.7574 | 0.7652 | -1.82% | -1.58% |
100 | 0.825 | 0.8336 | -1.47% | -0.66% | 0.6604 | 0.6722 | -7.30% | -6.95% |
50 | 0.8443 | 0.8507 | -0.44% | -0.47% | 0.5745 | 0.5861 | -8.94% | -9.87% |
20 | 0.8491 | 0.8503 | -0.86% | -1.08% | 0.4318 | 0.4376 | -15.63% | -17.77% |
10 | 0.8561 | 0.859 | -1.20% | -1.12% | 0.4301 | 0.432 | -14.87% | -14.93% |
-2.90% / -5.29% / -42.08% / -49.57%
かえって悪くなってしまいました。浅すぎるネットワークはダメみたいですね。
5. バッチサイズを640→128に変更する(学習率も適切に合わせる)
変更前は初期学習率5e-3であったのに対し、変更後は初期学習率1e-3に変更しました。バッチサイズと一緒に学習率も落としたのは、Don’t Decay the Learning Rate, Increase the Batch Sizeという論文によります。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.887 | 0.8915 | 0.00% | -0.08% | 0.8816 | 0.8863 | 0.02% | 0.12% |
200 | 0.88 | 0.8813 | -0.65% | -0.87% | 0.84 | 0.8443 | -0.77% | -0.58% |
100 | 0.8833 | 0.8845 | -0.48% | -0.69% | 0.7909 | 0.7945 | -2.13% | -1.94% |
50 | 0.8895 | 0.8948 | -0.49% | -0.29% | 0.7335 | 0.7382 | -2.18% | -2.35% |
20 | 0.9015 | 0.9023 | -0.41% | -0.47% | 0.6515 | 0.6629 | -3.85% | -4.25% |
10 | 0.9106 | 0.9107 | -0.22% | -0.53% | 0.5916 | 0.6101 | -6.64% | -5.91% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8775 | 0.8799 | -0.95% | -1.24% | 0.8712 | 0.8743 | -1.02% | -1.08% |
200 | 0.8751 | 0.881 | -1.14% | -0.90% | 0.8351 | 0.8375 | -1.26% | -1.26% |
100 | 0.8803 | 0.8818 | -0.78% | -0.96% | 0.783 | 0.792 | -2.92% | -2.19% |
50 | 0.8897 | 0.8902 | -0.47% | -0.75% | 0.723 | 0.733 | -3.23% | -2.87% |
20 | 0.8997 | 0.9028 | -0.59% | -0.42% | 0.6386 | 0.6539 | -5.14% | -5.15% |
10 | 0.9026 | 0.909 | -1.02% | -0.70% | 0.5879 | 0.6161 | -7.01% | -5.31% |
-2.25% / -4.95% / -15.55% / -20.58%
今度はSoftmaxが精度を出しすぎるために悪くなってしまいました。Softmaxに打ち勝つにはやはりハイパーパラメータを再チューニングしないといけないようです。
6. Affinity Lossの直前のBatchNormalizationを削る
個人的にこれは本命ではないかと思った条件です。出力層直前のBatchNormは潜在次元が低次元なら必要でも、256次元ぐらいになると実はいらないのでは?となるからです。ちなみにSoftmaxの場合はBatchNormは常に入れていません。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8712 | 0.8742 | 0.42% | 0.29% | 0.8643 | 0.8674 | 0.44% | 0.50% |
200 | 0.8689 | 0.8728 | 0.72% | 0.38% | 0.8211 | 0.8247 | 0.27% | 0.12% |
100 | 0.879 | 0.881 | 1.58% | 1.62% | 0.7836 | 0.792 | 1.99% | 2.77% |
50 | 0.8762 | 0.8832 | 0.58% | 1.11% | 0.6763 | 0.6907 | -3.36% | -2.38% |
20 | 0.8931 | 0.8983 | 0.99% | 1.13% | 0.6253 | 0.6337 | -1.43% | -3.67% |
10 | 0.8991 | 0.9012 | 1.10% | 0.60% | 0.5885 | 0.6056 | 0.33% | -1.72% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8761 | 0.8793 | 0.91% | 0.80% | 0.8699 | 0.8717 | 1.00% | 0.93% |
200 | 0.8688 | 0.8727 | 0.71% | 0.37% | 0.8181 | 0.8222 | -0.03% | -0.13% |
100 | 0.8712 | 0.8758 | 0.80% | 1.10% | 0.7736 | 0.7778 | 0.99% | 1.35% |
50 | 0.8783 | 0.8806 | 0.79% | 0.85% | 0.6874 | 0.6991 | -2.25% | -1.54% |
20 | 0.8922 | 0.8941 | 0.90% | 0.71% | 0.6195 | 0.6482 | -2.01% | -2.22% |
10 | 0.8954 | 0.8978 | 0.73% | 0.26% | 0.534 | 0.5831 | -5.12% | -3.97% |
5.39% / 4.84% / -1.76% / -7.42%
精度面では今までは一番よいのではないでしょうか。出力層の直前にBatchNormを入れるのは、潜在次元のプロットが超球面上になるという縛りを置くのとほぼ同じなので、初期値ガチャに対する安定性は上がりますが(F1の中央値はこちらのほうが悪いのはそういうことかと)、うまく行ったときに精度が出やすいのはこちらだと思います。
7. SoftmaxとAffinity lossを混合して訓練する
これも悪くないなと思ったもので、SoftmaxとAffinity lossの合計21チャンネルを出力層(Affinity lossはdiversity regularizationを入れて実装しているので11チャンネルになります)として、推定時にはAffinity lossを使います。ここでのSoftmaxの位置づけは、あくまでニューラルネットワークが特徴量を得やすくするための学習の加速というものです。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8732 | 0.8771 | 0.81% | 0.64% | 0.865 | 0.8691 | 0.55% | 0.62% |
200 | 0.8644 | 0.8723 | 0.31% | 0.96% | 0.8163 | 0.8261 | 0.12% | 0.20% |
100 | 0.8702 | 0.8784 | 0.45% | 1.03% | 0.7628 | 0.7825 | -0.77% | 0.97% |
50 | 0.8802 | 0.8835 | 1.36% | 0.81% | 0.7028 | 0.7227 | 0.13% | 1.39% |
20 | 0.8889 | 0.893 | 0.57% | 0.68% | 0.6268 | 0.6466 | 0.25% | 1.24% |
10 | 0.8943 | 0.899 | 0.86% | 1.15% | 0.5891 | 0.6134 | 1.52% | 0.78% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8725 | 0.876 | 0.74% | 0.53% | 0.8654 | 0.8696 | 0.59% | 0.67% |
200 | 0.8676 | 0.8705 | 0.63% | 0.78% | 0.821 | 0.8227 | 0.59% | -0.14% |
100 | 0.869 | 0.8711 | 0.33% | 0.30% | 0.7655 | 0.7706 | -0.50% | -0.22% |
50 | 0.8781 | 0.8833 | 1.15% | 0.79% | 0.7031 | 0.7137 | 0.16% | 0.49% |
20 | 0.8901 | 0.8931 | 0.69% | 0.69% | 0.6069 | 0.6341 | -1.74% | -0.01% |
10 | 0.8959 | 0.9002 | 1.02% | 1.27% | 0.572 | 0.5984 | -0.19% | -0.72% |
4.36% / 4.56% / 1.80% / -1.09%
悪くないですね。F1スコアの中央値での評価は一番良かったと思います。
8. Data Augmentationを加える
今度はData Augmentationを入れてみます。水平反転+4ピクセルの上下左右のクロップといういわゆる「Standard Data Augmentation」ですね。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.883 | 0.8907 | 0.65% | 0.82% | 0.8789 | 0.884 | 0.89% | 0.92% |
200 | 0.8766 | 0.8809 | 0.94% | 0.89% | 0.833 | 0.84 | 0.65% | 0.21% |
100 | 0.8685 | 0.8769 | 0.19% | 0.70% | 0.7699 | 0.7912 | -0.33% | 1.41% |
50 | 0.8972 | 0.8983 | 2.79% | 2.04% | 0.7366 | 0.7567 | 4.21% | 5.56% |
20 | 0.8909 | 0.8984 | 1.28% | 0.81% | 0.6453 | 0.6494 | 5.35% | 2.62% |
10 | 0.8946 | 0.9016 | 0.65% | 0.44% | 0.5417 | 0.5571 | -1.24% | -5.25% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8865 | 0.8948 | 1.00% | 1.23% | 0.8816 | 0.8906 | 1.16% | 1.58% |
200 | 0.8794 | 0.8833 | 1.22% | 1.13% | 0.8349 | 0.8397 | 0.84% | 0.18% |
100 | 0.8771 | 0.8862 | 1.05% | 1.63% | 0.7795 | 0.7975 | 0.63% | 2.04% |
50 | 0.8969 | 0.901 | 2.76% | 2.31% | 0.7225 | 0.7507 | 2.80% | 4.96% |
20 | 0.894 | 0.901 | 1.59% | 1.07% | 0.6062 | 0.6275 | 1.44% | 0.43% |
10 | 0.8998 | 0.9024 | 1.17% | 0.52% | 0.5164 | 0.5494 | -3.77% | -6.02% |
6.50% / 8.79% / 9.53% / 3.10%
今までで一番良かったです。10のケースはかなり落ちていますが、Data Augmentationをすると良くなるということでしょうか。
9. 評価尺度をMacro F1からMicro F1にする
評価尺度を変えたら良くなったというのはちょっとずるいかなと思って避けていたのですが、どうもMacro F1で測るのが正しくないような気がしてきたので、Micro F1に変えてみました。
Macro F1とMicro F1の違いは、Macro F1がクラス間のF1スコアを計算するのに対して、Micro F1はサンプル間のF1スコアを計算します。Macroの場合は、クラス間での平均計算をどのクラス計算でも均一にしてしまうので、不均衡なクラスのサンプルに対して大きな重みがかかってしまうのです。例えば次のような例があったとしましょう。これはSoftmaxの例です。
大きなクラス:1万サンプル 精度85% F1スコア80%
小さなクラス:100サンプル 精度80% F1スコア70%
今Affinity Lossを使ったら、評価が次のように変わったとしましょう。
大きなクラス:1万サンプル 精度87% F1スコア82%
小さなクラス:100サンプル 精度78% F1スコア65%
全体の精度は、サンプル単位で計算するので、
【変更前:Softmax】
精度:(85×10000+80×100)÷10100=84.95%
MacroF1:(80+70)÷2=75.00%
【変更後:Affinity Loss】
精度:(87×10000+78×100)÷10100=86.83%
MacroF1:(82+65)÷2=73.50%
精度は上がっているのに、MacroF1は下がっているのがわかるでしょうか。MacroF1だとあまり結果が良くないのはおそらくこういう事情があるのだと思います。そこで、F1スコアも精度と同じくサンプル単位で考える必要があるのではないでしょうか。具体的は、
- 変更前Micro F1:(80×10000+70×100)÷10100=79.90%
- 変更後Micro F1:(82×10000+65×100)÷10100=81.83%
このようにMicro F1で見ると実は良かったというケースも考えられます。これを見ていきます。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8715 | 0.8822 | 0.82% | 1.49% | 0.8719 | 0.8811 | 0.88% | 1.37% |
200 | 0.8676 | 0.8705 | 0.53% | 0.43% | 0.8652 | 0.8665 | 0.25% | 0.05% |
100 | 0.8668 | 0.8714 | 0.09% | 0.19% | 0.8652 | 0.8697 | -0.03% | -0.04% |
50 | 0.8872 | 0.8908 | 1.83% | 1.48% | 0.8832 | 0.8851 | 1.47% | 0.88% |
20 | 0.885 | 0.8888 | 0.00% | 0.22% | 0.8848 | 0.8891 | -0.06% | 0.19% |
10 | 0.8937 | 0.8955 | 0.27% | -0.53% | 0.893 | 0.8948 | 0.20% | -0.63% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.869 | 0.8722 | 0.57% | 0.49% | 0.8693 | 0.8714 | 0.62% | 0.40% |
200 | 0.8644 | 0.8692 | 0.21% | 0.30% | 0.861 | 0.8657 | -0.17% | -0.03% |
100 | 0.8722 | 0.8759 | 0.63% | 0.64% | 0.8697 | 0.8755 | 0.42% | 0.54% |
50 | 0.8811 | 0.8877 | 1.22% | 1.17% | 0.879 | 0.8834 | 1.05% | 0.71% |
20 | 0.8829 | 0.8938 | -0.21% | 0.72% | 0.883 | 0.8932 | -0.24% | 0.60% |
10 | 0.8922 | 0.8965 | 0.12% | -0.43% | 0.8922 | 0.8966 | 0.12% | -0.45% |
3.54% / 2.54% / 2.71% / 1.80%
確かに安定して数字が出るようになりました。多少悪くなっているケースもありますが、これなら悪くはないですね。
10. Micro F1で評価し、Data Augmentationを加える
最後にData Augmentationを加え、Micro F1で評価してみましょう。
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.889 | 0.8907 | 0.47% | 0.36% | 0.8883 | 0.8909 | 0.45% | 0.30% |
200 | 0.8754 | 0.8794 | 1.24% | -0.29% | 0.8733 | 0.8788 | 1.03% | -0.40% |
100 | 0.8715 | 0.8781 | -0.10% | 0.02% | 0.8706 | 0.8765 | -0.15% | -0.01% |
50 | 0.8964 | 0.9023 | 2.85% | 3.03% | 0.8899 | 0.8981 | 2.29% | 2.69% |
20 | 0.8888 | 0.8939 | 0.81% | 0.56% | 0.8881 | 0.8944 | 0.70% | 0.63% |
10 | 0.8932 | 0.9056 | 0.02% | 0.54% | 0.893 | 0.9055 | 0.00% | 0.49% |
Affinity m=1 | Acc Median | Acc Max | Diff Acc Median | Diff Acc Max | F1 Median | F1 Max | Diff F1 Median | Diff F1 Max |
---|---|---|---|---|---|---|---|---|
500 | 0.8881 | 0.8897 | 0.38% | 0.26% | 0.887 | 0.8891 | 0.32% | 0.12% |
200 | 0.8775 | 0.8781 | 1.45% | -0.42% | 0.8752 | 0.8773 | 1.22% | -0.55% |
100 | 0.8765 | 0.882 | 0.40% | 0.41% | 0.8752 | 0.8805 | 0.31% | 0.39% |
50 | 0.8986 | 0.9003 | 3.07% | 2.83% | 0.8939 | 0.895 | 2.69% | 2.38% |
20 | 0.8943 | 0.8971 | 1.36% | 0.88% | 0.8938 | 0.8972 | 1.27% | 0.91% |
10 | 0.898 | 0.9069 | 0.50% | 0.67% | 0.8976 | 0.9079 | 0.46% | 0.73% |
5.29% / 7.16% / 4.32% / 6.27%
確かにData Augmentationすると良いですね。Micro F1でもAffinity Lossを使ったほうが良くなっているのが確認できます。
結果まとめ
各ケースの4つの値:「m=1のDiff Acc Median、m=5のDiff Acc Median、m=1のDiff F1 Median、m=5のDiff F1 Median(それぞれの行ごとの和)」をまとめてみます。
ケース | 条件 | m=1のDiff Acc Medianの和 | m=5のDiff Acc Medianの和 | m=1のDiff F1 Medianの和 | m=5のDiff F1 Medianの和 |
---|---|---|---|---|---|
1 | 長く訓練してみる(100epoch→250epoch) | 3.60% | 2.59% | -1.20% | -7.56% |
2 | ネットワークをMobileNetに変更する | -0.34% | -0.49% | -16.18% | -13.34% |
3 | ネットワークをResNet50に変更する | -91.43% | 25.28% | -143.45% | 8.02% |
4 | ネットワークを10層から3層に変更する | -2.90% | -5.29% | -42.08% | -49.57% |
5 | バッチサイズを640→128に変更する(学習率も適切に合わせる) | -2.25% | -4.95% | -15.55% | -20.58% |
6 | Affinity Lossの直前のBatchNormalizationを削る | 5.39% | 4.84% | -1.76% | -7.42% |
7 | SoftmaxとAffinity lossを混合して訓練する | 4.36% | 4.56% | 1.80% | -1.09% |
8 | Data Augmentationを加える | 6.50% | 8.79% | 9.53% | 3.10% |
9 | 評価尺度をMacro F1からMicro F1にする | 3.54% | 2.54% | 2.71% | 1.80% |
10 | Micro F1で評価し、Data Augmentationを加える | 5.29% | 7.16% | 4.32% | 6.27% |
となりました。これからわかることは次の通りです。
- Affinity LossはData Augmentationとセットで使うと効果を発揮しやすい。SoftmaxよりもData Augmentationの効果がブーストされる。
- 一見値が良くなくてもF1スコアなどのクラス間の集計方法(Macro→Micro)を変えると実は良い結果が出ていたということがわかる
- 出力層の直前のレイヤー数が多ければ、Affinity Lossの直前のBatch Normalizationを切っても良い
- SoftmaxとAffinity Lossを同時に訓練させると相互作用により若干良くなるが、効果としては地味
- m=5のように中心の値を増やすようなケースではより、ResNet50のようなより大きなネットワークを使うとうまくいきやすいかもしれない。しかし、ネットワークの変更はハイパーパラメータのチューニングが必要になるので、ハイパーパラメータが効いているのか、ネットワークが効いているのかよくわからなく、Data Augmentationやクラス間の集計に比べるとややぱっとしないかもしれない。
ということで、結局は「Data Augmentation大事だよ」ということになりました。これで自分の中では割とすっきりしました。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー