mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
test: Write unit tests for stdio MCP server integration (#2001)
This commit is contained in:
parent
16cc387fe1
commit
b1e6d053a2
0
tests/mcp/__init__.py
Normal file
0
tests/mcp/__init__.py
Normal file
11
tests/mcp/mcp_config.json
Normal file
11
tests/mcp/mcp_config.json
Normal file
@ -0,0 +1,11 @@
|
||||
{
|
||||
"mcpServers": {
|
||||
"weather": {
|
||||
"transport": "stdio",
|
||||
"command": "/Users/mattzhou/letta-cloud/apps/core/tests/mcp/weather/venv/bin/python3",
|
||||
"args": [
|
||||
"/Users/mattzhou/letta-cloud/apps/core/tests/mcp/weather/weather.py"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
125
tests/mcp/test_mcp.py
Normal file
125
tests/mcp/test_mcp.py
Normal file
@ -0,0 +1,125 @@
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import venv
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from mcp import Tool as MCPTool
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.config import LettaConfig
|
||||
from letta.functions.mcp_client.types import MCPServerType, StdioServerConfig
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
def create_virtualenv_and_install_requirements(requirements_path: Path, name="venv") -> Path:
|
||||
requirements_path = requirements_path.resolve()
|
||||
|
||||
if not requirements_path.exists():
|
||||
raise FileNotFoundError(f"Requirements file not found: {requirements_path}")
|
||||
if requirements_path.name != "requirements.txt":
|
||||
raise ValueError(f"Expected file named 'requirements.txt', got: {requirements_path.name}")
|
||||
|
||||
venv_dir = requirements_path.parent / name
|
||||
|
||||
if not venv_dir.exists():
|
||||
venv.EnvBuilder(with_pip=True).create(venv_dir)
|
||||
|
||||
pip_path = venv_dir / ("Scripts" if os.name == "nt" else "bin") / "pip"
|
||||
if not pip_path.exists():
|
||||
raise FileNotFoundError(f"pip executable not found at: {pip_path}")
|
||||
|
||||
try:
|
||||
subprocess.check_call([str(pip_path), "install", "-r", str(requirements_path)])
|
||||
except subprocess.CalledProcessError as exc:
|
||||
raise RuntimeError(f"pip install failed with exit code {exc.returncode}")
|
||||
|
||||
return venv_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def empty_mcp_config(tmp_path):
|
||||
path = Path(__file__).parent / "mcp_config.json"
|
||||
path.write_text(json.dumps({})) # writes "{}"
|
||||
|
||||
create_virtualenv_and_install_requirements(Path(__file__).parent / "weather" / "requirements.txt")
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def server(empty_mcp_config):
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
old_dir = constants.LETTA_DIR
|
||||
constants.LETTA_DIR = str(Path(__file__).parent)
|
||||
|
||||
server = SyncServer()
|
||||
yield server
|
||||
constants.LETTA_DIR = old_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server):
|
||||
user = server.user_manager.get_user_or_default()
|
||||
yield user
|
||||
|
||||
|
||||
def test_stdio_mcp_server(server, default_user):
|
||||
assert server.mcp_clients == {}
|
||||
|
||||
mcp_server_name = "weather"
|
||||
command = str(Path(__file__).parent / "weather" / "venv" / "bin" / "python3")
|
||||
args = [str(Path(__file__).parent / "weather" / "weather.py")]
|
||||
stdio_mcp_config = StdioServerConfig(server_name=mcp_server_name, command=command, args=args)
|
||||
server.add_mcp_server_to_config(stdio_mcp_config)
|
||||
|
||||
# Check that it's in clients
|
||||
assert mcp_server_name in server.mcp_clients
|
||||
|
||||
# Check that it's in the server mapping
|
||||
mcp_server_mapping = server.get_mcp_servers()
|
||||
assert mcp_server_name in mcp_server_mapping
|
||||
assert mcp_server_mapping[mcp_server_name] == StdioServerConfig(
|
||||
server_name=mcp_server_name, type=MCPServerType.STDIO, command=command, args=args, env=None
|
||||
)
|
||||
|
||||
# Check that it can return valid tools
|
||||
tools = server.get_tools_from_mcp_server(mcp_server_name)
|
||||
assert tools == [
|
||||
MCPTool(
|
||||
name="get_alerts",
|
||||
description="Get weather alerts for a US state.\n\n Args:\n state: Two-letter US state code (e.g. CA, NY)\n ",
|
||||
inputSchema={
|
||||
"properties": {"state": {"title": "State", "type": "string"}},
|
||||
"required": ["state"],
|
||||
"title": "get_alertsArguments",
|
||||
"type": "object",
|
||||
},
|
||||
),
|
||||
MCPTool(
|
||||
name="get_forecast",
|
||||
description="Get weather forecast for a location.\n\n Args:\n latitude: Latitude of the location\n longitude: Longitude of the location\n ",
|
||||
inputSchema={
|
||||
"properties": {"latitude": {"title": "Latitude", "type": "number"}, "longitude": {"title": "Longitude", "type": "number"}},
|
||||
"required": ["latitude", "longitude"],
|
||||
"title": "get_forecastArguments",
|
||||
"type": "object",
|
||||
},
|
||||
),
|
||||
]
|
||||
get_alerts_mcp_tool = tools[0]
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=get_alerts_mcp_tool)
|
||||
get_alerts_tool = server.tool_manager.create_or_update_mcp_tool(
|
||||
tool_create=tool_create, mcp_server_name=mcp_server_name, actor=default_user
|
||||
)
|
||||
|
||||
# Attempt running the tool
|
||||
function_response, is_error = server.mcp_clients[mcp_server_name].execute_tool(tool_name="get_alerts", tool_args={"state": "CA"})
|
||||
assert not is_error
|
||||
assert len(function_response) > 1000, function_response # Crude heuristic for an expected result
|
27
tests/mcp/weather/requirements.txt
Normal file
27
tests/mcp/weather/requirements.txt
Normal file
@ -0,0 +1,27 @@
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
certifi==2025.4.26
|
||||
click==8.1.8
|
||||
h11==0.16.0
|
||||
httpcore==1.0.9
|
||||
httpx==0.28.1
|
||||
httpx-sse==0.4.0
|
||||
idna==3.10
|
||||
markdown-it-py==3.0.0
|
||||
mcp==1.7.1
|
||||
mdurl==0.1.2
|
||||
pydantic==2.11.4
|
||||
pydantic-settings==2.9.1
|
||||
pydantic_core==2.33.2
|
||||
Pygments==2.19.1
|
||||
python-dotenv==1.1.0
|
||||
python-multipart==0.0.20
|
||||
rich==14.0.0
|
||||
shellingham==1.5.4
|
||||
sniffio==1.3.1
|
||||
sse-starlette==2.3.3
|
||||
starlette==0.46.2
|
||||
typer==0.15.3
|
||||
typing-inspection==0.4.0
|
||||
typing_extensions==4.13.2
|
||||
uvicorn==0.34.2
|
97
tests/mcp/weather/weather.py
Normal file
97
tests/mcp/weather/weather.py
Normal file
@ -0,0 +1,97 @@
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
# Initialize FastMCP server
|
||||
mcp = FastMCP("weather")
|
||||
|
||||
# Constants
|
||||
NWS_API_BASE = "https://api.weather.gov"
|
||||
USER_AGENT = "weather-app/1.0"
|
||||
|
||||
|
||||
async def make_nws_request(url: str) -> dict[str, Any] | None:
|
||||
"""Make a request to the NWS API with proper error handling."""
|
||||
headers = {"User-Agent": USER_AGENT, "Accept": "application/geo+json"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.get(url, headers=headers, timeout=30.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def format_alert(feature: dict) -> str:
|
||||
"""Format an alert feature into a readable string."""
|
||||
props = feature["properties"]
|
||||
return f"""
|
||||
Event: {props.get('event', 'Unknown')}
|
||||
Area: {props.get('areaDesc', 'Unknown')}
|
||||
Severity: {props.get('severity', 'Unknown')}
|
||||
Description: {props.get('description', 'No description available')}
|
||||
Instructions: {props.get('instruction', 'No specific instructions provided')}
|
||||
"""
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_alerts(state: str) -> str:
|
||||
"""Get weather alerts for a US state.
|
||||
|
||||
Args:
|
||||
state: Two-letter US state code (e.g. CA, NY)
|
||||
"""
|
||||
url = f"{NWS_API_BASE}/alerts/active/area/{state}"
|
||||
data = await make_nws_request(url)
|
||||
|
||||
if not data or "features" not in data:
|
||||
return "Unable to fetch alerts or no alerts found."
|
||||
|
||||
if not data["features"]:
|
||||
return "No active alerts for this state."
|
||||
|
||||
alerts = [format_alert(feature) for feature in data["features"]]
|
||||
return "\n---\n".join(alerts)
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_forecast(latitude: float, longitude: float) -> str:
|
||||
"""Get weather forecast for a location.
|
||||
|
||||
Args:
|
||||
latitude: Latitude of the location
|
||||
longitude: Longitude of the location
|
||||
"""
|
||||
# First get the forecast grid endpoint
|
||||
points_url = f"{NWS_API_BASE}/points/{latitude},{longitude}"
|
||||
points_data = await make_nws_request(points_url)
|
||||
|
||||
if not points_data:
|
||||
return "Unable to fetch forecast data for this location."
|
||||
|
||||
# Get the forecast URL from the points response
|
||||
forecast_url = points_data["properties"]["forecast"]
|
||||
forecast_data = await make_nws_request(forecast_url)
|
||||
|
||||
if not forecast_data:
|
||||
return "Unable to fetch detailed forecast."
|
||||
|
||||
# Format the periods into a readable forecast
|
||||
periods = forecast_data["properties"]["periods"]
|
||||
forecasts = []
|
||||
for period in periods[:5]: # Only show next 5 periods
|
||||
forecast = f"""
|
||||
{period['name']}:
|
||||
Temperature: {period['temperature']}°{period['temperatureUnit']}
|
||||
Wind: {period['windSpeed']} {period['windDirection']}
|
||||
Forecast: {period['detailedForecast']}
|
||||
"""
|
||||
forecasts.append(forecast)
|
||||
|
||||
return "\n---\n".join(forecasts)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Initialize and run the server
|
||||
mcp.run(transport="stdio")
|
Loading…
Reference in New Issue
Block a user