こしあん
2020-05-26

tf.data.Datasetでdictなデータと仲良くする方法

Pocket
LINEで送る
Delicious にシェア

295{icon} {views}

新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』(A4・195ページ)好評通販中です! 機械学習の入門からGANの最先端までを書いたおすすめの本となっています! Boothで試し読みできます。情報まとめ・質問用GitHub



TensorFlow2.0でdict構造のデータから、tf.data.Datasetを作る方法を見ていきます。バッチの軸の結合処理を一切書かずに、dict構造を保ったままバッチ化してくれる便利な方法です。

サンプルデータ

こんなデータをdict構造を保ったままtf.data.Datasetで回したいのです。JSONでよくある例です。

# サンプルデータ
sample_data = [
    {
        "name": "綾波",
        "hp": [290, 1498],
        "power": [12, 57],
        "torpedo": [98, 463]
    },
    {
        "name": "ラフィー",
        "hp": [319, 1650],
        "power": [19, 88],
        "torpedo": [56, 266]
    },
    {
        "name": "ジャベリン",
        "hp": [255, 1318],
        "power": [15, 70],
        "torpedo": [71, 334]
    },
    {
        "name": "Z23",
        "hp": [345, 1786],
        "power": [22, 104],
        "torpedo": [61, 289]
    }
]

なぜdict構造を保ちたいのかというと、パラメーターが多数になったときに、tupleだとどれがどれだかわからなくなるからです(順番がややこしい)。dict構造ならkey-valueが紐付いているため混乱が少なくなります。

ここでは、名前を無視して、名前、火力、雷装の値だけを取り出すものとします。

コード

もっとスマートな方法はあるかもしれませんが、tf.data.Dataset.from_generatorを使うとできます。ただし、このジェネレーターは出力の型とshapeを指定する必要があり、公式ドキュメントだとあくまでtupleの例しか書いていなくdictの例はありませんでした。ただし、dict構造を保ったままでも実装可能です。

import tensorflow as tf

def generator():
    for x in sample_data:
        items = {}
        for k in ["hp", "power", "torpedo"]:
            items[k] = x[k]
        yield items  # dictで返す

def main():
    data = tf.data.Dataset.from_generator(generator,
        {"hp": tf.float32, "power": tf.float32, "torpedo": tf.float32},
        {"hp": (2,), "power": (2,), "torpedo": (2,)}
        ).shuffle(4).batch(2)

    for x in data:
        print("---")
        print(x)

if __name__ == "__main__":
    main()

上の「generator」という関数を、「tf.data.Dataset.from_generator」で読ませます。ただし、ここでジェネレーターの返り値はdictです。

返り値がdictの場合、型とshapeの指定もdictでします。上の例がそうです。

結果は次のようになります。バッチサイズ2で2バッチ読み出されます。

---
{'hp': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 319., 1650.],
       [ 255., 1318.]], dtype=float32)>, 'power': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[19., 88.],
       [15., 70.]], dtype=float32)>, 'torpedo': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 56., 266.],
       [ 71., 334.]], dtype=float32)>}
---
{'hp': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 290., 1498.],
       [ 345., 1786.]], dtype=float32)>, 'power': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 12.,  57.],
       [ 22., 104.]], dtype=float32)>, 'torpedo': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 98., 463.],
       [ 61., 289.]], dtype=float32)>}

このように、dict構造を保ったままバッチ化ができました。ニューラルネットワークからアクセスしたいときはキーでアクセスすればOKでです。

バッチの軸の結合は一切書かずにdict構造を保ったまま変換してくれるのが便利ですね。


新刊情報

技術書典8の新刊『モザイク除去から学ぶ 最先端のディープラーニング』好評通販中(A4・195ページ)です! Boothで試し読みもできるのでよろしくね!


Pocket
LINEで送る
Delicious にシェア

Add a Comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です