feat: Add e2e tests where agents invoke MCP tooling (#2020)

This commit is contained in:
Matthew Zhou 2025-05-06 11:24:20 +08:00 committed by GitHub
parent 0ef64dc365
commit 7b13ab91b1
2 changed files with 134 additions and 95 deletions

View File

@ -1,8 +1 @@
{
"mcpServers": {
"github_composio": {
"transport": "sse",
"url": "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
}
}
}
{}

View File

@ -1,18 +1,21 @@
import json
import os
import subprocess
import threading
import uuid
import venv
from pathlib import Path
import pytest
from mcp import Tool as MCPTool
from dotenv import load_dotenv
from letta_client import Letta, McpTool, ToolCallMessage, ToolReturnMessage
import letta.constants as constants
from letta.config import LettaConfig
from letta.functions.mcp_client.types import MCPServerType, SSEServerConfig, StdioServerConfig
from letta.schemas.tool import ToolCreate
from letta.server.server import SyncServer
from letta.utils import parse_json
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from tests.utils import wait_for_server
def create_virtualenv_and_install_requirements(requirements_path: Path, name="venv") -> Path:
@ -40,122 +43,165 @@ def create_virtualenv_and_install_requirements(requirements_path: Path, name="ve
return venv_dir
# --- Server Management --- #
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture
def empty_mcp_config(tmp_path):
def empty_mcp_config():
path = Path(__file__).parent / "mcp_config.json"
path.write_text(json.dumps({})) # writes "{}"
return path
@pytest.fixture
def server(empty_mcp_config):
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
@pytest.fixture()
def server_url(empty_mcp_config):
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
config.save()
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
wait_for_server(url)
old_dir = constants.LETTA_DIR
constants.LETTA_DIR = str(Path(__file__).parent)
server = SyncServer()
yield server
constants.LETTA_DIR = old_dir
return url
@pytest.fixture
def default_user(server):
user = server.user_manager.get_user_or_default()
yield user
@pytest.fixture()
def client(server_url):
"""Creates a REST client for testing."""
client = Letta(base_url=server_url)
return client
def test_sse_mcp_server(server, default_user):
assert server.mcp_clients == {}
@pytest.fixture()
def agent_state(client):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.agents.create(
name=f"test_compl_{str(uuid.uuid4())[5:]}",
include_base_tools=True,
memory_blocks=[
{
"label": "human",
"value": "Name: Matt",
},
{
"label": "persona",
"value": "Friendly agent",
},
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
yield agent_state
client.agents.delete(agent_state.id)
def test_sse_mcp_server(client, agent_state):
mcp_server_name = "github_composio"
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
server.add_mcp_server_to_config(sse_mcp_config)
# Check that it's in clients
assert mcp_server_name in server.mcp_clients
client.tools.add_mcp_server(request=sse_mcp_config)
# Check that it's in the server mapping
mcp_server_mapping = server.get_mcp_servers()
mcp_server_mapping = client.tools.list_mcp_servers()
assert mcp_server_name in mcp_server_mapping
assert mcp_server_mapping[mcp_server_name] == sse_mcp_config
# Check tools
tools = server.get_tools_from_mcp_server(mcp_server_name)
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
assert len(tools) > 0
assert isinstance(tools[0], MCPTool)
assert isinstance(tools[0], McpTool)
star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None)
# Check that one of the tools are executable
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=star_mcp_tool)
server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=star_mcp_tool.name)
function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(
tool_name=star_mcp_tool.name, tool_args={"owner": "letta-ai", "repo": "letta"}
tool_args = {"owner": "letta-ai", "repo": "letta"}
# Add to agent, have agent invoke tool
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
role="user",
content=[TextContent(text=f"Use the `{letta_tool.name}` tool with these arguments: {tool_args}.")],
)
assert not is_error
function_response = parse_json(function_response)
assert function_response.get("successful"), function_response
assert function_response.get("data").get("details") == "Action executed successfully", function_response
],
)
seq = response.messages
calls = [m for m in seq if isinstance(m, ToolCallMessage)]
assert calls, "Expected a ToolCallMessage"
assert calls[0].tool_call.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"
returns = [m for m in seq if isinstance(m, ToolReturnMessage)]
assert returns, "Expected a ToolReturnMessage"
tr = returns[0]
# status field
assert tr.status == "success", f"Bad status: {tr.status}"
# parse JSON payload
payload = json.loads(tr.tool_return)
assert payload.get("successful", False), f"Tool returned failure payload: {payload}"
assert payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {payload}"
def test_stdio_mcp_server(server, default_user):
assert server.mcp_clients == {}
# Create venv
create_virtualenv_and_install_requirements(Path(__file__).parent / "weather" / "requirements.txt")
def test_stdio_mcp_server(client, agent_state):
req_file = Path(__file__).parent / "weather" / "requirements.txt"
create_virtualenv_and_install_requirements(req_file, name="venv")
mcp_server_name = "weather"
command = str(Path(__file__).parent / "weather" / "venv" / "bin" / "python3")
args = [str(Path(__file__).parent / "weather" / "weather.py")]
stdio_mcp_config = StdioServerConfig(server_name=mcp_server_name, command=command, args=args)
server.add_mcp_server_to_config(stdio_mcp_config)
# Check that it's in clients
assert mcp_server_name in server.mcp_clients
# Check that it's in the server mapping
mcp_server_mapping = server.get_mcp_servers()
assert mcp_server_name in mcp_server_mapping
assert mcp_server_mapping[mcp_server_name] == StdioServerConfig(
server_name=mcp_server_name, type=MCPServerType.STDIO, command=command, args=args, env=None
stdio_config = StdioServerConfig(
server_name=mcp_server_name,
command=command,
args=args,
)
# Check that it can return valid tools
tools = server.get_tools_from_mcp_server(mcp_server_name)
assert tools == [
MCPTool(
name="get_alerts",
description="Get weather alerts for a US state.\n\n Args:\n state: Two-letter US state code (e.g. CA, NY)\n ",
inputSchema={
"properties": {"state": {"title": "State", "type": "string"}},
"required": ["state"],
"title": "get_alertsArguments",
"type": "object",
},
),
MCPTool(
name="get_forecast",
description="Get weather forecast for a location.\n\n Args:\n latitude: Latitude of the location\n longitude: Longitude of the location\n ",
inputSchema={
"properties": {"latitude": {"title": "Latitude", "type": "number"}, "longitude": {"title": "Longitude", "type": "number"}},
"required": ["latitude", "longitude"],
"title": "get_forecastArguments",
"type": "object",
},
),
]
get_alerts_mcp_tool = tools[0]
client.tools.add_mcp_server(request=stdio_config)
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=get_alerts_mcp_tool)
server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
servers = client.tools.list_mcp_servers()
assert mcp_server_name in servers
# Attempt running the tool
function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(tool_name="get_alerts", tool_args={"state": "CA"})
assert not is_error
assert len(function_response) > 20, function_response # Crude heuristic for an expected result
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
assert tools, "Expected at least one tool from the weather MCP server"
assert any(t.name == "get_alerts" for t in tools), f"Got: {[t.name for t in tools]}"
get_alerts = next(t for t in tools if t.name == "get_alerts")
letta_tool = client.tools.add_mcp_tool(
mcp_server_name=mcp_server_name,
mcp_tool_name=get_alerts.name,
)
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
role="user",
content=[TextContent(text=(f"Use the `{letta_tool.name}` tool with these arguments: " f"{{'state': 'CA'}}."))],
)
],
)
calls = [m for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == "get_alerts"]
assert calls, "Expected a get_alerts ToolCallMessage"
returns = [m for m in response.messages if isinstance(m, ToolReturnMessage) and m.tool_call_id == calls[0].tool_call.tool_call_id]
assert returns, "Expected a ToolReturnMessage for get_alerts"
ret = returns[0]
assert ret.status == "success", f"Unexpected status: {ret.status}"
# make sure there's at least some payload
assert len(ret.tool_return.strip()) >= 10, f"Expected at least 10 characters in tool_return, got {len(ret.tool_return.strip())}"