PyTorchで複数出力があるモデルの出力の型について
Posted On 2019-06-28
出力が複数あるモデルの訓練というのは少し複雑なモデルだとよく出てきます。PyTorchでは複数出力のモデルの、出力の型はどうなっているでしょうか。それを見ていきます。中間層の値を取りたい場合も使えます。
目次
サンプルコード
import torch
from torch import nn
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 8, kernel_size=1)
self.conv2 = nn.Conv2d(8, 16, kernel_size=1)
self.conv3 = nn.Conv2d(16, 32, kernel_size=1)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(nn.AvgPool2d(2)(x1))
x3 = self.conv3(nn.AvgPool2d(2)(x2))
return [x1, x2, x3] # listで返しているので
def main():
model = TestModel()
input = torch.randn(4, 3, 16, 16) # batch_size=4, channel=3, x=y=16
out = model(input) # 出力はtorch.Tensorのlist
print(type(out))
for x in out:
print(x.size(), type(x))
if __name__ == "__main__":
main()
ここではconv1, conv2, conv3という複数の畳み込み層があります。各中間層の出力をまとめて返すモデルを試しに作ってみます。
結論からいうと、forwardの返り値の定義に依存します。
出力
上のコードのように、forwardのreturnをlistとして定義すると、出力はtorch.Tensorのlistになります。
<class 'list'>
torch.Size([4, 8, 16, 16]) <class 'torch.Tensor'>
torch.Size([4, 16, 8, 8]) <class 'torch.Tensor'>
torch.Size([4, 32, 4, 4]) <class 'torch.Tensor'>
またforwardの定義を、Tupleにすると、
return (x1, x2, x3) # Tupleにしてみる
出力もTupleになります。forで取り出せるのは変わりませんね。
<class 'tuple'>
torch.Size([4, 8, 16, 16]) <class 'torch.Tensor'>
torch.Size([4, 16, 8, 8]) <class 'torch.Tensor'>
torch.Size([4, 32, 4, 4]) <class 'torch.Tensor'>
つまりは、モデル側のforwardで定義した通りに帰ってくるということです。わかりやすい。
まとめ
PyTorchで出力が複数あるモデルの出力は、モデルのforward定義に依存
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー