fix: Fix trailing } in chat completions interface (#842)

This commit is contained in:
Matthew Zhou 2025-02-10 15:45:13 -08:00 committed by GitHub
parent 9a9bcbc0c5
commit 96c72fb157
6 changed files with 191 additions and 140 deletions

View File

@ -431,7 +431,6 @@ class Agent(BaseAgent):
openai_message_dict=response_message.model_dump(), openai_message_dict=response_message.model_dump(),
) )
) # extend conversation with assistant's reply ) # extend conversation with assistant's reply
self.logger.info(f"Function call message: {messages[-1]}")
nonnull_content = False nonnull_content = False
if response_message.content: if response_message.content:
@ -445,10 +444,7 @@ class Agent(BaseAgent):
function_call = ( function_call = (
response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function
) )
# Get the name of the function
function_name = function_call.name function_name = function_call.name
self.logger.info(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")
# Failure case 1: function name is wrong (not in agent_state.tools) # Failure case 1: function name is wrong (not in agent_state.tools)
target_letta_tool = None target_letta_tool = None

View File

@ -17,48 +17,45 @@ logger = get_logger(__name__)
def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]: def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]:
"""
with httpx.Client() as client: Sends an SSE POST request and yields parsed response chunks.
"""
# TODO: Please note his is a very generous timeout for e2b reasons
with httpx.Client(timeout=httpx.Timeout(5 * 60.0, read=5 * 60.0)) as client:
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source: with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12) # Check for immediate HTTP errors before processing the SSE stream
if not event_source.response.is_success: if not event_source.response.is_success:
# handle errors response_bytes = event_source.response.read()
pass logger.warning(f"SSE request error: {vars(event_source.response)}")
logger.warning(response_bytes.decode("utf-8"))
logger.warning("Caught error before iterating SSE request:", vars(event_source.response))
logger.warning(event_source.response.read().decode("utf-8"))
try: try:
response_bytes = event_source.response.read()
response_dict = json.loads(response_bytes.decode("utf-8")) response_dict = json.loads(response_bytes.decode("utf-8"))
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions. error_message = response_dict.get("error", {}).get("message", "")
if (
"error" in response_dict if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
and "message" in response_dict["error"] logger.error(error_message)
and OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in response_dict["error"]["message"] raise LLMError(error_message)
):
logger.error(response_dict["error"]["message"])
raise LLMError(response_dict["error"]["message"])
except LLMError: except LLMError:
raise raise
except: except Exception:
logger.error(f"Failed to parse SSE message, throwing SSE HTTP error up the stack") logger.error("Failed to parse SSE message, raising HTTP error")
event_source.response.raise_for_status() event_source.response.raise_for_status()
try: try:
for sse in event_source.iter_sse(): for sse in event_source.iter_sse():
# if sse.data == OPENAI_SSE_DONE: if sse.data in {status.value for status in MessageStreamStatus}:
# print("finished")
# break
if sse.data in [status.value for status in MessageStreamStatus]:
# break
yield MessageStreamStatus(sse.data) yield MessageStreamStatus(sse.data)
if sse.data == MessageStreamStatus.done.value:
# We received the [DONE], so stop reading the stream.
break
else: else:
chunk_data = json.loads(sse.data) chunk_data = json.loads(sse.data)
if "reasoning" in chunk_data: if "reasoning" in chunk_data:
yield ReasoningMessage(**chunk_data) yield ReasoningMessage(**chunk_data)
elif "message_type" in chunk_data and chunk_data["message_type"] == "assistant_message": elif chunk_data.get("message_type") == "assistant_message":
yield AssistantMessage(**chunk_data) yield AssistantMessage(**chunk_data)
elif "tool_call" in chunk_data: elif "tool_call" in chunk_data:
yield ToolCallMessage(**chunk_data) yield ToolCallMessage(**chunk_data)
@ -67,33 +64,31 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStrea
elif "step_count" in chunk_data: elif "step_count" in chunk_data:
yield LettaUsageStatistics(**chunk_data) yield LettaUsageStatistics(**chunk_data)
elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]: elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]:
yield ChatCompletionChunk(**chunk_data) # Add your processing logic for chat chunks here yield ChatCompletionChunk(**chunk_data)
else: else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}") raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
except SSEError as e: except SSEError as e:
logger.error("Caught an error while iterating the SSE stream:", str(e)) logger.error(f"SSE stream error: {e}")
if "application/json" in str(e): # Check if the error is because of JSON response
# TODO figure out a better way to catch the error other than re-trying with a POST if "application/json" in str(e):
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response response = client.post(url=url, json=data, headers=headers)
if response.headers["Content-Type"].startswith("application/json"):
error_details = response.json() # Parse the JSON to get the error message if response.headers.get("Content-Type", "").startswith("application/json"):
logger.error("Request:", vars(response.request)) error_details = response.json()
logger.error("POST Error:", error_details) logger.error(f"POST Error: {error_details}")
logger.error("Original SSE Error:", str(e))
else: else:
logger.error("Failed to retrieve JSON error message via retry.") logger.error("Failed to retrieve JSON error message via retry.")
else:
logger.error("SSEError not related to 'application/json' content type.")
# Optionally re-raise the exception if you need to propagate it
raise e raise e
except Exception as e: except Exception as e:
if event_source.response.request is not None: logger.error(f"Unexpected exception: {e}")
logger.error("HTTP Request:", vars(event_source.response.request))
if event_source.response is not None: if event_source.response.request:
logger.error("HTTP Status:", event_source.response.status_code) logger.error(f"HTTP Request: {vars(event_source.response.request)}")
logger.error("HTTP Headers:", event_source.response.headers) if event_source.response:
logger.error("Exception message:", str(e)) logger.error(f"HTTP Status: {event_source.response.status_code}")
logger.error(f"HTTP Headers: {event_source.response.headers}")
raise e raise e

View File

@ -7,6 +7,7 @@ from openai import OpenAI
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.log import get_logger
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as _Message from letta.schemas.message import Message as _Message
from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.message import MessageRole as _MessageRole
@ -26,7 +27,7 @@ from letta.schemas.openai.embedding_response import EmbeddingResponse
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
from letta.utils import get_tool_call_id, smart_urljoin from letta.utils import get_tool_call_id, smart_urljoin
OPENAI_SSE_DONE = "[DONE]" logger = get_logger(__name__)
def openai_get_model_list( def openai_get_model_list(
@ -354,9 +355,10 @@ def openai_chat_completions_process_stream(
except Exception as e: except Exception as e:
if stream_interface: if stream_interface:
stream_interface.stream_end() stream_interface.stream_end()
print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}") logger.error(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
raise e raise e
finally: finally:
logger.info(f"Finally ending streaming interface.")
if stream_interface: if stream_interface:
stream_interface.stream_end() stream_interface.stream_end()

View File

@ -41,7 +41,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
def __init__( def __init__(
self, self,
multi_step: bool = True, multi_step: bool = True,
timeout: int = 150, timeout: int = 3 * 60,
# The following are placeholders for potential expansions; they # The following are placeholders for potential expansions; they
# remain if you need to differentiate between actual "assistant messages" # remain if you need to differentiate between actual "assistant messages"
# vs. tool calls. By default, they are set for the "send_message" tool usage. # vs. tool calls. By default, they are set for the "send_message" tool usage.
@ -55,6 +55,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
# Parsing state for incremental function-call data # Parsing state for incremental function-call data
self.current_function_name = "" self.current_function_name = ""
self.current_function_arguments = [] self.current_function_arguments = []
self.current_json_parse_result = {}
# Internal chunk buffer and event for async notification # Internal chunk buffer and event for async notification
self._chunks = deque() self._chunks = deque()
@ -85,6 +86,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
try: try:
await asyncio.wait_for(self._event.wait(), timeout=self.timeout) await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning("Chat completions interface timed out! Please check that this is intended.")
break break
while self._chunks: while self._chunks:
@ -105,7 +107,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
self, self,
item: ChatCompletionChunk, item: ChatCompletionChunk,
): ):
""" """m
Add an item (a LettaMessage, status marker, or partial chunk) Add an item (a LettaMessage, status marker, or partial chunk)
to the queue and signal waiting consumers. to the queue and signal waiting consumers.
""" """
@ -156,6 +158,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
Called externally with a ChatCompletionChunkResponse. Transforms Called externally with a ChatCompletionChunkResponse. Transforms
it if necessary, then enqueues partial messages for streaming back. it if necessary, then enqueues partial messages for streaming back.
""" """
# print("RECEIVED CHUNK...")
processed_chunk = self._process_chunk_to_openai_style(chunk) processed_chunk = self._process_chunk_to_openai_style(chunk)
if processed_chunk is not None: if processed_chunk is not None:
self._push_to_buffer(processed_chunk) self._push_to_buffer(processed_chunk)
@ -216,37 +219,43 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
combined_args = "".join(self.current_function_arguments) combined_args = "".join(self.current_function_arguments)
parsed_args = OptimisticJSONParser().parse(combined_args) parsed_args = OptimisticJSONParser().parse(combined_args)
# If we can see a "message" field, return it as partial content # If the parsed result is different
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]: # This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
return ChatCompletionChunk( if parsed_args != self.current_json_parse_result:
id=chunk.id, self.current_json_parse_result = parsed_args
object=chunk.object, # If we can see a "message" field, return it as partial content
created=chunk.created.timestamp(), if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
model=chunk.model, return ChatCompletionChunk(
choices=[ id=chunk.id,
Choice( object=chunk.object,
index=choice.index, created=chunk.created.timestamp(),
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR), model=chunk.model,
finish_reason=None, choices=[
) Choice(
], index=choice.index,
) delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
# If there's a finish reason, pass that along # If there's a finish reason, pass that along
if choice.finish_reason is not None: if choice.finish_reason is not None:
return ChatCompletionChunk( # only emit a final chunk if finish_reason == "stop"
id=chunk.id, if choice.finish_reason == "stop":
object=chunk.object, return ChatCompletionChunk(
created=chunk.created.timestamp(), id=chunk.id,
model=chunk.model, object=chunk.object,
choices=[ created=chunk.created.timestamp(),
Choice( model=chunk.model,
index=choice.index, choices=[
delta=ChoiceDelta(), Choice(
finish_reason=self.FINISH_REASON_STR, index=choice.index,
) delta=ChoiceDelta(), # no partial text here
], finish_reason="stop",
) )
],
)
return None return None

View File

@ -9,6 +9,7 @@ from fastapi import Header
from pydantic import BaseModel from pydantic import BaseModel
from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.log import get_logger
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics
from letta.server.rest_api.interface import StreamingServerInterface from letta.server.rest_api.interface import StreamingServerInterface
@ -24,10 +25,14 @@ SSE_FINISH_MSG = "[DONE]" # mimic openai
SSE_ARTIFICIAL_DELAY = 0.1 SSE_ARTIFICIAL_DELAY = 0.1
logger = get_logger(__name__)
def sse_formatter(data: Union[dict, str]) -> str: def sse_formatter(data: Union[dict, str]) -> str:
"""Prefix with 'data: ', and always include double newlines""" """Prefix with 'data: ', and always include double newlines"""
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}" assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
data_str = json.dumps(data, separators=(",", ":")) if isinstance(data, dict) else data data_str = json.dumps(data, separators=(",", ":")) if isinstance(data, dict) else data
# print(f"data: {data_str}\n\n")
return f"data: {data_str}\n\n" return f"data: {data_str}\n\n"
@ -62,23 +67,29 @@ async def sse_async_generator(
usage = await usage_task usage = await usage_task
# Double-check the type # Double-check the type
if not isinstance(usage, LettaUsageStatistics): if not isinstance(usage, LettaUsageStatistics):
raise ValueError(f"Expected LettaUsageStatistics, got {type(usage)}") err_msg = f"Expected LettaUsageStatistics, got {type(usage)}"
logger.error(err_msg)
raise ValueError(err_msg)
yield sse_formatter(usage.model_dump()) yield sse_formatter(usage.model_dump())
except ContextWindowExceededError as e: except ContextWindowExceededError as e:
log_error_to_sentry(e) log_error_to_sentry(e)
logger.error(f"ContextWindowExceededError error: {e}")
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None}) yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
except RateLimitExceededError as e: except RateLimitExceededError as e:
log_error_to_sentry(e) log_error_to_sentry(e)
logger.error(f"RateLimitExceededError error: {e}")
yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None}) yield sse_formatter({"error": f"Stream failed: {e}", "code": str(e.code.value) if e.code else None})
except Exception as e: except Exception as e:
log_error_to_sentry(e) log_error_to_sentry(e)
yield sse_formatter({"error": f"Stream failed (internal error occured)"}) logger.error(f"Caught unexpected Exception: {e}")
yield sse_formatter({"error": f"Stream failed (internal error occurred)"})
except Exception as e: except Exception as e:
log_error_to_sentry(e) log_error_to_sentry(e)
logger.error(f"Caught unexpected Exception: {e}")
yield sse_formatter({"error": "Stream failed (decoder encountered an error)"}) yield sse_formatter({"error": "Stream failed (decoder encountered an error)"})
finally: finally:

View File

@ -5,101 +5,139 @@ import uuid
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta import RESTClient, create_client from letta import create_client
from letta.client.streaming import _sse_post from letta.client.streaming import _sse_post
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus from letta.schemas.enums import MessageStreamStatus
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
from letta.schemas.usage import LettaUsageStatistics from letta.schemas.usage import LettaUsageStatistics
# --- Server Management --- #
def run_server():
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv() load_dotenv()
# _reset_config()
from letta.server.rest_api.app import start_server from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True) start_server(debug=True)
@pytest.fixture( @pytest.fixture(scope="session")
scope="module", def server_url():
) """Ensures a server is running and returns its base URL."""
def client(): url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# get URL from enviornment
server_url = os.getenv("LETTA_SERVER_URL") if not os.getenv("LETTA_SERVER_URL"):
if server_url is None: thread = threading.Thread(target=_run_server, daemon=True)
# run server in thread
server_url = "http://localhost:8283"
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start() thread.start()
time.sleep(5) time.sleep(5) # Allow server startup time
print("Running client tests with server:", server_url)
# create user via admin client return url
client = create_client(base_url=server_url, token=None) # This yields control back to the test function
# --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = create_client(base_url=server_url, token=None)
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client yield client
# Fixture for test agent @pytest.fixture(scope="function")
@pytest.fixture(scope="module") def roll_dice_tool(client):
def agent_state(client: RESTClient): def roll_dice():
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}") """
yield agent_state Rolls a 6 sided die.
# delete agent Returns:
str: The roll result.
"""
return "Rolled a 10!"
tool = client.create_or_update_tool(func=roll_dice)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def agent(client, roll_dice_tool):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id])
yield agent_state
client.delete_agent(agent_state.id) client.delete_agent(agent_state.id)
def test_voice_streaming(mock_e2b_api_key_none, client: RESTClient, agent_state: AgentState): # --- Helper Functions --- #
"""
Test voice streaming for chat completions using the streaming API.
This test ensures the SSE (Server-Sent Events) response from the voice streaming endpoint
adheres to the expected structure and contains valid data for each type of chunk.
"""
# Prepare the chat completion request with streaming enabled def _get_chat_request(agent_id, message, stream=True):
request = ChatCompletionRequest( """Returns a chat completion request with streaming enabled."""
return ChatCompletionRequest(
model="gpt-4o-mini", model="gpt-4o-mini",
messages=[UserMessage(content="Tell me something interesting about bananas.")], messages=[UserMessage(content=message)],
user=agent_state.id, user=agent_id,
stream=True, stream=stream,
) )
# Perform a POST request to the voice/chat/completions endpoint and collect the streaming response
def _assert_valid_chunk(chunk, idx, chunks):
"""Validates the structure of each streaming chunk."""
if isinstance(chunk, ChatCompletionChunk):
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
elif isinstance(chunk, LettaUsageStatistics):
assert chunk.completion_tokens > 0, "Completion tokens must be > 0."
assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0."
assert chunk.total_tokens > 0, "Total tokens must be > 0."
assert chunk.step_count == 1, "Step count must be 1."
elif isinstance(chunk, MessageStreamStatus):
assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status."
assert idx == len(chunks) - 1, "The last chunk must be 'done'."
else:
pytest.fail(f"Unexpected chunk type: {chunk}")
# --- Test Cases --- #
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message):
"""Tests chat completion streaming via SSE."""
request = _get_chat_request(agent.id, message)
response = _sse_post( response = _sse_post(
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
) )
# Convert the streaming response into a list of chunks for processing
chunks = list(response) chunks = list(response)
for idx, chunk in enumerate(chunks): for idx, chunk in enumerate(chunks):
if isinstance(chunk, ChatCompletionChunk): _assert_valid_chunk(chunk, idx, chunks)
# Assert that the chunk has at least one choice (a response from the model)
assert len(chunk.choices) > 0, "Each ChatCompletionChunk should have at least one choice."
elif isinstance(chunk, LettaUsageStatistics):
# Assert that the usage statistics contain valid token counts
assert chunk.completion_tokens > 0, "Completion tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.prompt_tokens > 0, "Prompt tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.total_tokens > 0, "Total tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.step_count == 1, "Step count in LettaUsageStatistics should always be 1 for a single request."
elif isinstance(chunk, MessageStreamStatus): @pytest.mark.asyncio
# Assert that the stream ends with a 'done' status @pytest.mark.parametrize("message", ["Tell me something interesting about bananas.", "Roll a dice!"])
assert chunk == MessageStreamStatus.done, "The last chunk should indicate the stream has completed." async def test_chat_completions_streaming_async(client, agent, message):
assert idx == len(chunks) - 1, "The 'done' status must be the last chunk in the stream." """Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(agent.id, message)
else: async_client = AsyncOpenAI(base_url=f"{client.base_url}/openai/{client.api_prefix}", max_retries=0)
# Fail the test if an unexpected chunk type is encountered stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
pytest.fail(f"Unexpected chunk type: {chunk}", pytrace=True)
async with stream:
async for chunk in stream:
if isinstance(chunk, ChatCompletionChunk):
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
assert chunk.choices[0].delta.content, f"Chunk at index 0 has no content: {chunk.model_dump_json(indent=4)}"
else:
pytest.fail(f"Unexpected chunk type: {chunk}")