Kerasでランドマーク検出用の損失関数を作る上でのポイント
ランドマーク検出やオブジェクト検出では、yに最初に物体やランドマークが存在する確率をおいて、それ以降に座標を配置するというようなデータ構造を取ります。その場合、カスタム損失関数を定義する必要が出てきますが、どのように定義するれば良いでしょうか。それを見ていきます。
目次
Kerasの損失関数
分類問題ではy(ラベル)はそのクラスに属せば0、属さなければ1という簡単な設定でした。例えば猫かどうかを分類する場合、1サンプルあたりのyは「猫ならば1、猫以外ならば0」という形になります。ところで、損失関数は
def loss_function(y_true, y_pred):
# 損失を計算するための計算
return loss
という形で定義されます。実際はloss=”binary_crossentropy”とか指定するため、このように損失関数を書く必要はありませんが。この猫の分類の場合、y_true, y_predに渡されるテンソルの形はどのようになるでしょうか。
Pythonのブロードキャスティングによるバグを回避するために、ランク1の行列を使わないとするなら、損失関数に渡されるy_true, y_predの形はおそらく(ミニバッチサイズ, 1)となるはずです。これはサイズ1のベクトルをミニバッチサイズ分積み重ねたもので、例えばXの行列と似ていますね。実際にy_trueの中身をprint(K.int_shape(y_true))とかで表示すると、(None, None)が出てきてデバッグに少し困ることがあるのですがそこはちょっと闇ということで…。
では、これを犬か猫かという多クラス分類にする場合、損失関数のy_true, y_predの形はどうなるでしょうか。1サンプルの場合、猫だったら[1, 0]、犬だったら[0, 1]と定義すれば、2つ目の次元が増えるだけですね。なので、(ミニバッチサイズ, 2)となるはずです。クラス数が増えればあとは同様ですね。
ランドマーク検出におけるy(ラベル)
さてここからが本題。ランドマーク検出では問題設定にもよりますが、yを次のようにおきます。
- 画像内にランドマークが存在するか(0なら存在しない、1なら存在する)
- ランドマーク1のxの座標(画像左端が0で、右端が1)
- ランドマーク1のyの座標(画像上橋が0で、右端が1)
- ランドマーク2のxの座標(以下同様)
これをランドマークの数だけつぎ込みます。ランドマークの数がLなら、1サンプルあたりのyの次元は「N=2×L+1」となります。基本的には多クラス分類と同じなのですが、確率と座標という全く2つの値を扱うため、カスタム損失関数を定義してもいいと思います。面倒だったら両方とも平均二乗誤差で計算しても大丈夫だとは思います。
ではどうカスタム損失関数を定義すればよいでしょうか?確率の部分はbinary_cross_entropy(交差エントロピー)で、座標の部分はmean_squared_error(平均二乗誤差)で計算するのが良さそうです。これを定義しましょう。
ランドマーク検出のカスタム損失関数
y_true, y_predの次元は(ミニバッチサイズ, N)であることを意識して定義します。これが一例です。
from keras.objectives import binary_crossentropy, mean_squared_error
import keras.backend as K
def loss_function(y_true, y_pred):
bce = binary_crossentropy(K.expand_dims(y_true[:, 0]), K.expand_dims(y_pred[:, 0]))
mse = K.sign(y_true[:,0]) * mean_squared_error(y_true[:, 1:], y_pred[:, 1:])
return K.mean(bce + mse)
追記:最後のreturn以降のK.mean()ですが、別になくてもいいようです(むしろないほうが推奨?)。これが必要かどうかは、サンプル単位で損失の加重をつけるかどうかなので、そこらへんを気にしない場合はどっちでもいいと思います。
交差エントロピー部分
まずは交差エントロピーから。スライスしたのを直接binary_cross_entropyに放り込んではいけません。そのまま放り込むとランク1のテンソル同士の計算になり、交差エントロピーがスカラー(実数)になります。本来、サンプル間の確率の交差エントロピーを取りたいため、ここの出力はベクトルでほしいのです。奇妙なことが起こるかどうか実際に確認してみましょう。
>>> K.get_value(y_true)
array([[1. , 0.2, 0.2, 0.4, 0.5],
[0. , 0. , 0. , 0. , 0. ],
[1. , 0.7, 0.8, 0.3, 0.3]], dtype=float32)
>>> K.get_value(y_pred)
array([[0.9, 0.1, 0.3, 0.5, 0.7],
[0.2, 0.7, 0.8, 0.1, 0.2],
[0.7, 0.5, 0.8, 0.2, 0.3]], dtype=float32)
これは今説明のためにy_true, y_predに簡単な値を設定してみました。まずはダメな例から。
>>> K.get_value(binary_crossentropy(y_true[:,0], y_pred[:,0]))
0.228393
はい、ダメです。途中でランク1になってしまったので、サンプル間の確率がそのまま交差エントロピーの計算に放り込まれてしまったのです。ここはK.expand_dims()でスライスした後に次元を増やすのが正解です。
>>> K.get_value(binary_crossentropy(K.expand_dims(y_true[:,0]), K.expand_dims(y
pred[:,0])))
array([0.10536053, 0.22314355, 0.35667494], dtype=float32)
これが本来欲しかった、サンプル別の確率の交差エントロピーです。非常に間違いやすいところなのですが、ブロードキャスティングで値が出る場合は、こういうケースでも実行時に警告一切出してくれないのですよね。なので、しっかり別立てで確認していかないと思わぬ落とし穴になると思います。
ちなみにブロードキャスティングのバグを直したら訓練が1~2割高速化しました。Numpyのブロードキャスティングは気づく場合があっても、Kerasのテンソルのブロードキャスティングは知らないまま動かしていて、闇が深いような感じがします。これが原因で訓練結果がおかしかったら普通はデータやモデルのほうを疑ってしまうので、なかなかここまで気づけない。
平均二乗誤差部分
さて、損失関数の続きを。こちらはブロードキャスティングの問題がないので、比較的わかりやすいと思います。
mse = K.sign(y_true[:,0]) * mean_squared_error(y_true[:, 1:], y_pred[:, 1:])
最初のsign関数ですが、ランドマークがない場合は続く平均二乗誤差を無視しますよ、という意味です。K.signは要素別の計算なので、ランク1のテンソルを代入してもスカラーになることはありません。
>>> K.get_value(K.sign(y_true[:,0]))
array([1., 0., 1.], dtype=float32)
続く平均二乗誤差ですが、スライスはしているものの複数列のスライスなのでランク2のテンソルのままです。したがって、集計機能のある(ランクが落ちる)mean_squared_errorに代入してもスカラーになることはありません。ランク1のテンソル(ベクトル)になります。
>>> K.get_value(mean_squared_error(y_true[:, 1:], y_pred[:, 1:]))
array([0.0175, 0.295 , 0.0125], dtype=float32)
交差エントロピーと平均二乗誤差を足す
あとは簡単ですね。最後にサンプル別のベクトルを何らかの集計関数にかけましょう。合計でも平均でもどっちでもいいと思うのですが、一応平均にしてみました(ミニバッチサイズが変わったときに統一性を出すため)。
return K.mean(bce + mse)
一応数値例でも確認してみましょう。
>>> bce = binary_crossentropy(K.expand_dims(y_true[:,0]), K.expand_dims(y_pred
,0]))
>>> bce
<tf.Tensor 'Mean_22:0' shape=(3,) dtype=float32>
>>> mse = mean_squared_error(y_true[:, 1:], y_pred[:, 1:])
>>> mse
<tf.Tensor 'Mean_23:0' shape=(3,) dtype=float32>
>>> error = K.mean(bce + mse)
>>> error
<tf.Tensor 'Mean_24:0' shape=() dtype=float32>
>>> K.get_value(error)
0.33672634
うまくいきましたね。たった3行のコードでしたが、ランク1のテンソルという思わぬ落とし穴があり、テンソル演算の部分はきちんとデバッグしないといけないという闇の深さを思い知らされました。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー