こしあん
2019-02-13

TensorFlow/Kerasでネットワーク内でData Augmentationする方法

Pocket
LINEで送る
Delicious にシェア

2.8k{icon} {views}

新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』(A4・195ページ)好評通販中です! 機械学習の入門からGANの最先端までを書いたおすすめの本となっています! Boothで試し読みできます。情報まとめ・質問用GitHub



NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。

こんなイメージ

Numpy(CPU)でやる場合

  • NumpyでDataAugmentation→model.fit_generator(…)→Input→ニューラルネットワーク

ニューラルネットワークでやる場合

  • model.fit(…)/.fit_generator(…)→Input→Data Augmetation層→ニューラルネットワークの隠れ層

今回はこちらを見ていきます。

Data Augmentation層

今回はランダムなHorizontal Flip+Random Cropといういわゆる「Standard Data Augmentation」を実装します。

Girlという標準画像を変形していきます。

ちなみに答えから言ってしまうと、Horizontal FlipもRandom CropもTensorFlowで関数があります。

これをLambda層でラップしていきます。

from keras.layers import Input, Lambda
from keras.models import Model
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def augmentation_perimage(img):
    x = tf.image.random_flip_left_right(img)
    x = tf.random_crop(x, [128,128,3])
    return x

def standard_augmentation(inputs):
    random_flip = tf.map_fn(augmentation_perimage, inputs)
    return random_flip

def create_model():
    input = Input((256,256,3))
    x = Lambda(standard_augmentation)(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    pred = model.predict(images)

    for i in range(16):
        ax = plt.subplot(4,4,i+1)
        ax.axis("off")
        ax.imshow(pred[i])
    plt.show()

main()

うまくいった。Data Augmenation専用の関数があったことに驚きですね。

train/testの挙動の振り分け

だいたいこれでいいんですが、これだとpredictでもAugmentationしちゃうので、trainだけAugmentationするようにします。

from keras.layers import Input, Lambda, Layer
from keras.models import Model
from PIL import Image
import keras.backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

class StandardAugmentation(Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs, training=None):
        return K.in_train_phase(tf.map_fn(self.augmentation_perimage, inputs), 
                                inputs, training=training)

    def augmentation_perimage(self, img):
        x = tf.image.random_flip_left_right(img)
        x = tf.random_crop(x, [128,128,3])
        x = tf.image.resize_image_with_crop_or_pad(x, 256, 256)
        return x

def create_model():
    input = Input((256,256,3))
    x = StandardAugmentation()(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    model.summary()

    f = K.function([model.layers[0].input, K.learning_phase()],
                   [model.layers[-1].output])
    pred = f([images, 0])[0] # ここを0、1にする


あえてこうやって書くとちょっと面倒かもしれませんね。


新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』好評通販中(A4・195ページ)です! Boothで試し読みもできるのでよろしくね!


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です