tf.tensordotで行列積を表現するための設定
TensorFlowのtensordotという関数はとても強力で、テンソルに対する行列積に対する計算をだいたい表現できます。しかし、軸の設定がいまいちよくわからなかったので、確かめてみました。
目次
2×2行列同士の積の場合(Numpy)
まず単純に2×2行列同士の(ドット)積を考えます。まずはNumpyから。
import numpy as np
a = np.arange(4).reshape(2, 2) + 3
b = np.arange(4).reshape(2, 2)
print("a", a)
print("b", b)
print("np.dot(a,b)", np.dot(a, b))
答えは次のようになります。
a [[3 4]
[5 6]]
b [[0 1]
[2 3]]
np.dot(a,b) [[ 8 15]
[12 23]]
これは行列の積の定義にならって計算したものです。次のような計算が行われています。
- 8 = 3×0 + 4×2
- 15 = 3×1 + 4×3
- 12 = 5×0 + 6×2
- 23 = 5×1 + 6×3
2×2行列の積の場合(tensordot)
同じ計算をtf.tensordotで表現すると次のようになります。
import tensorflow as tf
import keras.backend as K
a = np.arange(4).reshape(2, 2) + 3
b = np.arange(4).reshape(2, 2)
ta, tb = K.variable(a), K.variable(b)
result = tf.tensordot(ta, tb, axes=[1, 0])
print(K.eval(result))
[[ 8. 15.]
[12. 23.]]
ポイントは、tensordotのaxesのです。それぞれ、掛ける(reduceする)軸と考えます。ta側がaxes=1, tb側がaxes=0で掛けるということですね。行列の積は横×縦で計算するので、行列の横は1番目の軸、縦は0番目の軸であるということを考えるとしっくりきます。
もっと複雑なテンソルの積の場合
より一般的に、「複雑なテンソル×行列」を計算したい場合です。例えば、ある軸に対してのみ行列積を適用して、それを他の軸に対しても直感的にはforループみたいな感じでやってほしいという例です。テンソル×テンソルよりかはあると思います。
Numpy
まずはNumpyで試してみます。掛けられる側を4階テンソル、掛ける側を行列(2階テンソル)とします。
a = np.arange(4).reshape(2, 2) + 3
a = np.broadcast_to(a, (4, 3, 2, 2))
b = np.arange(4).reshape(2, 2)
print(np.dot(a, b))
先程の例の2×2行列を4×3回コピーしたものと同じです。np.dotでの計算結果は先程の2×2行列の積をひたすらコピーしたものとなります。
[[[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]]
[[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]]
[[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]]
[[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]
[[ 8 15]
[12 23]]]]
tensordotの場合
こちらは先程の2×2行列のと同じように考えればよいです。4階テンソル側の積を取る軸は1番目ではなく、3番目(どっちにしても最後の軸)となるので、axesは「3と0」になります。
a = np.arange(4).reshape(2, 2) + 3
a = np.broadcast_to(a, (4, 3, 2, 2))
b = np.arange(4).reshape(2, 2)
ta, tb = K.variable(a), K.variable(b)
result = tf.tensordot(ta, tb, axes=[3, 0])
出力は先程のNumpyの場合と一緒です。
まとめ
2×2行列の積ができてしまえば、同じ発想でテンソル同士の積もいけそう。2×2行列の積でのaxesの指定の仕方は、左が1で右が0。これは行列積の定義を振りかえれば明らかということでした。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー