こしあん
2019-11-10

スタイル変換のStyle Lossとは何をやっているか

Pocket
LINEで送る


スタイル変換やImage to Imageの損失関数で使われる・Style Lossの実装を詳しく見ていきます。Style Lossの計算で用いているグラム行列の計算方法をTensorFlowで考えます。

Style Lossのやっていること

2つの画像の、VGG16や19(どっちを使うか、どのレイヤーを使うかは論文によって異なる)の中間層の値を取り出し、ぞれぞれのグラム行列を取りL1ロスを計算する。

比較する画像は論文によって異なるが、例えばP-Convの場合は、復元画像とGround Truthの画像を比較する。

実装から見る

例えばStyle Lossを使っているP-Convの場合。コードはこちらより。

https://github.com/MathiasGruber/PConv-Keras/blob/master/libs/pconv_model.py

    def loss_style(self, output, vgg_gt):
        """Style loss based on output/computation, used for both eq. 4 & 5 in paper"""
        loss = 0
        for o, g in zip(output, vgg_gt):
            loss += self.l1(self.gram_matrix(o), self.gram_matrix(g))
        return loss

グラム行列のL1ロスを取っています。L1ロスはおなじみの定義ですね。

    @staticmethod
    def l1(y_true, y_pred):
        """Calculate the L1 loss used in all loss calculations"""
        if K.ndim(y_true) == 4:
            return K.mean(K.abs(y_pred - y_true), axis=[1,2,3])
        elif K.ndim(y_true) == 3:
            return K.mean(K.abs(y_pred - y_true), axis=[1,2])
        else:
            raise NotImplementedError("Calculating L1 loss on 1D tensors? should not occur for this network")

で、問題はグラム行列のほうです。

    @staticmethod
    def gram_matrix(x, norm_by_channels=False):
        """Calculate gram matrix used in style loss"""

        # Assertions on input
        assert K.ndim(x) == 4, 'Input tensor should be a 4d (B, H, W, C) tensor'
        assert K.image_data_format() == 'channels_last', "Please use channels-last format"        

        # Permute channels and get resulting shape
        x = K.permute_dimensions(x, (0, 3, 1, 2))
        shape = K.shape(x)
        B, C, H, W = shape[0], shape[1], shape[2], shape[3]

        # Reshape x and do batch dot product
        features = K.reshape(x, K.stack([B, C, H*W]))
        gram = K.batch_dot(features, features, axes=2)

        # Normalize with channels, height and width
        gram = gram /  K.cast(C * H * W, x.dtype)

        return gram

入力が$(B, H, W, C)$という4階テンソルだとしましょう。Bはバッチサイズ、Hは縦解像度、Wは横解像度、Cはチャンネル数を表します。

コードを順におっていくと、まずは次元を入れ替えます。$(B, C, H, W)$としています。次に、$(B, C, HW)$という3階に変形します。そのあとのbatch_dotというのがKeras独特の関数なのですが、1階以降に対して、

$$(C, HW)\dot (HW, C)=(C, C)$$

という計算をするのがbatch_dotの関数です。

batch_dotをTensorFlowの関数で置き換えて考える

batch_dotをもう少しわかりやすい関数で見てみます。以下の処理とbatch_dotは同じ結果を出します。

import tensorflow as tf
import tensorflow.keras.backend as K

def batch_dot():
    x = tf.reshape(tf.range(24), (4, 2, 3))
    # Kerasのbatch_dot
    a = K.batch_dot(x, x, axes=2)
    print("Keras batch dot")
    print(a)
    # tfの関数で再現
    y = tf.transpose(x, (0, 2, 1))
    b = tf.matmul(x, y)
    print("Implemented by tensorflow function")
    print(b)
    print("DIff")
    print(b - a)

if __name__ == "__main__":
    batch_dot()

結果は次の通り。

Keras batch dot
tf.Tensor(
[[[   5   14]
  [  14   50]]

 [[ 149  212]
  [ 212  302]]

 [[ 509  626]
  [ 626  770]]

 [[1085 1256]
  [1256 1454]]], shape=(4, 2, 2), dtype=int32)
Implemented by tensorflow function
tf.Tensor(
[[[   5   14]
  [  14   50]]

 [[ 149  212]
  [ 212  302]]

 [[ 509  626]
  [ 626  770]]

 [[1085 1256]
  [1256 1454]]], shape=(4, 2, 2), dtype=int32)
DIff
tf.Tensor(
[[[0 0]
  [0 0]]

 [[0 0]
  [0 0]]

 [[0 0]
  [0 0]]

 [[0 0]
  [0 0]]], shape=(4, 2, 2), dtype=int32)

この通り、計算結果が同一になります。つまり、3階に変形して、1階以降を転置したものとの積を取ったのがbatch_dotなのです。

一般にグラム行列といえば、2階の行列に対して$WW^T$を計算するのがグラム行列(左右は逆転してもOK)なので、TensorFlowで実装したコードは3階に拡張したグラム行列の定義そのものになります。

グラム行列の直感的な解釈

そもそも「なんでスタイルの損失を決めるのにグラム行列を取るの?」ということなのですが、分散共分散行列・相関行列はグラム行列の特別な場合です。詳しくはこちらの記事で書いたのでを参照してみてください。

つまり、「VGGでの中間層での空間での、ピクセル間の相関のようなものを似せる」というのが、Style Lossのやっていることです。もしこれをグラム行列のL1ではなくピクセル間のL1を取るとPerceptual Lossというものになりますが、Perceptual LossとStyle Lossの違いとは、Style Lossのほうが相関を見ているので、より広範囲、より位置に依存しない特徴を見ているということになります。

Related Posts

pix2pix HDのCoarse to fineジェネレーターを考える... pix2pix HDの論文を読んでいたら「Coarse to fineジェネレーター」という、低解像度→高解像度と解像度を分けて訓練するネットワークの工夫をしていました。pix2pixはGANですが、このジェネレーターや訓練の工夫は、Non-GANでも理屈上は使えるはずなので、この有効性をImag...
PyTorchでweight clipping WGANの論文見てたらWeight Clippingしていたので、簡単な例を実装して実験してみました。かなり簡単にできます。それを見ていきましょう。 Weight Clippingとは レイヤーの係数の値を一定範囲以内に収める手法。例えば、あるレイヤーが「-2, -1, 0, 1, 2」という...
KerasのModelCheckpointのsave_best_onlyは何を表すのか?... Kerasには「モデルの精度が良くなったときだけ係数を保存する」のに便利なModelCheckpointというクラスがあります。ただこのsave_best_onlyがいまいち公式の解説だとピンとこないので調べてみました。 ModelCheckpointとは? 公式ドキュメントより ke...
tf.tensordotで行列積を表現するための設定... TensorFlowのtensordotという関数はとても強力で、テンソルに対する行列積に対する計算をだいたい表現できます。しかし、軸の設定がいまいちよくわからなかったので、確かめてみました。 2x2行列同士の積の場合(Numpy) まず単純に2x2行列同士の(ドット)積を考えます。まずはNu...
PyTorch/TorchVisionで複数の入力をモデルに渡したいケース... PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です