こしあん
2021-11-06

SwinTransformerでCIFAR-10を一から訓練する


Pocket
LINEで送る
Delicious にシェア

4.9k{icon} {views}


画像のTransformer系で有望なモデルである「Swin Transformer」でCIFAR-10を1から訓練してみました。1からの訓練はCNNほど楽ではありませんが、流行りのTransformerを気軽に扱うことができました。

Swin Transformer概要

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
https://arxiv.org/abs/2103.14030

2021年の論文です。画像のTransformer系の論文では現在SoTAですし、EfficientNetに肉薄する精度を出している画期的な論文です。

大きく見るとResNetに近い構成ですね。ただ、レイヤーがほぼMLPやAttentionなのでTransformer版のResNetという感じでしょう。「Patch Partition/Merging」、「W-MSA/SW-MSA」という独自のテクニックが出てきます。この説明は今回詳しくしないですが、ざっくりとだけ書いておきます。

Patch Partition/Merging

Patch Partition/Mergingというのは、画像のテンソルをreshapeして小さなパッチを作る手法です。

この手法は今後も使われそうだな(と勝手に思っているので)、少し深堀りしておきます。512×512のカラー画像を128×128の16パッチに分割する方法を考えます。これをNumPy配列のreshapeやtransposeだけで実装してみましょう。

from PIL import Image
import numpy as np

def make_patch():
    with Image.open("lena.png") as img:
        x = np.array(img)
    print(x.shape) # (512, 512, 3)
    # PyTorchの表記に合わせる
    x = x.transpose([2, 0, 1])[None, ...] # (1, 3, 512, 512)
    # 128x128のパッチを作る
    x = x.reshape(1, 3, 4, 128, 4, 128)
    # パッチをバッチの軸に集約する
    x = x.transpose([0, 2, 4, 1, 3, 5]).reshape(16, 3, 128, 128)

    fig = plt.figure(figsize=(12, 12))
    for i in range(16):
        ax = fig.add_subplot(4, 4, i+1)
        ax.imshow(x[i].transpose([1, 2, 0])) # [C, H, W] -> [H, W, C]
    plt.show()

これをニューラルネットワークの中でやっているのがSwin Transformerです。コードを見ているとこれとほぼ同じ処理が出てきます。実際はこれに縦横をロールを加えてニューロンが位置に依存してしまわないように調整しています。

実はこの処理、私が書いた「インフィニティNumPy」に丸々出てきたんですよね。これはSwin Transformerの論文が出る前に書いたものなのですが、こういったところで応用されてるのを見ると感慨深いです(自分がこれを知ったのはContextual Attentionという論文です)。

W-MSA/SW-MSA

これがいわゆるTransformerのAttentionなレイヤーです。この他にRelative position biasというPositional Encoding的な要素を組み込んでSelf Attentionを計算しています。ここの実装はかなりややこしいので割愛します。

Swin Transformerをざっくり実装してCIFAR-10

論文ではCIFAR-10の実装はありませんでした。論文を参考にそれっぽい実装してみます。斜め読みしただけなので、条件を相当端折っている点はご了承ください。

では早速実装していきましょう。Swin Transformerの公式実装(Microsoftによる)を使います。

https://github.com/microsoft/Swin-Transformer

こちらから「/models/swin_transformer.py」をコピーしてきます。swin_transformer.pyを作業ディレクトリ直下においた前提で進めます。

論文では、モデルを小さい順に「Swin-T, Swin-S, Swin-B, Swin-L」があったのですが、今回はSwin-SをベースにCIFAR-10用に調整して使います。以下はImageNet用の設定です。

  • Swin-S: C = 96, layer numbers ={2, 2, 18, 2}

ImageNetの解像度は224に対し、CIFAR-10は32です。最後の2を取っ払って「2, 2, 18」としてみます。

from swin_transformer import SwinTransformer
from torchinfo import summary

def show_summary():
    model = SwinTransformer(img_size=32, num_classes=10, embed_dim=96, 
            num_heads=[3, 6, 12], depths=[2, 2, 18], window_size=8)
    summary(model, (1, 3, 32, 32))

if __name__ == "__main__":
    show_summary()

num_headsはデフォルトの引数(3, 6, 12, 24)の最後だけ削りました。window_sizeはデフォルト(ImageNetの設定)だと7ですが、7だと解像度が割り切れず、モデル内部のreshapeでエラーを引き起こすので8に変えました。4でもいいと思います。

torchinfoでsummaryを見ると以下のようになります。

====================================================================================================
Layer (type:depth-idx)                             Output Shape              Param #
====================================================================================================
SwinTransformer                                    --                        --
├─ModuleList: 1-1                                  --                        --
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-1                        --                        225,030
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-2                        --                        890,316
│    └─BasicLayer: 2                               --                        --
│    │    └─ModuleList: 3-3                        --                        31,942,296
├─PatchEmbed: 1-2                                  [1, 64, 96]               --
│    └─Conv2d: 2-1                                 [1, 96, 8, 8]             4,704
│    └─LayerNorm: 2-2                              [1, 64, 96]               192
├─Dropout: 1-3                                     [1, 64, 96]               --
├─ModuleList: 1-1                                  --                        --
│    └─BasicLayer: 2-3                             [1, 16, 192]              --
│    │    └─PatchMerging: 3-4                      [1, 16, 192]              74,496
│    └─BasicLayer: 2-4                             [1, 4, 384]               --
│    │    └─PatchMerging: 3-5                      [1, 4, 384]               296,448
│    └─BasicLayer: 2-5                             [1, 4, 384]               --
├─LayerNorm: 1-4                                   [1, 4, 384]               768
├─AdaptiveAvgPool1d: 1-5                           [1, 384, 1]               --
├─Linear: 1-6                                      [1, 10]                   3,850
====================================================================================================
Total params: 33,434,218
Trainable params: 33,434,218
Non-trainable params: 0
Total mult-adds (M): 33.73
====================================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 4.28
Params size (MB): 133.74
Estimated Total Size (MB): 138.03
====================================================================================================

特徴としては、MLP中心なので係数が多いです。ConvNetでパラメーター数が33MというとResNet101に匹敵するような大型のモデルですが、Forward/Backwardの必要メモリ数が少ないです。CIFAR-10なら、バッチサイズ512や1024のような大きな値でも1枚のGPUで訓練できます。

何も考えずに訓練する

Swin Transformerを含め、画像のTransformer系のモデルを1から訓練するときは、強めのData Augmentationをかけることがほとんどです。最初はそこらを全部無視して、Data AugmentationなしでConvNetのように、単純な訓練でどの程度精度が出るか確かめます。

from swin_transformer import SwinTransformer
import torch
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
import torchmetrics

def main():
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    batch_size = 1024
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

    model = SwinTransformer(img_size=32, num_classes=10, embed_dim=96, 
            num_heads=[3, 6, 12], depths=[2, 2, 18], window_size=8)
    model = model.cuda()

    optimizer = torch.optim.AdamW(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()

    train_acc, test_acc = torchmetrics.Accuracy(), torchmetrics.Accuracy()
    max_test_acc = 0.0
    sw = SummaryWriter("logs")

    for epoch in range(300):  # loop over the dataset multiple times
        train_acc.reset(), test_acc.reset()
        model.train()
        for data in trainloader:
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            pred = torch.argmax(outputs, dim=-1)
            train_acc(pred.cpu(), labels.cpu())

        model.eval()
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()

            outputs = model(inputs)
            pred = torch.argmax(outputs, dim=-1)
            test_acc(pred.cpu(), labels.cpu())

        print(f"Epoch{epoch+1}, TrainAcc:{train_acc.compute()}, TestAcc:{test_acc.compute()}")
        if max_test_acc < test_acc.compute():
            print(f"Test acc improved from {max_test_acc} to {test_acc.compute()}")
            torch.save(model.state_dict(), "model")
            print("Model saved.")
            max_test_acc = test_acc.compute()
        sw.add_scalar("Train Accuracy", train_acc.compute(), epoch+1)
        sw.add_scalar("Test Accuracy", test_acc.compute(), epoch+1)

if __name__ == "__main__":
    main()

PyTorchのCIFAR-10のチュートリアルをベースにしたものです。TensorBoardに記録した精度を見てみます。

ん…ん?? なんかめちゃくちゃオーバーフィッティングしてる??

テスト精度は67-68%でしょうか。MLP中心にしては頑張っているほうではないかと思います。ただこれは比較的うまく行っている方で、初期値ガチャでハズレを引くと、

20エポックやってもせいぜい25%程度(もっと悪い場合もある)という悲しい結果に。学習の初期に不安定になりやすいようですね。

論文の設定に近づける

さすがにこれだと手抜きすぎなので、もう少しがんばります。

  • Data Augmentationの追加
     + RandAugment
     + Mixup + Cutmix
  • 学習率のWarmup+Cosine Annealingのスケジューラーを追加
  • weight decay=0.05を追加(論文のImageNet 1Kの設定流用)

必要な処理はtimmライブラリにあるので、こちらから流用します。RandAugmentはAutoAugmentのディレクトリにあるので、AutoAugmentの親戚なのでしょう。

MixupとCutMixに対応するために損失関数をSoftTargetCrossEntropyに変えます。この損失関数もtimmに組み込まれています。

RandAugmentとMixup+CutMixのハイパーパラメータはtimmのドキュメントにある値を流用しました。コードは以下の通りです。

from swin_transformer import SwinTransformer
import torch
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
import torchvision
import torchvision.transforms as transforms
import torchmetrics
from timm.loss import SoftTargetCrossEntropy
from timm.data.auto_augment import rand_augment_transform
from timm.data.mixup import Mixup
from timm.scheduler import CosineLRScheduler

def main():
    transform = transforms.Compose(
        [
            # RandAugment
            rand_augment_transform( 
                config_str='rand-m9-mstd0.5', 
                hparams={'translate_const': 117, 'img_mean': (124, 116, 104)}
            ),
            transforms.ToTensor()])
    batch_size = 1024
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                            download=True, transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                            shuffle=True, num_workers=2)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                        download=True, transform=transforms.Compose(
                                            [transforms.ToTensor()]))
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                            shuffle=False, num_workers=2)

    model = SwinTransformer(img_size=32, num_classes=10, embed_dim=96, 
            num_heads=[3, 6, 12], depths=[2, 2, 18], window_size=8)
    model = model.cuda()

    optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.05)
    criterion = SoftTargetCrossEntropy()
    scheduler = CosineLRScheduler(optimizer, t_initial=300, 
                                  lr_min=1e-4, warmup_t=20, 
                                  warmup_lr_init=1e-4, warmup_prefix=True)

    # mixup + cutmix
    mixup_args = {
        'mixup_alpha': 0.,
        'cutmix_alpha': 1.0,
        'cutmix_minmax': None,
        'prob': 1.0,
        'switch_prob': 0.,
        'mode': 'batch',
        'label_smoothing': 0,
        'num_classes': 10}
    mixup_fn = Mixup(**mixup_args)

    train_acc, test_acc = torchmetrics.Accuracy(), torchmetrics.Accuracy()
    max_test_acc = 0.0
    sw = SummaryWriter("logs")

    for epoch in range(300):  # loop over the dataset multiple times
        train_acc.reset(), test_acc.reset()
        model.train()
        for data in trainloader:
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()
            inputs, labels_mixup = mixup_fn(inputs, labels)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels_mixup)
            loss.backward()
            optimizer.step()

            pred = torch.argmax(outputs, dim=-1)
            train_acc(pred.cpu(), labels.cpu())

        model.eval()
        for data in testloader:
            inputs, labels = data
            inputs, labels = inputs.cuda(), labels.cuda()

            outputs = model(inputs)
            pred = torch.argmax(outputs, dim=-1)
            test_acc(pred.cpu(), labels.cpu())

        print(f"Epoch{epoch+1}, TrainAcc:{train_acc.compute()}, TestAcc:{test_acc.compute()}")
        if max_test_acc < test_acc.compute():
            print(f"Test acc improved from {max_test_acc} to {test_acc.compute()}")
            torch.save(model.state_dict(), "model")
            print("Model saved.")
            max_test_acc = test_acc.compute()
        sw.add_scalar("Train Accuracy", train_acc.compute(), epoch+1)
        sw.add_scalar("Test Accuracy", test_acc.compute(), epoch+1)

        scheduler.step(epoch+1)

if __name__ == "__main__":
    main()

テスト精度は最終的には86.1%(以前81.1%と書いていましたが、テストデータのtransformが間違っていました。訂正いたします)でした。強いData Augmentationを使っているため訓練精度が低いです。「正則化が強すぎるから収束遅いよね。Weight Decayいらないんじゃない?」と外してみたら、テスト精度は0.4%落ちました(81.1%→80.7%。テストデータのtransformが間違っている時点)。

もちろんこれがSwin Transformerのフルスペックの実装ではありませんが、CIFAR-10の場合、CNNなら割といい加減にやっても90%は出るので「こんだけ頑張っても90%出ないのか」という思いはありました。もちろんもっと大きい画像でやれば精度は出るのかもしれません。

しかし、解像度224のImageNetの場合でも、1から訓練する場合はこれと同程度かそれ以上の正則化や学習率のコントロールなどが必要になると論文にありました。転移学習の場合はもっと簡単にできるかもしれませんが、1から学習させる場合はTransformer系はかなり大変そうだなという印象が拭えませんでした。1から学習させるときに「Swin TransformerかCNNか選べるよ」って言われたら自分は95%ぐらいCNNを選ぶと思います。Transformer系でもSwin Transformerは相当頑張っている方なので、あと1~2年したらもっと簡単なモデルが出るかもしれません。

Swin Transformerの良い点は、モデルの係数の大きさの割にメモリ消費が少ないことです。これは畳み込みを全結合に置き換えたことに由来するものです。係数が33Mでもバッチサイズ1024で訓練可能でした。また、訓練速度もそこまで落ちることはなく、2080Ti1枚で300エポックを1時間半で回せました。全結合中心なのでGPUに向いていると思います(逆にTPUは速度あんまり出ないかもしれません)。悪い点は頑張ってはいるがまだハイパラに敏感すぎるのと、モデルサイズが大きいことですね。何回もスナップショットと取っているとすぐに数GBや数十GB行ってしまうと思います。

今回はSwin TransformerでCIFAR-10を1から訓練する方法をざっくりと見ていきましたが、次回以降はSwin Transformerの転移学習を見ていきたいと思います。こっちはもっと簡単にできると期待したいです。

転移学習の記事を書きました

Pocket
Delicious にシェア



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です