TensorFlow2.0+TPUで複数のモデルを交互に訓練するテスト
Posted On 2019-10-17
GANの利用を想定します。以前TPUだと複数のモデルを同時or交互に訓練するというのは厳しかったのですが、これがTF2.0で変わったのか確かめます。
環境:TensorFlow2.0.0、Colab TPU
目次
コード
おまじない
import tensorflow as tf
import os
# tpu用
# 詳細 https://www.tensorflow.org/guide/distributed_training#tpustrategy
tpu_grpc_url = "grpc://" + os.environ["COLAB_TPU_ADDR"]
tpu_cluster_resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu_grpc_url)
tf.config.experimental_connect_to_cluster(tpu_cluster_resolver) # TF2.0の場合、ここを追加
tf.tpu.experimental.initialize_tpu_system(tpu_cluster_resolver) # TF2.0の場合、今後experimentialが取れる可能性がある
strategy = tf.distribute.experimental.TPUStrategy(tpu_cluster_resolver) # ここも同様
KerasのAPIを使うのではなくカスタムの訓練ループを書きます。
import tensorflow as tf
import tensorflow.keras.layers as layers
def model_a():
inputs = layers.Input((1,))
x = layers.Dense(2, use_bias=False)(inputs)
return tf.keras.models.Model(inputs, x)
def model_b():
inputs = layers.Input((2,))
x = layers.Dense(1, use_bias=False)(inputs)
return tf.keras.models.Model(inputs, x)
def main():
with strategy.scope():
ma = model_a()
mb = model_b()
loss_func = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)
optima = tf.keras.optimizers.Adam()
optimb = tf.keras.optimizers.Adam()
print("Before update")
print(ma.trainable_weights)
print(mb.trainable_weights)
def train_a():
# update A
with tf.GradientTape() as tape:
x = tf.random.normal(shape=(2,1))
y_pred = mb(ma(x))
y = tf.ones(shape=(2, 1))
loss = loss_func(y, y_pred)
grad = tape.gradient(loss, ma.trainable_weights)
optima.apply_gradients(zip(grad, ma.trainable_weights))
def train_b():
# update B
with tf.GradientTape() as tape:
x = tf.random.normal(shape=(2,1))
y_pred = mb(ma(x))
y = tf.ones(shape=(2, 1))
loss = loss_func(y, y_pred)
grad = tape.gradient(loss, mb.trainable_weights)
optimb.apply_gradients(zip(grad, mb.trainable_weights))
@tf.function
def distributed_train_a():
return strategy.experimental_run_v2(train_a)
@tf.function
def distributed_train_b():
return strategy.experimental_run_v2(train_b)
distributed_train_a()
print("Update only A ")
print(ma.trainable_weights)
print(mb.trainable_weights)
distributed_train_b()
print("Update only B")
print(ma.trainable_weights)
print(mb.trainable_weights)
if __name__ == "__main__":
main()
「ModelB(ModelA(input))」のような出力をして、あとは適当に損失関数を入れ、「Aのみ訓練→Bのみ訓練」がうまく行っているか確かめます。
結果
Before update
[TPUMirroredVariable:{
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense/kernel/replica_1:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense/kernel/replica_2:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense/kernel/replica_3:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense/kernel/replica_4:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense/kernel/replica_5:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense/kernel/replica_6:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense/kernel/replica_7:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>
}]
[TPUMirroredVariable:{
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense_1/kernel/replica_1:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense_1/kernel/replica_2:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense_1/kernel/replica_3:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense_1/kernel/replica_4:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense_1/kernel/replica_5:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense_1/kernel/replica_6:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense_1/kernel/replica_7:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>
}]
Update only A
[TPUMirroredVariable:{
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense/kernel/replica_1:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense/kernel/replica_2:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense/kernel/replica_3:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense/kernel/replica_4:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense/kernel/replica_5:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense/kernel/replica_6:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense/kernel/replica_7:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>
}]
[TPUMirroredVariable:{
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense_1/kernel/replica_1:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense_1/kernel/replica_2:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense_1/kernel/replica_3:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense_1/kernel/replica_4:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense_1/kernel/replica_5:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense_1/kernel/replica_6:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense_1/kernel/replica_7:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>
}]
Update only B
[TPUMirroredVariable:{
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense/kernel/replica_1:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense/kernel/replica_2:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense/kernel/replica_3:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense/kernel/replica_4:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense/kernel/replica_5:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense/kernel/replica_6:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense/kernel/replica_7:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>
}]
[TPUMirroredVariable:{
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
1 /job:worker/replica:0/task:0/device:TPU:1: <tf.Variable 'dense_1/kernel/replica_1:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
2 /job:worker/replica:0/task:0/device:TPU:2: <tf.Variable 'dense_1/kernel/replica_2:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
3 /job:worker/replica:0/task:0/device:TPU:3: <tf.Variable 'dense_1/kernel/replica_3:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
4 /job:worker/replica:0/task:0/device:TPU:4: <tf.Variable 'dense_1/kernel/replica_4:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
5 /job:worker/replica:0/task:0/device:TPU:5: <tf.Variable 'dense_1/kernel/replica_5:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
6 /job:worker/replica:0/task:0/device:TPU:6: <tf.Variable 'dense_1/kernel/replica_6:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
7 /job:worker/replica:0/task:0/device:TPU:7: <tf.Variable 'dense_1/kernel/replica_7:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>
}
TPUの8個分のデバイスの出力が出てちょっとわかりづらいです(Distribute Trainingのため)が、0番目のデバイスだけ着目してみましょう。
Before update
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.31612158, -0.6931683 ]], dtype=float32)>,
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
Update only A
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.2208728],
[ 1.3821994]], dtype=float32)>,
Update only B
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense/kernel:0' shape=(1, 2) dtype=float32, numpy=array([[ 0.3151208, -0.6921675]], dtype=float32)>,
0 /job:worker/replica:0/task:0/device:TPU:0: <tf.Variable 'dense_1/kernel:0' shape=(2, 1) dtype=float32, numpy=
array([[-1.219872 ],
[ 1.3811986]], dtype=float32)>,
Update only A では、モデルAの係数のみ変わっており、モデルBの係数は変わっていません。
Update only B では、モデルBの係数のみ変わっており、モデルAの係数は変わっていません。
→うまくいってる
この結果を踏まえると、TPUでもいい感じにGANは実装できそうです。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー