こしあん
2019-05-11

クラス別のData Augmentationって意味ある?を調べてみた


3k{icon} {views}


Data Augmentationで精度を上げることにお熱になっていると、「特定の足引っ張っているクラスだけAugmetationかけたらいいんじゃない?」的なことをやりたくなります。これを調べてみました。その結果、全体を一括で処理するケースと、クラス別に局所的な処理をするケースで、精度に大きなギャップが生まれることが確認できました(局所的にやるほうが悪くなります)。

クラス別の精度を見てみよう

これは以前の投稿でやった、CIFAR-10を10層CNNで分類し、左右反転とクロップを加えたStandard Data Augmentationを加えたときの混同行列です。

このように、自動車の精度は97%以上出ても、猫の精度は85%も出ていなく、クラス間にかなりムラがあることがわかります。これを見ると「なら、自動車はAugmentationかけずに、猫だけAugmentationすればいいやん」って誰しも一度は考えると思うのです。

しかし、クラス別のData Augmentationについて調べると、自分が探した限りではまるで情報が出てこない、少なくとも論文が見当たらなかったというのが実情でした(もしかしたら自分が知らないだけで無知なのかもしれません)。多分うまくいかないんだろうな、ということなんでしょう。これを確かめてみました。

自分が確かめてみた結果、クラス別のAugmentationは多分良くないんじゃないかという結論になりました。

やること

CIFAR-10にStandard Data Augmentationを加えた時点をベースラインとします。ここから、

  1. ベースラインのクラス別精度を取る
  2. color_shift_range = 50を全部のクラスでAugmentationする
  3. 1,2でベースラインに対して精度が上がったクラスのみ、color_shift_range = 50のAugmentationをかける(基準はベースラインから)

部分的なAugmentationってどうするの?っていうことですが、こうします。

def partial_generator(X, y, color_shift, batch_size=128):
    non_shift_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
                                   width_shift_range=4.0/32, height_shift_range=4.0/32
                                   ).flow(X, y, batch_size, seed=123)
    shift_gen     = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
                                   width_shift_range=4.0/32, height_shift_range=4.0/32,
                                   channel_shift_range=color_shift).flow(X, y, batch_size, seed=123)
    while True:
        X_batch_no_shift, y_batch = next(non_shift_gen)
        X_batch_shifted, _ = next(shift_gen)
        if X_batch_no_shift.shape[0] != batch_size or X_batch_shifted.shape[0] != batch_size:
            continue
        y_ind = np.argmax(y_batch, axis=-1)
        # shift=50だと2,4,6,7,9が上がったから
        for i in range(batch_size):
            if y_ind[i] in [2, 4, 6, 7, 9]:
                X_batch_no_shift[i] = X_batch_shifted[i]
        yield X_batch_no_shift, y_batch

このように2つのジェネレーターを同時に走らせ、同一の乱数シードでflowさせます。こうすることで、シャッフルされるものの、同一のインデックスが抽出されるようになります。

ベースラインのクラス別精度

まずは、ベースラインのクラス別精度です。混同行列の対角成分を取りました。(各値を10で割ると%単位の精度になります)

ベース 飛行機 自動車 鹿 カエル トラック
1 944 976 884 850 937 879 951 961 961 956
2 950 972 890 855 925 880 958 947 955 949
3 945 979 903 846 931 881 961 943 966 950
4 943 976 895 846 939 886 953 950 952 957
5 953 973 890 852 937 885 953 949 959 957
平均 947.0 975.2 892.4 849.8 933.8 882.2 955.2 950.0 958.6 953.8

全体の平均=92.98%となりました。ちなみに、CIFAR-10はクラス単位のサンプル数が同じ均衡データなので(テストデータは全クラス1000個ずつある)、ただ単に平均して構いません。

全部のクラスにColor shiftを入れる

次に全部のクラスに対して、color_shift=50のAugmentationを追加で入れます。混同行列の値は以下の通りです。

全体50 飛行機 自動車 鹿 カエル トラック
1 941 972 911 848 933 879 954 958 952 967
2 939 977 906 837 948 867 953 958 960 953
3 942 980 912 849 938 881 968 944 950 964
4 956 978 901 841 936 878 958 946 954 954
5 930 971 915 843 942 875 957 952 954 960
平均 941.6 975.6 909.0 843.6 939.4 876.0 958.0 951.6 954.0 959.6

全体の平均は=93.08%です。効果あるんだかかなり微妙ですが、一応効果あるとしましょう。ベースラインとの差分を取ります。

差分 飛行機 自動車 鹿 カエル トラック
1 -3 -4 27 -2 -4 0 3 -3 -9 11
2 -11 5 16 -18 23 -13 -5 11 5 4
3 -3 1 9 3 7 0 7 1 -16 14
4 13 2 6 -5 -3 -8 5 -4 2 -3
5 -23 -2 25 -9 5 -10 4 3 -5 3
平均 -5.4 0.4 16.6 -6.2 5.6 -6.2 2.8 1.6 -4.6 5.8

これを見ると、確かにクラス別に効きやすい・効きにくいという傾向があるようです。太字としたクラスを「効いている」クラスとして扱いました。先程のコードの、

        # shift=50だと2,4,6,7,9が上がったから
        for i in range(batch_size):
            if y_ind[i] in [2, 4, 6, 7, 9]:
                X_batch_no_shift[i] = X_batch_shifted[i]

「2,4,6,7,9」の出処というのはこの太字のインデックスに対応します。

効果があったクラスに対してだけ、Color shiftを加える

先程のshift_range=50のカラーシフトのうち、効果のあった「インデックス=2,4,6,7,9」に対してのみColor shiftのAugmentationを加えます。結果は以下の通りです。

部分50 飛行機 自動車 鹿 カエル トラック
1 944 963 836 820 883 872 946 916 949 928
2 935 966 835 836 894 880 944 898 941 938
3 944 959 857 807 920 867 938 903 953 938
4 937 971 870 780 879 851 933 911 957 922
5 921 972 858 821 896 837 936 911 951 941
平均 936.2 966.2 851.2 812.8 894.4 861.4 939.4 907.8 950.2 933.4

全体の平均は90.53%となり明らかに悪くなっているのがわかります。クラス別に比較してみましょう。ベースラインからの比較です。

部分-ベース 飛行機 自動車 鹿 カエル トラック
1 0 -13 -48 -30 -54 -7 -5 -45 -12 -28
2 -15 -6 -55 -19 -31 0 -14 -49 -14 -11
3 -1 -20 -46 -39 -11 -14 -23 -40 -13 -12
4 -6 -5 -25 -66 -60 -35 -20 -39 5 -35
5 -32 -1 -32 -31 -41 -48 -17 -38 -8 -16
平均 -10.8 -9.0 -41.2 -37.0 -39.4 -20.8 -15.8 -42.2 -8.4 -20.4

太字にしたのがAugmentationしたクラスです。部分的にAugmentしたクラスが明らかに足を引っ張っています

全体-ベース 飛行機 自動車 鹿 カエル トラック
1 3 -9 -75 -28 -50 -7 -8 -42 -3 -39
2 -4 -11 -71 -1 -54 13 -9 -60 -19 -15
3 2 -21 -55 -42 -18 -14 -30 -41 3 -26
4 -19 -7 -31 -61 -57 -27 -25 -35 3 -32
5 -9 1 -57 -22 -46 -38 -21 -41 -3 -19
平均 -5.4 -9.4 -57.8 -30.8 -45.0 -14.6 -18.6 -43.8 -3.8 -26.2

全部で悪化しているので、こんなのやらないほうがマシですね。猫も引っ張られて悪くなっていますが、局所的なAugmentationを入れたクラスの方が大きく悪くなっているのがポイントです。しかし、これは部分的なAugmentationをするときのみおこることで、全クラス一括でやればここまで悪くなりません。

局所的なAugmentation vs 局所的なSampling

局所的なAugmentationはどうも意味がないというのがわかりそうですが、一方で、局所的なSamplingというのは意味があるだろうというのが通説です。これは不均衡データで用いられているOversampling/Undersamplingというテクニックによります。

CIFARのような整ったデータではないケースで、クラス間にサンプル数が偏りがあるケースを想定します。Oversamplingなら少ないクラスをより高い頻度でサンプリングし、Undersamplingなら多いクラスをより低い頻度でサンプリングするテクニックで、これは不均衡の是正やトータルでの精度向上に寄与します。同じクラス単位の局所的な操作でも、AugmentationとSamplingを切り離して考えたほうがよさそうです。

まとめ

クラス単位のAugmentationというのはどうも意味がないではないか。

全体でAugmentationしたときにクラス間にばらつきが出て、仮に大きく上昇するクラスがあったとしても、それを局所的にやると大きく精度を損なうことがある。Augmentationをするのなら、全体に共通の処理を施すべき。

ということでした。もしかしたら「いや、クラス単位のAugmentationも意味あるんだ」というケースもあるかもしれませんので、そうだったら知らせていただけると助かります。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です