こしあん
2019-02-13

転移学習でネットワーク内でアップサンプリングする方法(Keras)


Pocket
LINEで送る
Delicious にシェア

4.2k{icon} {views}


転移学習でインプットのサイズを揃えなければいけないことがありますが、これをRAM(CPU)上でやるとメモリが不足することがあります。転移学習の重みをそのまま使い、事前にアップサンプリングレイヤーを差し込む方法を紹介します。

関連記事とバックグラウンド

まず前提知識としてCPU側でアップサンプリングするとメモリーサイズが2乗のオーダーで増えます。解像度が128は解像度が64の4倍メモリーサイズを使います。これをメモリーに余裕があるネットワーク側でやりたいのです。

以前書いたこの記事とも重なるのですが、基本的にUpsampling2Dレイヤーを使います。以下の方法で独自にUpsamplingを定義してもOKです。

TPUでアップサンプリングする際にエラーを出さない方法
https://blog.shikoan.com/tpu-upsampling/

ただし、新規にモデルを定義する場合はそれでよかったのですが、Functional APIの要領でUpsampling済みのテンソルを差し込むというのが難しくなります。

また私が書いたこの記事の要領で、モデルを再定義し、Upsamplingのレイヤーを差し込むというのもできますが面倒です。

keras.applicationsで指定できた

実はkeras.applicationsの入力にはinput_shapeで(128,128,3)のようにshapeを指定する方法のほかに、input_tensorでKerasのテンソルを指定する方法があります。ここにアップサンプリング済みのテンソルを入れればよいわけです。イメージ的には、

vgg = VGG16(include_top=False, weights="imagenet", input_tensor=x)

このxというのが、アップサンプリング済みのテンソルです。

コード

from keras.layers import UpSampling2D, Input
from keras.applications import VGG16
from keras.applications.imagenet_utils import preprocess_input
from keras.datasets import cifar10
from keras.models import Model

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
input = Input((32,32,3))
x = UpSampling2D(4)(input)
vgg = VGG16(include_top=False, weights="imagenet", input_tensor=x)
model = Model(input, vgg.output)

model.summary()

y_pred = model.predict(preprocess_input(X_test[:16]))
print(y_pred.shape)

summaryを見ると、このようにVGG16に入る前にアップサンプリングされているのが確認できます。

_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
input_1 (InputLayer)         (None, 32, 32, 3)         0
_________________________________________________________________
up_sampling2d_1 (UpSampling2 (None, 128, 128, 3)       0
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 128, 128, 64)      1792
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 128, 128, 64)      36928
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 64, 64, 64)        0
_________________________________________________________________
~中略~
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 4, 4, 512)         0
=================================================================
Total params: 14,714,688
Trainable params: 14,714,688
Non-trainable params: 0
_________________________________________________________________

出力のshapeも問題ありません。

(16, 4, 4, 512)

これで堂々とネットワーク内でアップサンプリングできるというわけです。簡単でしたね。

Pocket
LINEで送る
Delicious にシェア



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です