mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
201 lines
6.8 KiB
Python
201 lines
6.8 KiB
Python
import os
|
|
import threading
|
|
import uuid
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from letta_client import Letta
|
|
from openai import AsyncOpenAI
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import MessageStreamStatus
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest
|
|
from letta.schemas.openai.chat_completion_request import UserMessage as OpenAIUserMessage
|
|
from letta.schemas.tool import ToolCreate
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.services.tool_manager import ToolManager
|
|
from tests.utils import wait_for_server
|
|
|
|
# --- Server Management --- #
|
|
|
|
|
|
def _run_server():
|
|
"""Starts the Letta server in a background thread."""
|
|
load_dotenv()
|
|
from letta.server.rest_api.app import start_server
|
|
|
|
start_server(debug=True)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def server_url():
|
|
"""Ensures a server is running and returns its base URL."""
|
|
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
|
|
|
if not os.getenv("LETTA_SERVER_URL"):
|
|
thread = threading.Thread(target=_run_server, daemon=True)
|
|
thread.start()
|
|
wait_for_server(url) # Allow server startup time
|
|
|
|
return url
|
|
|
|
|
|
# --- Client Setup --- #
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def client(server_url):
|
|
"""Creates a REST client for testing."""
|
|
client = Letta(base_url=server_url)
|
|
yield client
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def roll_dice_tool(client):
|
|
def roll_dice():
|
|
"""
|
|
Rolls a 6 sided die.
|
|
|
|
Returns:
|
|
str: The roll result.
|
|
"""
|
|
return "Rolled a 10!"
|
|
|
|
tool = client.tools.upsert_from_function(func=roll_dice)
|
|
# Yield the created tool
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def weather_tool(client):
|
|
def get_weather(location: str) -> str:
|
|
"""
|
|
Fetches the current weather for a given location.
|
|
|
|
Parameters:
|
|
location (str): The location to get the weather for.
|
|
|
|
Returns:
|
|
str: A formatted string describing the weather in the given location.
|
|
|
|
Raises:
|
|
RuntimeError: If the request to fetch weather data fails.
|
|
"""
|
|
import requests
|
|
|
|
url = f"https://wttr.in/{location}?format=%C+%t"
|
|
|
|
response = requests.get(url)
|
|
if response.status_code == 200:
|
|
weather_data = response.text
|
|
return f"The weather in {location} is {weather_data}."
|
|
else:
|
|
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
|
|
|
|
tool = client.tools.upsert_from_function(func=get_weather)
|
|
# Yield the created tool
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def composio_gmail_get_profile_tool(default_user):
|
|
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
|
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
|
|
yield tool
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
def agent(client, roll_dice_tool, weather_tool):
|
|
"""Creates an agent and ensures cleanup after tests."""
|
|
agent_state = client.agents.create(
|
|
name=f"test_compl_{str(uuid.uuid4())[5:]}",
|
|
tool_ids=[roll_dice_tool.id, weather_tool.id],
|
|
include_base_tools=True,
|
|
memory_blocks=[
|
|
{"label": "human", "value": "(I know nothing about the human)"},
|
|
{"label": "persona", "value": "Friendly agent"},
|
|
],
|
|
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
|
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
|
)
|
|
yield agent_state
|
|
client.agents.delete(agent_state.id)
|
|
|
|
|
|
# --- Helper Functions --- #
|
|
|
|
|
|
def _get_chat_request(message, stream=True):
|
|
"""Returns a chat completion request with streaming enabled."""
|
|
return ChatCompletionRequest(
|
|
model="gpt-4o-mini",
|
|
messages=[OpenAIUserMessage(content=message)],
|
|
stream=stream,
|
|
)
|
|
|
|
|
|
def _assert_valid_chunk(chunk, idx, chunks):
|
|
"""Validates the structure of each streaming chunk."""
|
|
if isinstance(chunk, ChatCompletionChunk):
|
|
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
|
|
|
elif isinstance(chunk, LettaUsageStatistics):
|
|
assert chunk.completion_tokens > 0, "Completion tokens must be > 0."
|
|
assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0."
|
|
assert chunk.total_tokens > 0, "Total tokens must be > 0."
|
|
assert chunk.step_count == 1, "Step count must be 1."
|
|
|
|
elif isinstance(chunk, MessageStreamStatus):
|
|
assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status."
|
|
assert idx == len(chunks) - 1, "The last chunk must be 'done'."
|
|
|
|
else:
|
|
pytest.fail(f"Unexpected chunk type: {chunk}")
|
|
|
|
|
|
# --- Test Cases --- #
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas.", "What's the weather in SF?"])
|
|
@pytest.mark.parametrize("endpoint", ["openai/v1"])
|
|
async def test_chat_completions_streaming_openai_client(disable_e2b_api_key, client, agent, message, endpoint):
|
|
"""Tests chat completion streaming using the Async OpenAI client."""
|
|
request = _get_chat_request(message)
|
|
|
|
async_client = AsyncOpenAI(base_url=f"http://localhost:8283/{endpoint}/{agent.id}", max_retries=0)
|
|
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
|
|
|
received_chunks = 0
|
|
stop_chunk_count = 0
|
|
last_chunk = None
|
|
|
|
try:
|
|
async with stream:
|
|
async for chunk in stream:
|
|
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
|
|
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
|
|
|
# Track last chunk for final verification
|
|
last_chunk = chunk
|
|
|
|
# If this chunk has a finish reason of "stop", track it
|
|
if chunk.choices[0].finish_reason == "stop":
|
|
stop_chunk_count += 1
|
|
# Fail early if more than one stop chunk is sent
|
|
assert stop_chunk_count == 1, f"Multiple stop chunks detected: {chunk.model_dump_json(indent=4)}"
|
|
continue
|
|
|
|
# Validate regular content chunks
|
|
assert chunk.choices[0].delta.content, f"Chunk at index {received_chunks} has no content: {chunk.model_dump_json(indent=4)}"
|
|
received_chunks += 1
|
|
except Exception as e:
|
|
pytest.fail(f"Streaming failed with exception: {e}")
|
|
|
|
assert received_chunks > 0, "No valid streaming chunks were received."
|
|
|
|
# Ensure the last chunk is the expected stop chunk
|
|
assert last_chunk is not None, "No last chunk received."
|