こしあん
2019-05-11

クラス別のData Augmentationって意味ある?を調べてみた

Pocket
LINEで送る


Data Augmentationで精度を上げることにお熱になっていると、「特定の足引っ張っているクラスだけAugmetationかけたらいいんじゃない?」的なことをやりたくなります。これを調べてみました。その結果、全体を一括で処理するケースと、クラス別に局所的な処理をするケースで、精度に大きなギャップが生まれることが確認できました(局所的にやるほうが悪くなります)。

クラス別の精度を見てみよう

これは以前の投稿でやった、CIFAR-10を10層CNNで分類し、左右反転とクロップを加えたStandard Data Augmentationを加えたときの混同行列です。

このように、自動車の精度は97%以上出ても、猫の精度は85%も出ていなく、クラス間にかなりムラがあることがわかります。これを見ると「なら、自動車はAugmentationかけずに、猫だけAugmentationすればいいやん」って誰しも一度は考えると思うのです。

しかし、クラス別のData Augmentationについて調べると、自分が探した限りではまるで情報が出てこない、少なくとも論文が見当たらなかったというのが実情でした(もしかしたら自分が知らないだけで無知なのかもしれません)。多分うまくいかないんだろうな、ということなんでしょう。これを確かめてみました。

自分が確かめてみた結果、クラス別のAugmentationは多分良くないんじゃないかという結論になりました。

やること

CIFAR-10にStandard Data Augmentationを加えた時点をベースラインとします。ここから、

  1. ベースラインのクラス別精度を取る
  2. color_shift_range = 50を全部のクラスでAugmentationする
  3. 1,2でベースラインに対して精度が上がったクラスのみ、color_shift_range = 50のAugmentationをかける(基準はベースラインから)

部分的なAugmentationってどうするの?っていうことですが、こうします。

def partial_generator(X, y, color_shift, batch_size=128):
    non_shift_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
                                   width_shift_range=4.0/32, height_shift_range=4.0/32
                                   ).flow(X, y, batch_size, seed=123)
    shift_gen     = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
                                   width_shift_range=4.0/32, height_shift_range=4.0/32,
                                   channel_shift_range=color_shift).flow(X, y, batch_size, seed=123)
    while True:
        X_batch_no_shift, y_batch = next(non_shift_gen)
        X_batch_shifted, _ = next(shift_gen)
        if X_batch_no_shift.shape[0] != batch_size or X_batch_shifted.shape[0] != batch_size:
            continue
        y_ind = np.argmax(y_batch, axis=-1)
        # shift=50だと2,4,6,7,9が上がったから
        for i in range(batch_size):
            if y_ind[i] in [2, 4, 6, 7, 9]:
                X_batch_no_shift[i] = X_batch_shifted[i]
        yield X_batch_no_shift, y_batch

このように2つのジェネレーターを同時に走らせ、同一の乱数シードでflowさせます。こうすることで、シャッフルされるものの、同一のインデックスが抽出されるようになります。

ベースラインのクラス別精度

まずは、ベースラインのクラス別精度です。混同行列の対角成分を取りました。(各値を10で割ると%単位の精度になります)

ベース 飛行機 自動車 鹿 カエル トラック
1 944 976 884 850 937 879 951 961 961 956
2 950 972 890 855 925 880 958 947 955 949
3 945 979 903 846 931 881 961 943 966 950
4 943 976 895 846 939 886 953 950 952 957
5 953 973 890 852 937 885 953 949 959 957
平均 947.0 975.2 892.4 849.8 933.8 882.2 955.2 950.0 958.6 953.8

全体の平均=92.98%となりました。ちなみに、CIFAR-10はクラス単位のサンプル数が同じ均衡データなので(テストデータは全クラス1000個ずつある)、ただ単に平均して構いません。

全部のクラスにColor shiftを入れる

次に全部のクラスに対して、color_shift=50のAugmentationを追加で入れます。混同行列の値は以下の通りです。

全体50 飛行機 自動車 鹿 カエル トラック
1 941 972 911 848 933 879 954 958 952 967
2 939 977 906 837 948 867 953 958 960 953
3 942 980 912 849 938 881 968 944 950 964
4 956 978 901 841 936 878 958 946 954 954
5 930 971 915 843 942 875 957 952 954 960
平均 941.6 975.6 909.0 843.6 939.4 876.0 958.0 951.6 954.0 959.6

全体の平均は=93.08%です。効果あるんだかかなり微妙ですが、一応効果あるとしましょう。ベースラインとの差分を取ります。

差分 飛行機 自動車 鹿 カエル トラック
1 -3 -4 27 -2 -4 0 3 -3 -9 11
2 -11 5 16 -18 23 -13 -5 11 5 4
3 -3 1 9 3 7 0 7 1 -16 14
4 13 2 6 -5 -3 -8 5 -4 2 -3
5 -23 -2 25 -9 5 -10 4 3 -5 3
平均 -5.4 0.4 16.6 -6.2 5.6 -6.2 2.8 1.6 -4.6 5.8

これを見ると、確かにクラス別に効きやすい・効きにくいという傾向があるようです。太字としたクラスを「効いている」クラスとして扱いました。先程のコードの、

        # shift=50だと2,4,6,7,9が上がったから
        for i in range(batch_size):
            if y_ind[i] in [2, 4, 6, 7, 9]:
                X_batch_no_shift[i] = X_batch_shifted[i]

「2,4,6,7,9」の出処というのはこの太字のインデックスに対応します。

効果があったクラスに対してだけ、Color shiftを加える

先程のshift_range=50のカラーシフトのうち、効果のあった「インデックス=2,4,6,7,9」に対してのみColor shiftのAugmentationを加えます。結果は以下の通りです。

部分50 飛行機 自動車 鹿 カエル トラック
1 944 963 836 820 883 872 946 916 949 928
2 935 966 835 836 894 880 944 898 941 938
3 944 959 857 807 920 867 938 903 953 938
4 937 971 870 780 879 851 933 911 957 922
5 921 972 858 821 896 837 936 911 951 941
平均 936.2 966.2 851.2 812.8 894.4 861.4 939.4 907.8 950.2 933.4

全体の平均は90.53%となり明らかに悪くなっているのがわかります。クラス別に比較してみましょう。ベースラインからの比較です。

部分-ベース 飛行機 自動車 鹿 カエル トラック
1 0 -13 -48 -30 -54 -7 -5 -45 -12 -28
2 -15 -6 -55 -19 -31 0 -14 -49 -14 -11
3 -1 -20 -46 -39 -11 -14 -23 -40 -13 -12
4 -6 -5 -25 -66 -60 -35 -20 -39 5 -35
5 -32 -1 -32 -31 -41 -48 -17 -38 -8 -16
平均 -10.8 -9.0 -41.2 -37.0 -39.4 -20.8 -15.8 -42.2 -8.4 -20.4

太字にしたのがAugmentationしたクラスです。部分的にAugmentしたクラスが明らかに足を引っ張っています

全体-ベース 飛行機 自動車 鹿 カエル トラック
1 3 -9 -75 -28 -50 -7 -8 -42 -3 -39
2 -4 -11 -71 -1 -54 13 -9 -60 -19 -15
3 2 -21 -55 -42 -18 -14 -30 -41 3 -26
4 -19 -7 -31 -61 -57 -27 -25 -35 3 -32
5 -9 1 -57 -22 -46 -38 -21 -41 -3 -19
平均 -5.4 -9.4 -57.8 -30.8 -45.0 -14.6 -18.6 -43.8 -3.8 -26.2

全部で悪化しているので、こんなのやらないほうがマシですね。猫も引っ張られて悪くなっていますが、局所的なAugmentationを入れたクラスの方が大きく悪くなっているのがポイントです。しかし、これは部分的なAugmentationをするときのみおこることで、全クラス一括でやればここまで悪くなりません。

局所的なAugmentation vs 局所的なSampling

局所的なAugmentationはどうも意味がないというのがわかりそうですが、一方で、局所的なSamplingというのは意味があるだろうというのが通説です。これは不均衡データで用いられているOversampling/Undersamplingというテクニックによります。

CIFARのような整ったデータではないケースで、クラス間にサンプル数が偏りがあるケースを想定します。Oversamplingなら少ないクラスをより高い頻度でサンプリングし、Undersamplingなら多いクラスをより低い頻度でサンプリングするテクニックで、これは不均衡の是正やトータルでの精度向上に寄与します。同じクラス単位の局所的な操作でも、AugmentationとSamplingを切り離して考えたほうがよさそうです。

まとめ

クラス単位のAugmentationというのはどうも意味がないではないか。

全体でAugmentationしたときにクラス間にばらつきが出て、仮に大きく上昇するクラスがあったとしても、それを局所的にやると大きく精度を損なうことがある。Augmentationをするのなら、全体に共通の処理を施すべき。

ということでした。もしかしたら「いや、クラス単位のAugmentationも意味あるんだ」というケースもあるかもしれませんので、そうだったら知らせていただけると助かります。

Related Posts

Kerasで転移学習用にレイヤー名とそのインデックスを調べる方法... Kerasで転移学習をするときに、学習済みモデルのレイヤーの名前と、そのインデックス(何番目にあるかということ)の対応を知りたいことがあります。その方法を解説します。 転移学習とは 転移学習とは、ImageNetなど何百万もの大量の画像で事前学習させたモデルを使い、それを「特徴量検出器」として...
Chainerで画像の前処理やDataAugmentationをしたいときはDatasetMixin... Chainerにはデフォルトでランダムクロップや標準化といった、画像の前処理やDataAugmentation用の関数が用意されていません。別途のChainer CVというライブラリを使う方法もありますが、chainer.dataset.DatasetMixinを継承させて独自のデータ・セットを定...
転移学習でネットワーク内でアップサンプリングする方法(Keras)... 転移学習でインプットのサイズを揃えなければいけないことがありますが、これをRAM(CPU)上でやるとメモリが不足することがあります。転移学習の重みをそのまま使い、事前にアップサンプリングレイヤーを差し込む方法を紹介します。 関連記事とバックグラウンド まず前提知識としてCPU側でアップサンプリ...
TPUでも大きなバッチサイズに対して精度を出す... TPUでは大きなバッチサイズを適用することが訓練の高速化につながりますが、これは精度と引き換えになることがあります。大きなバッチサイズでも精度を出す方法を論文をもとに調べてみました。 背景 Qiitaに書いたGoogle Brainの論文「学習率を落とすな、バッチサイズを増やせ」を読むの続き。...
TensorFlow2.0でDistribute Trainingしたときにfitと訓練ループで精度... TensorFlowでDistribute Training(複数GPUやTPUでの訓練)をしたときに、Keras APIのfit()でのValidation精度と、訓練ループを書いたときの精度でかなり(1~2%)違うという状況に遭遇しました。特定の文を忘れただけだったのですが、解決に1日かかった...
Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です