こしあん
2019-09-26

PyTorchで双方向連結リストなデータ構造のモデルを作る

Pocket
LINEで送る


ディープラーニングのモデルには、訓練の途中でレイヤーを追加するなど特殊な訓練をするものがあります(Progressive-GANなど)。そのとき、モデルを「レイヤーやブロックの連結リスト」として定義しておくと見通しがよくなることがあります。その例を見ていきます。

訓練中に継ぎ足していくモデル

例えば、Progressive-GAN(図は論文から

Progressive-GAN(PG-GAN)は低解像度から高解像度で積み重ねるように訓練していくのが特徴です。低解像度の訓練が終わったら、より高い解像度のレイヤーを継ぎ足す→そのレイヤーを訓練する→また継ぎ足すのように層状に訓練していきます。

双方向連結リストとして考える

このようなモデルの何が困るかというと、訓練する際にいちいちレイヤーを指定したりインデックスを指定しないといけないのです。PG-GANはモデル構造が決定的なので、全てのレイヤーをリストとかで格納するでも良いのですが、連結リストとして考えてみるとわかりやすいかもしれません。

連結リストとはデータ構造の一種で、このように前後のオブジェクト(ノードとか言います)の参照を記録しておくものです。後ろのオブジェクトだけ記録しておけば(前もあるかもしれないけど普通は後ろ)片方向、前後両方記録しておけば双方向の連結リストとなります。

今回は、PyTorchのレイヤーやモジュール(nn.Module)を連結リストのノードと見立てて、リスト全体で1個のモデルになるようにしてみます。

コード

ざっくりと実装してこんな感じ。Moduleクラスが追加するブロックの中身。Blockクラスが前後の参照を含んだブロック構成になります。

import torch
from torch import nn

class Module(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True)
        )

    def forward(self, inputs):
        return self.layers(inputs)


class Block(nn.Module):
    def __init__(self, new_in_ch, new_out_ch, prev_block):
        super().__init__()
        self.prev_block = prev_block
        self.current_block = Module(new_in_ch, new_out_ch)
        self.next_block = None

    def forward(self, inputs):
        x = self.current_block(inputs)
        if self.next_block is not None:
            return self.next_block.forward(x) # forwardを再帰にする
        else:
            return x

    def add_block(self, new_block):
        self.next_block = new_block

def count_params(model):
    return sum([p.numel() for p in model.parameters()])

def main():
    x = torch.randn(16, 3, 32, 32)
    b1 = Block(3, 32, None) # 最初のブロック
    print(count_params(b1), b1(x).size())

    b2= Block(32, 64, b1) # 2番目ブロック
    b1.add_block(b2) # 最初のブロックに2番目を追加
    print(count_params(b1), b1(x).size()) # 呼び出すときは最初のブロック

    b3 = Block(64, 128, b2)
    b2.add_block(b3)
    print(count_params(b1), b1(x).size())

if __name__ == "__main__":
    main()  

forwardを次のブロックの有無に応じて再帰にするのがポイントですね。呼び出すときは最初のブロックでやります。出力はこんな感じ。

928 torch.Size([16, 32, 32, 32])
19488 torch.Size([16, 64, 32, 32])
93472 torch.Size([16, 128, 32, 32])

パラメーターが二重カウントされるのではないかと危惧しましたがそういうことはないようです。backpropもちゃんと計算できました。

Related Posts

TensorFlow2.0+TPUでData AugmentationしながらCIFAR-10... TensorFlow2.0+TPUでData AugmentationしながらCIFAR-10を分類するサンプルです。Data Augmentationはtf.dataでやるのがポイントです。 TensorFlowを2.Xに上げる まずは、ランタイム切り替えで「TPU」を選択しましょう。無料で...
Self-attention GAN(SAGAN)を実装して遊んでみた... 前回の投稿では、Spectral Noramlizationを使ったGAN「SNGAN」を実装しましたが、それの応用系であるSelf-attention GAN「SAGAN」を実装して遊んでみました。CIFAR-10、STL-10、AnimeFace Dataset、Oxford Flowerを生...
GANでGeneratorの損失関数をmin(log(1-D))からmaxlog Dにした場合の実験... GANの訓練をうまくいくためのTipとしてよく引用される、How to train GANの中から、Generatorの損失関数をmin(log(1-D))からmaxlog Dにした場合を実験してみました。その結果、損失結果を変更しても出力画像のクォリティーには大して差が出ないことがわかりました。...
TensorFlow2.0でDistribute Trainingしたときにfitと訓練ループで精度... TensorFlowでDistribute Training(複数GPUやTPUでの訓練)をしたときに、Keras APIのfit()でのValidation精度と、訓練ループを書いたときの精度でかなり(1~2%)違うという状況に遭遇しました。特定の文を忘れただけだったのですが、解決に1日かかった...
Numpyの配列に対して「最も多く存在する値」を求める方法... アンサンブル学習などで、Numpyの配列のある軸に対して「最も多く存在する値」を求めたい、つまり「多数決」をしたいことがあります。その方法を見ていきます。 最も大きい値がmax, 最も大きい値が存在するインデックスがargmax, では「最も多く存在する値」は? 配列のある軸に対して、「最も大...
Pocket
LINEで送る

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です