mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
116 lines
3.0 KiB
Python
116 lines
3.0 KiB
Python
import os
|
|
import threading
|
|
import time
|
|
|
|
import pytest
|
|
from dotenv import load_dotenv
|
|
from letta_client import Letta, LettaBatchRequest, MessageCreate, TextContent
|
|
|
|
from letta.config import LettaConfig
|
|
from letta.server.server import SyncServer
|
|
|
|
|
|
def run_server():
|
|
"""Starts the Letta server in a background thread."""
|
|
load_dotenv()
|
|
from letta.server.rest_api.app import start_server
|
|
|
|
start_server(debug=True)
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def server_url():
|
|
"""
|
|
Ensures a server is running and returns its base URL.
|
|
|
|
Uses environment variable if available, otherwise starts a server
|
|
in a background thread.
|
|
"""
|
|
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
|
|
|
if not os.getenv("LETTA_SERVER_URL"):
|
|
thread = threading.Thread(target=run_server, daemon=True)
|
|
thread.start()
|
|
time.sleep(5) # Give server time to start
|
|
|
|
return url
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server():
|
|
"""
|
|
Creates a SyncServer instance for testing.
|
|
|
|
Loads and saves config to ensure proper initialization.
|
|
"""
|
|
config = LettaConfig.load()
|
|
config.save()
|
|
return SyncServer()
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
|
def client(server_url):
|
|
"""Creates a REST client connected to the test server."""
|
|
return Letta(base_url=server_url)
|
|
|
|
|
|
def test_create_batch(client: Letta):
|
|
|
|
# create agents
|
|
agent1 = client.agents.create(
|
|
name="agent1_batch",
|
|
memory_blocks=[{"label": "persona", "value": "you are agent 1"}],
|
|
model="anthropic/claude-3-7-sonnet-20250219",
|
|
embedding="letta/letta-free",
|
|
)
|
|
agent2 = client.agents.create(
|
|
name="agent2_batch",
|
|
memory_blocks=[{"label": "persona", "value": "you are agent 2"}],
|
|
model="anthropic/claude-3-7-sonnet-20250219",
|
|
embedding="letta/letta-free",
|
|
)
|
|
|
|
# create a run
|
|
run = client.batches.create(
|
|
requests=[
|
|
LettaBatchRequest(
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content=[
|
|
TextContent(
|
|
text="hi",
|
|
)
|
|
],
|
|
)
|
|
],
|
|
agent_id=agent1.id,
|
|
),
|
|
LettaBatchRequest(
|
|
messages=[
|
|
MessageCreate(
|
|
role="user",
|
|
content=[
|
|
TextContent(
|
|
text="hi",
|
|
)
|
|
],
|
|
)
|
|
],
|
|
agent_id=agent2.id,
|
|
),
|
|
]
|
|
)
|
|
assert run is not None
|
|
|
|
# list batches
|
|
batches = client.batches.list()
|
|
assert len(batches) > 0, f"Expected 1 batch, got {len(batches)}"
|
|
|
|
# get the batch results
|
|
results = client.batches.retrieve(
|
|
batch_id=run.id,
|
|
)
|
|
assert results is not None
|
|
print(results)
|