こしあん
2021-02-06

tf.keras.models.Modelのsave_weightsのあれこれ:オプティマイザーの値を復元するには


9.9k{icon} {views}


Kerasでモデルの保存するとき、save_weightsの関数をよく使います。しかし、オプティマイザーの値を復元して訓練再開しようとするとかなりややこしいことになります。モデルの値の復元は簡単ですが、オプティマイザーの復元をどうやるのか探ってみました。

結論

まずは先に結論から。

環境:TensorFlow2.3.1 GPU

  • save_weightsのsave_formatには“tf”“h5”の2種類を指定できる
  • どちらのフォーマットでもモデルの値は保存される。この復元には特殊な操作はいらずただ、load_weightsすればよい
  • save_formatがtfの場合は、オプティマイザーの値を保存できるが、復元するにはダミーの勾配を適用したあとにload_weightsしないといけない
  • save_formatがh5の場合は、オプティマイザーの値は保存されない

なぜオプティマイザーの値の復元が必要か

そもそも「オプティマイザーの値とはなにか、モデルの値(係数)とは何が違うのか」というと、例えばAdamの移動平均の計算に使うパラメーターを指します。

これは訓練を途中でやめて再開するときに意識する必要が出ます。具体的なシチュエーションとしては、GPUがクラッシュしたときや、Colabの時間制限で再開するときです。

オプティマイザーの値を復元しなくても訓練再開自体は可能ですが、再開時に移動平均などが初期値から計算されるので、勾配の蓄積が進むまでは変な方向に勾配降下法が進みます。訓練再開直後にロスが暴れるのはこういう理由です。学習率のWarmupをし、勾配の蓄積が進むまでの学習率を小さくするという方法でも対処できます。

今回はより直接的に、モデルのチェックポイントからオプティマイザーの係数を復元する方法を見ていきます。訓練終了直前の状態を復元するのが目標です。

訓練サンプルコード

今回訓練のコードはどうでもよくて、適当なモデルを作って、h5とtf形式でsave_weightsします。CIFAR-10を適当に1エポック訓練します。

import tensorflow as tf
import tensorflow.keras.layers as layers

def conv_bn_relu(inputs, chs, reps):
    x = inputs
    for i in range(reps):
        x = layers.Conv2D(chs, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
    return x

def create_model():
    inputs = layers.Input((32, 32, 3))
    x = conv_bn_relu(inputs, 64, 3)
    x = layers.AveragePooling2D(2)(x)
    x = conv_bn_relu(x, 128, 3)
    x = layers.AveragePooling2D(2)(x)
    x = conv_bn_relu(x, 256, 3)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(10, activation="softmax")(x)
    return tf.keras.models.Model(inputs, x)

def perprocess(img, label):
    img = tf.cast(img, tf.float32) / 255.0
    label = tf.cast(label, tf.float32)
    return img, label

def train():
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
    trainset = tf.data.Dataset.from_tensor_slices((X_train, y_train)
        ).map(perprocess).shuffle(4096).batch(128).repeat().prefetch(50)

    model = create_model()
    model.compile("adam", "sparse_categorical_crossentropy", ["acc"])

    model.fit(trainset, steps_per_epoch=50000//128, epochs=1)        
    # 'Adam/conv2d/kernel/m:0' shape=(3, 3, 3, 64) 
    print(model.optimizer.weights[1][0, 0, 0,:10])
    # <tf.Variable 'conv2d/kernel:0' shape=(3, 3, 3, 64)
    print(model.weights[0][0, 0, 0,:10])

    model.save_weights("model_tf.ckpt", save_format="tf")  # デフォルト
    model.save_weights("model_h5.h5", save_format="h5")  

「model_tf.ckpt」と「model_h5.h5」という2つのモデルが出力されます。save_modelではtfとh5の2種類の保存形式が選べます。保存形式の切り替えは「save_format」で指定します。デフォルトはtfです。TensorFlowのソースコードによると次のように判定されます。

  • save_formatに何も指定しない場合
    • 保存のファイルパスの拡張子が「.h5」「.keras」「.hdf5」の場合はh5
    • そうでなければtf
  • save_formatを指定した場合
    • save_formatが「h5」「keras」「hdf5」ならh5
    • save_formatが「tensorflow」か「tf」ならtf

保存前のオプティマイザーとモデルの係数は次の通りです。それぞれ同一のカーネルを指しています(見やすいように先頭の10要素を出します)。

// オプティマイザーの値
tf.Tensor(
[-0.00127623  0.00114414  0.00413108  0.01570049  0.00921293 -0.01972224
 -0.00671442 -0.00677514 -0.02500603  0.00259613], shape=(10,), dtype=float32)
// モデルの値
tf.Tensor(
[ 0.03293408  0.0522718   0.0577691   0.01291534 -0.0351864   0.02534967
 -0.0685387  -0.08971363 -0.05289717 -0.0040391 ], shape=(10,), dtype=float32)

オプティマイザーの値とモデルの値は一致する保証はありません。なぜなら、Adamは移動平均取っているからです。一致させたいのはモデルのロードの前後のオプティマイザーの値です。

つぎに、モデルの読み込み時の処理を比較してみます。

save_formatがtfかつ、コンパイルしないで読み込む

モデルの推論で出てくる書き方です。モデルの値は復元され、推論自体はできますが、オプティマイザーの値は復元されません。警告メッセージがざーっと出ます。

def load_tf_wo_compile():
    model = create_model()
    model.load_weights("model_tf.ckpt")
    # 係数は読めるがオプティマイザーの係数が見つからないと怒られる
    print(model.weights[0][0, 0, 0,:10])
    print(model.optimizer)
tf.Tensor(
[ 0.03293408  0.0522718   0.0577691   0.01291534 -0.0351864   0.02534967
 -0.0685387  -0.08971363 -0.05289717 -0.0040391 ], shape=(10,), dtype=float32)
None
WARNING: Logging before flag parsing goes to stderr.
W0206 00:16:41.323693 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer
W0206 00:16:41.326685 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer.iter
W0206 00:16:41.326685 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer.beta_1
W0206 00:16:41.327682 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer.beta_2
W0206 00:16:41.327682 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer.decay
W0206 00:16:41.327682 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer.learning_rate
W0206 00:16:41.328680 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
W0206 00:16:41.328680 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
W0206 00:16:41.329677 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.gamma
W0206 00:16:41.329677 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-1.beta
W0206 00:16:41.331672 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-2.kernel
W0206 00:16:41.332669 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-2.bias
W0206 00:16:41.334664 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-3.gamma
W0206 00:16:41.334664 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-3.beta
W0206 00:16:41.336658 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-4.kernel
W0206 00:16:41.342642 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-4.bias
W0206 00:16:41.343640 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-5.gamma
W0206 00:16:41.345635 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-5.beta
W0206 00:16:41.345635 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-6.kernel
W0206 00:16:41.347629 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-6.bias
W0206 00:16:41.348626 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-7.gamma
W0206 00:16:41.348626 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-7.beta
W0206 00:16:41.350621 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-8.kernel
W0206 00:16:41.351618 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-8.bias
W0206 00:16:41.353613 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-9.gamma
W0206 00:16:41.353613 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-9.beta
W0206 00:16:41.358600 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-10.kernel
W0206 00:16:41.359597 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-10.bias
W0206 00:16:41.361592 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-11.gamma
W0206 00:16:41.361592 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-11.beta
W0206 00:16:41.362589 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-12.kernel
W0206 00:16:41.364584 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-12.bias
W0206 00:16:41.364584 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-13.gamma
W0206 00:16:41.365581 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-13.beta
W0206 00:16:41.367576 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-14.kernel
W0206 00:16:41.367576 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-14.bias
W0206 00:16:41.368573 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-15.gamma
W0206 00:16:41.374557 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-15.beta
W0206 00:16:41.375555 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-16.kernel
W0206 00:16:41.377549 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-16.bias
W0206 00:16:41.377549 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-17.gamma
W0206 00:16:41.378547 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-17.beta
W0206 00:16:41.380541 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-18.kernel
W0206 00:16:41.381538 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-18.bias
W0206 00:16:41.383533 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
W0206 00:16:41.384530 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
W0206 00:16:41.385756 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.gamma
W0206 00:16:41.385756 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-1.beta
W0206 00:16:41.386756 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-2.kernel
W0206 00:16:41.392184 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-2.bias
W0206 00:16:41.392184 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-3.gamma
W0206 00:16:41.393184 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-3.beta
W0206 00:16:41.395178 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-4.kernel
W0206 00:16:41.395178 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-4.bias
W0206 00:16:41.396175 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-5.gamma
W0206 00:16:41.398170 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-5.beta
W0206 00:16:41.398170 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-6.kernel
W0206 00:16:41.398170 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-6.bias
W0206 00:16:41.401100 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-7.gamma
W0206 00:16:41.401100 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-7.beta
W0206 00:16:41.402100 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-8.kernel
W0206 00:16:41.404133 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-8.bias
W0206 00:16:41.408620 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-9.gamma
W0206 00:16:41.408620 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-9.beta
W0206 00:16:41.410616 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-10.kernel
W0206 00:16:41.410616 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-10.bias
W0206 00:16:41.411611 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-11.gamma
W0206 00:16:41.413606 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-11.beta
W0206 00:16:41.413606 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-12.kernel
W0206 00:16:41.414603 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-12.bias
W0206 00:16:41.416598 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-13.gamma
W0206 00:16:41.416598 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-13.beta
W0206 00:16:41.416598 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-14.kernel
W0206 00:16:41.419591 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-14.bias
W0206 00:16:41.419591 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-15.gamma
W0206 00:16:41.426572 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-15.beta
W0206 00:16:41.426572 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-16.kernel
W0206 00:16:41.427568 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-16.bias
W0206 00:16:41.428591 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-17.gamma
W0206 00:16:41.429563 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-17.beta
W0206 00:16:41.429563 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-18.kernel
W0206 00:16:41.431586 26668 util.py:150] Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-18.bias
W0206 00:16:41.432555 26668 util.py:158] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) 
but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

モデルがコンパイルされていないので、オプティマイザーがモデルに登録されていません。「model.optimizerがNone」になるのはそういう理由です。

「Unresolved object in checkpoint: (root).optimizer’s state 」というのが「オプティマイザーの状態が解決できなくて復元できない」ということにあたります。オプティマイザーが登録されていないから当たり前です。「tf.train.Checkpoint云々」と出ていますが、これはKerasのload_modelなので、元のチェックポイントの処理までできればたどりたくありません。

save_formatがtfかつ、コンパイルして読み込む

「オプティマイザーはコンパイルすれば登録されるんやろ。ならload_weightsする前にコンパイルすればいいやん」と思うかもしれませんが、そう簡単にはいきません。

def load_tf_w_compile():
    model = create_model()
    model.compile("adam", "sparse_categorical_crossentropy", ["acc"])
    model.load_weights("model_tf.ckpt")
    print(model.weights[0][0, 0, 0,:10])
    # コンパイルしても別のオプティマイザーなので値が保持されない
    print(model.optimizer.weights[1][0, 0, 0,:10])
tf.Tensor(
[ 0.03293408  0.0522718   0.0577691   0.01291534 -0.0351864   0.02534967
 -0.0685387  -0.08971363 -0.05289717 -0.0040391 ], shape=(10,), dtype=float32)
tf.Tensor(
[0.00163518 0.00058693 0.00144919 0.00051345 0.00025841 0.00122939
 0.00028696 0.00024856 0.0035538  0.00060467], shape=(10,), dtype=float32)
WARNING: Logging before flag parsing goes to stderr.
W0206 00:27:26.504664  8500 util.py:150] Unresolved object in checkpoint: (root).optimizer.iter
W0206 00:27:26.505199  8500 util.py:150] Unresolved object in checkpoint: (root).optimizer.beta_1
W0206 00:27:26.505199  8500 util.py:150] Unresolved object in checkpoint: (root).optimizer.beta_2
W0206 00:27:26.505199  8500 util.py:150] Unresolved object in checkpoint: (root).optimizer.decay
W0206 00:27:26.506199  8500 util.py:150] Unresolved object in checkpoint: (root).optimizer.learning_rate
W0206 00:27:26.506199  8500 util.py:158] A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) 
but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

エラーの数は先程より減りましたが、保存前のオプティマイザーの値は「-0.00127623 0.00114414 0.00413108 …」だったのでこれは違います(バグなんですかねこれ:TF2.3.1環境)

save_formatがtfかつ、コンパイル+ダミーの勾配を適用して読み込む【成功例】

次の方法でようやくオプティマイザーの値を正常に読み込めました。

def load_tf_w_zero_grad():
    model = create_model()
    model.compile("adam", "sparse_categorical_crossentropy", ["acc"])

    zero_grad = [tf.zeros_like(x) for x in model.weights]
    model.optimizer.apply_gradients(zip(zero_grad, model.weights))

    model.load_weights("model_tf.ckpt")
    # これでようやくオプティマイザーの値も同一になる
    print(model.weights[0][0, 0, 0,:10])
    print(model.optimizer.weights[1][0, 0, 0,:10])
tf.Tensor(
[ 0.03293408  0.0522718   0.0577691   0.01291534 -0.0351864   0.02534967
 -0.0685387  -0.08971363 -0.05289717 -0.0040391 ], shape=(10,), dtype=float32)
tf.Tensor(
[-0.00127623  0.00114414  0.00413108  0.01570049  0.00921293 -0.01972224
 -0.00671442 -0.00677514 -0.02500603  0.00259613], shape=(10,), dtype=float32)

オプティマイザーの値が訓練終了時と同じになりました。追加のエラーメッセージも消えました。

ここでやったことは、「全て0の勾配(ダミーの勾配)を1回オプティマイザーに適用した」ということです。これが入るとオプティマイザー側に係数がセットされるのでしょう。ダミーの勾配を適用してからload_weightsすると全て元通りになりました。

h5形式だとオプティマイザーの値は保存されない

同じことをh5形式でやってもオプティマイザーの値は復元されません

def load_h5_w_zero_grad():
    model = create_model()
    model.compile("adam", "sparse_categorical_crossentropy", ["acc"])

    zero_grad = [tf.zeros_like(x) for x in model.weights]
    model.optimizer.apply_gradients(zip(zero_grad, model.weights))

    model.load_weights("model_h5.h5")
    # h5はオプティマイザーの値は持っていない
    print(model.weights[0][0, 0, 0,:10])
    print(model.optimizer.weights[1][0, 0, 0,:10])
tf.Tensor(
[ 0.03293408  0.0522718   0.0577691   0.01291534 -0.0351864   0.02534967
 -0.0685387  -0.08971363 -0.05289717 -0.0040391 ], shape=(10,), dtype=float32)
tf.Tensor([0. 0. 0. 0. 0. 0. 0. 0. 0. 0.], shape=(10,), dtype=float32) // 全て0になる

h5のケースではオプティマイザーの値が全て0になってしまいました。

それはそうで、h5とtf形式ではファイルサイズが異なります。h5では7609KBだったのが、tfだと22591KBでした。保存している情報がtfのほうが多く、オプティマイザーの値はh5では持っていないが、tfだと持っているということを表します。

これはh5の保存部分のソースからもわかります。h5の保存では、

    if save_format == 'h5':
      with h5py.File(filepath, 'w') as f:
        hdf5_format.save_weights_to_hdf5_group(f, self.layers)

となっており、モデルのレイヤーは保存対象ですが、付随するオプティマイザーは保存対象ではありません。

tfとh5両方保存すれば良くない?

なら「h5いらない子じゃん」と思うかもしれませんが、h5の場合だとload_weightsしたときにコンパイルしなくてもエラーは出ません。h5は推論特化でしょうね。

他にも例えばTPUで保存したモデルをCPUで読み込むときに、tf形式だと失敗してもh5形式だと復元できたりするので、容量が許すのなら、念の為両方の形式で保存しておくのが良いのではないでしょうか。わざわざ訓練したのに、推論時に読み込めないことがわかったときの悲しさは果てしない。

とりあえず「tf形式にして復元時にダミーの勾配を適用してからload_weightsしないと、オプティマイザーの値は完全には復元されませんでした」という報告を書いて終わりにしたいと思います。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です