OpenAIのChatGPT APIの並列化を試す(LangChain)
大量のデータをChatGPTで推論したいときに、並列化したらどの程度速くなったのかを実験してみました。振れ幅はかなり大きいですが、かなり並列化の効果はありました。
目次
はじめに
ChatGPTのAPIは普通に叩いていると遅いので並列化したい…という話。forループで大量のデータに対してAPIを叩く場合を想定
※OpenAIのAPIリミットがアカウント単位で異なるので、全てのパターンでこの計測結果になるとは限りません
https://platform.openai.com/docs/guides/rate-limits
使用するライブラリ
- LangChain v0.0.220
- Guidance v0.0.64
※Guidanceはv0.0.64では、内部のエラーハンドリングでうまくいかない部分があり(502 Bad Gateway)、try~catchで別途実装しないとリトライがうまくいかないことがわかりました。Guidance派は今後のアップデートに期待しましょう。
この点はLangChainはうまくいったので、LangChainに焦点をあてて実験します。
実験
- 並列化なし
- 並列化あり
- 並列化数 2
- 並列化数 4
- 並列化数 8
- 並列化数 16
並列化はRayを使い、並列化数はray.initのnum_cpusの値を指定します。
実験内容
こちらの記事から国の一覧データを取得します(249カ国)。各国に対し、「{国名}の観光名所を1つ簡潔に教えてください」という問いを、ChatGPTに対してします。使用するモデルは「gpt-3.5-turbo-0613」です。
結果
並列化数 | 1回目(s) | 2回目(s) | 3回目(s) | 平均 |
---|---|---|---|---|
なし | 1575 | – | – | 1575 |
2 | 496 | 811 | 530 | 612 |
4 | 784 | 282 | 439 | 502 |
8 | 429 | 432 | 354 | 405 |
16 | 160 | 346 | 334 | 280 |
失敗したケースはどれもありませんでした。502エラーは出ていますが、LangChainのリトライの中で吸収されています。
がーっと進むときもあれば、エラーなのかRete Limitなのかで一時的に止まることもあります。
結論:振れ幅かなり大きいけど、雑に大きめの並列化で良さそう
コード
並列化なし
import os
os.environ["OPENAI_API_KEY"] = "<your-openai-api-key>"
import pandas as pd
from langchain import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate
import time
def main():
df = pd.read_csv("country_data.csv", encoding="utf-8")
chat = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=1, max_tokens=256)
prompt_template = HumanMessagePromptTemplate.from_template("{country}の観光名所を1つ簡潔に教えてください")
chat_prompt = ChatPromptTemplate.from_messages([prompt_template])
llm_chain = LLMChain(
llm=chat,
prompt=chat_prompt
)
failed_cnt = 0
start_time = time.time()
for country in df["国・地域名"].tolist():
try:
print(llm_chain.run(country=country))
except:
failed_cnt += 1
print("失敗数", failed_cnt)
print("経過時間", time.time()-start_time)
if __name__ == "__main__":
main()
並列化あり
import pandas as pd
from langchain import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts.chat import HumanMessagePromptTemplate, ChatPromptTemplate
import ray
import time
@ray.remote
def run_unit(chain, country):
try:
result = chain.run(country=country)
print(result)
return 0
except:
return 1
def main():
df = pd.read_csv("country_data.csv", encoding="utf-8")
chat = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=1, max_tokens=256)
prompt_template = HumanMessagePromptTemplate.from_template("{country}の観光名所を1つ簡潔に教えてください")
chat_prompt = ChatPromptTemplate.from_messages([prompt_template])
llm_chain = LLMChain(
llm=chat,
prompt=chat_prompt
)
result = {"time":[], "failure":[]}
for parallel in [2, 4, 8, 16]:
start_time = time.time()
ray.init(num_cpus=parallel)
results = []
for country in df["国・地域名"].tolist():
results.append(run_unit.remote(llm_chain, country))
results = ray.get(results)
elapsed = time.time() - start_time
result["time"].append(elapsed)
result["failure"].append(sum(results))
print(result)
ray.shutdown()
time.sleep(120)
print(result)
if __name__ == "__main__":
main()
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー