こしあん
2019-09-24

argparseに直接dictを読み込ませる怪しいやり方

Pocket
LINEで送る


argparseにコマンドライン引数ではなく、ファイルから読み込んだdictをオーバーラップさせる方法を試してみました。本来のargparseの使い方ではない怪しいやり方ですが、JSONやyamlファイルとの連携が可能なので便利ではないかなと思います。

注意

これは本来のargparseの使い方ではありません。この例では動作しますが、argparseの型や値の保証のような機能が完全に損なわれる場合があります。

やりたいこと

ハイパーパラメータをソースに直書きするのあんまりよくないよね?
→ argparseを使えばいいけど、コマンドライン引数で指定するのめんどくさいよね
→ ただargparseってデフォルト値や、値の説明を放り込めて便利だよね
→ なら設定値を別ファイル(yamlやJSON)に直書きして、それをargparseに放り込みたいよね

というのがきっかけ。こういうのないかなーみたいなことつぶやいたら即教えていただきました(ふぁむたろうさんありがとうございました)

サンプルコード

この例では、読み込ませるdictを「dummy」としてハードコーディングしていますが、実際はファイルから読み込ませると良いでしょう。

import argparse

p = argparse.ArgumentParser(description="load from dictionary sample")
p.add_argument("--batch_size", help="batch size", default=128)
p.add_argument("--learning_rate", help="initial leraning rate", default=0.1)

args = p.parse_args()

def update_args(args_instance):
    # ダミーのdict : yamlやJSONから読んでも良い
    dummy = {
        "batch_size": 256,
        "momentum": 0.9
    }
    # この方法だとなんでも代入できてしまうので、add_argumentで指定したもののみ代入する
    for key, value in dummy.items():
        if args_instance.__contains__(key):
            args_instance.__setattr__(key, value)
    return args_instance

print("before loading")
print(args)
update_args(args) # dictからの値の読み込み
print("after loading")
print(args)

まずは普通にparse_args()でコマンドライン引数を読んでいきます。ただし、パースはあまり意味がなくて、後のdictからの読み込みで上書きされてしまいます(ここでやっていることは、実質ただの初期化)。

dictから読み込んで上書きするのが「update_args」という関数です。この例ではちょっと意地悪をして、add_argumentで定義をしていない「momentum」というセットを含んでいます。

argparseのインスタンスにダイレクトに代入するには、「setattr」の関数を使うとできます。ただし、この関数はセットするキーや値のチェックもしないので、何でも代入できてしまいます。そこで、「contains」という関数を使って、add_argumentで定義したキーのみアップデートするようにしています。

この結果は次のようになります。

before loading
Namespace(batch_size=128, learning_rate=0.1)
after loading
Namespace(batch_size=256, learning_rate=0.1)

batch_sizeがdummyで指定した256にアップデートされました。

ちなみに、このupdate_argsという関数は、add_argumentやハイパーパラメータの条件に依存しないので、例えばutilsに入れてtrain側で読み出せば、事足りることになります(引数に設定ファイルのパスでも入れておくといいですね)。

さらに怪しい使い方

ちなみに「setattr」は本当になんでも代入できてしまうので、これを逆手に取り、loggingなどの訓練結果を放り込んで保存するというやり方があるそうです

確かにこれなら、ハイパーパラメータから訓練結果まで一つのファイルで管理可能です。最後にargsだけ保存してしまえば、訓練結果をそれっぽく管理できるということになります。本来のargparseの使い方からは思いっきりかけ離れていますが、ちょっとした機械学習の運用としてなら便利でしょう。

Related Posts

KerasのCallbackを使って継承したImageDataGeneratorに値が渡せるか確かめ... Kerasで前処理の内容をエポックごとに変えたいというケースがたまにあります。これを実装するとなると、CallbackからGeneratorに値を渡すというコードになりますが、これが本当にできるかどうか確かめてみました。 想定する状況 例えば、前処理で正則化に関係するData Augmenta...
TensorFlow2.0のGradientTapeを複数使う場合のサンプル... TF2.0系で少し複雑なモデルを訓練するときに、GradientTapeを複数使うことがあります。例として、微分を取りたい場所が2箇所あるケースや、2階微分を取りたいケースを挙げます。場合によって微妙に書き方が違うので注意が必要です。 微分を取りたい場所が2箇所あるケース 簡単な例にします。次...
Numpyの配列をN個飛ばしで列挙する簡単な方法... Numpyの配列から奇数番目、偶数番目の要素を取り出したいときが稀によくあります。インデックスの配列を定義する必要があるのかなと思いますが、とても簡単な方法があります。それを見ていきましょう。 基本は「::スキップしたい間隔」 例として、0~9までの配列をとります。 >>>...
モルフォロジー変換は実はMaxPoolingだったという話(TensorFlowでの実装)... 画像処理の重要な変換に膨張(Dilation)や収縮(Erosion)といったモルフォロジー変換があります。実はこれはディープラーニングでよく使われるMaxPoolingフィルターで置き換えることができます。TensorFlowの実装で見ていきます。 モルフォロジー変換 OpenCVでのモルフ...
TensorFlow2.0で訓練の途中に学習率を変える方法... TensorFlow2.0で訓練の途中に学習率を変える方法を、Keras APIと訓練ループを自分で書くケースとで見ていきます。従来のKerasではLearning Rate Schedulerを使いましたが、TF2.0ではどうすればいいでしょうか? Keras APIの場合 従来どおりLea...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です