スタイル変換のStyle Lossとは何をやっているか
スタイル変換やImage to Imageの損失関数で使われる・Style Lossの実装を詳しく見ていきます。Style Lossの計算で用いているグラム行列の計算方法をTensorFlowで考えます。
目次
Style Lossのやっていること
2つの画像の、VGG16や19(どっちを使うか、どのレイヤーを使うかは論文によって異なる)の中間層の値を取り出し、ぞれぞれのグラム行列を取りL1ロスを計算する。
比較する画像は論文によって異なるが、例えばP-Convの場合は、復元画像とGround Truthの画像を比較する。
実装から見る
例えばStyle Lossを使っているP-Convの場合。コードはこちらより。
https://github.com/MathiasGruber/PConv-Keras/blob/master/libs/pconv_model.py
def loss_style(self, output, vgg_gt):
"""Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
loss = 0
for o, g in zip(output, vgg_gt):
loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
return loss
グラム行列のL1ロスを取っています。L1ロスはおなじみの定義ですね。
@staticmethod
def l1(y_true, y_pred):
"""Calculate the L1 loss used in all loss calculations"""
if K.ndim(y_true) == 4:
return K.mean(K.abs(y_pred - y_true), axis=[1,2,3])
elif K.ndim(y_true) == 3:
return K.mean(K.abs(y_pred - y_true), axis=[1,2])
else:
raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network")
で、問題はグラム行列のほうです。
@staticmethod
def gram_matrix(x, norm_by_channels=False):
"""Calculate gram matrix used in style loss"""
# Assertions on input
assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor'
assert K.image_data_format() == 'channels_last', "Please use channels-last format"
# Permute channels and get resulting shape
x = K.permute_dimensions(x, (0, 3, 1, 2))
shape = K.shape(x)
B, C, H, W = shape[0], shape[1], shape[2], shape[3]
# Reshape x and do batch dot product
features = K.reshape(x, K.stack([B, C, H*W]))
gram = K.batch_dot(features, features, axes=2)
# Normalize with channels, height and width
gram = gram / K.cast(C * H * W, x.dtype)
return gram
入力が$(B, H, W, C)$という4階テンソルだとしましょう。Bはバッチサイズ、Hは縦解像度、Wは横解像度、Cはチャンネル数を表します。
コードを順におっていくと、まずは次元を入れ替えます。$(B, C, H, W)$としています。次に、$(B, C, HW)$という3階に変形します。そのあとのbatch_dotというのがKeras独特の関数なのですが、1階以降に対して、
$$(C, HW)\dot (HW, C)=(C, C)$$
という計算をするのがbatch_dotの関数です。
batch_dotをTensorFlowの関数で置き換えて考える
batch_dotをもう少しわかりやすい関数で見てみます。以下の処理とbatch_dotは同じ結果を出します。
import tensorflow as tf
import tensorflow.keras.backend as K
def batch_dot():
x = tf.reshape(tf.range(24), (4, 2, 3))
# Kerasのbatch_dot
a = K.batch_dot(x, x, axes=2)
print("Keras batch dot")
print(a)
# tfの関数で再現
y = tf.transpose(x, (0, 2, 1))
b = tf.matmul(x, y)
print("Implemented by tensorflow function")
print(b)
print("DIff")
print(b - a)
if __name__ == "__main__":
batch_dot()
結果は次の通り。
Keras batch dot
tf.Tensor(
[[[ 5 14]
[ 14 50]]
[[ 149 212]
[ 212 302]]
[[ 509 626]
[ 626 770]]
[[1085 1256]
[1256 1454]]], shape=(4, 2, 2), dtype=int32)
Implemented by tensorflow function
tf.Tensor(
[[[ 5 14]
[ 14 50]]
[[ 149 212]
[ 212 302]]
[[ 509 626]
[ 626 770]]
[[1085 1256]
[1256 1454]]], shape=(4, 2, 2), dtype=int32)
DIff
tf.Tensor(
[[[0 0]
[0 0]]
[[0 0]
[0 0]]
[[0 0]
[0 0]]
[[0 0]
[0 0]]], shape=(4, 2, 2), dtype=int32)
この通り、計算結果が同一になります。つまり、3階に変形して、1階以降を転置したものとの積を取ったのがbatch_dotなのです。
一般にグラム行列といえば、2階の行列に対して$WW^T$を計算するのがグラム行列(左右は逆転してもOK)なので、TensorFlowで実装したコードは3階に拡張したグラム行列の定義そのものになります。
グラム行列の直感的な解釈
そもそも「なんでスタイルの損失を決めるのにグラム行列を取るの?」ということなのですが、分散共分散行列・相関行列はグラム行列の特別な場合です。詳しくはこちらの記事で書いたのでを参照してみてください。
つまり、「VGGでの中間層での空間での、ピクセル間の相関のようなものを似せる」というのが、Style Lossのやっていることです。もしこれをグラム行列のL1ではなくピクセル間のL1を取るとPerceptual Lossというものになりますが、Perceptual LossとStyle Lossの違いとは、Style Lossのほうが相関を見ているので、より広範囲、より位置に依存しない特徴を見ているということになります。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー