TensorFlow/Kerasでネットワーク内でData Augmentationする方法
Posted On 2019-02-13
NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。
目次
こんなイメージ
Numpy(CPU)でやる場合
- NumpyでDataAugmentation→model.fit_generator(…)→Input→ニューラルネットワーク
ニューラルネットワークでやる場合
- model.fit(…)/.fit_generator(…)→Input→Data Augmetation層→ニューラルネットワークの隠れ層
今回はこちらを見ていきます。
Data Augmentation層
今回はランダムなHorizontal Flip+Random Cropといういわゆる「Standard Data Augmentation」を実装します。
Girlという標準画像を変形していきます。
ちなみに答えから言ってしまうと、Horizontal FlipもRandom CropもTensorFlowで関数があります。
- Horizontal Flip:https://www.tensorflow.org/api_docs/python/tf/image/random_flip_left_right
- RandomCrop : https://www.tensorflow.org/api_docs/python/tf/image/random_crop
これをLambda層でラップしていきます。
from keras.layers import Input, Lambda
from keras.models import Model
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
def augmentation_perimage(img):
x = tf.image.random_flip_left_right(img)
x = tf.random_crop(x, [128,128,3])
return x
def standard_augmentation(inputs):
random_flip = tf.map_fn(augmentation_perimage, inputs)
return random_flip
def create_model():
input = Input((256,256,3))
x = Lambda(standard_augmentation)(input)
return Model(input, x)
def main():
with Image.open("girl.jpg") as img:
girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
model = create_model()
pred = model.predict(images)
for i in range(16):
ax = plt.subplot(4,4,i+1)
ax.axis("off")
ax.imshow(pred[i])
plt.show()
main()
うまくいった。Data Augmenation専用の関数があったことに驚きですね。
train/testの挙動の振り分け
だいたいこれでいいんですが、これだとpredictでもAugmentationしちゃうので、trainだけAugmentationするようにします。
from keras.layers import Input, Lambda, Layer
from keras.models import Model
from PIL import Image
import keras.backend as K
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
class StandardAugmentation(Layer):
def __init__(self):
super().__init__()
def call(self, inputs, training=None):
return K.in_train_phase(tf.map_fn(self.augmentation_perimage, inputs),
inputs, training=training)
def augmentation_perimage(self, img):
x = tf.image.random_flip_left_right(img)
x = tf.random_crop(x, [128,128,3])
x = tf.image.resize_image_with_crop_or_pad(x, 256, 256)
return x
def create_model():
input = Input((256,256,3))
x = StandardAugmentation()(input)
return Model(input, x)
def main():
with Image.open("girl.jpg") as img:
girl = np.asarray(img.resize((256,256)), np.float32) / 255.0
images = np.repeat(np.expand_dims(girl, 0), 16, axis=0)
model = create_model()
model.summary()
f = K.function([model.layers[0].input, K.learning_phase()],
[model.layers[-1].output])
pred = f([images, 0])[0] # ここを0、1にする
あえてこうやって書くとちょっと面倒かもしれませんね。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー