こしあん
2019-09-24

argparseに直接dictを読み込ませる怪しいやり方

Pocket
LINEで送る


argparseにコマンドライン引数ではなく、ファイルから読み込んだdictをオーバーラップさせる方法を試してみました。本来のargparseの使い方ではない怪しいやり方ですが、JSONやyamlファイルとの連携が可能なので便利ではないかなと思います。

注意

これは本来のargparseの使い方ではありません。この例では動作しますが、argparseの型や値の保証のような機能が完全に損なわれる場合があります。

やりたいこと

ハイパーパラメータをソースに直書きするのあんまりよくないよね?
→ argparseを使えばいいけど、コマンドライン引数で指定するのめんどくさいよね
→ ただargparseってデフォルト値や、値の説明を放り込めて便利だよね
→ なら設定値を別ファイル(yamlやJSON)に直書きして、それをargparseに放り込みたいよね

というのがきっかけ。こういうのないかなーみたいなことつぶやいたら即教えていただきました(ふぁむたろうさんありがとうございました)

サンプルコード

この例では、読み込ませるdictを「dummy」としてハードコーディングしていますが、実際はファイルから読み込ませると良いでしょう。

import argparse

p = argparse.ArgumentParser(description="load from dictionary sample")
p.add_argument("--batch_size", help="batch size", default=128)
p.add_argument("--learning_rate", help="initial leraning rate", default=0.1)

args = p.parse_args()

def update_args(args_instance):
    # ダミーのdict : yamlやJSONから読んでも良い
    dummy = {
        "batch_size": 256,
        "momentum": 0.9
    }
    # この方法だとなんでも代入できてしまうので、add_argumentで指定したもののみ代入する
    for key, value in dummy.items():
        if args_instance.__contains__(key):
            args_instance.__setattr__(key, value)
    return args_instance

print("before loading")
print(args)
update_args(args) # dictからの値の読み込み
print("after loading")
print(args)

まずは普通にparse_args()でコマンドライン引数を読んでいきます。ただし、パースはあまり意味がなくて、後のdictからの読み込みで上書きされてしまいます(ここでやっていることは、実質ただの初期化)。

dictから読み込んで上書きするのが「update_args」という関数です。この例ではちょっと意地悪をして、add_argumentで定義をしていない「momentum」というセットを含んでいます。

argparseのインスタンスにダイレクトに代入するには、「setattr」の関数を使うとできます。ただし、この関数はセットするキーや値のチェックもしないので、何でも代入できてしまいます。そこで、「contains」という関数を使って、add_argumentで定義したキーのみアップデートするようにしています。

この結果は次のようになります。

before loading
Namespace(batch_size=128, learning_rate=0.1)
after loading
Namespace(batch_size=256, learning_rate=0.1)

batch_sizeがdummyで指定した256にアップデートされました。

ちなみに、このupdate_argsという関数は、add_argumentやハイパーパラメータの条件に依存しないので、例えばutilsに入れてtrain側で読み出せば、事足りることになります(引数に設定ファイルのパスでも入れておくといいですね)。

さらに怪しい使い方

ちなみに「setattr」は本当になんでも代入できてしまうので、これを逆手に取り、loggingなどの訓練結果を放り込んで保存するというやり方があるそうです

確かにこれなら、ハイパーパラメータから訓練結果まで一つのファイルで管理可能です。最後にargsだけ保存してしまえば、訓練結果をそれっぽく管理できるということになります。本来のargparseの使い方からは思いっきりかけ離れていますが、ちょっとした機械学習の運用としてなら便利でしょう。

Related Posts

OpenCVのsubtractについての小ネタ OpenCVのsubtractと通常のNumpyの引き算の差が気になったのでメモ。実際に試してみました。 環境:Numpy:1.16.3, OpenCV:4.1.0 NumpyとOpenCVのsubtractの差 OpenCVとNumpy配列は密接に関係していて、新しい画像を作るときにnp....
OpenCVで作成した動画がブラウザで正常に表示できない場合の解決法... OpenCVで作成した動画をサイトで表示する場合、ローカルで再生できていても、ブラウザ上では突然プレビューがでなり、ハマることがあります。原因の特定が難しい現象ですが、動画を作成する際にH.264形式でエンコードするとうまくいきました。その方法を解説します。 MPV4は手軽だが… OpenCV...
PyTorchでConvolutionフィルターをやる(エッジ検出やアンシャープマスク)... PyTorchでPILのConvolutionフィルター(エッジ検出やアンシャープマスク)をやりたくなったので、どう実装するか考えてみました。 やりたいこと PIL/PillowのConvolutionフィルター(ImageFilterなど)の処理をPyTorchの畳み込み演算で再現したい ...
PyTorchで複数出力があるモデルの出力の型について... 出力が複数あるモデルの訓練というのは少し複雑なモデルだとよく出てきます。PyTorchでは複数出力のモデルの、出力の型はどうなっているでしょうか。それを見ていきます。中間層の値を取りたい場合も使えます。 サンプルコード import torch from torch import nn cl...
TensorFlowで値のソートをする方法 TensorFlowでNumpyのnp.sortやnp.argsortのようなソートを行うことを考えます。一般にTensorFlowで値のソートというと、自動微分もあわさって難しいように思えますが、実はちゃんとソートできます。うまくやればKerasからも使うことができます。 tf.nn.top_...
Pocket
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です