こしあん
2018-08-30

PyTorchでサイズの異なる画像を読み込む方法

実際の画像判定では、MNISTやCIFARのようにサイズが完全に整形されたデータはなかなか少ないです。例えばサイズが横幅は一定でも縦幅が異なっていたりするケースがあります。訓練画像間でサイズが異なる場合、そのまま読み込みするとエラーになります。その解決法を示します。

transforms.RandomResizedCropを使おう

他にもあるかもしれませんが、ToTensor()の前にRandomResizedCropを挟むのがかなり確実ではないかと思います。自分がやった限りでは特にエラーが起きませんでした。

class torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=2)
https://pytorch.org/docs/stable/torchvision/transforms.html#torchvision.transforms.RandomResizedCrop

もともとこれData Augmentation用の関数で、指定した比率のサイズ(scale)とアスペクト比(ratio)でトリミングします。例えば、縦100×横100の画像があり、scale=0.5、ratio=0.75なら縦81×横61でランダムにトリミングするようです。最終的にsizeで合わせたサイズに拡大・縮小されて出力されます。

クロップ部分の細かいことはさておいて、scale=1、ratio=1で固定すれば、入力画像をそのままリサイズするだけの関数になります。本来こっちを使いそうなtransforms.Resizeを、異なるサイズのある環境で使うとなぜかエラーになります(バグかもしれないのでそのうち改善されるかもしれません)。TorchVisionのバージョン0.2.1では、RandomResizedCropを使うとエラーは起きませんでした。

次のように使います。

import torch
from torchvision import datasets, transforms

data_transform = transforms.Compose([
     transforms.RandomResizedCrop(160, scale=(1.0, 1.0), ratio=(1.0, 1.0)), 
     transforms.ToTensor()
    ])

your_datasets = datasets.ImageFolder(root="path-to-your-dataset/train", transform=data_transform) 
loader = torch.utils.data.DataLoader(your_datasets, batch_size=100)

for batch_index, (X, y) in enumerate(loader):
    # ここに処理を書く

DataLoaderの画像を表示する

ちなみにPyTorchの画像はChannels_firstなので、Pyplotで表示するときに少し工夫がいります。np.rollaxisでChannels_lastに変換しましょう。

import numpy as np
import matplotlib.pyplot as plt

plt.plot(figsize=(10, 10))
plt.subplots_adjust(left=0.05, right=0.95, top=0.95, bottom=0.05, hspace=0.05, wspace=0.05)
for i in range(100):
    x = np.rollaxis(X[i].numpy(), 0, 3)
    plt.subplot(10, 10, i+1, xticks=[], yticks=[])
    plt.imshow(x)
plt.show()

以上です。

Related Posts

Numpyだけでサクッと画像を拡大する方法... Numpyだけで画像をサクッと拡大する方法を紹介します。OpenCVやPillowを使うまでもないな、というようなときに便利な方法です。ニューラルネットワークでインプットのサイズを調整するときも使えます。 ただのNearest Neighbor法 拡大前の1ピクセルを1つの四角形と見立てて、拡...
One-Hotエンコーディング(ダミー変数)ならPandasのget_dummies()を使おう... 特徴量処理(特徴量エンジニアリング)でよく使う処理として、「A,B,C」「1,2,3」といったカテゴリー変数をOne-Hotベクトル化するというのがあります。SkelarnのOneHotEncoderでもできますが、Pandasのget_dummies()を使うと、もっと統合的にすることができま...
Chainerで画像の前処理やDataAugmentationをしたいときはDatasetMixin... Chainerにはデフォルトでランダムクロップや標準化といった、画像の前処理やDataAugmentation用の関数が用意されていません。別途のChainer CVというライブラリを使う方法もありますが、chainer.dataset.DatasetMixinを継承させて独自のデータ・セットを定...
PandasのDataFrameでグループ別にサンプルをN個抜き出す方法... 「PandasでGroupbyでグルーピングしたはいんだけど、そこからグループ別にサンプルを1個、2個…と抜き出す、SQLでよくやるやつってどうやるんだっけ?」ということが気になったので、調べました。ちゃんとした方法があります。 例題 今、中国地方と四国地方の県と面積をDataFrameにして...
Python(Numpy)で画像を水平反転する方法:Data Augmentation向け... OpenCVを使わずに単純に画像を左右反転(水平反転)する方法を考えます。ディープラーニングでデータのジェネレーターを自分で実装した場合、Data Augmentationを組み込む際にも必要になります。それを見ていきましょう。 左右反転自体は実は簡単 例えばNumpyの行列を左右反転させてみ...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です