こしあん
2019-05-11

クラス別のData Augmentationって意味ある?を調べてみた

Pocket
LINEで送る


Data Augmentationで精度を上げることにお熱になっていると、「特定の足引っ張っているクラスだけAugmetationかけたらいいんじゃない?」的なことをやりたくなります。これを調べてみました。その結果、全体を一括で処理するケースと、クラス別に局所的な処理をするケースで、精度に大きなギャップが生まれることが確認できました(局所的にやるほうが悪くなります)。

クラス別の精度を見てみよう

これは以前の投稿でやった、CIFAR-10を10層CNNで分類し、左右反転とクロップを加えたStandard Data Augmentationを加えたときの混同行列です。

このように、自動車の精度は97%以上出ても、猫の精度は85%も出ていなく、クラス間にかなりムラがあることがわかります。これを見ると「なら、自動車はAugmentationかけずに、猫だけAugmentationすればいいやん」って誰しも一度は考えると思うのです。

しかし、クラス別のData Augmentationについて調べると、自分が探した限りではまるで情報が出てこない、少なくとも論文が見当たらなかったというのが実情でした(もしかしたら自分が知らないだけで無知なのかもしれません)。多分うまくいかないんだろうな、ということなんでしょう。これを確かめてみました。

自分が確かめてみた結果、クラス別のAugmentationは多分良くないんじゃないかという結論になりました。

やること

CIFAR-10にStandard Data Augmentationを加えた時点をベースラインとします。ここから、

  1. ベースラインのクラス別精度を取る
  2. color_shift_range = 50を全部のクラスでAugmentationする
  3. 1,2でベースラインに対して精度が上がったクラスのみ、color_shift_range = 50のAugmentationをかける(基準はベースラインから)

部分的なAugmentationってどうするの?っていうことですが、こうします。

def partial_generator(X, y, color_shift, batch_size=128):
    non_shift_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
                                   width_shift_range=4.0/32, height_shift_range=4.0/32
                                   ).flow(X, y, batch_size, seed=123)
    shift_gen     = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
                                   width_shift_range=4.0/32, height_shift_range=4.0/32,
                                   channel_shift_range=color_shift).flow(X, y, batch_size, seed=123)
    while True:
        X_batch_no_shift, y_batch = next(non_shift_gen)
        X_batch_shifted, _ = next(shift_gen)
        if X_batch_no_shift.shape[0] != batch_size or X_batch_shifted.shape[0] != batch_size:
            continue
        y_ind = np.argmax(y_batch, axis=-1)
        # shift=50だと2,4,6,7,9が上がったから
        for i in range(batch_size):
            if y_ind[i] in [2, 4, 6, 7, 9]:
                X_batch_no_shift[i] = X_batch_shifted[i]
        yield X_batch_no_shift, y_batch

このように2つのジェネレーターを同時に走らせ、同一の乱数シードでflowさせます。こうすることで、シャッフルされるものの、同一のインデックスが抽出されるようになります。

ベースラインのクラス別精度

まずは、ベースラインのクラス別精度です。混同行列の対角成分を取りました。(各値を10で割ると%単位の精度になります)

ベース 飛行機 自動車 鹿 カエル トラック
1 944 976 884 850 937 879 951 961 961 956
2 950 972 890 855 925 880 958 947 955 949
3 945 979 903 846 931 881 961 943 966 950
4 943 976 895 846 939 886 953 950 952 957
5 953 973 890 852 937 885 953 949 959 957
平均 947.0 975.2 892.4 849.8 933.8 882.2 955.2 950.0 958.6 953.8

全体の平均=92.98%となりました。ちなみに、CIFAR-10はクラス単位のサンプル数が同じ均衡データなので(テストデータは全クラス1000個ずつある)、ただ単に平均して構いません。

全部のクラスにColor shiftを入れる

次に全部のクラスに対して、color_shift=50のAugmentationを追加で入れます。混同行列の値は以下の通りです。

全体50 飛行機 自動車 鹿 カエル トラック
1 941 972 911 848 933 879 954 958 952 967
2 939 977 906 837 948 867 953 958 960 953
3 942 980 912 849 938 881 968 944 950 964
4 956 978 901 841 936 878 958 946 954 954
5 930 971 915 843 942 875 957 952 954 960
平均 941.6 975.6 909.0 843.6 939.4 876.0 958.0 951.6 954.0 959.6

全体の平均は=93.08%です。効果あるんだかかなり微妙ですが、一応効果あるとしましょう。ベースラインとの差分を取ります。

差分 飛行機 自動車 鹿 カエル トラック
1 -3 -4 27 -2 -4 0 3 -3 -9 11
2 -11 5 16 -18 23 -13 -5 11 5 4
3 -3 1 9 3 7 0 7 1 -16 14
4 13 2 6 -5 -3 -8 5 -4 2 -3
5 -23 -2 25 -9 5 -10 4 3 -5 3
平均 -5.4 0.4 16.6 -6.2 5.6 -6.2 2.8 1.6 -4.6 5.8

これを見ると、確かにクラス別に効きやすい・効きにくいという傾向があるようです。太字としたクラスを「効いている」クラスとして扱いました。先程のコードの、

        # shift=50だと2,4,6,7,9が上がったから
        for i in range(batch_size):
            if y_ind[i] in [2, 4, 6, 7, 9]:
                X_batch_no_shift[i] = X_batch_shifted[i]

「2,4,6,7,9」の出処というのはこの太字のインデックスに対応します。

効果があったクラスに対してだけ、Color shiftを加える

先程のshift_range=50のカラーシフトのうち、効果のあった「インデックス=2,4,6,7,9」に対してのみColor shiftのAugmentationを加えます。結果は以下の通りです。

部分50 飛行機 自動車 鹿 カエル トラック
1 944 963 836 820 883 872 946 916 949 928
2 935 966 835 836 894 880 944 898 941 938
3 944 959 857 807 920 867 938 903 953 938
4 937 971 870 780 879 851 933 911 957 922
5 921 972 858 821 896 837 936 911 951 941
平均 936.2 966.2 851.2 812.8 894.4 861.4 939.4 907.8 950.2 933.4

全体の平均は90.53%となり明らかに悪くなっているのがわかります。クラス別に比較してみましょう。ベースラインからの比較です。

部分-ベース 飛行機 自動車 鹿 カエル トラック
1 0 -13 -48 -30 -54 -7 -5 -45 -12 -28
2 -15 -6 -55 -19 -31 0 -14 -49 -14 -11
3 -1 -20 -46 -39 -11 -14 -23 -40 -13 -12
4 -6 -5 -25 -66 -60 -35 -20 -39 5 -35
5 -32 -1 -32 -31 -41 -48 -17 -38 -8 -16
平均 -10.8 -9.0 -41.2 -37.0 -39.4 -20.8 -15.8 -42.2 -8.4 -20.4

太字にしたのがAugmentationしたクラスです。部分的にAugmentしたクラスが明らかに足を引っ張っています

全体-ベース 飛行機 自動車 鹿 カエル トラック
1 3 -9 -75 -28 -50 -7 -8 -42 -3 -39
2 -4 -11 -71 -1 -54 13 -9 -60 -19 -15
3 2 -21 -55 -42 -18 -14 -30 -41 3 -26
4 -19 -7 -31 -61 -57 -27 -25 -35 3 -32
5 -9 1 -57 -22 -46 -38 -21 -41 -3 -19
平均 -5.4 -9.4 -57.8 -30.8 -45.0 -14.6 -18.6 -43.8 -3.8 -26.2

全部で悪化しているので、こんなのやらないほうがマシですね。猫も引っ張られて悪くなっていますが、局所的なAugmentationを入れたクラスの方が大きく悪くなっているのがポイントです。しかし、これは部分的なAugmentationをするときのみおこることで、全クラス一括でやればここまで悪くなりません。

局所的なAugmentation vs 局所的なSampling

局所的なAugmentationはどうも意味がないというのがわかりそうですが、一方で、局所的なSamplingというのは意味があるだろうというのが通説です。これは不均衡データで用いられているOversampling/Undersamplingというテクニックによります。

CIFARのような整ったデータではないケースで、クラス間にサンプル数が偏りがあるケースを想定します。Oversamplingなら少ないクラスをより高い頻度でサンプリングし、Undersamplingなら多いクラスをより低い頻度でサンプリングするテクニックで、これは不均衡の是正やトータルでの精度向上に寄与します。同じクラス単位の局所的な操作でも、AugmentationとSamplingを切り離して考えたほうがよさそうです。

まとめ

クラス単位のAugmentationというのはどうも意味がないではないか。

全体でAugmentationしたときにクラス間にばらつきが出て、仮に大きく上昇するクラスがあったとしても、それを局所的にやると大きく精度を損なうことがある。Augmentationをするのなら、全体に共通の処理を施すべき。

ということでした。もしかしたら「いや、クラス単位のAugmentationも意味あるんだ」というケースもあるかもしれませんので、そうだったら知らせていただけると助かります。

Related Posts

TPUでアップサンプリングする際にエラーを出さない方法... 画像処理をしているとUpsamplingが必要になることがあります。Keras/TensorFlowではUpsampling2Dというレイヤーを使ってアップサンプリングができますが、このレイヤーがTPUだとエラーを出すので解決法を探しました。自分でアップサンプリングレイヤーを定義するとうまく行った...
OpenCVで作成した動画がブラウザで正常に表示できない場合の解決法... OpenCVで作成した動画をサイトで表示する場合、ローカルで再生できていても、ブラウザ上では突然プレビューがでなり、ハマることがあります。原因の特定が難しい現象ですが、動画を作成する際にH.264形式でエンコードするとうまくいきました。その方法を解説します。 MPV4は手軽だが… OpenCV...
Affinity LossをCIFAR-10で精度を求めてひたすら頑張った話... 不均衡データに対して有効性があると言われている損失関数「Affinity loss」をCIFAR-10で精度を出すためにひたすら頑張った、というひたすら泥臭い話。条件10個試したらやっと精度を出すためのコツみたいなのが見えてきました。 結論 長いので先に結論から。CIFAR-10をAffini...
KerasでSTL-10を扱う方法 スタンフォード大学が公開している画像データセットに「STL-10」というのがあります。これはCIFAR-10に似た形式ながら、高画質かつ教師なし学習にも応用できる便利なデータです。PyTorchではtorchvisionを使うと簡単に読み込めるのですが、Kerasではデフォルトで容易されていないの...
データのお気持ちを考えながらData Augmentationする... Data Augmentationの「なぜ?」に注目しながら、エラー分析をしてCIFAR-10の精度向上を目指します。その結果、オレオレAugmentationながら、Wide ResNetで97.3%という、Auto Augmentとほぼ同じ(-0.1%)精度を出すことができました。 (※すご...
Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です