こしあん
2019-10-04

PyTorchでConvolutionフィルターをやる(エッジ検出やアンシャープマスク)

Pocket
LINEで送る


PyTorchでPILのConvolutionフィルター(エッジ検出やアンシャープマスク)をやりたくなったので、どう実装するか考えてみました。

やりたいこと

PIL/PillowのConvolutionフィルター(ImageFilterなど)の処理をPyTorchの畳み込み演算で再現したい

使う画像はこれ。自分が昔撮ってきたものです。「train.jpg」とします。

畳み込みカーネルの値はこちらの記事のものを使っています。だいたいPILの結果と同じになるはずです。

準備

PyTorchのテンソルに変換するために下記の関数を用意しましょう。

from PIL import Image
import numpy as np
import torch
import torch.nn.functional as F
import torchvision

def load_tensor():
    with Image.open("train.jpg") as img:
        array = np.asarray(img, np.float32).transpose([2, 0, 1]) / 255.0
        tensor = torch.as_tensor(np.expand_dims(array, axis=0))  # rank 4
    return tensor

アンシャープマスク

def sharpen_filter():
    kernel = np.array([[-2, -2, -2], [-2, 32, -2], [-2, -2, -2]], np.float32) / 16.0  # convolution filter
    with torch.no_grad():
        # [out_ch, in_ch, .., ..] : channel wiseに計算
        sharpen_k = torch.as_tensor(kernel.reshape(1, 1, 3, 3))

        color = load_tensor()  # color image [1, 3, H, W]
        # channel-wise conv(大事) 3x3 convなのでPadding=1を入れる
        multiband = [F.conv2d(color[:, i:i + 1,:,:], sharpen_k, padding=1) for i in range(3)]
        sharpened_image = torch.cat(multiband, dim=1)
        torchvision.utils.save_image(sharpened_image, "shapren.jpg")

結果

全面の窓ガラスのあたりが少しくっきりしたイメージがあります。ただ、右上の屋根と空の境界線は少しシャギーになりましたね。

コードは、アンシャープマスク用の畳み込みカーネルを用意します。ただ、カラー画像にConvolutionをかけるときは、畳み込みカーネルを(3, 3, 3, 3)というshapeで用意して一括でConv2dを入れると、最後にチャンネルすべての和を取って、値が溢れたりグレースケールのような出力になってしまいます。チャンネル単位でConv2dをとって、最後のConcatしましょう(大事)

エッジ検出

def edge_detection():
    kernel = np.array([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], np.float32)  # convolution filter
    with torch.no_grad():
        # [out_ch, in_ch, .., ..] : channel wiseに計算
        edge_k = torch.as_tensor(kernel.reshape(1, 1, 3, 3))

        # エッジ検出はグレースケール化してからやる
        color = load_tensor()  # color image [1, 3, H, W]
        gray_kernel = np.array([0.299, 0.587, 0.114], np.float32).reshape(1, 3, 1, 1)  # color -> gray kernel
        gray_k = torch.as_tensor(gray_kernel)
        gray = torch.sum(color * gray_k, dim=1, keepdim=True)  # grayscale image [1, 1, H, W]

        # エッジ検出
        edge_image = F.conv2d(gray, edge_k, padding=1)
        torchvision.utils.save_image(edge_image, "edge.jpg")

結果:

いい感じにできていますね。

エッジ検出はもともとグレースケールにかけるものなので、チャンネル単位のConv2dは考える必要はありません。グレースケール化の係数やエッジ検出のカーネルは何パターンかあるので、あくまで一例です。

まとめ

PyTorchだろうと(ここでは紹介しなかったけどKerasだろうと)、PILのようなConvolutionフィルターは頑張れば再現できますよ、ということでした。

Related Posts

ML Study Jams中級編終わらせてきた ML Study JamsというGoogle Cloudが提供している無料の学習プログラムの第二弾がオープンしています。今度は中級編が追加されており、全部終わらせてきたのでその報告と感想を書いていきたいと思います。 前回の記事 QWIKLABSの使い方とかはこっち。 ML Study Jam...
TensorFlow2.0で訓練の途中に学習率を変える方法... TensorFlow2.0で訓練の途中に学習率を変える方法を、Keras APIと訓練ループを自分で書くケースとで見ていきます。従来のKerasではLearning Rate Schedulerを使いましたが、TF2.0ではどうすればいいでしょうか? Keras APIの場合 従来どおりLea...
OpenCVで画像を歪ませる方法 PythonでOpenCVを使い画像を歪ませる方法を考えます。アフィン変換というちょっと直感的に理解しにくいことをしますが、慣れればそこまで難しくはありません。ディープラーニングのData Augmentationにも使えます。 OpenCVでのアフィン変換のイメージ アフィン変換というと、ま...
TensorFlow/Kerasでネットワーク内でData Augmentationする方法... NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。 こんなイメージ Numpy(CPU)でやる場合 Num...
Numpyの配列をN個飛ばしで列挙する簡単な方法... Numpyの配列から奇数番目、偶数番目の要素を取り出したいときが稀によくあります。インデックスの配列を定義する必要があるのかなと思いますが、とても簡単な方法があります。それを見ていきましょう。 基本は「::スキップしたい間隔」 例として、0~9までの配列をとります。 >>>...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です