こしあん
2018-08-30

PyTorchでサイズの異なる画像を読み込む方法

Pocket
LINEで送る

実際の画像判定では、MNISTやCIFARのようにサイズが完全に整形されたデータはなかなか少ないです。例えばサイズが横幅は一定でも縦幅が異なっていたりするケースがあります。訓練画像間でサイズが異なる場合、そのまま読み込みするとエラーになります。その解決法を示します。

transforms.RandomResizedCropを使おう

他にもあるかもしれませんが、ToTensor()の前にRandomResizedCropを挟むのがかなり確実ではないかと思います。自分がやった限りでは特にエラーが起きませんでした。

class torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)
https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomResizedCrop

もともとこれData Augmentation用の関数で、指定した比率のサイズ(scale)とアスペクト比(ratio)でトリミングします。例えば、縦100×横100の画像があり、scale=0.5、ratio=0.75なら縦81×横61でランダムにトリミングするようです。最終的にsizeで合わせたサイズに拡大・縮小されて出力されます。

クロップ部分の細かいことはさておいて、scale=1、ratio=1で固定すれば、入力画像をそのままリサイズするだけの関数になります。本来こっちを使いそうなtransforms.Resizeを、異なるサイズのある環境で使うとなぜかエラーになります(バグかもしれないのでそのうち改善されるかもしれません)。TorchVisionのバージョン0.2.1では、RandomResizedCropを使うとエラーは起きませんでした。

次のように使います。

import torch
from torchvision import datasets, transforms

data_transform = transforms.Compose([
     transforms.RandomResizedCrop(160, scale=(1.0, 1.0), ratio=(1.0, 1.0)), 
     transforms.ToTensor()
    ])

your_datasets = datasets.ImageFolder(root="path-to-your-dataset/train", transform=data_transform) 
loader = torch.utils.data.DataLoader(your_datasets, batch_size=100)

for batch_index, (X, y) in enumerate(loader):
    # ここに処理を書く

DataLoaderの画像を表示する

ちなみにPyTorchの画像はChannels_firstなので、Pyplotで表示するときに少し工夫がいります。np.rollaxisでChannels_lastに変換しましょう。

import numpy as np
import matplotlib.pyplot as plt

plt.plot(figsize=(10, 10))
plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, hspace=0.05, wspace=0.05)
for i in range(100):
    x = np.rollaxis(X[i].numpy(), 0, 3)
    plt.subplot(10, 10, i+1, xticks=[], yticks=[])
    plt.imshow(x)
plt.show()

以上です。

Related Posts

Numpyだけでサクッと画像を拡大する方法... Numpyだけで画像をサクッと拡大する方法を紹介します。OpenCVやPillowを使うまでもないな、というようなときに便利な方法です。ニューラルネットワークでインプットのサイズを調整するときも使えます。 ただのNearest Neighbor法 拡大前の1ピクセルを1つの四角形と見立てて、拡...
PythonのMessagePack-Numpyで独自のクラスをシリアライズする方法... MessagePackを使ってシリアライズを高速化したかったのですが、独自のクラスやネストされたオブジェクトについてシリアル化する方法が全然なかったので調べてみました。Numpyのシリアライズも使えるMessagePackの拡張版、MessagePack-Numpyを使って確かめます。 Mess...
One-Hotエンコーディング(ダミー変数)ならPandasのget_dummies()を使おう... 特徴量処理(特徴量エンジニアリング)でよく使う処理として、「A,B,C」「1,2,3」といったカテゴリー変数をOne-Hotベクトル化するというのがあります。SkelarnのOneHotEncoderでもできますが、Pandasのget_dummies()を使うと、もっと統合的にすることができま...
PyTorch/TorchVisionで複数の入力をモデルに渡したいケース... PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を...
PyTorchでOnehotエンコーディングするためのワンライナー... PyTorchでクラスの数字を0,1のベクトルに変形するOnehotベクトルを簡単に書く方法を紹介します。ワンライナーでできます。 TL;DR PyTorchではこれでOnehotエンコーディングができます。 onehot = torch.eye(10) ただし、labelはLongTe...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です