こしあん
2018-11-14

KerasのModelCheckpointのsave_best_onlyは何を表すのか?

Pocket
LINEで送る
Delicious にシェア

5k{icon} {views}



Kerasには「モデルの精度が良くなったときだけ係数を保存する」のに便利なModelCheckpointというクラスがあります。ただこのsave_best_onlyがいまいち公式の解説だとピンとこないので調べてみました。

ModelCheckpointとは?

公式ドキュメントより

keras.callbacks.ModelCheckpoint(filepath, monitor=’val_loss’, verbose=0, save_best_only=False, save_weights_only=False, mode=’auto’, period=1)

各エポック終了後にモデルを保存します.

filepath、monitorとかはわかりやすいんでいいんですが、この「save_best_onlyって何?なんでしかもデフォルトでFalseになってるの?」というのが気になって仕方がありませんでした。この公式解説が紛らわしくて

save_best_only: save_best_only=Trueの場合,監視しているデータによって最新の最良モデルが上書きされません.

これを読むと、「じゃあsave_best_only=Falseなら、より精度なりが良くなったときにファイルを上書きしないでどんどん新しいファイルを作るの?」と思ってしまいます。自分は一時期そう思っていました。ただこれは勘違いです。

もしsave_best_only=False(デフォルト)で実行すると

試しにModelCheckpointの「save_best_only=False」で実行してみます。データはMNISTとします。

from keras.layers import Dense, Input, Flatten
from keras.models import Model
from keras.callbacks import ModelCheckpoint

from keras.datasets import mnist
from keras.utils import to_categorical

(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train, X_test = X_train/255.0, X_test/255.0
y_train, y_test = to_categorical(y_train), to_categorical(y_test)

input = Input((28, 28))
x = Flatten()(input)
x = Dense(128, activation="relu")(x)
x = Dense(10, activation="softmax")(x)

model = Model(input, x)
model.compile("adam", "categorical_crossentropy", ["acc"])

cp = ModelCheckpoint("weights.hdf5", monitor="val_loss", verbose=1,
                     save_best_only=False, save_weights_only=True)
model.fit(X_train, y_train, batch_size=128, epochs=20, callbacks=[cp],
        validation_data=(X_test, y_test))

ここではverbose=1としてログに書き出すようにしています。結果は次のようになります。

Epoch 00011: saving model to weights.hdf5
Epoch 12/20
60000/60000 [==============================] - 3s 49us/step - loss: 0.0208 - acc: 0.9950 - val_loss: 0.0717 - val_acc: 0.9780

Epoch 00012: saving model to weights.hdf5
Epoch 13/20
60000/60000 [==============================] - 3s 49us/step - loss: 0.0178 - acc: 0.9956 - val_loss: 0.0748 - val_acc: 0.9781

Epoch 00013: saving model to weights.hdf5
Epoch 14/20
60000/60000 [==============================] - 3s 49us/step - loss: 0.0146 - acc: 0.9968 - val_loss: 0.0745 - val_acc: 0.979

ここのval_lossを注意深く見てください。val_lossが0.0717→0.0748→0.0745と上がっているにもかかわらず、モデルが上書きされていますよね?。これがsave_best_only=Falseの正体です。

つまり、「save_best_only=False」なら、モニターしている評価値(この場合はval_loss)悪くなったときに上書きしないということではなくて、periodの引数で指定したのエポックごとに必ず保存するという意味なのです。試しにperiod=2としてみましょう。

Epoch 17/20
60000/60000 [==============================] - 3s 49us/step - loss: 0.0101 - acc: 0.9981 - val_loss: 0.0806 - val_acc: 0.9768
Epoch 18/20
60000/60000 [==============================] - 3s 48us/step - loss: 0.0088 - acc: 0.9984 - val_loss: 0.0842 - val_acc: 0.9761

Epoch 00018: saving model to weights.hdf5
Epoch 19/20
60000/60000 [==============================] - 3s 48us/step - loss: 0.0070 - acc: 0.9990 - val_loss: 0.0790 - val_acc: 0.9795
Epoch 20/20
60000/60000 [==============================] - 3s 50us/step - loss: 0.0077 - acc: 0.9984 - val_loss: 0.0877 - val_acc: 0.9771

Epoch 00020: saving model to weights.hdf5

「save_best_only=False, period=2」なら、モニターしている評価値が上がろうが下がろうが、2エポックごとに自動的に保存しているというのが確認できました。

一番いい精度を頼みたい場合(save_best_only=Trueなら)

GANとかモデルが崩壊するおそれのあるケースでもない限り、かなりの場合で「一番いい精度(少ない損失)の係数だけ保存していればいい」と思います。その場合は、save_best_only=Trueにしましょう。これをすることで想定された挙動になります。

ModelCheckpointの設定を以下のように変えます。デフォルトでsave_best_only=Falseなので、設定変更は必須です。

cp = ModelCheckpoint("weights.hdf5", monitor="val_loss", verbose=1,
                     save_best_only=True, save_weights_only=True)
Epoch 13/20
60000/60000 [==============================] - 3s 50us/step - loss: 0.0189 - acc: 0.9954 - val_loss: 0.0675 - val_acc: 0.9798

Epoch 00013: val_loss improved from 0.06868 to 0.06753, saving model to weights.hdf5
Epoch 14/20
60000/60000 [==============================] - 3s 50us/step - loss: 0.0167 - acc: 0.9960 - val_loss: 0.0744 - val_acc: 0.9773

Epoch 00014: val_loss did not improve from 0.06753

このように悪化した場合は「val_loss did not improve from 0.06753」と表示され、改善したときだけ保存されます。おそらくこれが大多数の人がやりたかったことではないでしょうか。

ちなみにモニターする評価値には訓練損失(loss)や訓練精度(acc)を入れることもできます。試しに訓練精度を入れてみましょう。もし訓練損失でいいときは、「monitor=”loss”」でOKです。

cp = ModelCheckpoint("weights.hdf5", monitor="acc", verbose=1,
                     save_best_only=True, save_weights_only=True)
Epoch 1/20
60000/60000 [==============================] - 4s 66us/step - loss: 0.3603 - acc: 0.9002 - val_loss: 0.1877 - val_acc: 0.9450

Epoch 00001: acc improved from -inf to 0.90017, saving model to weights.hdf5

このようにOKですね。デフォルトで組み込みの損失関数や評価関数では、大小の変化で改善か悪化かを自動的に判断してくれるようです。自作の関数だと「mode」の引数で設定必須だと思います

まとめ

  • ModelCheckpointを使うとモデルの係数を自動的に保存してくれる
  • しかし、デフォルトでsave_best_only=Falseなので、改善しようが悪化しようがなんでもかんでも上書きしようとする。save_best_only=Trueへの変更を忘れずに


Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

【新刊】インフィニティNumPy――配列の初期化から、ゲームの戦闘、静止画や動画作成までの221問

「本当の実装力を身につける」ための221本ノック――
機械学習(ML)で避けて通れない数値計算ライブラリ・NumPyを、自在に活用できるようになろう。「できる」ための体系的な理解を目指します。基礎から丁寧に解説し、ディープラーニング(DL)の難しいモデルで遭遇する、NumPyの黒魔術もカバー。初心者から経験者・上級者まで楽しめる一冊です。問題を解き終わったとき、MLやDLなどの発展分野にスムーズに入っていけるでしょう。

本書の大きな特徴として、Pythonの本でありがちな「NumPyとML・DLの結合を外した」点があります。NumPyを理解するのに、MLまで理解するのは負担が大きいです。本書ではあえてこれらの内容を書いていません。行列やテンソルの理解に役立つ「従来の画像処理」をNumPyベースで深く解説・実装していきます。

しかし、問題の多くは、DLの実装で頻出の関数・処理を重点的に取り上げています。経験者なら思わず「あー」となるでしょう。関数丸暗記では自分で実装できません。「覚える関数は最小限、できる内容は無限大」の世界をぜひ体験してみてください。画像編集ソフトの処理をNumPyベースで実装する楽しさがわかるでしょう。※紙の本は電子版の特典つき

モザイク除去から学ぶ 最先端のディープラーニング

「誰もが夢見るモザイク除去」を起点として、機械学習・ディープラーニングの基本をはじめ、GAN(敵対的生成ネットワーク)の基本や発展型、ICCV, CVPR, ECCVといった国際学会の最新論文をカバーしていく本です。
ディープラーニングの研究は発展が目覚ましく、特にGANの発展型は市販の本でほとんどカバーされていない内容です。英語の原著論文を著者がコードに落とし込み、実装を踏まえながら丁寧に解説していきます。
また、本コードは全てTensorFlow2.0(Keras)に対応し、Googleの開発した新しい機械学習向け計算デバイス・TPU(Tensor Processing Unit)をフル活用しています。Google Colaboratoryを用いた環境構築不要の演習問題もあるため、読者自ら手を動かしながら理解を深めていくことができます。

AI、機械学習、ディープラーニングの最新事情、奥深いGANの世界を知りたい方にとってぜひ手にとっていただきたい一冊となっています。持ち運びに便利な電子書籍のDLコードが付属しています。

「おもしろ同人誌バザールオンライン」で紹介されました!(14:03~) https://youtu.be/gaXkTj7T79Y?t=843

まとめURL:https://github.com/koshian2/MosaicDeeplearningBook
A4 全195ページ、カラー12ページ / 2020年3月発行

Shikoan's ML Blog -Vol.1/2-

累計100万PV超の人気ブログが待望の電子化! このブログが電子書籍になって読みやすくなりました!

・1章完結のオムニバス形式
・機械学習の基本からマニアックなネタまで
・どこから読んでもOK
・何巻から読んでもOK

・短いものは2ページ、長いものは20ページ超のものも…
・通勤・通学の短い時間でもすぐ読める!
・読むのに便利な「しおり」機能つき

・全巻はA5サイズでたっぷりの「200ページオーバー」
・1冊にたっぷり30本収録。1本あたり18.3円の圧倒的コストパフォーマンス!
・文庫本感覚でお楽しみください

北海道の駅巡りコーナー

日高本線 車なし全駅巡り

ローカル線や秘境駅、マニアックな駅に興味のある方におすすめ! 2021年に大半区間が廃線になる、北海道の日高本線の全区間・全29駅(苫小牧~様似)を記録した本です。マイカーを使わずに、公共交通機関(バス)と徒歩のみで全駅訪問を行いました。日高本線が延伸する計画のあった、襟裳岬まで様似から足を伸ばしています。代行バスと路線バスの織り成す極限の時刻表ゲームと、絶海の太平洋と馬に囲まれた日高路、日高の隠れたグルメを是非たっぷり堪能してください。A4・フルカラー・192ページのたっぷりのボリュームで、あなたも旅行気分を漫喫できること待ったなし!

見どころ:日高本線被災区間(大狩部、慶能舞川橋梁、清畠~豊郷) / 牧場に囲まれた絵笛駅 / 窓口のあっただるま駅・荻伏駅 / 汐見の戦争遺跡のトーチカ / 新冠温泉、三石温泉 / 襟裳岬

A4 全192ページフルカラー / 2020年11月発行


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です