Pythonでzipを使ってJSONライクなdictをいい感じにforループで回す
Pythonでzip関数を使ってJSONライクな辞書(dict)を、いい感じにforループで要素ごとに抽出する例を解説します。ほぼワンライナーでできるので簡単です。
目次
想定
こんなJSONライクな辞書があったとします。JSONからパースした場合でもいいですね。
data = {
"in_channels": [1024, 1024, 512, 256, 128],
"out_channels": [1024, 512, 256, 128, 64],
"upsample": [True, True, True, True, True],
"resolution": [8, 16, 32, 64, 128],
"attention": {
8: False, 16:False, 32:False, 64:True, 128:False
}
}
これをforループで回して、各配列の中身を取得したいのです。例えば、こんな出力にしたいのです。
(1024, 1024, True, 8, False)
(1024, 512, True, 16, False)
(512, 256, True, 32, False)
(256, 128, True, 64, True)
(128, 64, True, 128, False)
実はこれはほぼワンライナーでできます。それを見ていきましょう。
zip関数の基本
まずはzip関数の基本から。zip関数は
for a, b in zip([1, 2, 3], [4, 5, 6]):
print(a, b)
とすると
1 4
2 5
3 6
のように出力するのがzip関数です。要はzip関数の中に複数の一次元配列を放り込めればいいのです。
ただしdictの配列数(要素数)は可変なので、真面目にzip(a, b, …)のように書くのはアホらしいです。どうすればいいでしょう?
アンパックのアスタリスク「*」
答えは、シーケンスのアンパック・アスタリスク(*)を使いましょう。可変長引数でも使われるやつですね。zipと組み合わせると先程の例は次のようにも書けます。
data = [[1, 2, 3], [4, 5, 6]]
for a, b in zip(*data):
print(a, b)
これは先程と同じ出力になります。
1 4
2 5
3 6
dictの値だけ配列の配列に格納するには、dict.values()を使えばいいのでこれでできそうな気がします。
ネストされた辞書のケア
辞書の中身がすべて配列だったらOKですが、values()→アンパック→zipだとネストされた辞書が要素にあったときに困ります。例えば、
def main():
data = {
"in_channels": [1024, 1024, 512, 256, 128],
"out_channels": [1024, 512, 256, 128, 64],
"upsample": [True, True, True, True, True],
"resolution": [8, 16, 32, 64, 128],
"attention": {
8: False, 16:False, 32:False, 64:True, 128:False
}
}
for seq in zip(*data.values()):
print(seq)
このようにすると、
(1024, 1024, True, 8, 8)
(1024, 512, True, 16, 16)
(512, 256, True, 32, 32)
(256, 128, True, 64, 64)
(128, 64, True, 128, 128)
ネストされた辞書(attention)のキーだけ抽出されてしまいました。ネストされた辞書の値がほしいのでちょっとこれは困りました。
そこで以下のようにループ側で対策します。
# ネストされた辞書対策
def main():
data = {
"in_channels": [1024, 1024, 512, 256, 128],
"out_channels": [1024, 512, 256, 128, 64],
"upsample": [True, True, True, True, True],
"resolution": [8, 16, 32, 64, 128],
"attention": {
8: False, 16:False, 32:False, 64:True, 128:False
}
}
for seq in zip(*(v.values() if type(v) is dict else v for v in data.values())):
print(seq)
単純にvalues()で出てきた値がdictかどうかを調べているだけです。もしvがネストされた辞書だったら、さらにそれのvalues()を読みに行くことで、値側を取りに行くようにしています。
結果は以下の通りです。
(1024, 1024, True, 8, False)
(1024, 512, True, 16, False)
(512, 256, True, 32, False)
(256, 128, True, 64, True)
(128, 64, True, 128, False)
これで目的の結果になりました。
注意点
この例はPython3.7以上で行っています。Python3.7以上ではdictの順番が保証されるのが明確に言語仕様として追加されたのでOrderedDictを使わなくてもキーの順番保証ができるようになっています。Pythonのバージョンが古い場合は、OrderedDictを使って順番のケアをしてある必要があるかもしれません(あるいはライブラリによってはOrderedDictでパースしているケースがあるので、そのタイプのケアなど)。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー