こしあん
2019-06-28

PyTorchで複数出力があるモデルの出力の型について

Pocket
LINEで送る


出力が複数あるモデルの訓練というのは少し複雑なモデルだとよく出てきます。PyTorchでは複数出力のモデルの、出力の型はどうなっているでしょうか。それを見ていきます。中間層の値を取りたい場合も使えます。

サンプルコード

import torch
from torch import nn

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 8, kernel_size=1)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=1)
        self.conv3 = nn.Conv2d(16, 32, kernel_size=1)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(nn.AvgPool2d(2)(x1))
        x3 = self.conv3(nn.AvgPool2d(2)(x2))
        return [x1, x2, x3] # listで返しているので

def main():
    model = TestModel()
    input = torch.randn(4, 3, 16, 16) # batch_size=4, channel=3, x=y=16
    out = model(input) # 出力はtorch.Tensorのlist
    print(type(out))
    for x in out:
        print(x.size(), type(x))

if __name__ == "__main__":
    main()

ここではconv1, conv2, conv3という複数の畳み込み層があります。各中間層の出力をまとめて返すモデルを試しに作ってみます。

結論からいうと、forwardの返り値の定義に依存します。

出力

上のコードのように、forwardのreturnをlistとして定義すると、出力はtorch.Tensorのlistになります。

<class 'list'>
torch.Size([4, 8, 16, 16]) <class 'torch.Tensor'>
torch.Size([4, 16, 8, 8]) <class 'torch.Tensor'>
torch.Size([4, 32, 4, 4]) <class 'torch.Tensor'>

またforwardの定義を、Tupleにすると、

        return (x1, x2, x3) # Tupleにしてみる

出力もTupleになります。forで取り出せるのは変わりませんね。

<class 'tuple'>
torch.Size([4, 8, 16, 16]) <class 'torch.Tensor'>
torch.Size([4, 16, 8, 8]) <class 'torch.Tensor'>
torch.Size([4, 32, 4, 4]) <class 'torch.Tensor'>

つまりは、モデル側のforwardで定義した通りに帰ってくるということです。わかりやすい。

まとめ

PyTorchで出力が複数あるモデルの出力は、モデルのforward定義に依存

Related Posts

TensorFlowでコサイン類似度を計算する方法... TensorFlowで損失関数や距離関数に「コサイン類似度」を使うことを考えます。Scikit-learnでは簡単に計算できますが、同様にTensorFlowでの行列演算でも計算できます。それを見ていきます。 コサイン類似度 コサイン類似度は、ユークリッド距離と同様に距離関数として使われる評価...
TensorFlow2.0+TPUでData AugmentationしながらCIFAR-10... TensorFlow2.0+TPUでData AugmentationしながらCIFAR-10を分類するサンプルです。Data Augmentationはtf.dataでやるのがポイントです。 TensorFlowを2.Xに上げる まずは、ランタイム切り替えで「TPU」を選択しましょう。無料で...
Kerasで複数のラベル(出力)があるモデルを訓練する... Kerasで複数のラベル(出力)のあるモデルを訓練することを考えます。ここでの複数のラベルとは、あるラベルとそれに付随する情報が送られてきて、それを同時に損失関数で計算する例です。これを見ていきましょう。 問題設定 MNISTの分類で、ラベルが奇数のときだけ損失を評価し(categorical...
PyTorchで複数のGPUで訓練するときのSync Batch Normalizationの必要性... PyTorchにはSync Batch Normalizationというレイヤーがありますが、これが通常のBatch Normzalitionと何が違うのか具体例を通じて見ていきます。また、通常のBatch Normは複数GPUでData Parallelするときにデメリットがあるのでそれも確認し...
PyTorch/TorchVisionで複数の入力をモデルに渡したいケース... PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です