こしあん
2019-01-20

条件に応じた配列の要素の抽出をTensorFlowで行う

Pocket
LINEで送る
Delicious にシェア

2.5k{icon} {views}

新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』(A4・195ページ)好評通販中です! 機械学習の入門からGANの最先端までを書いたおすすめの本となっています! Boothで試し読みできます。情報まとめ・質問用GitHub



Numpyで条件を与えて、インデックスのスライスによって配列の要素を抽出する、というようなケースはよくあります。これをTensorFlowのテンソルでやるのにはどうすればいいのでしょうか?それを見ていきます。

Numpyではこんな例

例えば、5×5のランダムな行列をデータとします。この配列を左上から右下に0~24の通し番号を振り、この通し番号が偶数の要素のみ抽出したいとします。

これをNumpyで書くと次のようになります。

import numpy as np

def numpy_select():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    print(X)
    #[[ 5 21 22 18 15]
    # [ 8  7 11  4  3]
    # [24 12 16  9 14]
    # [20  0  1 10 19]
    # [17  6 23  2 13]]
    ind = np.arange(25).reshape(5,5)
    flag = ind % 2 == 0
    print(X[flag])
    # [ 5 22 15  7  4 24 16 14  0 10 17 23 13]

ちゃんと偶数番目の要素のみ抽出されているのがわかります。

TensorFlowで書く場合はちょっとめんどい

Numpyの場合はインデックスに与えて終わり、でしたが、TensorFlowで書く場合はちょっと面倒になります。Numpyと同じようにやってしまうと、

import tensorflow as tf
import keras.backend as K

def tf_select_bad():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    print(tX[tflag]) # エラー
    # TypeError: unsupported operand type(s) for +: 'Tensor' and 'int'

このようにエラーになってしまいます。tf.whereで条件のインデックス配列を選択、tf.gather_ndでそのインデックスに応じて抽出とちょっと回りくどいことをしなくてはいけません。次の例は正しく動作します。

import tensorflow as tf
import keras.backend as K

def tf_select_good():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)
    print(K.eval(tX))
    #[[ 5. 21. 22. 18. 15.]
    # [ 8.  7. 11.  4.  3.]
    # [24. 12. 16.  9. 14.]
    # [20.  0.  1. 10. 19.]
    # [17.  6. 23.  2. 13.]]

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    select_indices = tf.where(tflag)
    gather = tf.gather_nd(tX, select_indices)
    print(K.eval(gather))
    # [ 5. 22. 15.  7.  4. 24. 16. 14.  0. 10. 17. 23. 13.]

うまくいきました。公式ドキュメントは以下のとおりです。

tf.gatherとtf.gather_ndは別物なので注意してくださいね。


新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』好評通販中(A4・195ページ)です! Boothで試し読みもできるのでよろしくね!


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です