PyTorchでガウシアンピラミッド+ラプラシアンピラミッド(Gaussian/Laplacian Pyramid)
Progressive-GANの論文で、SWD(Sliced Wasserstein Distance)が評価指標として出てきたので、その途中で必要になったガウシアンピラミッド、ラプラシアンピラミッドをPyTorchで実装してみました。これらのピラミッドはGAN関係なく、画像処理一般で使えるものです。応用例として、ラプラシアンブレンドもPyTorchで実装しています。
目次
関連
OpenCVのドキュメントが非常に参考になりました。こちらを参考にCNNのConvolutionに置き換えて実装しました。
またここに出てくる実装は、自分の下記の記事がベースになっています。必要なら参考にしてみてください。
- PyTorchでConvolutionフィルターをやる(エッジ検出やアンシャープマスク)
https://blog.shikoan.com/pytorch-edge-unsharp/ - 画像のピラミッドを1枚の画像として出力するサンプル
https://blog.shikoan.com/image-pyramid-concat/ - OpenCVのsubtractについての小ネタ
https://blog.shikoan.com/opencv-subtract/
ガウシアンピラミッド、ラプラシアンピラミッド
まず画像のピラミッドとは何かというと、もとの画像の解像度を1/2、1/2、1/2……のように繰り返し縮小してまとめたものがピラミッドです。縮小する過程でガウシアンぼかしを入れたものがガウシアンピラミッド、ガウシアンピラミッドに解像度間の差を取ったものがラプラシアンピラミッドとなります。
ラプラシアンピラミッドはバンドパスフィルタのようなことをやっているので、画像の周波数分解とも考えることができます。OpenCVのドキュメントからです。
上がガウシアンピラミッドで、下がラプラシアンピラミッドです。大きい画像が高周波数成分、低い画像が低周波数成分です。発想としてはOctave ConvolutionのOctaveと同じです。
以下にアルゴリズムの詳細を示します。
ガウシアンピラミッド
以下のようなコードで示される5×5の畳み込みカーネルを用意します。
import numpy as np
kernel = np.array([
[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]], np.float32) / 256.0
OpenCVのドキュメントだと一部分母が16になっていましたが、256にするのが正しいかと思われます(普通の畳み込みカーネルも行列の全要素の和が分母になってます)。
OpenCVでは、「5×5カーネルの畳み込み→偶数の行と列を削除する」という処理を実行しています。しかし、ディープラーニングの畳み込み関数(Conv2D)ではもっと簡潔に表せて、「5×5カーネル、stride=2、padding=2」というパラメーターで(ほぼ)同じ処理を実行できます。固定値のカーネルをどう計算するのかはこちらを参照してください。
つまり、PyTorchでガウシアンピラミッドを実装するには、上で表されるような5×5カーネルによるConv2dを連続で実行するだけということになります。
ラプラシアンピラミッド
ラプラシアンピラミッドは少し複雑です。ガウシアンピラミッドで紹介した処理は解像度を下げる方向なので「pyramid_down」と表すことにします。ラプラシアンピラミッドでは、解像度を上げる「pyramid_up」という処理が必要になります。大きな流れは次の通りです。
- ガウシアンピラミッド(pyramid_down)で計算されたピラミッドのi番目の画像について考える
- i番目の画像に対し、pyramid_downをしたi+1番目のガウシアンピラミッドに注目する
- i+1番目のガウシアンピラミッドをpyramid_upし、1.で示したi番目の画像との差分を取る
ラプラシアンピラミッドのやっていることは周波数帯ごとの差分計算なので、「ラプラシアンピラミッドの個数=ガウシアンピラミッドの個数-1」となります。
次は「pyramid_up」はどういう処理をするかということです。
- 入力画像を2倍の解像度にアップサンプリングする(後でガウシアンぼかしを入れることから、Nearest Neighbor法で良いと思われる)
- pyramid_downで使ったのと同じ畳み込みカーネルでガウシアンぼかしをかける
これでラプラシアンピラミッドのアルゴリズムは終わりです。
コード
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
# 画像ファイル→PyTorchテンソル
def load_tensor(img_path):
with Image.open(img_path) as img:
array = np.asarray(img, np.float32).transpose([2, 0, 1]) / 255.0
tensor = torch.as_tensor(np.expand_dims(array, axis=0)) # rank 4
return tensor
# ピラミッドを1枚の画像に結合して保存するための関数
def tile_pyramid(imgs):
height, width = imgs[0].size()[2:]
with torch.no_grad():
canvas = torch.zeros(1, 3, height * 3 // 2, width)
x, y = 0, 0
for i, img in enumerate(imgs):
h, w = img.size()[2:]
canvas[:,:, y:(y + h), x:(x + w)] = img
if i % 2 == 0:
x += width // (2 ** (i + 3))
y += height // (2 ** i) # 0, 2, 4..でy方向にシフト
else:
x += width // (2 ** i) # 1, 3, 5..でx方向にシフト
y += height // (2 ** (i + 3))
return canvas
# ガウシアンぼかしのカーネル
def get_gaussian_kernel():
# [out_ch, in_ch, .., ..] : channel wiseに計算
kernel = np.array([
[1, 4, 6, 4, 1],
[4, 16, 24, 16, 4],
[6, 24, 36, 24, 6],
[4, 16, 24, 16, 4],
[1, 4, 6, 4, 1]], np.float32) / 256.0
gaussian_k = torch.as_tensor(kernel.reshape(1, 1, 5, 5))
return gaussian_k
def pyramid_down(image):
with torch.no_grad():
gaussian_k = get_gaussian_kernel()
# channel-wise conv(大事)
multiband = [F.conv2d(image[:, i:i + 1,:,:], gaussian_k, padding=2, stride=2) for i in range(3)]
down_image = torch.cat(multiband, dim=1)
return down_image
def pyramid_up(image):
with torch.no_grad():
gaussian_k = get_gaussian_kernel()
upsample = F.interpolate(image, scale_factor=2)
multiband = [F.conv2d(upsample[:, i:i + 1,:,:], gaussian_k, padding=2) for i in range(3)]
up_image = torch.cat(multiband, dim=1)
return up_image
def gaussian_pyramid(original, n_pyramids):
x = original
# pyramid down
pyramids = [original]
for i in range(n_pyramids):
x = pyramid_down(x)
pyramids.append(x)
return pyramids
def laplacian_pyramid(original, n_pyramids):
# gaussian pyramidを作る
pyramids = gaussian_pyramid(original, n_pyramids)
# pyramid up - diff
laplacian = []
for i in range(len(pyramids) - 1):
diff = pyramids[i] - pyramid_up(pyramids[i + 1])
laplacian.append(diff)
return laplacian
# 見やすいようにMin-Maxでスケーリングする
def normalize_pyramids(pyramids):
result = []
for diff in pyramids:
diff_min = torch.min(diff)
diff_max = torch.max(diff)
diff_normalize = (diff - diff_min) / (diff_max - diff_min)
result.append(diff_normalize)
return result
ラプラシアンピラミッドの出力結果を見やすくするために、Min-MaxスケーリングでNormalizeした関数も入れています。
結果:ガウシアンピラミッド
三江線に登場していただきましょう。これをtrain.jpgとします。
if __name__ == "__main__":
original = load_tensor("train.jpg")
pyramid = gaussian_pyramid(original, 6)
torchvision.utils.save_image(tile_pyramid(pyramid), "gaussian_pyramid.jpg")
ガウシアンピラミッドは次のようになります。
画像がだんだん小さくなっているのでわかりづらいですが、ガウシアンぼかしをかけているのでだんだんぼやけてくるのがわかります。
結果:ラプラシアンピラミッド
同じようにラプラシアンピラミッドを作ってみます。元画像のピクセル値が0~1の場合、ラプラシアンピラミッドは差分を取っているので、-1~1までの値が取り得ます。そこで、Min-Max Scalerを適用し、出力画像が0~1の範囲になるようにします。
if __name__ == "__main__":
original = load_tensor("train.jpg")
pyramid = laplacian_pyramid(original, 6)
normalize = normalize_pyramids(pyramid)
torchvision.utils.save_image(tile_pyramid(normalize), "laplacian_pyramid.jpg")
この結果はとてもおもしろいです。細かい輪郭線が高周波数帯に、逆にテクスチャや色は低周波数帯に集中しているのがわかります。
これはCNNとのアナロジーとしても理解できます。CNNは浅い層では解像度が高く(高周波)、深い層では解像度が低い(低周波)構成となっています。同時に浅い層では主に輪郭やエッジのような波形的情報を見ており、深い層ではテクスチャや何が映っているかという空間的・意味的な情報を見ていると言われています。
ラプラシアンピラミッドによって、CNNでなくても画像そのものが高周波・低周波で同様の構成になっており、CNNの構造はまさにそれにフィットしている――だからCNNが画像に向いている、ということが可視化できるともいえるでしょう。
ちなみにOpenCVの実装では、subtract(引き算)関数でオーバーフロー抑制機能が機能し、マイナスの値を0に丸められています。したがって、このPyTorchの実装とOpenCVの実装では結果が少し違います。ただし、マイナスの値を丸めないほうが明らかに情報は落ちないので、ここではそのまま「ただの引き算」として処理しています。後述のラプラシアンブレンドでも、丸めないほうが自然な出力になります。
ラプラシアンブレンド
応用例として、OpenCVで紹介されているラプラシアンブレンドをPyTorchで実装してみました。
こちらのサイトから、「apple.jpg」と「orange.jpg」をダウンロードしてきます。
# ラプラシアンブレンディング
def laplacian_blending():
apple = load_tensor("apple.jpg")
orange = load_tensor("orange.jpg")
apple_lap = laplacian_pyramid(apple, 5)
orange_lap = laplacian_pyramid(orange, 5)
# ラプラシアンピラミッドを左右にブレンドする
blend_lap = []
for x, y in zip(apple_lap, orange_lap):
width = x.size(3)
b = torch.cat([x[:,:,:,:width // 2], y[:,:,:, width // 2:]], dim=3)
blend_lap.append(b)
# 最高レベルのガウシアンピラミッドのブレンド
apple_top = gaussian_pyramid(apple, 5)[-1]
orange_top = gaussian_pyramid(orange, 5)[-1]
out = torch.cat([apple_top[:,:,:,:apple_top.size(3) // 2],
orange_top[:,:,:, orange_top.size(3) // 2:]], dim=3)
# ラプラシアンピラミッドからの再構築
for lap in blend_lap[::-1]:
out = pyramid_up(out) + lap
torchvision.utils.save_image(out, "laplacian_blend.png")
# 比較例:ダイレクトにブレンド
direct = torch.cat([apple[:,:,:,:apple.size(3) // 2],
orange[:,:,:, orange.size(3) // 2:]], dim=3)
torchvision.utils.save_image(direct, "direct_blend.png")
結果は以下の通りです。
ダイレクトブレンド(元画像を左右に貼り付けただけ)
ラプラシアンブレンド
ダイレクトブレンドはつぎはぎ目が明らかに不自然ですが(クソコラ感がすごい)、ラプラシアンブレンドは滑らかになっていますね。
まとめ
PyTorchだろうとラプラシアンピラミッドは実装できました。これを使えばSWD(Sliced Wasserstein Distance)は多分計算できると思います。ラプラシアンピラミッドは単なる画像特徴量としても面白そうです。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー