こしあん
2019-10-04

PyTorchでConvolutionフィルターをやる(エッジ検出やアンシャープマスク)


4.7k{icon} {views}


PyTorchでPILのConvolutionフィルター(エッジ検出やアンシャープマスク)をやりたくなったので、どう実装するか考えてみました。

やりたいこと

PIL/PillowのConvolutionフィルター(ImageFilterなど)の処理をPyTorchの畳み込み演算で再現したい

使う画像はこれ。自分が昔撮ってきたものです。「train.jpg」とします。

畳み込みカーネルの値はこちらの記事のものを使っています。だいたいPILの結果と同じになるはずです。

準備

PyTorchのテンソルに変換するために下記の関数を用意しましょう。

from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision

def load_tensor():
    with Image.open("train.jpg") as img:
        array = np.asarray(img, np.float32).transpose([2, 0, 1]) / 255.0
        tensor = torch.as_tensor(np.expand_dims(array, axis=0))  # rank 4
    return tensor

アンシャープマスク

def sharpen_filter():
    kernel = np.array([[-2, -2, -2], [-2, 32, -2], [-2, -2, -2]], np.float32) / 16.0  # convolution filter
    with torch.no_grad():
        # [out_ch, in_ch, .., ..] : channel wiseに計算
        sharpen_k = torch.as_tensor(kernel.reshape(1, 1, 3, 3))

        color = load_tensor()  # color image [1, 3, H, W]
        # channel-wise conv(大事) 3x3 convなのでPadding=1を入れる
        multiband = [F.conv2d(color[:, i:i + 1,:,:], sharpen_k, padding=1) for i in range(3)]
        sharpened_image = torch.cat(multiband, dim=1)
        torchvision.utils.save_image(sharpened_image, "shapren.jpg")

結果

全面の窓ガラスのあたりが少しくっきりしたイメージがあります。ただ、右上の屋根と空の境界線は少しシャギーになりましたね。

コードは、アンシャープマスク用の畳み込みカーネルを用意します。ただ、カラー画像にConvolutionをかけるときは、畳み込みカーネルを(3, 3, 3, 3)というshapeで用意して一括でConv2dを入れると、最後にチャンネルすべての和を取って、値が溢れたりグレースケールのような出力になってしまいます。チャンネル単位でConv2dをとって、最後のConcatしましょう(大事)

エッジ検出

def edge_detection():
    kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], np.float32)  # convolution filter
    with torch.no_grad():
        # [out_ch, in_ch, .., ..] : channel wiseに計算
        edge_k = torch.as_tensor(kernel.reshape(1, 1, 3, 3))

        # エッジ検出はグレースケール化してからやる
        color = load_tensor()  # color image [1, 3, H, W]
        gray_kernel = np.array([0.299, 0.587, 0.114], np.float32).reshape(1, 3, 1, 1)  # color -> gray kernel
        gray_k = torch.as_tensor(gray_kernel)
        gray = torch.sum(color * gray_k, dim=1, keepdim=True)  # grayscale image [1, 1, H, W]

        # エッジ検出
        edge_image = F.conv2d(gray, edge_k, padding=1)
        torchvision.utils.save_image(edge_image, "edge.jpg")

結果:

いい感じにできていますね。

エッジ検出はもともとグレースケールにかけるものなので、チャンネル単位のConv2dは考える必要はありません。グレースケール化の係数やエッジ検出のカーネルは何パターンかあるので、あくまで一例です。

まとめ

PyTorchだろうと(ここでは紹介しなかったけどKerasだろうと)、PILのようなConvolutionフィルターは頑張れば再現できますよ、ということでした。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です