こしあん
2019-02-13

TensorFlow/Kerasでネットワーク内でData Augmentationする方法


NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。

こんなイメージ

Numpy(CPU)でやる場合

  • NumpyでDataAugmentation→model.fit_generator(…)→Input→ニューラルネットワーク

ニューラルネットワークでやる場合

  • model.fit(…)/.fit_generator(…)→Input→Data Augmetation層→ニューラルネットワークの隠れ層

今回はこちらを見ていきます。

Data Augmentation層

今回はランダムなHorizontal Flip+Random Cropといういわゆる「Standard Data Augmentation」を実装します。

Girlという標準画像を変形していきます。

ちなみに答えから言ってしまうと、Horizontal FlipもRandom CropもTensorFlowで関数があります。

これをLambda層でラップしていきます。

from keras.layers import Input, Lambda
from keras.models import Model
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def augmentation_perimage(img):
    x = tf.image.random_flip_left_right(img)
    x = tf.random_crop(x, [128,128,3])
    return x

def standard_augmentation(inputs):
    random_flip = tf.map_fn(augmentation_perimage, inputs)
    return random_flip

def create_model():
    input = Input((256,256,3))
    x = Lambda(standard_augmentation)(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    pred = model.predict(images)

    for i in range(16):
        ax = plt.subplot(4,4,i+1)
        ax.axis("off")
        ax.imshow(pred[i])
    plt.show()

main()

うまくいった。Data Augmenation専用の関数があったことに驚きですね。

train/testの挙動の振り分け

だいたいこれでいいんですが、これだとpredictでもAugmentationしちゃうので、trainだけAugmentationするようにします。

from keras.layers import Input, Lambda, Layer
from keras.models import Model
from PIL import Image
import keras.backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

class StandardAugmentation(Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs, training=None):
        return K.in_train_phase(tf.map_fn(self.augmentation_perimage, inputs), 
                                inputs, training=training)

    def augmentation_perimage(self, img):
        x = tf.image.random_flip_left_right(img)
        x = tf.random_crop(x, [128,128,3])
        x = tf.image.resize_image_with_crop_or_pad(x, 256, 256)
        return x

def create_model():
    input = Input((256,256,3))
    x = StandardAugmentation()(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    model.summary()

    f = K.function([model.layers[0].input, K.learning_phase()],
                   [model.layers[-1].output])
    pred = f([images, 0])[0] # ここを0、1にする


あえてこうやって書くとちょっと面倒かもしれませんね。

Related Posts

TPUで学習率減衰させる方法 TPUで学習率減衰したいが、TensorFlowのオプティマイザーを使うべきか、tf.kerasのオプティマイザーを使うべきか、あるいはKerasのオプティマイザーを使うべきか非常にややこしいことがあります。TPUで学習率を減衰させる方法を再現しました。 結論から TPU環境でtf.keras...
Kerasに組み込まれているMobileNetの実装 MobileNetのsummary _________________________________________________________________ Layer (type) Output Shape Param # ...
OpenCVで画像を歪ませる方法 PythonでOpenCVを使い画像を歪ませる方法を考えます。アフィン変換というちょっと直感的に理解しにくいことをしますが、慣れればそこまで難しくはありません。ディープラーニングのData Augmentationにも使えます。 OpenCVでのアフィン変換のイメージ アフィン変換というと、ま...
Python(Numpy)で画像を水平反転する方法:Data Augmentation向け... OpenCVを使わずに単純に画像を左右反転(水平反転)する方法を考えます。ディープラーニングでデータのジェネレーターを自分で実装した場合、Data Augmentationを組み込む際にも必要になります。それを見ていきましょう。 左右反転自体は実は簡単 例えばNumpyの行列を左右反転させてみ...
Kerasでモデルのsummaryをテキストとして保存する方法... Kerasで「plot_modelを使えばモデルの可視化ができるが、GraphViz入れないといけなかったり、セットアップが面倒くさい!model.summary()のテキストをファイル保存で十分だ!」という場合に使えるテクニックです。 summary()のprint_fn引数を使う sum...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です