PyTorchでConvolutionフィルターをやる(エッジ検出やアンシャープマスク)
PyTorchでPILのConvolutionフィルター(エッジ検出やアンシャープマスク)をやりたくなったので、どう実装するか考えてみました。
目次
やりたいこと
PIL/PillowのConvolutionフィルター(ImageFilterなど)の処理をPyTorchの畳み込み演算で再現したい
使う画像はこれ。自分が昔撮ってきたものです。「train.jpg」とします。
畳み込みカーネルの値はこちらの記事のものを使っています。だいたいPILの結果と同じになるはずです。
準備
PyTorchのテンソルに変換するために下記の関数を用意しましょう。
from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
def load_tensor():
with Image.open("train.jpg") as img:
array = np.asarray(img, np.float32).transpose([2, 0, 1]) / 255.0
tensor = torch.as_tensor(np.expand_dims(array, axis=0)) # rank 4
return tensor
アンシャープマスク
def sharpen_filter():
kernel = np.array([[-2, -2, -2], [-2, 32, -2], [-2, -2, -2]], np.float32) / 16.0 # convolution filter
with torch.no_grad():
# [out_ch, in_ch, .., ..] : channel wiseに計算
sharpen_k = torch.as_tensor(kernel.reshape(1, 1, 3, 3))
color = load_tensor() # color image [1, 3, H, W]
# channel-wise conv(大事) 3x3 convなのでPadding=1を入れる
multiband = [F.conv2d(color[:, i:i + 1,:,:], sharpen_k, padding=1) for i in range(3)]
sharpened_image = torch.cat(multiband, dim=1)
torchvision.utils.save_image(sharpened_image, "shapren.jpg")
結果
全面の窓ガラスのあたりが少しくっきりしたイメージがあります。ただ、右上の屋根と空の境界線は少しシャギーになりましたね。
コードは、アンシャープマスク用の畳み込みカーネルを用意します。ただ、カラー画像にConvolutionをかけるときは、畳み込みカーネルを(3, 3, 3, 3)というshapeで用意して一括でConv2dを入れると、最後にチャンネルすべての和を取って、値が溢れたりグレースケールのような出力になってしまいます。チャンネル単位でConv2dをとって、最後のConcatしましょう(大事)。
エッジ検出
def edge_detection():
kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], np.float32) # convolution filter
with torch.no_grad():
# [out_ch, in_ch, .., ..] : channel wiseに計算
edge_k = torch.as_tensor(kernel.reshape(1, 1, 3, 3))
# エッジ検出はグレースケール化してからやる
color = load_tensor() # color image [1, 3, H, W]
gray_kernel = np.array([0.299, 0.587, 0.114], np.float32).reshape(1, 3, 1, 1) # color -> gray kernel
gray_k = torch.as_tensor(gray_kernel)
gray = torch.sum(color * gray_k, dim=1, keepdim=True) # grayscale image [1, 1, H, W]
# エッジ検出
edge_image = F.conv2d(gray, edge_k, padding=1)
torchvision.utils.save_image(edge_image, "edge.jpg")
結果:
いい感じにできていますね。
エッジ検出はもともとグレースケールにかけるものなので、チャンネル単位のConv2dは考える必要はありません。グレースケール化の係数やエッジ検出のカーネルは何パターンかあるので、あくまで一例です。
まとめ
PyTorchだろうと(ここでは紹介しなかったけどKerasだろうと)、PILのようなConvolutionフィルターは頑張れば再現できますよ、ということでした。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー