こしあん
2023-09-29

Streamlit+LangChainでストリーミング対応しつつPDFに複数の質問をさせる


2k{icon} {views}


Streamlit+LangChainでChatGPTのストリーミング表示を実装してみます。PDFの検索ベースで、かつテンプレートの質問を連続的に行うという実践的な例を紹介します。LangChainのコールバックの実装と、UIへのつなぎ込みの部分に工夫が必要です。

はじめに

Streamlit+LangChainでChatGPTをストリーミング表示させるのはいくつかありますが、単発の質問で複数の質問を表示させたり、ページを切り替えたときにチャット履歴を保持したり、実践的な例がほとんどなかったので試してみました。結論からいうとできました。ただ2023年9月現在、なかなかネットに出てこない情報を手探りで試した感はあったのでメモとしておいておきます。

最終的に作りたいもの

  • アプリとしてのガワ(フロントエンド)はStreamlitを使う
  • PDFから検索してQAさせたい
  • RetrievalQAさせるのだがストリーミング表示させたい
  • 複数のテンプレート質問を連続的に行いたい(インタラクティブでなくて良い)

LangChain単体でストリーミング

ChatGPTのストリーミングだとLangChainではなく、OpenAIのライブラリを使ったほうが便利だったりします。

for response in openai.ChatCompletion.create(
        model=st.session_state["openai_model"],
        messages=[{"role": m["role"], "content": m["content"]} for m in st.session_state.messages],
        stream=True,
    ):
    # ここに何らかの処理を書く

これはStreamlit公式のサンプルプログラムからですが、OpenAIのライブラリを使って実装しています。OpenAIのライブラリですとストリーミング対応は、

  • 「stream=True」のオプションを入れる。
  • ストリームで逐次生成されるトークン(英語の場合は単語)が、ジェネレーターの形式で飛んでくるので、forループで待ってれば取り出せる

なので扱いが楽です。ところがLangChainになると、OpenAIのライブラリではforループで良いところを、独自のコールバックで実装しないといけなく、この情報がそこまでネットになかったので逆にわかりづらかったりします。このコールバックはBaseCallbackHandlerというLnagChainで用意されているクラスを継承して実装するものです。

まず先に全体のプログラムを示します。UIとのつなぎ込みはなしで、LangChain単位でストリーミングさせつつPDFに質問させるプログラムです。OpenAIのAPIKeyは「OPENAI_API_KEY」という環境変数に入れておいてください。

PDFの検索部分のコードはML Bearさんのこちらの記事を参考にさせていただきました(ほぼコピペ)。Qdrantはメモリ上に展開しています。PDFは最近読んだこちらのPDFをローカルにDLして保存しておきます。

from typing import Any

from PyPDF2 import PdfReader
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Qdrant
from langchain.callbacks.base import BaseCallbackHandler

from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

class StreamHandler(BaseCallbackHandler):
    def __init__(self, initial_text=""):
        self.text=initial_text
        self.cnt = 0

    def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
        print("\n---Question", self.cnt, " start---")

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        print(token, end="")

    def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
        print("\n---Question", self.cnt, " end---")
        self.cnt += 1

def main():
    filename = "2211.11559.pdf"
    pdf_reader = PdfReader(filename)
    text = '\n\n'.join([page.extract_text() for page in pdf_reader.pages])
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        separators=["\n\n"],
        chunk_size=512,
        chunk_overlap=128,
    )
    pdf_text = text_splitter.split_text(text)

    # 今ここ揮発性DBだから永続化したいときは別なのに変えてね
    client = QdrantClient(":memory:")

    collections = client.get_collections().collections
    collection_names = [collection.name for collection in collections]
    collection_name = filename

    if collection_name not in collection_names:
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
        )

    qdrant = Qdrant(
        client=client,
        collection_name=collection_name,
        embeddings=OpenAIEmbeddings())
    qdrant.add_texts(pdf_text)

    stream_handler = StreamHandler()
    llm = ChatOpenAI(model="gpt-3.5-turbo-16k-0613", temperature=0.2, max_tokens=512, streaming=True, callbacks=[stream_handler])
    retriever = qdrant.as_retriever(
        search_type="similarity",
        search_kwargs={"k":4}
    )
    qa = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff", 
            retriever=retriever,
            return_source_documents=True,
        )

    questions = [
        "What is this paper about?",
        "What's so great about it compared to previous research?",
        "What is the key to the technique or method?",
        "How does it validate the proposed method?",
        "Is there any discussion of this paper?",
        "What paper should I read next to better understand this paper?"
    ]
    for question in questions:
        qa(question)

if __name__ == "__main__":
    main()

動かすと次のようになります(拡大してみてください)。

ポイントは以下のStreamHandlerの定義の部分で、よくStreamingの例として以下のようなStreamHandlerの実装を目にすることがあると思います。

class StreamHandler(BaseCallbackHandler):
    def __init__(self, container, initial_text=""):
        self.container = container
        self.text = initial_text

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        self.container.markdown(self.text)

これはStreamlit公式のストリーミングデモコードです。「on_llm_new_token」で「トークンがきたときにテキストを追加しているな」というのはなんとなくわかると思います。このコールバックはLangChainの機能なので、LangChainでチュートリアルページを作るべきだと思うのですが、LangChainの公式だとStreamingStdOutCallbackHandlerを使うだけで終わりという割と不親切な感じなのですよね。

上記のコードでは、1つのメッセージに対するストリーミングのやり取りはできるのですが、複数の質問を連鎖的に流したときに、そのメッセージを記録する場所(ステート)の切り替えまでは実装できていません。ここを対応する必要があります。

結論は最初に示したコードで、コールバックの「on_llm_start」「on_llm_end」で対応します。

class StreamHandler(BaseCallbackHandler):
    def __init__(self, initial_text=""):
        self.text=initial_text
        self.cnt = 0

    def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
        print("\n---Question", self.cnt, " start---")

    def on_llm_new_token(self, token: str, **kwargs) -> None:
        self.text += token
        print(token, end="")

    def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
        print("\n---Question", self.cnt, " end---")
        self.cnt += 1

ディープラーニングやっている方なら、Kerasのコールバックと同じと理解すればわかりやすいと思います(自分はこれに気づいて腑に落ちました)。「on_llm_start」はLLMの返答開始時に呼ばれる関数で、「on_llm_end」はLLMの返答終了時に呼ばれる関数です。メッセージ間にまたがるステートの切り替えはここで実装すればいいことになります。

この例では、メッセージ単位のインデックスをつけて、返答開始・終了時にPrefix, Suffixとなるメッセージを表示しているだけです。

Streamlitとつなぎこんだ場合

Streamlitにつなぎこんだ場合のコードがこちらです。

from typing import Any
from PyPDF2 import PdfReader
from langchain.chains import RetrievalQA
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Qdrant
from langchain.callbacks.base import BaseCallbackHandler
import streamlit as st

from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

if "messages" not in st.session_state:
    st.session_state.messages = []

class StreamHandler(BaseCallbackHandler):
    def __init__(self, initial_text=""):
        self.initial_text = initial_text
        self.text = initial_text

    def on_llm_start(self, *args: Any, **kwargs: Any):
        self.text = self.initial_text
        # Weird code. But just works fine.
        with st.chat_message("assistant"):
            self.container = st.empty()

    def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
        # Add to UI Only
        self.text += token
        self.container.markdown(self.text)
        # print(token, end="")

    def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
        # Add to state
        st.session_state.messages.append({
            "role": "assistant",
            "content": self.text
        })


def start_logic():
    pdf_reader = PdfReader("2211.11559.pdf")
    text = '\n\n'.join([page.extract_text() for page in pdf_reader.pages])
    text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
        separators=["\n\n"],
        chunk_size=512,
        chunk_overlap=128,
    )
    pdf_text = text_splitter.split_text(text)

    client = QdrantClient(":memory:")

    collections = client.get_collections().collections
    collection_names = [collection.name for collection in collections]
    collection_name = "pdf_content"

    if collection_name not in collection_names:
        client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
        )

    qdrant = Qdrant(
        client=client,
        collection_name=collection_name,
        embeddings=OpenAIEmbeddings())
    qdrant.add_texts(pdf_text)

    stream_handler = StreamHandler()
    llm = ChatOpenAI(model="gpt-3.5-turbo-16k-0613", temperature=0.2, max_tokens=512, streaming=True, callbacks=[stream_handler])
    retriever = qdrant.as_retriever(
        search_type="similarity",
        search_kwargs={"k":4}
    )
    qa = RetrievalQA.from_chain_type(
            llm=llm,
            chain_type="stuff", 
            retriever=retriever,
            return_source_documents=True,
        )

    questions = [
        "What is this paper about?",
        "What's so great about it compared to previous research?",
        "What is the key to the technique or method?",
        "How does it validate the proposed method?",
        "Is there any discussion of this paper?",
        "What paper should I read next to better understand this paper?"
    ]
    for question in questions:
        # Add to state
        st.session_state.messages.append({
            "role": "user",
            "content": question
        })
        # Add to UI
        with st.chat_message("user"):
            st.markdown(question)
        qa(question)


st.title("Chat with PDF Paper")

# Display chat messages from history on app rerun
for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

# main button
button = st.button("Click to start chat")
if button:
    start_logic()

これを実行した結果が先程示した冒頭の動画になります。これを自力で実装するには、Streamlitのステートとビューと再読み込みのタイミングの不思議な関係を知る必要があります。

Streamlitではステート変数はst.session_stateという場所に保存され、基本的にステート変数とUIはリンクすることもあります。例えば、

st.markdown(st.session_state.hoge)

というUIの表示があったとして、hogeがチェックボックスのようなコンポーネントの値を記録し、チェックボックスに動作が入りhogeというステート変数が変わると、UI側に自動的に反映されます。

なら、「ステート変数にストリーミングの内容を逐次記録していけば、順次反映されるのではないの?」と思うかもしれませんが、そうなるとは限らないこともあります。私もちゃんと理解しきれていないのですが、例えば別スレッドで動かしたプログラムからステートをいじった場合、UIへの反映がトリガーされないようで、単にストリーミングからステート変数に書き出しているだけではUIへの反映がされないこともありました。dictやlistの変更がキャッチできていないのかもしれません(それだったらVueでもある例なので納得はいきます)

今回のストリーミングの場合は、コールバック内でUIの書き換えを明示することではじめてUIの変更が反映されました。

今回はうまく回避したのですが、Streamlitの場合、ボタンを押すなどのUIのトリガーが走ったときの、ページの再読み込みで全画面を再描画するという癖の強い仕様があります。なので、ストリーミングチャットを2個同時に並べて表示すると、片方のストリーミングを更新中にもう片方のボタンを押すと実装次第では前にボタンを押したチャットが消えてしまうという問題が発生します。このへんの描画するエリアを限定できないのがStreamlitの気持ち悪い部分でもあり、実装する際にハマりがちな部分です。ストリーミングのような非同期っぽい処理を入れ込もうとすると、結構慣れが必要なのが難しい点です。

Gradioだと関数ドリブンの記述になるので、全画面再描画は回避できますが、UIが野暮ったいのでStreamlitを使いたいこともあります(あくまでPythonベースなら)。もっとちゃんとしたフロントエンド言語を使えばシュッとできちゃうかもしれないのですが、あくまでPython限定ということで。

ひとまずStreamlitでも動いたので安心しました。複数チャットはもう1画面に収めるのを諦めて2ページに跨がらせました(このコードだとチャット履歴は保存されているので、ストリーミングの実行完了した状態でページをまたいで戻ればチャット履歴は復活します)。



Shikoan's ML Blogの中の人が運営しているサークル「じゅ~しぃ~すくりぷと」の本のご案内

技術書コーナー

北海道の駅巡りコーナー


Add a Comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です