PyTorchでOnehotエンコーディングするためのワンライナー
Posted On 2019-07-20
PyTorchでクラスの数字を0,1のベクトルに変形するOnehotベクトルを簡単に書く方法を紹介します。ワンライナーでできます。
目次
TL;DR
PyTorchではこれでOnehotエンコーディングができます。
onehot = torch.eye(10)[label]
ただし、labelはLongTensorとします。
解説
Numpyの場合と同じです。torch.eyeが単位行列(対角成分が1、他が0の行列)なので、それをインデックスで取り出すことでOnehotエンコーディングになります。
MNISTで確認
MNISTのData Loaderで確認してみます。
import torch
import torchvision
from torchvision import transforms
def load_mnist():
trans = transforms.Compose([
transforms.ToTensor()
])
dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=trans, download=True)
return torch.utils.data.DataLoader(dataset, batch_size=16)
def onehot_convert():
dataloader = load_mnist()
for image, label in dataloader:
print(label) # tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7])
print(label.type()) # torch.LongTensor
onehot = torch.eye(10)[label] # 10はクラス数
print(onehot)
print(onehot.type()) # torch.FloatTensor
exit()
if __name__ == "__main__":
onehot_convert()
組み込みのDataLoaderから読み込ませる場合は、LongTensorで読み込まれるので、型のキャストは不要です。Onehotベクトルの出力は以下のようになります。
tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])
Numpyからのテンソルの場合
NumpyからのテンソルをOnehotエンコーディングするときは、型に気をつける必要があります。LongTensorに対応するのはnp.int64です。
import numpy as np
def from_numpy():
np.random.seed(123)
label_numpy = np.random.permutation(10).astype(np.int64) # Numpyの時点でint64にするとLongTensorになる
label = torch.as_tensor(label_numpy)
print(label) # tensor([4, 0, 7, 5, 8, 3, 1, 6, 9, 2])
print(label.type()) # torch.LongTensor
onehot = torch.eye(10)[label]
print(onehot)
print(onehot.type()) # torch.FloatTensor
Numpyの時点でint64にキャストしてしまうのが一例です。このonehotベクトルの出力は次のようになります。
tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
[0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
[0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
[0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])
ただし、astypeの部分をnp.int32にすると次のようなエラーが出ます。
IndexError: tensors used as indices must be long, byte or bool tensors
PyTorchにおいてはインデックスの数字はLongでなくてはいけないようです。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー