MemGPT/letta/interfaces/openai_chat_completions_streaming_interface.py
cthomas 20ecab29a1
chore: bump version 0.7.8 (#2604)
Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
2025-04-30 23:39:58 -07:00

114 lines
4.7 KiB
Python

from typing import Any, AsyncGenerator, Dict, List, Optional
from openai import AsyncStream
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
from letta.interfaces.utils import _format_sse_chunk
from letta.server.rest_api.json_parser import OptimisticJSONParser
class OpenAIChatCompletionsStreamingInterface:
"""
Encapsulates the logic for streaming responses from OpenAI.
This class handles parsing of partial tokens, pre-execution messages,
and detection of tool call events.
"""
def __init__(self, stream_pre_execution_message: bool = True):
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.stream_pre_execution_message: bool = stream_pre_execution_message
self.current_parsed_json_result: Dict[str, Any] = {}
self.content_buffer: List[str] = []
self.tool_call_happened: bool = False
self.finish_reason_stop: bool = False
self.tool_call_name: Optional[str] = None
self.tool_call_args_str: str = ""
self.tool_call_id: Optional[str] = None
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]:
"""
Iterates over the OpenAI stream, yielding SSE events.
It also collects tokens and detects if a tool call is triggered.
"""
async with stream:
async for chunk in stream:
choice = chunk.choices[0]
delta = choice.delta
finish_reason = choice.finish_reason
async for sse_chunk in self._process_content(delta, chunk):
yield sse_chunk
async for sse_chunk in self._process_tool_calls(delta, chunk):
yield sse_chunk
if self._handle_finish_reason(finish_reason):
break
async def _process_content(self, delta: ChoiceDelta, chunk: ChatCompletionChunk) -> AsyncGenerator[str, None]:
"""Processes regular content tokens and streams them."""
if delta.content:
self.content_buffer.append(delta.content)
yield _format_sse_chunk(chunk)
async def _process_tool_calls(self, delta: ChoiceDelta, chunk: ChatCompletionChunk) -> AsyncGenerator[str, None]:
"""Handles tool call initiation and streaming of pre-execution messages."""
if not delta.tool_calls:
return
tool_call = delta.tool_calls[0]
self._update_tool_call_info(tool_call)
if self.stream_pre_execution_message and tool_call.function.arguments:
self.tool_call_args_str += tool_call.function.arguments
async for sse_chunk in self._stream_pre_execution_message(chunk, tool_call):
yield sse_chunk
def _update_tool_call_info(self, tool_call: Any) -> None:
"""Updates tool call-related attributes."""
if tool_call.function.name:
self.tool_call_name = tool_call.function.name
if tool_call.id:
self.tool_call_id = tool_call.id
async def _stream_pre_execution_message(self, chunk: ChatCompletionChunk, tool_call: Any) -> AsyncGenerator[str, None]:
"""Parses and streams pre-execution messages if they have changed."""
parsed_args = self.optimistic_json_parser.parse(self.tool_call_args_str)
if parsed_args.get(PRE_EXECUTION_MESSAGE_ARG) and parsed_args[PRE_EXECUTION_MESSAGE_ARG] != self.current_parsed_json_result.get(
PRE_EXECUTION_MESSAGE_ARG
):
# Extract old and new message content
old = self.current_parsed_json_result.get(PRE_EXECUTION_MESSAGE_ARG, "")
new = parsed_args[PRE_EXECUTION_MESSAGE_ARG]
# Compute the new content by slicing off the old prefix
content = new[len(old) :] if old else new
# Update current state
self.current_parsed_json_result = parsed_args
# Yield the formatted SSE chunk
yield _format_sse_chunk(
ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created,
model=chunk.model,
choices=[Choice(index=0, delta=ChoiceDelta(content=content, role="assistant"), finish_reason=None)],
)
)
def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool:
"""Handles the finish reason and determines if streaming should stop."""
if finish_reason == "tool_calls":
self.tool_call_happened = True
return True
if finish_reason == "stop":
self.finish_reason_stop = True
return True
return False