Kerasのジェネレーターでサンプルが列挙される順番について
Kerasの(カスタム)ジェネレーターでサンプルがどの順番で呼び出されるか、1ループ終わったあとにどういう処理がなされるのか調べてみました。ジェネレーターを自分で定義するとモデルの表現の幅は広がるものの、バグが起きやすくなるので「本当に順番が保証されるのか」や「ハマりどころ」を確認します。
目次
0~9の数字をループさせるジェネレーター
Kerasのジェネレーターは、無限ループさせてその中で訓練データをループさせるという構造を取ります。
def generator(batch_size):
X_cache, y_cache = [], []
while True:
for i in range(10):
X_cache.append(i)
y_cache.append(i)
if len(X_cache)==batch_size:
X_batch = np.asarray(X_cache)
y_batch = np.asarray(y_cache)
X_cache, y_cache = [], []
yield X_batch, y_batch
このように、Xとyのキャッシュを用意しておいてどんどん突っ込んで、バッチサイズになったらNumpy配列として返すというのがわかりやすいのではないでしょうか(複数GPUのことは考えていないので、並列化はやらないものとします)。
このコードは「0→1→2…→9→0→1」の順にバッチサイズ分切り出して返すだけのジェネレーターです。
恒等出力モデル
Kerasで「入力=出力となるモデル」を作ります。Lambdaレイヤーを使って、「lambda x: x」のような恒等出力関数を導入します。これによって、predictしたときに入力の値がそのまま取り出せるというわけです。
from keras.layers import Lambda, Input
from keras.models import Model
import numpy as np
input = Input((1,))
x = Lambda(lambda x: x)(input)
model = Model(input, x)
検証
バッチサイズを7(中途半端!)な値にして、predict_generatorさせて結果を見ます。generatorの場合はstepsで回す回数を指定します。この場合は、「0~6」と「7~9+0~3」が出てくれば想定どおりの挙動です。
result = model.predict_generator(generator(7), steps=2, max_queue_size=10)
print(result)
[[0.]
[1.]
[2.]
[3.]
[4.]
[5.]
[6.]
[7.]
[8.]
[9.]
[0.]
[1.]
[2.]
[3.]]
とりあえず想定どおりの結果が出てきましたね。
max_queue_sizeを変えると?
ちなみにmax_queue_sizeはデフォルトで10ですが、これはバッチサイズを大きくしたときに、RAM(メモリ)を圧迫するので下げたほうがいい場合もあります。この値を変えても同じように順番が保証されるのでしょうか?
result = model.predict_generator(generator(7), steps=2, max_queue_size=1)
print(result)
[[0.]
[1.]
[2.]
[3.]
[4.]
[5.]
[6.]
[7.]
[8.]
[9.]
[0.]
[1.]
[2.]
[3.]]
とりあえず、max_queue_sizeを変えてもジェネレーターがシーケンシャルである限り順番は保証されます。
ありがちなミス1:ジェネレーターの中で並列化(マルチプロセッシング)をする
「ジェネレーターがシーケンシャルである限り」という条件です。よくありがちな例として、訓練(train)のジェネレーターは、並列化しても問題ありません。なぜなら、ジェネレーターの中では、仮に並列化してもXとyの順番を保証することは容易だからです。並列化する関数でXとyを同時に返せば保証されます。
ここでの並列化というのは、multiprocessingやjoblibによるマルチプロセス化を示します。jpegの読み込みでこのマルチプロセス化をすると結構速くなるので便利です。
しかし、テストのジェネレーターのように、ジェネレーターの外側に正しいラベルデータがあって、ジェネレーターは外側の順番に一致するように返していくようなケースだと、この並列化は大変まずいです。なぜなら、並列化は基本的に順番が保証されないからです(これはPythonに限らずどのだいたいの言語でそうです)。「シーケンシャル=並列化をしない」と考えればOKです。
ありがちなミス2:ジェネレーターのインスタンスを使い回す
ジェネレーターのインスタンスを使い回すと、イテレーションの位置が継続されるので、バグのもとになります。
result1 = model.predict_generator(generator(3), steps=1)
result2 = model.predict_generator(generator(3), steps=1)
print(result1)
print(result2)
このようにeval_generatorやpredict_generatorのタイミングで、ジェネレーターのインスタンスを別に直すとイテレーションの位置がリセットされます。
[[0.]
[1.]
[2.]]
[[0.]
[1.]
[2.]]
しかし、ジェネレーターのインスタンスを一度宣言して使い回すとイテレーションの位置が継続されます。
gen = generator(3)
result1 = model.predict_generator(gen, steps=1, max_queue_size=0)
result2 = model.predict_generator(gen, steps=1, max_queue_size=0)
print(result1)
print(result2)
#[[0.]
# [1.]
# [2.]]
#[[8.]
# [9.]
# [0.]]
また、イテレーションの位置はmax_queue_sizeによるキューの分も入っているので、max_queue_sizeの値を変えたり、または呼び出しのたびにも値は変わってきます。
gen = generator(3)
result1 = model.predict_generator(gen, steps=1, max_queue_size=5)
result2 = model.predict_generator(gen, steps=1, max_queue_size=5)
print(result1)
print(result2)
#[[0.]
# [1.]
# [2.]]
#[[1.]
# [2.]
# [3.]]
非常に面倒ですね。順番気にするときはジェネレーターのインスタンスの使い回しをやめましょう。
ありがちなミス3:キャッシュを無限ループの中でリセットする
これはただのケアレスミスですが、このコードを
def generator(batch_size):
X_cache, y_cache = [], []
while True:
for i in range(10):
X_cache.append(i)
y_cache.append(i)
if len(X_cache)==batch_size:
X_batch = np.asarray(X_cache)
y_batch = np.asarray(y_cache)
X_cache, y_cache = [], []
yield X_batch, y_batch
このように、X_cache, y_cacheの宣言の位置を微妙に変えて、無限ループの内側でやるとまたややこしいことになります。
def generator(batch_size):
while True:
X_cache, y_cache = [], []
for i in range(10):
X_cache.append(i)
y_cache.append(i)
if len(X_cache)==batch_size:
X_batch = np.asarray(X_cache)
y_batch = np.asarray(y_cache)
X_cache, y_cache = [], []
yield X_batch, y_batch
gen = generator(7)
result = model.predict_generator(gen, steps=2, max_queue_size=1)
print(result)
[[0.]
[1.]
[2.]
[3.]
[4.]
[5.]
[6.]
[0.]
[1.]
[2.]
[3.]
[4.]
[5.]
[6.]]
7,8,9→「リセット」となってしまうので、7~9が呼び出されることが永遠にありません。ケアレスミスですが気をつけましょう。
まとめ
Kerasのジェネレーター絡みはかなりバグが出やすい。そしてpredictのように順番気にする場合は以下の点に気をつけると良さそうです。
- 順番気にする(シーケンシャルにする)場合は、絶対に並列化をしない
- キャッシュをしてバッチに変換するのなら、変数の宣言は無限ループの外側に書く
- 複数回推論や評価をする場合は、ジェネレーターのインスタンスの使い回しをしない
また、より安全にいくために、以下のようにしてしまうといいと思います。
- サンプル数がバッチサイズの倍数になるように、ダミーデータや既存のデータ(例えば1番目のデータ)を末尾にコピーして端数が出ないように調整する
こんなところでしょうか。テストデータ数が素数だったりすると地味に発狂したりします(ダミーデータ入れればいいだけですが)。
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー