こしあん
2019-06-28

PyTorchで複数出力があるモデルの出力の型について

Pocket
LINEで送る


出力が複数あるモデルの訓練というのは少し複雑なモデルだとよく出てきます。PyTorchでは複数出力のモデルの、出力の型はどうなっているでしょうか。それを見ていきます。中間層の値を取りたい場合も使えます。

サンプルコード

import torch
from torch import nn

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=1)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(nn.AvgPool2d(2)(x1))
        x3 = self.conv3(nn.AvgPool2d(2)(x2))
        return [x1, x2, x3] # listで返しているので

def main():
    model = TestModel()
    input = torch.randn(4, 3, 16, 16) # batch_size=4, channel=3, x=y=16
    out = model(input) # 出力はtorch.Tensorのlist
    print(type(out))
    for x in out:
        print(x.size(), type(x))

if __name__ == "__main__":
    main()

ここではconv1, conv2, conv3という複数の畳み込み層があります。各中間層の出力をまとめて返すモデルを試しに作ってみます。

結論からいうと、forwardの返り値の定義に依存します。

出力

上のコードのように、forwardのreturnをlistとして定義すると、出力はtorch.Tensorのlistになります。

<class 'list'>
torch.Size([4, 8, 16, 16]) <class 'torch.Tensor'>
torch.Size([4, 16, 8, 8]) <class 'torch.Tensor'>
torch.Size([4, 32, 4, 4]) <class 'torch.Tensor'>

またforwardの定義を、Tupleにすると、

        return (x1, x2, x3) # Tupleにしてみる

出力もTupleになります。forで取り出せるのは変わりませんね。

<class 'tuple'>
torch.Size([4, 8, 16, 16]) <class 'torch.Tensor'>
torch.Size([4, 16, 8, 8]) <class 'torch.Tensor'>
torch.Size([4, 32, 4, 4]) <class 'torch.Tensor'>

つまりは、モデル側のforwardで定義した通りに帰ってくるということです。わかりやすい。

まとめ

PyTorchで出力が複数あるモデルの出力は、モデルのforward定義に依存

Related Posts

TensorFlowでコサイン類似度を計算する方法... TensorFlowで損失関数や距離関数に「コサイン類似度」を使うことを考えます。Scikit-learnでは簡単に計算できますが、同様にTensorFlowでの行列演算でも計算できます。それを見ていきます。 コサイン類似度 コサイン類似度は、ユークリッド距離と同様に距離関数として使われる評価...
KerasのCallbackを使って継承したImageDataGeneratorに値が渡せるか確かめ... Kerasで前処理の内容をエポックごとに変えたいというケースがたまにあります。これを実装するとなると、CallbackからGeneratorに値を渡すというコードになりますが、これが本当にできるかどうか確かめてみました。 想定する状況 例えば、前処理で正則化に関係するData Augmenta...
WarmupとData Augmentationのバッチサイズ別の精度低下について... 大きいバッチサイズで訓練する際は、バッチサイズの増加にともなう精度低下が深刻になります。この精度低下を抑制することはできるのですが、例えばData Augmentationのようなデータ増強・正則化による精度向上とは何が違うのでしょうか。それを調べてみました。 きっかけ この記事を書いたときに...
条件に応じた配列の要素の抽出をTensorFlowで行う... Numpyで条件を与えて、インデックスのスライスによって配列の要素を抽出する、というようなケースはよくあります。これをTensorFlowのテンソルでやるのにはどうすればいいのでしょうか?それを見ていきます。 Numpyではこんな例 例えば、5×5のランダムな行列をデータとします。この配列を左...
Numpyの配列に対して「最も多く存在する値」を求める方法... アンサンブル学習などで、Numpyの配列のある軸に対して「最も多く存在する値」を求めたい、つまり「多数決」をしたいことがあります。その方法を見ていきます。 最も大きい値がmax, 最も大きい値が存在するインデックスがargmax, では「最も多く存在する値」は? 配列のある軸に対して、「最も大...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です