こしあん
2019-11-10

スタイル変換のStyle Lossとは何をやっているか

Pocket
LINEで送る
Delicious にシェア

490{icon} {views}

新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』(A4・195ページ)好評通販中です! 機械学習の入門からGANの最先端までを書いたおすすめの本となっています! Boothで試し読みできます。情報まとめ・質問用GitHub



スタイル変換やImage to Imageの損失関数で使われる・Style Lossの実装を詳しく見ていきます。Style Lossの計算で用いているグラム行列の計算方法をTensorFlowで考えます。

Style Lossのやっていること

2つの画像の、VGG16や19(どっちを使うか、どのレイヤーを使うかは論文によって異なる)の中間層の値を取り出し、ぞれぞれのグラム行列を取りL1ロスを計算する。

比較する画像は論文によって異なるが、例えばP-Convの場合は、復元画像とGround Truthの画像を比較する。

実装から見る

例えばStyle Lossを使っているP-Convの場合。コードはこちらより。

https://github.com/MathiasGruber/PConv-Keras/blob/master/libs/pconv_model.py

    def loss_style(self, output, vgg_gt):
        """Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
        loss = 0
        for o, g in zip(output, vgg_gt):
            loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
        return loss

グラム行列のL1ロスを取っています。L1ロスはおなじみの定義ですね。

    @staticmethod
    def l1(y_true, y_pred):
        """Calculate the L1 loss used in all loss calculations"""
        if K.ndim(y_true) == 4:
            return K.mean(K.abs(y_pred - y_true), axis=[1,2,3])
        elif K.ndim(y_true) == 3:
            return K.mean(K.abs(y_pred - y_true), axis=[1,2])
        else:
            raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network")

で、問題はグラム行列のほうです。

    @staticmethod
    def gram_matrix(x, norm_by_channels=False):
        """Calculate gram matrix used in style loss"""

        # Assertions on input
        assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor'
        assert K.image_data_format() == 'channels_last', "Please use channels-last format"        

        # Permute channels and get resulting shape
        x = K.permute_dimensions(x, (0, 3, 1, 2))
        shape = K.shape(x)
        B, C, H, W = shape[0], shape[1], shape[2], shape[3]

        # Reshape x and do batch dot product
        features = K.reshape(x, K.stack([B, C, H*W]))
        gram = K.batch_dot(features, features, axes=2)

        # Normalize with channels, height and width
        gram = gram /  K.cast(C * H * W, x.dtype)

        return gram

入力が$(B, H, W, C)$という4階テンソルだとしましょう。Bはバッチサイズ、Hは縦解像度、Wは横解像度、Cはチャンネル数を表します。

コードを順におっていくと、まずは次元を入れ替えます。$(B, C, H, W)$としています。次に、$(B, C, HW)$という3階に変形します。そのあとのbatch_dotというのがKeras独特の関数なのですが、1階以降に対して、

$$(C, HW)\dot (HW, C)=(C, C)$$

という計算をするのがbatch_dotの関数です。

batch_dotをTensorFlowの関数で置き換えて考える

batch_dotをもう少しわかりやすい関数で見てみます。以下の処理とbatch_dotは同じ結果を出します。

import tensorflow as tf
import tensorflow.keras.backend as K

def batch_dot():
    x = tf.reshape(tf.range(24), (4, 2, 3))
    # Kerasのbatch_dot
    a = K.batch_dot(x, x, axes=2)
    print("Keras batch dot")
    print(a)
    # tfの関数で再現
    y = tf.transpose(x, (0, 2, 1))
    b = tf.matmul(x, y)
    print("Implemented by tensorflow function")
    print(b)
    print("DIff")
    print(b - a)

if __name__ == "__main__":
    batch_dot()

結果は次の通り。

Keras batch dot
tf.Tensor(
[[[   5   14]
  [  14   50]]

 [[ 149  212]
  [ 212  302]]

 [[ 509  626]
  [ 626  770]]

 [[1085 1256]
  [1256 1454]]], shape=(4, 2, 2), dtype=int32)
Implemented by tensorflow function
tf.Tensor(
[[[   5   14]
  [  14   50]]

 [[ 149  212]
  [ 212  302]]

 [[ 509  626]
  [ 626  770]]

 [[1085 1256]
  [1256 1454]]], shape=(4, 2, 2), dtype=int32)
DIff
tf.Tensor(
[[[0 0]
  [0 0]]

 [[0 0]
  [0 0]]

 [[0 0]
  [0 0]]

 [[0 0]
  [0 0]]], shape=(4, 2, 2), dtype=int32)

この通り、計算結果が同一になります。つまり、3階に変形して、1階以降を転置したものとの積を取ったのがbatch_dotなのです。

一般にグラム行列といえば、2階の行列に対して$WW^T$を計算するのがグラム行列(左右は逆転してもOK)なので、TensorFlowで実装したコードは3階に拡張したグラム行列の定義そのものになります。

グラム行列の直感的な解釈

そもそも「なんでスタイルの損失を決めるのにグラム行列を取るの?」ということなのですが、分散共分散行列・相関行列はグラム行列の特別な場合です。詳しくはこちらの記事で書いたのでを参照してみてください。

つまり、「VGGでの中間層での空間での、ピクセル間の相関のようなものを似せる」というのが、Style Lossのやっていることです。もしこれをグラム行列のL1ではなくピクセル間のL1を取るとPerceptual Lossというものになりますが、Perceptual LossとStyle Lossの違いとは、Style Lossのほうが相関を見ているので、より広範囲、より位置に依存しない特徴を見ているということになります。


新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』好評通販中(A4・195ページ)です! Boothで試し読みもできるのでよろしくね!


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です