こしあん
2018-10-19

Kerasで損失関数に複数の変数を渡す方法

Pocket
LINEで送る
Delicious にシェア

4.4k{icon} {views}



Kerasで少し複雑なモデルを訓練させるときに、損失関数にy_true, y_pred以外の値を渡したいときがあります。クラスのインスタンス変数などでキャッシュさせることなく、ダイレクトに損失関数に複数の値を渡す方法を紹介します。

元ネタ:Passing additional arguments to objective function #2121

損失関数を別の関数でラップさせよう

まずはサンプルモデルから。多層パーセプトロンによる簡単なMNISTの例です。

from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.objectives import categorical_crossentropy
import numpy as np

(X, y), (_, _) = mnist.load_data()
X = (X / 255.0).reshape(-1, 784).astype(np.float32)
y = to_categorical(y).astype(np.float32)

input = Input((784,))
x = Dense(64, activation="relu")(input)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)

今説明のために損失関数を、「通常のクラス間の交差エントロピー+定数」で定義します。複数の値を渡すための説明用に定数を導入したのであって、これ自体にモデル的な意味はありません。実際使うときはモデルに合わせて適宜変えてください。

def outer_loss(args):
    def loss_function(y_true, y_pred):
        return categorical_crossentropy(y_true, y_pred) + args
    return loss_function

使い方はそのまま。交差エントロピーのロスに+10した値を表示してみましょう。

model.compile(Adam(), loss=outer_loss(10), metrics=["acc"])
model.fit(X, y, batch_size=128, epochs=3)

結果は次のようになります。

40448/60000 [===================>..........] - ETA: 0s - loss: 10.1593 - acc: 0.
45312/60000 [=====================>........] - ETA: 0s - loss: 10.1582 - acc: 0.
50176/60000 [========================>.....] - ETA: 0s - loss: 10.1594 - acc: 0.
54912/60000 [==========================>...] - ETA: 0s - loss: 10.1577 - acc: 0.
59776/60000 [============================>.] - ETA: 0s - loss: 10.1566 - acc: 0.
60000/60000 [==============================] - 1s 14us/step - loss: 10.1564 - ac
c: 0.9561

いい感じに+10されていますね。ではこのプラスを+5にしてみましょう。

model.compile(Adam(), loss=outer_loss(5), metrics=["acc"])
40448/60000 [===================>..........] - ETA: 0s - loss: 5.1530 - acc: 0.9
45056/60000 [=====================>........] - ETA: 0s - loss: 5.1523 - acc: 0.9
49152/60000 [=======================>......] - ETA: 0s - loss: 5.1519 - acc: 0.9
53376/60000 [=========================>....] - ETA: 0s - loss: 5.1510 - acc: 0.9
57856/60000 [===========================>..] - ETA: 0s - loss: 5.1495 - acc: 0.9
60000/60000 [==============================] - 1s 14us/step - loss: 5.1491 - acc
: 0.9576

良いですね。ちなみにこの+の値は分類にはなんの影響も与えていないので、値を変えたからといって精度がよくなるということはありません。精度が変わるのは今回の場合、初期値ガチャのせいです。

なぜうまく行くのか

最初不思議だったのですが、次のように考えると理解できました。

まず普通のカスタム損失関数の場合。loss=”categorical_crossentropy”と指定した場合もそうかもしれませんが、次のようなコードが動いているはずです。

def loss_function(y_true, y_pred):
    # 何らかのlossの計算
    return loss

func_obj = loss_function
loss = func_obj(y_true, y_pred)

疑似コードでの説明ですが、model.compile()で指定した損失関数は内部的に「func_obj=loss_function」の部分のように代入されるのだと思われます。そこで代入された関数のオブジェクトを動的に読み出して計算をし、訓練をしています。

では先程のようにネストした場合はどうかというと、

def outer_loss(args):
    def loss_function(y_true, y_pred):
        return categorical_crossentropy(y_true, y_pred) + args
    return loss_function

func_obj = outer_loss(10)
loss = func_obj(y_true, y_pred)

となります。このouter_loss(10)がloss_functionの参照を返すために、このように定義すると損失が計算できるというわけですね。

ただテクニックはmodel.saveをするときは気をつけてください。issueにも書かれていますが、model.saveするときにネストした損失関数をシリアル化できない可能性があります。自分が試した限りでは上手く行ったのですが、もし複雑なモデルだとエラーが出るかもしれません。model.saveではなく、model.save_weightsでhdf5形式で保存すればおそらく大丈夫だと思います。

やっていることはクロージャー

これを見て思うのが、「(JavaScriptでよく使う)クロージャと同じじゃね?」と気づきます。そうです、まるっきりそっくりです。ただJavaScriptのクロージャーはメモリリークの温床となりやすいので、関数の参照(ID)だけ見ておきます。

サンプルコードです。

def outer(args):
    def inner(str):
        print(args, str)
    func_obj = inner
    print(func_obj)
    return func_obj

先ほどとほとんど同じですが、innerの関数の参照を表示できるようにしてみました。このinner関数を複数回読んでみます。1つ目、ループのたびに代入する。

for i in range(3):
    tmp = outer(10)
    tmp("Hello, World")
<function outer.<locals>.inner at 0x00000000031730D0>
10 Hello, World
<function outer.<locals>.inner at 0x0000000003173158>
10 Hello, World
<function outer.<locals>.inner at 0x00000000031730D0>
10 Hello, World

代入によって元のouterのオブジェクトが変わってしまったために、ループのたびにinnerの参照が変わってしまいましたね。ガベージコレクション(GC)が機能して上手くいくかもしれませんが、これはメモリリーク起こしやすいので避けたほうがいいパターンではないかと思います。

次にループの外側でouterを代入し、呼び出しだけループ内で書くパターン。

tmp = outer(10)
for i in range(3):
    tmp("Hello, World")
<function outer.<locals>.inner at 0x00000000036C30D0>
10 Hello, World
10 Hello, World
10 Hello, World

これは代入が1回だけなので、共通のinner関数のオブジェクトを呼び出しているのがわかります。これは良いパターンです。おそらくKerasもこのように定義しているのだと思います。

以上です。なかなかトリッキーな書き方ですが、損失関数にキャッシュなしで複数の値を与えることができました。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

【新刊】インフィニティNumPy――配列の初期化から、ゲームの戦闘、静止画や動画作成までの221問

「本当の実装力を身につける」ための221本ノック――
機械学習(ML)で避けて通れない数値計算ライブラリ・NumPyを、自在に活用できるようになろう。「できる」ための体系的な理解を目指します。基礎から丁寧に解説し、ディープラーニング(DL)の難しいモデルで遭遇する、NumPyの黒魔術もカバー。初心者から経験者・上級者まで楽しめる一冊です。問題を解き終わったとき、MLやDLなどの発展分野にスムーズに入っていけるでしょう。

本書の大きな特徴として、Pythonの本でありがちな「NumPyとML・DLの結合を外した」点があります。NumPyを理解するのに、MLまで理解するのは負担が大きいです。本書ではあえてこれらの内容を書いていません。行列やテンソルの理解に役立つ「従来の画像処理」をNumPyベースで深く解説・実装していきます。

しかし、問題の多くは、DLの実装で頻出の関数・処理を重点的に取り上げています。経験者なら思わず「あー」となるでしょう。関数丸暗記では自分で実装できません。「覚える関数は最小限、できる内容は無限大」の世界をぜひ体験してみてください。画像編集ソフトの処理をNumPyベースで実装する楽しさがわかるでしょう。※紙の本は電子版の特典つき

モザイク除去から学ぶ 最先端のディープラーニング

「誰もが夢見るモザイク除去」を起点として、機械学習・ディープラーニングの基本をはじめ、GAN(敵対的生成ネットワーク)の基本や発展型、ICCV, CVPR, ECCVといった国際学会の最新論文をカバーしていく本です。
ディープラーニングの研究は発展が目覚ましく、特にGANの発展型は市販の本でほとんどカバーされていない内容です。英語の原著論文を著者がコードに落とし込み、実装を踏まえながら丁寧に解説していきます。
また、本コードは全てTensorFlow2.0(Keras)に対応し、Googleの開発した新しい機械学習向け計算デバイス・TPU(Tensor Processing Unit)をフル活用しています。Google Colaboratoryを用いた環境構築不要の演習問題もあるため、読者自ら手を動かしながら理解を深めていくことができます。

AI、機械学習、ディープラーニングの最新事情、奥深いGANの世界を知りたい方にとってぜひ手にとっていただきたい一冊となっています。持ち運びに便利な電子書籍のDLコードが付属しています。

「おもしろ同人誌バザールオンライン」で紹介されました!(14:03~) https://youtu.be/gaXkTj7T79Y?t=843

まとめURL:https://github.com/koshian2/MosaicDeeplearningBook
A4 全195ページ、カラー12ページ / 2020年3月発行

Shikoan's ML Blog -Vol.1/2-

累計100万PV超の人気ブログが待望の電子化! このブログが電子書籍になって読みやすくなりました!

・1章完結のオムニバス形式
・機械学習の基本からマニアックなネタまで
・どこから読んでもOK
・何巻から読んでもOK

・短いものは2ページ、長いものは20ページ超のものも…
・通勤・通学の短い時間でもすぐ読める!
・読むのに便利な「しおり」機能つき

・全巻はA5サイズでたっぷりの「200ページオーバー」
・1冊にたっぷり30本収録。1本あたり18.3円の圧倒的コストパフォーマンス!
・文庫本感覚でお楽しみください

北海道の駅巡りコーナー

日高本線 車なし全駅巡り

ローカル線や秘境駅、マニアックな駅に興味のある方におすすめ! 2021年に大半区間が廃線になる、北海道の日高本線の全区間・全29駅(苫小牧~様似)を記録した本です。マイカーを使わずに、公共交通機関(バス)と徒歩のみで全駅訪問を行いました。日高本線が延伸する計画のあった、襟裳岬まで様似から足を伸ばしています。代行バスと路線バスの織り成す極限の時刻表ゲームと、絶海の太平洋と馬に囲まれた日高路、日高の隠れたグルメを是非たっぷり堪能してください。A4・フルカラー・192ページのたっぷりのボリュームで、あなたも旅行気分を漫喫できること待ったなし!

見どころ:日高本線被災区間(大狩部、慶能舞川橋梁、清畠~豊郷) / 牧場に囲まれた絵笛駅 / 窓口のあっただるま駅・荻伏駅 / 汐見の戦争遺跡のトーチカ / 新冠温泉、三石温泉 / 襟裳岬

A4 全192ページフルカラー / 2020年11月発行


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です