こしあん
2019-09-16

画像分類で比較するBatch Norm, Instance Norm, Spectral Normの勾配の大きさ

Pocket
LINEで送る
Delicious にシェア

3.4k{icon} {views}



GANの安定化のために、Batch Normalizationを置き換えるということがしばしば行われます。その置き換え先として、Spectral Norm、Instance Normなどが挙げられます。今回はGANではなく普通の画像分類の問題としてBatch Normを置き換えし、勾配のノルムどのように変わるかを比較します。

はじめに

GANの安定化のためにBatch Normalizationをそのまま使わない、ここを工夫するというのはほとんど定石になりつつあります。最たるものが「Spectral Normalization」で、特異値分解によるリプシッツ定数のコントロールにより、Dのリプシッツ連続性を保証するような制約をおくということを行っています。具体的には、GANのDのBatch NormをSpectral Normで置き換えするという単純かつ明快なものです。このSpectral Normは、GANにおいてImage Netに対する現在のSoTAであるBigGANにおいても広く使われています。

Spectral Normを使わなくてもBatch Normの置き換えをするというケースはあります。例えば、pix2pix HDではNormalizationにInstance Normalizationを使っています。またProgressive GANではLayer Normzalitionを使ったりPixelwise Normalizationを使ったりといろいろ工夫をしています。

今回は、一旦GANから離れて通常の画像分類の問題に戻り、「Normalizationの違いによってレイヤー単位での勾配がどの程度変わるのか」を調べます。具体的には、画像分類のモデルを訓練していって、レイヤー単位の勾配ノルムがどのように変化するのかを調べます。

やること

CIFAR-10の分類モデルを訓練します。以下の2つのモデルを用意します(モデルの詳細はコード参照)。

  • 10層モデル
  • ResNetライクなモデル(ResNetの論文からCIFARの解像度にあうように適当に変更したものです)

そして各モデルに対し、Normalizationを3種類変更します。すべてPyTorchに組み込まれているものです。

  • Batch Normalization
  • Instance Normalization
  • Spectral Normalization

すべてで6ケースになります。各ケースに対してエポックのはじめにConvレイヤーの勾配ノルムを計算します。具体的には、

def calc_gradients(model, input_x, target_y):
    model.zero_grad()
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(model(input_x), target_y)
    loss.backward()
    grads = [p.grad.norm().item() for p in model.parameters() if len(p.size()) == 4]
    return grads

このようなコードで計算します。ここでinput_xは入力画像、target_yは本物のラベルです。テストデータすべてに対する勾配ノルムを計算し、ミニバッチ単位で平均します。この平均をエポック単位、ケース単位で比較します。

バッチサイズは128とし、学習率は0.1でモメンタム係数は0.9としました。100エポック訓練し、50エポック目と80エポック目で学習率を1/10にします。また、ResNetライクなモデルに限り、Weight Decay=0.0001を追加しました。

注意

Instance NormのResNetライクなモデルだけ、なぜかうまく行かなかった(ロスがずっと下がらないし初期の勾配がすべて0になってしまう)ので結果を省略します。Batch Normはすべてで訓練が成功しましたが、Spectral NormやInstance Normを画像分類で使う場合は「初期値ガチャ」みたいなものがあるようで、そのガチャに失敗するとずっとロスが下がらないということが観測されました。ここに掲載しているのはうまくいったケースです。

テスト精度推移

テストデータをValidationとしたときの精度推移はこちらです。横軸がエポック、縦軸が精度を表します。

精度を単純に上げたいのなら、従来どおりBatch Normzalitionがやはり有効なようです。Spectral Normは振れ幅が低すぎてロスが下がり切る前に学習率が減衰してしまったのかもしれません。

Normalization別の勾配ノルム

次はモデル-エポック単位で切り出したときの勾配推移です。横軸が層のインデックス(大きい値ほど深い)、縦軸がノルムの値です。

Batch Normは訓練開始時の勾配がひたすら大きいという傾向があります。これはResNet、10層問わず同じです。Batch Normの勾配は訓練が進むほど他のNormalizationよりも低くなりがちになります。

Instance Normは10層モデルだけですが、Batch Normほど訓練開始時の勾配は大きくありません。最終的にはBatch Normと同じように訓練が進んでいきます。

少し変わっているのがSpectral Normです。Spectral Normの訓練開始時の勾配は驚くほど低いです。GANのように失敗が訓練開始時に起きやすいケースでは、Spectral Normは有効でしょう。しかし面白いのがその後で、訓練が進むほどBatch NormやInstance Normの勾配が下がっていくのに対して、Spectral Normはだんだん勾配が大きくなっていきます。特に10層モデルのようなResNetではない(Skip Connectionのない)ケースでは顕著で、層の深さに対して指数関数的に上がっているようにも見えます。Self attention GANやBig GANのように訓練が進むと崩壊するケースがあるのはもしかしたらこういう理由なのかもしれません。ResNetを使えば多少は勾配の発散を軽減できそうなので、GANでのSpectral NormがResNetとセットで使われるのもなんとなく理解できます。

Epoch別の勾配ノルム

次は、モデル-Normalization単位で切り出したときの勾配推移です。横軸が層のインデックス(大きいほど深い層)、縦軸がノルムの値です。

訓練が進むほど勾配が大きくなるのはどのNormalizationでも共通でした(ただしResNetのBatch Normだけ例外)。変わるのはNormalization別の相対的な順位ということでしょうか。

10層モデルの場合、Spectral Normは初期の勾配は低いですが、訓練が進むほど、層の深さに対して指数関数的な勾配構成になります。Batch NormやInstance Normの場合は真ん中の層の勾配が膨らむので、深い層の勾配が膨らむのはあんまりよくないような感じがします(GANでモデルを深くするというのはあんまりよくないという話は聞いた記憶がある)。Spectral Normでも、ResNetにすると深い層の勾配が膨らまなくなるので、ResNetは大事なのでしょう。

しかし、ResNetにするとなぜかBatch Normの初期の勾配が驚くほど高いという不思議なことがおこります。これは10層モデルでは起こりませんでした。確かにこれはこれでよくないのでしょう(Spectral Normが良いと言われる理由はおそらくこれ)。

まとめ

今回の実験をまとめます。

  • 訓練開始時の勾配ノルムは、「Batch Norm>Instance Norm>Spectral Norm」だった。特にResNetではBatch Normが初期になぜかとても高い勾配を出すので、これがGANの失敗に寄与しているのだと思われる。この点ではSpectral Normは良い。
  • しかし訓練が進んでいくと、Spectral Normの勾配ノルムがBatch Normよりも遥かに高くなるケースがある。ResNetを使わない場合、深い層ほど指数関数的に勾配が増加していくので、Spectral Normを使う場合はResNetを使ったほうがよさそう。GANで訓練がある程度進んだところから急に崩壊するというのは、このSpectral Normの勾配が大きくなりすぎていることによるものなのかもれない。
  • Spectral Normは現時点(2019年9月)では非常に有効なGANの安定化手法であるが、もしかするとまだまだ改良の余地があるのかもしれない。

コード

モデル(models.py)

import torch
from torch import nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.nn.utils.spectral_norm import spectral_norm

class TenLayersModel(nn.Module):
    def __init__(self, normalization):
        super().__init__()
        self.convs = self.create_model(normalization)
        self.linear = nn.Linear(256, 10)

    def conv_norm_relu(self, in_ch, out_ch, normalization):
        layers = []
        if normalization == "spectral":
            w = nn.Conv2d(in_ch, out_ch, 3, padding=1)
            layers.append(spectral_norm(w))
        else:
            layers.append(nn.Conv2d(in_ch, out_ch, 3, padding=1))
            if normalization == "batch":
                layers.append(nn.BatchNorm2d(out_ch))
            elif normalization == "instance":
                layers.append(nn.InstanceNorm2d(out_ch))
        layers.append(nn.ReLU(True))
        return layers

    def create_model(self, normalization):
        layers = []
        for in_ch in [3, 64, 64]:
            layers += self.conv_norm_relu(in_ch, 64, normalization)
        layers.append(nn.AvgPool2d(2))
        for in_ch in [64, 128, 128]:
            layers += self.conv_norm_relu(in_ch, 128, normalization)
        layers.append(nn.AvgPool2d(2))
        for in_ch in [128, 256, 256]:
            layers += self.conv_norm_relu(in_ch, 256, normalization)
        layers.append(nn.AvgPool2d(8))
        return nn.Sequential(*layers)

    def forward(self, inputs):
        x = self.convs(inputs).view(inputs.size(0), -1)
        x = self.linear(x)
        return x

class ResNetPreactModule(nn.Module):
    def __init__(self, in_ch, out_ch, downsampling, normalization):
        assert normalization in ["batch", "spectral", "instance"]
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.shortcut_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else None

        self.norm1 = self.get_normalization(in_ch, normalization)
        self.norm2 = self.get_normalization(out_ch, normalization)
        if normalization == "spectral":
            self.conv1 = spectral_norm(self.conv1)
            self.conv2 = spectral_norm(self.conv2)
            self.shortcut_conv = spectral_norm(self.shortcut_conv) if self.shortcut_conv is not None else None
        self.downsampling = nn.AvgPool2d(downsampling) if downsampling > 1 else None

    def get_normalization(self, ch, normalization):
        if normalization == "batch":
            return nn.BatchNorm2d(ch)
        elif normalization == "instance":
            return nn.InstanceNorm2d(ch)
        else:
            return None

    def forward(self, inputs):
        # main path
        x = self.norm1(inputs) if self.norm1 is not None else inputs
        x = F.relu(x)
        x = self.conv1(x)
        x = self.norm2(x) if self.norm2 is not None else x
        x = F.relu(x)
        x = self.conv2(x)
        # shortcut path
        shortcut = self.shortcut_conv(inputs) if self.shortcut_conv is not None else inputs
        # downsampling
        if self.downsampling is not None:
            x = self.downsampling(x)
            shortcut = self.downsampling(shortcut)
        return x + shortcut

class ResNetLikeModel(nn.Module):
    def __init__(self, normalization):
        super().__init__()
        self.conv = nn.Sequential(
            *self.resnet_block(3, 64, 3, normalization),
            *self.resnet_block(64, 128, 4, normalization),
            *self.resnet_block(128, 256, 6, normalization, enable_downsampling=False),
            nn.AvgPool2d(8)
        )
        if normalization == "batch":
            self.last_norm = nn.BatchNorm2d(256)
        elif normalization == "instance":
            self.last_norm = nn.InstanceNorm2d(256)
        else:
            self.last_norm = None
        self.linear = nn.Linear(256, 10)

    def resnet_block(self, in_ch, out_ch, reps, normalization, enable_downsampling=True):
        layers = []
        for i in range(reps):
            current_in = in_ch if i == 0 else out_ch
            down = 2 if i == reps-1 and enable_downsampling else 1
            layers.append(ResNetPreactModule(current_in, out_ch, down, normalization))
        return layers

    def forward(self, inputs):
        x = self.conv(inputs)
        x = self.last_norm(x) if self.last_norm is not None else x
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

勾配計算(compute_gradients.py)

import torch
import torchvision
from models import TenLayersModel, ResNetLikeModel
from torchvision import transforms
from tqdm import tqdm
import numpy as np
import os
import pickle

def load_cifar():
    trans = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))        
    ])
    trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=trans)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)
    testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=trans)
    testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)
    return trainloader, testloader

def calc_gradients(model, input_x, target_y):
    model.zero_grad()
    loss_func = torch.nn.CrossEntropyLoss()
    loss = loss_func(model(input_x), target_y)
    loss.backward()
    grads = [p.grad.norm().item() for p in model.parameters() if len(p.size()) == 4]
    return grads

def main(network, normalization):
    if network == "ten":
        model = TenLayersModel(normalization)
    elif network == "resnet":
        model = ResNetLikeModel(normalization)

    model_name = f"{network}_{normalization}"
    output_dir = "snapshot"
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    device = "cuda"
    batch_size = 128
    model.to(device)
    model = torch.nn.DataParallel(model)

    trainloader, testloader = load_cifar()
    log_gradients = []
    log_loss = []
    log_val_acc = []

    weight_decay = 0.0001 if network == "resnet" else 0

    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.1)
    criterion = torch.nn.CrossEntropyLoss()
    max_val_acc = 0.0

    for epoch in tqdm(range(100)):
        # gradient check
        gradients = []
        for X, y in testloader:
            if len(X) != 128: continue
            X, y = X.to(device), y.to(device)
            gradients.append(calc_gradients(model, X, y))
        model.zero_grad()
        layer_grads = np.mean(np.array(gradients), axis=0)  # batch-wise mean
        log_gradients.append(layer_grads)

        # train
        train_loss = 0.0
        for i, (X, y) in enumerate(trainloader):
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            y_pred = model(X)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        log_loss.append(train_loss / (i + 1))

        # validation
        with torch.no_grad():
            correct, total = 0, 0
            for X, y in testloader:
                X, y = X.to(device), y.to(device)
                outputs = model(X)
                _, pred = torch.max(outputs.data, 1)
                total += y.size(0)
                correct += (pred == y).sum().item()
            log_val_acc.append(correct / total)

        # save model
        if max_val_acc < log_val_acc[-1]:
            torch.save(model.state_dict(), f"{output_dir}/{model_name}.pytorch")
            max_val_acc = log_val_acc[-1]

        scheduler.step()

        print("Epoch =", epoch, "Loss =", log_loss[-1], "Val_acc =", log_val_acc[-1], "/ ", model_name)
        # print(log_gradients[-1])

    # save result
    with open(f"{output_dir}/log_{model_name}.pkl", "wb") as fp:
        result = {"gradient":log_gradients, "loss":log_loss, "val_acc":log_val_acc}
        pickle.dump(result, fp)

if __name__ == "__main__":
    for model in ["ten", "resnet"]:
        for norm in ["batch", "instance", "spectral"]:
            main(model, norm)



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

【新刊】インフィニティNumPy――配列の初期化から、ゲームの戦闘、静止画や動画作成までの221問

「本当の実装力を身につける」ための221本ノック――
機械学習(ML)で避けて通れない数値計算ライブラリ・NumPyを、自在に活用できるようになろう。「できる」ための体系的な理解を目指します。基礎から丁寧に解説し、ディープラーニング(DL)の難しいモデルで遭遇する、NumPyの黒魔術もカバー。初心者から経験者・上級者まで楽しめる一冊です。問題を解き終わったとき、MLやDLなどの発展分野にスムーズに入っていけるでしょう。

本書の大きな特徴として、Pythonの本でありがちな「NumPyとML・DLの結合を外した」点があります。NumPyを理解するのに、MLまで理解するのは負担が大きいです。本書ではあえてこれらの内容を書いていません。行列やテンソルの理解に役立つ「従来の画像処理」をNumPyベースで深く解説・実装していきます。

しかし、問題の多くは、DLの実装で頻出の関数・処理を重点的に取り上げています。経験者なら思わず「あー」となるでしょう。関数丸暗記では自分で実装できません。「覚える関数は最小限、できる内容は無限大」の世界をぜひ体験してみてください。画像編集ソフトの処理をNumPyベースで実装する楽しさがわかるでしょう。※紙の本は電子版の特典つき

モザイク除去から学ぶ 最先端のディープラーニング

「誰もが夢見るモザイク除去」を起点として、機械学習・ディープラーニングの基本をはじめ、GAN(敵対的生成ネットワーク)の基本や発展型、ICCV, CVPR, ECCVといった国際学会の最新論文をカバーしていく本です。
ディープラーニングの研究は発展が目覚ましく、特にGANの発展型は市販の本でほとんどカバーされていない内容です。英語の原著論文を著者がコードに落とし込み、実装を踏まえながら丁寧に解説していきます。
また、本コードは全てTensorFlow2.0(Keras)に対応し、Googleの開発した新しい機械学習向け計算デバイス・TPU(Tensor Processing Unit)をフル活用しています。Google Colaboratoryを用いた環境構築不要の演習問題もあるため、読者自ら手を動かしながら理解を深めていくことができます。

AI、機械学習、ディープラーニングの最新事情、奥深いGANの世界を知りたい方にとってぜひ手にとっていただきたい一冊となっています。持ち運びに便利な電子書籍のDLコードが付属しています。

「おもしろ同人誌バザールオンライン」で紹介されました!(14:03~) https://youtu.be/gaXkTj7T79Y?t=843

まとめURL:https://github.com/koshian2/MosaicDeeplearningBook
A4 全195ページ、カラー12ページ / 2020年3月発行

Shikoan's ML Blog -Vol.1/2-

累計100万PV超の人気ブログが待望の電子化! このブログが電子書籍になって読みやすくなりました!

・1章完結のオムニバス形式
・機械学習の基本からマニアックなネタまで
・どこから読んでもOK
・何巻から読んでもOK

・短いものは2ページ、長いものは20ページ超のものも…
・通勤・通学の短い時間でもすぐ読める!
・読むのに便利な「しおり」機能つき

・全巻はA5サイズでたっぷりの「200ページオーバー」
・1冊にたっぷり30本収録。1本あたり18.3円の圧倒的コストパフォーマンス!
・文庫本感覚でお楽しみください

北海道の駅巡りコーナー

日高本線 車なし全駅巡り

ローカル線や秘境駅、マニアックな駅に興味のある方におすすめ! 2021年に大半区間が廃線になる、北海道の日高本線の全区間・全29駅(苫小牧~様似)を記録した本です。マイカーを使わずに、公共交通機関(バス)と徒歩のみで全駅訪問を行いました。日高本線が延伸する計画のあった、襟裳岬まで様似から足を伸ばしています。代行バスと路線バスの織り成す極限の時刻表ゲームと、絶海の太平洋と馬に囲まれた日高路、日高の隠れたグルメを是非たっぷり堪能してください。A4・フルカラー・192ページのたっぷりのボリュームで、あなたも旅行気分を漫喫できること待ったなし!

見どころ:日高本線被災区間(大狩部、慶能舞川橋梁、清畠~豊郷) / 牧場に囲まれた絵笛駅 / 窓口のあっただるま駅・荻伏駅 / 汐見の戦争遺跡のトーチカ / 新冠温泉、三石温泉 / 襟裳岬

A4 全192ページフルカラー / 2020年11月発行


Pocket
LINEで送る

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です