LangChainのRecursiveCharacterTextSplitterの動作がおかしいので自作する
LangChainのRecursiveCharacterTextSplitter.from_tiktoken_encoderの動作が思ってたのと違ったので、それに相当するものを自作してみました。とてつもなく大きいトークン数のチャンクが出てきたときに有効かもしれません。
目次
概要
RecursiveCharacterTextSplitter.from_tiktoken_encoderがちゃんとトークン数をコントロールできていない(2024年2月現在)ので、自作しました
環境
- LangChain0.0.350
- pypdfium24.25.0
- tiktoken0.5.2
LangChainの場合
令和5年度版の外交青書のPDFを例に、PyPdfium2で読み取り、LangChainのRecursiveCharacterTextSplitter.from_tiktoken_encoderを使い、適当なトークンサイズのチャンクに区切る。
ここでトークンサイズは512、チャンクオーバーラップは128とする。生成されたチャンクが指定されたトークンサイズを満たしているのかを、tiktokenを使い事後検証する。事後検証では、全チャンクに対する最小・最大トークン数を計算する。RecursiveCharacterTextSplitterの原理上、本来この最大トークン数はトークンサイズにほぼ一致しているはずである。
LangChainを使った場合のコードを示す。
import pypdfium2 as pdfium
from langchain.text_splitter import RecursiveCharacterTextSplitter
import tiktoken
import time
def langchain_splitter(join_delimiter="\n\n", separator="\n\n", chunk_size=512, chunk_overlap=128):
pdf_read_start_time = time.time()
pdf = pdfium.PdfDocument("100523089.pdf")
contents = []
for page in pdf:
contents.append(page.get_textpage().get_text_range().replace("\r\n", "\n"))
contents = join_delimiter.join(contents)
pdf_read_end_time = time.time()
chunk_start_time = time.time()
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
separators=[separator],
chunk_size=chunk_size,
encoding_name="cl100k_base",
chunk_overlap=chunk_overlap,
)
chunks = text_splitter.split_text(contents)
chunk_end_time = time.time()
# Verification
encoder = tiktoken.get_encoding("cl100k_base")
encoded_text = encoder.encode_batch(chunks)
num_tokens = [len(t) for t in encoded_text]
if len(num_tokens) > 0:
print(f"Min tokens: {min(num_tokens)}, Max tokens : {max(num_tokens)}, Len chunks : {len(chunks)}")
print("PDF Read Elapsed : ", pdf_read_end_time - pdf_read_start_time)
print("Chunk Elapsed : ", chunk_end_time - chunk_start_time)
if __name__ == "__main__":
langchain_splitter(join_delimiter="\n\n")
ところが、生成されたチャンク内での最大トークン数は、Splitterで指定したトークンサイズと一致しない。ページごとのjoin_delimiter="\n\n"
の場合、
Min tokens: 58, Max tokens : 3503, Len chunks : 433
PDF Read Elapsed : 7.720760822296143
Chunk Elapsed : 0.4266848564147949
Max tokensが512を大きく超えているのである。join_delimiter="\n"
の場合、酷いことになる。全トークンが1個のチャンクに放り込まれる。これはLLMはたまったものじゃない。
Min tokens: 690650, Max tokens : 690650, Len chunks : 1
PDF Read Elapsed : 7.665188789367676
Chunk Elapsed : 0.4167459011077881
LangChainのバージョン問題でもない
LangChainのバージョン問題かと思って、Colabで以下のバージョンをインストールして検証たが、改善しなかった。日本語特有のバグなのかもしれない。
- langchain0.1.7
- tiktoken0.6.0
自作する
仕方がないので自作する。
import pypdfium2 as pdfium
import tiktoken
import time
def custom_recursive_splitter(contents, encoder="cl100k_base",
separator="\n\n", chunk_size=512, chunk_overlap=128):
encoder = tiktoken.get_encoding(encoder)
separator_token = encoder.encode(separator)
chunks = []
cursor = 0
while cursor < len(contents):
encoded_tokens = encoder.encode(contents[cursor:cursor+chunk_size*3//2]) # ここは適当
if separator_token[0] not in encoded_tokens:
tokens_to_take = chunk_size
else:
min_separator_index = encoded_tokens.index(separator_token[0])
tokens_to_take = max(min(min_separator_index, chunk_size), chunk_overlap)
chunk_tokens = encoded_tokens[:tokens_to_take]
current_chunk = encoder.decode(chunk_tokens)
if tokens_to_take <= chunk_overlap:
cursor += len(current_chunk)
else:
overlap_start_chunk = encoder.decode(
encoded_tokens[:tokens_to_take-chunk_overlap])
cursor += len(overlap_start_chunk)
chunks.append(current_chunk)
return chunks
def run_custom_splitter(join_delimiter="\n\n", separator="\n\n", chunk_size=512, chunk_overlap=128):
pdf_read_start_time = time.time()
pdf = pdfium.PdfDocument("100523089.pdf")
contents = []
for page in pdf:
contents.append(page.get_textpage().get_text_range().replace("\r\n", "\n"))
contents = join_delimiter.join(contents)
pdf_read_end_time = time.time()
chunk_start_time = time.time()
chunks = custom_recursive_splitter(contents,
separator=separator, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunk_end_time = time.time()
# Verification
encoder = tiktoken.get_encoding("cl100k_base")
encoded_text = encoder.encode_batch(chunks)
num_tokens = [len(t) for t in encoded_text]
if len(num_tokens) > 0:
print(f"Min tokens: {min(num_tokens)}, Max tokens : {max(num_tokens)}, Len chunks : {len(chunks)}")
print("PDF Read Elapsed : ", pdf_read_end_time - pdf_read_start_time)
print("Chunk Elapsed : ", chunk_end_time - chunk_start_time)
if __name__ == "__main__":
run_custom_splitter(join_delimiter="\n")
処理は愚直なやり方なので、高速化はほとんど期待できない。join_delimiter="\n\n"
の場合:
Min tokens: 127, Max tokens : 513, Len chunks : 2544
PDF Read Elapsed : 8.0279860496521
Chunk Elapsed : 1.2823963165283203
join_delimiter="\n"
の場合:
Min tokens: 254, Max tokens : 513, Len chunks : 1798
PDF Read Elapsed : 7.655973672866821
Chunk Elapsed : 0.9463059902191162
このように自作すれば最大トークン数は担保できた。512ではなく513となっているが、ここは誤差であろう。少なくとも60万トークン一気にくるなんてことはないので安心である。
チャンク生成は2倍ぐらい遅くなっているが、そもそも処理のボトルネックはPDFの読み込みにあるので、チャンク生成が遅くなってもこの例ではそこまで大きな差ではない。どうしても気になるなら処理を並列化すればいい(シーケンシャルサーチなので面倒だけど)。tiktoken自体が高速(むしろそれが売り)なので、こんな何度も呼び出すような処理でもなんとかなってしまう。
LangChainに対するぼやき
「最近LangChain使う意味ないよな。JSONモードできちゃったし、ストリーミングの取り回しもOpenAIのライブラリのほうが簡単だし、RAGも自分で実装できるし」と漠然と感じていた。「でもRecursiveTextSplitterは便利だからこれだけ使おうかな」と思っていた
しかし、RecursiveTextSplitterがかなり変なチャンク生成をすることがわかったので、「もうLangChain当分使うことはないな。バージョンアップで互換性いろいろ犠牲になるし、プロンプトもあんまり良くないし」という確信に変わってしまった。あれをプロダクションで使うのはしんどいと気がする。LangChainさようなら~ LangSmithは気になったら使ってみるかも
Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内
技術書コーナー
北海道の駅巡りコーナー