こしあん
2019-07-20

PyTorchでOnehotエンコーディングするためのワンライナー

Pocket
LINEで送る


PyTorchでクラスの数字を0,1のベクトルに変形するOnehotベクトルを簡単に書く方法を紹介します。ワンライナーでできます。

TL;DR

PyTorchではこれでOnehotエンコーディングができます。

onehot = torch.eye(10)[label]

ただし、labelはLongTensorとします。

解説

Numpyの場合と同じです。torch.eyeが単位行列(対角成分が1、他が0の行列)なので、それをインデックスで取り出すことでOnehotエンコーディングになります。

MNISTで確認

MNISTのData Loaderで確認してみます。

import torch
import torchvision
from torchvision import transforms

def load_mnist():
    trans = transforms.Compose([
        transforms.ToTensor()
    ])
    dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=trans, download=True)
    return torch.utils.data.DataLoader(dataset, batch_size=16)

def onehot_convert():
    dataloader = load_mnist()
    for image, label in dataloader:
        print(label) # tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7])
        print(label.type())  # torch.LongTensor
        onehot = torch.eye(10)[label] # 10はクラス数
        print(onehot)
        print(onehot.type()) # torch.FloatTensor
        exit()

if __name__ == "__main__":
    onehot_convert()

組み込みのDataLoaderから読み込ませる場合は、LongTensorで読み込まれるので、型のキャストは不要です。Onehotベクトルの出力は以下のようになります。

tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])

Numpyからのテンソルの場合

NumpyからのテンソルをOnehotエンコーディングするときは、型に気をつける必要があります。LongTensorに対応するのはnp.int64です。

import numpy as np
def from_numpy():
    np.random.seed(123)
    label_numpy = np.random.permutation(10).astype(np.int64) # Numpyの時点でint64にするとLongTensorになる
    label = torch.as_tensor(label_numpy)
    print(label) # tensor([4, 0, 7, 5, 8, 3, 1, 6, 9, 2])
    print(label.type()) # torch.LongTensor
    onehot = torch.eye(10)[label]
    print(onehot)
    print(onehot.type()) # torch.FloatTensor

Numpyの時点でint64にキャストしてしまうのが一例です。このonehotベクトルの出力は次のようになります。

tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

ただし、astypeの部分をnp.int32にすると次のようなエラーが出ます。

IndexError: tensors used as indices must be long, byte or bool tensors

PyTorchにおいてはインデックスの数字はLongでなくてはいけないようです。

Related Posts

Pillowでグレースケール化するときに3チャンネルで出力するテクニック... カラー画像をグレースケール化すると1チャンネルの出力になりますが、カラー画像と同時に扱うと整合性からグレースケールでも3チャンネルで欲しいことがあります。Numpyのブロードキャストを使わずに簡単に3チャンネルで出力する方法を書きます。 グレースケール化してカラー化すればOK これでOK i...
TensorFlow2.0+TPUで複数のモデルを交互に訓練するテスト... GANの利用を想定します。以前TPUだと複数のモデルを同時or交互に訓練するというのは厳しかったのですが、これがTF2.0で変わったのか確かめます。 環境:TensorFlow2.0.0、Colab TPU コード おまじない import tensorflow as tf import ...
TensorFlowでもラプラシアンピラミッドを作る... 以前作ったPyTorchのラプラシアンピラミッドをTensorFlow2.0に移植しました。何かと便利なラプラシアンピラミッドをつかってみよう。 環境 TensorFlow2.0 CPUの動作で確認しましたが、一応TPUでも動くように配慮はしました。 PyTorchの元記事 コード Co...
GANでGeneratorの損失関数をmin(log(1-D))からmaxlog Dにした場合の実験... GANの訓練をうまくいくためのTipとしてよく引用される、How to train GANの中から、Generatorの損失関数をmin(log(1-D))からmaxlog Dにした場合を実験してみました。その結果、損失結果を変更しても出力画像のクォリティーには大して差が出ないことがわかりました。...
PyTorchでガウシアンピラミッド+ラプラシアンピラミッド(Gaussian/Laplacian ... Progressive-GANの論文で、SWD(Sliced Wasserstein Distance)が評価指標として出てきたので、その途中で必要になったガウシアンピラミッド、ラプラシアンピラミッドをPyTorchで実装してみました。これらのピラミッドはGAN関係なく、画像処理一般で使えるものです...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です