こしあん
2019-09-26

PyTorchで双方向連結リストなデータ構造のモデルを作る

Pocket
LINEで送る


ディープラーニングのモデルには、訓練の途中でレイヤーを追加するなど特殊な訓練をするものがあります(Progressive-GANなど)。そのとき、モデルを「レイヤーやブロックの連結リスト」として定義しておくと見通しがよくなることがあります。その例を見ていきます。

訓練中に継ぎ足していくモデル

例えば、Progressive-GAN(図は論文から

Progressive-GAN(PG-GAN)は低解像度から高解像度で積み重ねるように訓練していくのが特徴です。低解像度の訓練が終わったら、より高い解像度のレイヤーを継ぎ足す→そのレイヤーを訓練する→また継ぎ足すのように層状に訓練していきます。

双方向連結リストとして考える

このようなモデルの何が困るかというと、訓練する際にいちいちレイヤーを指定したりインデックスを指定しないといけないのです。PG-GANはモデル構造が決定的なので、全てのレイヤーをリストとかで格納するでも良いのですが、連結リストとして考えてみるとわかりやすいかもしれません。

連結リストとはデータ構造の一種で、このように前後のオブジェクト(ノードとか言います)の参照を記録しておくものです。後ろのオブジェクトだけ記録しておけば(前もあるかもしれないけど普通は後ろ)片方向、前後両方記録しておけば双方向の連結リストとなります。

今回は、PyTorchのレイヤーやモジュール(nn.Module)を連結リストのノードと見立てて、リスト全体で1個のモデルになるようにしてみます。

コード

ざっくりと実装してこんな感じ。Moduleクラスが追加するブロックの中身。Blockクラスが前後の参照を含んだブロック構成になります。

import torch
from torch import nn

class Module(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True)
        )

    def forward(self, inputs):
        return self.layers(inputs)


class Block(nn.Module):
    def __init__(self, new_in_ch, new_out_ch, prev_block):
        super().__init__()
        self.prev_block = prev_block
        self.current_block = Module(new_in_ch, new_out_ch)
        self.next_block = None

    def forward(self, inputs):
        x = self.current_block(inputs)
        if self.next_block is not None:
            return self.next_block.forward(x) # forwardを再帰にする
        else:
            return x

    def add_block(self, new_block):
        self.next_block = new_block

def count_params(model):
    return sum([p.numel() for p in model.parameters()])

def main():
    x = torch.randn(16, 3, 32, 32)
    b1 = Block(3, 32, None) # 最初のブロック
    print(count_params(b1), b1(x).size())

    b2= Block(32, 64, b1) # 2番目ブロック
    b1.add_block(b2) # 最初のブロックに2番目を追加
    print(count_params(b1), b1(x).size()) # 呼び出すときは最初のブロック

    b3 = Block(64, 128, b2)
    b2.add_block(b3)
    print(count_params(b1), b1(x).size())

if __name__ == "__main__":
    main()  

forwardを次のブロックの有無に応じて再帰にするのがポイントですね。呼び出すときは最初のブロックでやります。出力はこんな感じ。

928 torch.Size([16, 32, 32, 32])
19488 torch.Size([16, 64, 32, 32])
93472 torch.Size([16, 128, 32, 32])

パラメーターが二重カウントされるのではないかと危惧しましたがそういうことはないようです。backpropもちゃんと計算できました。

Related Posts

Google ColabのTPUでメッセージがうるさいときにやること... Google ColabのTPUを使っているとえらいメッセージが表示されて、うるさいときがあります。そんなときにメッセージを消す裏技を発見したので書いていきたいと思います。 こんなメッセージ おそらくTPUがまだ実験段階だからかと思われますが、デフォルトだとこんな感じにずらずらとログが流れます...
Python(Numpy)で画像を水平反転する方法:Data Augmentation向け... OpenCVを使わずに単純に画像を左右反転(水平反転)する方法を考えます。ディープラーニングでデータのジェネレーターを自分で実装した場合、Data Augmentationを組み込む際にも必要になります。それを見ていきましょう。 左右反転自体は実は簡単 例えばNumpyの行列を左右反転させてみ...
PyTorchでOnehotエンコーディングするためのワンライナー... PyTorchでクラスの数字を0,1のベクトルに変形するOnehotベクトルを簡単に書く方法を紹介します。ワンライナーでできます。 TL;DR PyTorchではこれでOnehotエンコーディングができます。 onehot = torch.eye(10) ただし、labelはLongTe...
TensorFlowでもラプラシアンピラミッドを作る... 以前作ったPyTorchのラプラシアンピラミッドをTensorFlow2.0に移植しました。何かと便利なラプラシアンピラミッドをつかってみよう。 環境 TensorFlow2.0 CPUの動作で確認しましたが、一応TPUでも動くように配慮はしました。 PyTorchの元記事 コード Co...
Kerasのジェネレーターでサンプルが列挙される順番について... Kerasの(カスタム)ジェネレーターでサンプルがどの順番で呼び出されるか、1ループ終わったあとにどういう処理がなされるのか調べてみました。ジェネレーターを自分で定義するとモデルの表現の幅は広がるものの、バグが起きやすくなるので「本当に順番が保証されるのか」や「ハマりどころ」を確認します。 0~...
Pocket
LINEで送る

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です