返回

LangChain LLMChain 流式传输:异步生成器详解与实战

python

从 LangChain 的 LLMChain 获取异步生成器进行流式传输

在使用 LangChain 与大型语言模型(LLM)构建流式应用程序时,开发人员可能会遇到一个常见的难题:LLMChainastream 方法并未按照预期返回异步生成器,导致无法逐 token 进行流式传输。 本文探讨问题根源,并提供多种解决方案。

问题剖析

期望的行为是, astream 方法能返回一个异步生成器,让我们可以按顺序获取每个生成的文本块。 但实际情况可能并非如此,终端上出现的是完整输出,而不是流式的 token。 这主要是因为:

  1. 回调处理器的限制: 默认的回调处理器,例如 StreamingStdOutCallbackHandler,设计用来在终端输出流式内容,不直接向用户代码提供 token 级别的控制。 当使用 StreamingStdOutCallbackHandler 时,会直接输出整个文本流,而不是生成器。
  2. 异步生成器的误解: LLMChainastream 方法,在配置不当时,可能并没有以正确的方式产生异步生成器,而是等整个输出完成后才一并返回。 这导致客户端看似接收到单一的响应,而非数据流。

解决方案一:使用自定义的回调处理器

通过创建自定义回调处理器,我们可以精确控制模型输出的 token 处理方式,并正确实现异步生成。 该方法将从每个 token 中生成一个异步生成器。

操作步骤:

  1. 定义自定义回调处理器: 创建一个继承自 AsyncCallbackHandler 的类。 在 on_llm_new_token 方法中,使用 async yield 来发出 token。
  2. 初始化 LLM 时应用: 将自定义的回调管理器传给 LLM 模型,同时启用 streaming 参数。

代码示例:

from langchain.callbacks.base import AsyncCallbackHandler
from langchain.chat_models import AzureChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.manager import AsyncCallbackManager
import asyncio

class CustomAsyncCallbackHandler(AsyncCallbackHandler):
    async def on_llm_new_token(self, token: str, **kwargs):
        yield token

class OpenAIModel:
    def __init__(self):
        self.llm = None

    def __call__(self, streaming:bool = False) -> str:
        if self.llm is None:
            openai_params = {
                # other params removed for debugging purpose
                'streaming': streaming,
                'callback_manager': AsyncCallbackManager([CustomAsyncCallbackHandler()]) if streaming else None,
                'verbose': True if streaming else False
            }

            self.llm = AzureChatOpenAI(**openai_params)
        return self

    async def streaming_answer(self, question: str):
      qaPrompt = PromptTemplate(
            input_variables=["question"], template="OPENAI_TEMPLATE"  # Anonymized template
        )

      chain = LLMChain(llm=self.llm, prompt=qaPrompt)
      async for chunk in chain.astream({"question": question}):
        async for token in chunk["text_generation"]:
           yield token

async def main():
    model = OpenAIModel()
    model(streaming=True)
    async for token in model.streaming_answer(question="hello"):
        print(token, end="", flush=True)
    print("\nStreaming completed!")

if __name__ == "__main__":
    asyncio.run(main())

解决方案二:手动处理 stream 返回

有些时候 LLMChain.astream 由于配置问题, 或者底层逻辑调整会不如预期, 没有返回预期的异步生成器。 可以尝试通过llm.stream获取流式数据, 自己处理, 构建异步生成器。
此方法可以直接使用底层的流式 API,对流的控制粒度更精细。

操作步骤:

  1. 使用 llm.stream 获取数据: llmstream 方法直接返回可迭代的响应。
  2. 创建异步生成器: 使用 async for 遍历stream 方法返回的可迭代对象, 并将解析得到的 token 通过 yield 输出。

代码示例:

from langchain.chat_models import AzureChatOpenAI
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain.callbacks.manager import AsyncCallbackManager
import asyncio

class OpenAIModel:
    def __init__(self):
        self.llm = None

    def __call__(self, streaming: bool = False):
        if self.llm is None:
           openai_params = {
            # other params removed for debugging purpose
            'streaming': streaming,
            'callback_manager': AsyncCallbackManager() if streaming else None,
            'verbose': True if streaming else False
          }

           self.llm = AzureChatOpenAI(**openai_params)

        return self

    async def streaming_answer(self, question: str):
      qaPrompt = PromptTemplate(
          input_variables=["question"], template="OPENAI_TEMPLATE"  # Anonymized template
        )
      chain = LLMChain(llm=self.llm, prompt=qaPrompt)
      async def stream():
         async for output in self.llm.stream(qaPrompt.format_prompt(question=question).to_messages()):
            yield output

      async for item in stream():
           if 'content' in item.dict()["message"]:
              yield item.dict()["message"]["content"]


async def main():
    model = OpenAIModel()
    model(streaming=True)
    async for token in model.streaming_answer(question="hello"):
       print(token, end="", flush=True)

    print("\nStreaming completed!")

if __name__ == "__main__":
    asyncio.run(main())

安全建议

在处理流式数据时,需要注意以下几点:

  • 速率限制: 在快速处理 token 流时,留意 OpenAI API 或你使用的服务的速率限制,避免请求过多导致错误。
  • 错误处理: 在流式过程中,增加对错误的鲁棒性处理。使用 try-except 代码块处理 API 错误或网络问题。
  • 内容安全: 需要实施额外的安全机制,用于监控或过滤输出中的潜在有害内容。可以配合 LangChain 的相关工具进行实现。

通过上述两种方法,您可以更好地控制 LLM 的输出流,实现高效和自定义的流式数据处理。选择哪种方案取决于您对代码可读性和精细控制的需求。