Kerasで損失関数に複数の変数を渡す方法
Kerasで少し複雑なモデルを訓練させるときに、損失関数にy_true, y_pred以外の値を渡したいときがあります。クラスのインスタンス変数などでキャッシュさせることなく、ダイレクトに損失関数に複数の値を渡す方法を紹介します。
元ネタ:Passing additional arguments to objective function #2121
目次
損失関数を別の関数でラップさせよう
まずはサンプルモデルから。多層パーセプトロンによる簡単なMNISTの例です。
from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.objectives import categorical_crossentropy
import numpy as np
(X, y), (_, _) = mnist.load_data()
X = (X / 255.0).reshape(-1, 784).astype(np.float32)
y = to_categorical(y).astype(np.float32)
input = Input((784,))
x = Dense(64, activation="relu")(input)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)
今説明のために損失関数を、「通常のクラス間の交差エントロピー+定数」で定義します。複数の値を渡すための説明用に定数を導入したのであって、これ自体にモデル的な意味はありません。実際使うときはモデルに合わせて適宜変えてください。
def outer_loss(args):
def loss_function(y_true, y_pred):
return categorical_crossentropy(y_true, y_pred) + args
return loss_function
使い方はそのまま。交差エントロピーのロスに+10した値を表示してみましょう。
model.compile(Adam(), loss=outer_loss(10), metrics=["acc"])
model.fit(X, y, batch_size=128, epochs=3)
結果は次のようになります。
40448/60000 [===================>..........] - ETA: 0s - loss: 10.1593 - acc: 0.
45312/60000 [=====================>........] - ETA: 0s - loss: 10.1582 - acc: 0.
50176/60000 [========================>.....] - ETA: 0s - loss: 10.1594 - acc: 0.
54912/60000 [==========================>...] - ETA: 0s - loss: 10.1577 - acc: 0.
59776/60000 [============================>.] - ETA: 0s - loss: 10.1566 - acc: 0.
60000/60000 [==============================] - 1s 14us/step - loss: 10.1564 - ac
c: 0.9561
いい感じに+10されていますね。ではこのプラスを+5にしてみましょう。
model.compile(Adam(), loss=outer_loss(5), metrics=["acc"])
40448/60000 [===================>..........] - ETA: 0s - loss: 5.1530 - acc: 0.9
45056/60000 [=====================>........] - ETA: 0s - loss: 5.1523 - acc: 0.9
49152/60000 [=======================>......] - ETA: 0s - loss: 5.1519 - acc: 0.9
53376/60000 [=========================>....] - ETA: 0s - loss: 5.1510 - acc: 0.9
57856/60000 [===========================>..] - ETA: 0s - loss: 5.1495 - acc: 0.9
60000/60000 [==============================] - 1s 14us/step - loss: 5.1491 - acc
: 0.9576
良いですね。ちなみにこの+の値は分類にはなんの影響も与えていないので、値を変えたからといって精度がよくなるということはありません。精度が変わるのは今回の場合、初期値ガチャのせいです。
なぜうまく行くのか
最初不思議だったのですが、次のように考えると理解できました。
まず普通のカスタム損失関数の場合。loss=”categorical_crossentropy”と指定した場合もそうかもしれませんが、次のようなコードが動いているはずです。
def loss_function(y_true, y_pred):
# 何らかのlossの計算
return loss
func_obj = loss_function
loss = func_obj(y_true, y_pred)
疑似コードでの説明ですが、model.compile()で指定した損失関数は内部的に「func_obj=loss_function」の部分のように代入されるのだと思われます。そこで代入された関数のオブジェクトを動的に読み出して計算をし、訓練をしています。
では先程のようにネストした場合はどうかというと、
def outer_loss(args):
def loss_function(y_true, y_pred):
return categorical_crossentropy(y_true, y_pred) + args
return loss_function
func_obj = outer_loss(10)
loss = func_obj(y_true, y_pred)
となります。このouter_loss(10)がloss_functionの参照を返すために、このように定義すると損失が計算できるというわけですね。
ただテクニックはmodel.saveをするときは気をつけてください。issueにも書かれていますが、model.saveするときにネストした損失関数をシリアル化できない可能性があります。自分が試した限りでは上手く行ったのですが、もし複雑なモデルだとエラーが出るかもしれません。model.saveではなく、model.save_weightsでhdf5形式で保存すればおそらく大丈夫だと思います。
やっていることはクロージャー
これを見て思うのが、「(JavaScriptでよく使う)クロージャと同じじゃね?」と気づきます。そうです、まるっきりそっくりです。ただJavaScriptのクロージャーはメモリリークの温床となりやすいので、関数の参照(ID)だけ見ておきます。
サンプルコードです。
def outer(args):
def inner(str):
print(args, str)
func_obj = inner
print(func_obj)
return func_obj
先ほどとほとんど同じですが、innerの関数の参照を表示できるようにしてみました。このinner関数を複数回読んでみます。1つ目、ループのたびに代入する。
for i in range(3):
tmp = outer(10)
tmp("Hello, World")
<function outer.<locals>.inner at 0x00000000031730D0>
10 Hello, World
<function outer.<locals>.inner at 0x0000000003173158>
10 Hello, World
<function outer.<locals>.inner at 0x00000000031730D0>
10 Hello, World
代入によって元のouterのオブジェクトが変わってしまったために、ループのたびにinnerの参照が変わってしまいましたね。ガベージコレクション(GC)が機能して上手くいくかもしれませんが、これはメモリリーク起こしやすいので避けたほうがいいパターンではないかと思います。
次にループの外側でouterを代入し、呼び出しだけループ内で書くパターン。
tmp = outer(10)
for i in range(3):
tmp("Hello, World")
<function outer.<locals>.inner at 0x00000000036C30D0>
10 Hello, World
10 Hello, World
10 Hello, World
これは代入が1回だけなので、共通のinner関数のオブジェクトを呼び出しているのがわかります。これは良いパターンです。おそらくKerasもこのように定義しているのだと思います。
以上です。なかなかトリッキーな書き方ですが、損失関数にキャッシュなしで複数の値を与えることができました。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー