こしあん
2019-05-13

pipからインストールしたTorchVisionにImageNetがないときの対応

Pocket
LINEで送る
Delicious にシェア

3.1k{icon} {views}

新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』(A4・195ページ)好評通販中です! 機械学習の入門からGANの最先端までを書いたおすすめの本となっています! Boothで試し読みできます。情報まとめ・質問用GitHub



TorchVisionの公式ドキュメントにはImageNetが利用できるとの記述がありますが、pipからインストールするとImageNetのモジュール自体がないことがあります。TorchVisionにImageNetのモジュールを手動でインストールする方法を解説します。

発生状況

  • Python3.7(pip環境でAnacondaの使用はなし)
  • PyTorch1.1+CPUバージョン
  • torchvision v0.2.2.post3

PyTorchとtorchvisionは以下のpipでインストール。公式サイトより

pip3 install https://download.pytorch.org/whl/cpu/torch-1.1.0-cp37-cp37m-win_amd64.whl
pip3 install torchvision

ImageNetがない

TorchVisionのドキュメントを見ると、Imagenet-12が使えるとの記述があります。

しかし、インストール直後に、

import torchvision

def test():
    data_train = torchvision.datasets.ImageNet("./imagenet", download=True)
    data_test = torchvision.datasets.ImageNet("./imagenet", download=True, split="val")

if __name__ == "__main__":
    test()

のようなコードを実行すると、

  File "***.py", line 4, in test
    data_train = torchvision.datasets.ImageNet("./imagenet", download=True)
AttributeError: module 'torchvision.datasets' has no attribute 'ImageNet'

のようにエラーが表示されてデータのダウンロードや読み込みがうまくいきません。

TorchVisionのディレクトリを見ると

Windows10環境の場合ですが、エクスプローラーで「torchvision」を検索すると以下のようなフォルダが出てきました。

imagenet.pyがない??

ディレクトリは、「C:\Users\ユーザー名\AppData\Local\Programs\Python\Python37\Lib\site-packages\torchvision\datasets」にありました。これをdatasetsフォルダとします。

しかし一方で、公式のリポジトリを見ると、確かにソースは存在するのです。まぁごにょごにょな理由なのかもしれませんし、そのうち直るのかもしれません。

ここでは、公式からファイルを手動でコピーしてくる方法を取ります。

手動インストール

  1. datasetsフォルダに公式の「imagenet.py」をコピー
  2. _init_.pyを次のように変更して保存。from~とallの中身の2箇所を変更します。
from .lsun import LSUN, LSUNClass
from .folder import ImageFolder, DatasetFolder
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .stl10 import STL10
from .mnist import MNIST, EMNIST, FashionMNIST, KMNIST
from .svhn import SVHN
from .phototour import PhotoTour
from .fakedata import FakeData
from .semeion import SEMEION
from .omniglot import Omniglot
from .sbu import SBU
from .flickr import Flickr8k, Flickr30k
from .voc import VOCSegmentation, VOCDetection
from .cityscapes import Cityscapes
from .imagenet import ImageNet # ここを追加

__all__ = ('LSUN', 'LSUNClass',
           'ImageFolder', 'DatasetFolder', 'FakeData',
           'CocoCaptions', 'CocoDetection',
           'CIFAR10', 'CIFAR100', 'EMNIST', 'FashionMNIST', 'ImageNet', # ここに'ImageNet'を追加
           'MNIST', 'KMNIST', 'STL10', 'SVHN', 'PhotoTour', 'SEMEION',
           'Omniglot', 'SBU', 'Flickr8k', 'Flickr30k',
           'VOCSegmentation', 'VOCDetection', 'Cityscapes')

試してみる

先程のコードを実行してみると、

  0%|                                                                      | 77717504/147897477120 [00:16<23:36:47, 1738894.67it/s]

確かにダウンロードが始まりました。しかしとんでもない容量(Trainだけで147GB)なので、暇なときにダウンロードするのをおすすめします。


新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』好評通販中(A4・195ページ)です! Boothで試し読みもできるのでよろしくね!


Pocket
LINEで送る

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です