こしあん
2019-01-20

条件に応じた配列の要素の抽出をTensorFlowで行う

Pocket
LINEで送る


Numpyで条件を与えて、インデックスのスライスによって配列の要素を抽出する、というようなケースはよくあります。これをTensorFlowのテンソルでやるのにはどうすればいいのでしょうか?それを見ていきます。

Numpyではこんな例

例えば、5×5のランダムな行列をデータとします。この配列を左上から右下に0~24の通し番号を振り、この通し番号が偶数の要素のみ抽出したいとします。

これをNumpyで書くと次のようになります。

import numpy as np

def numpy_select():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    print(X)
    #[[ 5 21 22 18 15]
    # [ 8  7 11  4  3]
    # [24 12 16  9 14]
    # [20  0  1 10 19]
    # [17  6 23  2 13]]
    ind = np.arange(25).reshape(5,5)
    flag = ind % 2 == 0
    print(X[flag])
    # [ 5 22 15  7  4 24 16 14  0 10 17 23 13]

ちゃんと偶数番目の要素のみ抽出されているのがわかります。

TensorFlowで書く場合はちょっとめんどい

Numpyの場合はインデックスに与えて終わり、でしたが、TensorFlowで書く場合はちょっと面倒になります。Numpyと同じようにやってしまうと、

import tensorflow as tf
import keras.backend as K

def tf_select_bad():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    print(tX[tflag]) # エラー
    # TypeError: unsupported operand type(s) for +: 'Tensor' and 'int'

このようにエラーになってしまいます。tf.whereで条件のインデックス配列を選択、tf.gather_ndでそのインデックスに応じて抽出とちょっと回りくどいことをしなくてはいけません。次の例は正しく動作します。

import tensorflow as tf
import keras.backend as K

def tf_select_good():
    np.random.seed(123)
    X = np.random.permutation(25).reshape(5,5)
    tX = K.variable(X)
    print(K.eval(tX))
    #[[ 5. 21. 22. 18. 15.]
    # [ 8.  7. 11.  4.  3.]
    # [24. 12. 16.  9. 14.]
    # [20.  0.  1. 10. 19.]
    # [17.  6. 23.  2. 13.]]

    ind = np.arange(25).reshape(5,5)
    tind = K.variable(ind)
    tflag = K.equal(tind%2, 0)

    select_indices = tf.where(tflag)
    gather = tf.gather_nd(tX, select_indices)
    print(K.eval(gather))
    # [ 5. 22. 15.  7.  4. 24. 16. 14.  0. 10. 17. 23. 13.]

うまくいきました。公式ドキュメントは以下のとおりです。

tf.gatherとtf.gather_ndは別物なので注意してくださいね。

Related Posts

PyTorch/TorchVisionで複数の入力をモデルに渡したいケース... PyTorch/TorchVisionで入力が複数あり、それぞれの入力に対して同じ前処理(transforms)をかけるケースを考えます。デフォルトのtransformsは複数対応していないのでうまくいきません。しかし、ラッパークラスを作り、それで前処理をラップするといい感じにできたのでその方法を...
argparseに直接dictを読み込ませる怪しいやり方... argparseにコマンドライン引数ではなく、ファイルから読み込んだdictをオーバーラップさせる方法を試してみました。本来のargparseの使い方ではない怪しいやり方ですが、JSONやyamlファイルとの連携が可能なので便利ではないかなと思います。 注意 これは本来のargparseの使い...
pix2pixを1から実装して白黒画像をカラー化してみた(PyTorch)... pix2pixによる白黒画像のカラー化を1から実装します。PyTorchで行います。かなり自然な色付けができました。pix2pixはGANの中でも理論が単純なのにくわえ、学習も比較的安定しているので結構おすすめです。 はじめに PyTorchでDCGANができたので、今回はpix2pixをやり...
PyTorchで双方向連結リストなデータ構造のモデルを作る... ディープラーニングのモデルには、訓練の途中でレイヤーを追加するなど特殊な訓練をするものがあります(Progressive-GANなど)。そのとき、モデルを「レイヤーやブロックの連結リスト」として定義しておくと見通しがよくなることがあります。その例を見ていきます。 訓練中に継ぎ足していくモデル ...
TensorFlow2.0でDistribute Trainingしたときにfitと訓練ループで精度... TensorFlowでDistribute Training(複数GPUやTPUでの訓練)をしたときに、Keras APIのfit()でのValidation精度と、訓練ループを書いたときの精度でかなり(1~2%)違うという状況に遭遇しました。特定の文を忘れただけだったのですが、解決に1日かかった...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です