tf.data.Datasetでdictなデータと仲良くする方法
TensorFlow2.0でdict構造のデータから、tf.data.Datasetを作る方法を見ていきます。バッチの軸の結合処理を一切書かずに、dict構造を保ったままバッチ化してくれる便利な方法です。
目次
サンプルデータ
こんなデータをdict構造を保ったままtf.data.Datasetで回したいのです。JSONでよくある例です。
# サンプルデータ
sample_data = [
{
"name": "綾波",
"hp": [290, 1498],
"power": [12, 57],
"torpedo": [98, 463]
},
{
"name": "ラフィー",
"hp": [319, 1650],
"power": [19, 88],
"torpedo": [56, 266]
},
{
"name": "ジャベリン",
"hp": [255, 1318],
"power": [15, 70],
"torpedo": [71, 334]
},
{
"name": "Z23",
"hp": [345, 1786],
"power": [22, 104],
"torpedo": [61, 289]
}
]
なぜdict構造を保ちたいのかというと、パラメーターが多数になったときに、tupleだとどれがどれだかわからなくなるからです(順番がややこしい)。dict構造ならkey-valueが紐付いているため混乱が少なくなります。
ここでは、名前を無視して、名前、火力、雷装の値だけを取り出すものとします。
コード
もっとスマートな方法はあるかもしれませんが、tf.data.Dataset.from_generatorを使うとできます。ただし、このジェネレーターは出力の型とshapeを指定する必要があり、公式ドキュメントだとあくまでtupleの例しか書いていなくdictの例はありませんでした。ただし、dict構造を保ったままでも実装可能です。
import tensorflow as tf
def generator():
for x in sample_data:
items = {}
for k in ["hp", "power", "torpedo"]:
items[k] = x[k]
yield items # dictで返す
def main():
data = tf.data.Dataset.from_generator(generator,
{"hp": tf.float32, "power": tf.float32, "torpedo": tf.float32},
{"hp": (2,), "power": (2,), "torpedo": (2,)}
).shuffle(4).batch(2)
for x in data:
print("---")
print(x)
if __name__ == "__main__":
main()
上の「generator」という関数を、「tf.data.Dataset.from_generator」で読ませます。ただし、ここでジェネレーターの返り値はdictです。
返り値がdictの場合、型とshapeの指定もdictでします。上の例がそうです。
結果は次のようになります。バッチサイズ2で2バッチ読み出されます。
---
{'hp': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 319., 1650.],
[ 255., 1318.]], dtype=float32)>, 'power': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[19., 88.],
[15., 70.]], dtype=float32)>, 'torpedo': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 56., 266.],
[ 71., 334.]], dtype=float32)>}
---
{'hp': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 290., 1498.],
[ 345., 1786.]], dtype=float32)>, 'power': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 12., 57.],
[ 22., 104.]], dtype=float32)>, 'torpedo': <tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[ 98., 463.],
[ 61., 289.]], dtype=float32)>}
このように、dict構造を保ったままバッチ化ができました。ニューラルネットワークからアクセスしたいときはキーでアクセスすればOKでです。
バッチの軸の結合は一切書かずにdict構造を保ったまま変換してくれるのが便利ですね。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー