fix: Fix chat completions streaming (#1078)

This commit is contained in:
Matthew Zhou 2025-02-20 11:37:52 -08:00 committed by GitHub
parent 3cf6a2b692
commit c325d11c35
3 changed files with 29 additions and 11 deletions

View File

@ -153,7 +153,9 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
"""No-op retained for interface compatibility."""
return
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime) -> None:
def process_chunk(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
) -> None:
"""
Called externally with a ChatCompletionChunkResponse. Transforms
it if necessary, then enqueues partial messages for streaming back.

View File

@ -48,7 +48,9 @@ class AgentChunkStreamingInterface(ABC):
raise NotImplementedError
@abstractmethod
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime):
def process_chunk(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
):
"""Process a streaming chunk from an OpenAI-compatible server"""
raise NotImplementedError
@ -92,7 +94,9 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
def _flush(self):
pass
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime):
def process_chunk(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
):
assert len(chunk.choices) == 1, chunk
message_delta = chunk.choices[0].delta

View File

@ -120,9 +120,16 @@ def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, messag
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
)
chunks = list(response)
for idx, chunk in enumerate(chunks):
_assert_valid_chunk(chunk, idx, chunks)
try:
chunks = list(response)
assert len(chunks) > 1, "Streaming response did not return enough chunks (may have failed silently)."
for idx, chunk in enumerate(chunks):
assert chunk, f"Empty chunk received at index {idx}."
print(chunk)
_assert_valid_chunk(chunk, idx, chunks)
except Exception as e:
pytest.fail(f"Streaming failed with exception: {e}")
@pytest.mark.asyncio
@ -134,10 +141,15 @@ async def test_chat_completions_streaming_async(client, agent, message):
async_client = AsyncOpenAI(base_url=f"{client.base_url}/openai/{client.api_prefix}", max_retries=0)
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
async with stream:
async for chunk in stream:
if isinstance(chunk, ChatCompletionChunk):
received_chunks = 0
try:
async with stream:
async for chunk in stream:
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
assert chunk.choices[0].delta.content, f"Chunk at index 0 has no content: {chunk.model_dump_json(indent=4)}"
else:
pytest.fail(f"Unexpected chunk type: {chunk}")
received_chunks += 1
except Exception as e:
pytest.fail(f"Streaming failed with exception: {e}")
assert received_chunks > 1, "No valid streaming chunks were received."