こしあん
2019-06-24

PyTorch/TorchVisionで複数の入力をモデルに渡したいケース


7.6k{icon} {views}


PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を解説します。

はじめに

PyTorchで例えば、Pix2Pixを使って白黒画像のカラー化の学習を行いたいとします。このケースでは、訓練時の入力に「白黒画像、カラー画像」の2つの入力が必要になります。

カラー画像の白黒化のように簡単なケースの場合は、カラー画像単体を渡してしまって、モデル内/食わせる前後で、白黒化の式を使ってグレースケールのバッチを作ってしまう方法もできます。しかし、複数入力があるケースというのは必ずどこかで出現します。

しかし、torchvision.transformsでの変換は、入力が1つのみのケースを想定していて、入力が複数の場合はうまくいきません。TensorDatasetの場合は、2個以上の引数を返すというのができますが、データを全てメモリに乗せてしまうため、大きいデータの場合はメモリが制約になるケースがあります。今回はTorchVisionの一般的なDataSetをそのまま使います。具体的にはSTL-10で試してみます。

transformの複数入力対応用のラッパークラスを作る

まず結論から。複数入力に対応したラッパークラスを作るのが便利そうです。各処理を別々に対応して定義するのはバカバカしいので。

# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
    def __init__(self, base_func):
        """
        複数の入力に対してtransformsを使うためのラッパークラス
        * base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。
        """
        self.base_func = base_func

    def __call__(self, xs):
        if isinstance(self.base_func, list):
            return [f(x) for f, x in zip(self.base_func, xs)]
        else:
            return [self.base_func(x) for x in xs]

例えばこんな感じ。base_funcには、tranforms.ToTensor()やtransforms.Normalize()などの関数/クラスを渡します。「call」以下で、各入力に対してbase_funcを適用した値をリストとして返します。リストがどんどんつながっていくので今までのtransformみたいにパイプライン状に定義することが可能ですよ。

base_funcの引数はリスト対応しておくと便利でしょう。これは、入力に対して引数が異なる場合を想定しています。例えば、1つ目の入力はカラー画像なので3チャンネル全てに対してスケール調整したいが、2つ目の入力はグレースケール画像なので1チャンネルだけスケール調整したいというケース。

関数が1つで引数が複数だとちょっと処理が面倒なので、関数を複数にしてしまってリストで渡すのが簡単だと思われます。

ラッパークラスの使い方

普通のtransformsと同じような感覚で使えます。

def dataloader():
    transform = transforms.Compose([
        ColorAndGray(),
        MultiInputWrapper(transforms.ToTensor()),
        MultiInputWrapper([
            transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ])
    ])

    dataset = torchvision.datasets.STL10(
        root="./data",
        split="train",
        download=True,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=16
    )
    return loader

先程定義したMultiInputWrapperで、transformsの各処理をラップするだけですね。これでいい感じに複数入力対応してくれます。

MultiInputWrapperのbase_funcがリストの場合は、Normalizeのケースですが、このように入力順に適用する関数を与えればいいと思います。1番目がカラー画像、2番目がグレースケール画像ですね。ColorAndGray()の定義は次で説明します。

1つの入力から2つの入力へ分岐させる部分

もともとSTL-10の画像の入力は1つなので、これをカラー画像、白黒画像の2つの入力に分岐させます。これがColorAndGrayの部分です。

class ColorAndGray(object):
    def __call__(self, img):
        # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
        gray = img.convert("L")
        return img, gray

callの引数は何なんの?ということですが、呼ぶ場所によって違います。このケースではToTensor()の前で読んでいるのでPILのインスタンスです。もしToTensor()の後で呼ぶとPyTorchのテンソルになります。

このColorAndGrayの定義だと返り値は、カラー画像のPILインスタンス+白黒画像のPILインスタンスのTupleになります。以降の処理はtransformsそのままでは処理できないんで、MultiInputWrapperでラップしてケアします。

試してみる

STL-10からunlabeledを16枚とって、カラー画像、白黒画像をタイルして表示します。

def main():
    # dataloader
    loader = dataloader()
    # バッチを1つだけ取得
    (color_batch, gray_batch), y = next(iter(loader))
    # 4x4に結合して保存
    torchvision.utils.save_image(color_batch, "colors.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)
    torchvision.utils.save_image(gray_batch, "grays.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)

if __name__ == "__main__":
    main()

結果はこのようになりました。うまくいきました。


まとめ

PyTorch/TorchVisionでtransformsを複数入力に対応させたかったら、複数入力用のラッパークラスを作ると便利かもしれないよ、ということでした。

コード全体

import torch
import torchvision
from torchvision import transforms

class ColorAndGray(object):
    def __call__(self, img):
        # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
        gray = img.convert("L")
        return img, gray

# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
    def __init__(self, base_func):
        """
        複数の入力に対してtransformsを使うためのラッパークラス
        * base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。
        """
        self.base_func = base_func

    def __call__(self, xs):
        if isinstance(self.base_func, list):
            return [f(x) for f, x in zip(self.base_func, xs)]
        else:
            return [self.base_func(x) for x in xs]

def dataloader():
    transform = transforms.Compose([
        ColorAndGray(),
        MultiInputWrapper(transforms.ToTensor()),
        MultiInputWrapper([
            transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ])
    ])

    dataset = torchvision.datasets.STL10(
        root="./data",
        split="train",
        download=True,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=16
    )
    return loader

def main():
    # dataloader
    loader = dataloader()
    # バッチを1つだけ取得
    (color_batch, gray_batch), y = next(iter(loader))
    # 4x4に結合して保存
    torchvision.utils.save_image(color_batch, "colors.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)
    torchvision.utils.save_image(gray_batch, "grays.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)

if __name__ == "__main__":
    main()


Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です