こしあん
2019-10-17

TensorFlow2.0+TPUで複数のモデルを交互に訓練するテスト

Pocket
LINEで送る


GANの利用を想定します。以前TPUだと複数のモデルを同時or交互に訓練するというのは厳しかったのですが、これがTF2.0で変わったのか確かめます。

環境:TensorFlow2.0.0、Colab TPU

コード

おまじない

import tensorflow as tf
import os
# tpu用
# 詳細 https://www.tensorflow.org/guide/distributed_training#tpustrategy
tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
tf.config.experimental_connect_to_cluster(tpu_cluster_resolver) # TF2.0の場合、ここを追加
tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) # TF2.0の場合、今後experimentialが取れる可能性がある    
strategy = tf.distribute.experimental.TPUStrategy(tpu_cluster_resolver)  # ここも同様

KerasのAPIを使うのではなくカスタムの訓練ループを書きます。

import tensorflow as tf
import tensorflow.keras.layers as layers

def model_a():
    inputs = layers.Input((1,))
    x = layers.Dense(2, use_bias=False)(inputs)
    return tf.keras.models.Model(inputs, x)

def model_b():
    inputs = layers.Input((2,))
    x = layers.Dense(1, use_bias=False)(inputs)
    return tf.keras.models.Model(inputs, x)

def main():
    with strategy.scope():
        ma = model_a()
        mb = model_b()
        loss_func = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
        optima = tf.keras.optimizers.Adam()
        optimb = tf.keras.optimizers.Adam()

        print("Before update")
        print(ma.trainable_weights)
        print(mb.trainable_weights)

        def train_a():
            # update A
            with tf.GradientTape() as tape:
                x = tf.random.normal(shape=(2,1))
                y_pred = mb(ma(x))
                y = tf.ones(shape=(2, 1))
                loss = loss_func(y, y_pred)
            grad = tape.gradient(loss, ma.trainable_weights)
            optima.apply_gradients(zip(grad, ma.trainable_weights))

        def train_b():
            # update B
            with tf.GradientTape() as tape:
                x = tf.random.normal(shape=(2,1))
                y_pred = mb(ma(x))
                y = tf.ones(shape=(2, 1))
                loss = loss_func(y, y_pred)
            grad = tape.gradient(loss, mb.trainable_weights)
            optimb.apply_gradients(zip(grad, mb.trainable_weights))

        @tf.function
        def distributed_train_a():
            return strategy.experimental_run_v2(train_a)

        @tf.function
        def distributed_train_b():
            return strategy.experimental_run_v2(train_b)

        distributed_train_a()

        print("Update only A ")
        print(ma.trainable_weights)
        print(mb.trainable_weights)

        distributed_train_b()

        print("Update only B")
        print(ma.trainable_weights)
        print(mb.trainable_weights)

if __name__ == "__main__":
    main()

「ModelB(ModelA(input))」のような出力をして、あとは適当に損失関数を入れ、「Aのみ訓練→Bのみ訓練」がうまく行っているか確かめます。

結果

Before update
[TPUMirroredVariable:{
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense/kernel/replica_1:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense/kernel/replica_2:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense/kernel/replica_3:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense/kernel/replica_4:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense/kernel/replica_5:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense/kernel/replica_6:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense/kernel/replica_7:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>
}]
[TPUMirroredVariable:{
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense_1/kernel/replica_1:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense_1/kernel/replica_2:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense_1/kernel/replica_3:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense_1/kernel/replica_4:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense_1/kernel/replica_5:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense_1/kernel/replica_6:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense_1/kernel/replica_7:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>
}]
Update only A 
[TPUMirroredVariable:{
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense/kernel/replica_1:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense/kernel/replica_2:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense/kernel/replica_3:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense/kernel/replica_4:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense/kernel/replica_5:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense/kernel/replica_6:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense/kernel/replica_7:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>
}]
[TPUMirroredVariable:{
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense_1/kernel/replica_1:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense_1/kernel/replica_2:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense_1/kernel/replica_3:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense_1/kernel/replica_4:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense_1/kernel/replica_5:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense_1/kernel/replica_6:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,
  7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense_1/kernel/replica_7:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>
}]
Update only B
[TPUMirroredVariable:{
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense/kernel/replica_1:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense/kernel/replica_2:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense/kernel/replica_3:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense/kernel/replica_4:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense/kernel/replica_5:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense/kernel/replica_6:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense/kernel/replica_7:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>
}]
[TPUMirroredVariable:{
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,
  1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense_1/kernel/replica_1:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,
  2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense_1/kernel/replica_2:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,
  3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense_1/kernel/replica_3:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,
  4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense_1/kernel/replica_4:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,
  5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense_1/kernel/replica_5:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,
  6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense_1/kernel/replica_6:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,
  7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense_1/kernel/replica_7:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>
}

TPUの8個分のデバイスの出力が出てちょっとわかりづらいです(Distribute Trainingのため)が、0番目のデバイスだけ着目してみましょう。

Before update
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,

Update only A 
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
       [ 1.3821994]], dtype=float32)>,

Update only B
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
  0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
       [ 1.3811986]], dtype=float32)>,

Update only A では、モデルAの係数のみ変わっており、モデルBの係数は変わっていません。

Update only B では、モデルBの係数のみ変わっており、モデルAの係数は変わっていません。

うまくいってる

この結果を踏まえると、TPUでもいい感じにGANは実装できそうです。

Related Posts

TensorFlowの関数で画像にモザイクを書ける方法... TensorFlow2.0の関数を使い、画像にモザイクをかける方法を紹介します。OpenCVやPILでの書き方はいろいろありますが、TensorFlowでどう書くかはまず出てきませんでした。GPUやTPUでのブーストも使えます。 モザイク付与のアルゴリズム いろいろあるとは思いますが、自分が使...
PyTorchで双方向連結リストなデータ構造のモデルを作る... ディープラーニングのモデルには、訓練の途中でレイヤーを追加するなど特殊な訓練をするものがあります(Progressive-GANなど)。そのとき、モデルを「レイヤーやブロックの連結リスト」として定義しておくと見通しがよくなることがあります。その例を見ていきます。 訓練中に継ぎ足していくモデル ...
Spectral Normalization(SNGAN)を実装していろいろ遊んでみた... GANの安定化の大きなブレイクスルーである「Spectral Normalization」をPyTorchで実装していろいろ遊んでみました。従来のGANよりも多クラスの出力がかなりやりやすくなりました。確かにGANの安定化についてはものすごい効いているので、ぜひ皆さんも遊んでみてください。 ※ア...
Kerasで損失関数に複数の変数を渡す方法... Kerasで少し複雑なモデルを訓練させるときに、損失関数にy_true, y_pred以外の値を渡したいときがあります。クラスのインスタンス変数などでキャッシュさせることなく、ダイレクトに損失関数に複数の値を渡す方法を紹介します。 元ネタ:Passing additional arguments...
SA-GANの実装から見る画像のSelf attention 自然言語処理でよく使われるSelf-attentionは画像処理においてもたびたび使われることがあります。自然言語処理のは出てきても、画像のはあまり情報が出てこなかったので、SAGANの実装から画像におけるSelf attentionを見ていきます。 SA-GAN Self attention...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です