こしあん
2022-03-05

Albumentationsとtorchvisionで前処理の挙動を揃えたい


6.4k{icon} {views}


AlbumentationsとtorchvisionのToTensorは微妙に挙動が異なります。テンソル化の前処理を揃えないと精度が下がることがあるので、その検証をしていきたいと思います。結論としては、AlbumentationsではToTensorの前にNormalizeが必要です。

やりたいこと

  • AlbumentationsのToTensorとtorchvisionのToTensorで、型や出力値域が微妙に異なる
  • Albumentationsを使いたいんだけど、挙動はtorchvisionと一緒にしてほしい

前処理の動作検証が必要だよね、ということで調べてみました。

とりあえずデータを読み込んでみる

Albumentations / torchvisionそれぞれで、Dataset、DataLoaderを作ります。題材はCIFAR-10とします。

import torch
import torchvision
import albumentations as alb

class AlbumentationsCIFAR(torchvision.datasets.CIFAR10):
    def __init__(self, root="./data", train=True, download=True, transform=None):
        super().__init__("./data", train=True, download=True, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

def test():
    pytorch_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True)
    albumentations_dataset = AlbumentationsCIFAR()

    # 検証用に順序保証する。シーケンシャルにし並列化させない
    pytorch_loader = torch.utils.data.DataLoader(
        pytorch_dataset, batch_size=128, shuffle=False, num_workers=1)
    albumentations_loader = torch.utils.data.DataLoader(
        albumentations_dataset, batch_size=128, shuffle=False, num_workers=1)

    pytorch_batch = next(iter(pytorch_loader))
    albumentations_batch = next(iter(albumentations_loader))

    # Labelを見る
    print(pytorch_batch[1])
    print(albumentations_batch[1])

if __name__ == "__main__":
    test()

AlbumentationsでのDataLoaderは公式ドキュメントを参考にしています。Albumentationsのインポート略称は公式ドキュメントでは「A」ですが、宗教上の理由で「alb」としています。

ここでは、torchvisionとAlbumentationsの2つのデータセットを定義し、それぞれDataLoaderを作成しています。検証のため、2つのDataLoaderの各バッチに対し、参照するサンプルが同じになるように保証したほうが都合が良いので、

  • num_workers=1にして並列化させない
  • shuffle=Falseにしてシャッフルしない

などの配慮をします。各DataLoaderの最初のバッチのラベルが同一なら、参照しているサンプルが同じとみなせるので、それを確認します。

tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3, 4, 7, 7, 2, 9, 9, 9, 3, 2, 6, 4, 3, 6, 6,
        2, 6, 3, 5, 4, 0, 0, 9, 1, 3, 4, 0, 3, 7, 3, 3, 5, 2, 2, 7, 1, 1, 1, 2,
        2, 0, 9, 5, 7, 9, 2, 2, 5, 2, 4, 3, 1, 1, 8, 2, 1, 1, 4, 9, 7, 8, 5, 9,
        6, 7, 3, 1, 9, 0, 3, 1, 3, 5, 4, 5, 7, 7, 4, 7, 9, 4, 2, 3, 8, 0, 1, 6,
        1, 1, 4, 1, 8, 3, 9, 6, 6, 1, 8, 5, 2, 9, 9, 8, 1, 7, 7, 0, 0, 6, 9, 1,
        2, 2, 9, 2, 6, 6, 1, 9])
tensor([6, 9, 9, 4, 1, 1, 2, 7, 8, 3, 4, 7, 7, 2, 9, 9, 9, 3, 2, 6, 4, 3, 6, 6,
        2, 6, 3, 5, 4, 0, 0, 9, 1, 3, 4, 0, 3, 7, 3, 3, 5, 2, 2, 7, 1, 1, 1, 2,
        2, 0, 9, 5, 7, 9, 2, 2, 5, 2, 4, 3, 1, 1, 8, 2, 1, 1, 4, 9, 7, 8, 5, 9,
        6, 7, 3, 1, 9, 0, 3, 1, 3, 5, 4, 5, 7, 7, 4, 7, 9, 4, 2, 3, 8, 0, 1, 6,
        1, 1, 4, 1, 8, 3, 9, 6, 6, 1, 8, 5, 2, 9, 9, 8, 1, 7, 7, 0, 0, 6, 9, 1,
        2, 2, 9, 2, 6, 6, 1, 9])

このとおり、torchvisionとAlbumentationsの各DataLoaderについて、参照するサンプルの同期がとれました

データ型、値域の確認

次に各DataLoaderの値域、データ型を確認します。transformは各ライブラリのToTensorのみ入れます。

AlbumentationsのToTensorについて

公式ドキュメントでは、ToTensorとToTensorV2の2種類が記載されています(2022/3現在)が、実際に動くのはToTensorV2のみです(ver1.1.0)。

https://albumentations.ai/docs/api_reference/pytorch/transforms/

実際にAlbumentationsのToTensorV2をToTensorで置き換えると、

    transform=alb.pytorch.ToTensor()
AttributeError: module 'albumentations.pytorch' has no attribute 'ToTensor'

と怒られます。ネーミング的にはちょっとアレですが、ToTensorV2を使えば良さそうです。

なお、Albumentationsの各処理は長々とパスが書かれていますが、ほとんどAlbumentations直下にエイリアスがかかっているため、「alb.RandomResizedCrop」のように扱えます。

値域、データ型の確認

最初のバッチの画像のデータ型、値域を確認します。

import torch
import torchvision
import albumentations as alb
import albumentations.pytorch

class AlbumentationsCIFAR(torchvision.datasets.CIFAR10):
    def __init__(self, root="./data", train=True, download=True, transform=None):
        super().__init__("./data", train=True, download=True, transform=transform)

    def __getitem__(self, index):
        image, label = self.data[index], self.targets[index]

        if self.transform is not None:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

def test():
    pytorch_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True,
        transform=torchvision.transforms.ToTensor())
    albumentations_dataset = AlbumentationsCIFAR(
        transform=alb.pytorch.ToTensorV2()
    )

    # 検証用に順序保証する。シーケンシャルにし並列化させない
    pytorch_loader = torch.utils.data.DataLoader(
        pytorch_dataset, batch_size=128, shuffle=False, num_workers=1)
    albumentations_loader = torch.utils.data.DataLoader(
        albumentations_dataset, batch_size=128, shuffle=False, num_workers=1)

    pytorch_batch  = next(iter(pytorch_loader))
    albumentations_batch = next(iter(albumentations_loader))

    # 画像
    print(pytorch_batch[0].dtype, pytorch_batch[0].min(), pytorch_batch[0].max())
    print(albumentations_batch[0].dtype, albumentations_batch[0].min(), albumentations_batch[0].max())

    # ラベル(出力は同じ)
    print(pytorch_batch[1])
    print(albumentations_batch[1])

if __name__ == "__main__":
    test()

最初の「import albumentations.pytorch」は、「alb.pytorch.ToTensorV2」でエラーを出さないために必要です。このコードの出力は、

torch.float32 tensor(0.) tensor(1.)
torch.uint8 tensor(0, dtype=torch.uint8) tensor(255, dtype=torch.uint8)
// ラベルの一覧は前と変わらないので省略

ここでtorchvisionとAlbumentationsのToTensorは、デフォルトではデータ型と値域が異なることがわかります。

  • torchvisionはfloat32で値域が0~1
  • Albumentationsはuint8で値域が0~255

データ型と値域を統一する

Albumentationsの挙動をtorchvisionに寄せるには、Normalizeを入れるのが考えられます。

Normalizeの式は、公式ドキュメントより、

img = (img – mean * max_pixel_value) / (std * max_pixel_value)

とあり、デフォルトはmax_pixel_value=255.0なので、mean=0, std=1にすれば0-1の値に収まることがわかります。コードのtest関数部分を次のように変えます。

def test():
    pytorch_dataset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True,
        transform=torchvision.transforms.ToTensor())
    albumentations_dataset = AlbumentationsCIFAR(
        transform=alb.Compose([
            alb.Normalize(mean=(0,0,0), std=(1,1,1)), # ここを追加
            alb.pytorch.ToTensorV2()
        ])
    )

    # 検証用に順序保証する。シーケンシャルにし並列化させない
    pytorch_loader = torch.utils.data.DataLoader(
        pytorch_dataset, batch_size=128, shuffle=False, num_workers=1)
    albumentations_loader = torch.utils.data.DataLoader(
        albumentations_dataset, batch_size=128, shuffle=False, num_workers=1)

    pytorch_batch  = next(iter(pytorch_loader))
    albumentations_batch = next(iter(albumentations_loader))

    # 画像
    print(pytorch_batch[0].dtype, pytorch_batch[0].min(), pytorch_batch[0].max())
    print(albumentations_batch[0].dtype, albumentations_batch[0].min(), albumentations_batch[0].max())

    # 比較
    print(torch.mean(torch.abs(pytorch_batch[0]-albumentations_batch[0]), dim=(1,2,3)))
torch.float32 tensor(0.) tensor(1.)
torch.float32 tensor(0.) tensor(1.)

tensor([1.1922e-08, 1.7047e-08, 1.6349e-08, 9.2180e-09, 1.2913e-08, 1.2052e-08,
        1.4027e-08, 1.4488e-08, 2.1964e-08, 7.8968e-09, 6.9073e-09, 2.2551e-08,
        2.0100e-08, 1.2904e-08, 9.0664e-09, 1.6928e-08, 2.2592e-08, 9.2283e-09,
        1.9999e-08, 1.6108e-08, 3.6200e-08, 1.8515e-08, 7.1183e-09, 1.0181e-08,
        1.0170e-08, 1.6735e-08, 9.7671e-09, 8.0836e-09, 1.3194e-08, 3.3542e-08,
        1.6452e-08, 1.2935e-08, 9.4127e-09, 1.2455e-08, 8.8230e-09, 1.2179e-08,
        1.9774e-08, 1.0544e-08, 2.0370e-08, 1.8105e-08, 1.5532e-08, 3.4272e-08,
        8.8403e-09, 1.7240e-08, 1.1451e-08, 1.0562e-08, 1.6852e-08, 1.3733e-08,
        2.7241e-08, 3.0786e-08, 9.5497e-09, 1.3034e-08, 1.0662e-08, 1.9575e-08,
        8.5590e-09, 2.3874e-08, 1.4196e-08, 7.6816e-09, 1.3233e-08, 1.0464e-08,
        1.7320e-08, 1.1161e-08, 1.0346e-08, 8.5990e-09, 1.3130e-08, 1.8357e-08,
        1.2044e-08, 2.3544e-08, 1.0615e-08, 1.4474e-08, 1.5284e-08, 1.0124e-08,
        1.9643e-08, 1.0413e-08, 1.6433e-08, 1.4508e-08, 2.2245e-08, 1.3305e-08,
        1.0749e-08, 2.9506e-08, 1.1559e-08, 1.0640e-08, 1.3761e-08, 1.8641e-08,
        1.1540e-08, 5.4903e-09, 9.9487e-09, 1.0912e-08, 8.7436e-09, 1.7547e-08,
        1.2513e-08, 8.6875e-09, 2.0400e-08, 9.8456e-09, 2.0542e-08, 2.5347e-08,
        1.4476e-08, 1.1507e-08, 1.1457e-08, 1.0666e-08, 3.1619e-08, 8.9452e-09,
        1.6097e-08, 1.1660e-08, 8.2731e-09, 2.9899e-08, 1.0493e-08, 1.0522e-08,
        4.2317e-08, 1.8354e-08, 1.1478e-08, 9.0210e-09, 1.4171e-08, 1.4848e-08,
        1.6785e-08, 3.6614e-08, 1.0471e-08, 1.2806e-08, 3.2504e-08, 3.1142e-08,
        4.4505e-09, 7.2590e-09, 1.8738e-08, 1.5882e-08, 1.0721e-08, 2.2569e-08,
        2.6089e-08, 2.5713e-08])

これはどちらもFloat32の0-1スケールになっています。もし、Albumentationsでスケールを-1~1にしたいときは、mean=std=0.5とします。

実際に画素値同士のL1ロスを取ってみると、平均して1e-8とほぼ誤差であることがわかります。

結論

AlbumentationsのToTensorの挙動をtorchvisionと同じにしたければ、mean=0, std=1でNormalizeする必要がある。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です