こしあん
2018-10-01

tensorflow.kerasでKeras方式のhdf5で重みを保存する方法

Pocket
LINEで送る

従来のKerasで係数を保存すると「hdf5」形式で保存されたのですが、TPU環境などでTensorFlowのKerasAPIを使うと、TensorFlow形式のチェックポイントまるごと保存で互換性の面で困ったことがおきます。従来のKerasのhdf5形式で保存する方法を紹介します。

サンプルコード

これはGoogle ColabのTPUでMNISTを分類するコードです。

import tensorflow as tf
from keras.datasets import mnist
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from keras.utils import to_categorical
import numpy as np
import os
from tensorflow.contrib.tpu.python.tpu import keras_support

(X_train, y_train), (_, _) = mnist.load_data()
X_train = X_train / 255.0
y_train = to_categorical(y_train)
X_train = X_train.reshape(X_train.shape[0], -1)

input = Input((784,))
x = Dense(64, activation="relu")(input)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)
model.compile(Adam(), loss="categorical_crossentropy", metrics=["acc"])

tpu_grpc_url = "grpc://"+os.environ["COLAB_TPU_ADDR"]
model = tf.contrib.tpu.keras_to_tpu_model(model, 
            strategy=keras_support.TPUDistributionStrategy(
                        tf.contrib.cluster_resolver.TPUClusterResolver(tpu_grpc_url)))

model.fit(X_train, y_train, epochs=10, batch_size=1024)

model.save_weights("./weights.hdf5")

これを保存すると次のようにファイルがいっぱいできます。

> !ls
checkpoint  sample_data  weights.hdf5.data-00000-of-00001  weights.hdf5.index

これはTensorFlow形式で保存されてしまっているためです。「weights.hdf5.data-00000-of-00001」をwights.hdf5とリネームしてKerasで読み込ませても、そもそもHDF5形式ではないためエラーになってしまいます。どうすればよいでしょうか?

原因はTensorFlow1.9.0の仕様変更

TensorFlow1.9.0のリリースノートに、

tf.keras:
: :
tf.keras.Model.save_weights は今ではデフォルトで TensorFlow フォーマットでセーブします。

TensorFlow 1.9.0 リリースノート

これが全ての原因です。デフォルトということは、何らかのオプションを設定すればHDF5形式で保存できそうな感じはします。TensorFlowのソースコードを見てみました。save_weightsのコードのコメントにありました。

    Arguments:
        filepath: String, path to the file to save the weights to. When saving
            in TensorFlow format, this is the prefix used for checkpoint files
            (multiple files are generated). Note that the '.h5' suffix causes
            weights to be saved in HDF5 format.
        overwrite: Whether to silently overwrite any existing file at the
            target location, or provide the user with a manual prompt.
        save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
            '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
            `None` defaults to 'tf'.

https://github.com/tensorflow/tensorflow/blob/ad872f220df6808e8a5fcb926480f87cb2371dfd/tensorflow/python/keras/engine/network.py

つまり、save_formatの引数を=”h5″にしてsave_weightsすればよさそうですね。やってみましょう。

model.save_weights("./weights.hdf5", save_format="h5")
> !ls
sample_data  weights.hdf5

うまくいきました。HDF5形式で保存できています。ダウンロードして確認してみましょう。これはColab上での操作なので、他の環境ではgoogle.colabがインストールされていないと思います。

from google.colab import files
files.download("weights.hdf5")

カレントディレクトリに「weights.hdf5」をコピーします。そして係数を読み込みます。

from keras.layers import Dense, Input
from keras.models import Model

input = Input((784,))
x = Dense(64, activation="relu")(input)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)

model.load_weights("weights.hdf5")

weights = model.get_weights()
print(weights)
[array([[ 0.03950901,  0.07999339, -0.04398718, ..., -0.02395174,
         0.06044701, -0.0060069 ],
       [ 0.0700544 ,  0.06324833, -0.01576125, ...,  0.04512443,
        -0.00077055,  0.0424362 ],
       [ 0.04617605,  0.02478255,  0.00991695, ..., -0.06699679,
        -0.00292164, -0.05890182],
       ...,
       [ 0.07553779,  0.04920203, -0.05630066, ...,  0.0593554 ,
         0.08149198, -0.02658052],
       [ 0.04290282,  0.00971551, -0.02268285, ...,  0.01220566,
        -0.05852858, -0.02812307],
       [-0.05941847, -0.03612577,  0.05638266, ..., -0.04648735,
         0.07260651,  0.0159335 ]], dtype=float32), array([-0.0478633 , -0.06754
491,  0.07281522,  0.02968462,  0.05135746,
        0.05843485,  0.0194768 ,  0.00995919, -0.03050879,  0.12237937,
       -0.01765378,  0.08142806, -0.0467488 ,  0.04426579,  0.09194406,
        0.07080972,  0.09837534,  0.14349514, -0.07120208, -0.02860033,
       -0.08540137,  0.06272363,  0.14404611, -0.0416419 , -0.02341138,
       -0.00632342, -0.01621706, -0.07912031,  0.01071538,  0.07026922,
       -0.03116987,  0.02629776,  0.11185876, -0.09980662,  0.02117014,
        0.11517286,  0.03370601,  0.03579468, -0.01941629,  0.08394724,
        0.0734622 ,  0.06467377, -0.02742913, -0.09451034,  0.06308644,
       -0.00315004,  0.0798418 ,  0.09963303,  0.07617176,  0.05602382,
        0.01201982,  0.09839159, -0.01821309,  0.1587676 , -0.02780196,
        0.0340536 ,  0.0199388 , -0.00435052, -0.04387056,  0.1445573 ,
       -0.05228622, -0.04837526,  0.02425369,  0.00256828], dtype=float32), arra
y([[ 1.05951533e-01, -3.05032909e-01,  2.14703709e-01,
        -2.55714357e-01,  3.30205917e-01, -2.42590472e-01,
         2.70795047e-01, -2.87338555e-01,  1.34354711e-01,
(以下略)

OKです。このようにColab上ではtensorflow.kerasのKerasAPIで訓練して、ローカルではKerasのAPIを使うということもある程度はできます。これでTPUで訓練させた係数をローカルで読み込むということができますね。

Related Posts

Kerasのバックエンドで「○○以上☓☓以下」を計算する方法... Kerasのバックエンド関数を使ったときに「○○以上☓☓以下」を求めたい場合があります。しかし、KerasではAndのような論理演算をすると少し困ることがあります。その方法を解説します。 10×10の行列 0~99の数字を並べた「10×10」の行列を用意します。TensorFlowのテンソルで...
PCA Color Augmentationを拡張してTensorFlow/Keras向けに実装した... PCA Color AugmentationはAlexNetの論文に示された画像向けのData Augmentationですが、画像用だけではなく、テンソルの固有値分解をすることで構造化データに対しても使えるようにしてみました。これの解説と効果を書きます。 リポジトリ こちら https://...
TPUで学習率減衰させる方法 TPUで学習率減衰したいが、TensorFlowのオプティマイザーを使うべきか、tf.kerasのオプティマイザーを使うべきか、あるいはKerasのオプティマイザーを使うべきか非常にややこしいことがあります。TPUで学習率を減衰させる方法を再現しました。 結論から TPU環境でtf.keras...
Kerasに組み込まれているDenseNet(121/169/201)の実装... TL;DR パラメーター数 DenseNet-121 : 8,062,504 DenseNet-169 : 14,307,880 DenseNet-201 : 20,242,984 DenseNet-121のsummary ________________________________...
Google Colaboratoryで保存したKerasのモデルを読み込むとValueError... Google Colaboratory(Colab)上のKerasでh5形式で保存したモデルをダウンロードして、load_modelすると「TypeError: ('Keyword argument not understood:', 'data_format')」とエラーが発生して読み込めないこ...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です