こしあん
2018-11-06

Kerasのバックエンドで「○○以上☓☓以下」を計算する方法


Kerasのバックエンド関数を使ったときに「○○以上☓☓以下」を求めたい場合があります。しかし、KerasではAndのような論理演算をすると少し困ることがあります。その方法を解説します。

10×10の行列

0~99の数字を並べた「10×10」の行列を用意します。TensorFlowのテンソルで定義します。

import numpy as np
import keras.backend as K

X = K.variable(np.arange(100).reshape(10,10))
print(K.eval(X))

このとおり。

[[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14. 15. 16. 17. 18. 19.]
 [20. 21. 22. 23. 24. 25. 26. 27. 28. 29.]
 [30. 31. 32. 33. 34. 35. 36. 37. 38. 39.]
 [40. 41. 42. 43. 44. 45. 46. 47. 48. 49.]
 [50. 51. 52. 53. 54. 55. 56. 57. 58. 59.]
 [60. 61. 62. 63. 64. 65. 66. 67. 68. 69.]
 [70. 71. 72. 73. 74. 75. 76. 77. 78. 79.]
 [80. 81. 82. 83. 84. 85. 86. 87. 88. 89.]
 [90. 91. 92. 93. 94. 95. 96. 97. 98. 99.]]

「9で割った余りが3以上7以下」なら1、それ以外なら0を出力する

上の行列に対して「9で割った余りが3以上7以下」なら1、それ以外なら0を出力するという操作を適用します。はじめは10で割った余りにしようと思いましたが、それだとインデックスによるスライスで簡単にできてしまいそうなので、9で割った余りにしました。

割り算の余り

まず余りですが、普通のNumpyのように「%」演算子でできます。

print(K.eval(X%9))
[[0. 1. 2. 3. 4. 5. 6. 7. 8. 0.]
 [1. 2. 3. 4. 5. 6. 7. 8. 0. 1.]
 [2. 3. 4. 5. 6. 7. 8. 0. 1. 2.]
 [3. 4. 5. 6. 7. 8. 0. 1. 2. 3.]
 [4. 5. 6. 7. 8. 0. 1. 2. 3. 4.]
 [5. 6. 7. 8. 0. 1. 2. 3. 4. 5.]
 [6. 7. 8. 0. 1. 2. 3. 4. 5. 6.]
 [7. 8. 0. 1. 2. 3. 4. 5. 6. 7.]
 [8. 0. 1. 2. 3. 4. 5. 6. 7. 8.]
 [0. 1. 2. 3. 4. 5. 6. 7. 8. 0.]]

「○○以上」、「○○以下」という判定

次に「○○以上」という判定ですが、これもK.greater_equalという関数を使うとできます。

X = K.variable(np.arange(100).reshape(10,10))
print(K.eval(K.greater_equal(X%9, 3)))
[[False False False  True  True  True  True  True  True False]
 [False False  True  True  True  True  True  True False False]
 [False  True  True  True  True  True  True False False False]
 [ True  True  True  True  True  True False False False  True]
 [ True  True  True  True  True False False False  True  True]
 [ True  True  True  True False False False  True  True  True]
 [ True  True  True False False False  True  True  True  True]
 [ True  True False False False  True  True  True  True  True]
 [ True False False False  True  True  True  True  True  True]
 [False False False  True  True  True  True  True  True False]]

同じく「○○以下」という判定も、K.less_equalという関数を使うとできます。

X = K.variable(np.arange(100).reshape(10,10))
print(K.eval(K.less_equal(X%9, 7)))
[[ True  True  True  True  True  True  True  True False  True]
 [ True  True  True  True  True  True  True False  True  True]
 [ True  True  True  True  True  True False  True  True  True]
 [ True  True  True  True  True False  True  True  True  True]
 [ True  True  True  True False  True  True  True  True  True]
 [ True  True  True False  True  True  True  True  True  True]
 [ True  True False  True  True  True  True  True  True  True]
 [ True False  True  True  True  True  True  True  True  True]
 [False  True  True  True  True  True  True  True  True False]
 [ True  True  True  True  True  True  True  True False  True]]

ポイントは要素間のAND

次がポイント。「3以上」と「7以下」の判定をどう結合(論理積)を取ればよいでしょうか。ちなみにここではbool型で返ってきているので、掛け算をしようとすると怒られます。

実はあんまり有名になっていませんが、論理演算子もオーバーロードされているようです。なので、「&:AND」「|:OR」で要素間の論理演算ができてしまいます。

X = K.variable(np.arange(100).reshape(10,10))
flag = K.greater_equal(X%9, 3) & K.less_equal(X%9, 7)
print(K.eval(flag))
[[False False False  True  True  True  True  True False False]
 [False False  True  True  True  True  True False False False]
 [False  True  True  True  True  True False False False False]
 [ True  True  True  True  True False False False False  True]
 [ True  True  True  True False False False False  True  True]
 [ True  True  True False False False False  True  True  True]
 [ True  True False False False False  True  True  True  True]
 [ True False False False False  True  True  True  True  True]
 [False False False False  True  True  True  True  True False]
 [False False False  True  True  True  True  True False False]]

これで一通りOKですね。「9で割って3以上7以下のところだけTrue」になっています。

最後にキャスト

今欲しいのはTrue/Falseではなく、0/1という数字なので(このまま例えば数字のXなどの数字のテンソルと掛け算をすると怒られます)、数字にキャストしてあげる必要があります。K.castを使えばいいです。

X = K.variable(np.arange(100).reshape(10,10))
flag = K.cast(K.greater_equal(X%9, 3) & K.less_equal(X%9, 7), "float32")
print(K.eval(flag))
[[0. 0. 0. 1. 1. 1. 1. 1. 0. 0.]
 [0. 0. 1. 1. 1. 1. 1. 0. 0. 0.]
 [0. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0. 0. 1.]
 [1. 1. 1. 1. 0. 0. 0. 0. 1. 1.]
 [1. 1. 1. 0. 0. 0. 0. 1. 1. 1.]
 [1. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
 [1. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1. 1. 0.]
 [0. 0. 0. 1. 1. 1. 1. 1. 0. 0.]]

これで完成しました。

「9で割って3以下」または「9で割って7以上」なら1を返す場合

今度は逆に論理和のケースです。

import numpy as np
import keras.backend as K

X = K.variable(np.arange(100).reshape(10,10))
flag = K.cast(K.less_equal(X%9, 3) | K.greater_equal(X%9, 7), "float32")
print(K.eval(flag))
[[1. 1. 1. 1. 0. 0. 0. 1. 1. 1.]
 [1. 1. 1. 0. 0. 0. 1. 1. 1. 1.]
 [1. 1. 0. 0. 0. 1. 1. 1. 1. 1.]
 [1. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 1. 1. 1. 1. 1. 1. 0.]
 [0. 0. 1. 1. 1. 1. 1. 1. 0. 0.]
 [0. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 0. 0. 0. 1.]
 [1. 1. 1. 1. 1. 0. 0. 0. 1. 1.]
 [1. 1. 1. 1. 0. 0. 0. 1. 1. 1.]]

同様にOKですね。

まとめ

Kerasで「○○以上」を取りたいときはK.greater_equal、○○以下を取りたいときはK.less_equalを使う。そして、論理演算子のオーバーロードが用意されているので、&や|でAND,ORといった計算をする

以上です。条件分岐で使うことも多いと思います。

Related Posts

Kerasに組み込まれているNASNet(Large)の実装... NASNet Largeのsummary __________________________________________________________________________________________________ Layer (type) ...
Kerasに組み込まれているResNet50の実装 ResNet50のsummary __________________________________________________________________________________________________ Layer (type) O...
Kerasに組み込まれているNASNet(Mobile)の実装... NASNet Mobileのsummary __________________________________________________________________________________________________ Layer (type) ...
ニューラル協調フィルタリングで特徴量抽出して、アニメ同士の足し算・引き算をやってみる... Qiitaからお引越しテスト。Qiitaの記事では、ニューラル協調フィルタリングでMyAnimeListのレコメンドデータから、アニメの作品単位の特徴量抽出を行い、クラスタリングの手法を用いて、アニメを10個のグループに分類しました。この記事では、同様に抽出した特徴量を用いて、Word2Vecの...
Google ColabのTPUでResNetのベンチマークを取ってみた... Google ColaboratoryでTPUが使えるようになりましたが、さっそくどのぐらい速いのかベンチマークを取ってみました。以前やったResNetのベンチマークを使います。 環境:Google Colab(TPU)、TensorFlow:1.11.0-rc2、Keras:2.1.6 コ...

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です