PyTorchで画像を小さいパッチに切り出す方法
Posted On 2019-10-10
PyTorchで1枚の画像を複数の小さい画像(パッチ)に切り出す方法を紹介します。TensorFlowだとtf.image.extract_patchesにあたる処理です。
目次
torch.Tensor.unfold
torch.Tensor.unfoldという関数を使います。
unfold(dimension, size, step) → Tensor
という形式で、順番にパッチを切り出す次元、パッチサイズ、パッチを切り出す間隔ですね。次元は縦と横で取ればいいので画像の4階テンソルなら2,3で取れば良いでしょう。
コード
この画像を「cat.jpg」とします。
128px × 128pxのパッチで、64px間隔に取り出すものとします。
import torch
import torchvision
from PIL import Image
import numpy as np
# テンソルの読み込み
with Image.open("cat.jpg") as img:
original = np.expand_dims(np.asarray(img, np.float32).transpose([2, 0, 1]), axis=0) / 255.0
x = torch.as_tensor(original)
print(x.size()) # [1, 3, 821, 547]
# パッチサイズは128x128, 64px間隔で切り出す
patches = x.unfold(2, 128, 64).unfold(3, 128, 64)
print(patches.size()) # [1, 3, 11, 7, 128, 128]
# パッチをタイルして保存
out = patches.permute([0, 2, 3, 1, 4, 5]).reshape(-1, 3, 128, 128) # permuteが必要
torchvision.utils.save_image(out, "patch.png", nrow=7)
縦方向と横方向で2回パッチを取れば良いです。パッチを取った後がごちゃごちゃしていますが、次元は「バッチサイズ、チャンネル数、縦方向のパッチ数、横方向のパッチ数、縦方向のパッチ解像度、横方向のパッチ解像度」となります。縦・横はunfoldを適用する順番によって変わります。結果は次の通り。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー