こしあん
2018-11-06

Kerasのバックエンドで「○○以上☓☓以下」を計算する方法


2.8k{icon} {views}


Kerasのバックエンド関数を使ったときに「○○以上☓☓以下」を求めたい場合があります。しかし、KerasではAndのような論理演算をすると少し困ることがあります。その方法を解説します。

10×10の行列

0~99の数字を並べた「10×10」の行列を用意します。TensorFlowのテンソルで定義します。

import numpy as np
import keras.backend as K

X = K.variable(np.arange(100).reshape(10,10))
print(K.eval(X))

このとおり。

[[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9.]
 [10. 11. 12. 13. 14. 15. 16. 17. 18. 19.]
 [20. 21. 22. 23. 24. 25. 26. 27. 28. 29.]
 [30. 31. 32. 33. 34. 35. 36. 37. 38. 39.]
 [40. 41. 42. 43. 44. 45. 46. 47. 48. 49.]
 [50. 51. 52. 53. 54. 55. 56. 57. 58. 59.]
 [60. 61. 62. 63. 64. 65. 66. 67. 68. 69.]
 [70. 71. 72. 73. 74. 75. 76. 77. 78. 79.]
 [80. 81. 82. 83. 84. 85. 86. 87. 88. 89.]
 [90. 91. 92. 93. 94. 95. 96. 97. 98. 99.]]

「9で割った余りが3以上7以下」なら1、それ以外なら0を出力する

上の行列に対して「9で割った余りが3以上7以下」なら1、それ以外なら0を出力するという操作を適用します。はじめは10で割った余りにしようと思いましたが、それだとインデックスによるスライスで簡単にできてしまいそうなので、9で割った余りにしました。

割り算の余り

まず余りですが、普通のNumpyのように「%」演算子でできます。

print(K.eval(X%9))
[[0. 1. 2. 3. 4. 5. 6. 7. 8. 0.]
 [1. 2. 3. 4. 5. 6. 7. 8. 0. 1.]
 [2. 3. 4. 5. 6. 7. 8. 0. 1. 2.]
 [3. 4. 5. 6. 7. 8. 0. 1. 2. 3.]
 [4. 5. 6. 7. 8. 0. 1. 2. 3. 4.]
 [5. 6. 7. 8. 0. 1. 2. 3. 4. 5.]
 [6. 7. 8. 0. 1. 2. 3. 4. 5. 6.]
 [7. 8. 0. 1. 2. 3. 4. 5. 6. 7.]
 [8. 0. 1. 2. 3. 4. 5. 6. 7. 8.]
 [0. 1. 2. 3. 4. 5. 6. 7. 8. 0.]]

「○○以上」、「○○以下」という判定

次に「○○以上」という判定ですが、これもK.greater_equalという関数を使うとできます。

X = K.variable(np.arange(100).reshape(10,10))
print(K.eval(K.greater_equal(X%9, 3)))
[[False False False  True  True  True  True  True  True False]
 [False False  True  True  True  True  True  True False False]
 [False  True  True  True  True  True  True False False False]
 [ True  True  True  True  True  True False False False  True]
 [ True  True  True  True  True False False False  True  True]
 [ True  True  True  True False False False  True  True  True]
 [ True  True  True False False False  True  True  True  True]
 [ True  True False False False  True  True  True  True  True]
 [ True False False False  True  True  True  True  True  True]
 [False False False  True  True  True  True  True  True False]]

同じく「○○以下」という判定も、K.less_equalという関数を使うとできます。

X = K.variable(np.arange(100).reshape(10,10))
print(K.eval(K.less_equal(X%9, 7)))
[[ True  True  True  True  True  True  True  True False  True]
 [ True  True  True  True  True  True  True False  True  True]
 [ True  True  True  True  True  True False  True  True  True]
 [ True  True  True  True  True False  True  True  True  True]
 [ True  True  True  True False  True  True  True  True  True]
 [ True  True  True False  True  True  True  True  True  True]
 [ True  True False  True  True  True  True  True  True  True]
 [ True False  True  True  True  True  True  True  True  True]
 [False  True  True  True  True  True  True  True  True False]
 [ True  True  True  True  True  True  True  True False  True]]

ポイントは要素間のAND

次がポイント。「3以上」と「7以下」の判定をどう結合(論理積)を取ればよいでしょうか。ちなみにここではbool型で返ってきているので、掛け算をしようとすると怒られます。

実はあんまり有名になっていませんが、論理演算子もオーバーロードされているようです。なので、「&:AND」「|:OR」で要素間の論理演算ができてしまいます。

X = K.variable(np.arange(100).reshape(10,10))
flag = K.greater_equal(X%9, 3) & K.less_equal(X%9, 7)
print(K.eval(flag))
[[False False False  True  True  True  True  True False False]
 [False False  True  True  True  True  True False False False]
 [False  True  True  True  True  True False False False False]
 [ True  True  True  True  True False False False False  True]
 [ True  True  True  True False False False False  True  True]
 [ True  True  True False False False False  True  True  True]
 [ True  True False False False False  True  True  True  True]
 [ True False False False False  True  True  True  True  True]
 [False False False False  True  True  True  True  True False]
 [False False False  True  True  True  True  True False False]]

これで一通りOKですね。「9で割って3以上7以下のところだけTrue」になっています。

最後にキャスト

今欲しいのはTrue/Falseではなく、0/1という数字なので(このまま例えば数字のXなどの数字のテンソルと掛け算をすると怒られます)、数字にキャストしてあげる必要があります。K.castを使えばいいです。

X = K.variable(np.arange(100).reshape(10,10))
flag = K.cast(K.greater_equal(X%9, 3) & K.less_equal(X%9, 7), "float32")
print(K.eval(flag))
[[0. 0. 0. 1. 1. 1. 1. 1. 0. 0.]
 [0. 0. 1. 1. 1. 1. 1. 0. 0. 0.]
 [0. 1. 1. 1. 1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 1. 1. 0. 0. 0. 0. 1.]
 [1. 1. 1. 1. 0. 0. 0. 0. 1. 1.]
 [1. 1. 1. 0. 0. 0. 0. 1. 1. 1.]
 [1. 1. 0. 0. 0. 0. 1. 1. 1. 1.]
 [1. 0. 0. 0. 0. 1. 1. 1. 1. 1.]
 [0. 0. 0. 0. 1. 1. 1. 1. 1. 0.]
 [0. 0. 0. 1. 1. 1. 1. 1. 0. 0.]]

これで完成しました。

「9で割って3以下」または「9で割って7以上」なら1を返す場合

今度は逆に論理和のケースです。

import numpy as np
import keras.backend as K

X = K.variable(np.arange(100).reshape(10,10))
flag = K.cast(K.less_equal(X%9, 3) | K.greater_equal(X%9, 7), "float32")
print(K.eval(flag))
[[1. 1. 1. 1. 0. 0. 0. 1. 1. 1.]
 [1. 1. 1. 0. 0. 0. 1. 1. 1. 1.]
 [1. 1. 0. 0. 0. 1. 1. 1. 1. 1.]
 [1. 0. 0. 0. 1. 1. 1. 1. 1. 1.]
 [0. 0. 0. 1. 1. 1. 1. 1. 1. 0.]
 [0. 0. 1. 1. 1. 1. 1. 1. 0. 0.]
 [0. 1. 1. 1. 1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 1. 1. 0. 0. 0. 1.]
 [1. 1. 1. 1. 1. 0. 0. 0. 1. 1.]
 [1. 1. 1. 1. 0. 0. 0. 1. 1. 1.]]

同様にOKですね。

まとめ

Kerasで「○○以上」を取りたいときはK.greater_equal、○○以下を取りたいときはK.less_equalを使う。そして、論理演算子のオーバーロードが用意されているので、&や|でAND,ORといった計算をする

以上です。条件分岐で使うことも多いと思います。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です