こしあん
2019-06-24

PyTorch/TorchVisionで複数の入力をモデルに渡したいケース

Pocket
LINEで送る


PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を解説します。

はじめに

PyTorchで例えば、Pix2Pixを使って白黒画像のカラー化の学習を行いたいとします。このケースでは、訓練時の入力に「白黒画像、カラー画像」の2つの入力が必要になります。

カラー画像の白黒化のように簡単なケースの場合は、カラー画像単体を渡してしまって、モデル内/食わせる前後で、白黒化の式を使ってグレースケールのバッチを作ってしまう方法もできます。しかし、複数入力があるケースというのは必ずどこかで出現します。

しかし、torchvision.transformsでの変換は、入力が1つのみのケースを想定していて、入力が複数の場合はうまくいきません。TensorDatasetの場合は、2個以上の引数を返すというのができますが、データを全てメモリに乗せてしまうため、大きいデータの場合はメモリが制約になるケースがあります。今回はTorchVisionの一般的なDataSetをそのまま使います。具体的にはSTL-10で試してみます。

transformの複数入力対応用のラッパークラスを作る

まず結論から。複数入力に対応したラッパークラスを作るのが便利そうです。各処理を別々に対応して定義するのはバカバカしいので。

# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
    def __init__(self, base_func):
        """
        複数の入力に対してtransformsを使うためのラッパークラス
        * base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。
        """
        self.base_func = base_func

    def __call__(self, xs):
        if isinstance(self.base_func, list):
            return [f(x) for f, x in zip(self.base_func, xs)]
        else:
            return [self.base_func(x) for x in xs]

例えばこんな感じ。base_funcには、tranforms.ToTensor()やtransforms.Normalize()などの関数/クラスを渡します。「call」以下で、各入力に対してbase_funcを適用した値をリストとして返します。リストがどんどんつながっていくので今までのtransformみたいにパイプライン状に定義することが可能ですよ。

base_funcの引数はリスト対応しておくと便利でしょう。これは、入力に対して引数が異なる場合を想定しています。例えば、1つ目の入力はカラー画像なので3チャンネル全てに対してスケール調整したいが、2つ目の入力はグレースケール画像なので1チャンネルだけスケール調整したいというケース。

関数が1つで引数が複数だとちょっと処理が面倒なので、関数を複数にしてしまってリストで渡すのが簡単だと思われます。

ラッパークラスの使い方

普通のtransformsと同じような感覚で使えます。

def dataloader():
    transform = transforms.Compose([
        ColorAndGray(),
        MultiInputWrapper(transforms.ToTensor()),
        MultiInputWrapper([
            transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ])
    ])

    dataset = torchvision.datasets.STL10(
        root="./data",
        split="train",
        download=True,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=16
    )
    return loader

先程定義したMultiInputWrapperで、transformsの各処理をラップするだけですね。これでいい感じに複数入力対応してくれます。

MultiInputWrapperのbase_funcがリストの場合は、Normalizeのケースですが、このように入力順に適用する関数を与えればいいと思います。1番目がカラー画像、2番目がグレースケール画像ですね。ColorAndGray()の定義は次で説明します。

1つの入力から2つの入力へ分岐させる部分

もともとSTL-10の画像の入力は1つなので、これをカラー画像、白黒画像の2つの入力に分岐させます。これがColorAndGrayの部分です。

class ColorAndGray(object):
    def __call__(self, img):
        # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
        gray = img.convert("L")
        return img, gray

callの引数は何なんの?ということですが、呼ぶ場所によって違います。このケースではToTensor()の前で読んでいるのでPILのインスタンスです。もしToTensor()の後で呼ぶとPyTorchのテンソルになります。

このColorAndGrayの定義だと返り値は、カラー画像のPILインスタンス+白黒画像のPILインスタンスのTupleになります。以降の処理はtransformsそのままでは処理できないんで、MultiInputWrapperでラップしてケアします。

試してみる

STL-10からunlabeledを16枚とって、カラー画像、白黒画像をタイルして表示します。

def main():
    # dataloader
    loader = dataloader()
    # バッチを1つだけ取得
    (color_batch, gray_batch), y = next(iter(loader))
    # 4x4に結合して保存
    torchvision.utils.save_image(color_batch, "colors.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)
    torchvision.utils.save_image(gray_batch, "grays.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)

if __name__ == "__main__":
    main()

結果はこのようになりました。うまくいきました。


まとめ

PyTorch/TorchVisionでtransformsを複数入力に対応させたかったら、複数入力用のラッパークラスを作ると便利かもしれないよ、ということでした。

コード全体

import torch
import torchvision
from torchvision import transforms

class ColorAndGray(object):
    def __call__(self, img):
        # ToTensor()の前に呼ぶ場合はimgはPILのインスタンス
        gray = img.convert("L")
        return img, gray

# 複数の入力をtransformsに展開するラッパークラスを作る
class MultiInputWrapper(object):
    def __init__(self, base_func):
        """
        複数の入力に対してtransformsを使うためのラッパークラス
        * base_func : 各入力に対して適用するtransformsの関数/クラス。関数またはlist(関数)。
        """
        self.base_func = base_func

    def __call__(self, xs):
        if isinstance(self.base_func, list):
            return [f(x) for f, x in zip(self.base_func, xs)]
        else:
            return [self.base_func(x) for x in xs]

def dataloader():
    transform = transforms.Compose([
        ColorAndGray(),
        MultiInputWrapper(transforms.ToTensor()),
        MultiInputWrapper([
            transforms.Normalize(mean=(0.5,0.5,0.5,), std=(0.5,0.5,0.5,)),
            transforms.Normalize(mean=(0.5,), std=(0.5,))
        ])
    ])

    dataset = torchvision.datasets.STL10(
        root="./data",
        split="train",
        download=True,
        transform=transform
    )
    loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=16
    )
    return loader

def main():
    # dataloader
    loader = dataloader()
    # バッチを1つだけ取得
    (color_batch, gray_batch), y = next(iter(loader))
    # 4x4に結合して保存
    torchvision.utils.save_image(color_batch, "colors.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)
    torchvision.utils.save_image(gray_batch, "grays.png", nrow=4, padding=10, range=(-1.0,1.0), normalize=True)

if __name__ == "__main__":
    main()

Related Posts

TPUで学習率減衰させる方法 TPUで学習率減衰したいが、TensorFlowのオプティマイザーを使うべきか、tf.kerasのオプティマイザーを使うべきか、あるいはKerasのオプティマイザーを使うべきか非常にややこしいことがあります。TPUで学習率を減衰させる方法を再現しました。 結論から TPU環境でtf.keras...
note開設のお知らせ 本日noteを開設いたしました。 https://note.mu/koshian2 これは自分の記事をより多くの方々に読んでいただき、新たな読者の開拓を図るためであります。 当面は既存の記事の再送を中心に考えていますが、いくつかnote向けに読みやすい新規の記事も考えています。好評なら新規の...
KerasのModelCheckpointのsave_best_onlyは何を表すのか?... Kerasには「モデルの精度が良くなったときだけ係数を保存する」のに便利なModelCheckpointというクラスがあります。ただこのsave_best_onlyがいまいち公式の解説だとピンとこないので調べてみました。 ModelCheckpointとは? 公式ドキュメントより ke...
Numpyの配列のみを操作して四角形を描画する(Numpyの画像処理)... Numpyの画像処理です。Numpyの配列のみを操作して、画面上に四角形を描画してみます。Numpyの画像処理は出力結果の合成のときにたまに使う割には若干独特なので注意が必要です。 Numpy arrayの画像の構造は(y, x, channel) ここだけ覚えておけば大丈夫です。Pillo...
KerasのCallbackを使って継承したImageDataGeneratorに値が渡せるか確かめ... Kerasで前処理の内容をエポックごとに変えたいというケースがたまにあります。これを実装するとなると、CallbackからGeneratorに値を渡すというコードになりますが、これが本当にできるかどうか確かめてみました。 想定する状況 例えば、前処理で正則化に関係するData Augmenta...
Pocket
LINEで送る

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です