mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add e2e tests where agents invoke MCP tooling (#2020)
This commit is contained in:
parent
0ef64dc365
commit
7b13ab91b1
@ -1,8 +1 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"github_composio": {
|
||||
"transport": "sse",
|
||||
"url": "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
|
||||
}
|
||||
}
|
||||
}
|
||||
{}
|
||||
|
@ -1,18 +1,21 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import threading
|
||||
import uuid
|
||||
import venv
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from mcp import Tool as MCPTool
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, McpTool, ToolCallMessage, ToolReturnMessage
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.config import LettaConfig
|
||||
from letta.functions.mcp_client.types import MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.server.server import SyncServer
|
||||
from letta.utils import parse_json
|
||||
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
|
||||
def create_virtualenv_and_install_requirements(requirements_path: Path, name="venv") -> Path:
|
||||
@ -40,122 +43,165 @@ def create_virtualenv_and_install_requirements(requirements_path: Path, name="ve
|
||||
return venv_dir
|
||||
|
||||
|
||||
# --- Server Management --- #
|
||||
|
||||
|
||||
def _run_server():
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_mcp_config(tmp_path):
|
||||
def empty_mcp_config():
|
||||
path = Path(__file__).parent / "mcp_config.json"
|
||||
path.write_text(json.dumps({})) # writes "{}"
|
||||
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server(empty_mcp_config):
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
@pytest.fixture()
|
||||
def server_url(empty_mcp_config):
|
||||
"""Ensures a server is running and returns its base URL."""
|
||||
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
config.save()
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(url)
|
||||
|
||||
old_dir = constants.LETTA_DIR
|
||||
constants.LETTA_DIR = str(Path(__file__).parent)
|
||||
|
||||
server = SyncServer()
|
||||
yield server
|
||||
constants.LETTA_DIR = old_dir
|
||||
return url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server):
|
||||
user = server.user_manager.get_user_or_default()
|
||||
yield user
|
||||
@pytest.fixture()
|
||||
def client(server_url):
|
||||
"""Creates a REST client for testing."""
|
||||
client = Letta(base_url=server_url)
|
||||
return client
|
||||
|
||||
|
||||
def test_sse_mcp_server(server, default_user):
|
||||
assert server.mcp_clients == {}
|
||||
@pytest.fixture()
|
||||
def agent_state(client):
|
||||
"""Creates an agent and ensures cleanup after tests."""
|
||||
agent_state = client.agents.create(
|
||||
name=f"test_compl_{str(uuid.uuid4())[5:]}",
|
||||
include_base_tools=True,
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "Name: Matt",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "Friendly agent",
|
||||
},
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
)
|
||||
yield agent_state
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
def test_sse_mcp_server(client, agent_state):
|
||||
mcp_server_name = "github_composio"
|
||||
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
|
||||
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||
server.add_mcp_server_to_config(sse_mcp_config)
|
||||
|
||||
# Check that it's in clients
|
||||
assert mcp_server_name in server.mcp_clients
|
||||
client.tools.add_mcp_server(request=sse_mcp_config)
|
||||
|
||||
# Check that it's in the server mapping
|
||||
mcp_server_mapping = server.get_mcp_servers()
|
||||
mcp_server_mapping = client.tools.list_mcp_servers()
|
||||
assert mcp_server_name in mcp_server_mapping
|
||||
assert mcp_server_mapping[mcp_server_name] == sse_mcp_config
|
||||
|
||||
# Check tools
|
||||
tools = server.get_tools_from_mcp_server(mcp_server_name)
|
||||
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||
assert len(tools) > 0
|
||||
assert isinstance(tools[0], MCPTool)
|
||||
assert isinstance(tools[0], McpTool)
|
||||
star_mcp_tool = next((t for t in tools if t.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"), None)
|
||||
|
||||
# Check that one of the tools are executable
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=star_mcp_tool)
|
||||
server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
|
||||
letta_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name=star_mcp_tool.name)
|
||||
|
||||
function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(
|
||||
tool_name=star_mcp_tool.name, tool_args={"owner": "letta-ai", "repo": "letta"}
|
||||
tool_args = {"owner": "letta-ai", "repo": "letta"}
|
||||
|
||||
# Add to agent, have agent invoke tool
|
||||
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[TextContent(text=f"Use the `{letta_tool.name}` tool with these arguments: {tool_args}.")],
|
||||
)
|
||||
],
|
||||
)
|
||||
assert not is_error
|
||||
function_response = parse_json(function_response)
|
||||
assert function_response.get("successful"), function_response
|
||||
assert function_response.get("data").get("details") == "Action executed successfully", function_response
|
||||
seq = response.messages
|
||||
calls = [m for m in seq if isinstance(m, ToolCallMessage)]
|
||||
assert calls, "Expected a ToolCallMessage"
|
||||
assert calls[0].tool_call.name == "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"
|
||||
|
||||
returns = [m for m in seq if isinstance(m, ToolReturnMessage)]
|
||||
assert returns, "Expected a ToolReturnMessage"
|
||||
tr = returns[0]
|
||||
# status field
|
||||
assert tr.status == "success", f"Bad status: {tr.status}"
|
||||
# parse JSON payload
|
||||
payload = json.loads(tr.tool_return)
|
||||
assert payload.get("successful", False), f"Tool returned failure payload: {payload}"
|
||||
assert payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {payload}"
|
||||
|
||||
|
||||
def test_stdio_mcp_server(server, default_user):
|
||||
assert server.mcp_clients == {}
|
||||
|
||||
# Create venv
|
||||
create_virtualenv_and_install_requirements(Path(__file__).parent / "weather" / "requirements.txt")
|
||||
def test_stdio_mcp_server(client, agent_state):
|
||||
req_file = Path(__file__).parent / "weather" / "requirements.txt"
|
||||
create_virtualenv_and_install_requirements(req_file, name="venv")
|
||||
|
||||
mcp_server_name = "weather"
|
||||
command = str(Path(__file__).parent / "weather" / "venv" / "bin" / "python3")
|
||||
args = [str(Path(__file__).parent / "weather" / "weather.py")]
|
||||
stdio_mcp_config = StdioServerConfig(server_name=mcp_server_name, command=command, args=args)
|
||||
server.add_mcp_server_to_config(stdio_mcp_config)
|
||||
|
||||
# Check that it's in clients
|
||||
assert mcp_server_name in server.mcp_clients
|
||||
|
||||
# Check that it's in the server mapping
|
||||
mcp_server_mapping = server.get_mcp_servers()
|
||||
assert mcp_server_name in mcp_server_mapping
|
||||
assert mcp_server_mapping[mcp_server_name] == StdioServerConfig(
|
||||
server_name=mcp_server_name, type=MCPServerType.STDIO, command=command, args=args, env=None
|
||||
stdio_config = StdioServerConfig(
|
||||
server_name=mcp_server_name,
|
||||
command=command,
|
||||
args=args,
|
||||
)
|
||||
|
||||
# Check that it can return valid tools
|
||||
tools = server.get_tools_from_mcp_server(mcp_server_name)
|
||||
assert tools == [
|
||||
MCPTool(
|
||||
name="get_alerts",
|
||||
description="Get weather alerts for a US state.\n\n Args:\n state: Two-letter US state code (e.g. CA, NY)\n ",
|
||||
inputSchema={
|
||||
"properties": {"state": {"title": "State", "type": "string"}},
|
||||
"required": ["state"],
|
||||
"title": "get_alertsArguments",
|
||||
"type": "object",
|
||||
},
|
||||
),
|
||||
MCPTool(
|
||||
name="get_forecast",
|
||||
description="Get weather forecast for a location.\n\n Args:\n latitude: Latitude of the location\n longitude: Longitude of the location\n ",
|
||||
inputSchema={
|
||||
"properties": {"latitude": {"title": "Latitude", "type": "number"}, "longitude": {"title": "Longitude", "type": "number"}},
|
||||
"required": ["latitude", "longitude"],
|
||||
"title": "get_forecastArguments",
|
||||
"type": "object",
|
||||
},
|
||||
),
|
||||
]
|
||||
get_alerts_mcp_tool = tools[0]
|
||||
client.tools.add_mcp_server(request=stdio_config)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=get_alerts_mcp_tool)
|
||||
server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user)
|
||||
servers = client.tools.list_mcp_servers()
|
||||
assert mcp_server_name in servers
|
||||
|
||||
# Attempt running the tool
|
||||
function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(tool_name="get_alerts", tool_args={"state": "CA"})
|
||||
assert not is_error
|
||||
assert len(function_response) > 20, function_response # Crude heuristic for an expected result
|
||||
tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||
assert tools, "Expected at least one tool from the weather MCP server"
|
||||
assert any(t.name == "get_alerts" for t in tools), f"Got: {[t.name for t in tools]}"
|
||||
|
||||
get_alerts = next(t for t in tools if t.name == "get_alerts")
|
||||
|
||||
letta_tool = client.tools.add_mcp_tool(
|
||||
mcp_server_name=mcp_server_name,
|
||||
mcp_tool_name=get_alerts.name,
|
||||
)
|
||||
|
||||
client.agents.tools.attach(agent_id=agent_state.id, tool_id=letta_tool.id)
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[TextContent(text=(f"Use the `{letta_tool.name}` tool with these arguments: " f"{{'state': 'CA'}}."))],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
calls = [m for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == "get_alerts"]
|
||||
assert calls, "Expected a get_alerts ToolCallMessage"
|
||||
|
||||
returns = [m for m in response.messages if isinstance(m, ToolReturnMessage) and m.tool_call_id == calls[0].tool_call.tool_call_id]
|
||||
assert returns, "Expected a ToolReturnMessage for get_alerts"
|
||||
ret = returns[0]
|
||||
|
||||
assert ret.status == "success", f"Unexpected status: {ret.status}"
|
||||
# make sure there's at least some payload
|
||||
assert len(ret.tool_return.strip()) >= 10, f"Expected at least 10 characters in tool_return, got {len(ret.tool_return.strip())}"
|
||||
|
Loading…
Reference in New Issue
Block a user