こしあん
2019-01-20

条件に応じた配列の要素の抽出をTensorFlowで行う

Pocket
LINEで送る


Numpyで条件を与えて、インデックスのスライスによって配列の要素を抽出する、というようなケースはよくあります。これをTensorFlowのテンソルでやるのにはどうすればいいのでしょうか?それを見ていきます。

Numpyではこんな例

例えば、5×5のランダムな行列をデータとします。この配列を左上から右下に0~24の通し番号を振り、この通し番号が偶数の要素のみ抽出したいとします。

これをNumpyで書くと次のようになります。

import numpy as np

def numpy_select():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    print(X)
    #[[ 5 21 22 18 15]
    # [ 8  7 11  4  3]
    # [24 12 16  9 14]
    # [20  0  1 10 19]
    # [17  6 23  2 13]]
    ind = np.arange(25).reshape(5,5)
    flag = ind % 2 == 0
    print(X[flag])
    # [ 5 22 15  7  4 24 16 14  0 10 17 23 13]

ちゃんと偶数番目の要素のみ抽出されているのがわかります。

TensorFlowで書く場合はちょっとめんどい

Numpyの場合はインデックスに与えて終わり、でしたが、TensorFlowで書く場合はちょっと面倒になります。Numpyと同じようにやってしまうと、

import tensorflow as tf
import keras.backend as K

def tf_select_bad():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    print(tX[tflag]) # エラー
    # TypeError: unsupported operand type(s) for +: 'Tensor' and 'int'

このようにエラーになってしまいます。tf.whereで条件のインデックス配列を選択、tf.gather_ndでそのインデックスに応じて抽出とちょっと回りくどいことをしなくてはいけません。次の例は正しく動作します。

import tensorflow as tf
import keras.backend as K

def tf_select_good():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)
    print(K.eval(tX))
    #[[ 5. 21. 22. 18. 15.]
    # [ 8.  7. 11.  4.  3.]
    # [24. 12. 16.  9. 14.]
    # [20.  0.  1. 10. 19.]
    # [17.  6. 23.  2. 13.]]

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    select_indices = tf.where(tflag)
    gather = tf.gather_nd(tX, select_indices)
    print(K.eval(gather))
    # [ 5. 22. 15.  7.  4. 24. 16. 14.  0. 10. 17. 23. 13.]

うまくいきました。公式ドキュメントは以下のとおりです。

tf.gatherとtf.gather_ndは別物なので注意してくださいね。

Related Posts

Kerasで転移学習用にレイヤー名とそのインデックスを調べる方法... Kerasで転移学習をするときに、学習済みモデルのレイヤーの名前と、そのインデックス(何番目にあるかということ)の対応を知りたいことがあります。その方法を解説します。 転移学習とは 転移学習とは、ImageNetなど何百万もの大量の画像で事前学習させたモデルを使い、それを「特徴量検出器」として...
OpenCVで画像を歪ませる方法 PythonでOpenCVを使い画像を歪ませる方法を考えます。アフィン変換というちょっと直感的に理解しにくいことをしますが、慣れればそこまで難しくはありません。ディープラーニングのData Augmentationにも使えます。 OpenCVでのアフィン変換のイメージ アフィン変換というと、ま...
ディープラーニング=最小二乗法のどこがダメなのか解説する... あるニュース記事で、ディープラーニング=最小二乗法で三次関数なんていう「伝説の画像」が出回っていたので、それに対して突っ込みつつ、非線形関数という立場からディープラーニングの本当の表現の豊かさを見ていきたいと思います。 きっかけ ある画像が出回っていた。日経新聞の解説らしい。 伝説の画像にな...
WarmupとData Augmentationのバッチサイズ別の精度低下について... 大きいバッチサイズで訓練する際は、バッチサイズの増加にともなう精度低下が深刻になります。この精度低下を抑制することはできるのですが、例えばData Augmentationのようなデータ増強・正則化による精度向上とは何が違うのでしょうか。それを調べてみました。 きっかけ この記事を書いたときに...
Numpyの配列に対して「最も多く存在する値」を求める方法... アンサンブル学習などで、Numpyの配列のある軸に対して「最も多く存在する値」を求めたい、つまり「多数決」をしたいことがあります。その方法を見ていきます。 最も大きい値がmax, 最も大きい値が存在するインデックスがargmax, では「最も多く存在する値」は? 配列のある軸に対して、「最も大...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です