こしあん
2018-10-19

Kerasで損失関数に複数の変数を渡す方法

Pocket
LINEで送る


Kerasで少し複雑なモデルを訓練させるときに、損失関数にy_true, y_pred以外の値を渡したいときがあります。クラスのインスタンス変数などでキャッシュさせることなく、ダイレクトに損失関数に複数の値を渡す方法を紹介します。

元ネタ:Passing additional arguments to objective function #2121

損失関数を別の関数でラップさせよう

まずはサンプルモデルから。多層パーセプトロンによる簡単なMNISTの例です。

from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.utils import to_categorical
from keras.objectives import categorical_crossentropy
import numpy as np

(X, y), (_, _) = mnist.load_data()
X = (X / 255.0).reshape(-1, 784).astype(np.float32)
y = to_categorical(y).astype(np.float32)

input = Input((784,))
x = Dense(64, activation="relu")(input)
x = Dense(10, activation="softmax")(x)
model = Model(input, x)

今説明のために損失関数を、「通常のクラス間の交差エントロピー+定数」で定義します。複数の値を渡すための説明用に定数を導入したのであって、これ自体にモデル的な意味はありません。実際使うときはモデルに合わせて適宜変えてください。

def outer_loss(args):
    def loss_function(y_true, y_pred):
        return categorical_crossentropy(y_true, y_pred) + args
    return loss_function

使い方はそのまま。交差エントロピーのロスに+10した値を表示してみましょう。

model.compile(Adam(), loss=outer_loss(10), metrics=["acc"])
model.fit(X, y, batch_size=128, epochs=3)

結果は次のようになります。

40448/60000 [===================>..........] - ETA: 0s - loss: 10.1593 - acc: 0.
45312/60000 [=====================>........] - ETA: 0s - loss: 10.1582 - acc: 0.
50176/60000 [========================>.....] - ETA: 0s - loss: 10.1594 - acc: 0.
54912/60000 [==========================>...] - ETA: 0s - loss: 10.1577 - acc: 0.
59776/60000 [============================>.] - ETA: 0s - loss: 10.1566 - acc: 0.
60000/60000 [==============================] - 1s 14us/step - loss: 10.1564 - ac
c: 0.9561

いい感じに+10されていますね。ではこのプラスを+5にしてみましょう。

model.compile(Adam(), loss=outer_loss(5), metrics=["acc"])
40448/60000 [===================>..........] - ETA: 0s - loss: 5.1530 - acc: 0.9
45056/60000 [=====================>........] - ETA: 0s - loss: 5.1523 - acc: 0.9
49152/60000 [=======================>......] - ETA: 0s - loss: 5.1519 - acc: 0.9
53376/60000 [=========================>....] - ETA: 0s - loss: 5.1510 - acc: 0.9
57856/60000 [===========================>..] - ETA: 0s - loss: 5.1495 - acc: 0.9
60000/60000 [==============================] - 1s 14us/step - loss: 5.1491 - acc
: 0.9576

良いですね。ちなみにこの+の値は分類にはなんの影響も与えていないので、値を変えたからといって精度がよくなるということはありません。精度が変わるのは今回の場合、初期値ガチャのせいです。

なぜうまく行くのか

最初不思議だったのですが、次のように考えると理解できました。

まず普通のカスタム損失関数の場合。loss=”categorical_crossentropy”と指定した場合もそうかもしれませんが、次のようなコードが動いているはずです。

def loss_function(y_true, y_pred):
    # 何らかのlossの計算
    return loss

func_obj = loss_function
loss = func_obj(y_true, y_pred)

疑似コードでの説明ですが、model.compile()で指定した損失関数は内部的に「func_obj=loss_function」の部分のように代入されるのだと思われます。そこで代入された関数のオブジェクトを動的に読み出して計算をし、訓練をしています。

では先程のようにネストした場合はどうかというと、

def outer_loss(args):
    def loss_function(y_true, y_pred):
        return categorical_crossentropy(y_true, y_pred) + args
    return loss_function

func_obj = outer_loss(10)
loss = func_obj(y_true, y_pred)

となります。このouter_loss(10)がloss_functionの参照を返すために、このように定義すると損失が計算できるというわけですね。

ただテクニックはmodel.saveをするときは気をつけてください。issueにも書かれていますが、model.saveするときにネストした損失関数をシリアル化できない可能性があります。自分が試した限りでは上手く行ったのですが、もし複雑なモデルだとエラーが出るかもしれません。model.saveではなく、model.save_weightsでhdf5形式で保存すればおそらく大丈夫だと思います。

やっていることはクロージャー

これを見て思うのが、「(JavaScriptでよく使う)クロージャと同じじゃね?」と気づきます。そうです、まるっきりそっくりです。ただJavaScriptのクロージャーはメモリリークの温床となりやすいので、関数の参照(ID)だけ見ておきます。

サンプルコードです。

def outer(args):
    def inner(str):
        print(args, str)
    func_obj = inner
    print(func_obj)
    return func_obj

先ほどとほとんど同じですが、innerの関数の参照を表示できるようにしてみました。このinner関数を複数回読んでみます。1つ目、ループのたびに代入する。

for i in range(3):
    tmp = outer(10)
    tmp("Hello, World")
<function outer.<locals>.inner at 0x00000000031730D0>
10 Hello, World
<function outer.<locals>.inner at 0x0000000003173158>
10 Hello, World
<function outer.<locals>.inner at 0x00000000031730D0>
10 Hello, World

代入によって元のouterのオブジェクトが変わってしまったために、ループのたびにinnerの参照が変わってしまいましたね。ガベージコレクション(GC)が機能して上手くいくかもしれませんが、これはメモリリーク起こしやすいので避けたほうがいいパターンではないかと思います。

次にループの外側でouterを代入し、呼び出しだけループ内で書くパターン。

tmp = outer(10)
for i in range(3):
    tmp("Hello, World")
<function outer.<locals>.inner at 0x00000000036C30D0>
10 Hello, World
10 Hello, World
10 Hello, World

これは代入が1回だけなので、共通のinner関数のオブジェクトを呼び出しているのがわかります。これは良いパターンです。おそらくKerasもこのように定義しているのだと思います。

以上です。なかなかトリッキーな書き方ですが、損失関数にキャッシュなしで複数の値を与えることができました。

Related Posts

PyTorchで行列(テンソル)積としてConv2dを使う... PyTorchではmatmulの挙動が特殊なので、思った通りにテンソル積が取れないことがあります。この記事では、基本的な畳み込み演算である「Conv2D」を使い、Numpyのドット積相当の演算を行うという方法を解説します。 はじめに PyTorchの変態コーディング技術です。多分。 画像のテ...
TensorFlow/Kerasでグラム行列(テンソル)を計算する方法... TensorFlowで分散や共分散が絡む演算を定義していると、グラム行列を計算する必要が出てくることがあります。行列はまだよくてもテンソルのグラム行列はどう計算するでしょうか?今回はテンソルの共分散計算に行く前に、その前提のテンソルのグラム行列の計算から見ていきます。 グラム行列とは 名前は仰...
KerasのModelCheckpointのsave_best_onlyは何を表すのか?... Kerasには「モデルの精度が良くなったときだけ係数を保存する」のに便利なModelCheckpointというクラスがあります。ただこのsave_best_onlyがいまいち公式の解説だとピンとこないので調べてみました。 ModelCheckpointとは? 公式ドキュメントより ke...
転移学習でネットワーク内でアップサンプリングする方法(Keras)... 転移学習でインプットのサイズを揃えなければいけないことがありますが、これをRAM(CPU)上でやるとメモリが不足することがあります。転移学習の重みをそのまま使い、事前にアップサンプリングレイヤーを差し込む方法を紹介します。 関連記事とバックグラウンド まず前提知識としてCPU側でアップサンプリ...
TPUでも大きなバッチサイズに対して精度を出す... TPUでは大きなバッチサイズを適用することが訓練の高速化につながりますが、これは精度と引き換えになることがあります。大きなバッチサイズでも精度を出す方法を論文をもとに調べてみました。 背景 Qiitaに書いたGoogle Brainの論文「学習率を落とすな、バッチサイズを増やせ」を読むの続き。...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です