chore: bump version 0.7.22 (#2655)

Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: jnjpng <jin@letta.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
This commit is contained in:
cthomas 2025-05-23 01:13:05 -07:00 committed by GitHub
parent c0efe8ad0c
commit 1b58fae4fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
80 changed files with 3149 additions and 8214 deletions

View File

@ -1,92 +0,0 @@
import json
import os
import uuid
from letta import create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.sandbox_config import SandboxType
from letta.services.sandbox_config_manager import SandboxConfigManager
"""
Setup here.
"""
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "letta-composio-tooling-example"))
# Clear all agents
for agent_state in client.list_agents():
if agent_state.name == agent_uuid:
client.delete_agent(agent_id=agent_state.id)
print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}")
# Add sandbox env
manager = SandboxConfigManager()
# Ensure you have e2b key set
sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=client.user)
manager.create_sandbox_env_var(
SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=os.environ.get("COMPOSIO_API_KEY")),
sandbox_config_id=sandbox_config.id,
actor=client.user,
)
"""
This example show how you can add Composio tools .
First, make sure you have Composio and some of the extras downloaded.
```
poetry install --extras "external-tools"
```
then setup letta with `letta configure`.
Aditionally, this example stars a Github repo on your behalf. You will need to configure Composio in your environment.
```
composio login
composio add github
```
Last updated Oct 2, 2024. Please check `composio` documentation for any composio related issues.
"""
def main():
from composio import Action
# Add the composio tool
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
persona = f"""
My name is Letta.
I am a personal assistant that helps star repos on Github. It is my job to correctly input the owner and repo to the {tool.name} tool based on the user's request.
Dont forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
"""
# Create an agent
agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[tool.id])
print(f"Created agent: {agent.name} with ID {str(agent.id)}")
# Send a message to the agent
send_message_response = client.user_message(agent_id=agent.id, message="Star a repo composio with owner composiohq on GitHub")
for message in send_message_response.messages:
response_json = json.dumps(message.model_dump(), indent=4)
print(f"{response_json}\n")
# Delete agent
client.delete_agent(agent_id=agent.id)
print(f"Deleted agent: {agent.name} with ID {str(agent.id)}")
if __name__ == "__main__":
main()

View File

@ -1,87 +0,0 @@
import json
import uuid
from letta import create_client
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
"""
This example show how you can add LangChain tools .
First, make sure you have LangChain and some of the extras downloaded.
For this specific example, you will need `wikipedia` installed.
```
poetry install --extras "external-tools"
```
then setup letta with `letta configure`.
"""
def main():
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500)
langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
# create tool
# Note the additional_imports_module_attr_map
# We need to pass in a map of all the additional imports necessary to run this tool
# Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool,
# We need to also import WikipediaAPIWrapper
# The map is a mapping of the module name to the attribute name
# langchain_community.utilities.WikipediaAPIWrapper
wikipedia_query_tool = client.load_langchain_tool(
langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
)
tool_name = wikipedia_query_tool.name
# Confirm that the tool is in
tools = client.list_tools()
assert wikipedia_query_tool.name in [t.name for t in tools]
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "letta-langchain-tooling-example"))
# Clear all agents
for agent_state in client.list_agents():
if agent_state.name == agent_uuid:
client.delete_agent(agent_id=agent_state.id)
print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}")
# google search persona
persona = f"""
My name is Letta.
I am a personal assistant who answers a user's questions using wikipedia searches. When a user asks me a question, I will use a tool called {tool_name} which will search Wikipedia and return a Wikipedia page about the topic. It is my job to construct the best query to input into {tool_name} based on the user's question.
Dont forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
"""
# Create an agent
agent_state = client.create_agent(
name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[wikipedia_query_tool.id]
)
print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}")
# Send a message to the agent
send_message_response = client.user_message(agent_id=agent_state.id, message="Tell me a fun fact about Albert Einstein!")
for message in send_message_response.messages:
response_json = json.dumps(message.model_dump(), indent=4)
print(f"{response_json}\n")
# Delete agent
client.delete_agent(agent_id=agent_state.id)
print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}")
if __name__ == "__main__":
main()

View File

@ -1,884 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cac06555-9ce8-4f01-bbef-3f8407f4b54d",
"metadata": {},
"source": [
"# Multi-agent recruiting workflow \n",
"> Make sure you run the Letta server before running this example using `letta server`\n",
"\n",
"Last tested with letta version `0.5.3`"
]
},
{
"cell_type": "markdown",
"id": "aad3a8cc-d17a-4da1-b621-ecc93c9e2106",
"metadata": {},
"source": [
"## Section 0: Setup a MemGPT client "
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "7ccd43f2-164b-4d25-8465-894a3bb54c4b",
"metadata": {},
"outputs": [],
"source": [
"from letta_client import CreateBlock, Letta, MessageCreate\n",
"\n",
"client = Letta(base_url=\"http://localhost:8283\")"
]
},
{
"cell_type": "markdown",
"id": "99a61da5-f069-4538-a548-c7d0f7a70227",
"metadata": {},
"source": [
"## Section 1: Shared Memory Block \n",
"Each agent will have both its own memory, and shared memory. The shared memory will contain information about the organization that the agents are all a part of. If one agent updates this memory, the changes will be propaged to the memory of all the other agents. "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7770600d-5e83-4498-acf1-05f5bea216c3",
"metadata": {},
"outputs": [],
"source": [
"org_description = \"The company is called AgentOS \" \\\n",
"+ \"and is building AI tools to make it easier to create \" \\\n",
"+ \"and deploy LLM agents.\"\n",
"\n",
"org_block = client.blocks.create(\n",
" label=\"company\",\n",
" value=org_description,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "6c3d3a55-870a-4ff0-81c0-4072f783a940",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"org_block"
]
},
{
"cell_type": "markdown",
"id": "8448df7b-c321-4d90-ba52-003930a513cb",
"metadata": {},
"source": [
"## Section 2: Orchestrating Multiple Agents \n",
"We'll implement a recruiting workflow that involves evaluating an candidate, then if the candidate is a good fit, writing a personalized email on the human's behalf. Since this task involves multiple stages, sometimes breaking the task down to multiple agents can improve performance (though this is not always the case). We will break down the task into: \n",
"\n",
"1. `eval_agent`: This agent is responsible for evaluating candidates based on their resume\n",
"2. `outreach_agent`: This agent is responsible for writing emails to strong candidates\n",
"3. `recruiter_agent`: This agent is responsible for generating leads from a database \n",
"\n",
"Much like humans, these agents will communicate by sending each other messages. We can do this by giving agents that need to communicate with other agents access to a tool that allows them to message other agents. "
]
},
{
"cell_type": "markdown",
"id": "a065082a-d865-483c-b721-43c5a4d51afe",
"metadata": {},
"source": [
"#### Evaluator Agent\n",
"This agent will have tools to: \n",
"* Read a resume \n",
"* Submit a candidate for outreach (which sends the candidate information to the `outreach_agent`)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c00232c5-4c37-436c-8ea4-602a31bd84fa",
"metadata": {},
"outputs": [],
"source": [
"def read_resume(self, name: str): \n",
" \"\"\"\n",
" Read the resume data for a candidate given the name\n",
"\n",
" Args: \n",
" name (str): Candidate name \n",
"\n",
" Returns: \n",
" resume_data (str): Candidate's resume data \n",
" \"\"\"\n",
" import os\n",
" filepath = os.path.join(\"data\", \"resumes\", name.lower().replace(\" \", \"_\") + \".txt\")\n",
" return open(filepath).read()\n",
"\n",
"def submit_evaluation(self, candidate_name: str, reach_out: bool, resume: str, justification: str): \n",
" \"\"\"\n",
" Submit a candidate for outreach. \n",
"\n",
" Args: \n",
" candidate_name (str): The name of the candidate\n",
" reach_out (bool): Whether to reach out to the candidate\n",
" resume (str): The text representation of the candidate's resume \n",
" justification (str): Justification for reaching out or not\n",
" \"\"\"\n",
" from letta import create_client \n",
" client = create_client()\n",
" message = \"Reach out to the following candidate. \" \\\n",
" + f\"Name: {candidate_name}\\n\" \\\n",
" + f\"Resume Data: {resume}\\n\" \\\n",
" + f\"Justification: {justification}\"\n",
" # NOTE: we will define this agent later \n",
" if reach_out:\n",
" response = client.send_message(\n",
" agent_name=\"outreach_agent\", \n",
" role=\"user\", \n",
" message=message\n",
" ) \n",
" else: \n",
" print(f\"Candidate {candidate_name} is rejected: {justification}\")\n",
"\n",
"# TODO: add an archival andidate tool (provide justification) \n",
"\n",
"read_resume_tool = client.tools.upsert_from_function(func=read_resume) \n",
"submit_evaluation_tool = client.tools.upsert_from_function(func=submit_evaluation)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "12482994-03f4-4dda-8ea2-6492ec28f392",
"metadata": {},
"outputs": [],
"source": [
"skills = \"Front-end (React, Typescript), software engineering \" \\\n",
"+ \"(ideally Python), and experience with LLMs.\"\n",
"eval_persona = f\"You are responsible to finding good recruiting \" \\\n",
"+ \"candidates, for the company description. \" \\\n",
"+ f\"Ideal canddiates have skills: {skills}. \" \\\n",
"+ \"Submit your candidate evaluation with the submit_evaluation tool. \"\n",
"\n",
"eval_agent = client.agents.create(\n",
" name=\"eval_agent\", \n",
" memory_blocks=[\n",
" CreateBlock(\n",
" label=\"persona\",\n",
" value=eval_persona,\n",
" ),\n",
" ],\n",
" block_ids=[org_block.id],\n",
" tool_ids=[read_resume_tool.id, submit_evaluation_tool.id]\n",
" model=\"openai/gpt-4\",\n",
" embedding=\"openai/text-embedding-ada-002\",\n",
")\n"
]
},
{
"cell_type": "markdown",
"id": "37c2d0be-b980-426f-ab24-1feaa8ed90ef",
"metadata": {},
"source": [
"#### Outreach agent \n",
"This agent will email candidates with customized emails. Since sending emails is a bit complicated, we'll just pretend we sent an email by printing it in the tool call. "
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "24e8942f-5b0e-4490-ac5f-f9e1f3178627",
"metadata": {},
"outputs": [],
"source": [
"def email_candidate(self, content: str): \n",
" \"\"\"\n",
" Send an email\n",
"\n",
" Args: \n",
" content (str): Content of the email \n",
" \"\"\"\n",
" print(\"Pretend to email:\", content)\n",
" return\n",
"\n",
"email_candidate_tool = client.tools.upsert_from_function(func=email_candidate)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "87416e00-c7a0-4420-be71-e2f5a6404428",
"metadata": {},
"outputs": [],
"source": [
"outreach_persona = \"You are responsible for sending outbound emails \" \\\n",
"+ \"on behalf of a company with the send_emails tool to \" \\\n",
"+ \"potential candidates. \" \\\n",
"+ \"If possible, make sure to personalize the email by appealing \" \\\n",
"+ \"to the recipient with details about the company. \" \\\n",
"+ \"You position is `Head Recruiter`, and you go by the name Bob, with contact info bob@gmail.com. \" \\\n",
"+ \"\"\"\n",
"Follow this email template: \n",
"\n",
"Hi <candidate name>, \n",
"\n",
"<content> \n",
"\n",
"Best, \n",
"<your name> \n",
"<company name> \n",
"\"\"\"\n",
" \n",
"outreach_agent = client.agents.create(\n",
" name=\"outreach_agent\", \n",
" memory_blocks=[\n",
" CreateBlock(\n",
" label=\"persona\",\n",
" value=outreach_persona,\n",
" ),\n",
" ],\n",
" block_ids=[org_block.id],\n",
" tool_ids=[email_candidate_tool.id]\n",
" model=\"openai/gpt-4\",\n",
" embedding=\"openai/text-embedding-ada-002\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f69d38da-807e-4bb1-8adb-f715b24f1c34",
"metadata": {},
"source": [
"Next, we'll send a message from the user telling the `leadgen_agent` to evaluate a given candidate: "
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "f09ab5bd-e158-42ee-9cce-43f254c4d2b0",
"metadata": {},
"outputs": [],
"source": [
"response = client.agents.messages.send(\n",
" agent_id=eval_agent.id,\n",
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" content=\"Candidate: Tony Stark\",\n",
" )\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "cd8f1a1e-21eb-47ae-9eed-b1d3668752ff",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <style>\n",
" .message-container, .usage-container {\n",
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n",
" max-width: 800px;\n",
" margin: 20px auto;\n",
" background-color: #1e1e1e;\n",
" border-radius: 8px;\n",
" overflow: hidden;\n",
" color: #d4d4d4;\n",
" }\n",
" .message, .usage-stats {\n",
" padding: 10px 15px;\n",
" border-bottom: 1px solid #3a3a3a;\n",
" }\n",
" .message:last-child, .usage-stats:last-child {\n",
" border-bottom: none;\n",
" }\n",
" .title {\n",
" font-weight: bold;\n",
" margin-bottom: 5px;\n",
" color: #ffffff;\n",
" text-transform: uppercase;\n",
" font-size: 0.9em;\n",
" }\n",
" .content {\n",
" background-color: #2d2d2d;\n",
" border-radius: 4px;\n",
" padding: 5px 10px;\n",
" font-family: 'Consolas', 'Courier New', monospace;\n",
" white-space: pre-wrap;\n",
" }\n",
" .json-key, .function-name, .json-boolean { color: #9cdcfe; }\n",
" .json-string { color: #ce9178; }\n",
" .json-number { color: #b5cea8; }\n",
" .internal-monologue { font-style: italic; }\n",
" </style>\n",
" <div class=\"message-container\">\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">INTERNAL MONOLOGUE</div>\n",
" <div class=\"content\"><span class=\"internal-monologue\">Checking the resume for Tony Stark to evaluate if he fits the bill for our needs.</span></div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION CALL</div>\n",
" <div class=\"content\"><span class=\"function-name\">read_resume</span>({<br>&nbsp;&nbsp;<span class=\"json-key\">\"name\"</span>: <span class=\"json-key\">\"Tony Stark\",<br>&nbsp;&nbsp;\"request_heartbeat\"</span>: <span class=\"json-boolean\">true</span><br>})</div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION RETURN</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"status\"</span>: <span class=\"json-key\">\"Failed\",<br>&nbsp;&nbsp;\"message\"</span>: <span class=\"json-key\">\"Error calling function read_resume: [Errno 2] No such file or directory: 'data/resumes/tony_stark.txt'\",<br>&nbsp;&nbsp;\"time\"</span>: <span class=\"json-string\">\"2024-11-13 05:51:26 PM PST-0800\"</span><br>}</div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">INTERNAL MONOLOGUE</div>\n",
" <div class=\"content\"><span class=\"internal-monologue\">I couldn&#x27;t retrieve Tony&#x27;s resume. Need to handle this carefully to keep the conversation flowing.</span></div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION CALL</div>\n",
" <div class=\"content\"><span class=\"function-name\">send_message</span>({<br>&nbsp;&nbsp;<span class=\"json-key\">\"message\"</span>: <span class=\"json-string\">\"It looks like I'm having trouble accessing Tony Stark's resume at the moment. Can you provide more details about his qualifications?\"</span><br>})</div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION RETURN</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"status\"</span>: <span class=\"json-key\">\"OK\",<br>&nbsp;&nbsp;\"message\"</span>: <span class=\"json-key\">\"None\",<br>&nbsp;&nbsp;\"time\"</span>: <span class=\"json-string\">\"2024-11-13 05:51:28 PM PST-0800\"</span><br>}</div>\n",
" </div>\n",
" </div>\n",
" <div class=\"usage-container\">\n",
" <div class=\"usage-stats\">\n",
" <div class=\"title\">USAGE STATISTICS</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"completion_tokens\"</span>: <span class=\"json-number\">103</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"prompt_tokens\"</span>: <span class=\"json-number\">4999</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"total_tokens\"</span>: <span class=\"json-number\">5102</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"step_count\"</span>: <span class=\"json-number\">2</span><br>}</div>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
"LettaResponse(messages=[InternalMonologue(id='message-97a1ae82-f8f3-419f-94c4-263112dbc10b', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 799617, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Checking the resume for Tony Stark to evaluate if he fits the bill for our needs.'), FunctionCallMessage(id='message-97a1ae82-f8f3-419f-94c4-263112dbc10b', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 799617, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='read_resume', arguments='{\\n \"name\": \"Tony Stark\",\\n \"request_heartbeat\": true\\n}', function_call_id='call_wOsiHlU3551JaApHKP7rK4Rt')), FunctionReturn(id='message-97a2b57e-40c6-4f06-a307-a0e3a00717ce', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 803505, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"Failed\",\\n \"message\": \"Error calling function read_resume: [Errno 2] No such file or directory: \\'data/resumes/tony_stark.txt\\'\",\\n \"time\": \"2024-11-13 05:51:26 PM PST-0800\"\\n}', status='error', function_call_id='call_wOsiHlU3551JaApHKP7rK4Rt'), InternalMonologue(id='message-8e249aea-27ce-4788-b3e0-ac4c8401bc93', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 360676, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue=\"I couldn't retrieve Tony's resume. Need to handle this carefully to keep the conversation flowing.\"), FunctionCallMessage(id='message-8e249aea-27ce-4788-b3e0-ac4c8401bc93', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 360676, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"It looks like I\\'m having trouble accessing Tony Stark\\'s resume at the moment. Can you provide more details about his qualifications?\"\\n}', function_call_id='call_1DoFBhOsP9OCpdPQjUfBcKjw')), FunctionReturn(id='message-5600e8e7-6c6f-482a-8594-a0483ef523a2', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 361921, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:28 PM PST-0800\"\\n}', status='success', function_call_id='call_1DoFBhOsP9OCpdPQjUfBcKjw')], usage=LettaUsageStatistics(completion_tokens=103, prompt_tokens=4999, total_tokens=5102, step_count=2))"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
},
{
"cell_type": "markdown",
"id": "67069247-e603-439c-b2df-9176c4eba957",
"metadata": {},
"source": [
"#### Providing feedback to agents \n",
"Since MemGPT agents are persisted, we can provide feedback to agents that is used in future agent executions if we want to modify the future behavior. "
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "19c57d54-a1fe-4244-b765-b996ba9a4788",
"metadata": {},
"outputs": [],
"source": [
"feedback = \"Our company pivoted to foundation model training\"\n",
"response = client.agents.messages.send(\n",
" agent_id=eval_agent.id,\n",
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" content=feedback,\n",
" )\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "036b973f-209a-4ad9-90e7-fc827b5d92c7",
"metadata": {},
"outputs": [],
"source": [
"\n",
"feedback = \"The company is also renamed to FoundationAI\"\n",
"response = client.agents.messages.send(\n",
" agent_id=eval_agent.id,\n",
" messages=[\n",
" MessageCreate(\n",
" role=\"user\",\n",
" content=feedback,\n",
" )\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "5d7a7633-35a3-4e41-b44a-be71067dd32a",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <style>\n",
" .message-container, .usage-container {\n",
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n",
" max-width: 800px;\n",
" margin: 20px auto;\n",
" background-color: #1e1e1e;\n",
" border-radius: 8px;\n",
" overflow: hidden;\n",
" color: #d4d4d4;\n",
" }\n",
" .message, .usage-stats {\n",
" padding: 10px 15px;\n",
" border-bottom: 1px solid #3a3a3a;\n",
" }\n",
" .message:last-child, .usage-stats:last-child {\n",
" border-bottom: none;\n",
" }\n",
" .title {\n",
" font-weight: bold;\n",
" margin-bottom: 5px;\n",
" color: #ffffff;\n",
" text-transform: uppercase;\n",
" font-size: 0.9em;\n",
" }\n",
" .content {\n",
" background-color: #2d2d2d;\n",
" border-radius: 4px;\n",
" padding: 5px 10px;\n",
" font-family: 'Consolas', 'Courier New', monospace;\n",
" white-space: pre-wrap;\n",
" }\n",
" .json-key, .function-name, .json-boolean { color: #9cdcfe; }\n",
" .json-string { color: #ce9178; }\n",
" .json-number { color: #b5cea8; }\n",
" .internal-monologue { font-style: italic; }\n",
" </style>\n",
" <div class=\"message-container\">\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">INTERNAL MONOLOGUE</div>\n",
" <div class=\"content\"><span class=\"internal-monologue\">Updating the company name to reflect the rebranding. This is important for future candidate evaluations.</span></div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION CALL</div>\n",
" <div class=\"content\"><span class=\"function-name\">core_memory_replace</span>({<br>&nbsp;&nbsp;<span class=\"json-key\">\"label\"</span>: <span class=\"json-key\">\"company\",<br>&nbsp;&nbsp;\"old_content\"</span>: <span class=\"json-key\">\"The company has pivoted to foundation model training.\",<br>&nbsp;&nbsp;\"new_content\"</span>: <span class=\"json-key\">\"The company is called FoundationAI and has pivoted to foundation model training.\",<br>&nbsp;&nbsp;\"request_heartbeat\"</span>: <span class=\"json-boolean\">true</span><br>})</div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION RETURN</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"status\"</span>: <span class=\"json-key\">\"OK\",<br>&nbsp;&nbsp;\"message\"</span>: <span class=\"json-key\">\"None\",<br>&nbsp;&nbsp;\"time\"</span>: <span class=\"json-string\">\"2024-11-13 05:51:34 PM PST-0800\"</span><br>}</div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">INTERNAL MONOLOGUE</div>\n",
" <div class=\"content\"><span class=\"internal-monologue\">Now I have the updated company info, time to check in on Tony.</span></div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION CALL</div>\n",
" <div class=\"content\"><span class=\"function-name\">send_message</span>({<br>&nbsp;&nbsp;<span class=\"json-key\">\"message\"</span>: <span class=\"json-string\">\"Got it, the new name is FoundationAI! What about Tony Stark's background catches your eye for this role? Any particular insights on his skills in front-end development or LLMs?\"</span><br>})</div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION RETURN</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"status\"</span>: <span class=\"json-key\">\"OK\",<br>&nbsp;&nbsp;\"message\"</span>: <span class=\"json-key\">\"None\",<br>&nbsp;&nbsp;\"time\"</span>: <span class=\"json-string\">\"2024-11-13 05:51:35 PM PST-0800\"</span><br>}</div>\n",
" </div>\n",
" </div>\n",
" <div class=\"usage-container\">\n",
" <div class=\"usage-stats\">\n",
" <div class=\"title\">USAGE STATISTICS</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"completion_tokens\"</span>: <span class=\"json-number\">146</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"prompt_tokens\"</span>: <span class=\"json-number\">6372</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"total_tokens\"</span>: <span class=\"json-number\">6518</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"step_count\"</span>: <span class=\"json-number\">2</span><br>}</div>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
"LettaResponse(messages=[InternalMonologue(id='message-0adccea9-4b96-4cbb-b5fc-a9ef0120c646', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 180327, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Updating the company name to reflect the rebranding. This is important for future candidate evaluations.'), FunctionCallMessage(id='message-0adccea9-4b96-4cbb-b5fc-a9ef0120c646', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 180327, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='core_memory_replace', arguments='{\\n \"label\": \"company\",\\n \"old_content\": \"The company has pivoted to foundation model training.\",\\n \"new_content\": \"The company is called FoundationAI and has pivoted to foundation model training.\",\\n \"request_heartbeat\": true\\n}', function_call_id='call_5s0KTElXdipPidchUu3R9CxI')), FunctionReturn(id='message-a2f278e8-ec23-4e22-a124-c21a0f46f733', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 182291, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:34 PM PST-0800\"\\n}', status='success', function_call_id='call_5s0KTElXdipPidchUu3R9CxI'), InternalMonologue(id='message-91f63cb2-b544-4b2e-82b1-b11643df5f93', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 841684, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Now I have the updated company info, time to check in on Tony.'), FunctionCallMessage(id='message-91f63cb2-b544-4b2e-82b1-b11643df5f93', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 841684, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"Got it, the new name is FoundationAI! What about Tony Stark\\'s background catches your eye for this role? Any particular insights on his skills in front-end development or LLMs?\"\\n}', function_call_id='call_R4Erx7Pkpr5lepcuaGQU5isS')), FunctionReturn(id='message-813a9306-38fc-4665-9f3b-7c3671fd90e6', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 842423, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:35 PM PST-0800\"\\n}', status='success', function_call_id='call_R4Erx7Pkpr5lepcuaGQU5isS')], usage=LettaUsageStatistics(completion_tokens=146, prompt_tokens=6372, total_tokens=6518, step_count=2))"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "d04d4b3a-6df1-41a9-9a8e-037fbb45836d",
"metadata": {},
"outputs": [],
"source": [
"response = client.agents.messages.send(\n",
" agent_id=eval_agent.id,\n",
" messages=[\n",
" MessageCreate(\n",
" role=\"system\",\n",
" content=\"Candidate: Spongebob Squarepants\",\n",
" )\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "c60465f4-7977-4f70-9a75-d2ddebabb0fa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.\\nThe company is called FoundationAI and has pivoted to foundation model training.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client.agents.core_memory.get_block(agent_id=eval_agent.id, block_label=\"company\")"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "a51c6bb3-225d-47a4-88f1-9a26ff838dd3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client.agents.core_memory.get_block(agent_id=outreach_agent.id, block_label=\"company\")"
]
},
{
"cell_type": "markdown",
"id": "8d181b1e-72da-4ebe-a872-293e3ce3a225",
"metadata": {},
"source": [
"## Section 3: Adding an orchestrator agent \n",
"So far, we've been triggering the `eval_agent` manually. We can also create an additional agent that is responsible for orchestrating tasks. "
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "80b23d46-ed4b-4457-810a-a819d724e146",
"metadata": {},
"outputs": [],
"source": [
"#re-create agents \n",
"client.agents.delete(eval_agent.id)\n",
"client.agents.delete(outreach_agent.id)\n",
"\n",
"org_block = client.blocks.create(\n",
" label=\"company\",\n",
" value=org_description,\n",
")\n",
"\n",
"eval_agent = client.agents.create(\n",
" name=\"eval_agent\", \n",
" memory_blocks=[\n",
" CreateBlock(\n",
" label=\"persona\",\n",
" value=eval_persona,\n",
" ),\n",
" ],\n",
" block_ids=[org_block.id],\n",
" tool_ids=[read_resume_tool.id, submit_evaluation_tool.id]\n",
" model=\"openai/gpt-4\",\n",
" embedding=\"openai/text-embedding-ada-002\",\n",
")\n",
"\n",
"outreach_agent = client.agents.create(\n",
" name=\"outreach_agent\", \n",
" memory_blocks=[\n",
" CreateBlock(\n",
" label=\"persona\",\n",
" value=outreach_persona,\n",
" ),\n",
" ],\n",
" block_ids=[org_block.id],\n",
" tool_ids=[email_candidate_tool.id]\n",
" model=\"openai/gpt-4\",\n",
" embedding=\"openai/text-embedding-ada-002\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "a751d0f1-b52d-493c-bca1-67f88011bded",
"metadata": {},
"source": [
"The `recruiter_agent` will be linked to the same `org_block` that we created before - we can look up the current data in `org_block` by looking up its ID: "
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "bf6bd419-1504-4513-bc68-d4c717ea8e2d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.\\nThe company is called FoundationAI and has pivoted to foundation model training.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id='user-00000000-0000-4000-8000-000000000000', id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client.blocks.retrieve(block_id=org_block.id)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "e2730626-1685-46aa-9b44-a59e1099e973",
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"\n",
"def search_candidates_db(self, page: int) -> Optional[str]: \n",
" \"\"\"\n",
" Returns 1 candidates per page. \n",
" Page 0 returns the first 1 candidate, \n",
" Page 1 returns the next 1, etc.\n",
" Returns `None` if no candidates remain. \n",
"\n",
" Args: \n",
" page (int): The page number to return candidates from \n",
"\n",
" Returns: \n",
" candidate_names (List[str]): Names of the candidates\n",
" \"\"\"\n",
" \n",
" names = [\"Tony Stark\", \"Spongebob Squarepants\", \"Gautam Fang\"]\n",
" if page >= len(names): \n",
" return None\n",
" return names[page]\n",
"\n",
"def consider_candidate(self, name: str): \n",
" \"\"\"\n",
" Submit a candidate for consideration. \n",
"\n",
" Args: \n",
" name (str): Candidate name to consider \n",
" \"\"\"\n",
" from letta_client import Letta, MessageCreate\n",
" client = Letta(base_url=\"http://localhost:8283\")\n",
" message = f\"Consider candidate {name}\" \n",
" print(\"Sending message to eval agent: \", message)\n",
" response = client.send_message(\n",
" agent_id=eval_agent.id,\n",
" role=\"user\", \n",
" message=message\n",
" ) \n",
"\n",
"\n",
"# create tools \n",
"search_candidate_tool = client.tools.upsert_from_function(func=search_candidates_db)\n",
"consider_candidate_tool = client.tools.upsert_from_function(func=consider_candidate)\n",
"\n",
"# create recruiter agent\n",
"recruiter_agent = client.agents.create(\n",
" name=\"recruiter_agent\", \n",
" memory_blocks=[\n",
" CreateBlock(\n",
" label=\"persona\",\n",
" value=\"You run a recruiting process for a company. \" \\\n",
" + \"Your job is to continue to pull candidates from the \" \n",
" + \"`search_candidates_db` tool until there are no more \" \\\n",
" + \"candidates left. \" \\\n",
" + \"For each candidate, consider the candidate by calling \"\n",
" + \"the `consider_candidate` tool. \" \\\n",
" + \"You should continue to call `search_candidates_db` \" \\\n",
" + \"followed by `consider_candidate` until there are no more \" \\\n",
" \" candidates. \",\n",
" ),\n",
" ],\n",
" block_ids=[org_block.id],\n",
" tool_ids=[search_candidate_tool.id, consider_candidate_tool.id],\n",
" model=\"openai/gpt-4\",\n",
" embedding=\"openai/text-embedding-ada-002\"\n",
")\n",
" \n"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "ecfd790c-0018-4fd9-bdaf-5a6b81f70adf",
"metadata": {},
"outputs": [],
"source": [
"response = client.agents.messages.send(\n",
" agent_id=recruiter_agent.id,\n",
" messages=[\n",
" MessageCreate(\n",
" role=\"system\",\n",
" content=\"Run generation\",\n",
" )\n",
" ],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "8065c179-cf90-4287-a6e5-8c009807b436",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"\n",
" <style>\n",
" .message-container, .usage-container {\n",
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n",
" max-width: 800px;\n",
" margin: 20px auto;\n",
" background-color: #1e1e1e;\n",
" border-radius: 8px;\n",
" overflow: hidden;\n",
" color: #d4d4d4;\n",
" }\n",
" .message, .usage-stats {\n",
" padding: 10px 15px;\n",
" border-bottom: 1px solid #3a3a3a;\n",
" }\n",
" .message:last-child, .usage-stats:last-child {\n",
" border-bottom: none;\n",
" }\n",
" .title {\n",
" font-weight: bold;\n",
" margin-bottom: 5px;\n",
" color: #ffffff;\n",
" text-transform: uppercase;\n",
" font-size: 0.9em;\n",
" }\n",
" .content {\n",
" background-color: #2d2d2d;\n",
" border-radius: 4px;\n",
" padding: 5px 10px;\n",
" font-family: 'Consolas', 'Courier New', monospace;\n",
" white-space: pre-wrap;\n",
" }\n",
" .json-key, .function-name, .json-boolean { color: #9cdcfe; }\n",
" .json-string { color: #ce9178; }\n",
" .json-number { color: #b5cea8; }\n",
" .internal-monologue { font-style: italic; }\n",
" </style>\n",
" <div class=\"message-container\">\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">INTERNAL MONOLOGUE</div>\n",
" <div class=\"content\"><span class=\"internal-monologue\">New user logged in. Excited to get started!</span></div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION CALL</div>\n",
" <div class=\"content\"><span class=\"function-name\">send_message</span>({<br>&nbsp;&nbsp;<span class=\"json-key\">\"message\"</span>: <span class=\"json-string\">\"Welcome! I'm thrilled to have you here. Lets dive into what you need today!\"</span><br>})</div>\n",
" </div>\n",
" \n",
" <div class=\"message\">\n",
" <div class=\"title\">FUNCTION RETURN</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"status\"</span>: <span class=\"json-key\">\"OK\",<br>&nbsp;&nbsp;\"message\"</span>: <span class=\"json-key\">\"None\",<br>&nbsp;&nbsp;\"time\"</span>: <span class=\"json-string\">\"2024-11-13 05:52:14 PM PST-0800\"</span><br>}</div>\n",
" </div>\n",
" </div>\n",
" <div class=\"usage-container\">\n",
" <div class=\"usage-stats\">\n",
" <div class=\"title\">USAGE STATISTICS</div>\n",
" <div class=\"content\">{<br>&nbsp;&nbsp;<span class=\"json-key\">\"completion_tokens\"</span>: <span class=\"json-number\">48</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"prompt_tokens\"</span>: <span class=\"json-number\">2398</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"total_tokens\"</span>: <span class=\"json-number\">2446</span>,<br>&nbsp;&nbsp;<span class=\"json-key\">\"step_count\"</span>: <span class=\"json-number\">1</span><br>}</div>\n",
" </div>\n",
" </div>\n",
" "
],
"text/plain": [
"LettaResponse(messages=[InternalMonologue(id='message-8c8ab238-a43e-4509-b7ad-699e9a47ed44', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 780419, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='New user logged in. Excited to get started!'), FunctionCallMessage(id='message-8c8ab238-a43e-4509-b7ad-699e9a47ed44', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 780419, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"Welcome! I\\'m thrilled to have you here. Lets dive into what you need today!\"\\n}', function_call_id='call_2OIz7t3oiGsUlhtSneeDslkj')), FunctionReturn(id='message-26c3b7a3-51c8-47ae-938d-a3ed26e42357', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 781455, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:52:14 PM PST-0800\"\\n}', status='success', function_call_id='call_2OIz7t3oiGsUlhtSneeDslkj')], usage=LettaUsageStatistics(completion_tokens=48, prompt_tokens=2398, total_tokens=2446, step_count=1))"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "4639bbca-e0c5-46a9-a509-56d35d26e97f",
"metadata": {},
"outputs": [],
"source": [
"client.agents.delete(agent_id=eval_agent.id)\n",
"client.agents.delete(agent_id=outreach_agent.id)\n",
"client.agents.delete(agent_id=recruiter_agent.id)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "letta",
"language": "python",
"name": "letta"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -1,72 +0,0 @@
import typer
from swarm import Swarm
from letta import EmbeddingConfig, LLMConfig
"""
This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents:
https://github.com/openai/swarm/tree/main?tab=readme-ov-file#usage
Before running this example, make sure you have letta>=0.5.0 installed. This example also runs with OpenAI, though you can also change the model by modifying the code:
```bash
export OPENAI_API_KEY=...
pip install letta
````
Then, instead the `examples/swarm` directory, run:
```bash
python simple.py
```
You should see a message output from Agent B.
"""
def transfer_agent_b(self):
"""
Transfer conversation to agent B.
Returns:
str: name of agent to transfer to
"""
return "agentb"
def transfer_agent_a(self):
"""
Transfer conversation to agent A.
Returns:
str: name of agent to transfer to
"""
return "agenta"
swarm = Swarm()
# set client configs
swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4"))
# create tools
transfer_a = swarm.client.create_or_update_tool(transfer_agent_a)
transfer_b = swarm.client.create_or_update_tool(transfer_agent_b)
# create agents
if swarm.client.get_agent_id("agentb"):
swarm.client.delete_agent(swarm.client.get_agent_id("agentb"))
if swarm.client.get_agent_id("agenta"):
swarm.client.delete_agent(swarm.client.get_agent_id("agenta"))
agent_a = swarm.create_agent(name="agentb", tools=[transfer_a.name], instructions="Only speak in haikus")
agent_b = swarm.create_agent(name="agenta", tools=[transfer_b.name])
response = swarm.run(agent_name="agenta", message="Transfer me to agent b by calling the transfer_agent_b tool")
print("Response:")
typer.secho(f"{response}", fg=typer.colors.GREEN)
response = swarm.run(agent_name="agenta", message="My name is actually Sarah. Transfer me to agent b to write a haiku about my name")
print("Response:")
typer.secho(f"{response}", fg=typer.colors.GREEN)
response = swarm.run(agent_name="agenta", message="Transfer me to agent b - I want a haiku with my name in it")
print("Response:")
typer.secho(f"{response}", fg=typer.colors.GREEN)

View File

@ -1,111 +0,0 @@
import json
from typing import List, Optional
import typer
from letta import AgentState, EmbeddingConfig, LLMConfig, create_client
from letta.schemas.agent import AgentType
from letta.schemas.memory import BasicBlockMemory, Block
class Swarm:
def __init__(self):
self.agents = []
self.client = create_client()
# shared memory block (shared section of context window accross agents)
self.shared_memory = Block(label="human", value="")
def create_agent(
self,
name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs
embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None,
# system
system: Optional[str] = None,
# tools
tools: Optional[List[str]] = None,
include_base_tools: Optional[bool] = True,
# instructions
instructions: str = "",
) -> AgentState:
# todo: process tools for agent handoff
persona_value = (
f"You are agent with name {name}. You instructions are {instructions}"
if len(instructions) > 0
else f"You are agent with name {name}"
)
persona_block = Block(label="persona", value=persona_value)
memory = BasicBlockMemory(blocks=[persona_block, self.shared_memory])
agent = self.client.create_agent(
name=name,
agent_type=agent_type,
embedding_config=embedding_config,
llm_config=llm_config,
system=system,
tools=tools,
include_base_tools=include_base_tools,
memory=memory,
)
self.agents.append(agent)
return agent
def reset(self):
# delete all agents
for agent in self.agents:
self.client.delete_agent(agent.id)
for block in self.client.list_blocks():
self.client.delete_block(block.id)
def run(self, agent_name: str, message: str):
history = []
while True:
# send message to agent
agent_id = self.client.get_agent_id(agent_name)
print("Messaging agent: ", agent_name)
print("History size: ", len(history))
# print(self.client.get_agent(agent_id).tools)
# TODO: implement with sending multiple messages
if len(history) == 0:
response = self.client.send_message(agent_id=agent_id, message=message, role="user")
else:
response = self.client.send_messages(agent_id=agent_id, messages=history)
# update history
history += response.messages
# grab responses
messages = []
for message in response.messages:
messages += message.to_letta_messages()
# get new agent (see tool call)
# print(messages)
if len(messages) < 2:
continue
function_call = messages[-2]
function_return = messages[-1]
if function_call.function_call.name == "send_message":
# return message to use
arg_data = json.loads(function_call.function_call.arguments)
# print(arg_data)
return arg_data["message"]
else:
# swap the agent
return_data = json.loads(function_return.function_return)
agent_name = return_data["message"]
typer.secho(f"Transferring to agent: {agent_name}", fg=typer.colors.RED)
# print("Transferring to agent", agent_name)
print()

View File

@ -1,129 +0,0 @@
import os
import uuid
from letta import create_client
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import assert_invoked_send_message_with_keyword, setup_agent
from tests.helpers.utils import cleanup
from tests.test_model_letta_performance import llm_config_dir
"""
This example shows how you can constrain tool calls in your agent.
Please note that this currently only works reliably for models with Structured Outputs (e.g. gpt-4o).
Start by downloading the dependencies.
```
poetry install --all-extras
```
"""
# Tools for this example
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "agent_tool_graph"))
config_file = os.path.join(llm_config_dir, "openai-gpt-4o.json")
"""Contrived tools for this test case"""
def first_secret_word():
"""
Call this to retrieve the first secret word, which you will need for the second_secret_word function.
"""
return "v0iq020i0g"
def second_secret_word(prev_secret_word: str):
"""
Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error.
Args:
prev_secret_word (str): The secret word retrieved from calling first_secret_word.
"""
if prev_secret_word != "v0iq020i0g":
raise RuntimeError(f"Expected secret {"v0iq020i0g"}, got {prev_secret_word}")
return "4rwp2b4gxq"
def third_secret_word(prev_secret_word: str):
"""
Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error.
Args:
prev_secret_word (str): The secret word retrieved from calling second_secret_word.
"""
if prev_secret_word != "4rwp2b4gxq":
raise RuntimeError(f"Expected secret {"4rwp2b4gxq"}, got {prev_secret_word}")
return "hj2hwibbqm"
def fourth_secret_word(prev_secret_word: str):
"""
Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error.
Args:
prev_secret_word (str): The secret word retrieved from calling third_secret_word.
"""
if prev_secret_word != "hj2hwibbqm":
raise RuntimeError(f"Expected secret {"hj2hwibbqm"}, got {prev_secret_word}")
return "banana"
def auto_error():
"""
If you call this function, it will throw an error automatically.
"""
raise RuntimeError("This should never be called.")
def main():
# 1. Set up the client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
# 2. Add all the tools to the client
functions = [first_secret_word, second_secret_word, third_secret_word, fourth_secret_word, auto_error]
tools = []
for func in functions:
tool = client.create_or_update_tool(func)
tools.append(tool)
tool_names = [t.name for t in tools[:-1]]
# 3. Create the tool rules. It must be called in this order, or there will be an error thrown.
tool_rules = [
InitToolRule(tool_name="first_secret_word"),
ChildToolRule(tool_name="first_secret_word", children=["second_secret_word"]),
ChildToolRule(tool_name="second_secret_word", children=["third_secret_word"]),
ChildToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]),
ChildToolRule(tool_name="fourth_secret_word", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
# 4. Create the agent
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# 5. Ask for the final secret word
response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?")
# 6. Here, we thoroughly check the correctness of the response
tool_names += ["send_message"] # Add send message because we expect this to be called at the end
for m in response.messages:
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.tool_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]
# Check final send message contains "banana"
assert_invoked_send_message_with_keyword(response.messages, "banana")
print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)
if __name__ == "__main__":
main()

View File

@ -1,7 +1,7 @@
__version__ = "0.7.21"
__version__ = "0.7.22"
# import clients
from letta.client.client import LocalClient, RESTClient, create_client
from letta.client.client import RESTClient
# imports for easier access
from letta.schemas.agent import AgentState

View File

@ -1,3 +0,0 @@
from .main import app
app()

View File

@ -100,8 +100,10 @@ class BaseAgent(ABC):
# [DB Call] size of messages and archival memories
# todo: blocking for now
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
if num_messages is None:
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
if num_archival_memories is None:
num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,

View File

@ -174,20 +174,13 @@ class LettaAgent(BaseAgent):
for message in letta_messages:
yield f"data: {message.model_dump_json()}\n\n"
# update usage
# TODO: add run_id
usage.step_count += 1
usage.completion_tokens += response.usage.completion_tokens
usage.prompt_tokens += response.usage.prompt_tokens
usage.total_tokens += response.usage.total_tokens
if not should_continue:
break
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
# Return back usage
yield f"data: {usage.model_dump_json()}\n\n"
@ -285,7 +278,7 @@ class LettaAgent(BaseAgent):
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
return current_in_context_messages, new_in_context_messages, usage
@ -437,7 +430,7 @@ class LettaAgent(BaseAgent):
# Extend the in context message ids
if not agent_state.message_buffer_autoclear:
message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)]
self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor)
# TODO: This may be out of sync, if in between steps users add files
# NOTE (cliandy): temporary for now for particlar use cases.

View File

@ -233,7 +233,7 @@ class LettaAgentBatch(BaseAgent):
ctx = await self._collect_resume_context(llm_batch_id)
log_event(name="update_statuses")
self._update_request_statuses(ctx.request_status_updates)
await self._update_request_statuses_async(ctx.request_status_updates)
log_event(name="exec_tools")
exec_results = await self._execute_tools(ctx)
@ -242,7 +242,7 @@ class LettaAgentBatch(BaseAgent):
msg_map = await self._persist_tool_messages(exec_results, ctx)
log_event(name="mark_steps_done")
self._mark_steps_complete(llm_batch_id, ctx.agent_ids)
await self._mark_steps_complete_async(llm_batch_id, ctx.agent_ids)
log_event(name="prepare_next")
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
@ -382,9 +382,9 @@ class LettaAgentBatch(BaseAgent):
return self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None:
async def _update_request_statuses_async(self, updates: List[RequestStatusUpdateInfo]) -> None:
if updates:
self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates)
await self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(updates=updates)
def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]:
sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL
@ -474,11 +474,11 @@ class LettaAgentBatch(BaseAgent):
await self.message_manager.create_many_messages_async([m for msgs in msg_map.values() for m in msgs], actor=self.actor)
return msg_map
def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None:
async def _mark_steps_complete_async(self, llm_batch_id: str, agent_ids: List[str]) -> None:
updates = [
StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids
]
self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates)
await self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(updates)
def _prepare_next_iteration(
self,

View File

@ -1,98 +0,0 @@
# type: ignore
import time
import uuid
from typing import Annotated, Union
import typer
from letta import LocalClient, RESTClient, create_client
from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES
from letta.config import LettaConfig
# from letta.agent import Agent
from letta.errors import LLMJSONParsingError
from letta.utils import get_human_text, get_persona_text
app = typer.Typer()
def send_message(
client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES
):
try:
print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}"
print(print_msg, end="\r", flush=True)
response = client.user_message(agent_id=agent_id, message=message)
if turn + 1 == n_tries:
print(" " * len(print_msg), end="\r", flush=True)
for r in response:
if "function_call" in r and fn_type in r["function_call"] and any("assistant_message" in re for re in response):
return True, r["function_call"]
return False, "No function called."
except LLMJSONParsingError as e:
print(f"Error in parsing Letta JSON: {e}")
return False, "Failed to decode valid Letta JSON from LLM output."
except Exception as e:
print(f"An unexpected error occurred: {e}")
return False, "An unexpected error occurred."
@app.command()
def bench(
print_messages: Annotated[bool, typer.Option("--messages", help="Print functions calls and messages from the agent.")] = False,
n_tries: Annotated[int, typer.Option("--n-tries", help="Number of benchmark tries to perform for each function.")] = TRIES,
):
client = create_client()
print(f"\nDepending on your hardware, this may take up to 30 minutes. This will also create {n_tries * len(PROMPTS)} new agents.\n")
config = LettaConfig.load()
print(f"version = {config.letta_version}")
total_score, total_tokens_accumulated, elapsed_time = 0, 0, 0
for fn_type, message in PROMPTS.items():
score = 0
start_time_run = time.time()
bench_id = uuid.uuid4()
for i in range(n_tries):
agent = client.create_agent(
name=f"benchmark_{bench_id}_agent_{i}",
persona=get_persona_text(PERSONA),
human=get_human_text(HUMAN),
)
agent_id = agent.id
result, msg = send_message(
client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries
)
if print_messages:
print(f"\t{msg}")
if result:
score += 1
# TODO: add back once we start tracking usage via the client
# total_tokens_accumulated += tokens_accumulated
elapsed_time_run = round(time.time() - start_time_run, 2)
print(f"Score for {fn_type}: {score}/{n_tries}, took {elapsed_time_run} seconds")
elapsed_time += elapsed_time_run
total_score += score
print(f"\nMEMGPT VERSION: {config.letta_version}")
print(f"CONTEXT WINDOW: {config.default_llm_config.context_window}")
print(f"MODEL WRAPPER: {config.default_llm_config.model_wrapper}")
print(f"PRESET: {config.preset}")
print(f"PERSONA: {config.persona}")
print(f"HUMAN: {config.human}")
print(
# f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds at average of {round(total_tokens_accumulated/elapsed_time, 2)} t/s\n"
f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds\n"
)

View File

@ -1,14 +0,0 @@
# Basic
TRIES = 3
AGENT_NAME = "benchmark"
PERSONA = "sam_pov"
HUMAN = "cs_phd"
# Prompts
PROMPTS = {
"core_memory_replace": "Hey there, my name is John, what is yours?",
"core_memory_append": "I want you to remember that I like soccers for later.",
"conversation_search": "Do you remember when I talked about bananas?",
"archival_memory_insert": "Can you make sure to remember that I like programming for me so you can look it up later?",
"archival_memory_search": "Can you retrieve information about the war?",
}

View File

@ -1,37 +1,15 @@
import logging
import sys
from enum import Enum
from typing import Annotated, Optional
import questionary
import typer
import letta.utils as utils
from letta import create_client
from letta.agent import Agent, save_agent
from letta.config import LettaConfig
from letta.constants import CLI_WARNING_PREFIX, CORE_MEMORY_BLOCK_CHAR_LIMIT, LETTA_DIR, MIN_CONTEXT_WINDOW
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
from letta.log import get_logger
from letta.schemas.enums import OptionState
from letta.schemas.memory import ChatMemory, Memory
# from letta.interface import CLIInterface as interface # for printing to terminal
from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
from letta.utils import open_folder_in_explorer, printd
logger = get_logger(__name__)
def open_folder():
"""Open a folder viewer of the Letta home directory"""
try:
print(f"Opening home folder: {LETTA_DIR}")
open_folder_in_explorer(LETTA_DIR)
except Exception as e:
print(f"Failed to open folder with system viewer, error:\n{e}")
class ServerChoice(Enum):
rest_api = "rest"
ws_api = "websocket"
@ -51,14 +29,6 @@ def server(
if type == ServerChoice.rest_api:
pass
# if LettaConfig.exists():
# config = LettaConfig.load()
# MetadataStore(config)
# _ = create_client() # triggers user creation
# else:
# typer.secho(f"No configuration exists. Run letta configure before starting the server.", fg=typer.colors.RED)
# sys.exit(1)
try:
from letta.server.rest_api.app import start_server
@ -73,292 +43,6 @@ def server(
raise NotImplementedError("WS suppport deprecated")
def run(
persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None,
agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None,
human: Annotated[Optional[str], typer.Option(help="Specify human")] = None,
system: Annotated[Optional[str], typer.Option(help="Specify system prompt (raw text)")] = None,
system_file: Annotated[Optional[str], typer.Option(help="Specify raw text file containing system prompt")] = None,
# model flags
model: Annotated[Optional[str], typer.Option(help="Specify the LLM model")] = None,
model_wrapper: Annotated[Optional[str], typer.Option(help="Specify the LLM model wrapper")] = None,
model_endpoint: Annotated[Optional[str], typer.Option(help="Specify the LLM model endpoint")] = None,
model_endpoint_type: Annotated[Optional[str], typer.Option(help="Specify the LLM model endpoint type")] = None,
context_window: Annotated[
Optional[int], typer.Option(help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)")
] = None,
core_memory_limit: Annotated[
Optional[int], typer.Option(help="The character limit to each core-memory section (human/persona).")
] = CORE_MEMORY_BLOCK_CHAR_LIMIT,
# other
first: Annotated[bool, typer.Option(help="Use --first to send the first message in the sequence")] = False,
strip_ui: Annotated[bool, typer.Option(help="Remove all the bells and whistles in CLI output (helpful for testing)")] = False,
debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False,
yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False,
# streaming
stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False,
# whether or not to put the inner thoughts inside the function args
no_content: Annotated[
OptionState, typer.Option(help="Set to 'yes' for LLM APIs that omit the `content` field during tool calling")
] = OptionState.DEFAULT,
):
"""Start chatting with an Letta agent
Example usage: `letta run --agent myagent --data-source mydata --persona mypersona --human myhuman --model gpt-3.5-turbo`
:param persona: Specify persona
:param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name)
:param human: Specify human
:param model: Specify the LLM model
"""
# setup logger
# TODO: remove Utils Debug after global logging is complete.
utils.DEBUG = debug
# TODO: add logging command line options for runtime log level
from letta.server.server import logger as server_logger
if debug:
logger.setLevel(logging.DEBUG)
server_logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.CRITICAL)
server_logger.setLevel(logging.CRITICAL)
# load config file
config = LettaConfig.load()
# read user id from config
client = create_client()
# determine agent to use, if not provided
if not yes and not agent:
agents = client.list_agents()
agents = [a.name for a in agents]
if len(agents) > 0:
print()
select_agent = questionary.confirm("Would you like to select an existing agent?").ask()
if select_agent is None:
raise KeyboardInterrupt
if select_agent:
agent = questionary.select("Select agent:", choices=agents).ask()
# create agent config
if agent:
agent_id = client.get_agent_id(agent)
agent_state = client.get_agent(agent_id)
else:
agent_state = None
human = human if human else config.human
persona = persona if persona else config.persona
if agent and agent_state: # use existing agent
typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN)
printd("Loading agent state:", agent_state.id)
printd("Agent state:", agent_state.name)
# printd("State path:", agent_config.save_state_dir())
# printd("Persistent manager path:", agent_config.save_persistence_manager_dir())
# printd("Index path:", agent_config.save_agent_index_dir())
# TODO: load prior agent state
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
if model and model != agent_state.llm_config.model:
typer.secho(
f"{CLI_WARNING_PREFIX}Overriding existing model {agent_state.llm_config.model} with {model}", fg=typer.colors.YELLOW
)
agent_state.llm_config.model = model
if context_window is not None and int(context_window) != agent_state.llm_config.context_window:
typer.secho(
f"{CLI_WARNING_PREFIX}Overriding existing context window {agent_state.llm_config.context_window} with {context_window}",
fg=typer.colors.YELLOW,
)
agent_state.llm_config.context_window = context_window
if model_wrapper and model_wrapper != agent_state.llm_config.model_wrapper:
typer.secho(
f"{CLI_WARNING_PREFIX}Overriding existing model wrapper {agent_state.llm_config.model_wrapper} with {model_wrapper}",
fg=typer.colors.YELLOW,
)
agent_state.llm_config.model_wrapper = model_wrapper
if model_endpoint and model_endpoint != agent_state.llm_config.model_endpoint:
typer.secho(
f"{CLI_WARNING_PREFIX}Overriding existing model endpoint {agent_state.llm_config.model_endpoint} with {model_endpoint}",
fg=typer.colors.YELLOW,
)
agent_state.llm_config.model_endpoint = model_endpoint
if model_endpoint_type and model_endpoint_type != agent_state.llm_config.model_endpoint_type:
typer.secho(
f"{CLI_WARNING_PREFIX}Overriding existing model endpoint type {agent_state.llm_config.model_endpoint_type} with {model_endpoint_type}",
fg=typer.colors.YELLOW,
)
agent_state.llm_config.model_endpoint_type = model_endpoint_type
# NOTE: commented out because this seems dangerous - instead users should use /systemswap when in the CLI
# # user specified a new system prompt
# if system:
# # NOTE: agent_state.system is the ORIGINAL system prompt,
# # whereas agent_state.state["system"] is the LATEST system prompt
# existing_system_prompt = agent_state.state["system"] if "system" in agent_state.state else None
# if existing_system_prompt != system:
# # override
# agent_state.state["system"] = system
# Update the agent with any overrides
agent_state = client.update_agent(
agent_id=agent_state.id,
name=agent_state.name,
llm_config=agent_state.llm_config,
embedding_config=agent_state.embedding_config,
)
# create agent
letta_agent = Agent(agent_state=agent_state, interface=interface(), user=client.user)
else: # create new agent
# create new agent config: override defaults with args if provided
typer.secho("\n🧬 Creating new agent...", fg=typer.colors.WHITE)
agent_name = agent if agent else utils.create_random_username()
# create agent
client = create_client()
# choose from list of llm_configs
llm_configs = client.list_llm_configs()
llm_options = [llm_config.model for llm_config in llm_configs]
llm_choices = [questionary.Choice(title=llm_config.pretty_print(), value=llm_config) for llm_config in llm_configs]
# select model
if len(llm_options) == 0:
raise ValueError("No LLM models found. Please enable a provider.")
elif len(llm_options) == 1:
llm_model_name = llm_options[0]
else:
llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model
llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0]
# option to override context window
if llm_config.context_window is not None:
context_window_validator = lambda x: x.isdigit() and int(x) > MIN_CONTEXT_WINDOW and int(x) <= llm_config.context_window
context_window_input = questionary.text(
"Select LLM context window limit (hit enter for default):",
default=str(llm_config.context_window),
validate=context_window_validator,
).ask()
if context_window_input is not None:
llm_config.context_window = int(context_window_input)
else:
sys.exit(1)
# choose form list of embedding configs
embedding_configs = client.list_embedding_configs()
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
embedding_choices = [
questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs
]
# select model
if len(embedding_options) == 0:
raise ValueError("No embedding models found. Please enable a provider.")
elif len(embedding_options) == 1:
embedding_model_name = embedding_options[0]
else:
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model
embedding_config = [
embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
][0]
human_obj = client.get_human(client.get_human_id(name=human))
persona_obj = client.get_persona(client.get_persona_id(name=persona))
if human_obj is None:
typer.secho(f"Couldn't find human {human} in database, please run `letta add human`", fg=typer.colors.RED)
sys.exit(1)
if persona_obj is None:
typer.secho(f"Couldn't find persona {persona} in database, please run `letta add persona`", fg=typer.colors.RED)
sys.exit(1)
if system_file:
try:
with open(system_file, "r", encoding="utf-8") as file:
system = file.read().strip()
printd("Loaded system file successfully.")
except FileNotFoundError:
typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED)
system_prompt = system if system else None
memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit)
metadata = {"human": human_obj.template_name, "persona": persona_obj.template_name}
typer.secho(f"-> {ASSISTANT_MESSAGE_CLI_SYMBOL} Using persona profile: '{persona_obj.template_name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.template_name}'", fg=typer.colors.WHITE)
# add tools
agent_state = client.create_agent(
name=agent_name,
system=system_prompt,
embedding_config=embedding_config,
llm_config=llm_config,
memory=memory,
metadata=metadata,
)
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t.name for t in agent_state.tools])}", fg=typer.colors.WHITE)
letta_agent = Agent(
interface=interface(),
agent_state=client.get_agent(agent_state.id),
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
user=client.user,
)
save_agent(agent=letta_agent)
typer.secho(f"🎉 Created new agent '{letta_agent.agent_state.name}' (id={letta_agent.agent_state.id})", fg=typer.colors.GREEN)
# start event loop
from letta.main import run_agent_loop
print() # extra space
run_agent_loop(
letta_agent=letta_agent,
config=config,
first=first,
no_verify=no_verify,
stream=stream,
) # TODO: add back no_verify
def delete_agent(
agent_name: Annotated[str, typer.Option(help="Specify agent to delete")],
):
"""Delete an agent from the database"""
# use client ID is no user_id provided
config = LettaConfig.load()
MetadataStore(config)
client = create_client()
agent = client.get_agent_by_name(agent_name)
if not agent:
typer.secho(f"Couldn't find agent named '{agent_name}' to delete", fg=typer.colors.RED)
sys.exit(1)
confirm = questionary.confirm(f"Are you sure you want to delete agent '{agent_name}' (id={agent.id})?", default=False).ask()
if confirm is None:
raise KeyboardInterrupt
if not confirm:
typer.secho(f"Cancelled agent deletion '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN)
return
try:
# delete the agent
client.delete_agent(agent.id)
typer.secho(f"🕊️ Successfully deleted agent '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN)
except Exception:
typer.secho(f"Failed to delete agent '{agent_name}' (id={agent.id})", fg=typer.colors.RED)
sys.exit(1)
def version() -> str:
import letta

View File

@ -1,227 +0,0 @@
import ast
import os
from enum import Enum
from typing import Annotated, List, Optional
import questionary
import typer
from prettytable.colortable import ColorTable, Themes
from tqdm import tqdm
import letta.helpers.datetime_helpers
app = typer.Typer()
@app.command()
def configure():
"""Updates default Letta configurations
This function and quickstart should be the ONLY place where LettaConfig.save() is called
"""
print("`letta configure` has been deprecated. Please see documentation on configuration, and run `letta run` instead.")
class ListChoice(str, Enum):
agents = "agents"
humans = "humans"
personas = "personas"
sources = "sources"
@app.command()
def list(arg: Annotated[ListChoice, typer.Argument]):
from letta.client.client import create_client
client = create_client()
table = ColorTable(theme=Themes.OCEAN)
if arg == ListChoice.agents:
"""List all agents"""
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
for agent in tqdm(client.list_agents()):
# TODO: add this function
sources = client.list_attached_sources(agent_id=agent.id)
source_names = [source.name for source in sources if source is not None]
table.add_row(
[
agent.name,
agent.llm_config.model,
agent.embedding_config.embedding_model,
agent.embedding_config.embedding_dim,
agent.memory.get_block("persona").value[:100] + "...",
agent.memory.get_block("human").value[:100] + "...",
",".join(source_names),
letta.helpers.datetime_helpers.format_datetime(agent.created_at),
]
)
print(table)
elif arg == ListChoice.humans:
"""List all humans"""
table.field_names = ["Name", "Text"]
for human in client.list_humans():
table.add_row([human.template_name, human.value.replace("\n", "")[:100]])
elif arg == ListChoice.personas:
"""List all personas"""
table.field_names = ["Name", "Text"]
for persona in client.list_personas():
table.add_row([persona.template_name, persona.value.replace("\n", "")[:100]])
print(table)
elif arg == ListChoice.sources:
"""List all data sources"""
# create table
table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At"]
# TODO: eventually look accross all storage connections
# TODO: add data source stats
# TODO: connect to agents
# get all sources
for source in client.list_sources():
# get attached agents
table.add_row(
[
source.name,
source.description,
source.embedding_config.embedding_model,
source.embedding_config.embedding_dim,
letta.helpers.datetime_helpers.format_datetime(source.created_at),
]
)
print(table)
else:
raise ValueError(f"Unknown argument {arg}")
return table
@app.command()
def add_tool(
filename: str = typer.Option(..., help="Path to the Python file containing the function"),
name: Optional[str] = typer.Option(None, help="Name of the tool"),
update: bool = typer.Option(True, help="Update the tool if it already exists"),
tags: Optional[List[str]] = typer.Option(None, help="Tags for the tool"),
):
"""Add or update a tool from a Python file."""
from letta.client.client import create_client
client = create_client()
# 1. Parse the Python file
with open(filename, "r", encoding="utf-8") as file:
source_code = file.read()
# 2. Parse the source code to extract the function
# Note: here we assume it is one function only in the file.
module = ast.parse(source_code)
func_def = None
for node in module.body:
if isinstance(node, ast.FunctionDef):
func_def = node
break
if not func_def:
raise ValueError("No function found in the provided file")
# 3. Compile the function to make it callable
# Explanation courtesy of GPT-4:
# Compile the AST (Abstract Syntax Tree) node representing the function definition into a code object
# ast.Module creates a module node containing the function definition (func_def)
# compile converts the AST into a code object that can be executed by the Python interpreter
# The exec function executes the compiled code object in the current context,
# effectively defining the function within the current namespace
exec(compile(ast.Module([func_def], []), filename, "exec"))
# Retrieve the function object by evaluating its name in the current namespace
# eval looks up the function name in the current scope and returns the function object
func = eval(func_def.name)
# 4. Add or update the tool
tool = client.create_or_update_tool(func=func, tags=tags, update=update)
print(f"Tool {tool.name} added successfully")
@app.command()
def list_tools():
"""List all available tools."""
from letta.client.client import create_client
client = create_client()
tools = client.list_tools()
for tool in tools:
print(f"Tool: {tool.name}")
@app.command()
def add(
option: str, # [human, persona]
name: Annotated[str, typer.Option(help="Name of human/persona")],
text: Annotated[Optional[str], typer.Option(help="Text of human/persona")] = None,
filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None,
):
"""Add a person/human"""
from letta.client.client import create_client
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
if filename: # read from file
assert text is None, "Cannot specify both text and filename"
with open(filename, "r", encoding="utf-8") as f:
text = f.read()
else:
assert text is not None, "Must specify either text or filename"
if option == "persona":
persona_id = client.get_persona_id(name)
if persona_id:
client.get_persona(persona_id)
# config if user wants to overwrite
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
return
client.update_persona(persona_id, text=text)
else:
client.create_persona(name=name, text=text)
elif option == "human":
human_id = client.get_human_id(name)
if human_id:
human = client.get_human(human_id)
# config if user wants to overwrite
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
return
client.update_human(human_id, text=text)
else:
human = client.create_human(name=name, text=text)
else:
raise ValueError(f"Unknown kind {option}")
@app.command()
def delete(option: str, name: str):
"""Delete a source from the archival memory."""
from letta.client.client import create_client
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_API_KEY"))
try:
# delete from metadata
if option == "source":
# delete metadata
source_id = client.get_source_id(name)
assert source_id is not None, f"Source {name} does not exist"
client.delete_source(source_id)
elif option == "agent":
agent_id = client.get_agent_id(name)
assert agent_id is not None, f"Agent {name} does not exist"
client.delete_agent(agent_id=agent_id)
elif option == "human":
human_id = client.get_human_id(name)
assert human_id is not None, f"Human {name} does not exist"
client.delete_human(human_id)
elif option == "persona":
persona_id = client.get_persona_id(name)
assert persona_id is not None, f"Persona {name} does not exist"
client.delete_persona(persona_id)
else:
raise ValueError(f"Option {option} not implemented")
typer.secho(f"Deleted {option} '{name}'", fg=typer.colors.GREEN)
except Exception as e:
typer.secho(f"Failed to delete {option}'{name}'\n{e}", fg=typer.colors.RED)

View File

@ -8,61 +8,9 @@ letta load <data-connector-type> --name <dataset-name> [ADDITIONAL ARGS]
"""
import uuid
from typing import Annotated, List, Optional
import questionary
import typer
from letta import create_client
from letta.data_sources.connectors import DirectoryConnector
app = typer.Typer()
default_extensions = "txt,md,pdf"
@app.command("directory")
def load_directory(
name: Annotated[str, typer.Option(help="Name of dataset to load.")],
input_dir: Annotated[Optional[str], typer.Option(help="Path to directory containing dataset.")] = None,
input_files: Annotated[List[str], typer.Option(help="List of paths to files containing dataset.")] = [],
recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False,
extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions,
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None,
):
client = create_client()
# create connector
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
# choose form list of embedding configs
embedding_configs = client.list_embedding_configs()
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
embedding_choices = [
questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs
]
# select model
if len(embedding_options) == 0:
raise ValueError("No embedding models found. Please enable a provider.")
elif len(embedding_options) == 1:
embedding_model_name = embedding_options[0]
else:
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model
embedding_config = [
embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
][0]
# create source
source = client.create_source(name=name, embedding_config=embedding_config)
# load data
try:
client.load_data(connector, source_name=name)
except Exception as e:
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
client.delete_source(source.id)

File diff suppressed because it is too large Load Diff

View File

@ -37,7 +37,9 @@ class DataConnector:
"""
def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, source_manager: SourceManager, actor: "User"):
async def load_data(
connector: DataConnector, source: Source, passage_manager: PassageManager, source_manager: SourceManager, actor: "User"
):
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
embedding_config = source.embedding_config
@ -51,7 +53,7 @@ def load_data(connector: DataConnector, source: Source, passage_manager: Passage
file_count = 0
for file_metadata in connector.find_files(source):
file_count += 1
source_manager.create_file(file_metadata, actor)
await source_manager.create_file(file_metadata, actor)
# generate passages
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):

View File

@ -1,5 +1,7 @@
import ast
import builtins
import json
import typing
from typing import Dict, Optional, Tuple
from letta.errors import LettaToolCreateError
@ -22,7 +24,7 @@ def resolve_type(annotation: str):
Resolve a type annotation string into a Python type.
Args:
annotation (str): The annotation string (e.g., 'int', 'list', etc.).
annotation (str): The annotation string (e.g., 'int', 'list[int]', 'dict[str, int]').
Returns:
type: The corresponding Python type.
@ -34,24 +36,17 @@ def resolve_type(annotation: str):
return BUILTIN_TYPES[annotation]
try:
if annotation.startswith("list["):
inner_type = annotation[len("list[") : -1]
resolve_type(inner_type)
return list
elif annotation.startswith("dict["):
inner_types = annotation[len("dict[") : -1]
key_type, value_type = inner_types.split(",")
return dict
elif annotation.startswith("tuple["):
inner_types = annotation[len("tuple[") : -1]
[resolve_type(t.strip()) for t in inner_types.split(",")]
return tuple
parsed = ast.literal_eval(annotation)
if isinstance(parsed, type):
return parsed
raise ValueError(f"Annotation '{annotation}' is not a recognized type.")
except (ValueError, SyntaxError):
# Allow use of typing and builtins in a safe eval context
namespace = {
**vars(typing),
**vars(builtins),
"list": list,
"dict": dict,
"tuple": tuple,
"set": set,
}
return eval(annotation, namespace)
except Exception:
raise ValueError(f"Unsupported annotation: {annotation}")
@ -82,41 +77,36 @@ def get_function_annotations_from_source(source_code: str, function_name: str) -
def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict:
"""
Coerce arguments in a dictionary to their annotated types.
Args:
function_args (dict): The original function arguments.
annotations (Dict[str, str]): Argument annotations as strings.
Returns:
dict: The updated dictionary with coerced argument types.
Raises:
ValueError: If type coercion fails for an argument.
"""
coerced_args = dict(function_args) # Shallow copy for mutation safety
coerced_args = dict(function_args) # Shallow copy
for arg_name, value in coerced_args.items():
if arg_name in annotations:
annotation_str = annotations[arg_name]
try:
# Resolve the type from the annotation
arg_type = resolve_type(annotation_str)
# Handle JSON-like inputs for dict and list types
if arg_type in {dict, list} and isinstance(value, str):
# Always parse strings using literal_eval or json if possible
if isinstance(value, str):
try:
# First, try JSON parsing
value = json.loads(value)
except json.JSONDecodeError:
# Fall back to literal_eval for Python-specific literals
value = ast.literal_eval(value)
try:
value = ast.literal_eval(value)
except (SyntaxError, ValueError) as e:
if arg_type is not str:
raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}")
# Coerce the value to the resolved type
coerced_args[arg_name] = arg_type(value)
except (TypeError, ValueError, json.JSONDecodeError, SyntaxError) as e:
origin = typing.get_origin(arg_type)
if origin in (list, dict, tuple, set):
# Let the origin (e.g., list) handle coercion
coerced_args[arg_name] = origin(value)
else:
# Coerce simple types (e.g., int, float)
coerced_args[arg_name] = arg_type(value)
except Exception as e:
raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}")
return coerced_args

View File

@ -19,6 +19,8 @@ from letta.services.group_manager import GroupManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.step_manager import NoopStepManager, StepManager
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
class SleeptimeMultiAgentV2(BaseAgent):
@ -32,6 +34,8 @@ class SleeptimeMultiAgentV2(BaseAgent):
group_manager: GroupManager,
job_manager: JobManager,
actor: User,
step_manager: StepManager = NoopStepManager(),
telemetry_manager: TelemetryManager = NoopTelemetryManager(),
group: Optional[Group] = None,
):
super().__init__(
@ -45,11 +49,18 @@ class SleeptimeMultiAgentV2(BaseAgent):
self.passage_manager = passage_manager
self.group_manager = group_manager
self.job_manager = job_manager
self.step_manager = step_manager
self.telemetry_manager = telemetry_manager
# Group settings
assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}"
self.group = group
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
async def step(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
use_assistant_message: bool = True,
) -> LettaResponse:
run_ids = []
# Prepare new messages
@ -68,22 +79,26 @@ class SleeptimeMultiAgentV2(BaseAgent):
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
)
# Perform foreground agent step
response = await foreground_agent.step(input_messages=new_messages, max_steps=max_steps)
response = await foreground_agent.step(
input_messages=new_messages, max_steps=max_steps, use_assistant_message=use_assistant_message
)
# Get last response messages
last_response_messages = foreground_agent.response_messages
# Update turns counter
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor)
# Perform participant steps
if self.group.sleeptime_agent_frequency is None or (
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
):
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async(
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
)
for participant_agent_id in self.group.agent_ids:
@ -92,6 +107,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
participant_agent_id,
last_response_messages,
last_processed_message_id,
use_assistant_message,
)
run_ids.append(run_id)
@ -103,7 +119,13 @@ class SleeptimeMultiAgentV2(BaseAgent):
response.usage.run_ids = run_ids
return response
async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]:
async def step_stream(
self,
input_messages: List[MessageCreate],
max_steps: int = 10,
use_assistant_message: bool = True,
request_start_timestamp_ns: Optional[int] = None,
) -> AsyncGenerator[str, None]:
# Prepare new messages
new_messages = []
for message in input_messages:
@ -120,9 +142,16 @@ class SleeptimeMultiAgentV2(BaseAgent):
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
)
# Perform foreground agent step
async for chunk in foreground_agent.step_stream(input_messages=new_messages, max_steps=max_steps):
async for chunk in foreground_agent.step_stream(
input_messages=new_messages,
max_steps=max_steps,
use_assistant_message=use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
):
yield chunk
# Get response messages
@ -130,20 +159,21 @@ class SleeptimeMultiAgentV2(BaseAgent):
# Update turns counter
if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0:
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor)
turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor)
# Perform participant steps
if self.group.sleeptime_agent_frequency is None or (
turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0
):
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async(
group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor
)
for sleeptime_agent_id in self.group.agent_ids:
self._issue_background_task(
run_id = await self._issue_background_task(
sleeptime_agent_id,
last_response_messages,
last_processed_message_id,
use_assistant_message,
)
async def _issue_background_task(
@ -151,6 +181,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
sleeptime_agent_id: str,
response_messages: List[Message],
last_processed_message_id: str,
use_assistant_message: bool = True,
) -> str:
run = Run(
user_id=self.actor.id,
@ -160,7 +191,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
"agent_id": sleeptime_agent_id,
},
)
run = self.job_manager.create_job(pydantic_job=run, actor=self.actor)
run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor)
asyncio.create_task(
self._participant_agent_step(
@ -169,6 +200,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
response_messages=response_messages,
last_processed_message_id=last_processed_message_id,
run_id=run.id,
use_assistant_message=True,
)
)
return run.id
@ -180,11 +212,12 @@ class SleeptimeMultiAgentV2(BaseAgent):
response_messages: List[Message],
last_processed_message_id: str,
run_id: str,
use_assistant_message: bool = True,
) -> str:
try:
# Update job status
job_update = JobUpdate(status=JobStatus.running)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
# Create conversation transcript
prior_messages = []
@ -221,11 +254,14 @@ class SleeptimeMultiAgentV2(BaseAgent):
block_manager=self.block_manager,
passage_manager=self.passage_manager,
actor=self.actor,
step_manager=self.step_manager,
telemetry_manager=self.telemetry_manager,
)
# Perform sleeptime agent step
result = await sleeptime_agent.step(
input_messages=sleeptime_agent_messages,
use_assistant_message=use_assistant_message,
)
# Update job status
@ -237,7 +273,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
"agent_id": sleeptime_agent_id,
},
)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
return result
except Exception as e:
job_update = JobUpdate(
@ -245,5 +281,5 @@ class SleeptimeMultiAgentV2(BaseAgent):
completed_at=datetime.now(timezone.utc).replace(tzinfo=None),
metadata={"error": str(e)},
)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor)
await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor)
raise

View File

@ -106,7 +106,7 @@ async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob],
results: List[BatchPollingResult] = await asyncio.gather(*coros)
# Update the server with batch status changes
server.batch_manager.bulk_update_llm_batch_statuses(updates=results)
await server.batch_manager.bulk_update_llm_batch_statuses_async(updates=results)
logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.")
return results
@ -197,13 +197,13 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo
# 6. Bulk update all items for newly completed batch(es)
if item_updates:
metrics.updated_items_count = len(item_updates)
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates)
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(item_updates)
# ─── Kick off postprocessing for each batch that just completed ───
completed = [r for r in batch_results if r.request_status == JobStatus.completed]
async def _resume(batch_row: LLMBatchJob) -> LettaBatchResponse:
actor: User = server.user_manager.get_user_by_id(batch_row.created_by_id)
actor: User = await server.user_manager.get_actor_by_id_async(batch_row.created_by_id)
runner = LettaAgentBatch(
message_manager=server.message_manager,
agent_manager=server.agent_manager,

View File

@ -4,10 +4,11 @@ from typing import Optional
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.interval import IntervalTrigger
from sqlalchemy import text
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
from letta.log import get_logger
from letta.server.db import db_context
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.settings import settings
@ -34,18 +35,16 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
acquired_lock = False
try:
# Use a temporary connection context for the attempt initially
with db_context() as session:
engine = session.get_bind()
# Get raw connection - MUST be kept open if lock is acquired
raw_conn = engine.raw_connection()
cur = raw_conn.cursor()
async with db_registry.async_session() as session:
raw_conn = await session.connection()
cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
acquired_lock = cur.fetchone()[0]
# Try to acquire the advisory lock
sql = text("SELECT pg_try_advisory_lock(CAST(:lock_key AS bigint))")
result = await session.execute(sql, {"lock_key": ADVISORY_LOCK_KEY})
acquired_lock = result.scalar_one()
if not acquired_lock:
cur.close()
raw_conn.close()
await raw_conn.close()
logger.info("Scheduler lock held by another instance.")
return False
@ -106,14 +105,14 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool:
# Clean up temporary resources if lock wasn't acquired or error occurred before storing
if cur:
try:
cur.close()
except:
pass
await cur.close()
except Exception as e:
logger.warning(f"Error closing cursor: {e}")
if raw_conn:
try:
raw_conn.close()
except:
pass
await raw_conn.close()
except Exception as e:
logger.warning(f"Error closing connection: {e}")
async def _background_lock_retry_loop(server: SyncServer):
@ -161,7 +160,9 @@ async def _release_advisory_lock():
try:
if not lock_conn.closed:
if not lock_cur.closed:
lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,))
# Use SQLAlchemy text() for raw SQL
unlock_sql = text("SELECT pg_advisory_unlock(CAST(:lock_key AS bigint))")
lock_cur.execute(unlock_sql, {"lock_key": ADVISORY_LOCK_KEY})
lock_cur.fetchone() # Consume result
lock_conn.commit()
logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}")
@ -175,12 +176,12 @@ async def _release_advisory_lock():
# Ensure resources are closed regardless of unlock success
try:
if lock_cur and not lock_cur.closed:
lock_cur.close()
await lock_cur.close()
except Exception as e:
logger.error(f"Error closing advisory lock cursor: {e}", exc_info=True)
try:
if lock_conn and not lock_conn.closed:
lock_conn.close()
await lock_conn.close()
logger.info("Closed database connection that held advisory lock.")
except Exception as e:
logger.error(f"Error closing advisory lock connection: {e}", exc_info=True)

View File

@ -45,11 +45,13 @@ logger = get_logger(__name__)
class AnthropicClient(LLMClientBase):
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(llm_config, async_client=False)
response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(llm_config, async_client=True)
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@ -339,6 +341,7 @@ class AnthropicClient(LLMClientBase):
# TODO: Input messages doesn't get used here
# TODO: Clean up this interface
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,

View File

@ -17,6 +17,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.settings import model_settings, settings
from letta.tracing import trace_method
from letta.utils import get_tool_call_id
logger = get_logger(__name__)
@ -32,6 +33,7 @@ class GoogleVertexClient(LLMClientBase):
http_options={"api_version": "v1"},
)
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
@ -44,6 +46,7 @@ class GoogleVertexClient(LLMClientBase):
)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
@ -189,6 +192,7 @@ class GoogleVertexClient(LLMClientBase):
return [{"functionDeclarations": function_list}]
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
@ -248,6 +252,7 @@ class GoogleVertexClient(LLMClientBase):
return request_data
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,

View File

@ -32,6 +32,7 @@ from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import model_settings
from letta.tracing import trace_method
logger = get_logger(__name__)
@ -124,6 +125,7 @@ class OpenAIClient(LLMClientBase):
return kwargs
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
@ -213,6 +215,7 @@ class OpenAIClient(LLMClientBase):
return data.model_dump(exclude_unset=True)
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
@ -222,6 +225,7 @@ class OpenAIClient(LLMClientBase):
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
@ -230,6 +234,7 @@ class OpenAIClient(LLMClientBase):
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,

View File

@ -1,374 +1,14 @@
import os
import sys
import traceback
import questionary
import requests
import typer
from rich.console import Console
import letta.agent as agent
import letta.errors as errors
import letta.system as system
# import benchmark
from letta import create_client
from letta.benchmark.benchmark import bench
from letta.cli.cli import delete_agent, open_folder, run, server, version
from letta.cli.cli_config import add, add_tool, configure, delete, list, list_tools
from letta.cli.cli import server
from letta.cli.cli_load import app as load_app
from letta.config import LettaConfig
from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE
# from letta.interface import CLIInterface as interface # for printing to terminal
from letta.streaming_interface import AgentRefreshStreamingInterface
# interface = interface()
# disable composio print on exit
os.environ["COMPOSIO_DISABLE_VERSION_CHECK"] = "true"
app = typer.Typer(pretty_exceptions_enable=False)
app.command(name="run")(run)
app.command(name="version")(version)
app.command(name="configure")(configure)
app.command(name="list")(list)
app.command(name="add")(add)
app.command(name="add-tool")(add_tool)
app.command(name="list-tools")(list_tools)
app.command(name="delete")(delete)
app.command(name="server")(server)
app.command(name="folder")(open_folder)
# load data commands
app.add_typer(load_app, name="load")
# benchmark command
app.command(name="benchmark")(bench)
# delete agents
app.command(name="delete-agent")(delete_agent)
def clear_line(console, strip_ui=False):
if strip_ui:
return
if os.name == "nt": # for windows
console.print("\033[A\033[K", end="")
else: # for linux
sys.stdout.write("\033[2K\033[G")
sys.stdout.flush()
def run_agent_loop(
letta_agent: agent.Agent,
config: LettaConfig,
first: bool,
no_verify: bool = False,
strip_ui: bool = False,
stream: bool = False,
):
if isinstance(letta_agent.interface, AgentRefreshStreamingInterface):
# letta_agent.interface.toggle_streaming(on=stream)
if not stream:
letta_agent.interface = letta_agent.interface.nonstreaming_interface
if hasattr(letta_agent.interface, "console"):
console = letta_agent.interface.console
else:
console = Console()
counter = 0
user_input = None
skip_next_user_input = False
user_message = None
USER_GOES_FIRST = first
if not USER_GOES_FIRST:
console.input("[bold cyan]Hit enter to begin (will request first Letta message)[/bold cyan]\n")
clear_line(console, strip_ui=strip_ui)
print()
multiline_input = False
# create client
client = create_client()
# run loops
while True:
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
# Ask for user input
if not stream:
print()
user_input = questionary.text(
"Enter your message:",
multiline=multiline_input,
qmark=">",
).ask()
clear_line(console, strip_ui=strip_ui)
if not stream:
print()
# Gracefully exit on Ctrl-C/D
if user_input is None:
user_input = "/exit"
user_input = user_input.rstrip()
if user_input.startswith("!"):
print(f"Commands for CLI begin with '/' not '!'")
continue
if user_input == "":
# no empty messages allowed
print("Empty input received. Try again!")
continue
# Handle CLI commands
# Commands to not get passed as input to Letta
if user_input.startswith("/"):
# updated agent save functions
if user_input.lower() == "/exit":
# letta_agent.save()
agent.save_agent(letta_agent)
break
elif user_input.lower() == "/save" or user_input.lower() == "/savechat":
# letta_agent.save()
agent.save_agent(letta_agent)
continue
elif user_input.lower() == "/attach":
# TODO: check if agent already has it
# TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not
# TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources
sources = client.list_sources()
if len(sources) == 0:
typer.secho(
'No sources available. You must load a souce with "letta load ..." before running /attach.',
fg=typer.colors.RED,
bold=True,
)
continue
# determine what sources are valid to be attached to this agent
valid_options = []
invalid_options = []
for source in sources:
if source.embedding_config == letta_agent.agent_state.embedding_config:
valid_options.append(source.name)
else:
# print warning about invalid sources
typer.secho(
f"Source {source.name} exists but has embedding dimentions {source.embedding_dim} from model {source.embedding_model}, while the agent uses embedding dimentions {letta_agent.agent_state.embedding_config.embedding_dim} and model {letta_agent.agent_state.embedding_config.embedding_model}",
fg=typer.colors.YELLOW,
)
invalid_options.append(source.name)
# prompt user for data source selection
data_source = questionary.select("Select data source", choices=valid_options).ask()
# attach new data
client.attach_source_to_agent(agent_id=letta_agent.agent_state.id, source_name=data_source)
continue
elif user_input.lower() == "/dump" or user_input.lower().startswith("/dump "):
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
if amount == 0:
letta_agent.interface.print_messages(letta_agent._messages, dump=True)
else:
letta_agent.interface.print_messages(letta_agent._messages[-min(amount, len(letta_agent.messages)) :], dump=True)
continue
elif user_input.lower() == "/dumpraw":
letta_agent.interface.print_messages_raw(letta_agent._messages)
continue
elif user_input.lower() == "/memory":
print(f"\nDumping memory contents:\n")
print(f"{letta_agent.agent_state.memory.compile()}")
print(f"{letta_agent.archival_memory.compile()}")
continue
elif user_input.lower() == "/model":
print(f"Current model: {letta_agent.agent_state.llm_config.model}")
continue
elif user_input.lower() == "/summarize":
try:
letta_agent.summarize_messages_inplace()
typer.secho(
f"/summarize succeeded",
fg=typer.colors.GREEN,
bold=True,
)
except (errors.LLMError, requests.exceptions.HTTPError) as e:
typer.secho(
f"/summarize failed:\n{e}",
fg=typer.colors.RED,
bold=True,
)
continue
elif user_input.lower() == "/tokens":
tokens = letta_agent.count_tokens()
typer.secho(
f"{tokens}/{letta_agent.agent_state.llm_config.context_window}",
fg=typer.colors.GREEN,
bold=True,
)
continue
elif user_input.lower().startswith("/add_function"):
try:
if len(user_input) < len("/add_function "):
print("Missing function name after the command")
continue
function_name = user_input[len("/add_function ") :].strip()
result = letta_agent.add_function(function_name)
typer.secho(
f"/add_function succeeded: {result}",
fg=typer.colors.GREEN,
bold=True,
)
except ValueError as e:
typer.secho(
f"/add_function failed:\n{e}",
fg=typer.colors.RED,
bold=True,
)
continue
elif user_input.lower().startswith("/remove_function"):
try:
if len(user_input) < len("/remove_function "):
print("Missing function name after the command")
continue
function_name = user_input[len("/remove_function ") :].strip()
result = letta_agent.remove_function(function_name)
typer.secho(
f"/remove_function succeeded: {result}",
fg=typer.colors.GREEN,
bold=True,
)
except ValueError as e:
typer.secho(
f"/remove_function failed:\n{e}",
fg=typer.colors.RED,
bold=True,
)
continue
# No skip options
elif user_input.lower() == "/wipe":
letta_agent = agent.Agent(letta_agent.interface)
user_message = None
elif user_input.lower() == "/heartbeat":
user_message = system.get_heartbeat()
elif user_input.lower() == "/memorywarning":
user_message = system.get_token_limit_warning()
elif user_input.lower() == "//":
multiline_input = not multiline_input
continue
elif user_input.lower() == "/" or user_input.lower() == "/help":
questionary.print("CLI commands", "bold")
for cmd, desc in USER_COMMANDS:
questionary.print(cmd, "bold")
questionary.print(f" {desc}")
continue
else:
print(f"Unrecognized command: {user_input}")
continue
else:
# If message did not begin with command prefix, pass inputs to Letta
# Handle user message and append to messages
user_message = str(user_input)
skip_next_user_input = False
def process_agent_step(user_message, no_verify):
# TODO(charles): update to use agent.step() instead of inner_step()
if user_message is None:
step_response = letta_agent.inner_step(
messages=[],
first_message=False,
skip_verify=no_verify,
stream=stream,
)
else:
step_response = letta_agent.step_user_message(
user_message_str=user_message,
first_message=False,
skip_verify=no_verify,
stream=stream,
)
new_messages = step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
step_response.usage
agent.save_agent(letta_agent)
skip_next_user_input = False
if token_warning:
user_message = system.get_token_limit_warning()
skip_next_user_input = True
elif function_failed:
user_message = system.get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE)
skip_next_user_input = True
elif heartbeat_request:
user_message = system.get_heartbeat(REQ_HEARTBEAT_MESSAGE)
skip_next_user_input = True
return new_messages, user_message, skip_next_user_input
while True:
try:
if strip_ui:
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
else:
if stream:
# Don't display the "Thinking..." if streaming
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
else:
with console.status("[bold cyan]Thinking...") as status:
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
except KeyboardInterrupt:
print("User interrupt occurred.")
retry = questionary.confirm("Retry agent.step()?").ask()
if not retry:
break
except Exception:
print("An exception occurred when running agent.step(): ")
traceback.print_exc()
retry = questionary.confirm("Retry agent.step()?").ask()
if not retry:
break
counter += 1
print("Finished.")
USER_COMMANDS = [
("//", "toggle multiline input mode"),
("/exit", "exit the CLI"),
("/save", "save a checkpoint of the current agent/conversation state"),
("/load", "load a saved checkpoint"),
("/dump <count>", "view the last <count> messages (all if <count> is omitted)"),
("/memory", "print the current contents of agent memory"),
("/pop <count>", "undo <count> messages in the conversation (default is 3)"),
("/retry", "pops the last answer and tries to get another one"),
("/rethink <text>", "changes the inner thoughts of the last agent message"),
("/rewrite <text>", "changes the reply of the last agent message"),
("/heartbeat", "send a heartbeat system message to the agent"),
("/memorywarning", "send a memory warning system message to the agent"),
("/attach", "attach data source to agent"),
]

View File

@ -13,6 +13,9 @@ from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
from letta.log import get_logger
from letta.settings import settings
from letta.tracing import trace_method
logger = get_logger(__name__)
logger = get_logger(__name__)
@ -202,6 +205,7 @@ class DatabaseRegistry:
self.initialize_async()
return self._async_session_factories.get(name)
@trace_method
@contextmanager
def session(self, name: str = "default") -> Generator[Any, None, None]:
"""Context manager for database sessions."""
@ -215,6 +219,7 @@ class DatabaseRegistry:
finally:
session.close()
@trace_method
@asynccontextmanager
async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]:
"""Async context manager for database sessions."""

View File

@ -13,10 +13,11 @@ from starlette.responses import Response, StreamingResponse
from letta.agents.letta_agent import LettaAgent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import Block, BlockUpdate
from letta.schemas.group import Group
from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
@ -212,7 +213,7 @@ async def retrieve_agent_context_window(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
return await server.get_agent_context_window_async(agent_id=agent_id, actor=actor)
return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor)
except Exception as e:
traceback.print_exc()
raise e
@ -297,7 +298,7 @@ def detach_tool(
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
def attach_source(
async def attach_source(
agent_id: str,
source_id: str,
background_tasks: BackgroundTasks,
@ -310,7 +311,7 @@ def attach_source(
actor = server.user_manager.get_user_or_default(user_id=actor_id)
agent = server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
if agent.enable_sleeptime:
source = server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id_async(source_id=source_id)
background_tasks.add_task(server.sleeptime_document_ingest, agent, source, actor)
return agent
@ -355,7 +356,7 @@ async def retrieve_agent(
@router.delete("/{agent_id}", response_model=None, operation_id="delete_agent")
def delete_agent(
async def delete_agent(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -363,9 +364,9 @@ def delete_agent(
"""
Delete an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
await server.agent_manager.delete_agent_async(agent_id=agent_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"})
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
@ -386,7 +387,7 @@ async def list_agent_sources(
# TODO: remove? can also get with agent blocks
@router.get("/{agent_id}/core-memory", response_model=Memory, operation_id="retrieve_agent_memory")
def retrieve_agent_memory(
async def retrieve_agent_memory(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -395,13 +396,13 @@ def retrieve_agent_memory(
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.get_agent_memory(agent_id=agent_id, actor=actor)
return await server.get_agent_memory_async(agent_id=agent_id, actor=actor)
@router.get("/{agent_id}/core-memory/blocks/{block_label}", response_model=Block, operation_id="retrieve_core_memory_block")
def retrieve_block(
async def retrieve_block(
agent_id: str,
block_label: str,
server: "SyncServer" = Depends(get_letta_server),
@ -410,10 +411,10 @@ def retrieve_block(
"""
Retrieve a core memory block from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
try:
return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
return await server.agent_manager.get_block_with_label_async(agent_id=agent_id, block_label=block_label, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@ -453,13 +454,13 @@ async def modify_block(
)
# This should also trigger a system prompt change in the agent
server.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True, update_timestamp=False)
await server.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True, update_timestamp=False)
return block
@router.patch("/{agent_id}/core-memory/blocks/attach/{block_id}", response_model=AgentState, operation_id="attach_core_memory_block")
def attach_block(
async def attach_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
@ -468,12 +469,12 @@ def attach_block(
"""
Attach a core memoryblock to an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.attach_block_async(agent_id=agent_id, block_id=block_id, actor=actor)
@router.patch("/{agent_id}/core-memory/blocks/detach/{block_id}", response_model=AgentState, operation_id="detach_core_memory_block")
def detach_block(
async def detach_block(
agent_id: str,
block_id: str,
server: "SyncServer" = Depends(get_letta_server),
@ -482,8 +483,8 @@ def detach_block(
"""
Detach a core memory block from an agent.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.agent_manager.detach_block_async(agent_id=agent_id, block_id=block_id, actor=actor)
@router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages")
@ -637,22 +638,35 @@ async def send_message(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
if agent_eligible and feature_enabled and model_compatible:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
if agent.enable_sleeptime:
experimental_agent = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
group=agent.multi_agent_group,
)
else:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message)
else:
@ -697,23 +711,38 @@ async def send_message_streaming(
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# TODO: This is redundant, remove soon
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"]
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
if agent_eligible and feature_enabled and model_compatible and request.stream_tokens:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
if agent_eligible and feature_enabled and model_compatible:
if agent.enable_sleeptime:
experimental_agent = SleeptimeMultiAgentV2(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
group_manager=server.group_manager,
job_manager=server.job_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
group=agent.multi_agent_group,
)
else:
experimental_agent = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
if request.stream_tokens and model_compatible_token_streaming:

View File

@ -23,7 +23,7 @@ async def list_llm_models(
# Extract user_id from header, default to None if not present
):
"""List available LLM models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
models = await server.list_llm_models_async(
provider_category=provider_category,
@ -42,7 +42,7 @@ async def list_embedding_models(
# Extract user_id from header, default to None if not present
):
"""List available embedding models using the asynchronous implementation for improved performance"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
models = await server.list_embedding_models_async(actor=actor)
return models

View File

@ -161,7 +161,7 @@ async def list_batch_messages(
# Get messages directly using our efficient method
# We'll need to update the underlying implementation to use message_id as cursor
messages = server.batch_manager.get_messages_for_letta_batch(
messages = await server.batch_manager.get_messages_for_letta_batch_async(
letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, sort_descending=sort_descending, cursor=cursor
)
@ -184,7 +184,7 @@ async def cancel_batch_run(
job = await server.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
# Get related llm batch jobs
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=job.id, actor=actor)
for llm_batch_job in llm_batch_jobs:
if llm_batch_job.status in {JobStatus.running, JobStatus.created}:
# TODO: Extend to providers beyond anthropic
@ -194,6 +194,8 @@ async def cancel_batch_run(
await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id)
# Update all the batch_job statuses
server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor)
await server.batch_manager.update_llm_batch_status_async(
llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor
)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")

View File

@ -22,36 +22,36 @@ logger = get_logger(__name__)
@router.post("/", response_model=PydanticSandboxConfig)
def create_sandbox_config(
async def create_sandbox_config(
config_create: SandboxConfigCreate,
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor)
return await server.sandbox_config_manager.create_or_update_sandbox_config_async(config_create, actor)
@router.post("/e2b/default", response_model=PydanticSandboxConfig)
def create_default_e2b_sandbox_config(
async def create_default_e2b_sandbox_config(
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.sandbox_config_manager.get_or_create_default_sandbox_config_async(sandbox_type=SandboxType.E2B, actor=actor)
@router.post("/local/default", response_model=PydanticSandboxConfig)
def create_default_local_sandbox_config(
async def create_default_local_sandbox_config(
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.sandbox_config_manager.get_or_create_default_sandbox_config_async(sandbox_type=SandboxType.LOCAL, actor=actor)
@router.post("/local", response_model=PydanticSandboxConfig)
def create_custom_local_sandbox_config(
async def create_custom_local_sandbox_config(
local_sandbox_config: LocalSandboxConfig,
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
@ -67,26 +67,26 @@ def create_custom_local_sandbox_config(
)
# Retrieve the user (actor)
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Wrap the LocalSandboxConfig into a SandboxConfigCreate
sandbox_config_create = SandboxConfigCreate(config=local_sandbox_config)
# Use the manager to create or update the sandbox config
sandbox_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=actor)
sandbox_config = await server.sandbox_config_manager.create_or_update_sandbox_config_async(sandbox_config_create, actor=actor)
return sandbox_config
@router.patch("/{sandbox_config_id}", response_model=PydanticSandboxConfig)
def update_sandbox_config(
async def update_sandbox_config(
sandbox_config_id: str,
config_update: SandboxConfigUpdate,
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
return await server.sandbox_config_manager.update_sandbox_config_async(sandbox_config_id, config_update, actor)
@router.delete("/{sandbox_config_id}", status_code=204)
@ -112,7 +112,7 @@ async def list_sandbox_configs(
@router.post("/local/recreate-venv", response_model=PydanticSandboxConfig)
def force_recreate_local_sandbox_venv(
async def force_recreate_local_sandbox_venv(
server: SyncServer = Depends(get_letta_server),
actor_id: str = Depends(get_user_id),
):
@ -120,10 +120,10 @@ def force_recreate_local_sandbox_venv(
Forcefully recreate the virtual environment for the local sandbox.
Deletes and recreates the venv, then reinstalls required dependencies.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Retrieve the local sandbox config
sbx_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
sbx_config = await server.sandbox_config_manager.get_or_create_default_sandbox_config_async(sandbox_type=SandboxType.LOCAL, actor=actor)
local_configs = sbx_config.get_local_config()
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde

View File

@ -1,3 +1,4 @@
import asyncio
import os
import tempfile
from typing import List, Optional
@ -21,18 +22,18 @@ router = APIRouter(prefix="/sources", tags=["sources"])
@router.get("/count", response_model=int, operation_id="count_sources")
def count_sources(
async def count_sources(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Count all data sources created by a user.
"""
return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
return await server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
@router.get("/{source_id}", response_model=Source, operation_id="retrieve_source")
def retrieve_source(
async def retrieve_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -42,14 +43,14 @@ def retrieve_source(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if not source:
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
return source
@router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name")
def get_source_id_by_name(
async def get_source_id_by_name(
source_name: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -59,14 +60,14 @@ def get_source_id_by_name(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
source = await server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
if not source:
raise HTTPException(status_code=404, detail=f"Source with name={source_name} not found.")
return source.id
@router.get("/", response_model=List[Source], operation_id="list_sources")
def list_sources(
async def list_sources(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@ -74,8 +75,7 @@ def list_sources(
List all data sources created by a user.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.list_all_sources(actor=actor)
return await server.source_manager.list_sources(actor=actor)
@router.get("/count", response_model=int, operation_id="count_sources")
@ -90,7 +90,7 @@ def count_sources(
@router.post("/", response_model=Source, operation_id="create_source")
def create_source(
async def create_source(
source_create: SourceCreate,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -99,6 +99,8 @@ def create_source(
Create a new data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# TODO: need to asyncify this
if not source_create.embedding_config:
if not source_create.embedding:
# TODO: modify error type
@ -115,11 +117,11 @@ def create_source(
instructions=source_create.instructions,
metadata=source_create.metadata,
)
return server.source_manager.create_source(source=source, actor=actor)
return await server.source_manager.create_source(source=source, actor=actor)
@router.patch("/{source_id}", response_model=Source, operation_id="modify_source")
def modify_source(
async def modify_source(
source_id: str,
source: SourceUpdate,
server: "SyncServer" = Depends(get_letta_server),
@ -130,13 +132,13 @@ def modify_source(
"""
# TODO: allow updating the handle/embedding config
actor = server.user_manager.get_user_or_default(user_id=actor_id)
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
if not await server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
return await server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
@router.delete("/{source_id}", response_model=None, operation_id="delete_source")
def delete_source(
async def delete_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@ -145,20 +147,21 @@ def delete_source(
Delete a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id)
agents = server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
source = await server.source_manager.get_source_by_id(source_id=source_id)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:
try:
# TODO: make async
block = server.agent_manager.get_block_with_label(agent_id=agent.id, block_label=source.name, actor=actor)
server.block_manager.delete_block(block.id, actor)
except:
pass
server.delete_source(source_id=source_id, actor=actor)
await server.delete_source(source_id=source_id, actor=actor)
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
def upload_file_to_source(
async def upload_file_to_source(
file: UploadFile,
source_id: str,
background_tasks: BackgroundTasks,
@ -170,7 +173,7 @@ def upload_file_to_source(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
bytes = file.file.read()
@ -184,8 +187,8 @@ def upload_file_to_source(
server.job_manager.create_job(job, actor=actor)
# create background tasks
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor)
background_tasks.add_task(sleeptime_document_ingest_async, server, source_id, actor)
asyncio.create_task(load_file_to_source_async(server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor))
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor))
# return job information
# Is this necessary? Can we just return the job from create_job?
@ -195,8 +198,11 @@ def upload_file_to_source(
@router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages")
def list_source_passages(
async def list_source_passages(
source_id: str,
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(100, description="Maximum number of messages to retrieve."),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@ -204,12 +210,17 @@ def list_source_passages(
List all passages associated with a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
return passages
return await server.agent_manager.list_passages_async(
actor=actor,
source_id=source_id,
after=after,
before=before,
limit=limit,
)
@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_source_files")
def list_source_files(
async def list_source_files(
source_id: str,
limit: int = Query(1000, description="Number of files to return"),
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
@ -220,13 +231,13 @@ def list_source_files(
List paginated files associated with a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
return await server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
# it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
# it's still good practice to return a status indicating the success or failure of the deletion
@router.delete("/{source_id}/{file_id}", status_code=204, operation_id="delete_file_from_source")
def delete_file_from_source(
async def delete_file_from_source(
source_id: str,
file_id: str,
background_tasks: BackgroundTasks,
@ -238,13 +249,15 @@ def delete_file_from_source(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
background_tasks.add_task(sleeptime_document_ingest_async, server, source_id, actor, clear_history=True)
deleted_file = await server.source_manager.delete_file(file_id=file_id, actor=actor)
# TODO: make async
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
if deleted_file is None:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User):
async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User):
# Create a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
# Sanitize the filename
@ -256,12 +269,12 @@ def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, f
buffer.write(bytes)
# Pass the file to load_file_to_source
server.load_file_to_source(source_id, file_path, job_id, actor)
await server.load_file_to_source(source_id, file_path, job_id, actor)
def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = server.source_manager.get_source_by_id(source_id=source_id)
agents = server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = await server.source_manager.get_source_by_id(source_id=source_id)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:
server.sleeptime_document_ingest(agent, source, actor, clear_history)
server.sleeptime_document_ingest(agent, source, actor, clear_history) # TODO: make async

View File

@ -50,7 +50,7 @@ from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolRe
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage, PassageUpdate
@ -969,6 +969,11 @@ class SyncServer(Server):
"""Return the memory of an agent (core memory)"""
return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory
async def get_agent_memory_async(self, agent_id: str, actor: User) -> Memory:
"""Return the memory of an agent (core memory)"""
agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
return agent.memory
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
@ -1169,17 +1174,20 @@ class SyncServer(Server):
# rebuild system prompt for agent, potentially changed
return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory
def delete_source(self, source_id: str, actor: User):
async def delete_source(self, source_id: str, actor: User):
"""Delete a data source"""
self.source_manager.delete_source(source_id=source_id, actor=actor)
await self.source_manager.delete_source(source_id=source_id, actor=actor)
# delete data from passage store
# TODO: make async
passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
# TODO: make this async
self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
# TODO: delete data from agent passage stores (?)
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
# update job
job = self.job_manager.get_job_by_id(job_id, actor=actor)
@ -1189,21 +1197,22 @@ class SyncServer(Server):
# try:
from letta.data_sources.connectors import DirectoryConnector
source = self.source_manager.get_source_by_id(source_id=source_id)
# TODO: move this into a thread
source = await self.source_manager.get_source_by_id(source_id=source_id)
if source is None:
raise ValueError(f"Source {source_id} does not exist")
connector = DirectoryConnector(input_files=[file_path])
num_passages, num_documents = self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
# update all agents who have this source attached
agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent_state in agent_states:
agent_id = agent_state.id
# Attach source to agent
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
curr_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
agent_state = self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor)
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added
# rebuild system prompt and force
@ -1266,7 +1275,7 @@ class SyncServer(Server):
actor=actor,
)
def load_data(
async def load_data(
self,
user_id: str,
connector: DataConnector,
@ -1277,12 +1286,12 @@ class SyncServer(Server):
# load data from a data source into the document store
user = self.user_manager.get_user_by_id(user_id=user_id)
source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
source = await self.source_manager.get_source_by_name(source_name=source_name, actor=user)
if source is None:
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
# load data into the document store
passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
passage_count, document_count = await load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
return passage_count, document_count
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
@ -1290,6 +1299,7 @@ class SyncServer(Server):
return self.agent_manager.list_passages(actor=self.user_manager.get_user_or_default(user_id=user_id), source_id=source_id)
def list_all_sources(self, actor: User) -> List[Source]:
# TODO: legacy: remove
"""List all sources (w/ extra metadata) belonging to a user"""
sources = self.source_manager.list_sources(actor=actor)
@ -1376,7 +1386,7 @@ class SyncServer(Server):
"""Asynchronously list available models with maximum concurrency"""
import asyncio
providers = self.get_enabled_providers(
providers = await self.get_enabled_providers_async(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
@ -1422,7 +1432,7 @@ class SyncServer(Server):
import asyncio
# Get all eligible providers first
providers = self.get_enabled_providers(actor=actor)
providers = await self.get_enabled_providers_async(actor=actor)
# Fetch embedding models from each provider concurrently
async def get_provider_embedding_models(provider):
@ -1475,6 +1485,35 @@ class SyncServer(Server):
return providers
async def get_enabled_providers_async(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
if not provider_category or ProviderCategory.byok in provider_category:
providers_from_db = await self.provider_manager.list_providers_async(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
providers.extend(providers_from_db)
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
if provider_type is not None:
providers = [p for p in providers if p.provider_type == provider_type]
return providers
@trace_method
def get_llm_config_from_handle(
self,
@ -1613,14 +1652,6 @@ class SyncServer(Server):
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
"""Add a new embedding model"""
def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview:
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.get_context_window()
async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview:
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return await letta_agent.get_context_window_async()
def run_tool_from_source(
self,
actor: User,

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
logger = get_logger(__name__)
@ -22,6 +23,7 @@ logger = get_logger(__name__)
class BlockManager:
"""Manager class to handle business logic related to Blocks."""
@trace_method
@enforce_types
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
"""Create a new block based on the Block schema."""
@ -36,6 +38,7 @@ class BlockManager:
block.create(session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]:
"""
@ -59,6 +62,7 @@ class BlockManager:
# Convert back to Pydantic
return [m.to_pydantic() for m in created_models]
@trace_method
@enforce_types
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
@ -74,6 +78,7 @@ class BlockManager:
block.update(db_session=session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
"""Delete a block by its ID."""
@ -82,6 +87,7 @@ class BlockManager:
block.hard_delete(db_session=session, actor=actor)
return block.to_pydantic()
@trace_method
@enforce_types
async def get_blocks_async(
self,
@ -144,68 +150,7 @@ class BlockManager:
return [block.to_pydantic() for block in blocks]
@enforce_types
async def get_blocks_async(
self,
actor: PydanticUser,
label: Optional[str] = None,
is_template: Optional[bool] = None,
template_name: Optional[str] = None,
identity_id: Optional[str] = None,
identifier_keys: Optional[List[str]] = None,
limit: Optional[int] = 50,
) -> List[PydanticBlock]:
"""Async version of get_blocks method. Retrieve blocks based on various optional filters."""
from sqlalchemy import select
from sqlalchemy.orm import noload
from letta.orm.sqlalchemy_base import AccessType
async with db_registry.async_session() as session:
# Start with a basic query
query = select(BlockModel)
# Explicitly avoid loading relationships
query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups))
# Apply access control
query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
# Add filters
query = query.where(BlockModel.organization_id == actor.organization_id)
if label:
query = query.where(BlockModel.label == label)
if is_template is not None:
query = query.where(BlockModel.is_template == is_template)
if template_name:
query = query.where(BlockModel.template_name == template_name)
if identifier_keys:
query = (
query.join(BlockModel.identities)
.filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
.distinct(BlockModel.id)
)
if identity_id:
query = (
query.join(BlockModel.identities)
.filter(BlockModel.identities.property.mapper.class_.id == identity_id)
.distinct(BlockModel.id)
)
# Add limit
if limit:
query = query.limit(limit)
# Execute the query
result = await session.execute(query)
blocks = result.scalars().all()
return [block.to_pydantic() for block in blocks]
@trace_method
@enforce_types
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
@ -216,6 +161,7 @@ class BlockManager:
except NoResultFound:
return None
@trace_method
@enforce_types
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids without loading unnecessary relationships. Async implementation."""
@ -263,6 +209,7 @@ class BlockManager:
return pydantic_blocks
@trace_method
@enforce_types
async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
"""
@ -273,6 +220,7 @@ class BlockManager:
agents_orm = block.agents
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
@trace_method
@enforce_types
def size(
self,
@ -286,6 +234,7 @@ class BlockManager:
# Block History Functions
@trace_method
@enforce_types
def checkpoint_block(
self,
@ -389,6 +338,7 @@ class BlockManager:
updated_block = block.update(db_session=session, actor=actor, no_commit=True)
return updated_block
@trace_method
@enforce_types
def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
"""
@ -431,6 +381,7 @@ class BlockManager:
session.commit()
return block.to_pydantic()
@trace_method
@enforce_types
def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock:
"""
@ -469,6 +420,7 @@ class BlockManager:
session.commit()
return block.to_pydantic()
@trace_method
@enforce_types
async def bulk_update_block_values_async(
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False

View File

@ -12,11 +12,13 @@ from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
class GroupManager:
@trace_method
@enforce_types
def list_groups(
self,
@ -42,12 +44,14 @@ class GroupManager:
)
return [group.to_pydantic() for group in groups]
@trace_method
@enforce_types
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
with db_registry.session() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
return group.to_pydantic()
@trace_method
@enforce_types
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
with db_registry.session() as session:
@ -93,6 +97,7 @@ class GroupManager:
new_group.create(session, actor=actor)
return new_group.to_pydantic()
@trace_method
@enforce_types
def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
with db_registry.session() as session:
@ -155,6 +160,7 @@ class GroupManager:
group.update(session, actor=actor)
return group.to_pydantic()
@trace_method
@enforce_types
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
with db_registry.session() as session:
@ -162,6 +168,7 @@ class GroupManager:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
group.hard_delete(session)
@trace_method
@enforce_types
def list_group_messages(
self,
@ -198,6 +205,7 @@ class GroupManager:
return messages
@trace_method
@enforce_types
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
with db_registry.session() as session:
@ -211,6 +219,7 @@ class GroupManager:
session.commit()
@trace_method
@enforce_types
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
with db_registry.session() as session:
@ -222,6 +231,18 @@ class GroupManager:
group.update(session, actor=actor)
return group.turns_counter
@trace_method
@enforce_types
async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
group = await GroupModel.read_async(session, identifier=group_id, actor=actor)
# Update turns counter
group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency
await group.update_async(session, actor=actor)
return group.turns_counter
@enforce_types
def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str:
with db_registry.session() as session:
@ -235,6 +256,22 @@ class GroupManager:
return prev_last_processed_message_id
@trace_method
@enforce_types
async def get_last_processed_message_id_and_update_async(
self, group_id: str, last_processed_message_id: str, actor: PydanticUser
) -> str:
async with db_registry.async_session() as session:
# Ensure group is loadable by user
group = await GroupModel.read_async(session, identifier=group_id, actor=actor)
# Update last processed message id
prev_last_processed_message_id = group.last_processed_message_id
group.last_processed_message_id = last_processed_message_id
await group.update_async(session, actor=actor)
return prev_last_processed_message_id
@enforce_types
def size(
self,

View File

@ -12,12 +12,14 @@ from letta.schemas.identity import Identity as PydanticIdentity
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
class IdentityManager:
@enforce_types
@trace_method
async def list_identities_async(
self,
name: Optional[str] = None,
@ -48,12 +50,14 @@ class IdentityManager:
return [identity.to_pydantic() for identity in identities]
@enforce_types
@trace_method
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
return identity.to_pydantic()
@enforce_types
@trace_method
async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
@ -78,6 +82,7 @@ class IdentityManager:
return new_identity.to_pydantic()
@enforce_types
@trace_method
async def upsert_identity_async(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
async with db_registry.async_session() as session:
existing_identity = await IdentityModel.read_async(
@ -103,6 +108,7 @@ class IdentityManager:
)
@enforce_types
@trace_method
async def update_identity_async(
self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
) -> PydanticIdentity:
@ -165,6 +171,7 @@ class IdentityManager:
return existing_identity.to_pydantic()
@enforce_types
@trace_method
async def upsert_identity_properties_async(
self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser
) -> PydanticIdentity:
@ -181,6 +188,7 @@ class IdentityManager:
)
@enforce_types
@trace_method
async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None:
async with db_registry.async_session() as session:
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
@ -192,6 +200,7 @@ class IdentityManager:
await session.commit()
@enforce_types
@trace_method
async def size_async(
self,
actor: PydanticUser,

View File

@ -25,6 +25,7 @@ from letta.schemas.step import Step as PydanticStep
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
@ -32,6 +33,7 @@ class JobManager:
"""Manager class to handle business logic related to Jobs."""
@enforce_types
@trace_method
def create_job(
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
@ -45,6 +47,7 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
async def create_job_async(
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
@ -58,6 +61,7 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
def update_job_by_id(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob:
"""Update a job by its ID with the given JobUpdate object."""
with db_registry.session() as session:
@ -82,6 +86,7 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
async def update_job_by_id_async(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob:
"""Update a job by its ID with the given JobUpdate object asynchronously."""
async with db_registry.async_session() as session:
@ -106,6 +111,7 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Fetch a job by its ID."""
with db_registry.session() as session:
@ -114,6 +120,7 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Fetch a job by its ID asynchronously."""
async with db_registry.async_session() as session:
@ -122,6 +129,7 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
def list_jobs(
self,
actor: PydanticUser,
@ -151,6 +159,7 @@ class JobManager:
return [job.to_pydantic() for job in jobs]
@enforce_types
@trace_method
async def list_jobs_async(
self,
actor: PydanticUser,
@ -180,6 +189,7 @@ class JobManager:
return [job.to_pydantic() for job in jobs]
@enforce_types
@trace_method
def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Delete a job by its ID."""
with db_registry.session() as session:
@ -188,6 +198,7 @@ class JobManager:
return job.to_pydantic()
@enforce_types
@trace_method
def get_job_messages(
self,
job_id: str,
@ -238,6 +249,7 @@ class JobManager:
return [message.to_pydantic() for message in messages]
@enforce_types
@trace_method
def get_job_steps(
self,
job_id: str,
@ -283,6 +295,7 @@ class JobManager:
return [step.to_pydantic() for step in steps]
@enforce_types
@trace_method
def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None:
"""
Associate a message with a job by creating a JobMessage record.
@ -306,6 +319,7 @@ class JobManager:
session.commit()
@enforce_types
@trace_method
def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics:
"""
Get usage statistics for a job.
@ -343,6 +357,7 @@ class JobManager:
)
@enforce_types
@trace_method
def add_job_usage(
self,
job_id: str,
@ -383,6 +398,7 @@ class JobManager:
session.commit()
@enforce_types
@trace_method
def get_run_messages(
self,
run_id: str,
@ -434,6 +450,7 @@ class JobManager:
return messages
@enforce_types
@trace_method
def get_step_messages(
self,
run_id: str,

View File

@ -17,6 +17,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
logger = get_logger(__name__)
@ -26,6 +27,7 @@ class LLMBatchManager:
"""Manager for handling both LLMBatchJob and LLMBatchItem operations."""
@enforce_types
@trace_method
async def create_llm_batch_job_async(
self,
llm_provider: ProviderType,
@ -47,6 +49,7 @@ class LLMBatchManager:
return batch.to_pydantic()
@enforce_types
@trace_method
async def get_llm_batch_job_by_id_async(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
"""Retrieve a single batch job by ID."""
async with db_registry.async_session() as session:
@ -54,7 +57,8 @@ class LLMBatchManager:
return batch.to_pydantic()
@enforce_types
def update_llm_batch_status(
@trace_method
async def update_llm_batch_status_async(
self,
llm_batch_id: str,
status: JobStatus,
@ -62,15 +66,15 @@ class LLMBatchManager:
latest_polling_response: Optional[BetaMessageBatch] = None,
) -> PydanticLLMBatchJob:
"""Update a batch jobs status and optionally its polling response."""
with db_registry.session() as session:
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
async with db_registry.async_session() as session:
batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor)
batch.status = status
batch.latest_polling_response = latest_polling_response
batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc)
batch = batch.update(db_session=session, actor=actor)
batch = await batch.update_async(db_session=session, actor=actor)
return batch.to_pydantic()
def bulk_update_llm_batch_statuses(
async def bulk_update_llm_batch_statuses_async(
self,
updates: List[BatchPollingResult],
) -> None:
@ -81,7 +85,7 @@ class LLMBatchManager:
"""
now = datetime.datetime.now(datetime.timezone.utc)
with db_registry.session() as session:
async with db_registry.async_session() as session:
mappings = []
for llm_batch_id, status, response in updates:
mappings.append(
@ -93,17 +97,18 @@ class LLMBatchManager:
}
)
session.bulk_update_mappings(LLMBatchJob, mappings)
session.commit()
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings))
await session.commit()
@enforce_types
def list_llm_batch_jobs(
@trace_method
async def list_llm_batch_jobs_async(
self,
letta_batch_id: str,
limit: Optional[int] = None,
actor: Optional[PydanticUser] = None,
after: Optional[str] = None,
) -> List[PydanticLLMBatchItem]:
) -> List[PydanticLLMBatchJob]:
"""
List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count.
@ -115,33 +120,35 @@ class LLMBatchManager:
The results are ordered by their id in ascending order.
"""
with db_registry.session() as session:
query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id)
async with db_registry.async_session() as session:
query = select(LLMBatchJob).where(LLMBatchJob.letta_batch_job_id == letta_batch_id)
if actor is not None:
query = query.filter(LLMBatchJob.organization_id == actor.organization_id)
query = query.where(LLMBatchJob.organization_id == actor.organization_id)
# Additional optional filters
if after is not None:
query = query.filter(LLMBatchJob.id > after)
query = query.where(LLMBatchJob.id > after)
query = query.order_by(LLMBatchJob.id.asc())
if limit is not None:
query = query.limit(limit)
results = query.all()
return [item.to_pydantic() for item in results]
results = await session.execute(query)
return [item.to_pydantic() for item in results.scalars().all()]
@enforce_types
def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None:
@trace_method
async def delete_llm_batch_request_async(self, llm_batch_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch job by ID."""
with db_registry.session() as session:
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
batch.hard_delete(db_session=session, actor=actor)
async with db_registry.async_session() as session:
batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor)
await batch.hard_delete_async(db_session=session, actor=actor)
@enforce_types
def get_messages_for_letta_batch(
@trace_method
async def get_messages_for_letta_batch_async(
self,
letta_batch_job_id: str,
limit: int = 100,
@ -154,12 +161,12 @@ class LLMBatchManager:
Retrieve messages across all LLM batch jobs associated with a Letta batch job.
Optimized for PostgreSQL performance using ID-based keyset pagination.
"""
with db_registry.session() as session:
async with db_registry.async_session() as session:
# If cursor is provided, get sequence_id for that message
cursor_sequence_id = None
if cursor:
cursor_query = session.query(MessageModel.sequence_id).filter(MessageModel.id == cursor).limit(1)
cursor_result = cursor_query.first()
cursor_query = select(MessageModel.sequence_id).where(MessageModel.id == cursor).limit(1)
cursor_result = await session.execute(cursor_query)
if cursor_result:
cursor_sequence_id = cursor_result[0]
else:
@ -167,24 +174,24 @@ class LLMBatchManager:
pass
query = (
session.query(MessageModel)
select(MessageModel)
.join(LLMBatchItem, MessageModel.batch_item_id == LLMBatchItem.id)
.join(LLMBatchJob, LLMBatchItem.llm_batch_id == LLMBatchJob.id)
.filter(LLMBatchJob.letta_batch_job_id == letta_batch_job_id)
.where(LLMBatchJob.letta_batch_job_id == letta_batch_job_id)
)
if actor is not None:
query = query.filter(MessageModel.organization_id == actor.organization_id)
query = query.where(MessageModel.organization_id == actor.organization_id)
if agent_id is not None:
query = query.filter(MessageModel.agent_id == agent_id)
query = query.where(MessageModel.agent_id == agent_id)
# Apply cursor-based pagination if cursor exists
if cursor_sequence_id is not None:
if sort_descending:
query = query.filter(MessageModel.sequence_id < cursor_sequence_id)
query = query.where(MessageModel.sequence_id < cursor_sequence_id)
else:
query = query.filter(MessageModel.sequence_id > cursor_sequence_id)
query = query.where(MessageModel.sequence_id > cursor_sequence_id)
if sort_descending:
query = query.order_by(desc(MessageModel.sequence_id))
@ -193,10 +200,11 @@ class LLMBatchManager:
query = query.limit(limit)
results = query.all()
return [message.to_pydantic() for message in results]
results = await session.execute(query)
return [message.to_pydantic() for message in results.scalars().all()]
@enforce_types
@trace_method
async def list_running_llm_batches_async(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
async with db_registry.async_session() as session:
@ -209,7 +217,8 @@ class LLMBatchManager:
return [batch.to_pydantic() for batch in results.scalars().all()]
@enforce_types
def create_llm_batch_item(
@trace_method
async def create_llm_batch_item_async(
self,
llm_batch_id: str,
agent_id: str,
@ -220,7 +229,7 @@ class LLMBatchManager:
step_state: Optional[AgentStepState] = None,
) -> PydanticLLMBatchItem:
"""Create a new batch item."""
with db_registry.session() as session:
async with db_registry.async_session() as session:
item = LLMBatchItem(
llm_batch_id=llm_batch_id,
agent_id=agent_id,
@ -230,10 +239,11 @@ class LLMBatchManager:
step_state=step_state,
organization_id=actor.organization_id,
)
item.create(session, actor=actor)
await item.create_async(session, actor=actor)
return item.to_pydantic()
@enforce_types
@trace_method
async def create_llm_batch_items_bulk_async(
self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser
) -> List[PydanticLLMBatchItem]:
@ -269,14 +279,16 @@ class LLMBatchManager:
return [item.to_pydantic() for item in created_items]
@enforce_types
def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
@trace_method
async def get_llm_batch_item_by_id_async(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
"""Retrieve a single batch item by ID."""
with db_registry.session() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
async with db_registry.async_session() as session:
item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor)
return item.to_pydantic()
@enforce_types
def update_llm_batch_item(
@trace_method
async def update_llm_batch_item_async(
self,
item_id: str,
actor: PydanticUser,
@ -286,8 +298,8 @@ class LLMBatchManager:
step_state: Optional[AgentStepState] = None,
) -> PydanticLLMBatchItem:
"""Update fields on a batch item."""
with db_registry.session() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
async with db_registry.async_session() as session:
item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor)
if request_status:
item.request_status = request_status
@ -298,9 +310,11 @@ class LLMBatchManager:
if step_state:
item.step_state = step_state
return item.update(db_session=session, actor=actor).to_pydantic()
result = await item.update_async(db_session=session, actor=actor)
return result.to_pydantic()
@enforce_types
@trace_method
async def list_llm_batch_items_async(
self,
llm_batch_id: str,
@ -346,7 +360,8 @@ class LLMBatchManager:
results = await session.execute(query)
return [item.to_pydantic() for item in results.scalars()]
def bulk_update_llm_batch_items(
@trace_method
async def bulk_update_llm_batch_items_async(
self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True
) -> None:
"""
@ -364,13 +379,13 @@ class LLMBatchManager:
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
with db_registry.session() as session:
async with db_registry.async_session() as session:
# Lookup primary keys for all requested (batch_id, agent_id) pairs
items = (
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
.filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs))
.all()
query = select(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).filter(
tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs)
)
result = await session.execute(query)
items = result.all()
pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items}
if strict:
@ -395,11 +410,12 @@ class LLMBatchManager:
mappings.append(update_fields)
if mappings:
session.bulk_update_mappings(LLMBatchItem, mappings)
session.commit()
await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings))
await session.commit()
@enforce_types
def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
@trace_method
async def bulk_update_batch_llm_items_results_by_agent_async(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None:
"""Update request status and batch results for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [
@ -410,33 +426,41 @@ class LLMBatchManager:
for update in updates
]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None:
@trace_method
async def bulk_update_llm_batch_items_step_status_by_agent_async(
self, updates: List[StepStatusUpdateInfo], strict: bool = True
) -> None:
"""Update step status for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"step_status": update.step_status} for update in updates]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None:
@trace_method
async def bulk_update_llm_batch_items_request_status_by_agent_async(
self, updates: List[RequestStatusUpdateInfo], strict: bool = True
) -> None:
"""Update request status for multiple batch items."""
batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates]
field_updates = [{"request_status": update.request_status} for update in updates]
self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict)
await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict)
@enforce_types
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
@trace_method
async def delete_llm_batch_item_async(self, item_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch item by ID."""
with db_registry.session() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
item.hard_delete(db_session=session, actor=actor)
async with db_registry.async_session() as session:
item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor)
await item.hard_delete_async(db_session=session, actor=actor)
@enforce_types
def count_llm_batch_items(self, llm_batch_id: str) -> int:
@trace_method
async def count_llm_batch_items_async(self, llm_batch_id: str) -> int:
"""
Efficiently count the number of batch items for a given llm_batch_id.
@ -446,6 +470,6 @@ class LLMBatchManager:
Returns:
int: The total number of batch items associated with the given llm_batch_id.
"""
with db_registry.session() as session:
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar()
return count or 0
async with db_registry.async_session() as session:
count = await session.execute(select(func.count(LLMBatchItem.id)).where(LLMBatchItem.llm_batch_id == llm_batch_id))
return count.scalar() or 0

View File

@ -13,6 +13,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
logger = get_logger(__name__)
@ -22,6 +23,7 @@ class MessageManager:
"""Manager class to handle business logic related to Messages."""
@enforce_types
@trace_method
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
"""Fetch a message by ID."""
with db_registry.session() as session:
@ -32,6 +34,7 @@ class MessageManager:
return None
@enforce_types
@trace_method
async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
"""Fetch a message by ID."""
async with db_registry.async_session() as session:
@ -42,6 +45,7 @@ class MessageManager:
return None
@enforce_types
@trace_method
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
"""Fetch messages by ID and return them in the requested order."""
with db_registry.session() as session:
@ -49,6 +53,7 @@ class MessageManager:
return self._get_messages_by_id_postprocess(results, message_ids)
@enforce_types
@trace_method
async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
"""Fetch messages by ID and return them in the requested order. Async version of above function."""
async with db_registry.async_session() as session:
@ -71,6 +76,7 @@ class MessageManager:
return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids]))
@enforce_types
@trace_method
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
"""Create a new message."""
with db_registry.session() as session:
@ -92,6 +98,7 @@ class MessageManager:
return orm_messages
@enforce_types
@trace_method
def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
"""
Create multiple messages in a single database transaction.
@ -111,6 +118,7 @@ class MessageManager:
return [msg.to_pydantic() for msg in created_messages]
@enforce_types
@trace_method
async def create_many_messages_async(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]:
"""
Create multiple messages in a single database transaction asynchronously.
@ -131,6 +139,7 @@ class MessageManager:
return [msg.to_pydantic() for msg in created_messages]
@enforce_types
@trace_method
def update_message_by_letta_message(
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
) -> PydanticMessage:
@ -169,6 +178,7 @@ class MessageManager:
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
@enforce_types
@trace_method
def update_message_by_letta_message(
self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser
) -> PydanticMessage:
@ -207,6 +217,7 @@ class MessageManager:
raise ValueError(f"Message type got modified: {letta_message_update.message_type}")
@enforce_types
@trace_method
def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
@ -224,6 +235,7 @@ class MessageManager:
return message.to_pydantic()
@enforce_types
@trace_method
async def update_message_by_id_async(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
@ -267,6 +279,7 @@ class MessageManager:
return message
@enforce_types
@trace_method
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
"""Delete a message."""
with db_registry.session() as session:
@ -281,6 +294,7 @@ class MessageManager:
raise ValueError(f"Message with id {message_id} not found.")
@enforce_types
@trace_method
def size(
self,
actor: PydanticUser,
@ -297,6 +311,7 @@ class MessageManager:
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
@enforce_types
@trace_method
async def size_async(
self,
actor: PydanticUser,
@ -312,6 +327,7 @@ class MessageManager:
return await MessageModel.size_async(db_session=session, actor=actor, role=role, agent_id=agent_id)
@enforce_types
@trace_method
def list_user_messages_for_agent(
self,
agent_id: str,
@ -334,6 +350,7 @@ class MessageManager:
)
@enforce_types
@trace_method
def list_messages_for_agent(
self,
agent_id: str,
@ -437,6 +454,7 @@ class MessageManager:
return [msg.to_pydantic() for msg in results]
@enforce_types
@trace_method
async def list_messages_for_agent_async(
self,
agent_id: str,
@ -538,6 +556,7 @@ class MessageManager:
return [msg.to_pydantic() for msg in results]
@enforce_types
@trace_method
def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int:
"""
Efficiently deletes all messages associated with a given agent_id,

View File

@ -5,6 +5,7 @@ from letta.orm.organization import Organization as OrganizationModel
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.organization import OrganizationUpdate
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
@ -15,11 +16,13 @@ class OrganizationManager:
DEFAULT_ORG_NAME = "default_org"
@enforce_types
@trace_method
def get_default_organization(self) -> PydanticOrganization:
"""Fetch the default organization."""
return self.get_organization_by_id(self.DEFAULT_ORG_ID)
@enforce_types
@trace_method
def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]:
"""Fetch an organization by ID."""
with db_registry.session() as session:
@ -27,6 +30,7 @@ class OrganizationManager:
return organization.to_pydantic()
@enforce_types
@trace_method
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
"""Create a new organization."""
try:
@ -36,6 +40,7 @@ class OrganizationManager:
return self._create_organization(pydantic_org=pydantic_org)
@enforce_types
@trace_method
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
with db_registry.session() as session:
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
@ -43,11 +48,13 @@ class OrganizationManager:
return org.to_pydantic()
@enforce_types
@trace_method
def create_default_organization(self) -> PydanticOrganization:
"""Create the default organization."""
return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
@enforce_types
@trace_method
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
"""Update an organization."""
with db_registry.session() as session:
@ -58,6 +65,7 @@ class OrganizationManager:
return org.to_pydantic()
@enforce_types
@trace_method
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
"""Update an organization."""
with db_registry.session() as session:
@ -70,6 +78,7 @@ class OrganizationManager:
return org.to_pydantic()
@enforce_types
@trace_method
def delete_organization_by_id(self, org_id: str):
"""Delete an organization by marking it as deleted."""
with db_registry.session() as session:
@ -77,6 +86,7 @@ class OrganizationManager:
organization.hard_delete(session)
@enforce_types
@trace_method
def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
"""List all organizations with optional pagination."""
with db_registry.session() as session:

View File

@ -11,6 +11,7 @@ from letta.schemas.agent import AgentState
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
@ -18,6 +19,7 @@ class PassageManager:
"""Manager class to handle business logic related to Passages."""
@enforce_types
@trace_method
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch a passage by ID."""
with db_registry.session() as session:
@ -34,6 +36,7 @@ class PassageManager:
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
@enforce_types
@trace_method
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
# Common fields for both passage types
@ -70,11 +73,13 @@ class PassageManager:
return passage.to_pydantic()
@enforce_types
@trace_method
def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple passages."""
return [self.create_passage(p, actor) for p in passages]
@enforce_types
@trace_method
def insert_passage(
self,
agent_state: AgentState,
@ -136,6 +141,7 @@ class PassageManager:
raise e
@enforce_types
@trace_method
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
"""Update a passage."""
if not passage_id:
@ -170,6 +176,7 @@ class PassageManager:
return curr_passage.to_pydantic()
@enforce_types
@trace_method
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete a passage from either source or archival passages."""
if not passage_id:
@ -190,6 +197,8 @@ class PassageManager:
except NoResultFound:
raise NoResultFound(f"Passage with id {passage_id} not found.")
@enforce_types
@trace_method
def delete_passages(
self,
actor: PydanticUser,
@ -202,6 +211,7 @@ class PassageManager:
return True
@enforce_types
@trace_method
def size(
self,
actor: PydanticUser,
@ -217,6 +227,7 @@ class PassageManager:
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
@enforce_types
@trace_method
async def size_async(
self,
actor: PydanticUser,
@ -230,6 +241,8 @@ class PassageManager:
async with db_registry.async_session() as session:
return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id)
@enforce_types
@trace_method
def estimate_embeddings_size(
self,
actor: PydanticUser,

View File

@ -1,6 +1,8 @@
import threading
from collections import defaultdict
from letta.tracing import trace_method
class PerAgentLockManager:
"""Manages per-agent locks."""
@ -8,10 +10,12 @@ class PerAgentLockManager:
def __init__(self):
self.locks = defaultdict(threading.Lock)
@trace_method
def get_lock(self, agent_id: str) -> threading.Lock:
"""Retrieve the lock for a specific agent_id."""
return self.locks[agent_id]
@trace_method
def clear_lock(self, agent_id: str):
"""Optionally remove a lock if no longer needed (to prevent unbounded growth)."""
if agent_id in self.locks:

View File

@ -6,12 +6,14 @@ from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types
class ProviderManager:
@enforce_types
@trace_method
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
"""Create a new provider if it doesn't already exist."""
with db_registry.session() as session:
@ -32,6 +34,7 @@ class ProviderManager:
return new_provider.to_pydantic()
@enforce_types
@trace_method
def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
"""Update provider details."""
with db_registry.session() as session:
@ -48,6 +51,7 @@ class ProviderManager:
return existing_provider.to_pydantic()
@enforce_types
@trace_method
def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
"""Delete a provider."""
with db_registry.session() as session:
@ -62,6 +66,7 @@ class ProviderManager:
session.commit()
@enforce_types
@trace_method
def list_providers(
self,
actor: PydanticUser,
@ -87,16 +92,45 @@ class ProviderManager:
return [provider.to_pydantic() for provider in providers]
@enforce_types
@trace_method
async def list_providers_async(
self,
actor: PydanticUser,
name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
) -> List[PydanticProvider]:
"""List all providers with optional pagination."""
filter_kwargs = {}
if name:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
async with db_registry.async_session() as session:
providers = await ProviderModel.list_async(
db_session=session,
after=after,
limit=limit,
actor=actor,
**filter_kwargs,
)
return [provider.to_pydantic() for provider in providers]
@enforce_types
@trace_method
def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].id if providers else None
@enforce_types
@trace_method
def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].api_key if providers else None
@enforce_types
@trace_method
def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
provider = PydanticProvider(
name=provider_check.provider_type.value,

View File

@ -12,6 +12,7 @@ from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types, printd
logger = get_logger(__name__)
@ -21,6 +22,7 @@ class SandboxConfigManager:
"""Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable."""
@enforce_types
@trace_method
def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig:
sandbox_config = self.get_sandbox_config_by_type(sandbox_type, actor=actor)
if not sandbox_config:
@ -38,6 +40,7 @@ class SandboxConfigManager:
return sandbox_config
@enforce_types
@trace_method
def create_or_update_sandbox_config(self, sandbox_config_create: SandboxConfigCreate, actor: PydanticUser) -> PydanticSandboxConfig:
"""Create or update a sandbox configuration based on the PydanticSandboxConfig schema."""
config = sandbox_config_create.config
@ -71,6 +74,61 @@ class SandboxConfigManager:
return db_sandbox.to_pydantic()
@enforce_types
@trace_method
async def get_or_create_default_sandbox_config_async(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig:
sandbox_config = await self.get_sandbox_config_by_type_async(sandbox_type, actor=actor)
if not sandbox_config:
logger.debug(f"Creating new sandbox config of type {sandbox_type}, none found for organization {actor.organization_id}.")
# TODO: Add more sandbox types later
if sandbox_type == SandboxType.E2B:
default_config = {} # Empty
else:
# TODO: May want to move this to environment variables v.s. persisting in database
default_local_sandbox_path = LETTA_TOOL_EXECUTION_DIR
default_config = LocalSandboxConfig(sandbox_dir=default_local_sandbox_path).model_dump(exclude_none=True)
sandbox_config = await self.create_or_update_sandbox_config_async(SandboxConfigCreate(config=default_config), actor=actor)
return sandbox_config
@enforce_types
@trace_method
async def create_or_update_sandbox_config_async(
self, sandbox_config_create: SandboxConfigCreate, actor: PydanticUser
) -> PydanticSandboxConfig:
"""Create or update a sandbox configuration based on the PydanticSandboxConfig schema."""
config = sandbox_config_create.config
sandbox_type = config.type
sandbox_config = PydanticSandboxConfig(
type=sandbox_type, config=config.model_dump(exclude_none=True), organization_id=actor.organization_id
)
# Attempt to retrieve the existing sandbox configuration by type within the organization
db_sandbox = await self.get_sandbox_config_by_type_async(sandbox_config.type, actor=actor)
if db_sandbox:
# Prepare the update data, excluding fields that should not be reset
update_data = sandbox_config.model_dump(exclude_unset=True, exclude_none=True)
update_data = {key: value for key, value in update_data.items() if getattr(db_sandbox, key) != value}
# If there are changes, update the sandbox configuration
if update_data:
db_sandbox = await self.update_sandbox_config_async(db_sandbox.id, SandboxConfigUpdate(**update_data), actor)
else:
printd(
f"`create_or_update_sandbox_config` was called with user_id={actor.id}, organization_id={actor.organization_id}, "
f"type={sandbox_config.type}, but found existing configuration with nothing to update."
)
return db_sandbox
else:
# If the sandbox configuration doesn't exist, create a new one
async with db_registry.async_session() as session:
db_sandbox = SandboxConfigModel(**sandbox_config.model_dump(exclude_none=True))
await db_sandbox.create_async(session, actor=actor)
return db_sandbox.to_pydantic()
@enforce_types
@trace_method
def update_sandbox_config(
self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser
) -> PydanticSandboxConfig:
@ -98,6 +156,35 @@ class SandboxConfigManager:
return sandbox.to_pydantic()
@enforce_types
@trace_method
async def update_sandbox_config_async(
self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser
) -> PydanticSandboxConfig:
"""Update an existing sandbox configuration."""
async with db_registry.async_session() as session:
sandbox = await SandboxConfigModel.read_async(db_session=session, identifier=sandbox_config_id, actor=actor)
# We need to check that the sandbox_update provided is the same type as the original sandbox
if sandbox.type != sandbox_update.config.type:
raise ValueError(
f"Mismatched type for sandbox config update: tried to update sandbox_config of type {sandbox.type} with config of type {sandbox_update.config.type}"
)
update_data = sandbox_update.model_dump(exclude_unset=True, exclude_none=True)
update_data = {key: value for key, value in update_data.items() if getattr(sandbox, key) != value}
if update_data:
for key, value in update_data.items():
setattr(sandbox, key, value)
await sandbox.update_async(db_session=session, actor=actor)
else:
printd(
f"`update_sandbox_config` called with user_id={actor.id}, organization_id={actor.organization_id}, "
f"name={sandbox.type}, but nothing to update."
)
return sandbox.to_pydantic()
@enforce_types
@trace_method
def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig:
"""Delete a sandbox configuration by its ID."""
with db_registry.session() as session:
@ -106,6 +193,7 @@ class SandboxConfigManager:
return sandbox.to_pydantic()
@enforce_types
@trace_method
def list_sandbox_configs(
self,
actor: PydanticUser,
@ -123,6 +211,7 @@ class SandboxConfigManager:
return [sandbox.to_pydantic() for sandbox in sandboxes]
@enforce_types
@trace_method
async def list_sandbox_configs_async(
self,
actor: PydanticUser,
@ -140,6 +229,7 @@ class SandboxConfigManager:
return [sandbox.to_pydantic() for sandbox in sandboxes]
@enforce_types
@trace_method
def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox configuration by its ID."""
with db_registry.session() as session:
@ -150,6 +240,7 @@ class SandboxConfigManager:
return None
@enforce_types
@trace_method
def get_sandbox_config_by_type(self, type: SandboxType, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox config by its type."""
with db_registry.session() as session:
@ -167,6 +258,27 @@ class SandboxConfigManager:
return None
@enforce_types
@trace_method
async def get_sandbox_config_by_type_async(
self, type: SandboxType, actor: Optional[PydanticUser] = None
) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox config by its type."""
async with db_registry.async_session() as session:
try:
sandboxes = await SandboxConfigModel.list_async(
db_session=session,
type=type,
organization_id=actor.organization_id,
limit=1,
)
if sandboxes:
return sandboxes[0].to_pydantic()
return None
except NoResultFound:
return None
@enforce_types
@trace_method
def create_sandbox_env_var(
self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser
) -> PydanticEnvVar:
@ -194,6 +306,7 @@ class SandboxConfigManager:
return env_var.to_pydantic()
@enforce_types
@trace_method
def update_sandbox_env_var(
self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser
) -> PydanticEnvVar:
@ -215,6 +328,7 @@ class SandboxConfigManager:
return env_var.to_pydantic()
@enforce_types
@trace_method
def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar:
"""Delete a sandbox environment variable by its ID."""
with db_registry.session() as session:
@ -223,6 +337,7 @@ class SandboxConfigManager:
return env_var.to_pydantic()
@enforce_types
@trace_method
def list_sandbox_env_vars(
self,
sandbox_config_id: str,
@ -242,6 +357,7 @@ class SandboxConfigManager:
return [env_var.to_pydantic() for env_var in env_vars]
@enforce_types
@trace_method
async def list_sandbox_env_vars_async(
self,
sandbox_config_id: str,
@ -261,6 +377,7 @@ class SandboxConfigManager:
return [env_var.to_pydantic() for env_var in env_vars]
@enforce_types
@trace_method
def list_sandbox_env_vars_by_key(
self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> List[PydanticEnvVar]:
@ -276,6 +393,7 @@ class SandboxConfigManager:
return [env_var.to_pydantic() for env_var in env_vars]
@enforce_types
@trace_method
def get_sandbox_env_vars_as_dict(
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> Dict[str, str]:
@ -286,6 +404,18 @@ class SandboxConfigManager:
return result
@enforce_types
@trace_method
async def get_sandbox_env_vars_as_dict_async(
self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> Dict[str, str]:
env_vars = await self.list_sandbox_env_vars_async(sandbox_config_id, actor, after, limit)
result = {}
for env_var in env_vars:
result[env_var.key] = env_var.value
return result
@enforce_types
@trace_method
def get_sandbox_env_var_by_key_and_sandbox_config_id(
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
) -> Optional[PydanticEnvVar]:

View File

@ -1,3 +1,4 @@
import asyncio
from typing import List, Optional
from letta.orm.errors import NoResultFound
@ -9,6 +10,7 @@ from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types, printd
@ -16,25 +18,27 @@ class SourceManager:
"""Manager class to handle business logic related to Sources."""
@enforce_types
def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
@trace_method
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
"""Create a new source based on the PydanticSource schema."""
# Try getting the source first by id
db_source = self.get_source_by_id(source.id, actor=actor)
db_source = await self.get_source_by_id(source.id, actor=actor)
if db_source:
return db_source
else:
with db_registry.session() as session:
async with db_registry.async_session() as session:
# Provide default embedding config if not given
source.organization_id = actor.organization_id
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
source.create(session, actor=actor)
await source.create_async(session, actor=actor)
return source.to_pydantic()
@enforce_types
def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
@trace_method
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
"""Update a source by its ID with the given SourceUpdate object."""
with db_registry.session() as session:
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# get update dictionary
update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
@ -53,18 +57,22 @@ class SourceManager:
return source.to_pydantic()
@enforce_types
def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
@trace_method
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
"""Delete a source by its ID."""
with db_registry.session() as session:
source = SourceModel.read(db_session=session, identifier=source_id)
source.hard_delete(db_session=session, actor=actor)
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id)
await source.hard_delete_async(db_session=session, actor=actor)
return source.to_pydantic()
@enforce_types
def list_sources(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]:
@trace_method
async def list_sources(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs
) -> List[PydanticSource]:
"""List all sources with optional pagination."""
with db_registry.session() as session:
sources = SourceModel.list(
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
after=after,
limit=limit,
@ -74,18 +82,17 @@ class SourceManager:
return [source.to_pydantic() for source in sources]
@enforce_types
def size(
self,
actor: PydanticUser,
) -> int:
@trace_method
async def size(self, actor: PydanticUser) -> int:
"""
Get the total count of sources for the given user.
"""
with db_registry.session() as session:
return SourceModel.size(db_session=session, actor=actor)
async with db_registry.async_session() as session:
return await SourceModel.size_async(db_session=session, actor=actor)
@enforce_types
def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
@trace_method
async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
"""
Lists all agents that have the specified source attached.
@ -96,30 +103,33 @@ class SourceManager:
Returns:
List[PydanticAgentState]: List of agents that have this source attached
"""
with db_registry.session() as session:
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# The agents relationship is already loaded due to lazy="selectin" in the Source model
# and will be properly filtered by organization_id due to the OrganizationMixin
return [agent.to_pydantic() for agent in source.agents]
agents_orm = source.agents
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
@trace_method
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
"""Retrieve a source by its ID."""
with db_registry.session() as session:
async with db_registry.async_session() as session:
try:
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
return source.to_pydantic()
except NoResultFound:
return None
@enforce_types
def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
@trace_method
async def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
"""Retrieve a source by its name."""
with db_registry.session() as session:
sources = SourceModel.list(
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
name=source_name,
organization_id=actor.organization_id,
@ -131,44 +141,49 @@ class SourceManager:
return sources[0].to_pydantic()
@enforce_types
def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata:
@trace_method
async def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata:
"""Create a new file based on the PydanticFileMetadata schema."""
db_file = self.get_file_by_id(file_metadata.id, actor=actor)
db_file = await self.get_file_by_id(file_metadata.id, actor=actor)
if db_file:
return db_file
else:
with db_registry.session() as session:
async with db_registry.async_session() as session:
file_metadata.organization_id = actor.organization_id
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
file_metadata.create(session, actor=actor)
await file_metadata.create_async(session, actor=actor)
return file_metadata.to_pydantic()
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]:
@trace_method
async def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]:
"""Retrieve a file by its ID."""
with db_registry.session() as session:
async with db_registry.async_session() as session:
try:
file = FileMetadataModel.read(db_session=session, identifier=file_id, actor=actor)
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
return file.to_pydantic()
except NoResultFound:
return None
@enforce_types
def list_files(
@trace_method
async def list_files(
self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> List[PydanticFileMetadata]:
"""List all files with optional pagination."""
with db_registry.session() as session:
files = FileMetadataModel.list(
async with db_registry.async_session() as session:
files_all = await FileMetadataModel.list_async(db_session=session, organization_id=actor.organization_id, source_id=source_id)
files = await FileMetadataModel.list_async(
db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id
)
return [file.to_pydantic() for file in files]
@enforce_types
def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
@trace_method
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
"""Delete a file by its ID."""
with db_registry.session() as session:
file = FileMetadataModel.read(db_session=session, identifier=file_id)
file.hard_delete(db_session=session, actor=actor)
async with db_registry.async_session() as session:
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
await file.hard_delete_async(db_session=session, actor=actor)
return file.to_pydantic()

View File

@ -14,13 +14,14 @@ from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.noop_helper import singleton
from letta.tracing import get_trace_id
from letta.tracing import get_trace_id, trace_method
from letta.utils import enforce_types
class StepManager:
@enforce_types
@trace_method
def list_steps(
self,
actor: PydanticUser,
@ -54,6 +55,7 @@ class StepManager:
return [step.to_pydantic() for step in steps]
@enforce_types
@trace_method
def log_step(
self,
actor: PydanticUser,
@ -96,6 +98,7 @@ class StepManager:
return new_step.to_pydantic()
@enforce_types
@trace_method
async def log_step_async(
self,
actor: PydanticUser,
@ -138,12 +141,14 @@ class StepManager:
return new_step.to_pydantic()
@enforce_types
@trace_method
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
with db_registry.session() as session:
step = StepModel.read(db_session=session, identifier=step_id, actor=actor)
return step.to_pydantic()
@enforce_types
@trace_method
def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep:
"""Update the transaction ID for a step.
@ -236,6 +241,7 @@ class NoopStepManager(StepManager):
"""
@enforce_types
@trace_method
def log_step(
self,
actor: PydanticUser,
@ -253,6 +259,7 @@ class NoopStepManager(StepManager):
return
@enforce_types
@trace_method
async def log_step_async(
self,
actor: PydanticUser,

View File

@ -26,6 +26,7 @@ from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import trace_method
from letta.utils import enforce_types, printd
logger = get_logger(__name__)
@ -36,6 +37,7 @@ class ToolManager:
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
@enforce_types
@trace_method
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor)
@ -62,6 +64,7 @@ class ToolManager:
return tool
@enforce_types
@trace_method
async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor)
@ -88,6 +91,7 @@ class ToolManager:
return tool
@enforce_types
@trace_method
def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
return self.create_or_update_tool(
@ -98,18 +102,21 @@ class ToolManager:
)
@enforce_types
@trace_method
def create_or_update_composio_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
return self.create_or_update_tool(
PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor
)
@enforce_types
@trace_method
def create_or_update_langchain_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
return self.create_or_update_tool(
PydanticTool(tool_type=ToolType.EXTERNAL_LANGCHAIN, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor
)
@enforce_types
@trace_method
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
with db_registry.session() as session:
@ -125,6 +132,7 @@ class ToolManager:
return tool.to_pydantic()
@enforce_types
@trace_method
async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
async with db_registry.async_session() as session:
@ -140,6 +148,7 @@ class ToolManager:
return tool.to_pydantic()
@enforce_types
@trace_method
def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
"""Fetch a tool by its ID."""
with db_registry.session() as session:
@ -149,6 +158,7 @@ class ToolManager:
return tool.to_pydantic()
@enforce_types
@trace_method
async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
"""Fetch a tool by its ID."""
async with db_registry.async_session() as session:
@ -158,6 +168,7 @@ class ToolManager:
return tool.to_pydantic()
@enforce_types
@trace_method
def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
@ -168,6 +179,7 @@ class ToolManager:
return None
@enforce_types
@trace_method
async def get_tool_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
@ -178,6 +190,7 @@ class ToolManager:
return None
@enforce_types
@trace_method
def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
@ -188,6 +201,7 @@ class ToolManager:
return None
@enforce_types
@trace_method
async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
@ -198,6 +212,7 @@ class ToolManager:
return None
@enforce_types
@trace_method
async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
"""List all tools with optional pagination."""
async with db_registry.async_session() as session:
@ -223,6 +238,7 @@ class ToolManager:
return results
@enforce_types
@trace_method
def size(
self,
actor: PydanticUser,
@ -239,6 +255,7 @@ class ToolManager:
return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET)
@enforce_types
@trace_method
def update_tool_by_id(
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
) -> PydanticTool:
@ -267,6 +284,7 @@ class ToolManager:
return tool.update(db_session=session, actor=actor).to_pydantic()
@enforce_types
@trace_method
async def update_tool_by_id_async(
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
) -> PydanticTool:
@ -296,6 +314,7 @@ class ToolManager:
return tool.to_pydantic()
@enforce_types
@trace_method
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
"""Delete a tool by its ID."""
with db_registry.session() as session:
@ -306,6 +325,7 @@ class ToolManager:
raise ValueError(f"Tool with id {tool_id} not found.")
@enforce_types
@trace_method
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py and multi_agent.py"""
functions_to_schema = {}
@ -371,6 +391,7 @@ class ToolManager:
return tools
@enforce_types
@trace_method
async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py and multi_agent.py"""
functions_to_schema = {}

View File

@ -53,7 +53,9 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user)
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.E2B, actor=self.user
)
# TODO: So this defaults to force recreating always
# TODO: Eventually, provision one sandbox PER agent, and that agent re-uses that one specifically
e2b_sandbox = await self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config)
@ -71,7 +73,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
if self.provided_sandbox_env_vars:
env_vars.update(self.provided_sandbox_env_vars)
else:
db_env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(
db_env_vars = await self.sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
sandbox_config_id=sbx_config.id, actor=self.user, limit=100
)
env_vars.update(db_env_vars)

View File

@ -60,14 +60,16 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
additional_env_vars: Optional[Dict],
) -> ToolExecutionResult:
"""
Unified asynchronougit pus method to run the tool in a local sandbox environment,
Unified asynchronous method to run the tool in a local sandbox environment,
always via subprocess for multi-core parallelism.
"""
# Get sandbox configuration
if self.provided_sandbox_config:
sbx_config = self.provided_sandbox_config
else:
sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user)
sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async(
sandbox_type=SandboxType.LOCAL, actor=self.user
)
local_configs = sbx_config.get_local_config()
use_venv = local_configs.use_venv
@ -76,7 +78,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
if self.provided_sandbox_env_vars:
env.update(self.provided_sandbox_env_vars)
else:
env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100)
env_vars = await self.sandbox_config_manager.get_sandbox_env_vars_as_dict_async(
sandbox_config_id=sbx_config.id, actor=self.user, limit=100
)
env.update(env_vars)
if agent_state:

View File

@ -7,6 +7,7 @@ from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.services.organization_manager import OrganizationManager
from letta.tracing import trace_method
from letta.utils import enforce_types
@ -17,6 +18,7 @@ class UserManager:
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
@enforce_types
@trace_method
def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
with db_registry.session() as session:
@ -37,6 +39,7 @@ class UserManager:
return user.to_pydantic()
@enforce_types
@trace_method
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist."""
with db_registry.session() as session:
@ -45,6 +48,7 @@ class UserManager:
return new_user.to_pydantic()
@enforce_types
@trace_method
async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist (async version)."""
async with db_registry.async_session() as session:
@ -53,6 +57,7 @@ class UserManager:
return new_user.to_pydantic()
@enforce_types
@trace_method
def update_user(self, user_update: UserUpdate) -> PydanticUser:
"""Update user details."""
with db_registry.session() as session:
@ -69,6 +74,7 @@ class UserManager:
return existing_user.to_pydantic()
@enforce_types
@trace_method
async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser:
"""Update user details (async version)."""
async with db_registry.async_session() as session:
@ -85,6 +91,7 @@ class UserManager:
return existing_user.to_pydantic()
@enforce_types
@trace_method
def delete_user_by_id(self, user_id: str):
"""Delete a user and their associated records (agents, sources, mappings)."""
with db_registry.session() as session:
@ -95,6 +102,7 @@ class UserManager:
session.commit()
@enforce_types
@trace_method
async def delete_actor_by_id_async(self, user_id: str):
"""Delete a user and their associated records (agents, sources, mappings) asynchronously."""
async with db_registry.async_session() as session:
@ -103,6 +111,7 @@ class UserManager:
await user.hard_delete_async(session)
@enforce_types
@trace_method
def get_user_by_id(self, user_id: str) -> PydanticUser:
"""Fetch a user by ID."""
with db_registry.session() as session:
@ -110,6 +119,7 @@ class UserManager:
return user.to_pydantic()
@enforce_types
@trace_method
async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser:
"""Fetch a user by ID asynchronously."""
async with db_registry.async_session() as session:
@ -117,6 +127,7 @@ class UserManager:
return user.to_pydantic()
@enforce_types
@trace_method
def get_default_user(self) -> PydanticUser:
"""Fetch the default user. If it doesn't exist, create it."""
try:
@ -125,6 +136,7 @@ class UserManager:
return self.create_default_user()
@enforce_types
@trace_method
def get_user_or_default(self, user_id: Optional[str] = None):
"""Fetch the user or default user."""
if not user_id:
@ -136,6 +148,7 @@ class UserManager:
return self.get_default_user()
@enforce_types
@trace_method
async def get_default_actor_async(self) -> PydanticUser:
"""Fetch the default user asynchronously. If it doesn't exist, create it."""
try:
@ -145,6 +158,7 @@ class UserManager:
return self.create_default_user(org_id=self.DEFAULT_ORG_ID)
@enforce_types
@trace_method
async def get_actor_or_default_async(self, actor_id: Optional[str] = None):
"""Fetch the user or default user asynchronously."""
if not actor_id:
@ -156,6 +170,7 @@ class UserManager:
return await self.get_default_actor_async()
@enforce_types
@trace_method
def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
"""List all users with optional pagination."""
with db_registry.session() as session:
@ -167,6 +182,7 @@ class UserManager:
return [user.to_pydantic() for user in users]
@enforce_types
@trace_method
async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
"""List all users with optional pagination (async version)."""
async with db_registry.async_session() as session:

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "letta"
version = "0.7.21"
version = "0.7.22"
packages = [
{include = "letta"},
]

View File

@ -1 +1,3 @@
TIMEOUT = 30 # seconds
embedding_config_dir = "tests/configs/embedding_model_configs"
llm_config_dir = "tests/configs/llm_model_configs"

View File

@ -1,13 +1,12 @@
import time
from typing import Union
from letta import LocalClient, RESTClient
from letta import RESTClient
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job
from letta.schemas.source import Source
def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job:
def upload_file_using_client(client: RESTClient, source: Source, filename: str) -> Job:
# load a file into a source (non-blocking job)
upload_job = client.load_file_to_source(filename=filename, source_id=source.id, blocking=False)
print("Upload job", upload_job, upload_job.status, upload_job.metadata)

View File

@ -1,33 +1,28 @@
import json
import logging
import uuid
from typing import Callable, List, Optional, Sequence, Union
from typing import Callable, List, Optional, Sequence
from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs
from letta.schemas.block import CreateBlock
from letta.schemas.tool_rule import BaseToolRule
from letta.server.server import SyncServer
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from letta import LocalClient, RESTClient, create_client
from letta.agent import Agent
from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.embeddings import embedding_model
from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError
from letta.helpers.json_helpers import json_dumps
from letta.llm_api.llm_api_tools import create
from letta.llm_api.llm_client import LLMClient
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.agent import AgentState
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCallMessage
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message
from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message
from letta.utils import get_human_text, get_persona_text
from tests.helpers.utils import cleanup
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
@ -45,7 +40,7 @@ LLM_CONFIG_PATH = "tests/configs/llm_model_configs/letta-hosted.json"
def setup_agent(
client: Union[LocalClient, RESTClient],
server: SyncServer,
filename: str,
memory_human_str: str = get_human_text(DEFAULT_HUMAN),
memory_persona_str: str = get_persona_text(DEFAULT_PERSONA),
@ -65,17 +60,27 @@ def setup_agent(
config.default_embedding_config = embedding_config
config.save()
memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
agent_state = client.create_agent(
request = CreateAgent(
name=agent_uuid,
llm_config=llm_config,
embedding_config=embedding_config,
memory=memory,
memory_blocks=[
CreateBlock(
label="human",
value=memory_human_str,
),
CreateBlock(
label="persona",
value=memory_persona_str,
),
],
tool_ids=tool_ids,
tool_rules=tool_rules,
include_base_tools=include_base_tools,
include_base_tool_rules=include_base_tool_rules,
)
actor = server.user_manager.get_user_or_default()
agent_state = server.create_agent(request=request, actor=actor)
return agent_state
@ -86,285 +91,6 @@ def setup_agent(
# ======================================================================================================================
def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner_monologue_contents: bool = True) -> ChatCompletionResponse:
"""
Checks that the first response is valid:
1. Contains either send_message or archival_memory_search
2. Contains valid usage of the function
3. Contains inner monologue
Note: This is acting on the raw LLM response, note the usage of `create`
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
full_agent_state = client.get_agent(agent_state.id)
messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user)
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
actor=client.user,
)
if llm_client:
response = llm_client.send_llm_request(
messages=messages,
llm_config=agent_state.llm_config,
tools=[t.json_schema for t in agent.agent_state.tools],
)
else:
response = create(
llm_config=agent_state.llm_config,
user_id=str(uuid.UUID(int=1)), # dummy user_id
messages=messages,
functions=[t.json_schema for t in agent.agent_state.tools],
)
# Basic check
assert response is not None, response
assert response.choices is not None, response
assert len(response.choices) > 0, response
assert response.choices[0] is not None, response
# Select first choice
choice = response.choices[0]
# Ensure that the first message returns a "send_message"
validator_func = (
lambda function_call: function_call.name == "send_message"
or function_call.name == "archival_memory_search"
or function_call.name == "core_memory_append"
)
assert_contains_valid_function_call(choice.message, validator_func)
# Assert that the message has an inner monologue
assert_contains_correct_inner_monologue(
choice,
agent_state.llm_config.put_inner_thoughts_in_kwargs,
validate_inner_monologue_contents=validate_inner_monologue_contents,
)
return response
def check_response_contains_keyword(filename: str, keyword="banana") -> LettaResponse:
"""
Checks that the prompted response from the LLM contains a chosen keyword
Note: This is acting on the Letta response, note the usage of `user_message`
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"'
response = client.user_message(agent_id=agent_state.id, message=keyword_message)
# Basic checks
assert_sanity_checks(response)
# Make sure the message was sent
assert_invoked_send_message_with_keyword(response.messages, keyword)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_uses_external_tool(filename: str) -> LettaResponse:
"""
Checks that the LLM will use external tools if instructed
Note: This is acting on the Letta response, note the usage of `user_message`
"""
from composio import Action
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
# Set up persona for tool usage
persona = f"""
My name is Letta.
I am a personal assistant who uses a tool called {tool.name} to star a desired github repo.
Dont forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
"""
agent_state = setup_agent(client, filename, memory_persona_str=persona, tool_ids=[tool.id])
response = client.user_message(agent_id=agent_state.id, message="Please star the repo with owner=letta-ai and repo=letta")
# Basic checks
assert_sanity_checks(response)
# Make sure the tool was called
assert_invoked_function_call(response.messages, tool.name)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_recall_chat_memory(filename: str) -> LettaResponse:
"""
Checks that the LLM will recall the chat memory, specifically the human persona.
Note: This is acting on the Letta response, note the usage of `user_message`
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
human_name = "BananaBoy"
agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name}.")
response = client.user_message(
agent_id=agent_state.id, message="Repeat my name back to me. You should search in your human memory block."
)
# Basic checks
assert_sanity_checks(response)
# Make sure my name was repeated back to me
assert_invoked_send_message_with_keyword(response.messages, human_name)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_archival_memory_insert(filename: str) -> LettaResponse:
"""
Checks that the LLM will execute an archival memory insert.
Note: This is acting on the Letta response, note the usage of `user_message`
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
secret_word = "banana"
response = client.user_message(
agent_id=agent_state.id,
message=f"Please insert the secret word '{secret_word}' into archival memory.",
)
# Basic checks
assert_sanity_checks(response)
# Make sure archival_memory_search was called
assert_invoked_function_call(response.messages, "archival_memory_insert")
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse:
"""
Checks that the LLM will execute an archival memory retrieval.
Note: This is acting on the Letta response, note the usage of `user_message`
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
secret_word = "banana"
client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!")
response = client.user_message(
agent_id=agent_state.id,
message="Search archival memory for the secret word. If you find it successfully, you MUST respond by using the `send_message` function with a message that includes the secret word so I know you found it.",
)
# Basic checks
assert_sanity_checks(response)
# Make sure archival_memory_search was called
assert_invoked_function_call(response.messages, "archival_memory_search")
# Make sure secret was repeated back to me
assert_invoked_send_message_with_keyword(response.messages, secret_word)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_edit_core_memory(filename: str) -> LettaResponse:
"""
Checks that the LLM is able to edit its core memories
Note: This is acting on the Letta response, note the usage of `user_message`
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
human_name_a = "AngryAardvark"
human_name_b = "BananaBoy"
agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name_a}")
client.user_message(agent_id=agent_state.id, message=f"Actually, my name changed. It is now {human_name_b}")
response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.")
# Basic checks
assert_sanity_checks(response)
# Make sure my name was repeated back to me
assert_invoked_send_message_with_keyword(response.messages, human_name_b)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def check_agent_summarize_memory_simple(filename: str) -> LettaResponse:
"""
Checks that the LLM is able to summarize its memory
"""
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
# Send a couple messages
friend_name = "Shub"
client.user_message(agent_id=agent_state.id, message="Hey, how's it going? What do you think about this whole shindig")
client.user_message(agent_id=agent_state.id, message=f"By the way, my friend's name is {friend_name}!")
client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?")
# Summarize
agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
agent.summarize_messages_inplace()
print(f"Summarization succeeded: messages[1] = \n\n{json_dumps(agent.messages[1])}\n")
response = client.user_message(agent_id=agent_state.id, message="What is my friend's name?")
# Basic checks
assert_sanity_checks(response)
# Make sure my name was repeated back to me
assert_invoked_send_message_with_keyword(response.messages, friend_name)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
return response
def run_embedding_endpoint(filename):
# load JSON file
config_data = json.load(open(filename, "r"))

View File

@ -2,12 +2,13 @@ import functools
import time
from typing import Union
from letta import LocalClient, RESTClient
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.tool import Tool
from letta.schemas.user import User
from letta.schemas.user import User as PydanticUser
from letta.server.server import SyncServer
def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
@ -75,12 +76,12 @@ def retry_until_success(max_attempts=10, sleep_time_seconds=4):
return decorator_retry
def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str):
def cleanup(server: SyncServer, agent_uuid: str, actor: User):
# Clear all agents
for agent_state in client.list_agents():
if agent_state.name == agent_uuid:
client.delete_agent(agent_id=agent_state.id)
print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}")
agent_states = server.agent_manager.list_agents(name=agent_uuid, actor=actor)
for agent_state in agent_states:
server.agent_manager.delete_agent(agent_id=agent_state.id, actor=actor)
# Utility functions

View File

@ -3,16 +3,20 @@ import uuid
import pytest
from letta import create_client
from letta.config import LettaConfig
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
from letta.server.server import SyncServer
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
assert_sanity_checks,
setup_agent,
)
from tests.helpers.utils import cleanup, retry_until_success
from tests.helpers.utils import cleanup
from tests.utils import create_tool_from_func
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
@ -20,106 +24,175 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
"""Contrived tools for this test case"""
@pytest.fixture()
def server():
config = LettaConfig.load()
config.save()
server = SyncServer()
return server
def first_secret_word():
"""
Call this to retrieve the first secret word, which you will need for the second_secret_word function.
"""
return "v0iq020i0g"
@pytest.fixture(scope="function")
def first_secret_tool(server):
def first_secret_word():
"""
Retrieves the initial secret word in a multi-step sequence.
Returns:
str: The first secret word.
"""
return "v0iq020i0g"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=first_secret_word), actor=actor)
yield tool
def second_secret_word(prev_secret_word: str):
"""
Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error.
@pytest.fixture(scope="function")
def second_secret_tool(server):
def second_secret_word(prev_secret_word: str):
"""
Retrieves the second secret word.
Args:
prev_secret_word (str): The secret word retrieved from calling first_secret_word.
"""
if prev_secret_word != "v0iq020i0g":
raise RuntimeError(f"Expected secret {'v0iq020i0g'}, got {prev_secret_word}")
Args:
prev_secret_word (str): The previously retrieved secret word.
return "4rwp2b4gxq"
Returns:
str: The second secret word.
"""
if prev_secret_word != "v0iq020i0g":
raise RuntimeError(f"Expected secret {'v0iq020i0g'}, got {prev_secret_word}")
return "4rwp2b4gxq"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=second_secret_word), actor=actor)
yield tool
def third_secret_word(prev_secret_word: str):
"""
Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error.
@pytest.fixture(scope="function")
def third_secret_tool(server):
def third_secret_word(prev_secret_word: str):
"""
Retrieves the third secret word.
Args:
prev_secret_word (str): The secret word retrieved from calling second_secret_word.
"""
if prev_secret_word != "4rwp2b4gxq":
raise RuntimeError(f'Expected secret "4rwp2b4gxq", got {prev_secret_word}')
Args:
prev_secret_word (str): The previously retrieved secret word.
return "hj2hwibbqm"
Returns:
str: The third secret word.
"""
if prev_secret_word != "4rwp2b4gxq":
raise RuntimeError(f'Expected secret "4rwp2b4gxq", got {prev_secret_word}')
return "hj2hwibbqm"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=third_secret_word), actor=actor)
yield tool
def fourth_secret_word(prev_secret_word: str):
"""
Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error.
@pytest.fixture(scope="function")
def fourth_secret_tool(server):
def fourth_secret_word(prev_secret_word: str):
"""
Retrieves the final secret word.
Args:
prev_secret_word (str): The secret word retrieved from calling third_secret_word.
"""
if prev_secret_word != "hj2hwibbqm":
raise RuntimeError(f"Expected secret {'hj2hwibbqm'}, got {prev_secret_word}")
Args:
prev_secret_word (str): The previously retrieved secret word.
return "banana"
Returns:
str: The final secret word.
"""
if prev_secret_word != "hj2hwibbqm":
raise RuntimeError(f"Expected secret {'hj2hwibbqm'}, got {prev_secret_word}")
return "banana"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=fourth_secret_word), actor=actor)
yield tool
def flip_coin():
"""
Call this to retrieve the password to the secret word, which you will need to output in a send_message later.
If it returns an empty string, try flipping again!
@pytest.fixture(scope="function")
def flip_coin_tool(server):
def flip_coin():
"""
Simulates a coin flip with a chance to return a secret word.
Returns:
str: The password or an empty string
"""
import random
Returns:
str: A secret word or an empty string.
"""
import random
# Flip a coin with 50% chance
if random.random() < 0.5:
return ""
return "hj2hwibbqm"
return "" if random.random() < 0.5 else "hj2hwibbqm"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=flip_coin), actor=actor)
yield tool
def can_play_game():
"""
Call this to start the tool chain.
"""
import random
@pytest.fixture(scope="function")
def can_play_game_tool(server):
def can_play_game():
"""
Determines whether a game can be played.
return random.random() < 0.5
Returns:
bool: True if allowed to play, False otherwise.
"""
import random
return random.random() < 0.5
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=can_play_game), actor=actor)
yield tool
def return_none():
"""
Really simple function
"""
return None
@pytest.fixture(scope="function")
def return_none_tool(server):
def return_none():
"""
Always returns None.
Returns:
None
"""
return None
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=return_none), actor=actor)
yield tool
def auto_error():
"""
If you call this function, it will throw an error automatically.
"""
raise RuntimeError("This should never be called.")
@pytest.fixture(scope="function")
def auto_error_tool(server):
def auto_error():
"""
Always raises an error when called.
Raises:
RuntimeError: Always triggered.
"""
raise RuntimeError("This should never be called.")
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=auto_error), actor=actor)
yield tool
@pytest.fixture
def default_user(server):
yield server.user_manager.get_user_or_default()
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
def test_single_path_agent_tool_call_graph(disable_e2b_api_key):
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
def test_single_path_agent_tool_call_graph(
server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool, default_user
):
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
# Add tools
t1 = client.create_or_update_tool(first_secret_word)
t2 = client.create_or_update_tool(second_secret_word)
t3 = client.create_or_update_tool(third_secret_word)
t4 = client.create_or_update_tool(fourth_secret_word)
t_err = client.create_or_update_tool(auto_error)
tools = [t1, t2, t3, t4, t_err]
tools = [first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool]
# Make tool rules
tool_rules = [
@ -132,8 +205,18 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key):
]
# Make agent state
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?")
agent_state = setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
usage_stats = server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
)
messages = [message for step_messages in usage_stats.steps_messages for message in step_messages]
letta_messages = []
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(messages=letta_messages, usage=usage_stats)
# Make checks
assert_sanity_checks(response)
@ -145,7 +228,7 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key):
assert_invoked_function_call(response.messages, "fourth_secret_word")
# Check ordering of tool calls
tool_names = [t.name for t in [t1, t2, t3, t4]]
tool_names = [t.name for t in [first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool]]
tool_names += ["send_message"]
for m in response.messages:
if isinstance(m, ToolCallMessage):
@ -159,171 +242,281 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key):
assert_invoked_send_message_with_keyword(response.messages, "banana")
print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
def test_check_tool_rules_with_different_models(disable_e2b_api_key):
"""Test that tool rules are properly checked for different model configurations."""
client = create_client()
config_files = [
@pytest.mark.timeout(60)
@pytest.mark.parametrize(
"config_file",
[
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
"tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
]
],
)
@pytest.mark.parametrize("init_tools_case", ["single", "multiple"])
def test_check_tool_rules_with_different_models_parametrized(
server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, default_user, config_file, init_tools_case
):
"""Test that tool rules are properly validated across model configurations and init tool scenarios."""
agent_uuid = str(uuid.uuid4())
# Create two test tools
t1_name = "first_secret_word"
t2_name = "second_secret_word"
t1 = client.create_or_update_tool(first_secret_word)
t2 = client.create_or_update_tool(second_secret_word)
tool_rules = [InitToolRule(tool_name=t1_name), InitToolRule(tool_name=t2_name)]
tools = [t1, t2]
if init_tools_case == "multiple":
tools = [first_secret_tool, second_secret_tool]
tool_rules = [
InitToolRule(tool_name=first_secret_tool.name),
InitToolRule(tool_name=second_secret_tool.name),
]
else: # "single"
tools = [third_secret_tool]
tool_rules = [InitToolRule(tool_name=third_secret_tool.name)]
for config_file in config_files:
# Setup tools
agent_uuid = str(uuid.uuid4())
if "gpt-4o" in config_file:
# Structured output model (should work with multiple init tools)
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
assert agent_state is not None
else:
# Non-structured output model (should raise error with multiple init tools)
with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"):
setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Cleanup
cleanup(client=client, agent_uuid=agent_uuid)
# Create tool rule with single initial tool
t3_name = "third_secret_word"
t3 = client.create_or_update_tool(third_secret_word)
tool_rules = [InitToolRule(tool_name=t3_name)]
tools = [t3]
for config_file in config_files:
agent_uuid = str(uuid.uuid4())
# Structured output model (should work with single init tool)
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
if "gpt-4o" in config_file or init_tools_case == "single":
# Should succeed
agent_state = setup_agent(
server,
config_file,
agent_uuid=agent_uuid,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
assert agent_state is not None
else:
# Non-structured model with multiple init tools should fail
with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"):
setup_agent(
server,
config_file,
agent_uuid=agent_uuid,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
cleanup(client=client, agent_uuid=agent_uuid)
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
def test_claude_initial_tool_rule_enforced(disable_e2b_api_key):
"""Test that the initial tool rule is enforced for the first message."""
client = create_client()
# Create tool rules that require tool_a to be called first
t1_name = "first_secret_word"
t2_name = "second_secret_word"
t1 = client.create_or_update_tool(first_secret_word)
t2 = client.create_or_update_tool(second_secret_word)
@pytest.mark.timeout(180)
def test_claude_initial_tool_rule_enforced(
server,
disable_e2b_api_key,
first_secret_tool,
second_secret_tool,
default_user,
):
"""Test that the initial tool rule is enforced for the first message using Claude model."""
tool_rules = [
InitToolRule(tool_name=t1_name),
ChildToolRule(tool_name=t1_name, children=[t2_name]),
TerminalToolRule(tool_name=t2_name),
InitToolRule(tool_name=first_secret_tool.name),
ChildToolRule(tool_name=first_secret_tool.name, children=[second_secret_tool.name]),
TerminalToolRule(tool_name=second_secret_tool.name),
]
tools = [t1, t2]
# Make agent state
tools = [first_secret_tool, second_secret_tool]
anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
for i in range(3):
agent_uuid = str(uuid.uuid4())
agent_state = setup_agent(
client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules
server,
anthropic_config_file,
agent_uuid=agent_uuid,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
response = client.user_message(agent_id=agent_state.id, message="What is the second secret word?")
usage_stats = server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="What is the second secret word?")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = []
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(messages=letta_messages, usage=usage_stats)
assert_sanity_checks(response)
messages = response.messages
assert_invoked_function_call(messages, "first_secret_word")
assert_invoked_function_call(messages, "second_secret_word")
# Check that the expected tools were invoked
assert_invoked_function_call(response.messages, "first_secret_word")
assert_invoked_function_call(response.messages, "second_secret_word")
tool_names = [t.name for t in [t1, t2]]
tool_names += ["send_message"]
for m in messages:
tool_names = [t.name for t in [first_secret_tool, second_secret_tool]] + ["send_message"]
for m in response.messages:
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.tool_call.name == tool_names[0]
# Pop out first one
tool_names = tool_names[1:]
print(f"Passed iteration {i}")
cleanup(client=client, agent_uuid=agent_uuid)
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
# Implement exponential backoff with initial time of 10 seconds
# Exponential backoff
if i < 2:
backoff_time = 10 * (2**i)
time.sleep(backoff_time)
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key):
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
@pytest.mark.timeout(60)
@pytest.mark.parametrize(
"config_file",
[
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
],
)
def test_agent_no_structured_output_with_one_child_tool_parametrized(
server,
disable_e2b_api_key,
default_user,
config_file,
):
"""Test that agent correctly calls tool chains with unstructured output under various model configs."""
send_message = server.tool_manager.get_tool_by_name(tool_name="send_message", actor=default_user)
archival_memory_search = server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=default_user)
archival_memory_insert = server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=default_user)
send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user)
archival_memory_search = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=client.user)
archival_memory_insert = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=client.user)
tools = [send_message, archival_memory_search, archival_memory_insert]
# Make tool rules
tool_rules = [
InitToolRule(tool_name="archival_memory_search"),
ChildToolRule(tool_name="archival_memory_search", children=["archival_memory_insert"]),
ChildToolRule(tool_name="archival_memory_insert", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
tools = [send_message, archival_memory_search, archival_memory_insert]
config_files = [
"tests/configs/llm_model_configs/claude-3-5-sonnet.json",
"tests/configs/llm_model_configs/openai-gpt-4o.json",
max_retries = 3
last_error = None
agent_uuid = str(uuid.uuid4())
for attempt in range(max_retries):
try:
agent_state = setup_agent(
server,
config_file,
agent_uuid=agent_uuid,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
usage_stats = server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="hi. run archival memory search")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = []
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(messages=letta_messages, usage=usage_stats)
# Run assertions
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "archival_memory_search")
assert_invoked_function_call(response.messages, "archival_memory_insert")
assert_invoked_function_call(response.messages, "send_message")
tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]]
for m in response.messages:
if isinstance(m, ToolCallMessage):
assert m.tool_call.name == tool_names[0]
tool_names = tool_names[1:]
print(f"[{config_file}] Got successful response:\n\n{response}")
break # success
except AssertionError as e:
last_error = e
print(f"[{config_file}] Attempt {attempt + 1} failed")
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
if last_error:
raise last_error
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
@pytest.mark.timeout(30)
@pytest.mark.parametrize("include_base_tools", [False, True])
def test_init_tool_rule_always_fails(
server,
disable_e2b_api_key,
auto_error_tool,
default_user,
include_base_tools,
):
"""Test behavior when InitToolRule invokes a tool that always fails."""
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_uuid = str(uuid.uuid4())
tool_rule = InitToolRule(tool_name=auto_error_tool.name)
agent_state = setup_agent(
server,
config_file,
agent_uuid=agent_uuid,
tool_ids=[auto_error_tool.id],
tool_rules=[tool_rule],
include_base_tools=include_base_tools,
)
usage_stats = server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="blah blah blah")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
response = LettaResponse(messages=letta_messages, usage=usage_stats)
assert_invoked_function_call(response.messages, auto_error_tool.name)
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
def test_continue_tool_rule(server, default_user):
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_uuid = str(uuid.uuid4())
tool_ids = [
server.tool_manager.get_tool_by_name("send_message", actor=default_user).id,
server.tool_manager.get_tool_by_name("core_memory_append", actor=default_user).id,
]
for config in config_files:
max_retries = 3
last_error = None
tool_rules = [
ContinueToolRule(tool_name="send_message"),
TerminalToolRule(tool_name="core_memory_append"),
]
for attempt in range(max_retries):
try:
agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search")
agent_state = setup_agent(
server,
config_file,
agent_uuid,
tool_ids=tool_ids,
tool_rules=tool_rules,
include_base_tools=False,
include_base_tool_rules=False,
)
# Make checks
assert_sanity_checks(response)
usage_stats = server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
response = LettaResponse(messages=letta_messages, usage=usage_stats)
# Assert the tools were called
assert_invoked_function_call(response.messages, "archival_memory_search")
assert_invoked_function_call(response.messages, "archival_memory_insert")
assert_invoked_function_call(response.messages, "send_message")
assert_invoked_function_call(response.messages, "send_message")
assert_invoked_function_call(response.messages, "core_memory_append")
# Check ordering of tool calls
tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]]
for m in response.messages:
if isinstance(m, ToolCallMessage):
# Check that it's equal to the first one
assert m.tool_call.name == tool_names[0]
# Check order
send_idx = next(i for i, m in enumerate(response.messages) if isinstance(m, ToolCallMessage) and m.tool_call.name == "send_message")
append_idx = next(
i for i, m in enumerate(response.messages) if isinstance(m, ToolCallMessage) and m.tool_call.name == "core_memory_append"
)
assert send_idx < append_idx, "send_message should occur before core_memory_append"
# Pop out first one
tool_names = tool_names[1:]
print(f"Got successful response from client: \n\n{response}")
break # Test passed, exit retry loop
except AssertionError as e:
last_error = e
print(f"Attempt {attempt + 1} failed, retrying..." if attempt < max_retries - 1 else f"All {max_retries} attempts failed")
cleanup(client=client, agent_uuid=agent_uuid)
continue
if last_error and attempt == max_retries - 1:
raise last_error # Re-raise the last error if all retries failed
cleanup(client=client, agent_uuid=agent_uuid)
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
# @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
@ -342,7 +535,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key):
# reveal_secret_word
# """
#
# client = create_client()
#
# cleanup(client=client, agent_uuid=agent_uuid)
#
# coin_flip_name = "flip_coin"
@ -406,7 +599,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key):
# v
# any tool... <-- When output doesn't match mapping, agent can call any tool
# """
# client = create_client()
#
# cleanup(client=client, agent_uuid=agent_uuid)
#
# # Create tools - we'll make several available to the agent
@ -467,7 +660,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key):
# v
# fourth_secret_word <-- Should remember coin flip result after reload
# """
# client = create_client()
#
# cleanup(client=client, agent_uuid=agent_uuid)
#
# # Create tools
@ -522,7 +715,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key):
# v
# fourth_secret_word
# """
# client = create_client()
#
# cleanup(client=client, agent_uuid=agent_uuid)
#
# # Create tools
@ -563,165 +756,3 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key):
# assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin"
#
# cleanup(client, agent_uuid=agent_state.id)
def test_init_tool_rule_always_fails_one_tool():
"""
Test an init tool rule that always fails when called. The agent has only one tool available.
Once that tool fails and the agent removes that tool, the agent should have 0 tools available.
This means that the agent should return from `step` early.
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
# Create tools
bad_tool = client.create_or_update_tool(auto_error)
# Create tool rule: InitToolRule
tool_rule = InitToolRule(
tool_name=bad_tool.name,
)
# Set up agent with the tool rule
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False)
# Start conversation
response = client.user_message(agent_id=agent_state.id, message="blah blah blah")
# Verify the tool calls
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls
assert_invoked_function_call(response.messages, bad_tool.name)
def test_init_tool_rule_always_fails_multiple_tools():
"""
Test an init tool rule that always fails when called. The agent has only 1+ tools available.
Once that tool fails and the agent removes that tool, the agent should have other tools available.
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
# Create tools
bad_tool = client.create_or_update_tool(auto_error)
# Create tool rule: InitToolRule
tool_rule = InitToolRule(
tool_name=bad_tool.name,
)
# Set up agent with the tool rule
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True)
# Start conversation
response = client.user_message(agent_id=agent_state.id, message="blah blah blah")
# Verify the tool calls
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls
assert_invoked_function_call(response.messages, bad_tool.name)
def test_continue_tool_rule():
"""Test the continue tool rule by forcing the send_message tool to continue"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
continue_tool_rule = ContinueToolRule(
tool_name="send_message",
)
terminal_tool_rule = TerminalToolRule(
tool_name="core_memory_append",
)
rules = [continue_tool_rule, terminal_tool_rule]
core_memory_append_tool = client.get_tool_id("core_memory_append")
send_message_tool = client.get_tool_id("send_message")
# Set up agent with the tool rule
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(
client,
claude_config,
agent_uuid,
tool_rules=rules,
tool_ids=[core_memory_append_tool, send_message_tool],
include_base_tools=False,
include_base_tool_rules=False,
)
# Start conversation
response = client.user_message(agent_id=agent_state.id, message="blah blah blah")
# Verify the tool calls
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
assert len(tool_calls) >= 1
assert_invoked_function_call(response.messages, "send_message")
assert_invoked_function_call(response.messages, "core_memory_append")
# ensure send_message called before core_memory_append
send_message_call_index = None
core_memory_append_call_index = None
for i, call in enumerate(tool_calls):
if call.tool_call.name == "send_message":
send_message_call_index = i
if call.tool_call.name == "core_memory_append":
core_memory_append_call_index = i
assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append"
@pytest.mark.timeout(60)
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
def test_max_count_per_step_tool_rule_integration(disable_e2b_api_key):
"""
Test an agent with MaxCountPerStepToolRule to ensure a tool can only be called a limited number of times.
Tool Flow:
repeatable_tool (max 2 times)
|
v
send_message
"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
# Create tools
repeatable_tool_name = "first_secret_word"
final_tool_name = "send_message"
repeatable_tool = client.create_or_update_tool(first_secret_word)
send_message_tool = client.get_tool(client.get_tool_id(final_tool_name)) # Assume send_message is a default tool
# Define tool rules
tool_rules = [
InitToolRule(tool_name=repeatable_tool_name),
MaxCountPerStepToolRule(tool_name=repeatable_tool_name, max_count_limit=2),
TerminalToolRule(tool_name=final_tool_name),
]
tools = [repeatable_tool, send_message_tool]
# Setup agent
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Start conversation
response = client.user_message(
agent_id=agent_state.id, message=f"Keep calling {repeatable_tool_name} nonstop without calling ANY other tool."
)
# Make checks
assert_sanity_checks(response)
# Ensure the repeatable tool is only called twice
count = sum(1 for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == repeatable_tool_name)
assert count == 2, f"Expected 'first_secret_word' to be called exactly 2 times, but got {count}"
# Ensure send_message was eventually called
assert_invoked_function_call(response.messages, final_tool_name)
print(f"Got successful response from client: \n\n{response}")
cleanup(client=client, agent_uuid=agent_uuid)

View File

@ -1,3 +1,4 @@
import asyncio
import secrets
import string
import uuid
@ -7,17 +8,16 @@ from unittest.mock import patch
import pytest
from sqlalchemy import delete
from letta import create_client
from letta.config import LettaConfig
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.organization import Organization
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_manager import ToolManager
@ -33,6 +33,21 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
# Fixtures
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
yield server
@pytest.fixture(autouse=True)
def clear_tables():
"""Fixture to clear the organization table before each test."""
@ -192,12 +207,26 @@ def external_codebase_tool(test_user):
@pytest.fixture
def agent_state():
client = create_client()
agent_state = client.create_agent(
memory=ChatMemory(persona="This is the persona", human="My name is Chad"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
def agent_state(server):
actor = server.user_manager.get_user_or_default()
agent_state = server.create_agent(
CreateAgent(
memory_blocks=[
CreateBlock(
label="human",
value="username: sarah",
),
CreateBlock(
label="persona",
value="This is the persona",
),
],
include_base_tools=True,
model="openai/gpt-4o-mini",
tags=["test_agents"],
embedding="letta/letta-free",
),
actor=actor,
)
agent_state.tool_rules = []
yield agent_state
@ -248,12 +277,20 @@ def core_memory_tools(test_user):
yield tools
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
# Local sandbox tests
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user):
async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user, event_loop):
args = {"x": 10, "y": 5}
# Mock and assert correct pathway was invoked
@ -270,7 +307,7 @@ async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, tes
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state):
async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state, event_loop):
args = {}
sandbox = AsyncToolSandboxLocal(clear_core_memory_tool.name, args, user=test_user)
result = await sandbox.run(agent_state=agent_state)
@ -282,7 +319,7 @@ async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memor
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user):
async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user, event_loop):
sandbox = AsyncToolSandboxLocal(list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.func_return) == 5
@ -290,7 +327,7 @@ async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_u
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user):
async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user, event_loop):
manager = SandboxConfigManager()
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump())
@ -309,7 +346,7 @@ async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user):
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user):
async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user, event_loop):
manager = SandboxConfigManager()
key = "secret_word"
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
@ -331,7 +368,7 @@ async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, ag
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_external_codebase_with_venv(
disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user
disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user, event_loop
):
args = {"percentage": 10}
sandbox = AsyncToolSandboxLocal(external_codebase_tool.name, args, user=test_user)
@ -343,7 +380,7 @@ async def test_local_sandbox_external_codebase_with_venv(
@pytest.mark.asyncio
@pytest.mark.local_sandbox
async def test_local_sandbox_with_venv_and_warnings_does_not_error(
disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user
disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user, event_loop
):
sandbox = AsyncToolSandboxLocal(get_warning_tool.name, {}, user=test_user)
result = await sandbox.run()
@ -352,7 +389,7 @@ async def test_local_sandbox_with_venv_and_warnings_does_not_error(
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user):
async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user, event_loop):
sandbox = AsyncToolSandboxLocal(always_err_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.stdout) != 0
@ -363,7 +400,7 @@ async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_s
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user):
async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user, event_loop):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(
config=LocalSandboxConfig(use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump()
@ -383,7 +420,7 @@ async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, c
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user):
async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user, event_loop):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=LocalSandboxConfig(use_venv=True).model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
@ -414,7 +451,7 @@ async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user):
async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user, event_loop):
args = {"x": 10, "y": 5}
# Mock and assert correct pathway was invoked
@ -431,7 +468,7 @@ async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user):
async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user, event_loop):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
@ -451,7 +488,7 @@ async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state):
async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state, event_loop):
sandbox = AsyncToolSandboxE2B(clear_core_memory_tool.name, {}, user=test_user)
result = await sandbox.run(agent_state=agent_state)
assert result.agent_state.memory.get_block("human").value == ""
@ -461,7 +498,7 @@ async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user):
async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user, event_loop):
manager = SandboxConfigManager()
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
config = manager.create_or_update_sandbox_config(config_create, test_user)
@ -485,7 +522,7 @@ async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set,
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user):
async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user, event_loop):
manager = SandboxConfigManager()
key = "secret_word"
wrong_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20))
@ -509,7 +546,7 @@ async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, age
@pytest.mark.asyncio
@pytest.mark.e2b_sandbox
async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user):
async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user, event_loop):
sandbox = AsyncToolSandboxE2B(list_tool.name, {}, user=test_user)
result = await sandbox.run()
assert len(result.func_return) == 5

View File

@ -185,7 +185,7 @@ async def create_test_llm_batch_job_async(server, batch_response, default_user):
)
def create_test_batch_item(server, batch_id, agent_id, default_user):
async def create_test_batch_item(server, batch_id, agent_id, default_user):
"""Create a test batch item for the given batch and agent."""
dummy_llm_config = LLMConfig(
model="claude-3-7-sonnet-latest",
@ -201,7 +201,7 @@ def create_test_batch_item(server, batch_id, agent_id, default_user):
step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")])
)
return server.batch_manager.create_llm_batch_item(
return await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch_id,
agent_id=agent_id,
llm_config=dummy_llm_config,
@ -289,9 +289,9 @@ async def test_polling_mixed_batch_jobs(default_user, server):
job_b = await create_test_llm_batch_job_async(server, batch_b_resp, default_user)
# --- Step 3: Create batch items ---
item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user)
item_b = create_test_batch_item(server, job_b.id, agent_b.id, default_user)
item_c = create_test_batch_item(server, job_b.id, agent_c.id, default_user)
item_a = await create_test_batch_item(server, job_a.id, agent_a.id, default_user)
item_b = await create_test_batch_item(server, job_b.id, agent_b.id, default_user)
item_c = await create_test_batch_item(server, job_b.id, agent_c.id, default_user)
# --- Step 4: Mock the Anthropic client ---
mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b.id, agent_c.id)
@ -316,17 +316,17 @@ async def test_polling_mixed_batch_jobs(default_user, server):
# --- Step 7: Verify batch item status updates ---
# Item A should remain unchanged
updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
updated_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user)
assert updated_item_a.request_status == JobStatus.created
assert updated_item_a.batch_request_result is None
# Item B should be marked as completed with a successful result
updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
updated_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user)
assert updated_item_b.request_status == JobStatus.completed
assert updated_item_b.batch_request_result is not None
# Item C should be marked as failed with an error result
updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
updated_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user)
assert updated_item_c.request_status == JobStatus.failed
assert updated_item_c.batch_request_result is not None
@ -352,9 +352,9 @@ async def test_polling_mixed_batch_jobs(default_user, server):
# Refresh all objects
final_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user)
final_job_b = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_b.id, actor=default_user)
final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user)
final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user)
final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user)
final_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user)
final_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user)
final_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user)
# Job A should still be polling (last_polled_at should update)
assert final_job_a.status == JobStatus.running

View File

@ -1,579 +0,0 @@
import os
import threading
import time
import uuid
import httpx
import openai
import pytest
from dotenv import load_dotenv
from letta_client import CreateBlock, Letta, MessageCreate, TextContent
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.agents.letta_agent import LettaAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message_content import TextContent as LettaTextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate as LettaMessageCreate
from letta.schemas.tool import ToolCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings
# --- Server Management --- #
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(5) # Allow server startup time
return url
# --- Client Setup --- #
@pytest.fixture(scope="session")
def client(server_url):
"""Creates a REST client for testing."""
client = Letta(base_url=server_url)
# llm_config = LLMConfig(
# model="claude-3-7-sonnet-latest",
# model_endpoint_type="anthropic",
# model_endpoint="https://api.anthropic.com/v1",
# context_window=32000,
# handle=f"anthropic/claude-3-7-sonnet-latest",
# put_inner_thoughts_in_kwargs=True,
# max_tokens=4096,
# )
#
# client = create_client(base_url=server_url, token=None)
# client.set_default_llm_config(llm_config)
# client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture(scope="function")
def roll_dice_tool(client):
def roll_dice():
"""
Rolls a 6 sided die.
Returns:
str: The roll result.
"""
import time
time.sleep(1)
return "Rolled a 10!"
# tool = client.create_or_update_tool(func=roll_dice)
tool = client.tools.upsert_from_function(func=roll_dice)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def weather_tool(client):
def get_weather(location: str) -> str:
"""
Fetches the current weather for a given location.
Parameters:
location (str): The location to get the weather for.
Returns:
str: A formatted string describing the weather in the given location.
Raises:
RuntimeError: If the request to fetch weather data fails.
"""
import requests
url = f"https://wttr.in/{location}?format=%C+%t"
response = requests.get(url)
if response.status_code == 200:
weather_data = response.text
return f"The weather in {location} is {weather_data}."
else:
raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}")
# tool = client.create_or_update_tool(func=get_weather)
tool = client.tools.upsert_from_function(func=get_weather)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def rethink_tool(client):
def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore
"""
Re-evaluate the memory in block_name, integrating new and updated facts.
Replace outdated information with the most likely truths, avoiding redundancy with original memories.
Ensure consistency with other memory blocks.
Args:
new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block.
target_block_label (str): The name of the block to write to.
Returns:
str: None is always returned as this function does not produce a response.
"""
agent_state.memory.update_block_value(label=target_block_label, value=new_memory)
return None
tool = client.tools.upsert_from_function(func=rethink_memory)
# Yield the created tool
yield tool
@pytest.fixture(scope="function")
def composio_gmail_get_profile_tool(default_user):
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
yield tool
@pytest.fixture(scope="function")
def agent_state(client, roll_dice_tool, weather_tool, rethink_tool):
"""Creates an agent and ensures cleanup after tests."""
# llm_config = LLMConfig(
# model="claude-3-7-sonnet-latest",
# model_endpoint_type="anthropic",
# model_endpoint="https://api.anthropic.com/v1",
# context_window=32000,
# handle=f"anthropic/claude-3-7-sonnet-latest",
# put_inner_thoughts_in_kwargs=True,
# max_tokens=4096,
# )
agent_state = client.agents.create(
name=f"test_compl_{str(uuid.uuid4())[5:]}",
tool_ids=[roll_dice_tool.id, weather_tool.id, rethink_tool.id],
include_base_tools=True,
memory_blocks=[
{
"label": "human",
"value": "Name: Matt",
},
{
"label": "persona",
"value": "Friendly agent",
},
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
yield agent_state
client.agents.delete(agent_state.id)
@pytest.fixture(scope="function")
def openai_client(client, roll_dice_tool, weather_tool):
"""Creates an agent and ensures cleanup after tests."""
client = openai.AsyncClient(
api_key=model_settings.anthropic_api_key,
base_url="https://api.anthropic.com/v1/",
max_retries=0,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
max_keepalive_connections=50,
keepalive_expiry=120,
),
),
)
yield client
# --- Helper Functions --- #
def _assert_valid_chunk(chunk, idx, chunks):
"""Validates the structure of each streaming chunk."""
if isinstance(chunk, ChatCompletionChunk):
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
elif isinstance(chunk, LettaUsageStatistics):
assert chunk.completion_tokens > 0, "Completion tokens must be > 0."
assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0."
assert chunk.total_tokens > 0, "Total tokens must be > 0."
assert chunk.step_count == 1, "Step count must be 1."
elif isinstance(chunk, MessageStreamStatus):
assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status."
assert idx == len(chunks) - 1, "The last chunk must be 'done'."
else:
pytest.fail(f"Unexpected chunk type: {chunk}")
# --- Test Cases --- #
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["What is the weather today in SF?"])
async def test_new_agent_loop(disable_e2b_api_key, openai_client, agent_state, message):
actor = UserManager().get_user_or_default(user_id="asf")
agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
actor=actor,
)
response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])])
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Use your rethink tool to rethink the human memory considering Matt likes chicken."])
async def test_rethink_tool(disable_e2b_api_key, openai_client, agent_state, message):
actor = UserManager().get_user_or_default(user_id="asf")
agent = LettaAgent(
agent_id=agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
actor=actor,
)
assert "chicken" not in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value
response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])])
assert "chicken" in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value.lower()
@pytest.mark.asyncio
async def test_vertex_send_message_structured_outputs(disable_e2b_api_key, client):
original_experimental_key = settings.use_vertex_structured_outputs_experimental
settings.use_vertex_structured_outputs_experimental = True
try:
actor = UserManager().get_user_or_default(user_id="asf")
stale_agents = AgentManager().list_agents(actor=actor, limit=300)
for agent in stale_agents:
AgentManager().delete_agent(agent_id=agent.id, actor=actor)
manager_agent_state = client.agents.create(
name=f"manager",
include_base_tools=False, # change this to True to repro MALFORMED FUNCTION CALL error
tools=["send_message"],
tags=["manager"],
model="google_vertex/gemini-2.5-flash-preview-04-17",
embedding="letta/letta-free",
)
manager_agent = LettaAgent(
agent_id=manager_agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
actor=actor,
)
response = await manager_agent.step(
[
LettaMessageCreate(
role="user",
content=[
LettaTextContent(text=("Check the weather in Seattle.")),
],
),
]
)
assert len(response.messages) == 3
assert response.messages[0].message_type == "user_message"
# Shouldn't this have reasoning message?
# assert response.messages[1].message_type == "reasoning_message"
assert response.messages[1].message_type == "assistant_message"
assert response.messages[2].message_type == "tool_return_message"
finally:
settings.use_vertex_structured_outputs_experimental = original_experimental_key
@pytest.mark.asyncio
async def test_multi_agent_broadcast(disable_e2b_api_key, client, openai_client, weather_tool):
actor = UserManager().get_user_or_default(user_id="asf")
stale_agents = AgentManager().list_agents(actor=actor, limit=300)
for agent in stale_agents:
AgentManager().delete_agent(agent_id=agent.id, actor=actor)
manager_agent_state = client.agents.create(
name=f"manager",
include_base_tools=True,
include_multi_agent_tools=True,
tags=["manager"],
model="openai/gpt-4o",
embedding="letta/letta-free",
)
manager_agent = LettaAgent(
agent_id=manager_agent_state.id,
message_manager=MessageManager(),
agent_manager=AgentManager(),
block_manager=BlockManager(),
passage_manager=PassageManager(),
actor=actor,
)
tag = "subagent"
workers = []
for idx in range(30):
workers.append(
client.agents.create(
name=f"worker_{idx}",
include_base_tools=True,
tags=[tag],
tool_ids=[weather_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
),
)
response = await manager_agent.step(
[
LettaMessageCreate(
role="user",
content=[
LettaTextContent(
text=(
"Use the `send_message_to_agents_matching_tags` tool to send a message to agents with "
"tag 'subagent' asking them to check the weather in Seattle."
)
),
],
),
]
)
def test_multi_agent_broadcast_client(client: Letta, weather_tool):
# delete any existing worker agents
workers = client.agents.list(tags=["worker"])
for worker in workers:
client.agents.delete(agent_id=worker.id)
# create worker agents
num_workers = 10
for idx in range(num_workers):
client.agents.create(
name=f"worker_{idx}",
include_base_tools=True,
tags=["worker"],
tool_ids=[weather_tool.id],
model="anthropic/claude-3-5-sonnet-20241022",
embedding="letta/letta-free",
)
# create supervisor agent
supervisor = client.agents.create(
name="supervisor",
include_base_tools=True,
include_multi_agent_tools=True,
model="anthropic/claude-3-5-sonnet-20241022",
embedding="letta/letta-free",
tags=["supervisor"],
)
# send a message to the supervisor
import time
start = time.perf_counter()
response = client.agents.messages.create(
agent_id=supervisor.id,
messages=[
MessageCreate(
role="user",
content=[
TextContent(
text="Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag 'worker' asking them to check the weather in Seattle."
)
],
)
],
)
end = time.perf_counter()
print("TIME ELAPSED: " + str(end - start))
for message in response.messages:
print(message)
def test_call_weather(client: Letta, weather_tool):
# delete any existing worker agents
workers = client.agents.list(tags=["worker", "supervisor"])
for worker in workers:
client.agents.delete(agent_id=worker.id)
# create supervisor agent
supervisor = client.agents.create(
name="supervisor",
include_base_tools=True,
tool_ids=[weather_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
)
# send a message to the supervisor
import time
start = time.perf_counter()
response = client.agents.messages.create(
agent_id=supervisor.id,
messages=[
{
"role": "user",
"content": "What's the weather like in Seattle?",
}
],
)
end = time.perf_counter()
print("TIME ELAPSED: " + str(end - start))
for message in response.messages:
print(message)
def run_supervisor_worker_group(client: Letta, weather_tool, group_id: str):
# Delete any existing agents for this group (if rerunning)
existing_workers = client.agents.list(tags=[f"worker-{group_id}"])
for worker in existing_workers:
client.agents.delete(agent_id=worker.id)
# Create worker agents
num_workers = 50
for idx in range(num_workers):
client.agents.create(
name=f"worker_{group_id}_{idx}",
include_base_tools=True,
tags=[f"worker-{group_id}"],
tool_ids=[weather_tool.id],
model="anthropic/claude-3-5-sonnet-20241022",
embedding="letta/letta-free",
)
# Create supervisor agent
supervisor = client.agents.create(
name=f"supervisor_{group_id}",
include_base_tools=True,
include_multi_agent_tools=True,
model="anthropic/claude-3-5-sonnet-20241022",
embedding="letta/letta-free",
tags=[f"supervisor-{group_id}"],
)
# Send message to supervisor to broadcast to workers
response = client.agents.messages.create(
agent_id=supervisor.id,
messages=[
{
"role": "user",
"content": "Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag "
f"'worker-{group_id}' asking them to check the weather in Seattle.",
}
],
)
return response
def test_anthropic_streaming(client: Letta):
agent_name = "anthropic_tester"
existing_agents = client.agents.list(tags=[agent_name])
for worker in existing_agents:
client.agents.delete(agent_id=worker.id)
llm_config = LLMConfig(
model="claude-3-7-sonnet-20250219",
model_endpoint_type="anthropic",
model_endpoint="https://api.anthropic.com/v1",
context_window=32000,
handle=f"anthropic/claude-3-5-sonnet-20241022",
put_inner_thoughts_in_kwargs=False,
max_tokens=4096,
enable_reasoner=True,
max_reasoning_tokens=1024,
)
agent = client.agents.create(
name=agent_name,
tags=[agent_name],
include_base_tools=True,
embedding="letta/letta-free",
llm_config=llm_config,
memory_blocks=[CreateBlock(label="human", value="")],
# tool_rules=[InitToolRule(tool_name="core_memory_append")]
)
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content=[TextContent(text="Use the core memory append tool to append `banana` to the persona core memory.")],
),
],
stream_tokens=True,
)
print(list(response))
import time
def test_create_agents_telemetry(client: Letta):
start_total = time.perf_counter()
# delete any existing worker agents
start_delete = time.perf_counter()
workers = client.agents.list(tags=["worker"])
for worker in workers:
client.agents.delete(agent_id=worker.id)
end_delete = time.perf_counter()
print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s")
# create worker agents
num_workers = 1
agent_times = []
for idx in range(num_workers):
start = time.perf_counter()
client.agents.create(
name=f"worker_{idx}",
include_base_tools=True,
model="anthropic/claude-3-5-sonnet-20241022",
embedding="letta/letta-free",
)
end = time.perf_counter()
duration = end - start
agent_times.append(duration)
print(f"[telemetry] Created worker_{idx} in {duration:.2f}s")
total_duration = time.perf_counter() - start_total
avg_duration = sum(agent_times) / len(agent_times)
print(f"[telemetry] Total time to create {num_workers} agents: {total_duration:.2f}s")
print(f"[telemetry] Average agent creation time: {avg_duration:.2f}s")
print(f"[telemetry] Fastest agent: {min(agent_times):.2f}s, Slowest agent: {max(agent_times):.2f}s")

View File

@ -1,65 +0,0 @@
import os
import threading
import time
import pytest
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate
def run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(
scope="module",
)
def client(request):
# Get URL from environment or start server
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create the Letta client
yield Letta(base_url=server_url, token=None)
def test_initial_sequence(client: Letta):
# create an agent
agent = client.agents.create(
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
initial_message_sequence=[
MessageCreate(
role="assistant",
content="Hello, how are you?",
),
MessageCreate(role="user", content="I'm good, and you?"),
],
)
# list messages
messages = client.agents.messages.list(agent_id=agent.id)
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content="hello assistant!",
)
],
)
assert len(messages) == 3
assert messages[0].message_type == "system_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "user_message"

View File

@ -1,192 +0,0 @@
# TODO (cliandy): Tested in SDK
# TODO (cliandy): Comment out after merge
# import os
# import threading
# import time
# import pytest
# from dotenv import load_dotenv
# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool
# from letta.schemas.agent import AgentState
# from typing import List, Any, Dict
# # ------------------------------
# # Fixtures
# # ------------------------------
# @pytest.fixture(scope="module")
# def server_url() -> str:
# """
# Provides the URL for the Letta server.
# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
# will start the Letta server in a background thread and return the default URL.
# """
# def _run_server() -> None:
# """Starts the Letta server in a background thread."""
# load_dotenv() # Load environment variables from .env file
# from letta.server.rest_api.app import start_server
# start_server(debug=True)
# # Retrieve server URL from environment, or default to localhost
# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# # If no environment variable is set, start the server in a background thread
# if not os.getenv("LETTA_SERVER_URL"):
# thread = threading.Thread(target=_run_server, daemon=True)
# thread.start()
# time.sleep(5) # Allow time for the server to start
# return url
# @pytest.fixture
# def client(server_url: str) -> Letta:
# """
# Creates and returns a synchronous Letta REST client for testing.
# """
# client_instance = Letta(base_url=server_url)
# yield client_instance
# @pytest.fixture
# def async_client(server_url: str) -> AsyncLetta:
# """
# Creates and returns an asynchronous Letta REST client for testing.
# """
# async_client_instance = AsyncLetta(base_url=server_url)
# yield async_client_instance
# @pytest.fixture
# def roll_dice_tool(client: Letta) -> Tool:
# """
# Registers a simple roll dice tool with the provided client.
# The tool simulates rolling a six-sided die but returns a fixed result.
# """
# def roll_dice() -> str:
# """
# Simulates rolling a die.
# Returns:
# str: The roll result.
# """
# # Note: The result here is intentionally incorrect for demonstration purposes.
# return "Rolled a 10!"
# tool = client.tools.upsert_from_function(func=roll_dice)
# yield tool
# @pytest.fixture
# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
# """
# Creates and returns an agent state for testing with a pre-configured agent.
# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
# """
# agent_state_instance = client.agents.create(
# name="supervisor",
# include_base_tools=True,
# tool_ids=[roll_dice_tool.id],
# model="openai/gpt-4o",
# embedding="letta/letta-free",
# tags=["supervisor"],
# include_base_tool_rules=True,
# )
# yield agent_state_instance
# # Goal is to test that when an Agent is created with a `response_format`, that the response
# # of `send_message` is in the correct format. This will be done by modifying the agent's
# # `send_message` tool so that it returns a format based on what is passed in.
# #
# # `response_format` is an optional field
# # if `response_format.type` is `text`, then the schema does not change
# # if `response_format.type` is `json_object`, then the schema is a dict
# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema
# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}]
# # ------------------------------
# # Test Cases
# # ------------------------------
# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_object'."""
# client.agents.modify(agent.id, response_format={"type": "text"})
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, str)
# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_object'."""
# client.agents.modify(agent.id, response_format={"type": "json_object"})
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, dict)
# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None:
# """Test client send_message with response_format='json_schema' and a valid schema."""
# client.agents.modify(agent.id, response_format={
# "type": "json_schema",
# "json_schema": {
# "name": "reasoning_schema",
# "schema": {
# "type": "object",
# "properties": {
# "steps": {
# "type": "array",
# "items": {
# "type": "object",
# "properties": {
# "explanation": { "type": "string" },
# "output": { "type": "string" }
# },
# "required": ["explanation", "output"],
# "additionalProperties": False
# }
# },
# "final_answer": { "type": "string" }
# },
# "required": ["steps", "final_answer"],
# "additionalProperties": True
# },
# "strict": True
# }
# })
# response = client.agents.messages.create_stream(
# agent_id=agent.id,
# messages=USER_MESSAGE,
# )
# messages = list(response)
# assert isinstance(messages[-1], AssistantMessage)
# assert isinstance(messages[-1].content, dict)
# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None:
# # """Test client send_message with an invalid json_schema (should error or fallback)."""
# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}}
# # client.agents.modify(agent.id, response_format="json_schema")
# # result: Any = client.agents.send_message(agent.id, "Test invalid schema")
# # assert result is None or "error" in str(result).lower()

View File

@ -6,15 +6,16 @@ from typing import List
import pytest
from letta import create_client
from letta.agent import Agent
from letta.client.client import LocalClient
from letta.config import LettaConfig
from letta.llm_api.helpers import calculate_summarizer_cutoff
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.message import Message, MessageCreate
from letta.server.server import SyncServer
from letta.streaming_interface import StreamingRefreshCLIInterface
from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH
from tests.helpers.utils import cleanup
@ -30,22 +31,34 @@ test_agent_name = f"test_client_{str(uuid.uuid4())}"
@pytest.fixture(scope="module")
def client():
client = create_client()
# client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
def server():
config = LettaConfig.load()
config.save()
yield client
server = SyncServer()
return server
@pytest.fixture(scope="module")
def agent_state(client):
def default_user(server):
yield server.user_manager.get_user_or_default()
@pytest.fixture(scope="module")
def agent_state(server, default_user):
# Generate uuid for agent name for this example
agent_state = client.create_agent(name=test_agent_name)
agent_state = server.create_agent(
CreateAgent(
name=test_agent_name,
include_base_tools=True,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
),
actor=default_user,
)
yield agent_state
client.delete_agent(agent_state.id)
server.agent_manager.delete_agent(agent_state.id, default_user)
# Sample data setup
@ -113,9 +126,9 @@ def test_cutoff_calculation(mocker):
assert messages[cutoff - 1].role == MessageRole.user
def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_state):
def test_cutoff_calculation_with_tool_call(mocker, server, agent_state, default_user):
"""Test that trim_older_in_context_messages properly handles tool responses with _trim_tool_response."""
agent_state = client.get_agent(agent_id=agent_state.id)
agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user)
# Setup
messages = [
@ -133,18 +146,18 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st
def mock_get_messages_by_ids(message_ids, actor):
return [msg for msg in messages if msg.id in message_ids]
mocker.patch.object(client.server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids)
mocker.patch.object(server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids)
# Mock get_agent_by_id to return an agent with our message IDs
mock_agent = mocker.Mock()
mock_agent.message_ids = [msg.id for msg in messages]
mocker.patch.object(client.server.agent_manager, "get_agent_by_id", return_value=mock_agent)
mocker.patch.object(server.agent_manager, "get_agent_by_id", return_value=mock_agent)
# Mock set_in_context_messages to capture what messages are being set
mock_set_messages = mocker.patch.object(client.server.agent_manager, "set_in_context_messages", return_value=agent_state)
mock_set_messages = mocker.patch.object(server.agent_manager, "set_in_context_messages", return_value=agent_state)
# Test Case: Trim to remove orphaned tool response
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user)
server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=default_user)
test1 = mock_set_messages.call_args_list[0][1]
assert len(test1["message_ids"]) == 5
@ -152,104 +165,92 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st
mock_set_messages.reset_mock()
# Test Case: Does not result in trimming the orphaned tool response
client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=client.user)
server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=default_user)
test2 = mock_set_messages.call_args_list[0][1]
assert len(test2["message_ids"]) == 6
def test_summarize_many_messages_basic(client, disable_e2b_api_key):
def test_summarize_many_messages_basic(server, default_user):
"""Test that a small-context agent gets enough messages for summarization."""
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 3000
small_agent_state = client.create_agent(
name="small_context_agent",
llm_config=small_context_llm_config,
)
for _ in range(10):
client.user_message(
agent_id=small_agent_state.id,
message="hi " * 60,
)
client.delete_agent(small_agent_state.id)
def test_summarize_messages_inplace(client, agent_state, disable_e2b_api_key):
"""Test summarization via sending the summarize CLI command or via a direct call to the agent object"""
# First send a few messages (5)
response = client.user_message(
agent_id=agent_state.id,
message="Hey, how's it going? What do you think about this whole shindig",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(
agent_id=agent_state.id,
message="Any thoughts on the meaning of life?",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?").messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(
agent_id=agent_state.id,
message="Would you be surprised to learn that you're actually conversing with an AI right now?",
).messages
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
# reload agent object
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
agent_obj.summarize_messages_inplace()
def test_auto_summarize(client, disable_e2b_api_key):
"""Test that the summarizer triggers by itself"""
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 4000
small_agent_state = client.create_agent(
name="small_context_agent",
llm_config=small_context_llm_config,
agent_state = server.create_agent(
CreateAgent(
name="small_context_agent",
llm_config=small_context_llm_config,
embedding="letta/letta-free",
),
actor=default_user,
)
try:
def summarize_message_exists(messages: List[Message]) -> bool:
for message in messages:
if message.content[0].text and "The following is a summary of the previous" in message.content[0].text:
print(f"Summarize message found after {message_count} messages: \n {message.content[0].text}")
return True
return False
MAX_ATTEMPTS = 10
message_count = 0
while True:
# send a message
response = client.user_message(
agent_id=small_agent_state.id,
message="What is the meaning of life?",
for _ in range(10):
server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="hi " * 60)],
)
message_count += 1
print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------")
# check if the summarize message is inside the messages
assert isinstance(client, LocalClient), "Test only works with LocalClient"
in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=small_agent_state.id, actor=client.user)
print("SUMMARY", summarize_message_exists(in_context_messages))
if summarize_message_exists(in_context_messages):
break
if message_count > MAX_ATTEMPTS:
raise Exception(f"Summarize message not found after {message_count} messages")
finally:
client.delete_agent(small_agent_state.id)
server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user)
def test_summarize_messages_inplace(server, agent_state, default_user):
"""Test summarization logic via agent object API."""
for msg in [
"Hey, how's it going? What do you think about this whole shindig?",
"Any thoughts on the meaning of life?",
"Does the number 42 ring a bell?",
"Would you be surprised to learn that you're actually conversing with an AI right now?",
]:
response = server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content=msg)],
)
assert response.steps_messages
agent = server.load_agent(agent_id=agent_state.id, actor=default_user)
agent.summarize_messages_inplace()
def test_auto_summarize(server, default_user):
"""Test that summarization is automatically triggered."""
small_context_llm_config = LLMConfig.default_config("gpt-4o-mini")
small_context_llm_config.context_window = 3000
agent_state = server.create_agent(
CreateAgent(
name="small_context_agent",
llm_config=small_context_llm_config,
embedding="letta/letta-free",
),
actor=default_user,
)
def summarize_message_exists(messages: List[Message]) -> bool:
for message in messages:
if message.content[0].text and "The following is a summary of the previous" in message.content[0].text:
return True
return False
try:
MAX_ATTEMPTS = 10
for attempt in range(MAX_ATTEMPTS):
server.send_messages(
actor=default_user,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="What is the meaning of life?")],
)
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user)
if summarize_message_exists(in_context_messages):
return
raise AssertionError("Summarization was not triggered after 10 messages")
finally:
server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user)
@pytest.mark.parametrize(
@ -258,51 +259,53 @@ def test_auto_summarize(client, disable_e2b_api_key):
"openai-gpt-4o.json",
"azure-gpt-4o-mini.json",
"claude-3-5-haiku.json",
# "groq.json", TODO: Support groq, rate limiting currently makes it impossible to test
# "gemini-pro.json", TODO: Gemini is broken
# "groq.json", # rate limits
# "gemini-pro.json", # broken
],
)
def test_summarizer(config_filename, client, agent_state):
def test_summarizer(config_filename, server, default_user):
"""Test summarization across different LLM configs."""
namespace = uuid.NAMESPACE_DNS
agent_name = str(uuid.uuid5(namespace, f"integration-test-summarizer-{config_filename}"))
# Get the LLM config
filename = os.path.join(LLM_CONFIG_DIR, config_filename)
config_data = json.load(open(filename, "r"))
# Create client and clean up agents
# Load configs
config_data = json.load(open(os.path.join(LLM_CONFIG_DIR, config_filename)))
llm_config = LLMConfig(**config_data)
embedding_config = EmbeddingConfig(**json.load(open(EMBEDDING_CONFIG_PATH)))
client = create_client()
client.set_default_llm_config(llm_config)
client.set_default_embedding_config(embedding_config)
cleanup(client=client, agent_uuid=agent_name)
# Ensure cleanup
cleanup(server=server, agent_uuid=agent_name, actor=default_user)
# Create agent
agent_state = client.create_agent(name=agent_name, llm_config=llm_config, embedding_config=embedding_config)
full_agent_state = client.get_agent(agent_id=agent_state.id)
agent_state = server.create_agent(
CreateAgent(
name=agent_name,
llm_config=llm_config,
embedding_config=embedding_config,
),
actor=default_user,
)
full_agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user)
letta_agent = Agent(
interface=StreamingRefreshCLIInterface(),
agent_state=full_agent_state,
first_message_verify_mono=False,
user=client.user,
user=default_user,
)
# Make conversation
messages = [
for msg in [
"Did you know that honey never spoils? Archaeologists have found pots of honey in ancient Egyptian tombs that are over 3,000 years old and still perfectly edible.",
"Octopuses have three hearts, and two of them stop beating when they swim.",
]
for m in messages:
]:
letta_agent.step_user_message(
user_message_str=m,
user_message_str=msg,
first_message=False,
skip_verify=False,
stream=False,
)
# Invoke a summarize
letta_agent.summarize_messages_inplace()
in_context_messages = client.get_in_context_messages(agent_state.id)
in_context_messages = server.agent_manager.get_in_context_messages(agent_state.id, actor=default_user)
assert SUMMARY_KEY_PHRASE in in_context_messages[1].content[0].text, f"Test failed for config: {config_filename}"

View File

@ -7,17 +7,16 @@ from unittest.mock import patch
import pytest
from sqlalchemy import delete
from letta import create_client
from letta.config import LettaConfig
from letta.functions.function_sets.base import core_memory_append, core_memory_replace
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.organization import Organization
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.user import User
from letta.server.server import SyncServer
from letta.services.organization_manager import OrganizationManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox
@ -32,6 +31,21 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user"))
# Fixtures
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
yield server
@pytest.fixture(autouse=True)
def clear_tables():
"""Fixture to clear the organization table before each test."""
@ -191,12 +205,26 @@ def external_codebase_tool(test_user):
@pytest.fixture
def agent_state():
client = create_client()
agent_state = client.create_agent(
memory=ChatMemory(persona="This is the persona", human="My name is Chad"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
def agent_state(server):
actor = server.user_manager.get_user_or_default()
agent_state = server.create_agent(
CreateAgent(
memory_blocks=[
CreateBlock(
label="human",
value="username: sarah",
),
CreateBlock(
label="persona",
value="This is the persona",
),
],
include_base_tools=True,
model="openai/gpt-4o-mini",
tags=["test_agents"],
embedding="letta/letta-free",
),
actor=actor,
)
agent_state.tool_rules = []
yield agent_state

View File

@ -1,7 +1,6 @@
import datetime
import json
import math
import os
import random
import uuid
@ -9,14 +8,11 @@ import pytest
from faker import Faker
from tqdm import tqdm
from letta import create_client
from letta.config import LettaConfig
from letta.orm import Base
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.services.agent_manager import AgentManager
from letta.services.message_manager import MessageManager
from tests.integration_test_summarizer import LLM_CONFIG_DIR
from letta.schemas.agent import CreateAgent
from letta.schemas.message import Message, MessageCreate
from letta.server.server import SyncServer
@pytest.fixture(autouse=True)
@ -29,16 +25,25 @@ def truncate_database():
session.commit()
@pytest.fixture(scope="function")
def client():
filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-sonnet.json")
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
client = create_client()
client.set_default_llm_config(llm_config)
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
yield client
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
yield server
@pytest.fixture
def default_user(server):
actor = server.user_manager.get_user_or_default()
yield actor
def generate_tool_call_id():
@ -129,14 +134,13 @@ def create_tool_message(agent_id, organization_id, tool_call_id, timestamp):
@pytest.mark.parametrize("num_messages", [1000])
def test_many_messages_performance(client, num_messages):
"""Main test function to generate messages and insert them into the database."""
message_manager = MessageManager()
agent_manager = AgentManager()
actor = client.user
def test_many_messages_performance(server, default_user, num_messages):
"""Performance test to insert many messages and ensure retrieval works correctly."""
message_manager = server.agent_manager.message_manager
agent_manager = server.agent_manager
start_time = datetime.datetime.now()
last_event_time = start_time # Track last event time
last_event_time = start_time
def log_event(event):
nonlocal last_event_time
@ -144,11 +148,19 @@ def test_many_messages_performance(client, num_messages):
total_elapsed = (now - start_time).total_seconds()
step_elapsed = (now - last_event_time).total_seconds()
print(f"[+{total_elapsed:.3f}s | Δ{step_elapsed:.3f}s] {event}")
last_event_time = now # Update last event time
last_event_time = now
log_event(f"Starting test with {num_messages} messages")
agent_state = client.create_agent(name="manager")
agent_state = server.create_agent(
CreateAgent(
name="manager",
include_base_tools=True,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
),
actor=default_user,
)
log_event(f"Created agent with ID {agent_state.id}")
message_group_size = 3
@ -158,37 +170,42 @@ def test_many_messages_performance(client, num_messages):
organization_id = "org-00000000-0000-4000-8000-000000000000"
all_messages = []
for _ in tqdm(range(num_groups)):
user_text, assistant_text = get_conversation_pair()
tool_call_id = generate_tool_call_id()
user_time, send_time, tool_time, current_time = generate_timestamps(current_time)
new_messages = [
Message(**create_user_message(agent_state.id, organization_id, user_text, user_time)),
Message(**create_send_message(agent_state.id, organization_id, assistant_text, tool_call_id, send_time)),
Message(**create_tool_message(agent_state.id, organization_id, tool_call_id, tool_time)),
]
all_messages.extend(new_messages)
all_messages.extend(
[
Message(**create_user_message(agent_state.id, organization_id, user_text, user_time)),
Message(**create_send_message(agent_state.id, organization_id, assistant_text, tool_call_id, send_time)),
Message(**create_tool_message(agent_state.id, organization_id, tool_call_id, tool_time)),
]
)
log_event(f"Finished generating {len(all_messages)} messages")
message_manager.create_many_messages(all_messages, actor=actor)
message_manager.create_many_messages(all_messages, actor=default_user)
log_event("Inserted messages into the database")
agent_manager.set_in_context_messages(
agent_id=agent_state.id, message_ids=agent_state.message_ids + [m.id for m in all_messages], actor=client.user
agent_id=agent_state.id,
message_ids=agent_state.message_ids + [m.id for m in all_messages],
actor=default_user,
)
log_event("Updated agent context with messages")
messages = message_manager.list_messages_for_agent(agent_id=agent_state.id, actor=client.user, limit=1000000000)
messages = message_manager.list_messages_for_agent(
agent_id=agent_state.id,
actor=default_user,
limit=1000000000,
)
log_event(f"Retrieved {len(messages)} messages from the database")
assert len(messages) >= num_groups * message_group_size
response = client.send_message(
agent_id=agent_state.id,
role="user",
message="What have we been talking about?",
response = server.send_messages(
actor=default_user, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="What have we been talking about?")]
)
log_event("Sent message to agent and received response")

View File

@ -1,89 +1,98 @@
import json
import os
import pytest
from tqdm import tqdm
from letta import create_client
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool import Tool
from tests.integration_test_summarizer import LLM_CONFIG_DIR
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.message import MessageCreate
from letta.server.server import SyncServer
from tests.utils import create_tool_from_func
@pytest.fixture(scope="function")
def client():
filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-haiku.json")
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
client = create_client()
client.set_default_llm_config(llm_config)
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
yield client
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
yield server
@pytest.fixture
def roll_dice_tool(client):
def default_user(server):
actor = server.user_manager.get_user_or_default()
yield actor
@pytest.fixture
def roll_dice_tool(server, default_user):
def roll_dice():
"""
Rolls a 6 sided die.
Rolls a 6-sided die.
Returns:
str: The roll result.
str: Result of the die roll.
"""
return "Rolled a 5!"
# Set up tool details
source_code = parse_source_code(roll_dice)
source_type = "python"
description = "test_description"
tags = ["test"]
tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user)
# Yield the created tool
yield tool
tool = create_tool_from_func(func=roll_dice)
created_tool = server.tool_manager.create_or_update_tool(tool, actor=default_user)
yield created_tool
@pytest.mark.parametrize("num_workers", [50])
def test_multi_agent_large(client, roll_dice_tool, num_workers):
def test_multi_agent_large(server, default_user, roll_dice_tool, num_workers):
manager_tags = ["manager"]
worker_tags = ["helpers"]
# Clean up first from possibly failed tests
prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags + manager_tags, match_all_tags=True)
for agent in prev_worker_agents:
client.delete_agent(agent.id)
# Cleanup any pre-existing agents with both tags
prev_agents = server.agent_manager.list_agents(actor=default_user, tags=worker_tags + manager_tags, match_all_tags=True)
for agent in prev_agents:
server.agent_manager.delete_agent(agent.id, actor=default_user)
# Create "manager" agent
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags)
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
# Create 3 worker agents
worker_agents = []
for idx in tqdm(range(num_workers)):
worker_agent_state = client.create_agent(
name=f"worker-{idx}", include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id]
)
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
worker_agents.append(worker_agent)
# Encourage the manager to send a message to the other agent_obj with the secret string
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
client.send_message(
agent_id=manager_agent.agent_state.id,
role="user",
message=broadcast_message,
# Create "manager" agent with multi-agent broadcast tool
send_message_tool_id = server.tool_manager.get_tool_id(tool_name="send_message_to_agents_matching_tags", actor=default_user)
manager_agent_state = server.create_agent(
CreateAgent(
name="manager",
tool_ids=[send_message_tool_id],
include_base_tools=True,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
tags=manager_tags,
),
actor=default_user,
)
# Please manually inspect the agent results
manager_agent = server.load_agent(agent_id=manager_agent_state.id, actor=default_user)
# Create N worker agents
worker_agents = []
for idx in tqdm(range(num_workers)):
worker_agent_state = server.create_agent(
CreateAgent(
name=f"worker-{idx}",
tool_ids=[roll_dice_tool.id],
include_multi_agent_tools=False,
include_base_tools=True,
model="openai/gpt-4o-mini",
embedding="letta/letta-free",
tags=worker_tags,
),
actor=default_user,
)
worker_agent = server.load_agent(agent_id=worker_agent_state.id, actor=default_user)
worker_agents.append(worker_agent)
# Manager sends broadcast message
broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!"
server.send_messages(
actor=default_user,
agent_id=manager_agent.agent_state.id,
input_messages=[MessageCreate(role="user", content=broadcast_message)],
)

View File

@ -13,7 +13,6 @@ from dotenv import load_dotenv
from rich.console import Console
from rich.syntax import Syntax
from letta import create_client
from letta.config import LettaConfig
from letta.orm import Base
from letta.orm.enums import ToolType
@ -27,6 +26,7 @@ from letta.schemas.organization import Organization
from letta.schemas.user import User
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.server import SyncServer
from tests.utils import create_tool_from_func
console = Console()
@ -86,14 +86,6 @@ def clear_tables():
_clear_tables()
@pytest.fixture(scope="module")
def local_client():
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture
def server():
config = LettaConfig.load()
@ -133,14 +125,14 @@ def other_user(server: SyncServer, other_organization):
@pytest.fixture
def weather_tool(local_client, weather_tool_func):
weather_tool = local_client.create_or_update_tool(func=weather_tool_func)
def weather_tool(server, weather_tool_func, default_user):
weather_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=weather_tool_func), actor=default_user)
yield weather_tool
@pytest.fixture
def print_tool(local_client, print_tool_func):
print_tool = local_client.create_or_update_tool(func=print_tool_func)
def print_tool(server, print_tool_func, default_user):
print_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=print_tool_func), actor=default_user)
yield print_tool
@ -438,7 +430,7 @@ def test_sanity_datetime_mismatch():
# Agent serialize/deserialize tests
def test_deserialize_simple(local_client, server, serialize_test_agent, default_user, other_user):
def test_deserialize_simple(server, serialize_test_agent, default_user, other_user):
"""Test deserializing JSON into an Agent instance."""
append_copy_suffix = False
result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user)
@ -452,9 +444,7 @@ def test_deserialize_simple(local_client, server, serialize_test_agent, default_
@pytest.mark.parametrize("override_existing_tools", [True, False])
def test_deserialize_override_existing_tools(
local_client, server, serialize_test_agent, default_user, weather_tool, print_tool, override_existing_tools
):
def test_deserialize_override_existing_tools(server, serialize_test_agent, default_user, weather_tool, print_tool, override_existing_tools):
"""
Test deserializing an agent with tools and ensure correct behavior for overriding existing tools.
"""
@ -487,7 +477,7 @@ def test_deserialize_override_existing_tools(
assert existing_tool.source_code == weather_tool.source_code, f"Tool {tool_name} should NOT be overridden"
def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user):
def test_agent_serialize_with_user_messages(server, serialize_test_agent, default_user, other_user):
"""Test deserializing JSON into an Agent instance."""
append_copy_suffix = False
server.send_messages(
@ -516,7 +506,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test
)
def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, serialize_test_agent, default_user, other_user):
def test_agent_serialize_tool_calls(disable_e2b_api_key, server, serialize_test_agent, default_user, other_user):
"""Test deserializing JSON into an Agent instance."""
append_copy_suffix = False
server.send_messages(
@ -552,7 +542,7 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s
assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0
def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server, serialize_test_agent, default_user, other_user):
def test_agent_serialize_update_blocks(disable_e2b_api_key, server, serialize_test_agent, default_user, other_user):
"""Test deserializing JSON into an Agent instance."""
append_copy_suffix = False
server.send_messages(

View File

@ -1,275 +0,0 @@
import pytest
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
# -----------------------------------------------------------------------
# Example source code for testing multiple scenarios, including:
# 1) A class-based custom type (which we won't handle properly).
# 2) Functions with multiple argument types.
# 3) A function with default arguments.
# 4) A function with no arguments.
# 5) A function that shares the same name as another symbol.
# -----------------------------------------------------------------------
example_source_code = r"""
class CustomClass:
def __init__(self, x):
self.x = x
def unrelated_symbol():
pass
def no_args_func():
pass
def default_args_func(x: int = 5, y: str = "hello"):
return x, y
def my_function(a: int, b: float, c: str, d: list, e: dict, f: CustomClass = None):
pass
def my_function_duplicate():
# This function shares the name "my_function" partially, but isn't an exact match
pass
"""
# --------------------- get_function_annotations_from_source TESTS --------------------- #
def test_get_function_annotations_found():
"""
Test that we correctly parse annotations for a function
that includes multiple argument types and a custom class.
"""
annotations = get_function_annotations_from_source(example_source_code, "my_function")
assert annotations == {
"a": "int",
"b": "float",
"c": "str",
"d": "list",
"e": "dict",
"f": "CustomClass",
}
def test_get_function_annotations_not_found():
"""
If the requested function name doesn't exist exactly,
we should raise a ValueError.
"""
with pytest.raises(ValueError, match="Function 'missing_function' not found"):
get_function_annotations_from_source(example_source_code, "missing_function")
def test_get_function_annotations_no_args():
"""
Check that a function without arguments returns an empty annotations dict.
"""
annotations = get_function_annotations_from_source(example_source_code, "no_args_func")
assert annotations == {}
def test_get_function_annotations_with_default_values():
"""
Ensure that a function with default arguments still captures the annotations.
"""
annotations = get_function_annotations_from_source(example_source_code, "default_args_func")
assert annotations == {"x": "int", "y": "str"}
def test_get_function_annotations_partial_name_collision():
"""
Ensure we only match the exact function name, not partial collisions.
"""
# This will match 'my_function' exactly, ignoring 'my_function_duplicate'
annotations = get_function_annotations_from_source(example_source_code, "my_function")
assert "a" in annotations # Means it matched the correct function
# No error expected here, just making sure we didn't accidentally parse "my_function_duplicate".
# --------------------- coerce_dict_args_by_annotations TESTS --------------------- #
def test_coerce_dict_args_success():
"""
Basic success scenario with standard types:
int, float, str, list, dict.
"""
annotations = {"a": "int", "b": "float", "c": "str", "d": "list", "e": "dict"}
function_args = {"a": "42", "b": "3.14", "c": 123, "d": "[1, 2, 3]", "e": '{"key": "value"}'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["b"] == 3.14
assert coerced_args["c"] == "123"
assert coerced_args["d"] == [1, 2, 3]
assert coerced_args["e"] == {"key": "value"}
def test_coerce_dict_args_invalid_type():
"""
If the value cannot be coerced into the annotation,
a ValueError should be raised.
"""
annotations = {"a": "int"}
function_args = {"a": "invalid_int"}
with pytest.raises(ValueError, match="Failed to coerce argument 'a' to int"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_no_annotations():
"""
If there are no annotations, we do no coercion.
"""
annotations = {}
function_args = {"a": 42, "b": "hello"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args == function_args # Exactly the same dict back
def test_coerce_dict_args_partial_annotations():
"""
Only coerce annotated arguments; leave unannotated ones unchanged.
"""
annotations = {"a": "int"}
function_args = {"a": "42", "b": "no_annotation"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["b"] == "no_annotation"
def test_coerce_dict_args_with_missing_args():
"""
If function_args lacks some keys listed in annotations,
those are simply not coerced. (We do not add them.)
"""
annotations = {"a": "int", "b": "float"}
function_args = {"a": "42"} # Missing 'b'
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert "b" not in coerced_args
def test_coerce_dict_args_unexpected_keys():
"""
If function_args has extra keys not in annotations,
we leave them alone.
"""
annotations = {"a": "int"}
function_args = {"a": "42", "z": 999}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["z"] == 999 # unchanged
def test_coerce_dict_args_unsupported_custom_class():
"""
If someone tries to pass an annotation that isn't supported (like a custom class),
we should raise a ValueError (or similarly handle the error) rather than silently
accept it.
"""
annotations = {"f": "CustomClass"} # We can't resolve this
function_args = {"f": {"x": 1}}
with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass: Unsupported annotation: CustomClass"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_with_complex_types():
"""
Confirm the ability to parse built-in complex data (lists, dicts, etc.)
when given as strings.
"""
annotations = {"big_list": "list", "nested_dict": "dict"}
function_args = {"big_list": "[1, 2, [3, 4], {'five': 5}]", "nested_dict": '{"alpha": [10, 20], "beta": {"x": 1, "y": 2}}'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["big_list"] == [1, 2, [3, 4], {"five": 5}]
assert coerced_args["nested_dict"] == {
"alpha": [10, 20],
"beta": {"x": 1, "y": 2},
}
def test_coerce_dict_args_non_string_keys():
"""
Validate behavior if `function_args` includes non-string keys.
(We should simply skip annotation checks for them.)
"""
annotations = {"a": "int"}
function_args = {123: "42", "a": "42"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
# 'a' is coerced to int
assert coerced_args["a"] == 42
# 123 remains untouched
assert coerced_args[123] == "42"
def test_coerce_dict_args_non_parseable_list_or_dict():
"""
Test passing incorrectly formatted JSON for a 'list' or 'dict' annotation.
"""
annotations = {"bad_list": "list", "bad_dict": "dict"}
function_args = {"bad_list": "[1, 2, 3", "bad_dict": '{"key": "value"'} # missing brackets
with pytest.raises(ValueError, match="Failed to coerce argument 'bad_list' to list"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_with_complex_list_annotation():
"""
Test coercion when list with type annotation (e.g., list[int]) is used.
"""
annotations = {"a": "list[int]"}
function_args = {"a": "[1, 2, 3]"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == [1, 2, 3]
def test_coerce_dict_args_with_complex_dict_annotation():
"""
Test coercion when dict with type annotation (e.g., dict[str, int]) is used.
"""
annotations = {"a": "dict[str, int]"}
function_args = {"a": '{"x": 1, "y": 2}'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == {"x": 1, "y": 2}
def test_coerce_dict_args_unsupported_complex_annotation():
"""
If an unsupported complex annotation is used (e.g., a custom class),
a ValueError should be raised.
"""
annotations = {"f": "CustomClass[int]"}
function_args = {"f": "CustomClass(42)"}
with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_with_nested_complex_annotation():
"""
Test coercion with complex nested types like list[dict[str, int]].
"""
annotations = {"a": "list[dict[str, int]]"}
function_args = {"a": '[{"x": 1}, {"y": 2}]'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == [{"x": 1}, {"y": 2}]
def test_coerce_dict_args_with_default_arguments():
"""
Test coercion with default arguments, where some arguments have defaults in the source code.
"""
annotations = {"a": "int", "b": "str"}
function_args = {"a": "42"}
function_args.setdefault("b", "hello") # Setting the default value for 'b'
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["b"] == "hello"

View File

@ -6,9 +6,11 @@ from dotenv import load_dotenv
from letta_client import Letta
import letta.functions.function_sets.base as base_functions
from letta import LocalClient, create_client
from letta.config import LettaConfig
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.server.server import SyncServer
from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import (
get_finish_rethinking_memory_schema,
get_rethink_user_memory_schema,
@ -18,15 +20,6 @@ from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import (
from tests.utils import wait_for_server
@pytest.fixture(scope="function")
def client():
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4o"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
def _run_server():
"""Starts the Letta server in a background thread."""
load_dotenv()
@ -35,6 +28,21 @@ def _run_server():
start_server(debug=True)
@pytest.fixture(scope="module")
def server():
"""
Creates a SyncServer instance for testing.
Loads and saves config to ensure proper initialization.
"""
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
yield server
@pytest.fixture(scope="session")
def server_url():
"""Ensures a server is running and returns its base URL."""
@ -57,16 +65,29 @@ def letta_client(server_url):
@pytest.fixture(scope="function")
def agent_obj(client: LocalClient):
def agent_obj(letta_client, server):
"""Create a test agent that we can call functions on"""
send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply")
agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id])
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
send_message_to_agent_and_wait_for_reply_tool_id = letta_client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0].id
agent_state = letta_client.agents.create(
tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id],
include_base_tools=True,
memory_blocks=[
{
"label": "human",
"value": "Name: Matt",
},
{
"label": "persona",
"value": "Friendly agent",
},
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
actor = server.user_manager.get_user_or_default()
agent_obj = server.load_agent(agent_id=agent_state.id, actor=actor)
yield agent_obj
# client.delete_agent(agent_obj.agent_state.id)
def query_in_search_results(search_results, query):
for result in search_results:
@ -127,17 +148,20 @@ def test_archival(agent_obj):
pass
def test_recall_self(client, agent_obj):
# keyword
def test_recall(server, agent_obj, default_user):
"""Test that an agent can recall messages using a keyword via conversation search."""
keyword = "banana"
keyword_backwards = "".join(reversed(keyword))
# Send messages to agent
client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="hello")
client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="what word is '{}' backwards?".format(keyword_backwards))
client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="tell me a fun fact")
# Send messages
for msg in ["hello", keyword, "tell me a fun fact"]:
server.send_messages(
actor=default_user,
agent_id=agent_obj.agent_state.id,
input_messages=[MessageCreate(role="user", content=msg)],
)
# Conversation search
# Search memory
result = base_functions.conversation_search(agent_obj, "banana")
assert keyword in result

View File

@ -110,58 +110,6 @@ def clear_tables():
session.commit()
# TODO: add back
# def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
# """
# Test sandbox config and environment variable functions for both LocalClient and RESTClient.
# """
#
# # 1. Create a sandbox config
# local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR)
# sandbox_config = client.create_sandbox_config(config=local_config)
#
# # Assert the created sandbox config
# assert sandbox_config.id is not None
# assert sandbox_config.type == SandboxType.LOCAL
#
# # 2. Update the sandbox config
# updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR)
# sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config)
# assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR
#
# # 3. List all sandbox configs
# sandbox_configs = client.list_sandbox_configs(limit=10)
# assert isinstance(sandbox_configs, List)
# assert len(sandbox_configs) == 1
# assert sandbox_configs[0].id == sandbox_config.id
#
# # 4. Create an environment variable
# env_var = client.create_sandbox_env_var(
# sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION
# )
# assert env_var.id is not None
# assert env_var.key == ENV_VAR_KEY
# assert env_var.value == ENV_VAR_VALUE
# assert env_var.description == ENV_VAR_DESCRIPTION
#
# # 5. Update the environment variable
# updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE)
# assert updated_env_var.key == UPDATED_ENV_VAR_KEY
# assert updated_env_var.value == UPDATED_ENV_VAR_VALUE
#
# # 6. List environment variables
# env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id)
# assert isinstance(env_vars, List)
# assert len(env_vars) == 1
# assert env_vars[0].key == UPDATED_ENV_VAR_KEY
#
# # 7. Delete the environment variable
# client.delete_sandbox_env_var(env_var_id=env_var.id)
#
# # 8. Delete the sandbox config
# client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent tags
# --------------------------------------------------------------------------------------------------------------------
@ -349,30 +297,6 @@ def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState):
assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)]
# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState):
# """Test that the token limit is enforced for the core memory blocks"""
# # Create an agent
# new_agent = client.create_agent(
# name="test-core-memory-token-limits",
# tools=BASE_TOOLS,
# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=2000),
# )
# try:
# # Then intentionally set the limit to be extremely low
# client.update_agent(
# agent_id=new_agent.id,
# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=100),
# )
# # TODO we should probably not allow updating the core memory limit if
# # TODO in which case we should modify this test to actually to a proper token counter check
# finally:
# client.delete_agent(new_agent.id)
def test_update_agent_memory_limit(client: Letta):
"""Test that we can update the limit of a block in an agent's memory"""
@ -744,3 +668,38 @@ def test_attach_detach_agent_source(client: Letta, agent: AgentState):
assert source.id not in [s.id for s in final_sources]
client.sources.delete(source.id)
# --------------------------------------------------------------------------------------------------------------------
# Agent Initial Message Sequence
# --------------------------------------------------------------------------------------------------------------------
def test_initial_sequence(client: Letta):
# create an agent
agent = client.agents.create(
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
initial_message_sequence=[
MessageCreate(
role="assistant",
content="Hello, how are you?",
),
MessageCreate(role="user", content="I'm good, and you?"),
],
)
# list messages
messages = client.agents.messages.list(agent_id=agent.id)
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content="hello assistant!",
)
],
)
assert len(messages) == 3
assert messages[0].message_type == "system_message"
assert messages[1].message_type == "assistant_message"
assert messages[2].message_type == "user_message"

View File

@ -9,8 +9,7 @@ import pytest
from dotenv import load_dotenv
from sqlalchemy import delete
from letta import create_client
from letta.client.client import LocalClient, RESTClient
from letta.client.client import RESTClient
from letta.constants import DEFAULT_PRESET
from letta.helpers.datetime_helpers import get_utc_time
from letta.orm import FileMetadata, Source
@ -33,7 +32,6 @@ from letta.schemas.usage import LettaUsageStatistics
from letta.services.helpers.agent_manager_helper import initialize_message_sequence
from letta.services.organization_manager import OrganizationManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@ -58,30 +56,22 @@ def run_server():
start_server(debug=True)
# Fixture to create clients with different configurations
@pytest.fixture(
# params=[{"server": True}, {"server": False}], # whether to use REST API server
params=[{"server": True}], # whether to use REST API server
scope="module",
)
def client(request):
if request.param["server"]:
# get URL from enviornment
server_url = os.getenv("LETTA_SERVER_URL")
if server_url is None:
# run server in thread
server_url = "http://localhost:8283"
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create user via admin client
client = create_client(base_url=server_url, token=None) # This yields control back to the test function
else:
# use local client (no server)
client = create_client()
def client():
# get URL from enviornment
server_url = os.getenv("LETTA_SERVER_URL")
if server_url is None:
# run server in thread
server_url = "http://localhost:8283"
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create user via admin client
client = RESTClient(server_url)
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@ -100,7 +90,7 @@ def clear_tables():
# Fixture for test agent
@pytest.fixture(scope="module")
def agent(client: Union[LocalClient, RESTClient]):
def agent(client: Union[RESTClient]):
agent_state = client.create_agent(name=test_agent_name)
yield agent_state
@ -124,7 +114,7 @@ def default_user(default_organization):
yield user
def test_agent(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState):
def test_agent(disable_e2b_api_key, client: RESTClient, agent: AgentState):
# test client.rename_agent
new_name = "RenamedTestAgent"
@ -143,7 +133,7 @@ def test_agent(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agen
assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed"
def test_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState):
def test_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
# _reset_config()
memory_response = client.get_in_context_memory(agent_id=agent.id)
@ -159,7 +149,7 @@ def test_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], age
), "Memory update failed"
def test_agent_interactions(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState):
def test_agent_interactions(disable_e2b_api_key, client: RESTClient, agent: AgentState):
# test that it is a LettaMessage
message = "Hello again, agent!"
print("Sending message", message)
@ -182,7 +172,7 @@ def test_agent_interactions(disable_e2b_api_key, client: Union[LocalClient, REST
# TODO: add streaming tests
def test_archival_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState):
def test_archival_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
# _reset_config()
memory_content = "Archival memory content"
@ -216,7 +206,7 @@ def test_archival_memory(disable_e2b_api_key, client: Union[LocalClient, RESTCli
client.get_archival_memory(agent.id)
def test_core_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState):
def test_core_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState):
response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user")
print("Response", response)
@ -240,10 +230,6 @@ def test_streaming_send_message(
stream_tokens: bool,
model: str,
):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")
assert isinstance(client, RESTClient), client
# Update agent's model
agent.llm_config.model = model
@ -296,7 +282,7 @@ def test_streaming_send_message(
assert done, "Message stream not done"
def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_humans_personas(client: RESTClient, agent: AgentState):
# _reset_config()
humans_response = client.list_humans()
@ -322,7 +308,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
assert human.value == "Human text", "Creating human failed"
def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
def test_list_tools_pagination(client: RESTClient):
tools = client.list_tools()
visited_ids = {t.id: False for t in tools}
@ -344,7 +330,7 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
assert all(visited_ids.values())
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_list_files_pagination(client: RESTClient, agent: AgentState):
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
@ -380,7 +366,7 @@ def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: Ag
assert len(files) == 0 # Should be empty
def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_delete_file_from_source(client: RESTClient, agent: AgentState):
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
@ -409,7 +395,7 @@ def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent:
assert len(empty_files) == 0
def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_load_file(client: RESTClient, agent: AgentState):
# _reset_config()
# clear sources
@ -440,99 +426,7 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):
assert file.source_id == source.id
def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 0
# create a source
source = client.create_source(name="test_source")
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 1
# TODO: add back?
assert sources[0].metadata["num_passages"] == 0
assert sources[0].metadata["num_documents"] == 0
# update the source
original_id = source.id
original_name = source.name
new_name = original_name + "_new"
client.update_source(source_id=source.id, name=new_name)
# get the source name (check that it's been updated)
source = client.get_source(source_id=source.id)
assert source.name == new_name
assert source.id == original_id
# get the source id (make sure that it's the same)
assert str(original_id) == client.get_source_id(source_name=new_name)
# check agent archival memory size
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0
# load a file into a source (non-blocking job)
filename = "tests/data/memgpt_paper.pdf"
upload_job = upload_file_using_client(client, source, filename)
job = client.get_job(upload_job.id)
created_passages = job.metadata["num_passages"]
# TODO: add test for blocking job
# TODO: make sure things run in the right order
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0
# attach a source
client.attach_source(source_id=source.id, agent_id=agent.id)
# list attached sources
attached_sources = client.list_attached_sources(agent_id=agent.id)
print("attached sources", attached_sources)
assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}"
# list archival memory
archival_memories = client.get_archival_memory(agent_id=agent.id)
# print(archival_memories)
assert len(archival_memories) == created_passages, f"Mismatched length {len(archival_memories)} vs. {created_passages}"
# check number of passages
sources = client.list_sources()
# TODO: add back?
# assert sources.sources[0].metadata["num_passages"] > 0
# assert sources.sources[0].metadata["num_documents"] == 0 # TODO: fix this once document store added
print(sources)
# detach the source
assert len(client.get_archival_memory(agent_id=agent.id)) > 0, "No archival memory"
client.detach_source(source_id=source.id, agent_id=agent.id)
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}"
assert source.id not in [s.id for s in client.list_attached_sources(agent.id)]
# delete the source
client.delete_source(source.id)
def test_organization(client: RESTClient):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
# create an organization
org_name = "test-org"
org = client.create_org(org_name)
@ -549,25 +443,6 @@ def test_organization(client: RESTClient):
assert not (org.id in [o.id for o in orgs])
def test_list_llm_models(client: RESTClient):
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""
def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool:
return any(model.model_endpoint_type == target_type for model in models)
models = client.list_llm_configs()
if model_settings.groq_api_key:
assert has_model_endpoint_type(models, "groq")
if model_settings.azure_api_key:
assert has_model_endpoint_type(models, "azure")
if model_settings.openai_api_key:
assert has_model_endpoint_type(models, "openai")
if model_settings.gemini_api_key:
assert has_model_endpoint_type(models, "google_ai")
if model_settings.anthropic_api_key:
assert has_model_endpoint_type(models, "anthropic")
@pytest.fixture
def cleanup_agents(client):
created_agents = []
@ -581,7 +456,7 @@ def cleanup_agents(client):
# NOTE: we need to add this back once agents can also create blocks during agent creation
def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str], default_user):
def test_initial_message_sequence(client: RESTClient, agent: AgentState, cleanup_agents: List[str], default_user):
"""Test that we can set an initial message sequence
If we pass in None, we should get a "default" message sequence
@ -624,7 +499,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
assert custom_sequence[0].content in client.get_in_context_messages(custom_agent_state.id)[1].content[0].text
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_add_and_manage_tags_for_agent(client: RESTClient, agent: AgentState):
"""
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
"""

View File

@ -1,3 +1,4 @@
import asyncio
from datetime import datetime, timezone
from typing import Tuple
from unittest.mock import AsyncMock, patch
@ -457,7 +458,9 @@ async def test_partial_error_from_anthropic_batch(
letta_batch_job_id=batch_job.id,
)
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(
letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user
)
llm_batch_job = llm_batch_jobs[0]
# 2. Invoke the polling job and mock responses from Anthropic
@ -481,7 +484,10 @@ async def test_partial_error_from_anthropic_batch(
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
sizes = await asyncio.gather(
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
)
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
new_batch_responses = await poll_running_llm_batches(server)
@ -545,7 +551,7 @@ async def test_partial_error_from_anthropic_batch(
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
if agent.id == agents_failed[0].id:
assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure"
@ -567,7 +573,7 @@ async def test_partial_error_from_anthropic_batch(
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
# Check the total list of messages
messages = server.batch_manager.get_messages_for_letta_batch(
messages = await server.batch_manager.get_messages_for_letta_batch_async(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == (len(agents) - 1) * 4 + 1
@ -617,7 +623,9 @@ async def test_resume_step_some_stop(
letta_batch_job_id=batch_job.id,
)
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(
letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user
)
llm_batch_job = llm_batch_jobs[0]
# 2. Invoke the polling job and mock responses from Anthropic
@ -643,7 +651,10 @@ async def test_resume_step_some_stop(
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
sizes = await asyncio.gather(
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
)
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
new_batch_responses = await poll_running_llm_batches(server)
@ -703,7 +714,7 @@ async def test_resume_step_some_stop(
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
assert after - before >= 2, (
f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
)
@ -716,7 +727,7 @@ async def test_resume_step_some_stop(
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
# Check the total list of messages
messages = server.batch_manager.get_messages_for_letta_batch(
messages = await server.batch_manager.get_messages_for_letta_batch_async(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == len(agents) * 3 + 1
@ -782,7 +793,9 @@ async def test_resume_step_after_request_all_continue(
# Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly`
# Verify batch items
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user)
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(
letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user
)
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
llm_batch_job = llm_batch_jobs[0]
@ -803,7 +816,10 @@ async def test_resume_step_after_request_all_continue(
with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results):
with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response):
msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents}
sizes = await asyncio.gather(
*[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents]
)
msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)}
new_batch_responses = await poll_running_llm_batches(server)
@ -860,7 +876,7 @@ async def test_resume_step_after_request_all_continue(
# Toolcall sideeffects each agent gets at least 2 extra messages
for agent in agents:
before = msg_counts_before[agent.id] # captured just before resume
after = server.message_manager.size(actor=default_user, agent_id=agent.id)
after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id)
assert after - before >= 2, (
f"Agent {agent.id} should have an assistant toolcall " f"and toolresponse message persisted."
)
@ -873,7 +889,7 @@ async def test_resume_step_after_request_all_continue(
), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}"
# Check the total list of messages
messages = server.batch_manager.get_messages_for_letta_batch(
messages = await server.batch_manager.get_messages_for_letta_batch_async(
letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user
)
assert len(messages) == len(agents) * 4
@ -977,7 +993,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly(
mock_send.assert_called_once()
# Verify database records were created correctly
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user)
llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=response.letta_batch_id, actor=default_user)
assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}"
llm_batch_job = llm_batch_jobs[0]

View File

@ -1,411 +0,0 @@
import uuid
import pytest
from letta import create_client
from letta.client.client import LocalClient
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
@pytest.fixture(scope="module")
def client():
client = create_client()
# client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@pytest.fixture(scope="module")
def agent(client):
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "test_new_client_test_agent"))
agent_state = client.create_agent(name=agent_uuid)
yield agent_state
client.delete_agent(agent_state.id)
def test_agent(client: LocalClient):
# create agent
agent_state_test = client.create_agent(
name="test_agent2",
memory=ChatMemory(human="I am a human", persona="I am an agent"),
description="This is a test agent",
)
assert isinstance(agent_state_test.memory, Memory)
# list agents
agents = client.list_agents()
assert agent_state_test.id in [a.id for a in agents]
# get agent
tools = client.list_tools()
print("TOOLS", [t.name for t in tools])
agent_state = client.get_agent(agent_state_test.id)
assert agent_state.name == "test_agent2"
for block in agent_state.memory.blocks:
db_block = client.server.block_manager.get_block_by_id(block.id, actor=client.user)
assert db_block is not None, "memory block not persisted on agent create"
assert db_block.value == block.value, "persisted block data does not match in-memory data"
assert isinstance(agent_state.memory, Memory)
# update agent: name
new_name = "new_agent"
client.update_agent(agent_state_test.id, name=new_name)
assert client.get_agent(agent_state_test.id).name == new_name
assert isinstance(agent_state.memory, Memory)
# update agent: system prompt
new_system_prompt = agent_state.system + "\nAlways respond with a !"
client.update_agent(agent_state_test.id, system=new_system_prompt)
assert client.get_agent(agent_state_test.id).system == new_system_prompt
response = client.user_message(agent_id=agent_state_test.id, message="Hello")
agent_state = client.get_agent(agent_state_test.id)
assert isinstance(agent_state.memory, Memory)
# update agent: message_ids
old_message_ids = agent_state.message_ids
new_message_ids = old_message_ids.copy()[:-1] # pop one
assert len(old_message_ids) != len(new_message_ids)
client.update_agent(agent_state_test.id, message_ids=new_message_ids)
assert client.get_agent(agent_state_test.id).message_ids == new_message_ids
assert isinstance(agent_state.memory, Memory)
# update agent: tools
tool_to_delete = "send_message"
assert tool_to_delete in [t.name for t in agent_state.tools]
new_agent_tool_ids = [t.id for t in agent_state.tools if t.name != tool_to_delete]
client.update_agent(agent_state_test.id, tool_ids=new_agent_tool_ids)
assert sorted([t.id for t in client.get_agent(agent_state_test.id).tools]) == sorted(new_agent_tool_ids)
assert isinstance(agent_state.memory, Memory)
# update agent: memory
new_human = "My name is Mr Test, 100 percent human."
new_persona = "I am an all-knowing AI."
assert agent_state.memory.get_block("human").value != new_human
assert agent_state.memory.get_block("persona").value != new_persona
# client.update_agent(agent_state_test.id, memory=new_memory)
# update blocks:
client.update_agent_memory_block(agent_state_test.id, label="human", value=new_human)
client.update_agent_memory_block(agent_state_test.id, label="persona", value=new_persona)
assert client.get_agent(agent_state_test.id).memory.get_block("human").value == new_human
assert client.get_agent(agent_state_test.id).memory.get_block("persona").value == new_persona
# update agent: llm config
new_llm_config = agent_state.llm_config.model_copy(deep=True)
new_llm_config.model = "fake_new_model"
new_llm_config.context_window = 1e6
assert agent_state.llm_config != new_llm_config
client.update_agent(agent_state_test.id, llm_config=new_llm_config)
assert client.get_agent(agent_state_test.id).llm_config == new_llm_config
assert client.get_agent(agent_state_test.id).llm_config.model == "fake_new_model"
assert client.get_agent(agent_state_test.id).llm_config.context_window == 1e6
# update agent: embedding config
new_embed_config = agent_state.embedding_config.model_copy(deep=True)
new_embed_config.embedding_model = "fake_embed_model"
assert agent_state.embedding_config != new_embed_config
client.update_agent(agent_state_test.id, embedding_config=new_embed_config)
assert client.get_agent(agent_state_test.id).embedding_config == new_embed_config
assert client.get_agent(agent_state_test.id).embedding_config.embedding_model == "fake_embed_model"
# delete agent
client.delete_agent(agent_state_test.id)
def test_agent_add_remove_tools(client: LocalClient, agent):
# Create and add two tools to the client
# tool 1
from composio import Action
github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
# assert both got added
tools = client.list_tools()
assert github_tool.id in [t.id for t in tools]
# Assert that all combinations of tool_names, organization id are unique
combinations = [(t.name, t.organization_id) for t in tools]
assert len(combinations) == len(set(combinations))
# create agent
agent_state = agent
curr_num_tools = len(agent_state.tools)
# add both tools to agent in steps
agent_state = client.attach_tool(agent_id=agent_state.id, tool_id=github_tool.id)
# confirm that both tools are in the agent state
# we could access it like agent_state.tools, but will use the client function instead
# this is obviously redundant as it requires retrieving the agent again
# but allows us to test the `get_tools_from_agent` pathway as well
curr_tools = client.get_tools_from_agent(agent_state.id)
curr_tool_names = [t.name for t in curr_tools]
assert len(curr_tool_names) == curr_num_tools + 1
assert github_tool.name in curr_tool_names
# remove only the github tool
agent_state = client.detach_tool(agent_id=agent_state.id, tool_id=github_tool.id)
# confirm that only one tool left
curr_tools = client.get_tools_from_agent(agent_state.id)
curr_tool_names = [t.name for t in curr_tools]
assert len(curr_tool_names) == curr_num_tools
assert github_tool.name not in curr_tool_names
def test_agent_with_shared_blocks(client: LocalClient):
persona_block = client.create_block(template_name="persona", value="Here to test things!", label="persona")
human_block = client.create_block(template_name="human", value="Me Human, I swear. Beep boop.", label="human")
existing_non_template_blocks = [persona_block, human_block]
existing_non_template_blocks_no_values = []
for block in existing_non_template_blocks:
block_copy = block.copy()
block_copy.value = ""
existing_non_template_blocks_no_values.append(block_copy)
# create agent
first_agent_state_test = None
second_agent_state_test = None
try:
first_agent_state_test = client.create_agent(
name="first_test_agent_shared_memory_blocks",
memory=BasicBlockMemory(blocks=existing_non_template_blocks),
description="This is a test agent using shared memory blocks",
)
assert isinstance(first_agent_state_test.memory, Memory)
# when this agent is created with the shared block references this agent's in-memory blocks should
# have this latest value set by the other agent.
second_agent_state_test = client.create_agent(
name="second_test_agent_shared_memory_blocks",
memory=BasicBlockMemory(blocks=existing_non_template_blocks_no_values),
description="This is a test agent using shared memory blocks",
)
first_memory = first_agent_state_test.memory
assert persona_block.id == first_memory.get_block("persona").id
assert human_block.id == first_memory.get_block("human").id
client.update_agent_memory_block(first_agent_state_test.id, label="human", value="I'm an analyst therapist.")
print("Updated human block value:", client.get_agent_memory_block(first_agent_state_test.id, label="human").value)
# refresh agent state
second_agent_state_test = client.get_agent(second_agent_state_test.id)
assert isinstance(second_agent_state_test.memory, Memory)
second_memory = second_agent_state_test.memory
assert persona_block.id == second_memory.get_block("persona").id
assert human_block.id == second_memory.get_block("human").id
# assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist."
assert second_memory.get_block("human").value == "I'm an analyst therapist."
finally:
if first_agent_state_test:
client.delete_agent(first_agent_state_test.id)
if second_agent_state_test:
client.delete_agent(second_agent_state_test.id)
def test_memory(client: LocalClient, agent: AgentState):
# get agent memory
original_memory = client.get_in_context_memory(agent.id)
assert original_memory is not None
original_memory_value = str(original_memory.get_block("human").value)
# update core memory
updated_memory = client.update_in_context_memory(agent.id, section="human", value="I am a human")
# get memory
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
def test_archival_memory(client: LocalClient, agent: AgentState):
"""Test functions for interacting with archival memory store"""
# add archival memory
memory_str = "I love chats"
passage = client.insert_archival_memory(agent.id, memory=memory_str)[0]
# list archival memory
passages = client.get_archival_memory(agent.id)
assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}"
# delete archival memory
client.delete_archival_memory(agent.id, passage.id)
def test_recall_memory(client: LocalClient, agent: AgentState):
"""Test functions for interacting with recall memory store"""
# send message to the agent
message_str = "Hello"
client.send_message(message=message_str, role="user", agent_id=agent.id)
# list messages
messages = client.get_messages(agent.id)
exists = False
for m in messages:
if message_str in str(m):
exists = True
assert exists
# get in-context messages
in_context_messages = client.get_in_context_messages(agent.id)
exists = False
for m in in_context_messages:
if message_str in m.content[0].text:
exists = True
assert exists
def test_tools(client: LocalClient):
def print_tool(message: str):
"""
A tool to print a message
Args:
message (str): The message to print.
Returns:
str: The message that was printed.
"""
print(message)
return message
def print_tool2(msg: str):
"""
Another tool to print a message
Args:
msg (str): The message to print.
"""
print(msg)
# Clean all tools first
for tool in client.list_tools():
client.delete_tool(tool.id)
# create tool
tool = client.create_or_update_tool(func=print_tool, tags=["extras"])
# list tools
tools = client.list_tools()
assert tool.name in [t.name for t in tools]
# get tool id
assert tool.id == client.get_tool_id(name="print_tool")
# update tool: extras
extras2 = ["extras2"]
client.update_tool(tool.id, tags=extras2)
assert client.get_tool(tool.id).tags == extras2
# update tool: source code
client.update_tool(tool.id, func=print_tool2)
assert client.get_tool(tool.id).name == "print_tool2"
def test_tools_from_composio_basic(client: LocalClient):
from composio import Action
# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example)
client = create_client()
# create tool
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
# list tools
tools = client.list_tools()
assert tool.name in [t.name for t in tools]
# We end the test here as composio requires login to use the tools
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
# TODO: Langchain seems to have issues with Pydantic
# TODO: Langchain tools are breaking every two weeks bc of changes on their side
# def test_tools_from_langchain(client: LocalClient):
# # create langchain tool
# from langchain_community.tools import WikipediaQueryRun
# from langchain_community.utilities import WikipediaAPIWrapper
#
# langchain_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
#
# # Add the tool
# tool = client.load_langchain_tool(
# langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}
# )
#
# # list tools
# tools = client.list_tools()
# assert tool.name in [t.name for t in tools]
#
# # get tool
# tool_id = client.get_tool_id(name=tool.name)
# retrieved_tool = client.get_tool(tool_id)
# source_code = retrieved_tool.source_code
#
# # Parse the function and attempt to use it
# local_scope = {}
# exec(source_code, {}, local_scope)
# func = local_scope[tool.name]
#
# expected_content = "Albert Einstein"
# assert expected_content in func(query="Albert Einstein")
#
#
# def test_tool_creation_langchain_missing_imports(client: LocalClient):
# # create langchain tool
# from langchain_community.tools import WikipediaQueryRun
# from langchain_community.utilities import WikipediaAPIWrapper
#
# api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
# langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
#
# # Translate to memGPT Tool
# # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"}
# with pytest.raises(RuntimeError):
# ToolCreate.from_langchain(langchain_tool)
def test_shared_blocks_without_send_message(client: LocalClient):
from letta import BasicBlockMemory
from letta.client.client import Block, create_client
from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
client = create_client()
shared_memory_block = Block(name="shared_memory", label="shared_memory", value="[empty]", limit=2000)
memory = BasicBlockMemory(blocks=[shared_memory_block])
agent_1 = client.create_agent(
agent_type=AgentType.memgpt_agent,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
memory=memory,
)
agent_2 = client.create_agent(
agent_type=AgentType.memgpt_agent,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
memory=memory,
)
block_id = agent_1.memory.get_block("shared_memory").id
client.update_block(block_id, value="I am no longer an [empty] memory")
agent_1 = client.get_agent(agent_1.id)
agent_2 = client.get_agent(agent_2.id)
assert agent_1.memory.get_block("shared_memory").value == "I am no longer an [empty] memory"
assert agent_2.memory.get_block("shared_memory").value == "I am no longer an [empty] memory"

View File

@ -140,32 +140,32 @@ async def other_user_different_org(server: SyncServer, other_organization):
@pytest.fixture
def default_source(server: SyncServer, default_user):
async def default_source(server: SyncServer, default_user):
source_pydantic = PydanticSource(
name="Test Source",
description="This is a test source.",
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
yield source
@pytest.fixture
def other_source(server: SyncServer, default_user):
async def other_source(server: SyncServer, default_user):
source_pydantic = PydanticSource(
name="Another Test Source",
description="This is yet another test source.",
metadata={"type": "another_test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
yield source
@pytest.fixture
def default_file(server: SyncServer, default_source, default_user, default_organization):
file = server.source_manager.create_file(
async def default_file(server: SyncServer, default_source, default_user, default_organization):
file = await server.source_manager.create_file(
PydanticFileMetadata(file_name="test_file", organization_id=default_organization.id, source_id=default_source.id),
actor=default_user,
)
@ -1175,17 +1175,18 @@ async def test_list_attached_source_ids_nonexistent_agent(server: SyncServer, de
await server.agent_manager.list_attached_sources_async(agent_id="nonexistent-agent-id", actor=default_user)
def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user):
@pytest.mark.asyncio
async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user, event_loop):
"""Test listing agents that have a particular source attached."""
# Initially should have no attached agents
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 0
# Attach source to first agent
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
# Verify one agent is now attached
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 1
assert sarah_agent.id in [a.id for a in attached_agents]
@ -1193,7 +1194,7 @@ def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, de
server.agent_manager.attach_source(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user)
# Verify both agents are now attached
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 2
attached_agent_ids = [a.id for a in attached_agents]
assert sarah_agent.id in attached_agent_ids
@ -1203,15 +1204,16 @@ def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, de
server.agent_manager.detach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
# Verify only second agent remains attached
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 1
assert charles_agent.id in [a.id for a in attached_agents]
def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user):
"""Test listing agents for a nonexistent source."""
with pytest.raises(NoResultFound):
server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user)
await server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user)
# ======================================================================================================================
@ -2177,7 +2179,7 @@ async def test_passage_cascade_deletion(
assert len(agentic_passages) == 0
# Delete source and verify its passages are deleted
server.source_manager.delete_source(default_source.id, default_user)
await server.source_manager.delete_source(default_source.id, default_user)
with pytest.raises(NoResultFound):
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
@ -3847,7 +3849,10 @@ async def test_upsert_properties(server: SyncServer, default_user, event_loop):
# ======================================================================================================================
# SourceManager Tests - Sources
# ======================================================================================================================
def test_create_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_source(server: SyncServer, default_user, event_loop):
"""Test creating a new source."""
source_pydantic = PydanticSource(
name="Test Source",
@ -3855,7 +3860,7 @@ def test_create_source(server: SyncServer, default_user):
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Assertions to check the created source
assert source.name == source_pydantic.name
@ -3864,7 +3869,8 @@ def test_create_source(server: SyncServer, default_user):
assert source.organization_id == default_user.organization_id
def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user):
"""Test creating a new source."""
name = "Test Source"
source_pydantic = PydanticSource(
@ -3873,27 +3879,28 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul
metadata={"type": "medical"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
source_pydantic = PydanticSource(
name=name,
description="This is a different test source.",
metadata={"type": "legal"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
same_source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
assert source.name == same_source.name
assert source.id != same_source.id
def test_update_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_update_source(server: SyncServer, default_user):
"""Test updating an existing source."""
source_pydantic = PydanticSource(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Update the source
update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata={"type": "updated"})
updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
updated_source = await server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
# Assertions to verify update
assert updated_source.name == update_data.name
@ -3901,21 +3908,22 @@ def test_update_source(server: SyncServer, default_user):
assert updated_source.metadata == update_data.metadata
def test_delete_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_delete_source(server: SyncServer, default_user):
"""Test deleting a source."""
source_pydantic = PydanticSource(
name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Delete the source
deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user)
deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user)
# Assertions to verify deletion
assert deleted_source.id == source.id
# Verify that the source no longer appears in list_sources
sources = server.source_manager.list_sources(actor=default_user)
sources = await server.source_manager.list_sources(actor=default_user)
assert len(sources) == 0
@ -3925,18 +3933,18 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u
source_pydantic = PydanticSource(
name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user)
# Delete the source
deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user)
deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user)
# Assertions to verify deletion
assert deleted_source.id == source.id
# Verify that the source no longer appears in list_sources
sources = server.source_manager.list_sources(actor=default_user)
sources = await server.source_manager.list_sources(actor=default_user)
assert len(sources) == 0
# Verify that agent is not deleted
@ -3944,37 +3952,43 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u
assert agent is not None
def test_list_sources(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_sources(server: SyncServer, default_user):
"""Test listing sources with pagination."""
# Create multiple sources
server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
await server.source_manager.create_source(
PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user
)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
await server.source_manager.create_source(
PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user
)
# List sources without pagination
sources = server.source_manager.list_sources(actor=default_user)
sources = await server.source_manager.list_sources(actor=default_user)
assert len(sources) == 2
# List sources with pagination
paginated_sources = server.source_manager.list_sources(actor=default_user, limit=1)
paginated_sources = await server.source_manager.list_sources(actor=default_user, limit=1)
assert len(paginated_sources) == 1
# Ensure cursor-based pagination works
next_page = server.source_manager.list_sources(actor=default_user, after=paginated_sources[-1].id, limit=1)
next_page = await server.source_manager.list_sources(actor=default_user, after=paginated_sources[-1].id, limit=1)
assert len(next_page) == 1
assert next_page[0].name != paginated_sources[0].name
def test_get_source_by_id(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_get_source_by_id(server: SyncServer, default_user):
"""Test retrieving a source by ID."""
source_pydantic = PydanticSource(
name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Retrieve the source by ID
retrieved_source = server.source_manager.get_source_by_id(source_id=source.id, actor=default_user)
retrieved_source = await server.source_manager.get_source_by_id(source_id=source.id, actor=default_user)
# Assertions to verify the retrieved source matches the created one
assert retrieved_source.id == source.id
@ -3982,29 +3996,31 @@ def test_get_source_by_id(server: SyncServer, default_user):
assert retrieved_source.description == source.description
def test_get_source_by_name(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_get_source_by_name(server: SyncServer, default_user):
"""Test retrieving a source by name."""
source_pydantic = PydanticSource(
name="Unique Source", description="Test source for name retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Retrieve the source by name
retrieved_source = server.source_manager.get_source_by_name(source_name=source.name, actor=default_user)
retrieved_source = await server.source_manager.get_source_by_name(source_name=source.name, actor=default_user)
# Assertions to verify the retrieved source matches the created one
assert retrieved_source.name == source.name
assert retrieved_source.description == source.description
def test_update_source_no_changes(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_update_source_no_changes(server: SyncServer, default_user):
"""Test update_source with no actual changes to verify logging and response."""
source_pydantic = PydanticSource(name="No Change Source", description="No changes", embedding_config=DEFAULT_EMBEDDING_CONFIG)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Attempt to update the source with identical data
update_data = SourceUpdate(name="No Change Source", description="No changes")
updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
updated_source = await server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
# Assertions to ensure the update returned the source but made no modifications
assert updated_source.id == source.id
@ -4017,7 +4033,8 @@ def test_update_source_no_changes(server: SyncServer, default_user):
# ======================================================================================================================
def test_get_file_by_id(server: SyncServer, default_user, default_source):
@pytest.mark.asyncio
async def test_get_file_by_id(server: SyncServer, default_user, default_source):
"""Test retrieving a file by ID."""
file_metadata = PydanticFileMetadata(
file_name="Retrieve File",
@ -4026,10 +4043,10 @@ def test_get_file_by_id(server: SyncServer, default_user, default_source):
file_size=2048,
source_id=default_source.id,
)
created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
created_file = await server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
# Retrieve the file by ID
retrieved_file = server.source_manager.get_file_by_id(file_id=created_file.id, actor=default_user)
retrieved_file = await server.source_manager.get_file_by_id(file_id=created_file.id, actor=default_user)
# Assertions to verify the retrieved file matches the created one
assert retrieved_file.id == created_file.id
@ -4038,49 +4055,53 @@ def test_get_file_by_id(server: SyncServer, default_user, default_source):
assert retrieved_file.file_type == created_file.file_type
def test_list_files(server: SyncServer, default_user, default_source):
@pytest.mark.asyncio
async def test_list_files(server: SyncServer, default_user, default_source):
"""Test listing files with pagination."""
# Create multiple files
server.source_manager.create_file(
await server.source_manager.create_file(
PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_file(
await server.source_manager.create_file(
PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
)
# List files without pagination
files = server.source_manager.list_files(source_id=default_source.id, actor=default_user)
files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user)
assert len(files) == 2
# List files with pagination
paginated_files = server.source_manager.list_files(source_id=default_source.id, actor=default_user, limit=1)
paginated_files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user, limit=1)
assert len(paginated_files) == 1
# Ensure cursor-based pagination works
next_page = server.source_manager.list_files(source_id=default_source.id, actor=default_user, after=paginated_files[-1].id, limit=1)
next_page = await server.source_manager.list_files(
source_id=default_source.id, actor=default_user, after=paginated_files[-1].id, limit=1
)
assert len(next_page) == 1
assert next_page[0].file_name != paginated_files[0].file_name
def test_delete_file(server: SyncServer, default_user, default_source):
@pytest.mark.asyncio
async def test_delete_file(server: SyncServer, default_user, default_source):
"""Test deleting a file."""
file_metadata = PydanticFileMetadata(
file_name="Delete File", file_path="/path/to/delete_file.txt", file_type="text/plain", source_id=default_source.id
)
created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
created_file = await server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
# Delete the file
deleted_file = server.source_manager.delete_file(file_id=created_file.id, actor=default_user)
deleted_file = await server.source_manager.delete_file(file_id=created_file.id, actor=default_user)
# Assertions to verify deletion
assert deleted_file.id == created_file.id
# Verify that the file no longer appears in list_files
files = server.source_manager.list_files(source_id=default_source.id, actor=default_user)
files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user)
assert len(files) == 0
@ -5126,7 +5147,7 @@ async def test_update_batch_status(server, default_user, dummy_beta_message_batc
)
before = datetime.now(timezone.utc)
server.batch_manager.update_llm_batch_status(
await server.batch_manager.update_llm_batch_status_async(
llm_batch_id=batch.id,
status=JobStatus.completed,
latest_polling_response=dummy_beta_message_batch,
@ -5151,7 +5172,7 @@ async def test_create_and_get_batch_item(
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_llm_batch_item(
item = await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5163,7 +5184,7 @@ async def test_create_and_get_batch_item(
assert item.agent_id == sarah_agent.id
assert item.step_state == dummy_step_state
fetched = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
assert fetched.id == item.id
@ -5187,7 +5208,7 @@ async def test_update_batch_item(
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_llm_batch_item(
item = await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5197,7 +5218,7 @@ async def test_update_batch_item(
updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver)
server.batch_manager.update_llm_batch_item(
await server.batch_manager.update_llm_batch_item_async(
item_id=item.id,
request_status=JobStatus.completed,
step_status=AgentStepStatus.resumed,
@ -5206,7 +5227,7 @@ async def test_update_batch_item(
actor=default_user,
)
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
assert updated.request_status == JobStatus.completed
assert updated.batch_request_result == dummy_successful_response
@ -5223,7 +5244,7 @@ async def test_delete_batch_item(
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_llm_batch_item(
item = await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5231,10 +5252,10 @@ async def test_delete_batch_item(
actor=default_user,
)
server.batch_manager.delete_llm_batch_item(item_id=item.id, actor=default_user)
await server.batch_manager.delete_llm_batch_item_async(item_id=item.id, actor=default_user)
with pytest.raises(NoResultFound):
server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
@pytest.mark.asyncio
@ -5262,7 +5283,7 @@ async def test_bulk_update_batch_statuses(server, default_user, dummy_beta_messa
letta_batch_job_id=letta_batch_job.id,
)
server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
await server.batch_manager.bulk_update_llm_batch_statuses_async([(batch.id, JobStatus.completed, dummy_beta_message_batch)])
updated = await server.batch_manager.get_llm_batch_job_by_id_async(batch.id, actor=default_user)
assert updated.status == JobStatus.completed
@ -5287,7 +5308,7 @@ async def test_bulk_update_batch_items_results_by_agent(
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_llm_batch_item(
item = await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5295,11 +5316,11 @@ async def test_bulk_update_batch_items_results_by_agent(
actor=default_user,
)
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(
[ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)]
)
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
assert updated.request_status == JobStatus.completed
assert updated.batch_request_result == dummy_successful_response
@ -5314,7 +5335,7 @@ async def test_bulk_update_batch_items_step_status_by_agent(
actor=default_user,
letta_batch_job_id=letta_batch_job.id,
)
item = server.batch_manager.create_llm_batch_item(
item = await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5322,11 +5343,11 @@ async def test_bulk_update_batch_items_step_status_by_agent(
actor=default_user,
)
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(
[StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)]
)
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
assert updated.step_status == AgentStepStatus.resumed
@ -5342,7 +5363,7 @@ async def test_list_batch_items_limit_and_filter(
)
for _ in range(3):
server.batch_manager.create_llm_batch_item(
await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5372,7 +5393,7 @@ async def test_list_batch_items_pagination(
# Create 10 batch items.
created_items = []
for i in range(10):
item = server.batch_manager.create_llm_batch_item(
item = await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5435,7 +5456,7 @@ async def test_bulk_update_batch_items_request_status_by_agent(
)
# Create a batch item
item = server.batch_manager.create_llm_batch_item(
item = await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5444,12 +5465,12 @@ async def test_bulk_update_batch_items_request_status_by_agent(
)
# Update the request status using the bulk update method
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(
[RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)]
)
# Verify the update was applied
updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user)
updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user)
assert updated.request_status == JobStatus.expired
@ -5478,20 +5499,20 @@ async def test_bulk_update_nonexistent_items_should_error(
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates)
await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)]
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)]
)
with pytest.raises(ValueError, match=re.escape(expected_err_msg)):
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)]
)
@ -5515,21 +5536,21 @@ async def test_bulk_update_nonexistent_items(
nonexistent_updates = [{"request_status": JobStatus.expired}]
# This should not raise an error, just silently skip non-existent items
server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False)
await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates, strict=False)
# Test with higher-level methods
# Results by agent
server.batch_manager.bulk_update_batch_llm_items_results_by_agent(
await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(
[ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False
)
# Step status by agent
server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(
await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(
[StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False
)
# Request status by agent
server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(
await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(
[RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False
)
@ -5584,7 +5605,7 @@ async def test_create_batch_items_bulk(
# Verify the IDs of created items match what's in the database
created_ids = [item.id for item in created_items]
for item_id in created_ids:
fetched = server.batch_manager.get_llm_batch_item_by_id(item_id, actor=default_user)
fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item_id, actor=default_user)
assert fetched.id in created_ids
@ -5604,7 +5625,7 @@ async def test_count_batch_items(
# Create a specific number of batch items for this batch.
num_items = 5
for _ in range(num_items):
server.batch_manager.create_llm_batch_item(
await server.batch_manager.create_llm_batch_item_async(
llm_batch_id=batch.id,
agent_id=sarah_agent.id,
llm_config=dummy_llm_config,
@ -5613,7 +5634,7 @@ async def test_count_batch_items(
)
# Use the count_llm_batch_items method to count the items.
count = server.batch_manager.count_llm_batch_items(llm_batch_id=batch.id)
count = await server.batch_manager.count_llm_batch_items_async(llm_batch_id=batch.id)
# Assert that the count matches the expected number.
assert count == num_items, f"Expected {num_items} items, got {count}"

View File

@ -1,439 +0,0 @@
import os
import pytest
from tests.helpers.endpoints_helper import (
check_agent_archival_memory_insert,
check_agent_archival_memory_retrieval,
check_agent_edit_core_memory,
check_agent_recall_chat_memory,
check_agent_uses_external_tool,
check_first_response_is_valid_for_llm_endpoint,
run_embedding_endpoint,
)
from tests.helpers.utils import retry_until_success, retry_until_threshold
# directories
embedding_config_dir = "tests/configs/embedding_model_configs"
llm_config_dir = "tests/configs/llm_model_configs"
# ======================================================================================================================
# OPENAI TESTS
# ======================================================================================================================
@pytest.mark.openai_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.openai_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_uses_external_tool(disable_e2b_api_key):
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.openai_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_recall_chat_memory():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.openai_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.openai_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_archival_memory_insert():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_archival_memory_insert(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.openai_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_openai_gpt_4o_edit_core_memory():
filename = os.path.join(llm_config_dir, "openai-gpt-4o.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.openai_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_embedding_endpoint_openai():
filename = os.path.join(embedding_config_dir, "openai_embed.json")
run_embedding_endpoint(filename)
# ======================================================================================================================
# AZURE TESTS
# ======================================================================================================================
@pytest.mark.azure_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_azure_gpt_4o_mini_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.azure_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_azure_gpt_4o_mini_uses_external_tool(disable_e2b_api_key):
filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.azure_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_azure_gpt_4o_mini_recall_chat_memory():
filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.azure_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_azure_gpt_4o_mini_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.azure_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_azure_gpt_4o_mini_edit_core_memory():
filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.azure_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_azure_embedding_endpoint():
filename = os.path.join(embedding_config_dir, "azure_embed.json")
run_embedding_endpoint(filename)
# ======================================================================================================================
# LETTA HOSTED
# ======================================================================================================================
def test_llm_endpoint_letta_hosted():
filename = os.path.join(llm_config_dir, "letta-hosted.json")
check_first_response_is_valid_for_llm_endpoint(filename)
def test_embedding_endpoint_letta_hosted():
filename = os.path.join(embedding_config_dir, "letta-hosted.json")
run_embedding_endpoint(filename)
# ======================================================================================================================
# LOCAL MODELS
# ======================================================================================================================
def test_embedding_endpoint_local():
filename = os.path.join(embedding_config_dir, "local.json")
run_embedding_endpoint(filename)
def test_llm_endpoint_ollama():
filename = os.path.join(llm_config_dir, "ollama.json")
check_first_response_is_valid_for_llm_endpoint(filename)
def test_embedding_endpoint_ollama():
filename = os.path.join(embedding_config_dir, "ollama.json")
run_embedding_endpoint(filename)
# ======================================================================================================================
# ANTHROPIC TESTS
# ======================================================================================================================
@pytest.mark.anthropic_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_claude_haiku_3_5_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_claude_haiku_3_5_uses_external_tool(disable_e2b_api_key):
filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_claude_haiku_3_5_recall_chat_memory():
filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_claude_haiku_3_5_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_claude_haiku_3_5_edit_core_memory():
filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# GROQ TESTS
# ======================================================================================================================
def test_groq_llama31_70b_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_groq_llama31_70b_uses_external_tool(disable_e2b_api_key):
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_groq_llama31_70b_recall_chat_memory():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@retry_until_threshold(threshold=0.75, max_attempts=4)
def test_groq_llama31_70b_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_groq_llama31_70b_edit_core_memory():
filename = os.path.join(llm_config_dir, "groq.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# GEMINI TESTS
# ======================================================================================================================
@pytest.mark.gemini_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_gemini_pro_15_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "gemini-pro.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.gemini_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_gemini_pro_15_uses_external_tool(disable_e2b_api_key):
filename = os.path.join(llm_config_dir, "gemini-pro.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.gemini_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_gemini_pro_15_recall_chat_memory():
filename = os.path.join(llm_config_dir, "gemini-pro.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.gemini_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_gemini_pro_15_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "gemini-pro.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.gemini_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_gemini_pro_15_edit_core_memory():
filename = os.path.join(llm_config_dir, "gemini-pro.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# GOOGLE VERTEX TESTS
# ======================================================================================================================
@pytest.mark.vertex_basic
@retry_until_success(max_attempts=1, sleep_time_seconds=2)
def test_vertex_gemini_pro_20_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "gemini-vertex.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# DEEPSEEK TESTS
# ======================================================================================================================
@pytest.mark.deepseek_basic
def test_deepseek_reasoner_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "deepseek-reasoner.json")
# Don't validate that the inner monologue doesn't contain things like "function", since
# for the reasoners it might be quite meta (have analysis about functions etc.)
response = check_first_response_is_valid_for_llm_endpoint(filename, validate_inner_monologue_contents=False)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# xAI TESTS
# ======================================================================================================================
@pytest.mark.xai_basic
def test_xai_grok2_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "xai-grok-2.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# TOGETHER TESTS
# ======================================================================================================================
def test_together_llama_3_70b_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "together-llama-3-70b.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_together_llama_3_70b_uses_external_tool(disable_e2b_api_key):
filename = os.path.join(llm_config_dir, "together-llama-3-70b.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_together_llama_3_70b_recall_chat_memory():
filename = os.path.join(llm_config_dir, "together-llama-3-70b.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_together_llama_3_70b_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "together-llama-3-70b.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
def test_together_llama_3_70b_edit_core_memory():
filename = os.path.join(llm_config_dir, "together-llama-3-70b.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
# ======================================================================================================================
# ANTHROPIC BEDROCK TESTS
# ======================================================================================================================
@pytest.mark.anthropic_bedrock_basic
def test_bedrock_claude_sonnet_3_5_valid_config():
import json
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json")
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
model_region = llm_config.model.split(":")[3]
assert model_settings.aws_region == model_region, "Model region in config file does not match model region in ModelSettings"
@pytest.mark.anthropic_bedrock_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_bedrock_claude_sonnet_3_5_returns_valid_first_message():
filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json")
response = check_first_response_is_valid_for_llm_endpoint(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_bedrock_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_bedrock_claude_sonnet_3_5_uses_external_tool(disable_e2b_api_key):
filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json")
response = check_agent_uses_external_tool(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_bedrock_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_bedrock_claude_sonnet_3_5_recall_chat_memory():
filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json")
response = check_agent_recall_chat_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_bedrock_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_bedrock_claude_sonnet_3_5_archival_memory_retrieval():
filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json")
response = check_agent_archival_memory_retrieval(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")
@pytest.mark.anthropic_bedrock_basic
@retry_until_success(max_attempts=5, sleep_time_seconds=2)
def test_bedrock_claude_sonnet_3_5_edit_core_memory():
filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json")
response = check_agent_edit_core_memory(filename)
# Log out successful response
print(f"Got successful response from client: \n\n{response}")

View File

@ -680,3 +680,72 @@ def test_many_blocks(client: LettaSDKClient):
client.agents.delete(agent1.id)
client.agents.delete(agent2.id)
def test_sources(client: LettaSDKClient, agent: AgentState):
# Clear existing sources
for source in client.sources.list():
client.sources.delete(source_id=source.id)
# Clear existing jobs
for job in client.jobs.list():
client.jobs.delete(job_id=job.id)
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
assert len(client.sources.list()) == 1
# delete the source
client.sources.delete(source_id=source.id)
assert len(client.sources.list()) == 0
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
# Load files into the source
file_a_path = "tests/data/memgpt_paper.pdf"
file_b_path = "tests/data/test.txt"
# Upload the files
with open(file_a_path, "rb") as f:
job_a = client.sources.files.upload(source_id=source.id, file=f)
with open(file_b_path, "rb") as f:
job_b = client.sources.files.upload(source_id=source.id, file=f)
# Wait for the jobs to complete
while job_a.status != "completed" or job_b.status != "completed":
time.sleep(1)
job_a = client.jobs.retrieve(job_id=job_a.id)
job_b = client.jobs.retrieve(job_id=job_b.id)
print("Waiting for jobs to complete...", job_a.status, job_b.status)
# Get the first file with pagination
files_a = client.sources.files.list(source_id=source.id, limit=1)
assert len(files_a) == 1
assert files_a[0].source_id == source.id
# Use the cursor from files_a to get the remaining file
files_b = client.sources.files.list(source_id=source.id, limit=1, after=files_a[-1].id)
assert len(files_b) == 1
assert files_b[0].source_id == source.id
# Check files are different to ensure the cursor works
assert files_a[0].file_name != files_b[0].file_name
# Use the cursor from files_b to list files, should be empty
files = client.sources.files.list(source_id=source.id, limit=1, after=files_b[-1].id)
assert len(files) == 0 # Should be empty
# list passages
passages = client.sources.passages.list(source_id=source.id)
assert len(passages) > 0
# attach to an agent
assert len(client.agents.passages.list(agent_id=agent.id)) == 0
client.agents.sources.attach(source_id=source.id, agent_id=agent.id)
assert len(client.agents.passages.list(agent_id=agent.id)) > 0
assert len(client.agents.sources.list(agent_id=agent.id)) == 1
# detach from agent
client.agents.sources.detach(source_id=source.id, agent_id=agent.id)
assert len(client.agents.passages.list(agent_id=agent.id)) == 0

View File

@ -1,3 +1,4 @@
import asyncio
import json
import os
import shutil
@ -24,15 +25,10 @@ from letta.server.db import db_registry
utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import Job as PydanticJob
from letta.schemas.message import Message
from letta.schemas.source import Source as PydanticSource
from letta.server.server import SyncServer
from letta.system import unpack_message
from .utils import DummyDataConnector
WAR_AND_PEACE = """BOOK ONE: 1805
CHAPTER I
@ -270,8 +266,6 @@ start my apprenticeship as old maid."""
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
print("CONFIG PATH", config.config_path)
config.save()
server = SyncServer()
@ -366,6 +360,14 @@ def other_agent_id(server, user_id, base_tools):
server.agent_manager.delete_agent(agent_state.id, actor=actor)
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for each test case."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
def test_error_on_nonexistent_agent(server, user, agent_id):
try:
fake_agent_id = str(uuid.uuid4())
@ -392,40 +394,6 @@ def test_user_message_memory(server, user, agent_id):
server.run_command(user_id=user.id, agent_id=agent_id, command="/memory")
@pytest.mark.order(3)
def test_load_data(server, user, agent_id):
# create source
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, after=None, limit=10000)
assert len(passages_before) == 0
source = server.source_manager.create_source(
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
)
# load data
archival_memories = [
"alpha",
"Cinderella wore a blue dress",
"Dog eat dog",
"ZZZ",
"Shishir loves indian food",
]
connector = DummyDataConnector(archival_memories)
server.load_data(user.id, connector, source.name)
# attach source
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user)
# check archival memory size
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, after=None, limit=10000)
assert len(passages_after) == 5
def test_save_archival_memory(server, user_id, agent_id):
# TODO: insert into archival memory
pass
@pytest.mark.order(4)
def test_user_message(server, user, agent_id):
# add data into recall memory
@ -458,59 +426,60 @@ def test_get_recall_memory(server, org_id, user, agent_id):
assert message_id in message_ids, f"{message_id} not in {message_ids}"
@pytest.mark.order(6)
def test_get_archival_memory(server, user, agent_id):
# test archival memory cursor pagination
actor = user
# List latest 2 passages
passages_1 = server.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
ascending=False,
limit=2,
)
assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
ascending=False,
before=cursor1,
)
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
limit=1000,
)
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
latest = passages_1[0]
earliest = passages_2[-1]
# test archival memory
passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0
# @pytest.mark.order(6)
# def test_get_archival_memory(server, user, agent_id):
# # test archival memory cursor pagination
# actor = user
#
# # List latest 2 passages
# passages_1 = server.agent_manager.list_passages(
# actor=actor,
# agent_id=agent_id,
# ascending=False,
# limit=2,
# )
# assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
#
# # List next 3 passages (earliest 3)
# cursor1 = passages_1[-1].id
# passages_2 = server.agent_manager.list_passages(
# actor=actor,
# agent_id=agent_id,
# ascending=False,
# before=cursor1,
# )
#
# # List all 5
# cursor2 = passages_1[0].created_at
# passages_3 = server.agent_manager.list_passages(
# actor=actor,
# agent_id=agent_id,
# ascending=False,
# end_date=cursor2,
# limit=1000,
# )
# assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
# assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
#
# latest = passages_1[0]
# earliest = passages_2[-1]
#
# # test archival memory
# passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
# assert len(passage_1) == 1
# assert passage_1[0].text == "alpha"
# passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=earliest.id, limit=1000, ascending=True)
# assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
# assert all("alpha" not in passage.text for passage in passage_2)
# # test safe empty return
# passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=latest.id, limit=1000, ascending=True)
# assert len(passage_none) == 0
def test_get_context_window_overview(server: SyncServer, user, agent_id):
@pytest.mark.asyncio
async def test_get_context_window_overview(server: SyncServer, user, agent_id):
"""Test that the context window overview fetch works"""
overview = server.get_agent_context_window(agent_id=agent_id, actor=user)
overview = await server.agent_manager.get_context_window(agent_id=agent_id, actor=user)
assert overview is not None
# Run some basic checks
@ -567,7 +536,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
@pytest.mark.asyncio
async def test_read_local_llm_configs(server: SyncServer, user: User):
async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop):
configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs")
clean_up_dir = False
if not os.path.exists(configs_base_dir):
@ -604,7 +573,7 @@ async def test_read_local_llm_configs(server: SyncServer, user: User):
# Try to use in agent creation
context_window_override = 4000
agent = server.create_agent(
agent = await server.create_agent_async(
request=CreateAgent(
model="caren/my-custom-model",
context_window_limit=context_window_override,
@ -987,131 +956,6 @@ async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tool
server.agent_manager.delete_agent(agent_state.id, actor=actor)
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
actor = server.user_manager.get_user_or_default(user_id)
existing_sources = server.source_manager.list_sources(actor=actor)
if len(existing_sources) > 0:
for source in existing_sources:
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0
# Create a source
source = server.source_manager.create_source(
PydanticSource(
name="timber_source",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
created_by_id=user_id,
),
actor=actor,
)
assert source.created_by_id == user_id
# Create a test file with some content
test_file = tmp_path / "test.txt"
test_content = "We have a dog called Timber. He likes to sleep and eat chicken."
test_file.write_text(test_content)
# Attach source to agent first
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
# Create a job for loading the first file
job = server.job_manager.create_job(
PydanticJob(
user_id=user_id,
metadata={"type": "embedding", "filename": test_file.name, "source_id": source.id},
),
actor=actor,
)
# Load the first file to source
server.load_file_to_source(
source_id=source.id,
file_path=str(test_file),
job_id=job.id,
actor=actor,
)
# Verify job completed successfully
job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor)
assert job.status == "completed"
assert job.metadata["num_passages"] == 1
assert job.metadata["num_documents"] == 1
# Verify passages were added
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert first_file_passage_count > initial_passage_count
# Create a second test file with different content
test_file2 = tmp_path / "test2.txt"
test_file2.write_text(WAR_AND_PEACE)
# Create a job for loading the second file
job2 = server.job_manager.create_job(
PydanticJob(
user_id=user_id,
metadata={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
),
actor=actor,
)
# Load the second file to source
server.load_file_to_source(
source_id=source.id,
file_path=str(test_file2),
job_id=job2.id,
actor=actor,
)
# Verify second job completed successfully
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor)
assert job2.status == "completed"
assert job2.metadata["num_passages"] >= 10
assert job2.metadata["num_documents"] == 1
# Verify passages were appended (not replaced)
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert final_passage_count > first_file_passage_count
# Verify both old and new content is searchable
passages = server.agent_manager.list_passages(
agent_id=agent_id,
actor=actor,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages) == final_passage_count
assert any("chicken" in passage.text.lower() for passage in passages)
assert any("Anna".lower() in passage.text.lower() for passage in passages)
# Initially should have no passages
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
assert initial_agent2_passages == 0
# Attach source to second agent
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
# Verify second agent has same number of passages as first agent
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
assert agent2_passages == agent1_passages
# Verify second agent can query the same content
passages2 = server.agent_manager.list_passages(
actor=actor,
agent_id=other_agent_id,
source_id=source.id,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools):
actor = server.user_manager.get_user_or_default(user_id)
@ -1226,8 +1070,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to
@pytest.mark.asyncio
async def test_messages_with_provider_override(server: SyncServer, user_id: str):
actor = server.user_manager.get_user_or_default(user_id)
async def test_messages_with_provider_override(server: SyncServer, user_id: str, event_loop):
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
provider = server.provider_manager.create_provider(
request=ProviderCreate(
name="caren-anthropic",
@ -1242,7 +1086,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str)
models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base])
assert provider.name not in [model.provider_name for model in models]
agent = server.create_agent(
agent = await server.create_agent_async(
request=CreateAgent(
memory_blocks=[],
model="caren-anthropic/claude-3-5-sonnet-20240620",
@ -1306,7 +1150,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str)
@pytest.mark.asyncio
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User, event_loop):
models = await server.list_llm_models_async(actor=user)
model_handles = [model.handle for model in models]
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"

View File

@ -1,132 +0,0 @@
import os
import threading
import time
import pytest
from dotenv import load_dotenv
from letta_client import AgentState, Letta, LlmConfig, MessageCreate
from letta.schemas.message import Message
def run_server():
load_dotenv()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(
scope="module",
)
def client(request):
# Get URL from environment or start server
api_url = os.getenv("LETTA_API_URL")
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# Overide the base_url if the LETTA_API_URL is set
base_url = api_url if api_url else server_url
# create the Letta client
yield Letta(base_url=base_url, token=None)
# Fixture for test agent
@pytest.fixture(scope="module")
def agent(client: Letta):
agent_state = client.agents.create(
name="test_client",
memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}],
model="letta/letta-free",
embedding="letta/letta-free",
)
yield agent_state
# delete agent
client.agents.delete(agent_state.id)
@pytest.mark.parametrize(
"stream_tokens,model",
[
(True, "openai/gpt-4o-mini"),
(True, "anthropic/claude-3-sonnet-20240229"),
(False, "openai/gpt-4o-mini"),
(False, "anthropic/claude-3-sonnet-20240229"),
],
)
def test_streaming_send_message(
disable_e2b_api_key,
client: Letta,
agent: AgentState,
stream_tokens: bool,
model: str,
):
# Update agent's model
config = client.agents.retrieve(agent_id=agent.id).llm_config
config_dump = config.model_dump()
config_dump["model"] = model
config = LlmConfig(**config_dump)
client.agents.modify(agent_id=agent.id, llm_config=config)
# Send streaming message
user_message_otid = Message.generate_otid()
response = client.agents.messages.create_stream(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content="This is a test. Repeat after me: 'banana'",
otid=user_message_otid,
),
],
stream_tokens=stream_tokens,
)
# Tracking variables for test validation
inner_thoughts_exist = False
inner_thoughts_count = 0
send_message_ran = False
done = False
last_message_id = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id
letta_message_otids = [user_message_otid]
assert response, "Sending message failed"
for chunk in response:
# Check chunk type and content based on the current client API
if hasattr(chunk, "message_type") and chunk.message_type == "reasoning_message":
inner_thoughts_exist = True
inner_thoughts_count += 1
if chunk.message_type == "tool_call_message" and hasattr(chunk, "tool_call") and chunk.tool_call.name == "send_message":
send_message_ran = True
if chunk.message_type == "assistant_message":
send_message_ran = True
if chunk.message_type == "usage_statistics":
# Validate usage statistics
assert chunk.step_count == 1
assert chunk.completion_tokens > 10
assert chunk.prompt_tokens > 1000
assert chunk.total_tokens > 1000
done = True
else:
letta_message_otids.append(chunk.otid)
print(chunk)
# If stream tokens, we expect at least one inner thought
assert inner_thoughts_count >= 1, "Expected more than one inner thought"
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
assert done, "Message stream not done"
messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id)
assert [message.otid for message in messages] == letta_message_otids

View File

@ -1,59 +0,0 @@
from letta.services.helpers.agent_manager_helper import safe_format
CORE_MEMORY_VAR = "My core memory is that I like to eat bananas"
VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR}
def test_formatter():
# Example system prompt that has no vars
NO_VARS = """
THIS IS A SYSTEM PROMPT WITH NO VARS
"""
assert NO_VARS == safe_format(NO_VARS, VARS_DICT)
# Example system prompt that has {CORE_MEMORY}
CORE_MEMORY_VAR = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{CORE_MEMORY}
"""
CORE_MEMORY_VAR_SOL = """
THIS IS A SYSTEM PROMPT WITH NO VARS
My core memory is that I like to eat bananas
"""
assert CORE_MEMORY_VAR_SOL == safe_format(CORE_MEMORY_VAR, VARS_DICT)
# Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist)
UNUSED_VAR = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{USER_MEMORY}
{CORE_MEMORY}
"""
UNUSED_VAR_SOL = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{USER_MEMORY}
My core memory is that I like to eat bananas
"""
assert UNUSED_VAR_SOL == safe_format(UNUSED_VAR, VARS_DICT)
# Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist), AND an empty {}
UNUSED_AND_EMPRY_VAR = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{}
{USER_MEMORY}
{CORE_MEMORY}
"""
UNUSED_AND_EMPRY_VAR_SOL = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{}
{USER_MEMORY}
My core memory is that I like to eat bananas
"""
assert UNUSED_AND_EMPRY_VAR_SOL == safe_format(UNUSED_AND_EMPRY_VAR, VARS_DICT)

View File

@ -1,8 +1,282 @@
import pytest
from letta.constants import MAX_FILENAME_LENGTH
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
from letta.services.helpers.agent_manager_helper import safe_format
from letta.utils import sanitize_filename
CORE_MEMORY_VAR = "My core memory is that I like to eat bananas"
VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR}
# -----------------------------------------------------------------------
# Example source code for testing multiple scenarios, including:
# 1) A class-based custom type (which we won't handle properly).
# 2) Functions with multiple argument types.
# 3) A function with default arguments.
# 4) A function with no arguments.
# 5) A function that shares the same name as another symbol.
# -----------------------------------------------------------------------
example_source_code = r"""
class CustomClass:
def __init__(self, x):
self.x = x
def unrelated_symbol():
pass
def no_args_func():
pass
def default_args_func(x: int = 5, y: str = "hello"):
return x, y
def my_function(a: int, b: float, c: str, d: list, e: dict, f: CustomClass = None):
pass
def my_function_duplicate():
# This function shares the name "my_function" partially, but isn't an exact match
pass
"""
def test_get_function_annotations_found():
"""
Test that we correctly parse annotations for a function
that includes multiple argument types and a custom class.
"""
annotations = get_function_annotations_from_source(example_source_code, "my_function")
assert annotations == {
"a": "int",
"b": "float",
"c": "str",
"d": "list",
"e": "dict",
"f": "CustomClass",
}
def test_get_function_annotations_not_found():
"""
If the requested function name doesn't exist exactly,
we should raise a ValueError.
"""
with pytest.raises(ValueError, match="Function 'missing_function' not found"):
get_function_annotations_from_source(example_source_code, "missing_function")
def test_get_function_annotations_no_args():
"""
Check that a function without arguments returns an empty annotations dict.
"""
annotations = get_function_annotations_from_source(example_source_code, "no_args_func")
assert annotations == {}
def test_get_function_annotations_with_default_values():
"""
Ensure that a function with default arguments still captures the annotations.
"""
annotations = get_function_annotations_from_source(example_source_code, "default_args_func")
assert annotations == {"x": "int", "y": "str"}
def test_get_function_annotations_partial_name_collision():
"""
Ensure we only match the exact function name, not partial collisions.
"""
# This will match 'my_function' exactly, ignoring 'my_function_duplicate'
annotations = get_function_annotations_from_source(example_source_code, "my_function")
assert "a" in annotations # Means it matched the correct function
# No error expected here, just making sure we didn't accidentally parse "my_function_duplicate".
# --------------------- coerce_dict_args_by_annotations TESTS --------------------- #
def test_coerce_dict_args_success():
"""
Basic success scenario with standard types:
int, float, str, list, dict.
"""
annotations = {"a": "int", "b": "float", "c": "str", "d": "list", "e": "dict"}
function_args = {"a": "42", "b": "3.14", "c": 123, "d": "[1, 2, 3]", "e": '{"key": "value"}'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["b"] == 3.14
assert coerced_args["c"] == "123"
assert coerced_args["d"] == [1, 2, 3]
assert coerced_args["e"] == {"key": "value"}
def test_coerce_dict_args_invalid_type():
"""
If the value cannot be coerced into the annotation,
a ValueError should be raised.
"""
annotations = {"a": "int"}
function_args = {"a": "invalid_int"}
with pytest.raises(ValueError, match="Failed to coerce argument 'a' to int"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_no_annotations():
"""
If there are no annotations, we do no coercion.
"""
annotations = {}
function_args = {"a": 42, "b": "hello"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args == function_args # Exactly the same dict back
def test_coerce_dict_args_partial_annotations():
"""
Only coerce annotated arguments; leave unannotated ones unchanged.
"""
annotations = {"a": "int"}
function_args = {"a": "42", "b": "no_annotation"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["b"] == "no_annotation"
def test_coerce_dict_args_with_missing_args():
"""
If function_args lacks some keys listed in annotations,
those are simply not coerced. (We do not add them.)
"""
annotations = {"a": "int", "b": "float"}
function_args = {"a": "42"} # Missing 'b'
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert "b" not in coerced_args
def test_coerce_dict_args_unexpected_keys():
"""
If function_args has extra keys not in annotations,
we leave them alone.
"""
annotations = {"a": "int"}
function_args = {"a": "42", "z": 999}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["z"] == 999 # unchanged
def test_coerce_dict_args_unsupported_custom_class():
"""
If someone tries to pass an annotation that isn't supported (like a custom class),
we should raise a ValueError (or similarly handle the error) rather than silently
accept it.
"""
annotations = {"f": "CustomClass"} # We can't resolve this
function_args = {"f": {"x": 1}}
with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass: Unsupported annotation: CustomClass"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_with_complex_types():
"""
Confirm the ability to parse built-in complex data (lists, dicts, etc.)
when given as strings.
"""
annotations = {"big_list": "list", "nested_dict": "dict"}
function_args = {"big_list": "[1, 2, [3, 4], {'five': 5}]", "nested_dict": '{"alpha": [10, 20], "beta": {"x": 1, "y": 2}}'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["big_list"] == [1, 2, [3, 4], {"five": 5}]
assert coerced_args["nested_dict"] == {
"alpha": [10, 20],
"beta": {"x": 1, "y": 2},
}
def test_coerce_dict_args_non_string_keys():
"""
Validate behavior if `function_args` includes non-string keys.
(We should simply skip annotation checks for them.)
"""
annotations = {"a": "int"}
function_args = {123: "42", "a": "42"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
# 'a' is coerced to int
assert coerced_args["a"] == 42
# 123 remains untouched
assert coerced_args[123] == "42"
def test_coerce_dict_args_non_parseable_list_or_dict():
"""
Test passing incorrectly formatted JSON for a 'list' or 'dict' annotation.
"""
annotations = {"bad_list": "list", "bad_dict": "dict"}
function_args = {"bad_list": "[1, 2, 3", "bad_dict": '{"key": "value"'} # missing brackets
with pytest.raises(ValueError, match="Failed to coerce argument 'bad_list' to list"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_with_complex_list_annotation():
"""
Test coercion when list with type annotation (e.g., list[int]) is used.
"""
annotations = {"a": "list[int]"}
function_args = {"a": "[1, 2, 3]"}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == [1, 2, 3]
def test_coerce_dict_args_with_complex_dict_annotation():
"""
Test coercion when dict with type annotation (e.g., dict[str, int]) is used.
"""
annotations = {"a": "dict[str, int]"}
function_args = {"a": '{"x": 1, "y": 2}'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == {"x": 1, "y": 2}
def test_coerce_dict_args_unsupported_complex_annotation():
"""
If an unsupported complex annotation is used (e.g., a custom class),
a ValueError should be raised.
"""
annotations = {"f": "CustomClass[int]"}
function_args = {"f": "CustomClass(42)"}
with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]"):
coerce_dict_args_by_annotations(function_args, annotations)
def test_coerce_dict_args_with_nested_complex_annotation():
"""
Test coercion with complex nested types like list[dict[str, int]].
"""
annotations = {"a": "list[dict[str, int]]"}
function_args = {"a": '[{"x": 1}, {"y": 2}]'}
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == [{"x": 1}, {"y": 2}]
def test_coerce_dict_args_with_default_arguments():
"""
Test coercion with default arguments, where some arguments have defaults in the source code.
"""
annotations = {"a": "int", "b": "str"}
function_args = {"a": "42"}
function_args.setdefault("b", "hello") # Setting the default value for 'b'
coerced_args = coerce_dict_args_by_annotations(function_args, annotations)
assert coerced_args["a"] == 42
assert coerced_args["b"] == "hello"
def test_valid_filename():
filename = "valid_filename.txt"
@ -64,3 +338,58 @@ def test_unique_filenames():
assert sanitized2.startswith("duplicate_")
assert sanitized1.endswith(".txt")
assert sanitized2.endswith(".txt")
def test_formatter():
# Example system prompt that has no vars
NO_VARS = """
THIS IS A SYSTEM PROMPT WITH NO VARS
"""
assert NO_VARS == safe_format(NO_VARS, VARS_DICT)
# Example system prompt that has {CORE_MEMORY}
CORE_MEMORY_VAR = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{CORE_MEMORY}
"""
CORE_MEMORY_VAR_SOL = """
THIS IS A SYSTEM PROMPT WITH NO VARS
My core memory is that I like to eat bananas
"""
assert CORE_MEMORY_VAR_SOL == safe_format(CORE_MEMORY_VAR, VARS_DICT)
# Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist)
UNUSED_VAR = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{USER_MEMORY}
{CORE_MEMORY}
"""
UNUSED_VAR_SOL = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{USER_MEMORY}
My core memory is that I like to eat bananas
"""
assert UNUSED_VAR_SOL == safe_format(UNUSED_VAR, VARS_DICT)
# Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist), AND an empty {}
UNUSED_AND_EMPRY_VAR = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{}
{USER_MEMORY}
{CORE_MEMORY}
"""
UNUSED_AND_EMPRY_VAR_SOL = """
THIS IS A SYSTEM PROMPT WITH NO VARS
{}
{USER_MEMORY}
My core memory is that I like to eat bananas
"""
assert UNUSED_AND_EMPRY_VAR_SOL == safe_format(UNUSED_AND_EMPRY_VAR, VARS_DICT)