こしあん
2019-07-20

PyTorchでOnehotエンコーディングするためのワンライナー

Pocket
LINEで送る
Delicious にシェア

4.2k{icon} {views}

新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』(A4・195ページ)好評通販中です! 機械学習の入門からGANの最先端までを書いたおすすめの本となっています! Boothで試し読みできます。情報まとめ・質問用GitHub



PyTorchでクラスの数字を0,1のベクトルに変形するOnehotベクトルを簡単に書く方法を紹介します。ワンライナーでできます。

TL;DR

PyTorchではこれでOnehotエンコーディングができます。

onehot = torch.eye(10)[label]

ただし、labelはLongTensorとします。

解説

Numpyの場合と同じです。torch.eyeが単位行列(対角成分が1、他が0の行列)なので、それをインデックスで取り出すことでOnehotエンコーディングになります。

MNISTで確認

MNISTのData Loaderで確認してみます。

import torch
import torchvision
from torchvision import transforms

def load_mnist():
    trans = transforms.Compose([
        transforms.ToTensor()
    ])
    dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=trans, download=True)
    return torch.utils.data.DataLoader(dataset, batch_size=16)

def onehot_convert():
    dataloader = load_mnist()
    for image, label in dataloader:
        print(label) # tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7])
        print(label.type())  # torch.LongTensor
        onehot = torch.eye(10)[label] # 10はクラス数
        print(onehot)
        print(onehot.type()) # torch.FloatTensor
        exit()

if __name__ == "__main__":
    onehot_convert()

組み込みのDataLoaderから読み込ませる場合は、LongTensorで読み込まれるので、型のキャストは不要です。Onehotベクトルの出力は以下のようになります。

tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])

Numpyからのテンソルの場合

NumpyからのテンソルをOnehotエンコーディングするときは、型に気をつける必要があります。LongTensorに対応するのはnp.int64です。

import numpy as np
def from_numpy():
    np.random.seed(123)
    label_numpy = np.random.permutation(10).astype(np.int64) # Numpyの時点でint64にするとLongTensorになる
    label = torch.as_tensor(label_numpy)
    print(label) # tensor([4, 0, 7, 5, 8, 3, 1, 6, 9, 2])
    print(label.type()) # torch.LongTensor
    onehot = torch.eye(10)[label]
    print(onehot)
    print(onehot.type()) # torch.FloatTensor

Numpyの時点でint64にキャストしてしまうのが一例です。このonehotベクトルの出力は次のようになります。

tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

ただし、astypeの部分をnp.int32にすると次のようなエラーが出ます。

IndexError: tensors used as indices must be long, byte or bool tensors

PyTorchにおいてはインデックスの数字はLongでなくてはいけないようです。


新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』好評通販中(A4・195ページ)です! Boothで試し読みもできるのでよろしくね!


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です