fix: Fix trailing } in chat completions interface (#842)

This commit is contained in:
Matthew Zhou 2025-02-10 15:45:13 -08:00 committed by GitHub
parent 9a9bcbc0c5
commit 96c72fb157
6 changed files with 191 additions and 140 deletions

View File

@ -431,7 +431,6 @@ class Agent(BaseAgent):
openai_message_dict=response_message.model_dump(),
)
) # extend conversation with assistant's reply
self.logger.info(f"Function call message: {messages[-1]}")
nonnull_content = False
if response_message.content:
@ -445,10 +444,7 @@ class Agent(BaseAgent):
function_call = (
response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function
)
# Get the name of the function
function_name = function_call.name
self.logger.info(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
# Failure case 1: function name is wrong (not in agent_state.tools)
target_letta_tool = None

View File

@ -17,48 +17,45 @@ logger = get_logger(__name__)
def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]:
with httpx.Client() as client:
"""
Sends an SSE POST request and yields parsed response chunks.
"""
# TODO: Please note his is a very generous timeout for e2b reasons
with httpx.Client(timeout=httpx.Timeout(5 * 60.0, read=5 * 60.0)) as client:
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12)
# Check for immediate HTTP errors before processing the SSE stream
if not event_source.response.is_success:
# handle errors
pass
logger.warning("Caught error before iterating SSE request:", vars(event_source.response))
logger.warning(event_source.response.read().decode("utf-8"))
response_bytes = event_source.response.read()
logger.warning(f"SSE request error: {vars(event_source.response)}")
logger.warning(response_bytes.decode("utf-8"))
try:
response_bytes = event_source.response.read()
response_dict = json.loads(response_bytes.decode("utf-8"))
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
if (
"error" in response_dict
and "message" in response_dict["error"]
and OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in response_dict["error"]["message"]
):
logger.error(response_dict["error"]["message"])
raise LLMError(response_dict["error"]["message"])
error_message = response_dict.get("error", {}).get("message", "")
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
logger.error(error_message)
raise LLMError(error_message)
except LLMError:
raise
except:
logger.error(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
except Exception:
logger.error("Failed to parse SSE message, raising HTTP error")
event_source.response.raise_for_status()
try:
for sse in event_source.iter_sse():
# if sse.data == OPENAI_SSE_DONE:
# print("finished")
# break
if sse.data in [status.value for status in MessageStreamStatus]:
# break
if sse.data in {status.value for status in MessageStreamStatus}:
yield MessageStreamStatus(sse.data)
if sse.data == MessageStreamStatus.done.value:
# We received the [DONE], so stop reading the stream.
break
else:
chunk_data = json.loads(sse.data)
if "reasoning" in chunk_data:
yield ReasoningMessage(**chunk_data)
elif "message_type" in chunk_data and chunk_data["message_type"] == "assistant_message":
elif chunk_data.get("message_type") == "assistant_message":
yield AssistantMessage(**chunk_data)
elif "tool_call" in chunk_data:
yield ToolCallMessage(**chunk_data)
@ -67,33 +64,31 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStrea
elif "step_count" in chunk_data:
yield LettaUsageStatistics(**chunk_data)
elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]:
yield ChatCompletionChunk(**chunk_data) # Add your processing logic for chat chunks here
yield ChatCompletionChunk(**chunk_data)
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
except SSEError as e:
logger.error("Caught an error while iterating the SSE stream:", str(e))
if "application/json" in str(e): # Check if the error is because of JSON response
# TODO figure out a better way to catch the error other than re-trying with a POST
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
if response.headers["Content-Type"].startswith("application/json"):
error_details = response.json() # Parse the JSON to get the error message
logger.error("Request:", vars(response.request))
logger.error("POST Error:", error_details)
logger.error("Original SSE Error:", str(e))
logger.error(f"SSE stream error: {e}")
if "application/json" in str(e):
response = client.post(url=url, json=data, headers=headers)
if response.headers.get("Content-Type", "").startswith("application/json"):
error_details = response.json()
logger.error(f"POST Error: {error_details}")
else:
logger.error("Failed to retrieve JSON error message via retry.")
else:
logger.error("SSEError not related to 'application/json' content type.")
# Optionally re-raise the exception if you need to propagate it
raise e
except Exception as e:
if event_source.response.request is not None:
logger.error("HTTP Request:", vars(event_source.response.request))
if event_source.response is not None:
logger.error("HTTP Status:", event_source.response.status_code)
logger.error("HTTP Headers:", event_source.response.headers)
logger.error("Exception message:", str(e))
logger.error(f"Unexpected exception: {e}")
if event_source.response.request:
logger.error(f"HTTP Request: {vars(event_source.response.request)}")
if event_source.response:
logger.error(f"HTTP Status: {event_source.response.status_code}")
logger.error(f"HTTP Headers: {event_source.response.headers}")
raise e

View File

@ -7,6 +7,7 @@ from openai import OpenAI
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole
@ -26,7 +27,7 @@ from letta.schemas.openai.embedding_response import EmbeddingResponse
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
from letta.utils import get_tool_call_id, smart_urljoin
OPENAI_SSE_DONE = "[DONE]"
logger = get_logger(__name__)
def openai_get_model_list(
@ -354,9 +355,10 @@ def openai_chat_completions_process_stream(
except Exception as e:
if stream_interface:
stream_interface.stream_end()
print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
logger.error(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
raise e
finally:
logger.info(f"Finally ending streaming interface.")
if stream_interface:
stream_interface.stream_end()

View File

@ -41,7 +41,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
def __init__(
self,
multi_step: bool = True,
timeout: int = 150,
timeout: int = 3 * 60,
# The following are placeholders for potential expansions; they
# remain if you need to differentiate between actual "assistant messages"
# vs. tool calls. By default, they are set for the "send_message" tool usage.
@ -55,6 +55,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
# Parsing state for incremental function-call data
self.current_function_name = ""
self.current_function_arguments = []
self.current_json_parse_result = {}
# Internal chunk buffer and event for async notification
self._chunks = deque()
@ -85,6 +86,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
try:
await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
logger.warning("Chat completions interface timed out! Please check that this is intended.")
break
while self._chunks:
@ -105,7 +107,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
self,
item: ChatCompletionChunk,
):
"""
"""m
Add an item (a LettaMessage, status marker, or partial chunk)
to the queue and signal waiting consumers.
"""
@ -156,6 +158,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
Called externally with a ChatCompletionChunkResponse. Transforms
it if necessary, then enqueues partial messages for streaming back.
"""
# print("RECEIVED CHUNK...")
processed_chunk = self._process_chunk_to_openai_style(chunk)
if processed_chunk is not None:
self._push_to_buffer(processed_chunk)
@ -216,37 +219,43 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
combined_args = "".join(self.current_function_arguments)
parsed_args = OptimisticJSONParser().parse(combined_args)
# If we can see a "message" field, return it as partial content
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
# If the parsed result is different
# This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
if parsed_args != self.current_json_parse_result:
self.current_json_parse_result = parsed_args
# If we can see a "message" field, return it as partial content
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
# If there's a finish reason, pass that along
if choice.finish_reason is not None:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(),
finish_reason=self.FINISH_REASON_STR,
)
],
)
# only emit a final chunk if finish_reason == "stop"
if choice.finish_reason == "stop":
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(), # no partial text here
finish_reason="stop",
)
],
)
return None

View File

@ -9,6 +9,7 @@ from fastapi import Header
from pydantic import BaseModel
from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.log import get_logger
from letta.schemas.usage import LettaUsageStatistics
from letta.server.rest_api.interface import StreamingServerInterface
@ -24,10 +25,14 @@ SSE_FINISH_MSG = "[DONE]" # mimic openai
SSE_ARTIFICIAL_DELAY = 0.1
logger = get_logger(__name__)
def sse_formatter(data: Union[dict, str]) -> str:
"""Prefix with 'data: ', and always include double newlines"""
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
data_str = json.dumps(data, separators=(",", ":")) if isinstance(data, dict) else data
# print(f"data: {data_str}\n\n")
return f"data: {data_str}\n\n"
@ -62,23 +67,29 @@ async def sse_async_generator(
usage = await usage_task
# Double-check the type
if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
err_msg = f"Expected LettaUsageStatistics, got {type(usage)}"
logger.error(err_msg)
raise ValueError(err_msg)
yield sse_formatter(usage.model_dump())
except ContextWindowExceededError as e:
log_error_to_sentry(e)
logger.error(f"ContextWindowExceededError error: {e}")
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
except RateLimitExceededError as e:
log_error_to_sentry(e)
logger.error(f"RateLimitExceededError error: {e}")
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
except Exception as e:
log_error_to_sentry(e)
yield sse_formatter({"error": f"Stream failed (internal error occured)"})
logger.error(f"Caught unexpected Exception: {e}")
yield sse_formatter({"error": f"Stream failed (internal error occurred)"})
except Exception as e:
log_error_to_sentry(e)
logger.error(f"Caught unexpected Exception: {e}")
yield sse_formatter({"error": "Stream failed (decoder encountered an error)"})
finally:

View File

@ -5,101 +5,139 @@ import uuid
import pytest
from dotenv import load_dotenv
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta import RESTClient, create_client
from letta import create_client
from letta.client.streaming import _sse_post
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
from letta.schemas.usage import LettaUsageStatistics
# --- Server Management --- #
def run_server():
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
# _reset_config()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(
scope="module",
)
def client():
# get URL from enviornment
server_url = os.getenv("LETTA_SERVER_URL")
if server_url is None:
# run server in thread
server_url = "http://localhost:8283"
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create user via admin client
client = create_client(base_url=server_url, token=None) # This yields control back to the test function
time.sleep(5) # Allow server startup time
return url
# --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = create_client(base_url=server_url, token=None)
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
# Fixture for test agent
@pytest.fixture(scope="module")
def agent_state(client: RESTClient):
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}")
yield agent_state
@pytest.fixture(scope="function")
def roll_dice_tool(client):
def roll_dice():
"""
Rolls a 6 sided die.
# delete agent
Returns:
str: The roll result.
"""
return "Rolled a 10!"
tool = client.create_or_update_tool(func=roll_dice)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def agent(client, roll_dice_tool):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id])
yield agent_state
client.delete_agent(agent_state.id)
def test_voice_streaming(mock_e2b_api_key_none, client: RESTClient, agent_state: AgentState):
"""
Test voice streaming for chat completions using the streaming API.
# --- Helper Functions --- #
This test ensures the SSE (Server-Sent Events) response from the voice streaming endpoint
adheres to the expected structure and contains valid data for each type of chunk.
"""
# Prepare the chat completion request with streaming enabled
request = ChatCompletionRequest(
def _get_chat_request(agent_id, message, stream=True):
"""Returns a chat completion request with streaming enabled."""
return ChatCompletionRequest(
model="gpt-4o-mini",
messages=[UserMessage(content="Tell me something interesting about bananas.")],
user=agent_state.id,
stream=True,
messages=[UserMessage(content=message)],
user=agent_id,
stream=stream,
)
# Perform a POST request to the voice/chat/completions endpoint and collect the streaming response
def _assert_valid_chunk(chunk, idx, chunks):
"""Validates the structure of each streaming chunk."""
if isinstance(chunk, ChatCompletionChunk):
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
elif isinstance(chunk, LettaUsageStatistics):
assert chunk.completion_tokens > 0, "Completion tokens must be > 0."
assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0."
assert chunk.total_tokens > 0, "Total tokens must be > 0."
assert chunk.step_count == 1, "Step count must be 1."
elif isinstance(chunk, MessageStreamStatus):
assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status."
assert idx == len(chunks) - 1, "The last chunk must be 'done'."
else:
pytest.fail(f"Unexpected chunk type: {chunk}")
# --- Test Cases --- #
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message):
"""Tests chat completion streaming via SSE."""
request = _get_chat_request(agent.id, message)
response = _sse_post(
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
)
# Convert the streaming response into a list of chunks for processing
chunks = list(response)
for idx, chunk in enumerate(chunks):
if isinstance(chunk, ChatCompletionChunk):
# Assert that the chunk has at least one choice (a response from the model)
assert len(chunk.choices) > 0, "Each ChatCompletionChunk should have at least one choice."
_assert_valid_chunk(chunk, idx, chunks)
elif isinstance(chunk, LettaUsageStatistics):
# Assert that the usage statistics contain valid token counts
assert chunk.completion_tokens > 0, "Completion tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.prompt_tokens > 0, "Prompt tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.total_tokens > 0, "Total tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.step_count == 1, "Step count in LettaUsageStatistics should always be 1 for a single request."
elif isinstance(chunk, MessageStreamStatus):
# Assert that the stream ends with a 'done' status
assert chunk == MessageStreamStatus.done, "The last chunk should indicate the stream has completed."
assert idx == len(chunks) - 1, "The 'done' status must be the last chunk in the stream."
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas.", "Roll a dice!"])
async def test_chat_completions_streaming_async(client, agent, message):
"""Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(agent.id, message)
else:
# Fail the test if an unexpected chunk type is encountered
pytest.fail(f"Unexpected chunk type: {chunk}", pytrace=True)
async_client = AsyncOpenAI(base_url=f"{client.base_url}/openai/{client.api_prefix}", max_retries=0)
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
async with stream:
async for chunk in stream:
if isinstance(chunk, ChatCompletionChunk):
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
assert chunk.choices[0].delta.content, f"Chunk at index 0 has no content: {chunk.model_dump_json(indent=4)}"
else:
pytest.fail(f"Unexpected chunk type: {chunk}")