こしあん
2018-08-30

PyTorchでサイズの異なる画像を読み込む方法

実際の画像判定では、MNISTやCIFARのようにサイズが完全に整形されたデータはなかなか少ないです。例えばサイズが横幅は一定でも縦幅が異なっていたりするケースがあります。訓練画像間でサイズが異なる場合、そのまま読み込みするとエラーになります。その解決法を示します。

transforms.RandomResizedCropを使おう

他にもあるかもしれませんが、ToTensor()の前にRandomResizedCropを挟むのがかなり確実ではないかと思います。自分がやった限りでは特にエラーが起きませんでした。

class torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)
https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomResizedCrop

もともとこれData Augmentation用の関数で、指定した比率のサイズ(scale)とアスペクト比(ratio)でトリミングします。例えば、縦100×横100の画像があり、scale=0.5、ratio=0.75なら縦81×横61でランダムにトリミングするようです。最終的にsizeで合わせたサイズに拡大・縮小されて出力されます。

クロップ部分の細かいことはさておいて、scale=1、ratio=1で固定すれば、入力画像をそのままリサイズするだけの関数になります。本来こっちを使いそうなtransforms.Resizeを、異なるサイズのある環境で使うとなぜかエラーになります(バグかもしれないのでそのうち改善されるかもしれません)。TorchVisionのバージョン0.2.1では、RandomResizedCropを使うとエラーは起きませんでした。

次のように使います。

import torch
from torchvision import datasets, transforms

data_transform = transforms.Compose([
     transforms.RandomResizedCrop(160, scale=(1.0, 1.0), ratio=(1.0, 1.0)), 
     transforms.ToTensor()
    ])

your_datasets = datasets.ImageFolder(root="path-to-your-dataset/train", transform=data_transform) 
loader = torch.utils.data.DataLoader(your_datasets, batch_size=100)

for batch_index, (X, y) in enumerate(loader):
    # ここに処理を書く

DataLoaderの画像を表示する

ちなみにPyTorchの画像はChannels_firstなので、Pyplotで表示するときに少し工夫がいります。np.rollaxisでChannels_lastに変換しましょう。

import numpy as np
import matplotlib.pyplot as plt

plt.plot(figsize=(10, 10))
plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, hspace=0.05, wspace=0.05)
for i in range(100):
    x = np.rollaxis(X[i].numpy(), 0, 3)
    plt.subplot(10, 10, i+1, xticks=[], yticks=[])
    plt.imshow(x)
plt.show()

以上です。

Related Posts

Python(Numpy)で画像を水平反転する方法:Data Augmentation向け... OpenCVを使わずに単純に画像を左右反転(水平反転)する方法を考えます。ディープラーニングでデータのジェネレーターを自分で実装した場合、Data Augmentationを組み込む際にも必要になります。それを見ていきましょう。 左右反転自体は実は簡単 例えばNumpyの行列を左右反転させてみ...
OpenCVで画像を歪ませる方法 PythonでOpenCVを使い画像を歪ませる方法を考えます。アフィン変換というちょっと直感的に理解しにくいことをしますが、慣れればそこまで難しくはありません。ディープラーニングのData Augmentationにも使えます。 OpenCVでのアフィン変換のイメージ アフィン変換というと、ま...
TensorFlow/Kerasでネットワーク内でData Augmentationする方法... NumpyでData Augmentationするのが遅かったり、書くの面倒だったりすることありますよね。今回はNumpy(CPU)ではなく、ニューラルネットワーク側(GPU、TPU)でAugmetationをする方法を見ていきます。 こんなイメージ Numpy(CPU)でやる場合 Num...
KerasのCallbackを使って継承したImageDataGeneratorに値が渡せるか確かめ... Kerasで前処理の内容をエポックごとに変えたいというケースがたまにあります。これを実装するとなると、CallbackからGeneratorに値を渡すというコードになりますが、これが本当にできるかどうか確かめてみました。 想定する状況 例えば、前処理で正則化に関係するData Augmenta...
PandasのDataFrameでグループ別にサンプルをN個抜き出す方法... 「PandasでGroupbyでグルーピングしたはいんだけど、そこからグループ別にサンプルを1個、2個…と抜き出す、SQLでよくやるやつってどうやるんだっけ?」ということが気になったので、調べました。ちゃんとした方法があります。 例題 今、中国地方と四国地方の県と面積をDataFrameにして...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です