mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: Fix trailing } in chat completions interface (#842)
This commit is contained in:
parent
9a9bcbc0c5
commit
96c72fb157
@ -431,7 +431,6 @@ class Agent(BaseAgent):
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
)
|
||||
) # extend conversation with assistant's reply
|
||||
self.logger.info(f"Function call message: {messages[-1]}")
|
||||
|
||||
nonnull_content = False
|
||||
if response_message.content:
|
||||
@ -445,10 +444,7 @@ class Agent(BaseAgent):
|
||||
function_call = (
|
||||
response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function
|
||||
)
|
||||
|
||||
# Get the name of the function
|
||||
function_name = function_call.name
|
||||
self.logger.info(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
|
||||
|
||||
# Failure case 1: function name is wrong (not in agent_state.tools)
|
||||
target_letta_tool = None
|
||||
|
@ -17,48 +17,45 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]:
|
||||
|
||||
with httpx.Client() as client:
|
||||
"""
|
||||
Sends an SSE POST request and yields parsed response chunks.
|
||||
"""
|
||||
# TODO: Please note his is a very generous timeout for e2b reasons
|
||||
with httpx.Client(timeout=httpx.Timeout(5 * 60.0, read=5 * 60.0)) as client:
|
||||
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
|
||||
|
||||
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12)
|
||||
# Check for immediate HTTP errors before processing the SSE stream
|
||||
if not event_source.response.is_success:
|
||||
# handle errors
|
||||
pass
|
||||
|
||||
logger.warning("Caught error before iterating SSE request:", vars(event_source.response))
|
||||
logger.warning(event_source.response.read().decode("utf-8"))
|
||||
response_bytes = event_source.response.read()
|
||||
logger.warning(f"SSE request error: {vars(event_source.response)}")
|
||||
logger.warning(response_bytes.decode("utf-8"))
|
||||
|
||||
try:
|
||||
response_bytes = event_source.response.read()
|
||||
response_dict = json.loads(response_bytes.decode("utf-8"))
|
||||
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
|
||||
if (
|
||||
"error" in response_dict
|
||||
and "message" in response_dict["error"]
|
||||
and OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in response_dict["error"]["message"]
|
||||
):
|
||||
logger.error(response_dict["error"]["message"])
|
||||
raise LLMError(response_dict["error"]["message"])
|
||||
error_message = response_dict.get("error", {}).get("message", "")
|
||||
|
||||
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
|
||||
logger.error(error_message)
|
||||
raise LLMError(error_message)
|
||||
except LLMError:
|
||||
raise
|
||||
except:
|
||||
logger.error(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
|
||||
except Exception:
|
||||
logger.error("Failed to parse SSE message, raising HTTP error")
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
# if sse.data == OPENAI_SSE_DONE:
|
||||
# print("finished")
|
||||
# break
|
||||
if sse.data in [status.value for status in MessageStreamStatus]:
|
||||
# break
|
||||
if sse.data in {status.value for status in MessageStreamStatus}:
|
||||
yield MessageStreamStatus(sse.data)
|
||||
if sse.data == MessageStreamStatus.done.value:
|
||||
# We received the [DONE], so stop reading the stream.
|
||||
break
|
||||
else:
|
||||
chunk_data = json.loads(sse.data)
|
||||
|
||||
if "reasoning" in chunk_data:
|
||||
yield ReasoningMessage(**chunk_data)
|
||||
elif "message_type" in chunk_data and chunk_data["message_type"] == "assistant_message":
|
||||
elif chunk_data.get("message_type") == "assistant_message":
|
||||
yield AssistantMessage(**chunk_data)
|
||||
elif "tool_call" in chunk_data:
|
||||
yield ToolCallMessage(**chunk_data)
|
||||
@ -67,33 +64,31 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStrea
|
||||
elif "step_count" in chunk_data:
|
||||
yield LettaUsageStatistics(**chunk_data)
|
||||
elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]:
|
||||
yield ChatCompletionChunk(**chunk_data) # Add your processing logic for chat chunks here
|
||||
yield ChatCompletionChunk(**chunk_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
|
||||
|
||||
except SSEError as e:
|
||||
logger.error("Caught an error while iterating the SSE stream:", str(e))
|
||||
if "application/json" in str(e): # Check if the error is because of JSON response
|
||||
# TODO figure out a better way to catch the error other than re-trying with a POST
|
||||
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
|
||||
if response.headers["Content-Type"].startswith("application/json"):
|
||||
error_details = response.json() # Parse the JSON to get the error message
|
||||
logger.error("Request:", vars(response.request))
|
||||
logger.error("POST Error:", error_details)
|
||||
logger.error("Original SSE Error:", str(e))
|
||||
logger.error(f"SSE stream error: {e}")
|
||||
|
||||
if "application/json" in str(e):
|
||||
response = client.post(url=url, json=data, headers=headers)
|
||||
|
||||
if response.headers.get("Content-Type", "").startswith("application/json"):
|
||||
error_details = response.json()
|
||||
logger.error(f"POST Error: {error_details}")
|
||||
else:
|
||||
logger.error("Failed to retrieve JSON error message via retry.")
|
||||
else:
|
||||
logger.error("SSEError not related to 'application/json' content type.")
|
||||
|
||||
# Optionally re-raise the exception if you need to propagate it
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
if event_source.response.request is not None:
|
||||
logger.error("HTTP Request:", vars(event_source.response.request))
|
||||
if event_source.response is not None:
|
||||
logger.error("HTTP Status:", event_source.response.status_code)
|
||||
logger.error("HTTP Headers:", event_source.response.headers)
|
||||
logger.error("Exception message:", str(e))
|
||||
logger.error(f"Unexpected exception: {e}")
|
||||
|
||||
if event_source.response.request:
|
||||
logger.error(f"HTTP Request: {vars(event_source.response.request)}")
|
||||
if event_source.response:
|
||||
logger.error(f"HTTP Status: {event_source.response.status_code}")
|
||||
logger.error(f"HTTP Headers: {event_source.response.headers}")
|
||||
|
||||
raise e
|
||||
|
@ -7,6 +7,7 @@ from openai import OpenAI
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as _Message
|
||||
from letta.schemas.message import MessageRole as _MessageRole
|
||||
@ -26,7 +27,7 @@ from letta.schemas.openai.embedding_response import EmbeddingResponse
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.utils import get_tool_call_id, smart_urljoin
|
||||
|
||||
OPENAI_SSE_DONE = "[DONE]"
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def openai_get_model_list(
|
||||
@ -354,9 +355,10 @@ def openai_chat_completions_process_stream(
|
||||
except Exception as e:
|
||||
if stream_interface:
|
||||
stream_interface.stream_end()
|
||||
print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
|
||||
logger.error(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
logger.info(f"Finally ending streaming interface.")
|
||||
if stream_interface:
|
||||
stream_interface.stream_end()
|
||||
|
||||
|
@ -41,7 +41,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
def __init__(
|
||||
self,
|
||||
multi_step: bool = True,
|
||||
timeout: int = 150,
|
||||
timeout: int = 3 * 60,
|
||||
# The following are placeholders for potential expansions; they
|
||||
# remain if you need to differentiate between actual "assistant messages"
|
||||
# vs. tool calls. By default, they are set for the "send_message" tool usage.
|
||||
@ -55,6 +55,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
# Parsing state for incremental function-call data
|
||||
self.current_function_name = ""
|
||||
self.current_function_arguments = []
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
# Internal chunk buffer and event for async notification
|
||||
self._chunks = deque()
|
||||
@ -85,6 +86,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("Chat completions interface timed out! Please check that this is intended.")
|
||||
break
|
||||
|
||||
while self._chunks:
|
||||
@ -105,7 +107,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
self,
|
||||
item: ChatCompletionChunk,
|
||||
):
|
||||
"""
|
||||
"""m
|
||||
Add an item (a LettaMessage, status marker, or partial chunk)
|
||||
to the queue and signal waiting consumers.
|
||||
"""
|
||||
@ -156,6 +158,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
Called externally with a ChatCompletionChunkResponse. Transforms
|
||||
it if necessary, then enqueues partial messages for streaming back.
|
||||
"""
|
||||
# print("RECEIVED CHUNK...")
|
||||
processed_chunk = self._process_chunk_to_openai_style(chunk)
|
||||
if processed_chunk is not None:
|
||||
self._push_to_buffer(processed_chunk)
|
||||
@ -216,37 +219,43 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
combined_args = "".join(self.current_function_arguments)
|
||||
parsed_args = OptimisticJSONParser().parse(combined_args)
|
||||
|
||||
# If we can see a "message" field, return it as partial content
|
||||
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
# If the parsed result is different
|
||||
# This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
|
||||
if parsed_args != self.current_json_parse_result:
|
||||
self.current_json_parse_result = parsed_args
|
||||
# If we can see a "message" field, return it as partial content
|
||||
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# If there's a finish reason, pass that along
|
||||
if choice.finish_reason is not None:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(),
|
||||
finish_reason=self.FINISH_REASON_STR,
|
||||
)
|
||||
],
|
||||
)
|
||||
# only emit a final chunk if finish_reason == "stop"
|
||||
if choice.finish_reason == "stop":
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(), # no partial text here
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
@ -9,6 +9,7 @@ from fastapi import Header
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
|
||||
@ -24,10 +25,14 @@ SSE_FINISH_MSG = "[DONE]" # mimic openai
|
||||
SSE_ARTIFICIAL_DELAY = 0.1
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def sse_formatter(data: Union[dict, str]) -> str:
|
||||
"""Prefix with 'data: ', and always include double newlines"""
|
||||
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
|
||||
data_str = json.dumps(data, separators=(",", ":")) if isinstance(data, dict) else data
|
||||
# print(f"data: {data_str}\n\n")
|
||||
return f"data: {data_str}\n\n"
|
||||
|
||||
|
||||
@ -62,23 +67,29 @@ async def sse_async_generator(
|
||||
usage = await usage_task
|
||||
# Double-check the type
|
||||
if not isinstance(usage, LettaUsageStatistics):
|
||||
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}")
|
||||
err_msg = f"Expected LettaUsageStatistics, got {type(usage)}"
|
||||
logger.error(err_msg)
|
||||
raise ValueError(err_msg)
|
||||
yield sse_formatter(usage.model_dump())
|
||||
|
||||
except ContextWindowExceededError as e:
|
||||
log_error_to_sentry(e)
|
||||
logger.error(f"ContextWindowExceededError error: {e}")
|
||||
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
|
||||
|
||||
except RateLimitExceededError as e:
|
||||
log_error_to_sentry(e)
|
||||
logger.error(f"RateLimitExceededError error: {e}")
|
||||
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
|
||||
|
||||
except Exception as e:
|
||||
log_error_to_sentry(e)
|
||||
yield sse_formatter({"error": f"Stream failed (internal error occured)"})
|
||||
logger.error(f"Caught unexpected Exception: {e}")
|
||||
yield sse_formatter({"error": f"Stream failed (internal error occurred)"})
|
||||
|
||||
except Exception as e:
|
||||
log_error_to_sentry(e)
|
||||
logger.error(f"Caught unexpected Exception: {e}")
|
||||
yield sse_formatter({"error": "Stream failed (decoder encountered an error)"})
|
||||
|
||||
finally:
|
||||
|
@ -5,101 +5,139 @@ import uuid
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta import RESTClient, create_client
|
||||
from letta import create_client
|
||||
from letta.client.streaming import _sse_post
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
def run_server():
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
|
||||
# _reset_config()
|
||||
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
scope="module",
|
||||
)
|
||||
def client():
|
||||
# get URL from enviornment
|
||||
server_url = os.getenv("LETTA_SERVER_URL")
|
||||
if server_url is None:
|
||||
# run server in thread
|
||||
server_url = "http://localhost:8283"
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
@pytest.fixture(scope="session")
|
||||
def server_url():
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
# create user via admin client
|
||||
client = create_client(base_url=server_url, token=None) # This yields control back to the test function
|
||||
time.sleep(5) # Allow server startup time
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# --- Client Setup --- #
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = create_client(base_url=server_url, token=None)
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
yield client
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_state(client: RESTClient):
|
||||
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}")
|
||||
yield agent_state
|
||||
@pytest.fixture(scope="function")
|
||||
def roll_dice_tool(client):
|
||||
def roll_dice():
|
||||
"""
|
||||
Rolls a 6 sided die.
|
||||
|
||||
# delete agent
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = client.create_or_update_tool(func=roll_dice)
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent(client, roll_dice_tool):
|
||||
"""Creates an agent and ensures cleanup after tests."""
|
||||
agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id])
|
||||
yield agent_state
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
def test_voice_streaming(mock_e2b_api_key_none, client: RESTClient, agent_state: AgentState):
|
||||
"""
|
||||
Test voice streaming for chat completions using the streaming API.
|
||||
# --- Helper Functions --- #
|
||||
|
||||
This test ensures the SSE (Server-Sent Events) response from the voice streaming endpoint
|
||||
adheres to the expected structure and contains valid data for each type of chunk.
|
||||
"""
|
||||
|
||||
# Prepare the chat completion request with streaming enabled
|
||||
request = ChatCompletionRequest(
|
||||
def _get_chat_request(agent_id, message, stream=True):
|
||||
"""Returns a chat completion request with streaming enabled."""
|
||||
return ChatCompletionRequest(
|
||||
model="gpt-4o-mini",
|
||||
messages=[UserMessage(content="Tell me something interesting about bananas.")],
|
||||
user=agent_state.id,
|
||||
stream=True,
|
||||
messages=[UserMessage(content=message)],
|
||||
user=agent_id,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# Perform a POST request to the voice/chat/completions endpoint and collect the streaming response
|
||||
|
||||
def _assert_valid_chunk(chunk, idx, chunks):
|
||||
"""Validates the structure of each streaming chunk."""
|
||||
if isinstance(chunk, ChatCompletionChunk):
|
||||
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
||||
|
||||
elif isinstance(chunk, LettaUsageStatistics):
|
||||
assert chunk.completion_tokens > 0, "Completion tokens must be > 0."
|
||||
assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0."
|
||||
assert chunk.total_tokens > 0, "Total tokens must be > 0."
|
||||
assert chunk.step_count == 1, "Step count must be 1."
|
||||
|
||||
elif isinstance(chunk, MessageStreamStatus):
|
||||
assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status."
|
||||
assert idx == len(chunks) - 1, "The last chunk must be 'done'."
|
||||
|
||||
else:
|
||||
pytest.fail(f"Unexpected chunk type: {chunk}")
|
||||
|
||||
|
||||
# --- Test Cases --- #
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
|
||||
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message):
|
||||
"""Tests chat completion streaming via SSE."""
|
||||
request = _get_chat_request(agent.id, message)
|
||||
|
||||
response = _sse_post(
|
||||
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
|
||||
)
|
||||
|
||||
# Convert the streaming response into a list of chunks for processing
|
||||
chunks = list(response)
|
||||
|
||||
for idx, chunk in enumerate(chunks):
|
||||
if isinstance(chunk, ChatCompletionChunk):
|
||||
# Assert that the chunk has at least one choice (a response from the model)
|
||||
assert len(chunk.choices) > 0, "Each ChatCompletionChunk should have at least one choice."
|
||||
_assert_valid_chunk(chunk, idx, chunks)
|
||||
|
||||
elif isinstance(chunk, LettaUsageStatistics):
|
||||
# Assert that the usage statistics contain valid token counts
|
||||
assert chunk.completion_tokens > 0, "Completion tokens should be greater than 0 in LettaUsageStatistics."
|
||||
assert chunk.prompt_tokens > 0, "Prompt tokens should be greater than 0 in LettaUsageStatistics."
|
||||
assert chunk.total_tokens > 0, "Total tokens should be greater than 0 in LettaUsageStatistics."
|
||||
assert chunk.step_count == 1, "Step count in LettaUsageStatistics should always be 1 for a single request."
|
||||
|
||||
elif isinstance(chunk, MessageStreamStatus):
|
||||
# Assert that the stream ends with a 'done' status
|
||||
assert chunk == MessageStreamStatus.done, "The last chunk should indicate the stream has completed."
|
||||
assert idx == len(chunks) - 1, "The 'done' status must be the last chunk in the stream."
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas.", "Roll a dice!"])
|
||||
async def test_chat_completions_streaming_async(client, agent, message):
|
||||
"""Tests chat completion streaming using the Async OpenAI client."""
|
||||
request = _get_chat_request(agent.id, message)
|
||||
|
||||
else:
|
||||
# Fail the test if an unexpected chunk type is encountered
|
||||
pytest.fail(f"Unexpected chunk type: {chunk}", pytrace=True)
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/openai/{client.api_prefix}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if isinstance(chunk, ChatCompletionChunk):
|
||||
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
||||
assert chunk.choices[0].delta.content, f"Chunk at index 0 has no content: {chunk.model_dump_json(indent=4)}"
|
||||
else:
|
||||
pytest.fail(f"Unexpected chunk type: {chunk}")
|
||||
|
Loading…
Reference in New Issue
Block a user