mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: Fix chat completions streaming (#1078)
This commit is contained in:
parent
3cf6a2b692
commit
c325d11c35
@ -153,7 +153,9 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
"""No-op retained for interface compatibility."""
|
||||
return
|
||||
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime) -> None:
|
||||
def process_chunk(
|
||||
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Called externally with a ChatCompletionChunkResponse. Transforms
|
||||
it if necessary, then enqueues partial messages for streaming back.
|
||||
|
@ -48,7 +48,9 @@ class AgentChunkStreamingInterface(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime):
|
||||
def process_chunk(
|
||||
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
|
||||
):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -92,7 +94,9 @@ class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
def _flush(self):
|
||||
pass
|
||||
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime):
|
||||
def process_chunk(
|
||||
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime, expect_reasoning_content: bool = False
|
||||
):
|
||||
assert len(chunk.choices) == 1, chunk
|
||||
|
||||
message_delta = chunk.choices[0].delta
|
||||
|
@ -120,9 +120,16 @@ def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, messag
|
||||
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
|
||||
)
|
||||
|
||||
chunks = list(response)
|
||||
for idx, chunk in enumerate(chunks):
|
||||
_assert_valid_chunk(chunk, idx, chunks)
|
||||
try:
|
||||
chunks = list(response)
|
||||
assert len(chunks) > 1, "Streaming response did not return enough chunks (may have failed silently)."
|
||||
|
||||
for idx, chunk in enumerate(chunks):
|
||||
assert chunk, f"Empty chunk received at index {idx}."
|
||||
print(chunk)
|
||||
_assert_valid_chunk(chunk, idx, chunks)
|
||||
except Exception as e:
|
||||
pytest.fail(f"Streaming failed with exception: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -134,10 +141,15 @@ async def test_chat_completions_streaming_async(client, agent, message):
|
||||
async_client = AsyncOpenAI(base_url=f"{client.base_url}/openai/{client.api_prefix}", max_retries=0)
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if isinstance(chunk, ChatCompletionChunk):
|
||||
received_chunks = 0
|
||||
try:
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
|
||||
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
||||
assert chunk.choices[0].delta.content, f"Chunk at index 0 has no content: {chunk.model_dump_json(indent=4)}"
|
||||
else:
|
||||
pytest.fail(f"Unexpected chunk type: {chunk}")
|
||||
received_chunks += 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Streaming failed with exception: {e}")
|
||||
|
||||
assert received_chunks > 1, "No valid streaming chunks were received."
|
||||
|
Loading…
Reference in New Issue
Block a user