こしあん
2019-09-26

PyTorchで双方向連結リストなデータ構造のモデルを作る

Pocket
LINEで送る
Delicious にシェア

127{icon} {views}

新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』(A4・195ページ)好評通販中です! 機械学習の入門からGANの最先端までを書いたおすすめの本となっています! Boothで試し読みできます。情報まとめ・質問用GitHub



ディープラーニングのモデルには、訓練の途中でレイヤーを追加するなど特殊な訓練をするものがあります(Progressive-GANなど)。そのとき、モデルを「レイヤーやブロックの連結リスト」として定義しておくと見通しがよくなることがあります。その例を見ていきます。

訓練中に継ぎ足していくモデル

例えば、Progressive-GAN(図は論文から

Progressive-GAN(PG-GAN)は低解像度から高解像度で積み重ねるように訓練していくのが特徴です。低解像度の訓練が終わったら、より高い解像度のレイヤーを継ぎ足す→そのレイヤーを訓練する→また継ぎ足すのように層状に訓練していきます。

双方向連結リストとして考える

このようなモデルの何が困るかというと、訓練する際にいちいちレイヤーを指定したりインデックスを指定しないといけないのです。PG-GANはモデル構造が決定的なので、全てのレイヤーをリストとかで格納するでも良いのですが、連結リストとして考えてみるとわかりやすいかもしれません。

連結リストとはデータ構造の一種で、このように前後のオブジェクト(ノードとか言います)の参照を記録しておくものです。後ろのオブジェクトだけ記録しておけば(前もあるかもしれないけど普通は後ろ)片方向、前後両方記録しておけば双方向の連結リストとなります。

今回は、PyTorchのレイヤーやモジュール(nn.Module)を連結リストのノードと見立てて、リスト全体で1個のモデルになるようにしてみます。

コード

ざっくりと実装してこんな感じ。Moduleクラスが追加するブロックの中身。Blockクラスが前後の参照を含んだブロック構成になります。

import torch
from torch import nn

class Module(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True)
        )

    def forward(self, inputs):
        return self.layers(inputs)


class Block(nn.Module):
    def __init__(self, new_in_ch, new_out_ch, prev_block):
        super().__init__()
        self.prev_block = prev_block
        self.current_block = Module(new_in_ch, new_out_ch)
        self.next_block = None

    def forward(self, inputs):
        x = self.current_block(inputs)
        if self.next_block is not None:
            return self.next_block.forward(x) # forwardを再帰にする
        else:
            return x

    def add_block(self, new_block):
        self.next_block = new_block

def count_params(model):
    return sum([p.numel() for p in model.parameters()])

def main():
    x = torch.randn(16, 3, 32, 32)
    b1 = Block(3, 32, None) # 最初のブロック
    print(count_params(b1), b1(x).size())

    b2= Block(32, 64, b1) # 2番目ブロック
    b1.add_block(b2) # 最初のブロックに2番目を追加
    print(count_params(b1), b1(x).size()) # 呼び出すときは最初のブロック

    b3 = Block(64, 128, b2)
    b2.add_block(b3)
    print(count_params(b1), b1(x).size())

if __name__ == "__main__":
    main()  

forwardを次のブロックの有無に応じて再帰にするのがポイントですね。呼び出すときは最初のブロックでやります。出力はこんな感じ。

928 torch.Size([16, 32, 32, 32])
19488 torch.Size([16, 64, 32, 32])
93472 torch.Size([16, 128, 32, 32])

パラメーターが二重カウントされるのではないかと危惧しましたがそういうことはないようです。backpropもちゃんと計算できました。


新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』好評通販中(A4・195ページ)です! Boothで試し読みもできるのでよろしくね!


Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です