こしあん
2019-07-20

PyTorchでOnehotエンコーディングするためのワンライナー

Pocket
LINEで送る


PyTorchでクラスの数字を0,1のベクトルに変形するOnehotベクトルを簡単に書く方法を紹介します。ワンライナーでできます。

TL;DR

PyTorchではこれでOnehotエンコーディングができます。

onehot = torch.eye(10)[label]

ただし、labelはLongTensorとします。

解説

Numpyの場合と同じです。torch.eyeが単位行列(対角成分が1、他が0の行列)なので、それをインデックスで取り出すことでOnehotエンコーディングになります。

MNISTで確認

MNISTのData Loaderで確認してみます。

import torch
import torchvision
from torchvision import transforms

def load_mnist():
    trans = transforms.Compose([
        transforms.ToTensor()
    ])
    dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=trans, download=True)
    return torch.utils.data.DataLoader(dataset, batch_size=16)

def onehot_convert():
    dataloader = load_mnist()
    for image, label in dataloader:
        print(label) # tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7])
        print(label.type())  # torch.LongTensor
        onehot = torch.eye(10)[label] # 10はクラス数
        print(onehot)
        print(onehot.type()) # torch.FloatTensor
        exit()

if __name__ == "__main__":
    onehot_convert()

組み込みのDataLoaderから読み込ませる場合は、LongTensorで読み込まれるので、型のキャストは不要です。Onehotベクトルの出力は以下のようになります。

tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])

Numpyからのテンソルの場合

NumpyからのテンソルをOnehotエンコーディングするときは、型に気をつける必要があります。LongTensorに対応するのはnp.int64です。

import numpy as np
def from_numpy():
    np.random.seed(123)
    label_numpy = np.random.permutation(10).astype(np.int64) # Numpyの時点でint64にするとLongTensorになる
    label = torch.as_tensor(label_numpy)
    print(label) # tensor([4, 0, 7, 5, 8, 3, 1, 6, 9, 2])
    print(label.type()) # torch.LongTensor
    onehot = torch.eye(10)[label]
    print(onehot)
    print(onehot.type()) # torch.FloatTensor

Numpyの時点でint64にキャストしてしまうのが一例です。このonehotベクトルの出力は次のようになります。

tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

ただし、astypeの部分をnp.int32にすると次のようなエラーが出ます。

IndexError: tensors used as indices must be long, byte or bool tensors

PyTorchにおいてはインデックスの数字はLongでなくてはいけないようです。

Related Posts

KerasのCallbackを使って継承したImageDataGeneratorに値が渡せるか確かめ... Kerasで前処理の内容をエポックごとに変えたいというケースがたまにあります。これを実装するとなると、CallbackからGeneratorに値を渡すというコードになりますが、これが本当にできるかどうか確かめてみました。 想定する状況 例えば、前処理で正則化に関係するData Augmenta...
PandasのDataFrameでグループ別にサンプルをN個抜き出す方法... 「PandasでGroupbyでグルーピングしたはいんだけど、そこからグループ別にサンプルを1個、2個…と抜き出す、SQLでよくやるやつってどうやるんだっけ?」ということが気になったので、調べました。ちゃんとした方法があります。 例題 今、中国地方と四国地方の県と面積をDataFrameにして...
TensorFlow2.0+TPUで複数のモデルを交互に訓練するテスト... GANの利用を想定します。以前TPUだと複数のモデルを同時or交互に訓練するというのは厳しかったのですが、これがTF2.0で変わったのか確かめます。 環境:TensorFlow2.0.0、Colab TPU コード おまじない import tensorflow as tf import ...
argparseに直接dictを読み込ませる怪しいやり方... argparseにコマンドライン引数ではなく、ファイルから読み込んだdictをオーバーラップさせる方法を試してみました。本来のargparseの使い方ではない怪しいやり方ですが、JSONやyamlファイルとの連携が可能なので便利ではないかなと思います。 注意 これは本来のargparseの使い...
Kerasのジェネレーターでサンプルが列挙される順番について... Kerasの(カスタム)ジェネレーターでサンプルがどの順番で呼び出されるか、1ループ終わったあとにどういう処理がなされるのか調べてみました。ジェネレーターを自分で定義するとモデルの表現の幅は広がるものの、バグが起きやすくなるので「本当に順番が保証されるのか」や「ハマりどころ」を確認します。 0~...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です