こしあん
2019-07-20

PyTorchでOnehotエンコーディングするためのワンライナー

Pocket
LINEで送る


PyTorchでクラスの数字を0,1のベクトルに変形するOnehotベクトルを簡単に書く方法を紹介します。ワンライナーでできます。

TL;DR

PyTorchではこれでOnehotエンコーディングができます。

onehot = torch.eye(10)[label]

ただし、labelはLongTensorとします。

解説

Numpyの場合と同じです。torch.eyeが単位行列(対角成分が1、他が0の行列)なので、それをインデックスで取り出すことでOnehotエンコーディングになります。

MNISTで確認

MNISTのData Loaderで確認してみます。

import torch
import torchvision
from torchvision import transforms

def load_mnist():
    trans = transforms.Compose([
        transforms.ToTensor()
    ])
    dataset = torchvision.datasets.MNIST(root="./data", train=True, transform=trans, download=True)
    return torch.utils.data.DataLoader(dataset, batch_size=16)

def onehot_convert():
    dataloader = load_mnist()
    for image, label in dataloader:
        print(label) # tensor([5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5, 3, 6, 1, 7])
        print(label.type())  # torch.LongTensor
        onehot = torch.eye(10)[label] # 10はクラス数
        print(onehot)
        print(onehot.type()) # torch.FloatTensor
        exit()

if __name__ == "__main__":
    onehot_convert()

組み込みのDataLoaderから読み込ませる場合は、LongTensorで読み込まれるので、型のキャストは不要です。Onehotベクトルの出力は以下のようになります。

tensor([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]])

Numpyからのテンソルの場合

NumpyからのテンソルをOnehotエンコーディングするときは、型に気をつける必要があります。LongTensorに対応するのはnp.int64です。

import numpy as np
def from_numpy():
    np.random.seed(123)
    label_numpy = np.random.permutation(10).astype(np.int64) # Numpyの時点でint64にするとLongTensorになる
    label = torch.as_tensor(label_numpy)
    print(label) # tensor([4, 0, 7, 5, 8, 3, 1, 6, 9, 2])
    print(label.type()) # torch.LongTensor
    onehot = torch.eye(10)[label]
    print(onehot)
    print(onehot.type()) # torch.FloatTensor

Numpyの時点でint64にキャストしてしまうのが一例です。このonehotベクトルの出力は次のようになります。

tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])

ただし、astypeの部分をnp.int32にすると次のようなエラーが出ます。

IndexError: tensors used as indices must be long, byte or bool tensors

PyTorchにおいてはインデックスの数字はLongでなくてはいけないようです。

Related Posts

PyTorchでDCGANやってみた PyTorchでDCGANをやってみました。MNISTとCIFAR-10、STL-10を動かしてみましたがかなり簡単にできました。訓練時間もそこまで長くはないので結構手軽に遊べます。 はじめに PyTorchでDCGANやってみました。コードはほとんどこの記事のコピペです。MNISTとCIFA...
Kerasで転移学習用にレイヤー名とそのインデックスを調べる方法... Kerasで転移学習をするときに、学習済みモデルのレイヤーの名前と、そのインデックス(何番目にあるかということ)の対応を知りたいことがあります。その方法を解説します。 転移学習とは 転移学習とは、ImageNetなど何百万もの大量の画像で事前学習させたモデルを使い、それを「特徴量検出器」として...
Google ColabのTPU環境でmodel.fitのhistoryが消える現象... Google ColabのTPU環境でmodel.fitしたときに、通常の環境で得られるhistory(誤差や精度のログ)が消えていることがあります。その対応法を示します。 原因はTPU用のモデルに変換したから まず結論からいうとこの現象はCPU/GPU環境では再発しません。TPU環境特有の現...
pix2pixを1から実装して白黒画像をカラー化してみた(PyTorch)... pix2pixによる白黒画像のカラー化を1から実装します。PyTorchで行います。かなり自然な色付けができました。pix2pixはGANの中でも理論が単純なのにくわえ、学習も比較的安定しているので結構おすすめです。 はじめに PyTorchでDCGANができたので、今回はpix2pixをやり...
PyTorchでweight clipping WGANの論文見てたらWeight Clippingしていたので、簡単な例を実装して実験してみました。かなり簡単にできます。それを見ていきましょう。 Weight Clippingとは レイヤーの係数の値を一定範囲以内に収める手法。例えば、あるレイヤーが「-2, -1, 0, 1, 2」という...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です