こしあん
2019-01-20

条件に応じた配列の要素の抽出をTensorFlowで行う


Numpyで条件を与えて、インデックスのスライスによって配列の要素を抽出する、というようなケースはよくあります。これをTensorFlowのテンソルでやるのにはどうすればいいのでしょうか?それを見ていきます。

Numpyではこんな例

例えば、5×5のランダムな行列をデータとします。この配列を左上から右下に0~24の通し番号を振り、この通し番号が偶数の要素のみ抽出したいとします。

これをNumpyで書くと次のようになります。

import numpy as np

def numpy_select():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    print(X)
    #[[ 5 21 22 18 15]
    # [ 8  7 11  4  3]
    # [24 12 16  9 14]
    # [20  0  1 10 19]
    # [17  6 23  2 13]]
    ind = np.arange(25).reshape(5,5)
    flag = ind % 2 == 0
    print(X[flag])
    # [ 5 22 15  7  4 24 16 14  0 10 17 23 13]

ちゃんと偶数番目の要素のみ抽出されているのがわかります。

TensorFlowで書く場合はちょっとめんどい

Numpyの場合はインデックスに与えて終わり、でしたが、TensorFlowで書く場合はちょっと面倒になります。Numpyと同じようにやってしまうと、

import tensorflow as tf
import keras.backend as K

def tf_select_bad():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    print(tX[tflag]) # エラー
    # TypeError: unsupported operand type(s) for +: 'Tensor' and 'int'

このようにエラーになってしまいます。tf.whereで条件のインデックス配列を選択、tf.gather_ndでそのインデックスに応じて抽出とちょっと回りくどいことをしなくてはいけません。次の例は正しく動作します。

import tensorflow as tf
import keras.backend as K

def tf_select_good():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)
    print(K.eval(tX))
    #[[ 5. 21. 22. 18. 15.]
    # [ 8.  7. 11.  4.  3.]
    # [24. 12. 16.  9. 14.]
    # [20.  0.  1. 10. 19.]
    # [17.  6. 23.  2. 13.]]

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    select_indices = tf.where(tflag)
    gather = tf.gather_nd(tX, select_indices)
    print(K.eval(gather))
    # [ 5. 22. 15.  7.  4. 24. 16. 14.  0. 10. 17. 23. 13.]

うまくいきました。公式ドキュメントは以下のとおりです。

tf.gatherとtf.gather_ndは別物なので注意してくださいね。

Related Posts

Kerasで重みを共有しつつ、必要に応じて入力の位置を変える方法... Kerasで訓練させて、途中から新しく入力を作ってそこからの出力までの値を取りたいということがたまにあります。例えば、Variational Auto Encoderのサンプリングなんかそうです。このあまり書かれていないのでざっとですが整理しておきます。 こういうことをやりたい 言葉で書いても...
Kerasのジェネレーターでサンプルが列挙される順番について... Kerasの(カスタム)ジェネレーターでサンプルがどの順番で呼び出されるか、1ループ終わったあとにどういう処理がなされるのか調べてみました。ジェネレーターを自分で定義するとモデルの表現の幅は広がるものの、バグが起きやすくなるので「本当に順番が保証されるのか」や「ハマりどころ」を確認します。 0~...
TPUで学習率減衰させる方法 TPUで学習率減衰したいが、TensorFlowのオプティマイザーを使うべきか、tf.kerasのオプティマイザーを使うべきか、あるいはKerasのオプティマイザーを使うべきか非常にややこしいことがあります。TPUで学習率を減衰させる方法を再現しました。 結論から TPU環境でtf.keras...
Numpyの配列に対して「最も多く存在する値」を求める方法... アンサンブル学習などで、Numpyの配列のある軸に対して「最も多く存在する値」を求めたい、つまり「多数決」をしたいことがあります。その方法を見ていきます。 最も大きい値がmax, 最も大きい値が存在するインデックスがargmax, では「最も多く存在する値」は? 配列のある軸に対して、「最も大...
KerasのCallbackを使って継承したImageDataGeneratorに値が渡せるか確かめ... Kerasで前処理の内容をエポックごとに変えたいというケースがたまにあります。これを実装するとなると、CallbackからGeneratorに値を渡すというコードになりますが、これが本当にできるかどうか確かめてみました。 想定する状況 例えば、前処理で正則化に関係するData Augmenta...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です