こしあん
2019-02-13

TensorFlow/Kerasでネットワーク内でData Augmentationする方法


NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。

こんなイメージ

Numpy(CPU)でやる場合

  • NumpyでDataAugmentation→model.fit_generator(…)→Input→ニューラルネットワーク

ニューラルネットワークでやる場合

  • model.fit(…)/.fit_generator(…)→Input→Data Augmetation層→ニューラルネットワークの隠れ層

今回はこちらを見ていきます。

Data Augmentation層

今回はランダムなHorizontal Flip+Random Cropといういわゆる「Standard Data Augmentation」を実装します。

Girlという標準画像を変形していきます。

ちなみに答えから言ってしまうと、Horizontal FlipもRandom CropもTensorFlowで関数があります。

これをLambda層でラップしていきます。

from keras.layers import Input, Lambda
from keras.models import Model
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def augmentation_perimage(img):
    x = tf.image.random_flip_left_right(img)
    x = tf.random_crop(x, [128,128,3])
    return x

def standard_augmentation(inputs):
    random_flip = tf.map_fn(augmentation_perimage, inputs)
    return random_flip

def create_model():
    input = Input((256,256,3))
    x = Lambda(standard_augmentation)(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    pred = model.predict(images)

    for i in range(16):
        ax = plt.subplot(4,4,i+1)
        ax.axis("off")
        ax.imshow(pred[i])
    plt.show()

main()

うまくいった。Data Augmenation専用の関数があったことに驚きですね。

train/testの挙動の振り分け

だいたいこれでいいんですが、これだとpredictでもAugmentationしちゃうので、trainだけAugmentationするようにします。

from keras.layers import Input, Lambda, Layer
from keras.models import Model
from PIL import Image
import keras.backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

class StandardAugmentation(Layer):
    def __init__(self):
        super().__init__()

    def call(self, inputs, training=None):
        return K.in_train_phase(tf.map_fn(self.augmentation_perimage, inputs), 
                                inputs, training=training)

    def augmentation_perimage(self, img):
        x = tf.image.random_flip_left_right(img)
        x = tf.random_crop(x, [128,128,3])
        x = tf.image.resize_image_with_crop_or_pad(x, 256, 256)
        return x

def create_model():
    input = Input((256,256,3))
    x = StandardAugmentation()(input)
    return Model(input, x)

def main():
    with Image.open("girl.jpg") as img:
        girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
    images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
    model = create_model()
    model.summary()

    f = K.function([model.layers[0].input, K.learning_phase()],
                   [model.layers[-1].output])
    pred = f([images, 0])[0] # ここを0、1にする


あえてこうやって書くとちょっと面倒かもしれませんね。

Related Posts

note開設のお知らせ 本日noteを開設いたしました。 https://note.mu/koshian2 これは自分の記事をより多くの方々に読んでいただき、新たな読者の開拓を図るためであります。 当面は既存の記事の再送を中心に考えていますが、いくつかnote向けに読みやすい新規の記事も考えています。好評なら新規の...
Kerasでメモリ使用量を減らしたかったらmax_queue_sizeを調整しよう... Kerasで大きめの画像を使ったモデルを訓練していると、メモリが足りなくなるということがよくあります。途中処理の変数のデータ型(np.uint8)を変えるのだけではなく、max_queue_sizeの調整をする必要があることがあります。それを見ていきます。 メモリサイズの目安 ニューラルネット...
TPUで学習率減衰させる方法 TPUで学習率減衰したいが、TensorFlowのオプティマイザーを使うべきか、tf.kerasのオプティマイザーを使うべきか、あるいはKerasのオプティマイザーを使うべきか非常にややこしいことがあります。TPUで学習率を減衰させる方法を再現しました。 結論から TPU環境でtf.keras...
KerasのModelCheckpointのsave_best_onlyは何を表すのか?... Kerasには「モデルの精度が良くなったときだけ係数を保存する」のに便利なModelCheckpointというクラスがあります。ただこのsave_best_onlyがいまいち公式の解説だとピンとこないので調べてみました。 ModelCheckpointとは? 公式ドキュメントより ke...
TensorFlow/Kerasでの分散共分散行列・相関行列、テンソル主成分分析の実装... TensorFlowでは分散共分散行列や主成分分析用の関数が用意されていません。訓練を一切せずにTensorFlowとKeras関数だけを使って、分散共分散行列、相関行列、主成分分析を実装します。最終的にはカテゴリー別のテンソル主成分分析を作れるようにします。 何らかの論文でこれらのテクニックを...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です