こしあん
2019-07-24

tf.tensordotで行列積を表現するための設定


Pocket
LINEで送る
Delicious にシェア

4k{icon} {views}


TensorFlowのtensordotという関数はとても強力で、テンソルに対する行列積に対する計算をだいたい表現できます。しかし、軸の設定がいまいちよくわからなかったので、確かめてみました。

2×2行列同士の積の場合(Numpy)

まず単純に2×2行列同士の(ドット)積を考えます。まずはNumpyから。

import numpy as np

a = np.arange(4).reshape(2, 2) + 3
b = np.arange(4).reshape(2, 2)
print("a", a)
print("b", b)    
print("np.dot(a,b)", np.dot(a, b))

答えは次のようになります。

a [[3 4]
 [5 6]]
b [[0 1]
 [2 3]]
np.dot(a,b) [[ 8 15]
 [12 23]]

これは行列の積の定義にならって計算したものです。次のような計算が行われています。

  • 8 = 3×0 + 4×2
  • 15 = 3×1 + 4×3
  • 12 = 5×0 + 6×2
  • 23 = 5×1 + 6×3

2×2行列の積の場合(tensordot)

同じ計算をtf.tensordotで表現すると次のようになります。

import tensorflow as tf
import keras.backend as K

a = np.arange(4).reshape(2, 2) + 3
b = np.arange(4).reshape(2, 2)
ta, tb = K.variable(a), K.variable(b)
result = tf.tensordot(ta, tb, axes=[1, 0])
print(K.eval(result))
[[ 8. 15.]
 [12. 23.]]

ポイントは、tensordotのaxesのです。それぞれ、掛ける(reduceする)軸と考えます。ta側がaxes=1, tb側がaxes=0で掛けるということですね。行列の積は横×縦で計算するので、行列の横は1番目の軸、縦は0番目の軸であるということを考えるとしっくりきます。

もっと複雑なテンソルの積の場合

より一般的に、「複雑なテンソル×行列」を計算したい場合です。例えば、ある軸に対してのみ行列積を適用して、それを他の軸に対しても直感的にはforループみたいな感じでやってほしいという例です。テンソル×テンソルよりかはあると思います。

Numpy

まずはNumpyで試してみます。掛けられる側を4階テンソル、掛ける側を行列(2階テンソル)とします。

a = np.arange(4).reshape(2, 2) + 3
a = np.broadcast_to(a, (4, 3, 2, 2))
b = np.arange(4).reshape(2, 2)    
print(np.dot(a, b))

先程の例の2×2行列を4×3回コピーしたものと同じです。np.dotでの計算結果は先程の2×2行列の積をひたすらコピーしたものとなります。

[[[[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]]


 [[[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]]


 [[[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]]


 [[[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]

  [[ 8 15]
   [12 23]]]]

tensordotの場合

こちらは先程の2×2行列のと同じように考えればよいです。4階テンソル側の積を取る軸は1番目ではなく、3番目(どっちにしても最後の軸)となるので、axesは「3と0」になります。

    a = np.arange(4).reshape(2, 2) + 3
    a = np.broadcast_to(a, (4, 3, 2, 2))
    b = np.arange(4).reshape(2, 2)    
    ta, tb = K.variable(a), K.variable(b)
    result = tf.tensordot(ta, tb, axes=[3, 0])

出力は先程のNumpyの場合と一緒です。

まとめ

2×2行列の積ができてしまえば、同じ発想でテンソル同士の積もいけそう。2×2行列の積でのaxesの指定の仕方は、左が1で右が0。これは行列積の定義を振りかえれば明らかということでした。

Pocket
LINEで送る
Delicious にシェア



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です