こしあん
2018-11-09

Kerasでランドマーク検出用の損失関数を作る上でのポイント

Pocket
LINEで送る
Delicious にシェア

2.6k{icon} {views}



ランドマーク検出やオブジェクト検出では、yに最初に物体やランドマークが存在する確率をおいて、それ以降に座標を配置するというようなデータ構造を取ります。その場合、カスタム損失関数を定義する必要が出てきますが、どのように定義するれば良いでしょうか。それを見ていきます。

Kerasの損失関数

分類問題ではy(ラベル)はそのクラスに属せば0、属さなければ1という簡単な設定でした。例えば猫かどうかを分類する場合、1サンプルあたりのyは「猫ならば1、猫以外ならば0」という形になります。ところで、損失関数は

def loss_function(y_true, y_pred):
    # 損失を計算するための計算
    return loss

という形で定義されます。実際はloss=”binary_crossentropy”とか指定するため、このように損失関数を書く必要はありませんが。この猫の分類の場合、y_true, y_predに渡されるテンソルの形はどのようになるでしょうか。

Pythonのブロードキャスティングによるバグを回避するために、ランク1の行列を使わないとするなら、損失関数に渡されるy_true, y_predの形はおそらく(ミニバッチサイズ, 1)となるはずです。これはサイズ1のベクトルをミニバッチサイズ分積み重ねたもので、例えばXの行列と似ていますね。実際にy_trueの中身をprint(K.int_shape(y_true))とかで表示すると、(None, None)が出てきてデバッグに少し困ることがあるのですがそこはちょっと闇ということで…。

では、これを犬か猫かという多クラス分類にする場合、損失関数のy_true, y_predの形はどうなるでしょうか。1サンプルの場合、猫だったら[1, 0]、犬だったら[0, 1]と定義すれば、2つ目の次元が増えるだけですね。なので、(ミニバッチサイズ, 2)となるはずです。クラス数が増えればあとは同様ですね。

ランドマーク検出におけるy(ラベル)

さてここからが本題。ランドマーク検出では問題設定にもよりますが、yを次のようにおきます。

  1. 画像内にランドマークが存在するか(0なら存在しない、1なら存在する)
  2. ランドマーク1のxの座標(画像左端が0で、右端が1)
  3. ランドマーク1のyの座標(画像上橋が0で、右端が1)
  4. ランドマーク2のxの座標(以下同様)

これをランドマークの数だけつぎ込みます。ランドマークの数がLなら、1サンプルあたりのyの次元は「N=2×L+1」となります。基本的には多クラス分類と同じなのですが、確率と座標という全く2つの値を扱うため、カスタム損失関数を定義してもいいと思います。面倒だったら両方とも平均二乗誤差で計算しても大丈夫だとは思います。

ではどうカスタム損失関数を定義すればよいでしょうか?確率の部分はbinary_cross_entropy(交差エントロピー)で、座標の部分はmean_squared_error(平均二乗誤差)で計算するのが良さそうです。これを定義しましょう。

ランドマーク検出のカスタム損失関数

y_true, y_predの次元は(ミニバッチサイズ, N)であることを意識して定義します。これが一例です。

from keras.objectives import binary_crossentropy, mean_squared_error
import keras.backend as K

def loss_function(y_true, y_pred):
    bce = binary_crossentropy(K.expand_dims(y_true[:, 0]), K.expand_dims(y_pred[:, 0]))
    mse = K.sign(y_true[:,0]) * mean_squared_error(y_true[:, 1:], y_pred[:, 1:])
    return K.mean(bce + mse)

追記:最後のreturn以降のK.mean()ですが、別になくてもいいようです(むしろないほうが推奨?)。これが必要かどうかは、サンプル単位で損失の加重をつけるかどうかなので、そこらへんを気にしない場合はどっちでもいいと思います。

交差エントロピー部分

まずは交差エントロピーから。スライスしたのを直接binary_cross_entropyに放り込んではいけません。そのまま放り込むとランク1のテンソル同士の計算になり、交差エントロピーがスカラー(実数)になります。本来、サンプル間の確率の交差エントロピーを取りたいため、ここの出力はベクトルでほしいのです。奇妙なことが起こるかどうか実際に確認してみましょう。

>>> K.get_value(y_true)
array([[1. , 0.2, 0.2, 0.4, 0.5],
       [0. , 0. , 0. , 0. , 0. ],
       [1. , 0.7, 0.8, 0.3, 0.3]], dtype=float32)
>>> K.get_value(y_pred)
array([[0.9, 0.1, 0.3, 0.5, 0.7],
       [0.2, 0.7, 0.8, 0.1, 0.2],
       [0.7, 0.5, 0.8, 0.2, 0.3]], dtype=float32)

これは今説明のためにy_true, y_predに簡単な値を設定してみました。まずはダメな例から。

>>> K.get_value(binary_crossentropy(y_true[:,0], y_pred[:,0]))
0.228393

はい、ダメです。途中でランク1になってしまったので、サンプル間の確率がそのまま交差エントロピーの計算に放り込まれてしまったのです。ここはK.expand_dims()でスライスした後に次元を増やすのが正解です。

>>> K.get_value(binary_crossentropy(K.expand_dims(y_true[:,0]), K.expand_dims(y
pred[:,0])))
array([0.10536053, 0.22314355, 0.35667494], dtype=float32)

これが本来欲しかった、サンプル別の確率の交差エントロピーです。非常に間違いやすいところなのですが、ブロードキャスティングで値が出る場合は、こういうケースでも実行時に警告一切出してくれないのですよね。なので、しっかり別立てで確認していかないと思わぬ落とし穴になると思います。

ちなみにブロードキャスティングのバグを直したら訓練が1~2割高速化しました。Numpyのブロードキャスティングは気づく場合があっても、Kerasのテンソルのブロードキャスティングは知らないまま動かしていて、闇が深いような感じがします。これが原因で訓練結果がおかしかったら普通はデータやモデルのほうを疑ってしまうので、なかなかここまで気づけない。

平均二乗誤差部分

さて、損失関数の続きを。こちらはブロードキャスティングの問題がないので、比較的わかりやすいと思います。

    mse = K.sign(y_true[:,0]) * mean_squared_error(y_true[:, 1:], y_pred[:, 1:])

最初のsign関数ですが、ランドマークがない場合は続く平均二乗誤差を無視しますよ、という意味です。K.signは要素別の計算なので、ランク1のテンソルを代入してもスカラーになることはありません。

>>> K.get_value(K.sign(y_true[:,0]))
array([1., 0., 1.], dtype=float32)

続く平均二乗誤差ですが、スライスはしているものの複数列のスライスなのでランク2のテンソルのままです。したがって、集計機能のある(ランクが落ちる)mean_squared_errorに代入してもスカラーになることはありません。ランク1のテンソル(ベクトル)になります。

>>> K.get_value(mean_squared_error(y_true[:, 1:], y_pred[:, 1:]))
array([0.0175, 0.295 , 0.0125], dtype=float32)

交差エントロピーと平均二乗誤差を足す

あとは簡単ですね。最後にサンプル別のベクトルを何らかの集計関数にかけましょう。合計でも平均でもどっちでもいいと思うのですが、一応平均にしてみました(ミニバッチサイズが変わったときに統一性を出すため)。

    return K.mean(bce + mse)

一応数値例でも確認してみましょう。

>>> bce = binary_crossentropy(K.expand_dims(y_true[:,0]), K.expand_dims(y_pred
,0]))
>>> bce
<tf.Tensor 'Mean_22:0' shape=(3,) dtype=float32>
>>> mse = mean_squared_error(y_true[:, 1:], y_pred[:, 1:])
>>> mse
<tf.Tensor 'Mean_23:0' shape=(3,) dtype=float32>
>>> error = K.mean(bce + mse)
>>> error
<tf.Tensor 'Mean_24:0' shape=() dtype=float32>
>>> K.get_value(error)
0.33672634

うまくいきましたね。たった3行のコードでしたが、ランク1のテンソルという思わぬ落とし穴があり、テンソル演算の部分はきちんとデバッグしないといけないという闇の深さを思い知らされました。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

【新刊】インフィニティNumPy――配列の初期化から、ゲームの戦闘、静止画や動画作成までの221問

「本当の実装力を身につける」ための221本ノック――
機械学習(ML)で避けて通れない数値計算ライブラリ・NumPyを、自在に活用できるようになろう。「できる」ための体系的な理解を目指します。基礎から丁寧に解説し、ディープラーニング(DL)の難しいモデルで遭遇する、NumPyの黒魔術もカバー。初心者から経験者・上級者まで楽しめる一冊です。問題を解き終わったとき、MLやDLなどの発展分野にスムーズに入っていけるでしょう。

本書の大きな特徴として、Pythonの本でありがちな「NumPyとML・DLの結合を外した」点があります。NumPyを理解するのに、MLまで理解するのは負担が大きいです。本書ではあえてこれらの内容を書いていません。行列やテンソルの理解に役立つ「従来の画像処理」をNumPyベースで深く解説・実装していきます。

しかし、問題の多くは、DLの実装で頻出の関数・処理を重点的に取り上げています。経験者なら思わず「あー」となるでしょう。関数丸暗記では自分で実装できません。「覚える関数は最小限、できる内容は無限大」の世界をぜひ体験してみてください。画像編集ソフトの処理をNumPyベースで実装する楽しさがわかるでしょう。※紙の本は電子版の特典つき

モザイク除去から学ぶ 最先端のディープラーニング

「誰もが夢見るモザイク除去」を起点として、機械学習・ディープラーニングの基本をはじめ、GAN(敵対的生成ネットワーク)の基本や発展型、ICCV, CVPR, ECCVといった国際学会の最新論文をカバーしていく本です。
ディープラーニングの研究は発展が目覚ましく、特にGANの発展型は市販の本でほとんどカバーされていない内容です。英語の原著論文を著者がコードに落とし込み、実装を踏まえながら丁寧に解説していきます。
また、本コードは全てTensorFlow2.0(Keras)に対応し、Googleの開発した新しい機械学習向け計算デバイス・TPU(Tensor Processing Unit)をフル活用しています。Google Colaboratoryを用いた環境構築不要の演習問題もあるため、読者自ら手を動かしながら理解を深めていくことができます。

AI、機械学習、ディープラーニングの最新事情、奥深いGANの世界を知りたい方にとってぜひ手にとっていただきたい一冊となっています。持ち運びに便利な電子書籍のDLコードが付属しています。

「おもしろ同人誌バザールオンライン」で紹介されました!(14:03~) https://youtu.be/gaXkTj7T79Y?t=843

まとめURL:https://github.com/koshian2/MosaicDeeplearningBook
A4 全195ページ、カラー12ページ / 2020年3月発行

Shikoan's ML Blog -Vol.1/2-

累計100万PV超の人気ブログが待望の電子化! このブログが電子書籍になって読みやすくなりました!

・1章完結のオムニバス形式
・機械学習の基本からマニアックなネタまで
・どこから読んでもOK
・何巻から読んでもOK

・短いものは2ページ、長いものは20ページ超のものも…
・通勤・通学の短い時間でもすぐ読める!
・読むのに便利な「しおり」機能つき

・全巻はA5サイズでたっぷりの「200ページオーバー」
・1冊にたっぷり30本収録。1本あたり18.3円の圧倒的コストパフォーマンス!
・文庫本感覚でお楽しみください

北海道の駅巡りコーナー

日高本線 車なし全駅巡り

ローカル線や秘境駅、マニアックな駅に興味のある方におすすめ! 2021年に大半区間が廃線になる、北海道の日高本線の全区間・全29駅(苫小牧~様似)を記録した本です。マイカーを使わずに、公共交通機関(バス)と徒歩のみで全駅訪問を行いました。日高本線が延伸する計画のあった、襟裳岬まで様似から足を伸ばしています。代行バスと路線バスの織り成す極限の時刻表ゲームと、絶海の太平洋と馬に囲まれた日高路、日高の隠れたグルメを是非たっぷり堪能してください。A4・フルカラー・192ページのたっぷりのボリュームで、あなたも旅行気分を漫喫できること待ったなし!

見どころ:日高本線被災区間(大狩部、慶能舞川橋梁、清畠~豊郷) / 牧場に囲まれた絵笛駅 / 窓口のあっただるま駅・荻伏駅 / 汐見の戦争遺跡のトーチカ / 新冠温泉、三石温泉 / 襟裳岬

A4 全192ページフルカラー / 2020年11月発行


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です