こしあん
2019-06-24

PyTorch/TorchVisionで複数の入力をモデルに渡したいケース

Pocket
LINEで送る


PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を解説します。

はじめに

PyTorchで例えば、Pix2Pixを使って白黒画像のカラー化の学習を行いたいとします。このケースでは、訓練時の入力に「白黒画像、カラー画像」の2つの入力が必要になります。

カラー画像の白黒化のように簡単なケースの場合は、カラー画像単体を渡してしまって、モデル内/食わせる前後で、白黒化の式を使ってグレースケールのバッチを作ってしまう方法もできます。しかし、複数入力があるケースというのは必ずどこかで出現します。

しかし、torchvision.transformsでの変換は、入力が1つのみのケースを想定していて、入力が複数の場合はうまくいきません。TensorDatasetの場合は、2個以上の引数を返すというのができますが、データを全てメモリに乗せてしまうため、大きいデータの場合はメモリが制約になるケースがあります。今回はTorchVisionの一般的なDataSetをそのまま使います。具体的にはSTL-10で試してみます。

transformの複数入力対応用のラッパークラスを作る

まず結論から。複数入力に対応したラッパークラスを作るのが便利そうです。各処理を別々に対応して定義するのはバカバカしいので。

# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
    def __init__(self, base_func):
        """
        複数の入力に対してtransformsを使うためのラッパークラス
        * base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。
        """
        self.base_func = base_func

    def __call__(self, xs):
        if isinstance(self.base_func, list):
            return [f(x) for f, x in zip(self.base_func, xs)]
        else:
            return [self.base_func(x) for x in xs]

例えばこんな感じ。base_funcには、tranforms.ToTensor()やtransforms.Normalize()などの関数/クラスを渡します。「call」以下で、各入力に対してbase_funcを適用した値をリストとして返します。リストがどんどんつながっていくので今までのtransformみたいにパイプライン状に定義することが可能ですよ。

base_funcの引数はリスト対応しておくと便利でしょう。これは、入力に対して引数が異なる場合を想定しています。例えば、1つ目の入力はカラー画像なので3チャンネル全てに対してスケール調整したいが、2つ目の入力はグレースケール画像なので1チャンネルだけスケール調整したいというケース。

関数が1つで引数が複数だとちょっと処理が面倒なので、関数を複数にしてしまってリストで渡すのが簡単だと思われます。

ラッパークラスの使い方

普通のtransformsと同じような感覚で使えます。

def dataloader():
    transform = transforms.Compose([
        ColorAndGray(),
        MultiInputWrapper(transforms.ToTensor()),
        MultiInputWrapper([
            transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ])
    ])

    dataset = torchvision.datasets.STL10(
        root="./data",
        split="train",
        download=True,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=16
    )
    return loader

先程定義したMultiInputWrapperで、transformsの各処理をラップするだけですね。これでいい感じに複数入力対応してくれます。

MultiInputWrapperのbase_funcがリストの場合は、Normalizeのケースですが、このように入力順に適用する関数を与えればいいと思います。1番目がカラー画像、2番目がグレースケール画像ですね。ColorAndGray()の定義は次で説明します。

1つの入力から2つの入力へ分岐させる部分

もともとSTL-10の画像の入力は1つなので、これをカラー画像、白黒画像の2つの入力に分岐させます。これがColorAndGrayの部分です。

class ColorAndGray(object):
    def __call__(self, img):
        # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
        gray = img.convert("L")
        return img, gray

callの引数は何なんの?ということですが、呼ぶ場所によって違います。このケースではToTensor()の前で読んでいるのでPILのインスタンスです。もしToTensor()の後で呼ぶとPyTorchのテンソルになります。

このColorAndGrayの定義だと返り値は、カラー画像のPILインスタンス+白黒画像のPILインスタンスのTupleになります。以降の処理はtransformsそのままでは処理できないんで、MultiInputWrapperでラップしてケアします。

試してみる

STL-10からunlabeledを16枚とって、カラー画像、白黒画像をタイルして表示します。

def main():
    # dataloader
    loader = dataloader()
    # バッチを1つだけ取得
    (color_batch, gray_batch), y = next(iter(loader))
    # 4x4に結合して保存
    torchvision.utils.save_image(color_batch, "colors.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)
    torchvision.utils.save_image(gray_batch, "grays.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)

if __name__ == "__main__":
    main()

結果はこのようになりました。うまくいきました。


まとめ

PyTorch/TorchVisionでtransformsを複数入力に対応させたかったら、複数入力用のラッパークラスを作ると便利かもしれないよ、ということでした。

コード全体

import torch
import torchvision
from torchvision import transforms

class ColorAndGray(object):
    def __call__(self, img):
        # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
        gray = img.convert("L")
        return img, gray

# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
    def __init__(self, base_func):
        """
        複数の入力に対してtransformsを使うためのラッパークラス
        * base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。
        """
        self.base_func = base_func

    def __call__(self, xs):
        if isinstance(self.base_func, list):
            return [f(x) for f, x in zip(self.base_func, xs)]
        else:
            return [self.base_func(x) for x in xs]

def dataloader():
    transform = transforms.Compose([
        ColorAndGray(),
        MultiInputWrapper(transforms.ToTensor()),
        MultiInputWrapper([
            transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ])
    ])

    dataset = torchvision.datasets.STL10(
        root="./data",
        split="train",
        download=True,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=16
    )
    return loader

def main():
    # dataloader
    loader = dataloader()
    # バッチを1つだけ取得
    (color_batch, gray_batch), y = next(iter(loader))
    # 4x4に結合して保存
    torchvision.utils.save_image(color_batch, "colors.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)
    torchvision.utils.save_image(gray_batch, "grays.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)

if __name__ == "__main__":
    main()

Related Posts

Kerasで損失関数に複数の変数を渡す方法... Kerasで少し複雑なモデルを訓練させるときに、損失関数にy_true, y_pred以外の値を渡したいときがあります。クラスのインスタンス変数などでキャッシュさせることなく、ダイレクトに損失関数に複数の値を渡す方法を紹介します。 元ネタ:Passing additional arguments...
PandasのDataFrameでグループ別にサンプルをN個抜き出す方法... 「PandasでGroupbyでグルーピングしたはいんだけど、そこからグループ別にサンプルを1個、2個…と抜き出す、SQLでよくやるやつってどうやるんだっけ?」ということが気になったので、調べました。ちゃんとした方法があります。 例題 今、中国地方と四国地方の県と面積をDataFrameにして...
Kerasで転移学習用にレイヤー名とそのインデックスを調べる方法... Kerasで転移学習をするときに、学習済みモデルのレイヤーの名前と、そのインデックス(何番目にあるかということ)の対応を知りたいことがあります。その方法を解説します。 転移学習とは 転移学習とは、ImageNetなど何百万もの大量の画像で事前学習させたモデルを使い、それを「特徴量検出器」として...
画像をただ並べたいときに使えるTorchVision... TorchVisionはPyTorchの画像処理を手軽に行うためのライブラリですが、ディープラーニングを全く使わない、ただの画像処理でも有効に使うことができます。もちろんKerasやTensorFlowといった他のディープラーニングからの利用可能です。今回は、「画像をただ並べたいとき」にTorch...
PythonのMessagePack-Numpyで独自のクラスをシリアライズする方法... MessagePackを使ってシリアライズを高速化したかったのですが、独自のクラスやネストされたオブジェクトについてシリアル化する方法が全然なかったので調べてみました。Numpyのシリアライズも使えるMessagePackの拡張版、MessagePack-Numpyを使って確かめます。 Mess...
Pocket
LINEで送る

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です