こしあん
2019-10-05

PyTorchでガウシアンピラミッド+ラプラシアンピラミッド(Gaussian/Laplacian Pyramid)

Pocket
LINEで送る

Progressive-GANの論文で、SWD(Sliced Wasserstein Distance)が評価指標として出てきたので、その途中で必要になったガウシアンピラミッド、ラプラシアンピラミッドをPyTorchで実装してみました。これらのピラミッドはGAN関係なく、画像処理一般で使えるものです。応用例として、ラプラシアンブレンドもPyTorchで実装しています。

関連

OpenCVのドキュメントが非常に参考になりました。こちらを参考にCNNのConvolutionに置き換えて実装しました。

画像ピラミッド
http://labs.eecs.tottori-u.ac.jp/sd/Member/oyamada/OpenCV/html/py_tutorials/py_imgproc/py_pyramids/py_pyramids.html

またここに出てくる実装は、自分の下記の記事がベースになっています。必要なら参考にしてみてください。

ガウシアンピラミッド、ラプラシアンピラミッド

まず画像のピラミッドとは何かというと、もとの画像の解像度を1/2、1/2、1/2……のように繰り返し縮小してまとめたものがピラミッドです。縮小する過程でガウシアンぼかしを入れたものがガウシアンピラミッド、ガウシアンピラミッドに解像度間の差を取ったものがラプラシアンピラミッドとなります。

ラプラシアンピラミッドはバンドパスフィルタのようなことをやっているので、画像の周波数分解とも考えることができます。OpenCVのドキュメントからです。


上がガウシアンピラミッドで、下がラプラシアンピラミッドです。大きい画像が高周波数成分、低い画像が低周波数成分です。発想としてはOctave ConvolutionのOctaveと同じです。

以下にアルゴリズムの詳細を示します。

ガウシアンピラミッド

以下のようなコードで示される5×5の畳み込みカーネルを用意します。

import numpy as np

kernel = np.array([
    [1, 4, 6, 4, 1],
    [4, 16, 24, 16, 4],
    [6, 24, 36, 24, 6],
    [4, 16, 24, 16, 4],
    [1, 4, 6, 4, 1]], np.float32) / 256.0

OpenCVのドキュメントだと一部分母が16になっていましたが、256にするのが正しいかと思われます(普通の畳み込みカーネルも行列の全要素の和が分母になってます)。

OpenCVでは、「5×5カーネルの畳み込み→偶数の行と列を削除する」という処理を実行しています。しかし、ディープラーニングの畳み込み関数(Conv2D)ではもっと簡潔に表せて、「5×5カーネル、stride=2、padding=2」というパラメーターで(ほぼ)同じ処理を実行できます。固定値のカーネルをどう計算するのかはこちらを参照してください。

つまり、PyTorchでガウシアンピラミッドを実装するには、上で表されるような5×5カーネルによるConv2dを連続で実行するだけということになります。

ラプラシアンピラミッド

ラプラシアンピラミッドは少し複雑です。ガウシアンピラミッドで紹介した処理は解像度を下げる方向なので「pyramid_down」と表すことにします。ラプラシアンピラミッドでは、解像度を上げる「pyramid_up」という処理が必要になります。大きな流れは次の通りです。

  1. ガウシアンピラミッド(pyramid_down)で計算されたピラミッドのi番目の画像について考える
  2. i番目の画像に対し、pyramid_downをしたi+1番目のガウシアンピラミッドに注目する
  3. i+1番目のガウシアンピラミッドをpyramid_upし、1.で示したi番目の画像との差分を取る

ラプラシアンピラミッドのやっていることは周波数帯ごとの差分計算なので、「ラプラシアンピラミッドの個数=ガウシアンピラミッドの個数-1」となります。

次は「pyramid_up」はどういう処理をするかということです。

  • 入力画像を2倍の解像度にアップサンプリングする(後でガウシアンぼかしを入れることから、Nearest Neighbor法で良いと思われる)
  • pyramid_downで使ったのと同じ畳み込みカーネルでガウシアンぼかしをかける

これでラプラシアンピラミッドのアルゴリズムは終わりです。

コード

from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision

# 画像ファイル→PyTorchテンソル
def load_tensor(img_path):
    with Image.open(img_path) as img:
        array = np.asarray(img, np.float32).transpose([2, 0, 1]) / 255.0
        tensor = torch.as_tensor(np.expand_dims(array, axis=0))  # rank 4
    return tensor

# ピラミッドを1枚の画像に結合して保存するための関数
def tile_pyramid(imgs):
    height, width = imgs[0].size()[2:]
    with torch.no_grad():
        canvas = torch.zeros(1, 3, height * 3 // 2, width)
        x, y = 0, 0
        for i, img in enumerate(imgs):
            h, w = img.size()[2:]
            canvas[:,:, y:(y + h), x:(x + w)] = img            
            if i % 2 == 0:
                x += width // (2 ** (i + 3))                    
                y += height // (2 ** i) # 0, 2, 4..でy方向にシフト
            else:
                x += width // (2 ** i)  # 1, 3, 5..でx方向にシフト
                y += height // (2 ** (i + 3))
        return canvas

# ガウシアンぼかしのカーネル
def get_gaussian_kernel():
    # [out_ch, in_ch, .., ..] : channel wiseに計算
    kernel = np.array([
        [1, 4, 6, 4, 1],
        [4, 16, 24, 16, 4],
        [6, 24, 36, 24, 6],
        [4, 16, 24, 16, 4],
        [1, 4, 6, 4, 1]], np.float32) / 256.0
    gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5))
    return gaussian_k

def pyramid_down(image):
    with torch.no_grad():
        gaussian_k = get_gaussian_kernel()        
        # channel-wise conv(大事)
        multiband = [F.conv2d(image[:, i:i + 1,:,:], gaussian_k, padding=2, stride=2) for i in range(3)]
        down_image = torch.cat(multiband, dim=1)
    return down_image

def pyramid_up(image):
    with torch.no_grad():
        gaussian_k = get_gaussian_kernel()
        upsample = F.interpolate(image, scale_factor=2)
        multiband = [F.conv2d(upsample[:, i:i + 1,:,:], gaussian_k, padding=2) for i in range(3)]
        up_image = torch.cat(multiband, dim=1)
    return up_image

def gaussian_pyramid(original, n_pyramids):
    x = original
    # pyramid down
    pyramids = [original]
    for i in range(n_pyramids):
        x = pyramid_down(x)
        pyramids.append(x)
    return pyramids

def laplacian_pyramid(original, n_pyramids):
    # gaussian pyramidを作る
    pyramids = gaussian_pyramid(original, n_pyramids)

    # pyramid up - diff
    laplacian = []
    for i in range(len(pyramids) - 1):
        diff = pyramids[i] - pyramid_up(pyramids[i + 1])
        laplacian.append(diff)
    return laplacian

# 見やすいようにMin-Maxでスケーリングする
def normalize_pyramids(pyramids):
    result = []
    for diff in pyramids:
        diff_min = torch.min(diff)
        diff_max = torch.max(diff)
        diff_normalize = (diff - diff_min) / (diff_max - diff_min)
        result.append(diff_normalize)
    return result

ラプラシアンピラミッドの出力結果を見やすくするために、Min-MaxスケーリングでNormalizeした関数も入れています。

結果:ガウシアンピラミッド

三江線に登場していただきましょう。これをtrain.jpgとします。

if __name__ == "__main__":
    original = load_tensor("train.jpg")
    pyramid = gaussian_pyramid(original, 6)
    torchvision.utils.save_image(tile_pyramid(pyramid), "gaussian_pyramid.jpg")    

ガウシアンピラミッドは次のようになります。

画像がだんだん小さくなっているのでわかりづらいですが、ガウシアンぼかしをかけているのでだんだんぼやけてくるのがわかります。

結果:ラプラシアンピラミッド

同じようにラプラシアンピラミッドを作ってみます。元画像のピクセル値が0~1の場合、ラプラシアンピラミッドは差分を取っているので、-1~1までの値が取り得ます。そこで、Min-Max Scalerを適用し、出力画像が0~1の範囲になるようにします。

if __name__ == "__main__":
    original = load_tensor("train.jpg")
    pyramid = laplacian_pyramid(original, 6)
    normalize = normalize_pyramids(pyramid)
    torchvision.utils.save_image(tile_pyramid(normalize), "laplacian_pyramid.jpg")

この結果はとてもおもしろいです。細かい輪郭線が高周波数帯に、逆にテクスチャや色は低周波数帯に集中しているのがわかります。

これはCNNとのアナロジーとしても理解できます。CNNは浅い層では解像度が高く(高周波)、深い層では解像度が低い(低周波)構成となっています。同時に浅い層では主に輪郭やエッジのような波形的情報を見ており、深い層ではテクスチャや何が映っているかという空間的・意味的な情報を見ていると言われています。

ラプラシアンピラミッドによって、CNNでなくても画像そのものが高周波・低周波で同様の構成になっており、CNNの構造はまさにそれにフィットしている――だからCNNが画像に向いている、ということが可視化できるともいえるでしょう。

ちなみにOpenCVの実装では、subtract(引き算)関数でオーバーフロー抑制機能が機能し、マイナスの値を0に丸められています。したがって、このPyTorchの実装とOpenCVの実装では結果が少し違います。ただし、マイナスの値を丸めないほうが明らかに情報は落ちないので、ここではそのまま「ただの引き算」として処理しています。後述のラプラシアンブレンドでも、丸めないほうが自然な出力になります。

ラプラシアンブレンド

応用例として、OpenCVで紹介されているラプラシアンブレンドをPyTorchで実装してみました。

こちらのサイトから、「apple.jpg」と「orange.jpg」をダウンロードしてきます。

# ラプラシアンブレンディング
def laplacian_blending():
    apple = load_tensor("apple.jpg")
    orange = load_tensor("orange.jpg")
    apple_lap = laplacian_pyramid(apple, 5)
    orange_lap = laplacian_pyramid(orange, 5)

    # ラプラシアンピラミッドを左右にブレンドする
    blend_lap = []
    for x, y in zip(apple_lap, orange_lap):
        width = x.size(3)
        b = torch.cat([x[:,:,:,:width // 2], y[:,:,:, width // 2:]], dim=3)
        blend_lap.append(b)

    # 最高レベルのガウシアンピラミッドのブレンド
    apple_top = gaussian_pyramid(apple, 5)[-1]
    orange_top = gaussian_pyramid(orange, 5)[-1]
    out = torch.cat([apple_top[:,:,:,:apple_top.size(3) // 2],
                     orange_top[:,:,:, orange_top.size(3) // 2:]], dim=3)

    # ラプラシアンピラミッドからの再構築
    for lap in blend_lap[::-1]:
        out = pyramid_up(out) + lap
    torchvision.utils.save_image(out, "laplacian_blend.png")

    # 比較例:ダイレクトにブレンド
    direct = torch.cat([apple[:,:,:,:apple.size(3) // 2],
                        orange[:,:,:, orange.size(3) // 2:]], dim=3)
    torchvision.utils.save_image(direct, "direct_blend.png")

結果は以下の通りです。

ダイレクトブレンド(元画像を左右に貼り付けただけ)

ラプラシアンブレンド

ダイレクトブレンドはつぎはぎ目が明らかに不自然ですが(クソコラ感がすごい)、ラプラシアンブレンドは滑らかになっていますね。

まとめ

PyTorchだろうとラプラシアンピラミッドは実装できました。これを使えばSWD(Sliced Wasserstein Distance)は多分計算できると思います。ラプラシアンピラミッドは単なる画像特徴量としても面白そうです。

Related Posts

keras_preprocessingを使ってお手軽に画像を回転させる方法... Data Augmentationで画像を回転させたいことがあります。画像の回転は一般に「アフィン変換」と呼ばれる操作で、OpenCVやPillowのライブラリを使えば簡単にできるのですが、Numpy配列に対して1から書くとかなりめんどいのです。Kerasが裏で使っているkeras_preproc...
TensorFlowでコサイン類似度を計算する方法... TensorFlowで損失関数や距離関数に「コサイン類似度」を使うことを考えます。Scikit-learnでは簡単に計算できますが、同様にTensorFlowでの行列演算でも計算できます。それを見ていきます。 コサイン類似度 コサイン類似度は、ユークリッド距離と同様に距離関数として使われる評価...
画像のピラミッドを1枚の画像として出力するサンプル... 同一画像で繰り返し半分に縮小しながら積み重ねていく操作(ピラミッド)が必要になったので、ピラミッドを1枚の画像として出力するサンプルを作ってみました。 ピラミッド 同一画像の解像度をある一定比率(よくある例では半分)で繰り返し縮小しながら積み重ねていくことを、ピラミッドと言います。OpenCV...
KerasのModelCheckpointのsave_best_onlyは何を表すのか?... Kerasには「モデルの精度が良くなったときだけ係数を保存する」のに便利なModelCheckpointというクラスがあります。ただこのsave_best_onlyがいまいち公式の解説だとピンとこないので調べてみました。 ModelCheckpointとは? 公式ドキュメントより ke...
note開設のお知らせ 本日noteを開設いたしました。 https://note.mu/koshian2 これは自分の記事をより多くの方々に読んでいただき、新たな読者の開拓を図るためであります。 当面は既存の記事の再送を中心に考えていますが、いくつかnote向けに読みやすい新規の記事も考えています。好評なら新規の...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です