クラス別のData Augmentationって意味ある?を調べてみた
Data Augmentationで精度を上げることにお熱になっていると、「特定の足引っ張っているクラスだけAugmetationかけたらいいんじゃない?」的なことをやりたくなります。これを調べてみました。その結果、全体を一括で処理するケースと、クラス別に局所的な処理をするケースで、精度に大きなギャップが生まれることが確認できました(局所的にやるほうが悪くなります)。
目次
クラス別の精度を見てみよう
これは以前の投稿でやった、CIFAR-10を10層CNNで分類し、左右反転とクロップを加えたStandard Data Augmentationを加えたときの混同行列です。
このように、自動車の精度は97%以上出ても、猫の精度は85%も出ていなく、クラス間にかなりムラがあることがわかります。これを見ると「なら、自動車はAugmentationかけずに、猫だけAugmentationすればいいやん」って誰しも一度は考えると思うのです。
しかし、クラス別のData Augmentationについて調べると、自分が探した限りではまるで情報が出てこない、少なくとも論文が見当たらなかったというのが実情でした(もしかしたら自分が知らないだけで無知なのかもしれません)。多分うまくいかないんだろうな、ということなんでしょう。これを確かめてみました。
自分が確かめてみた結果、クラス別のAugmentationは多分良くないんじゃないかという結論になりました。
やること
CIFAR-10にStandard Data Augmentationを加えた時点をベースラインとします。ここから、
- ベースラインのクラス別精度を取る
- color_shift_range = 50を全部のクラスでAugmentationする
- 1,2でベースラインに対して精度が上がったクラスのみ、color_shift_range = 50のAugmentationをかける(基準はベースラインから)
部分的なAugmentationってどうするの?っていうことですが、こうします。
def partial_generator(X, y, color_shift, batch_size=128):
non_shift_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
width_shift_range=4.0/32, height_shift_range=4.0/32
).flow(X, y, batch_size, seed=123)
shift_gen = ImageDataGenerator(rescale=1.0/255, horizontal_flip=True,
width_shift_range=4.0/32, height_shift_range=4.0/32,
channel_shift_range=color_shift).flow(X, y, batch_size, seed=123)
while True:
X_batch_no_shift, y_batch = next(non_shift_gen)
X_batch_shifted, _ = next(shift_gen)
if X_batch_no_shift.shape[0] != batch_size or X_batch_shifted.shape[0] != batch_size:
continue
y_ind = np.argmax(y_batch, axis=-1)
# shift=50だと2,4,6,7,9が上がったから
for i in range(batch_size):
if y_ind[i] in [2, 4, 6, 7, 9]:
X_batch_no_shift[i] = X_batch_shifted[i]
yield X_batch_no_shift, y_batch
このように2つのジェネレーターを同時に走らせ、同一の乱数シードでflowさせます。こうすることで、シャッフルされるものの、同一のインデックスが抽出されるようになります。
ベースラインのクラス別精度
まずは、ベースラインのクラス別精度です。混同行列の対角成分を取りました。(各値を10で割ると%単位の精度になります)
ベース | 飛行機 | 自動車 | 鳥 | 猫 | 鹿 | 犬 | カエル | 馬 | 船 | トラック |
---|---|---|---|---|---|---|---|---|---|---|
1 | 944 | 976 | 884 | 850 | 937 | 879 | 951 | 961 | 961 | 956 |
2 | 950 | 972 | 890 | 855 | 925 | 880 | 958 | 947 | 955 | 949 |
3 | 945 | 979 | 903 | 846 | 931 | 881 | 961 | 943 | 966 | 950 |
4 | 943 | 976 | 895 | 846 | 939 | 886 | 953 | 950 | 952 | 957 |
5 | 953 | 973 | 890 | 852 | 937 | 885 | 953 | 949 | 959 | 957 |
平均 | 947.0 | 975.2 | 892.4 | 849.8 | 933.8 | 882.2 | 955.2 | 950.0 | 958.6 | 953.8 |
全体の平均=92.98%となりました。ちなみに、CIFAR-10はクラス単位のサンプル数が同じ均衡データなので(テストデータは全クラス1000個ずつある)、ただ単に平均して構いません。
全部のクラスにColor shiftを入れる
次に全部のクラスに対して、color_shift=50のAugmentationを追加で入れます。混同行列の値は以下の通りです。
全体50 | 飛行機 | 自動車 | 鳥 | 猫 | 鹿 | 犬 | カエル | 馬 | 船 | トラック |
---|---|---|---|---|---|---|---|---|---|---|
1 | 941 | 972 | 911 | 848 | 933 | 879 | 954 | 958 | 952 | 967 |
2 | 939 | 977 | 906 | 837 | 948 | 867 | 953 | 958 | 960 | 953 |
3 | 942 | 980 | 912 | 849 | 938 | 881 | 968 | 944 | 950 | 964 |
4 | 956 | 978 | 901 | 841 | 936 | 878 | 958 | 946 | 954 | 954 |
5 | 930 | 971 | 915 | 843 | 942 | 875 | 957 | 952 | 954 | 960 |
平均 | 941.6 | 975.6 | 909.0 | 843.6 | 939.4 | 876.0 | 958.0 | 951.6 | 954.0 | 959.6 |
全体の平均は=93.08%です。効果あるんだかかなり微妙ですが、一応効果あるとしましょう。ベースラインとの差分を取ります。
差分 | 飛行機 | 自動車 | 鳥 | 猫 | 鹿 | 犬 | カエル | 馬 | 船 | トラック |
---|---|---|---|---|---|---|---|---|---|---|
1 | -3 | -4 | 27 | -2 | -4 | 0 | 3 | -3 | -9 | 11 |
2 | -11 | 5 | 16 | -18 | 23 | -13 | -5 | 11 | 5 | 4 |
3 | -3 | 1 | 9 | 3 | 7 | 0 | 7 | 1 | -16 | 14 |
4 | 13 | 2 | 6 | -5 | -3 | -8 | 5 | -4 | 2 | -3 |
5 | -23 | -2 | 25 | -9 | 5 | -10 | 4 | 3 | -5 | 3 |
平均 | -5.4 | 0.4 | 16.6 | -6.2 | 5.6 | -6.2 | 2.8 | 1.6 | -4.6 | 5.8 |
これを見ると、確かにクラス別に効きやすい・効きにくいという傾向があるようです。太字としたクラスを「効いている」クラスとして扱いました。先程のコードの、
# shift=50だと2,4,6,7,9が上がったから
for i in range(batch_size):
if y_ind[i] in [2, 4, 6, 7, 9]:
X_batch_no_shift[i] = X_batch_shifted[i]
「2,4,6,7,9」の出処というのはこの太字のインデックスに対応します。
効果があったクラスに対してだけ、Color shiftを加える
先程のshift_range=50のカラーシフトのうち、効果のあった「インデックス=2,4,6,7,9」に対してのみColor shiftのAugmentationを加えます。結果は以下の通りです。
部分50 | 飛行機 | 自動車 | 鳥 | 猫 | 鹿 | 犬 | カエル | 馬 | 船 | トラック |
---|---|---|---|---|---|---|---|---|---|---|
1 | 944 | 963 | 836 | 820 | 883 | 872 | 946 | 916 | 949 | 928 |
2 | 935 | 966 | 835 | 836 | 894 | 880 | 944 | 898 | 941 | 938 |
3 | 944 | 959 | 857 | 807 | 920 | 867 | 938 | 903 | 953 | 938 |
4 | 937 | 971 | 870 | 780 | 879 | 851 | 933 | 911 | 957 | 922 |
5 | 921 | 972 | 858 | 821 | 896 | 837 | 936 | 911 | 951 | 941 |
平均 | 936.2 | 966.2 | 851.2 | 812.8 | 894.4 | 861.4 | 939.4 | 907.8 | 950.2 | 933.4 |
全体の平均は90.53%となり明らかに悪くなっているのがわかります。クラス別に比較してみましょう。ベースラインからの比較です。
部分-ベース | 飛行機 | 自動車 | 鳥 | 猫 | 鹿 | 犬 | カエル | 馬 | 船 | トラック |
---|---|---|---|---|---|---|---|---|---|---|
1 | 0 | -13 | -48 | -30 | -54 | -7 | -5 | -45 | -12 | -28 |
2 | -15 | -6 | -55 | -19 | -31 | 0 | -14 | -49 | -14 | -11 |
3 | -1 | -20 | -46 | -39 | -11 | -14 | -23 | -40 | -13 | -12 |
4 | -6 | -5 | -25 | -66 | -60 | -35 | -20 | -39 | 5 | -35 |
5 | -32 | -1 | -32 | -31 | -41 | -48 | -17 | -38 | -8 | -16 |
平均 | -10.8 | -9.0 | -41.2 | -37.0 | -39.4 | -20.8 | -15.8 | -42.2 | -8.4 | -20.4 |
太字にしたのがAugmentationしたクラスです。部分的にAugmentしたクラスが明らかに足を引っ張っています。
全体-ベース | 飛行機 | 自動車 | 鳥 | 猫 | 鹿 | 犬 | カエル | 馬 | 船 | トラック |
---|---|---|---|---|---|---|---|---|---|---|
1 | 3 | -9 | -75 | -28 | -50 | -7 | -8 | -42 | -3 | -39 |
2 | -4 | -11 | -71 | -1 | -54 | 13 | -9 | -60 | -19 | -15 |
3 | 2 | -21 | -55 | -42 | -18 | -14 | -30 | -41 | 3 | -26 |
4 | -19 | -7 | -31 | -61 | -57 | -27 | -25 | -35 | 3 | -32 |
5 | -9 | 1 | -57 | -22 | -46 | -38 | -21 | -41 | -3 | -19 |
平均 | -5.4 | -9.4 | -57.8 | -30.8 | -45.0 | -14.6 | -18.6 | -43.8 | -3.8 | -26.2 |
全部で悪化しているので、こんなのやらないほうがマシですね。猫も引っ張られて悪くなっていますが、局所的なAugmentationを入れたクラスの方が大きく悪くなっているのがポイントです。しかし、これは部分的なAugmentationをするときのみおこることで、全クラス一括でやればここまで悪くなりません。
局所的なAugmentation vs 局所的なSampling
局所的なAugmentationはどうも意味がないというのがわかりそうですが、一方で、局所的なSamplingというのは意味があるだろうというのが通説です。これは不均衡データで用いられているOversampling/Undersamplingというテクニックによります。
CIFARのような整ったデータではないケースで、クラス間にサンプル数が偏りがあるケースを想定します。Oversamplingなら少ないクラスをより高い頻度でサンプリングし、Undersamplingなら多いクラスをより低い頻度でサンプリングするテクニックで、これは不均衡の是正やトータルでの精度向上に寄与します。同じクラス単位の局所的な操作でも、AugmentationとSamplingを切り離して考えたほうがよさそうです。
まとめ
クラス単位のAugmentationというのはどうも意味がないではないか。
全体でAugmentationしたときにクラス間にばらつきが出て、仮に大きく上昇するクラスがあったとしても、それを局所的にやると大きく精度を損なうことがある。Augmentationをするのなら、全体に共通の処理を施すべき。
ということでした。もしかしたら「いや、クラス単位のAugmentationも意味あるんだ」というケースもあるかもしれませんので、そうだったら知らせていただけると助かります。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー