こしあん
2019-10-17

TensorFlow2.0のGradientTapeを複数使う場合のサンプル


4.5k{icon} {views}


TF2.0系で少し複雑なモデルを訓練するときに、GradientTapeを複数使うことがあります。例として、微分を取りたい場所が2箇所あるケースや、2階微分を取りたいケースを挙げます。場合によって微妙に書き方が違うので注意が必要です。

微分を取りたい場所が2箇所あるケース

簡単な例にします。次の関数の微分を考えます。

$$y=\log x^2 $$

合成関数の微分の公式の要領で、$y$の$x$微分と、$x^2$の$x$微分を計算してみます。

import tensorflow as tf
import numpy as np

def get_derivative(inputs):
    x = tf.Variable(np.array([inputs, inputs], np.float32)) # gadientを取るためにVariableとする
    with tf.GradientTape() as tape1, tf.GradientTape() as tape2:
        y1 = x ** 2
        y2 = tf.math.log(y1)
    print("dy1/dx =", tape1.gradient(y1, x))
    print("dy2/dx =", tape2.gradient(y2, x))

def main():
    for i in range(1, 4):
        print("i = ", i)
        get_derivative(i)

if __name__ == "__main__":
    main()

xの値を1~3と変化させます。ちなみに数学的に計算すると、

$$\frac{dy}{dx}=\frac{2}{x}, \frac{d}{dx}x^2 = 2x $$

となります。結果は数学的に計算した通りになりました。

i =  1
dy1/dx = tf.Tensor([2. 2.], shape=(2,), dtype=float32)
dy2/dx = tf.Tensor([2. 2.], shape=(2,), dtype=float32)
i =  2
dy1/dx = tf.Tensor([4. 4.], shape=(2,), dtype=float32)
dy2/dx = tf.Tensor([1. 1.], shape=(2,), dtype=float32)
i =  3
dy1/dx = tf.Tensor([6. 6.], shape=(2,), dtype=float32)
dy2/dx = tf.Tensor([0.6666667 0.6666667], shape=(2,), dtype=float32)

ポイントはここです。

    with tf.GradientTape() as tape1, tf.GradientTape() as tape2:

このように1つのwith句の中に2つのGradientTapeをおくと、2箇所の微分を取ることができます。ちなみに、tape1をtape2にしたり、tape2をtape1にしたりと、同一のGradientTapeに対して2回以上tape.gradientを呼び出すと次のようなエラーになります。

RuntimeError: GradientTape.gradient can only be called once on non-persistent tapes.

つまりこの記事でやっているのは、「GradientTape1つに対して2回以上偏微分計算できない」という縛りを回避するために、GradientTapeを2つ定義しているということです。

ちなみに、この例は次のようにネストさせて書いてもOKです。

def get_derivative(inputs):
    x = tf.Variable(np.array([inputs, inputs], np.float32)) # gadientを取るためにVariableとする
    with tf.GradientTape() as tape1:
        with tf.GradientTape() as tape2:
            y1 = x ** 2
        y2 = tf.math.log(y1)
    print("dy1/dx =", tape2.gradient(y1, x))
    print("dy2/dx =", tape1.gradient(y2, x))

2階以上の偏微分を取る

ディープラーニングで使うケースとしては少ないと思いますが、GradientTapeをネストさせることで2階以上の偏微分が計算できます。先程の2箇所の微分計算とは少し異なる書き方をします。

import tensorflow as tf
import numpy as np

def get_higher_derivative(inputs):
    x = tf.Variable(np.array([inputs, inputs], np.float32))
    with tf.GradientTape() as tape1:
        with tf.GradientTape() as tape2:
            y = tf.math.log(x ** 2)
        dy_dx = tape2.gradient(y, x)
    d2y_dx2 = tape1.gradient(dy_dx, x)

    print("dy/dx =", dy_dx)
    print("d2y/dx2 =", d2y_dx2)

def main():
    for i in range(1, 4):
        print("i = ", i)
        get_higher_derivative(i)

if __name__ == "__main__":
    main()

ちなみに、

$$y=\log x^2 $$

この関数の2階微分までを数学的に解くと、

$$\frac{dy}{dx}=\frac{2}{x}, \frac{d^2y}{dx^2}=-\frac{2}{x^2}$$

となります。これとあっているか確認してみます。

i =  1
dy/dx = tf.Tensor([2. 2.], shape=(2,), dtype=float32)
d2y/dx2 = tf.Tensor([-2. -2.], shape=(2,), dtype=float32)
i =  2
dy/dx = tf.Tensor([1. 1.], shape=(2,), dtype=float32)
d2y/dx2 = tf.Tensor([-0.5 -0.5], shape=(2,), dtype=float32)
i =  3
dy/dx = tf.Tensor([0.6666667 0.6666667], shape=(2,), dtype=float32)
d2y/dx2 = tf.Tensor([-0.22222222 -0.22222222], shape=(2,), dtype=float32)

確かにあっているようです。GradientTapeの部分を再掲します。

    with tf.GradientTape() as tape1:
        with tf.GradientTape() as tape2:
            y = tf.math.log(x ** 2)
        dy_dx = tape2.gradient(y, x)
    d2y_dx2 = tape1.gradient(dy_dx, x)

最初に複数のTapeをネストさせて宣言するのがポイントです。yの計算はtape1, tape2両方記録しています。tape2側がgradientを計算した部分は、tape2のインデントの外側なので、この偏微分を計算するというグラフはtape1だけ記録されます。したがって、tape1でもう一回gradientを呼べば2階微分が計算できる……という仕組み。

ちなみに最後の2行のtape2とtape1を入れ替えると、

i =  1
dy/dx = tf.Tensor([2. 2.], shape=(2,), dtype=float32)
d2y/dx2 = None
i =  2
dy/dx = tf.Tensor([1. 1.], shape=(2,), dtype=float32)
d2y/dx2 = None
i =  3
dy/dx = tf.Tensor([0.6666667 0.6666667], shape=(2,), dtype=float32)
d2y/dx2 = None

となり、2階部分が正常に計算できません。これはスコープの外側になっているためです。

まとめ

少し複雑なモデルではGradientTapeを複数個作ることがありますが、ケースによって適切な書き方が微妙に違うので確かめながらやりましょうということでした。あくまで想像ですが、インデント分けて書いたほうが無駄な微分計算をしなくて厳密……なのかもしれません。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です