MemGPT/tests/integration_test_batch.py
Sarah Wooders fb139dead6 feat: Write fern api tests for batch API (#1821)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
2025-04-21 15:48:06 -07:00

116 lines
3.0 KiB
Python

import os
import threading
import time
import pytest
from dotenv import load_dotenv
from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
from letta.config import LettaConfig
from letta.server.server import SyncServer
def run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""
Ensures a server is running and returns its base URL.
Uses environment variable if available, otherwise starts a server
in a background thread.
"""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5) # Give server time to start
return url
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
return SyncServer()
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client connected to the test server."""
return Letta(base_url=server_url)
def test_create_batch(client: Letta):
# create agents
agent1 = client.agents.create(
name="agent1_batch",
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
model="anthropic/claude-3-7-sonnet-20250219",
embedding="letta/letta-free",
)
agent2 = client.agents.create(
name="agent2_batch",
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
model="anthropic/claude-3-7-sonnet-20250219",
embedding="letta/letta-free",
)
# create a run
run = client.batches.create(
requests=[
LettaBatchRequest(
messages=[
MessageCreate(
role="user",
content=[
TextContent(
text="hi",
)
],
)
],
agent_id=agent1.id,
),
LettaBatchRequest(
messages=[
MessageCreate(
role="user",
content=[
TextContent(
text="hi",
)
],
)
],
agent_id=agent2.id,
),
]
)
assert run is not None
# list batches
batches = client.batches.list()
assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}"
# get the batch results
results = client.batches.retrieve(
batch_id=run.id,
)
assert results is not None
print(results)