こしあん
2021-11-06

SwinTransformerで転移学習(EfficientNet/ResNet50との比較)


Pocket
LINEで送る
Delicious にシェア

8k{icon} {views}


Swin Transformerを転移学習してみます。1から学習させる場合と異なり、そこまで強いData Augmentationをかけなくても訓練は安定します。訓練済み係数も含め、timmライブラリから簡単に利用できます。

この記事の目的

前回の記事でSwinTransformer-Likeなモデルを1から訓練してみました。1から訓練する際は、Data Augmentationを強くしたり、学習率を調整したりとハイパラチューニングに気を遣う必要がありました。転移学習の場合も同様にハイパラに敏感なのかどうなのかを見ていきます。

今回使うのはKaggleにある「315 Bird Species」というデータです。315種類の鳥の画像分類をするデータで、全データが224×224×3の解像度に統一されています。訓練が45980、Valとテストが1575ずつあり、アーカイブで1.23GBのかなり扱いやすいデータです。ライセンスはパブリックドメインなのも良いです。

データセットの説明によると、EfficientNetB3のモデルでValが98.92%、テストが99%出るそうです。ぱっと見人間的には難しいのですが、ニューラルネットワークにとっては精度の出やすいデータセットとなっています(少し精度がサチったデータセットを選んでしまったかもしれない)。

訓練の流れ

Swin Transformerを含めて以下の3つのモデルで転移学習をします。すべてtimmライブラリからダウンロードできる訓練済み係数を利用します。

  • ResNet 50
  • EfficientNet B0
  • Swin Transformer (Swin-S)

入力解像度はすべて224です。Data AugmentationはすべてHorizontal Flip(水平反転)だけ使います。これはSwin Transformerの転移学習が、既存のCNN同様簡単なData Augmentationでも安定して訓練できるかどうかを見るためです。

学習率のコントロールはすべてのモデルで、Warmup+Cosine Decayを使います。学習率、エポック数はすべて同一とします。バッチサイズはすべて32としました。

訓練コード

import os
import glob
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import timm 
from timm.scheduler import CosineLRScheduler
import torchmetrics
from tqdm import tqdm
from PIL import Image

class BirdSpeciesDataset(Dataset):
    def __init__(self, img_dir, root_dir="./data", transform=None):
        df = pd.read_csv(root_dir+"/class_dict.csv")
        self.classes = df["class"].to_list()

        self.img_labels = []
        for dir in sorted(glob.glob(f"{root_dir}/{img_dir}/*")):
            class_key = os.path.basename(dir)
            key_ind = self.classes.index(class_key)
            if key_ind >= 0:
                for img in sorted(glob.glob(f"{dir}/*.jpg")):
                    self.img_labels.append([img.replace("\\", "/"), key_ind])
        self.transform = transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        image = Image.open(self.img_labels[idx][0])
        label = self.img_labels[idx][1]
        if self.transform:
            image = self.transform(image)
        return image, label

def main(net_name, batch_size):
    trainset = BirdSpeciesDataset("train", transform=transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
    ]))
    valset = BirdSpeciesDataset("valid", transform=transforms.Compose([
        transforms.ToTensor()
    ]))
    trainloader = DataLoader(trainset, batch_size, shuffle=True, num_workers=2)
    valloader = DataLoader(valset, batch_size, shuffle=False, num_workers=2)

    if net_name == "swin_transformer":
        model = timm.create_model('swin_small_patch4_window7_224', pretrained=True)
        model.head = torch.nn.Linear(768, len(trainset.classes), bias=True)
    elif net_name == "resnet":
        model = timm.create_model('resnet50', pretrained=True)
        model.head = torch.nn.Linear(2048, len(trainset.classes), bias=True)
    elif net_name == "efficient_net":
        model = timm.create_model("efficientnet_b0", pretrained=True)
        model.classifier = torch.nn.Linear(1280, len(trainset.classes), bias=True)

    gpu_id = 0
    model = model.cuda(gpu_id)
    optimizer = torch.optim.AdamW(model.parameters())
    criterion = torch.nn.CrossEntropyLoss()
    scheduler = CosineLRScheduler(optimizer, t_initial=40, 
                                  lr_min=1e-4, warmup_t=5, 
                                  warmup_lr_init=5e-5, warmup_prefix=True)

    train_acc, val_acc = torchmetrics.Accuracy(), torchmetrics.Accuracy()
    max_val_acc = 0.0
    sw = SummaryWriter(net_name+"_logs")

    for epoch in range(40):
        model.train()
        train_acc.reset(), val_acc.reset()

        for x, y in tqdm(trainloader):
            optimizer.zero_grad()

            x, y = x.cuda(gpu_id), y.cuda(gpu_id)
            y_pred = model(x)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()

            train_acc(torch.argmax(y_pred, -1).cpu(), y.cpu())

        model.eval()
        for x, y in valloader:
            with torch.no_grad():
                x, y = x.cuda(gpu_id), y.cuda(gpu_id)
                y_pred = model(x)
            val_acc(torch.argmax(y_pred, -1).cpu(), y.cpu())

        print(f"Epoch{epoch+1}, TrainAcc:{train_acc.compute()}, ValAcc:{val_acc.compute()}")
        if max_val_acc < val_acc.compute():
            print(f"Val acc improved from {max_val_acc} to {val_acc.compute()}")
            torch.save(model.state_dict(), net_name)
            print("Model saved.")
            max_val_acc = val_acc.compute()
        sw.add_scalar("Train Accuracy", train_acc.compute(), epoch+1)
        sw.add_scalar("Validation Accuracy", val_acc.compute(), epoch+1)

        scheduler.step(epoch+1)

if __name__ == "__main__":
    #main("resnet", 32)
    #main("efficient_net", 32)
    main("swin_transformer", 32)

テストコード

import timm
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from train import BirdSpeciesDataset
import torchmetrics
import time

def test():
    testset = BirdSpeciesDataset("test", transform=transforms.Compose([
        transforms.ToTensor()
    ]))
    testloader = DataLoader(testset, 32, shuffle=False, num_workers=2)

    test_acc = torchmetrics.Accuracy()

    for net_name in ["swin_transformer", "resnet", "efficient_net"]:
        if net_name == "swin_transformer":
            model = timm.create_model('swin_small_patch4_window7_224')
            model.head = torch.nn.Linear(768, len(testset.classes), bias=True)
        elif net_name == "resnet":
            model = timm.create_model('resnet50')
            model.head = torch.nn.Linear(2048, len(testset.classes), bias=True)
        elif net_name == "efficient_net":
            model = timm.create_model("efficientnet_b0")
            model.classifier = torch.nn.Linear(1280, len(testset.classes), bias=True)

        model.load_state_dict(torch.load(net_name))
        model = model.cuda()

        torch.cuda.empty_cache()
        test_acc.reset()
        model.eval()

        start_time = time.time()
        for x, y in testloader:
            with torch.no_grad():
                y_pred = model(x.cuda())
            test_acc(torch.argmax(y_pred, -1).cpu(), y)

        print(net_name, test_acc.compute(), time.time()-start_time)

if __name__ == "__main__":
    test()

結果

train valid test testの推論秒 訓練秒/エポック
ResNet 50 99.99% 98.73% 99.30% 5.46 316
EfficientNet B0 99.97% 99.05% 99.30% 4.58 298
Swin S 99.96% 98.67% 99.37% 10.08 679

※2080Ti×1で訓練・評価

ValidationではEfficientNet B0が最も精度が高く、次にResNet 50、Swin-Sは最も低い結果となりました。スループットでもEfficientNet B0が最も良く、Swin-SはResNet 50と比べて2倍近く遅い結果となりました。ただ、テストの精度ではわずかにSwin-Sのほうが高くなっていますが、精度がサチっているのでこの0.07%にどの程度の意味があるかはなんともいえません。

推論時にmodel.eval()を忘れると…

これだけ見ると「EfficientNet B0でいいじゃん」となるのですが、EfficientNetは思わぬハマりどころがあることに気づきました。このコードでは評価時はmodel.eval()をしているのですが、model.eval()を忘れると大きく精度が変わります。model.eval()をコメントアウトし、trainのモードでValidとTestの精度を見てみましょう。

model.eval()なし valid test valid diff test diff
ResNet 50 96.06% 97.08% -2.67% -2.22%
EfficientNet B0 93.59% 95.11% -5.46% -4.19%
Swin S 98.35% 99.17% -0.32% -0.20%

model.eval()で訓練時に評価し、推論時にmodel.eval()を忘れたときの精度の低下は無視できないものがあります。EfficientNetは最も下落幅が大きく4~5%も下落し、最下位となっています。ResNet 50よりも悪い結果となったのは驚きです。

推論時にmodel.eval()を忘れたことによる精度低下が最も少なかったのがSwin Transformerです。これは使っているNoramlizationのレイヤーの違いによるものでしょう。ResNetやEfficientNetはBatchNormを使っていますが、Swin TransformerはLayer Normalizationを使っています。これらのtrain/evalモードでの挙動の違いが、精度の低下をもたらしていると考えられます。Layer Normalizationはこの副作用が少ないのかもしれません。

まとめ

今回の実験の目的は、Swin Transformerの転移学習において、Horizontal Flipのような簡単なData Augmentationでも訓練が安定するかという点を調べることでした。今回の結果では、ResNetやEfficientNetと若干劣る精度となりましたが、いい勝負が出るくらいの精度は出せているので、訓練の安定性という点ではYESといえるでしょう。前回のCIFAR-10の記事で見たとおり、Swin Transformerは1から訓練するとハイパラの設定が大変でしたが、転移学習では他のCNN同様に簡単に訓練できることが示せました。

ただし、Swin Transformerの精度やスループットについては疑問が残りそうです。精度やスループットでEfficientNetをぶっちぎったかというとそうでもなく、スループットではGPUでもかなり重いのが現状でした。もう少し長く訓練したらSwin Transformerも、もう少し精度上がったのかもしれません。今回は精度がResNetクラスで相当サチっていたのでそれもあるかもしれません。

一方で精度やスループットで最高を示したEfficientNetは、model.eval()の付け忘れには弱いことがわかりました。EfficientNetやResNetといったBatch Normを含むモデルでは、これらの切り替えに気をつける必要がありそうです。この点ではSwin TransformerのLayer Normalizationは副作用が少ないようです。

Pocket
LINEで送る



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


One Comment

Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です