こしあん
2019-02-13

TensorFlow/Kerasでネットワーク内でData Augmentationする方法

Pocket
LINEで送る


NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。

こんなイメージ

Numpy(CPU)でやる場合

  • NumpyでDataAugmentation→model.fit_generator(…)→Input→ニューラルネットワーク

ニューラルネットワークでやる場合

  • model.fit(…)/.fit_generator(…)→Input→Data Augmetation層→ニューラルネットワークの隠れ層

今回はこちらを見ていきます。

Data Augmentation層

今回はランダムなHorizontal Flip+Random Cropといういわゆる「Standard Data Augmentation」を実装します。

Girlという標準画像を変形していきます。

ちなみに答えから言ってしまうと、Horizontal FlipもRandom CropもTensorFlowで関数があります。

これをLambda層でラップしていきます。

from keras.layers import Input, Lambda
from keras.models import Model
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def augmentation_perimage(img):
    x = tf.image.random_flip_left_right(img)
    x = tf.random_crop(x, [128,128,3])
    return x

def standard_augmentation(inputs):
    random_flip = tf.map_fn(augmentation_perimage, inputs)
    return random_flip

def create_model():
    input = Input((256,256,3))
    x = Lambda(standard_augmentation)(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    pred = model.predict(images)

    for i in range(16):
        ax = plt.subplot(4,4,i+1)
        ax.axis("off")
        ax.imshow(pred[i])
    plt.show()

main()

うまくいった。Data Augmenation専用の関数があったことに驚きですね。

train/testの挙動の振り分け

だいたいこれでいいんですが、これだとpredictでもAugmentationしちゃうので、trainだけAugmentationするようにします。

from keras.layers import Input, Lambda, Layer
from keras.models import Model
from PIL import Image
import keras.backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

class StandardAugmentation(Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs, training=None):
        return K.in_train_phase(tf.map_fn(self.augmentation_perimage, inputs), 
                                inputs, training=training)

    def augmentation_perimage(self, img):
        x = tf.image.random_flip_left_right(img)
        x = tf.random_crop(x, [128,128,3])
        x = tf.image.resize_image_with_crop_or_pad(x, 256, 256)
        return x

def create_model():
    input = Input((256,256,3))
    x = StandardAugmentation()(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    model.summary()

    f = K.function([model.layers[0].input, K.learning_phase()],
                   [model.layers[-1].output])
    pred = f([images, 0])[0] # ここを0、1にする


あえてこうやって書くとちょっと面倒かもしれませんね。

Related Posts

TensorFlow/Kerasでグラム行列(テンソル)を計算する方法... TensorFlowで分散や共分散が絡む演算を定義していると、グラム行列を計算する必要が出てくることがあります。行列はまだよくてもテンソルのグラム行列はどう計算するでしょうか?今回はテンソルの共分散計算に行く前に、その前提のテンソルのグラム行列の計算から見ていきます。 グラム行列とは 名前は仰...
Pillowでグレースケール化するときに3チャンネルで出力するテクニック... カラー画像をグレースケール化すると1チャンネルの出力になりますが、カラー画像と同時に扱うと整合性からグレースケールでも3チャンネルで欲しいことがあります。Numpyのブロードキャストを使わずに簡単に3チャンネルで出力する方法を書きます。 グレースケール化してカラー化すればOK これでOK i...
WarmupとData Augmentationのバッチサイズ別の精度低下について... 大きいバッチサイズで訓練する際は、バッチサイズの増加にともなう精度低下が深刻になります。この精度低下を抑制することはできるのですが、例えばData Augmentationのようなデータ増強・正則化による精度向上とは何が違うのでしょうか。それを調べてみました。 きっかけ この記事を書いたときに...
データのお気持ちを考えながらData Augmentationする... Data Augmentationの「なぜ?」に注目しながら、エラー分析をしてCIFAR-10の精度向上を目指します。その結果、オレオレAugmentationながら、Wide ResNetで97.3%という、Auto Augmentとほぼ同じ(-0.1%)精度を出すことができました。 (※すご...
Affinity LossをCIFAR-10で精度を求めてひたすら頑張った話... 不均衡データに対して有効性があると言われている損失関数「Affinity loss」をCIFAR-10で精度を出すためにひたすら頑張った、というひたすら泥臭い話。条件10個試したらやっと精度を出すためのコツみたいなのが見えてきました。 結論 長いので先に結論から。CIFAR-10をAffini...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です