TPUでアップサンプリングする際にエラーを出さない方法
画像処理をしているとUpsamplingが必要になることがあります。Keras/TensorFlowではUpsampling2Dというレイヤーを使ってアップサンプリングができますが、このレイヤーがTPUだとエラーを出すので解決法を探しました。自分でアップサンプリングレイヤーを定義するとうまく行ったので、それを見ていきます。
環境:TensorFlow v1.12.0
目次
TPU使わないとうまくいく
例えば、今簡単な例として、shape=(8,1,1,1)の0~7の数字の配列を画像と見立てます。これを縦横2倍のアップサンプリングをして、shape=(8,2,2,1)という形に変形します。
KerasのUpsampling2Dレイヤーを使った実装ではこうでしょう。
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import numpy as np
def upsampling_model():
input = layers.Input((1,1,1))
x = layers.UpSampling2D(2)(input)
return Model(input, x)
def upsampling_test():
model = upsampling_model()
X = np.arange(8).reshape(-1,1,1,1)
y = model.predict(X)
print(y.shape)
if __name__ == "__main__":
upsampling_test()
これは実際うまく行って、出力のshapeは正しい値になります。
(8, 2, 2, 1)
UpSampling2DはTPUではうまくは行かない
ただこのUpsampling2Dを使った方法はTPUだとエラーを出します。TPUに変換する処理を含めたコードです。
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
from tensorflow.contrib.tpu.python.tpu import keras_support
import numpy as np
import os
def upsampling_model():
input = layers.Input((1,1,1))
x = layers.UpSampling2D(2)(input)
return Model(input, x)
def upsampling_test():
model = upsampling_model()
# TPUモデルに変換するためにコンパイルが必要なのでこの値に意味はない
model.compile(tf.train.GradientDescentOptimizer(0.1), "mean_squared_error")
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
X = np.arange(8).reshape(-1,1,1,1)
y = model.predict(X)
print(y.shape)
if __name__ == "__main__":
upsampling_test()
途中コンパイルで適当なオプティマイザーや損失関数を与えていますが、これはTPUのモデルに変換するのにコンパイルが必要なので、この例では、コンパイルで指定した値は全く意味がありません。
ちなみにこのコードはコンパイル失敗します。
RuntimeError: Compilation failed: Compilation failure: Detected unsupported operations when trying to compile graph cluster_1_5423988494562430942[] on XLA_TPU_JIT: ResizeNearestNeighbor (No registered 'ResizeNearestNeighbor' OpKernel for XLA_TPU_JIT devices compatible with node {{node tpu_140457990329008/up_sampling2d_1/ResizeNearestNeighbor}} = ResizeNearestNeighbor[T=DT_FLOAT, align_corners=false, _device="/device:TPU_REPLICATED_CORE"](infeed-infer, tpu_140457990329008/up_sampling2d_1/mul)
. Registered: device='CPU'; T in [DT_DOUBLE]
device='CPU'; T in [DT_FLOAT]
device='CPU'; T in [DT_BFLOAT16]
device='CPU'; T in [DT_HALF]
device='CPU'; T in [DT_INT8]
device='CPU'; T in [DT_UINT8]
device='CPU'; T in [DT_INT16]
device='CPU'; T in [DT_UINT16]
device='CPU'; T in [DT_INT32]
device='CPU'; T in [DT_INT64]
){{node tpu_140457990329008/up_sampling2d_1/ResizeNearestNeighbor}}
どうもNearestNeighborのアップサンプリングのコンパイルが失敗するみたいですね。
NumpyでのNearestNeighbor法
さて、もっと簡単にNumpyでの例を振り返りましょう。以前こちらの記事で紹介した方法です。
Numpyだけでサクッと画像を拡大する方法
https://blog.shikoan.com/numpy-upsampling-image/
実はNumpyではForループを一切使わずにNearestNeighbor法の画像拡大をすることができます。repeatを2回かませる方法です。
image.repeat(2, axis=0).repeat(2, axis=1)
実はKerasのバックエンド関数に同様のrepeat関数があるため(K.repeat_elements()
)、このNumpyの方法をそのままTensorFlowに転用することができます。つまり、TensorFlow用のアップサンプリング用の関数を自分で書いてLambdaレイヤーでラップしてあげればOKです。
TPUで動く方法
こうしてみました。「upsampling2d_tpu」というのが自分で定義したアップサンプリング用の関数です。
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.contrib.tpu.python.tpu import keras_support
import tensorflow.keras.backend as K
import numpy as np
import os
def upsampling_model():
input = layers.Input((1,1,1))
x = layers.Lambda(upsampling2d_tpu, arguments={"scale":2})(input)
return Model(input, x)
def upsampling2d_tpu(inputs, scale=2):
x = K.repeat_elements(inputs, scale, axis=1)
x = K.repeat_elements(x, scale, axis=2)
return x
def upsampling_test():
model = upsampling_model()
# TPUモデルに変換するためにコンパイルが必要なのでこの値に意味はない
model.compile(tf.train.GradientDescentOptimizer(0.1), "mean_squared_error")
tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
strategy = keras_support.TPUDistributionStrategy(tpu_cluster_resolver)
model = tf.contrib.tpu.keras_to_tpu_model(model, strategy=strategy)
X = np.arange(8).reshape(-1,1,1,1)
y = model.predict(X)
print(y.shape)
if __name__ == "__main__":
upsampling_test()
これはうまくいきます。
INFO:tensorflow:New input shapes; (re-)compiling: mode=infer (# of cores 8), [TensorSpec(shape=(1, 1, 1, 1), dtype=tf.float32, name='input_12_10')]
INFO:tensorflow:Overriding default placeholder.
INFO:tensorflow:Remapping placeholder for input_12
INFO:tensorflow:Started compiling
INFO:tensorflow:Finished compiling. Time elapsed: 0.20801544189453125 secs
INFO:tensorflow:Setting weights on TPU model.
(8, 2, 2, 1)
TPUでもちゃんとアップサンプリングを行うことができました。NearestNeighbor法での拡大でいいところに、わざわざConv2DTransposeなどのレイヤーをはさむ必要はなさそうです。
まとめ
TPUではどうもUpSampling2Dのコンパイルに失敗するっぽい(今後のバージョンで改善される可能性あり)。K.repeat_elementsなどの関数を使い、自分でアップサンプリングの処理を書くとエラーにならないよ、ということでした。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー