diff --git a/examples/composio_tool_usage.py b/examples/composio_tool_usage.py deleted file mode 100644 index 89c662b00..000000000 --- a/examples/composio_tool_usage.py +++ /dev/null @@ -1,92 +0,0 @@ -import json -import os -import uuid - -from letta import create_client -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory -from letta.schemas.sandbox_config import SandboxType -from letta.services.sandbox_config_manager import SandboxConfigManager - -""" -Setup here. -""" -# Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) -client = create_client() -client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) -client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - -# Generate uuid for agent name for this example -namespace = uuid.NAMESPACE_DNS -agent_uuid = str(uuid.uuid5(namespace, "letta-composio-tooling-example")) - -# Clear all agents -for agent_state in client.list_agents(): - if agent_state.name == agent_uuid: - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") - - -# Add sandbox env -manager = SandboxConfigManager() -# Ensure you have e2b key set -sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=client.user) -manager.create_sandbox_env_var( - SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=os.environ.get("COMPOSIO_API_KEY")), - sandbox_config_id=sandbox_config.id, - actor=client.user, -) - - -""" -This example show how you can add Composio tools . - -First, make sure you have Composio and some of the extras downloaded. -``` -poetry install --extras "external-tools" -``` -then setup letta with `letta configure`. - -Aditionally, this example stars a Github repo on your behalf. You will need to configure Composio in your environment. -``` -composio login -composio add github -``` - -Last updated Oct 2, 2024. Please check `composio` documentation for any composio related issues. -""" - - -def main(): - from composio import Action - - # Add the composio tool - tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - persona = f""" - My name is Letta. - - I am a personal assistant that helps star repos on Github. It is my job to correctly input the owner and repo to the {tool.name} tool based on the user's request. - - Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. - """ - - # Create an agent - agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[tool.id]) - print(f"Created agent: {agent.name} with ID {str(agent.id)}") - - # Send a message to the agent - send_message_response = client.user_message(agent_id=agent.id, message="Star a repo composio with owner composiohq on GitHub") - for message in send_message_response.messages: - response_json = json.dumps(message.model_dump(), indent=4) - print(f"{response_json}\n") - - # Delete agent - client.delete_agent(agent_id=agent.id) - print(f"Deleted agent: {agent.name} with ID {str(agent.id)}") - - -if __name__ == "__main__": - main() diff --git a/examples/langchain_tool_usage.py b/examples/langchain_tool_usage.py deleted file mode 100644 index 3ce4eb399..000000000 --- a/examples/langchain_tool_usage.py +++ /dev/null @@ -1,87 +0,0 @@ -import json -import uuid - -from letta import create_client -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory - -""" -This example show how you can add LangChain tools . - -First, make sure you have LangChain and some of the extras downloaded. -For this specific example, you will need `wikipedia` installed. -``` -poetry install --extras "external-tools" -``` -then setup letta with `letta configure`. -""" - - -def main(): - from langchain_community.tools import WikipediaQueryRun - from langchain_community.utilities import WikipediaAPIWrapper - - api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500) - langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) - - # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - # create tool - # Note the additional_imports_module_attr_map - # We need to pass in a map of all the additional imports necessary to run this tool - # Because an object of type WikipediaAPIWrapper is passed into WikipediaQueryRun to initialize langchain_tool, - # We need to also import WikipediaAPIWrapper - # The map is a mapping of the module name to the attribute name - # langchain_community.utilities.WikipediaAPIWrapper - wikipedia_query_tool = client.load_langchain_tool( - langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"} - ) - tool_name = wikipedia_query_tool.name - - # Confirm that the tool is in - tools = client.list_tools() - assert wikipedia_query_tool.name in [t.name for t in tools] - - # Generate uuid for agent name for this example - namespace = uuid.NAMESPACE_DNS - agent_uuid = str(uuid.uuid5(namespace, "letta-langchain-tooling-example")) - - # Clear all agents - for agent_state in client.list_agents(): - if agent_state.name == agent_uuid: - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") - - # google search persona - persona = f""" - - My name is Letta. - - I am a personal assistant who answers a user's questions using wikipedia searches. When a user asks me a question, I will use a tool called {tool_name} which will search Wikipedia and return a Wikipedia page about the topic. It is my job to construct the best query to input into {tool_name} based on the user's question. - - Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. - """ - - # Create an agent - agent_state = client.create_agent( - name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[wikipedia_query_tool.id] - ) - print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}") - - # Send a message to the agent - send_message_response = client.user_message(agent_id=agent_state.id, message="Tell me a fun fact about Albert Einstein!") - for message in send_message_response.messages: - response_json = json.dumps(message.model_dump(), indent=4) - print(f"{response_json}\n") - - # Delete agent - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") - - -if __name__ == "__main__": - main() diff --git a/examples/notebooks/Multi-agent recruiting workflow.ipynb b/examples/notebooks/Multi-agent recruiting workflow.ipynb deleted file mode 100644 index 0b33ca069..000000000 --- a/examples/notebooks/Multi-agent recruiting workflow.ipynb +++ /dev/null @@ -1,884 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "cac06555-9ce8-4f01-bbef-3f8407f4b54d", - "metadata": {}, - "source": [ - "# Multi-agent recruiting workflow \n", - "> Make sure you run the Letta server before running this example using `letta server`\n", - "\n", - "Last tested with letta version `0.5.3`" - ] - }, - { - "cell_type": "markdown", - "id": "aad3a8cc-d17a-4da1-b621-ecc93c9e2106", - "metadata": {}, - "source": [ - "## Section 0: Setup a MemGPT client " - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "7ccd43f2-164b-4d25-8465-894a3bb54c4b", - "metadata": {}, - "outputs": [], - "source": [ - "from letta_client import CreateBlock, Letta, MessageCreate\n", - "\n", - "client = Letta(base_url=\"http://localhost:8283\")" - ] - }, - { - "cell_type": "markdown", - "id": "99a61da5-f069-4538-a548-c7d0f7a70227", - "metadata": {}, - "source": [ - "## Section 1: Shared Memory Block \n", - "Each agent will have both its own memory, and shared memory. The shared memory will contain information about the organization that the agents are all a part of. If one agent updates this memory, the changes will be propaged to the memory of all the other agents. " - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "7770600d-5e83-4498-acf1-05f5bea216c3", - "metadata": {}, - "outputs": [], - "source": [ - "org_description = \"The company is called AgentOS \" \\\n", - "+ \"and is building AI tools to make it easier to create \" \\\n", - "+ \"and deploy LLM agents.\"\n", - "\n", - "org_block = client.blocks.create(\n", - " label=\"company\",\n", - " value=org_description,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "6c3d3a55-870a-4ff0-81c0-4072f783a940", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "org_block" - ] - }, - { - "cell_type": "markdown", - "id": "8448df7b-c321-4d90-ba52-003930a513cb", - "metadata": {}, - "source": [ - "## Section 2: Orchestrating Multiple Agents \n", - "We'll implement a recruiting workflow that involves evaluating an candidate, then if the candidate is a good fit, writing a personalized email on the human's behalf. Since this task involves multiple stages, sometimes breaking the task down to multiple agents can improve performance (though this is not always the case). We will break down the task into: \n", - "\n", - "1. `eval_agent`: This agent is responsible for evaluating candidates based on their resume\n", - "2. `outreach_agent`: This agent is responsible for writing emails to strong candidates\n", - "3. `recruiter_agent`: This agent is responsible for generating leads from a database \n", - "\n", - "Much like humans, these agents will communicate by sending each other messages. We can do this by giving agents that need to communicate with other agents access to a tool that allows them to message other agents. " - ] - }, - { - "cell_type": "markdown", - "id": "a065082a-d865-483c-b721-43c5a4d51afe", - "metadata": {}, - "source": [ - "#### Evaluator Agent\n", - "This agent will have tools to: \n", - "* Read a resume \n", - "* Submit a candidate for outreach (which sends the candidate information to the `outreach_agent`)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "c00232c5-4c37-436c-8ea4-602a31bd84fa", - "metadata": {}, - "outputs": [], - "source": [ - "def read_resume(self, name: str): \n", - " \"\"\"\n", - " Read the resume data for a candidate given the name\n", - "\n", - " Args: \n", - " name (str): Candidate name \n", - "\n", - " Returns: \n", - " resume_data (str): Candidate's resume data \n", - " \"\"\"\n", - " import os\n", - " filepath = os.path.join(\"data\", \"resumes\", name.lower().replace(\" \", \"_\") + \".txt\")\n", - " return open(filepath).read()\n", - "\n", - "def submit_evaluation(self, candidate_name: str, reach_out: bool, resume: str, justification: str): \n", - " \"\"\"\n", - " Submit a candidate for outreach. \n", - "\n", - " Args: \n", - " candidate_name (str): The name of the candidate\n", - " reach_out (bool): Whether to reach out to the candidate\n", - " resume (str): The text representation of the candidate's resume \n", - " justification (str): Justification for reaching out or not\n", - " \"\"\"\n", - " from letta import create_client \n", - " client = create_client()\n", - " message = \"Reach out to the following candidate. \" \\\n", - " + f\"Name: {candidate_name}\\n\" \\\n", - " + f\"Resume Data: {resume}\\n\" \\\n", - " + f\"Justification: {justification}\"\n", - " # NOTE: we will define this agent later \n", - " if reach_out:\n", - " response = client.send_message(\n", - " agent_name=\"outreach_agent\", \n", - " role=\"user\", \n", - " message=message\n", - " ) \n", - " else: \n", - " print(f\"Candidate {candidate_name} is rejected: {justification}\")\n", - "\n", - "# TODO: add an archival andidate tool (provide justification) \n", - "\n", - "read_resume_tool = client.tools.upsert_from_function(func=read_resume) \n", - "submit_evaluation_tool = client.tools.upsert_from_function(func=submit_evaluation)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "12482994-03f4-4dda-8ea2-6492ec28f392", - "metadata": {}, - "outputs": [], - "source": [ - "skills = \"Front-end (React, Typescript), software engineering \" \\\n", - "+ \"(ideally Python), and experience with LLMs.\"\n", - "eval_persona = f\"You are responsible to finding good recruiting \" \\\n", - "+ \"candidates, for the company description. \" \\\n", - "+ f\"Ideal canddiates have skills: {skills}. \" \\\n", - "+ \"Submit your candidate evaluation with the submit_evaluation tool. \"\n", - "\n", - "eval_agent = client.agents.create(\n", - " name=\"eval_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=eval_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[read_resume_tool.id, submit_evaluation_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "37c2d0be-b980-426f-ab24-1feaa8ed90ef", - "metadata": {}, - "source": [ - "#### Outreach agent \n", - "This agent will email candidates with customized emails. Since sending emails is a bit complicated, we'll just pretend we sent an email by printing it in the tool call. " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "24e8942f-5b0e-4490-ac5f-f9e1f3178627", - "metadata": {}, - "outputs": [], - "source": [ - "def email_candidate(self, content: str): \n", - " \"\"\"\n", - " Send an email\n", - "\n", - " Args: \n", - " content (str): Content of the email \n", - " \"\"\"\n", - " print(\"Pretend to email:\", content)\n", - " return\n", - "\n", - "email_candidate_tool = client.tools.upsert_from_function(func=email_candidate)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "87416e00-c7a0-4420-be71-e2f5a6404428", - "metadata": {}, - "outputs": [], - "source": [ - "outreach_persona = \"You are responsible for sending outbound emails \" \\\n", - "+ \"on behalf of a company with the send_emails tool to \" \\\n", - "+ \"potential candidates. \" \\\n", - "+ \"If possible, make sure to personalize the email by appealing \" \\\n", - "+ \"to the recipient with details about the company. \" \\\n", - "+ \"You position is `Head Recruiter`, and you go by the name Bob, with contact info bob@gmail.com. \" \\\n", - "+ \"\"\"\n", - "Follow this email template: \n", - "\n", - "Hi , \n", - "\n", - " \n", - "\n", - "Best, \n", - " \n", - " \n", - "\"\"\"\n", - " \n", - "outreach_agent = client.agents.create(\n", - " name=\"outreach_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=outreach_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[email_candidate_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "f69d38da-807e-4bb1-8adb-f715b24f1c34", - "metadata": {}, - "source": [ - "Next, we'll send a message from the user telling the `leadgen_agent` to evaluate a given candidate: " - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "f09ab5bd-e158-42ee-9cce-43f254c4d2b0", - "metadata": {}, - "outputs": [], - "source": [ - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"user\",\n", - " content=\"Candidate: Tony Stark\",\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "cd8f1a1e-21eb-47ae-9eed-b1d3668752ff", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
Checking the resume for Tony Stark to evaluate if he fits the bill for our needs.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
read_resume({
  \"name\": \"Tony Stark\",
  \"request_heartbeat\"
: true
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"Failed\",
  \"message\"
: \"Error calling function read_resume: [Errno 2] No such file or directory: 'data/resumes/tony_stark.txt'\",
  \"time\"
: \"2024-11-13 05:51:26 PM PST-0800\"
}
\n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
I couldn't retrieve Tony's resume. Need to handle this carefully to keep the conversation flowing.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
send_message({
  \"message\": \"It looks like I'm having trouble accessing Tony Stark's resume at the moment. Can you provide more details about his qualifications?\"
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:51:28 PM PST-0800\"
}
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
USAGE STATISTICS
\n", - "
{
  \"completion_tokens\": 103,
  \"prompt_tokens\": 4999,
  \"total_tokens\": 5102,
  \"step_count\": 2
}
\n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "LettaResponse(messages=[InternalMonologue(id='message-97a1ae82-f8f3-419f-94c4-263112dbc10b', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 799617, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Checking the resume for Tony Stark to evaluate if he fits the bill for our needs.'), FunctionCallMessage(id='message-97a1ae82-f8f3-419f-94c4-263112dbc10b', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 799617, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='read_resume', arguments='{\\n \"name\": \"Tony Stark\",\\n \"request_heartbeat\": true\\n}', function_call_id='call_wOsiHlU3551JaApHKP7rK4Rt')), FunctionReturn(id='message-97a2b57e-40c6-4f06-a307-a0e3a00717ce', date=datetime.datetime(2024, 11, 14, 1, 51, 26, 803505, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"Failed\",\\n \"message\": \"Error calling function read_resume: [Errno 2] No such file or directory: \\'data/resumes/tony_stark.txt\\'\",\\n \"time\": \"2024-11-13 05:51:26 PM PST-0800\"\\n}', status='error', function_call_id='call_wOsiHlU3551JaApHKP7rK4Rt'), InternalMonologue(id='message-8e249aea-27ce-4788-b3e0-ac4c8401bc93', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 360676, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue=\"I couldn't retrieve Tony's resume. Need to handle this carefully to keep the conversation flowing.\"), FunctionCallMessage(id='message-8e249aea-27ce-4788-b3e0-ac4c8401bc93', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 360676, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"It looks like I\\'m having trouble accessing Tony Stark\\'s resume at the moment. Can you provide more details about his qualifications?\"\\n}', function_call_id='call_1DoFBhOsP9OCpdPQjUfBcKjw')), FunctionReturn(id='message-5600e8e7-6c6f-482a-8594-a0483ef523a2', date=datetime.datetime(2024, 11, 14, 1, 51, 28, 361921, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:28 PM PST-0800\"\\n}', status='success', function_call_id='call_1DoFBhOsP9OCpdPQjUfBcKjw')], usage=LettaUsageStatistics(completion_tokens=103, prompt_tokens=4999, total_tokens=5102, step_count=2))" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response" - ] - }, - { - "cell_type": "markdown", - "id": "67069247-e603-439c-b2df-9176c4eba957", - "metadata": {}, - "source": [ - "#### Providing feedback to agents \n", - "Since MemGPT agents are persisted, we can provide feedback to agents that is used in future agent executions if we want to modify the future behavior. " - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "19c57d54-a1fe-4244-b765-b996ba9a4788", - "metadata": {}, - "outputs": [], - "source": [ - "feedback = \"Our company pivoted to foundation model training\"\n", - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"user\",\n", - " content=feedback,\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "036b973f-209a-4ad9-90e7-fc827b5d92c7", - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "feedback = \"The company is also renamed to FoundationAI\"\n", - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"user\",\n", - " content=feedback,\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "5d7a7633-35a3-4e41-b44a-be71067dd32a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
Updating the company name to reflect the rebranding. This is important for future candidate evaluations.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
core_memory_replace({
  \"label\": \"company\",
  \"old_content\"
: \"The company has pivoted to foundation model training.\",
  \"new_content\"
: \"The company is called FoundationAI and has pivoted to foundation model training.\",
  \"request_heartbeat\"
: true
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:51:34 PM PST-0800\"
}
\n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
Now I have the updated company info, time to check in on Tony.
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
send_message({
  \"message\": \"Got it, the new name is FoundationAI! What about Tony Stark's background catches your eye for this role? Any particular insights on his skills in front-end development or LLMs?\"
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:51:35 PM PST-0800\"
}
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
USAGE STATISTICS
\n", - "
{
  \"completion_tokens\": 146,
  \"prompt_tokens\": 6372,
  \"total_tokens\": 6518,
  \"step_count\": 2
}
\n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "LettaResponse(messages=[InternalMonologue(id='message-0adccea9-4b96-4cbb-b5fc-a9ef0120c646', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 180327, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Updating the company name to reflect the rebranding. This is important for future candidate evaluations.'), FunctionCallMessage(id='message-0adccea9-4b96-4cbb-b5fc-a9ef0120c646', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 180327, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='core_memory_replace', arguments='{\\n \"label\": \"company\",\\n \"old_content\": \"The company has pivoted to foundation model training.\",\\n \"new_content\": \"The company is called FoundationAI and has pivoted to foundation model training.\",\\n \"request_heartbeat\": true\\n}', function_call_id='call_5s0KTElXdipPidchUu3R9CxI')), FunctionReturn(id='message-a2f278e8-ec23-4e22-a124-c21a0f46f733', date=datetime.datetime(2024, 11, 14, 1, 51, 34, 182291, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:34 PM PST-0800\"\\n}', status='success', function_call_id='call_5s0KTElXdipPidchUu3R9CxI'), InternalMonologue(id='message-91f63cb2-b544-4b2e-82b1-b11643df5f93', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 841684, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='Now I have the updated company info, time to check in on Tony.'), FunctionCallMessage(id='message-91f63cb2-b544-4b2e-82b1-b11643df5f93', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 841684, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"Got it, the new name is FoundationAI! What about Tony Stark\\'s background catches your eye for this role? Any particular insights on his skills in front-end development or LLMs?\"\\n}', function_call_id='call_R4Erx7Pkpr5lepcuaGQU5isS')), FunctionReturn(id='message-813a9306-38fc-4665-9f3b-7c3671fd90e6', date=datetime.datetime(2024, 11, 14, 1, 51, 35, 842423, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:51:35 PM PST-0800\"\\n}', status='success', function_call_id='call_R4Erx7Pkpr5lepcuaGQU5isS')], usage=LettaUsageStatistics(completion_tokens=146, prompt_tokens=6372, total_tokens=6518, step_count=2))" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "d04d4b3a-6df1-41a9-9a8e-037fbb45836d", - "metadata": {}, - "outputs": [], - "source": [ - "response = client.agents.messages.send(\n", - " agent_id=eval_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"system\",\n", - " content=\"Candidate: Spongebob Squarepants\",\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "c60465f4-7977-4f70-9a75-d2ddebabb0fa", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.\\nThe company is called FoundationAI and has pivoted to foundation model training.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.agents.core_memory.get_block(agent_id=eval_agent.id, block_label=\"company\")" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "a51c6bb3-225d-47a4-88f1-9a26ff838dd3", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id=None, id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.agents.core_memory.get_block(agent_id=outreach_agent.id, block_label=\"company\")" - ] - }, - { - "cell_type": "markdown", - "id": "8d181b1e-72da-4ebe-a872-293e3ce3a225", - "metadata": {}, - "source": [ - "## Section 3: Adding an orchestrator agent \n", - "So far, we've been triggering the `eval_agent` manually. We can also create an additional agent that is responsible for orchestrating tasks. " - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "80b23d46-ed4b-4457-810a-a819d724e146", - "metadata": {}, - "outputs": [], - "source": [ - "#re-create agents \n", - "client.agents.delete(eval_agent.id)\n", - "client.agents.delete(outreach_agent.id)\n", - "\n", - "org_block = client.blocks.create(\n", - " label=\"company\",\n", - " value=org_description,\n", - ")\n", - "\n", - "eval_agent = client.agents.create(\n", - " name=\"eval_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=eval_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[read_resume_tool.id, submit_evaluation_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")\n", - "\n", - "outreach_agent = client.agents.create(\n", - " name=\"outreach_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=outreach_persona,\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[email_candidate_tool.id]\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\",\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a751d0f1-b52d-493c-bca1-67f88011bded", - "metadata": {}, - "source": [ - "The `recruiter_agent` will be linked to the same `org_block` that we created before - we can look up the current data in `org_block` by looking up its ID: " - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "bf6bd419-1504-4513-bc68-d4c717ea8e2d", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Block(value='The company is called AgentOS and is building AI tools to make it easier to create and deploy LLM agents.\\nThe company is called FoundationAI and has pivoted to foundation model training.', limit=2000, template_name=None, template=False, label='company', description=None, metadata_={}, user_id='user-00000000-0000-4000-8000-000000000000', id='block-f212d9e6-f930-4d3b-b86a-40879a38aec4')" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "client.blocks.retrieve(block_id=org_block.id)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "e2730626-1685-46aa-9b44-a59e1099e973", - "metadata": {}, - "outputs": [], - "source": [ - "from typing import Optional\n", - "\n", - "def search_candidates_db(self, page: int) -> Optional[str]: \n", - " \"\"\"\n", - " Returns 1 candidates per page. \n", - " Page 0 returns the first 1 candidate, \n", - " Page 1 returns the next 1, etc.\n", - " Returns `None` if no candidates remain. \n", - "\n", - " Args: \n", - " page (int): The page number to return candidates from \n", - "\n", - " Returns: \n", - " candidate_names (List[str]): Names of the candidates\n", - " \"\"\"\n", - " \n", - " names = [\"Tony Stark\", \"Spongebob Squarepants\", \"Gautam Fang\"]\n", - " if page >= len(names): \n", - " return None\n", - " return names[page]\n", - "\n", - "def consider_candidate(self, name: str): \n", - " \"\"\"\n", - " Submit a candidate for consideration. \n", - "\n", - " Args: \n", - " name (str): Candidate name to consider \n", - " \"\"\"\n", - " from letta_client import Letta, MessageCreate\n", - " client = Letta(base_url=\"http://localhost:8283\")\n", - " message = f\"Consider candidate {name}\" \n", - " print(\"Sending message to eval agent: \", message)\n", - " response = client.send_message(\n", - " agent_id=eval_agent.id,\n", - " role=\"user\", \n", - " message=message\n", - " ) \n", - "\n", - "\n", - "# create tools \n", - "search_candidate_tool = client.tools.upsert_from_function(func=search_candidates_db)\n", - "consider_candidate_tool = client.tools.upsert_from_function(func=consider_candidate)\n", - "\n", - "# create recruiter agent\n", - "recruiter_agent = client.agents.create(\n", - " name=\"recruiter_agent\", \n", - " memory_blocks=[\n", - " CreateBlock(\n", - " label=\"persona\",\n", - " value=\"You run a recruiting process for a company. \" \\\n", - " + \"Your job is to continue to pull candidates from the \" \n", - " + \"`search_candidates_db` tool until there are no more \" \\\n", - " + \"candidates left. \" \\\n", - " + \"For each candidate, consider the candidate by calling \"\n", - " + \"the `consider_candidate` tool. \" \\\n", - " + \"You should continue to call `search_candidates_db` \" \\\n", - " + \"followed by `consider_candidate` until there are no more \" \\\n", - " \" candidates. \",\n", - " ),\n", - " ],\n", - " block_ids=[org_block.id],\n", - " tool_ids=[search_candidate_tool.id, consider_candidate_tool.id],\n", - " model=\"openai/gpt-4\",\n", - " embedding=\"openai/text-embedding-ada-002\"\n", - ")\n", - " \n" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "ecfd790c-0018-4fd9-bdaf-5a6b81f70adf", - "metadata": {}, - "outputs": [], - "source": [ - "response = client.agents.messages.send(\n", - " agent_id=recruiter_agent.id,\n", - " messages=[\n", - " MessageCreate(\n", - " role=\"system\",\n", - " content=\"Run generation\",\n", - " )\n", - " ],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "8065c179-cf90-4287-a6e5-8c009807b436", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - " \n", - "
\n", - " \n", - "
\n", - "
INTERNAL MONOLOGUE
\n", - "
New user logged in. Excited to get started!
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION CALL
\n", - "
send_message({
  \"message\": \"Welcome! I'm thrilled to have you here. Let’s dive into what you need today!\"
})
\n", - "
\n", - " \n", - "
\n", - "
FUNCTION RETURN
\n", - "
{
  \"status\": \"OK\",
  \"message\"
: \"None\",
  \"time\"
: \"2024-11-13 05:52:14 PM PST-0800\"
}
\n", - "
\n", - "
\n", - "
\n", - "
\n", - "
USAGE STATISTICS
\n", - "
{
  \"completion_tokens\": 48,
  \"prompt_tokens\": 2398,
  \"total_tokens\": 2446,
  \"step_count\": 1
}
\n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "LettaResponse(messages=[InternalMonologue(id='message-8c8ab238-a43e-4509-b7ad-699e9a47ed44', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 780419, tzinfo=datetime.timezone.utc), message_type='internal_monologue', internal_monologue='New user logged in. Excited to get started!'), FunctionCallMessage(id='message-8c8ab238-a43e-4509-b7ad-699e9a47ed44', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 780419, tzinfo=datetime.timezone.utc), message_type='function_call', function_call=FunctionCall(name='send_message', arguments='{\\n \"message\": \"Welcome! I\\'m thrilled to have you here. Let’s dive into what you need today!\"\\n}', function_call_id='call_2OIz7t3oiGsUlhtSneeDslkj')), FunctionReturn(id='message-26c3b7a3-51c8-47ae-938d-a3ed26e42357', date=datetime.datetime(2024, 11, 14, 1, 52, 14, 781455, tzinfo=datetime.timezone.utc), message_type='function_return', function_return='{\\n \"status\": \"OK\",\\n \"message\": \"None\",\\n \"time\": \"2024-11-13 05:52:14 PM PST-0800\"\\n}', status='success', function_call_id='call_2OIz7t3oiGsUlhtSneeDslkj')], usage=LettaUsageStatistics(completion_tokens=48, prompt_tokens=2398, total_tokens=2446, step_count=1))" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "response" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "4639bbca-e0c5-46a9-a509-56d35d26e97f", - "metadata": {}, - "outputs": [], - "source": [ - "client.agents.delete(agent_id=eval_agent.id)\n", - "client.agents.delete(agent_id=outreach_agent.id)\n", - "client.agents.delete(agent_id=recruiter_agent.id)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "letta", - "language": "python", - "name": "letta" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/swarm/simple.py b/examples/swarm/simple.py deleted file mode 100644 index 8e10c486d..000000000 --- a/examples/swarm/simple.py +++ /dev/null @@ -1,72 +0,0 @@ -import typer -from swarm import Swarm - -from letta import EmbeddingConfig, LLMConfig - -""" -This is an example of how to implement the basic example provided by OpenAI for tranferring a conversation between two agents: -https://github.com/openai/swarm/tree/main?tab=readme-ov-file#usage - -Before running this example, make sure you have letta>=0.5.0 installed. This example also runs with OpenAI, though you can also change the model by modifying the code: -```bash -export OPENAI_API_KEY=... -pip install letta -```` -Then, instead the `examples/swarm` directory, run: -```bash -python simple.py -``` -You should see a message output from Agent B. - -""" - - -def transfer_agent_b(self): - """ - Transfer conversation to agent B. - - Returns: - str: name of agent to transfer to - """ - return "agentb" - - -def transfer_agent_a(self): - """ - Transfer conversation to agent A. - - Returns: - str: name of agent to transfer to - """ - return "agenta" - - -swarm = Swarm() - -# set client configs -swarm.client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) -swarm.client.set_default_llm_config(LLMConfig.default_config(model_name="gpt-4")) - -# create tools -transfer_a = swarm.client.create_or_update_tool(transfer_agent_a) -transfer_b = swarm.client.create_or_update_tool(transfer_agent_b) - -# create agents -if swarm.client.get_agent_id("agentb"): - swarm.client.delete_agent(swarm.client.get_agent_id("agentb")) -if swarm.client.get_agent_id("agenta"): - swarm.client.delete_agent(swarm.client.get_agent_id("agenta")) -agent_a = swarm.create_agent(name="agentb", tools=[transfer_a.name], instructions="Only speak in haikus") -agent_b = swarm.create_agent(name="agenta", tools=[transfer_b.name]) - -response = swarm.run(agent_name="agenta", message="Transfer me to agent b by calling the transfer_agent_b tool") -print("Response:") -typer.secho(f"{response}", fg=typer.colors.GREEN) - -response = swarm.run(agent_name="agenta", message="My name is actually Sarah. Transfer me to agent b to write a haiku about my name") -print("Response:") -typer.secho(f"{response}", fg=typer.colors.GREEN) - -response = swarm.run(agent_name="agenta", message="Transfer me to agent b - I want a haiku with my name in it") -print("Response:") -typer.secho(f"{response}", fg=typer.colors.GREEN) diff --git a/examples/swarm/swarm.py b/examples/swarm/swarm.py deleted file mode 100644 index 6e0958bf7..000000000 --- a/examples/swarm/swarm.py +++ /dev/null @@ -1,111 +0,0 @@ -import json -from typing import List, Optional - -import typer - -from letta import AgentState, EmbeddingConfig, LLMConfig, create_client -from letta.schemas.agent import AgentType -from letta.schemas.memory import BasicBlockMemory, Block - - -class Swarm: - - def __init__(self): - self.agents = [] - self.client = create_client() - - # shared memory block (shared section of context window accross agents) - self.shared_memory = Block(label="human", value="") - - def create_agent( - self, - name: Optional[str] = None, - # agent config - agent_type: Optional[AgentType] = AgentType.memgpt_agent, - # model configs - embedding_config: EmbeddingConfig = None, - llm_config: LLMConfig = None, - # system - system: Optional[str] = None, - # tools - tools: Optional[List[str]] = None, - include_base_tools: Optional[bool] = True, - # instructions - instructions: str = "", - ) -> AgentState: - - # todo: process tools for agent handoff - persona_value = ( - f"You are agent with name {name}. You instructions are {instructions}" - if len(instructions) > 0 - else f"You are agent with name {name}" - ) - persona_block = Block(label="persona", value=persona_value) - memory = BasicBlockMemory(blocks=[persona_block, self.shared_memory]) - - agent = self.client.create_agent( - name=name, - agent_type=agent_type, - embedding_config=embedding_config, - llm_config=llm_config, - system=system, - tools=tools, - include_base_tools=include_base_tools, - memory=memory, - ) - self.agents.append(agent) - - return agent - - def reset(self): - # delete all agents - for agent in self.agents: - self.client.delete_agent(agent.id) - for block in self.client.list_blocks(): - self.client.delete_block(block.id) - - def run(self, agent_name: str, message: str): - - history = [] - while True: - # send message to agent - agent_id = self.client.get_agent_id(agent_name) - - print("Messaging agent: ", agent_name) - print("History size: ", len(history)) - # print(self.client.get_agent(agent_id).tools) - # TODO: implement with sending multiple messages - if len(history) == 0: - response = self.client.send_message(agent_id=agent_id, message=message, role="user") - else: - response = self.client.send_messages(agent_id=agent_id, messages=history) - - # update history - history += response.messages - - # grab responses - messages = [] - for message in response.messages: - messages += message.to_letta_messages() - - # get new agent (see tool call) - # print(messages) - - if len(messages) < 2: - continue - - function_call = messages[-2] - function_return = messages[-1] - if function_call.function_call.name == "send_message": - # return message to use - arg_data = json.loads(function_call.function_call.arguments) - # print(arg_data) - return arg_data["message"] - else: - # swap the agent - return_data = json.loads(function_return.function_return) - agent_name = return_data["message"] - typer.secho(f"Transferring to agent: {agent_name}", fg=typer.colors.RED) - # print("Transferring to agent", agent_name) - - print() diff --git a/examples/tool_rule_usage.py b/examples/tool_rule_usage.py deleted file mode 100644 index 8ec061d0c..000000000 --- a/examples/tool_rule_usage.py +++ /dev/null @@ -1,129 +0,0 @@ -import os -import uuid - -from letta import create_client -from letta.schemas.letta_message import ToolCallMessage -from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule -from tests.helpers.endpoints_helper import assert_invoked_send_message_with_keyword, setup_agent -from tests.helpers.utils import cleanup -from tests.test_model_letta_performance import llm_config_dir - -""" -This example shows how you can constrain tool calls in your agent. - -Please note that this currently only works reliably for models with Structured Outputs (e.g. gpt-4o). - -Start by downloading the dependencies. -``` -poetry install --all-extras -``` -""" - -# Tools for this example -# Generate uuid for agent name for this example -namespace = uuid.NAMESPACE_DNS -agent_uuid = str(uuid.uuid5(namespace, "agent_tool_graph")) -config_file = os.path.join(llm_config_dir, "openai-gpt-4o.json") - -"""Contrived tools for this test case""" - - -def first_secret_word(): - """ - Call this to retrieve the first secret word, which you will need for the second_secret_word function. - """ - return "v0iq020i0g" - - -def second_secret_word(prev_secret_word: str): - """ - Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. - - Args: - prev_secret_word (str): The secret word retrieved from calling first_secret_word. - """ - if prev_secret_word != "v0iq020i0g": - raise RuntimeError(f"Expected secret {"v0iq020i0g"}, got {prev_secret_word}") - - return "4rwp2b4gxq" - - -def third_secret_word(prev_secret_word: str): - """ - Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. - - Args: - prev_secret_word (str): The secret word retrieved from calling second_secret_word. - """ - if prev_secret_word != "4rwp2b4gxq": - raise RuntimeError(f"Expected secret {"4rwp2b4gxq"}, got {prev_secret_word}") - - return "hj2hwibbqm" - - -def fourth_secret_word(prev_secret_word: str): - """ - Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. - - Args: - prev_secret_word (str): The secret word retrieved from calling third_secret_word. - """ - if prev_secret_word != "hj2hwibbqm": - raise RuntimeError(f"Expected secret {"hj2hwibbqm"}, got {prev_secret_word}") - - return "banana" - - -def auto_error(): - """ - If you call this function, it will throw an error automatically. - """ - raise RuntimeError("This should never be called.") - - -def main(): - # 1. Set up the client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # 2. Add all the tools to the client - functions = [first_secret_word, second_secret_word, third_secret_word, fourth_secret_word, auto_error] - tools = [] - for func in functions: - tool = client.create_or_update_tool(func) - tools.append(tool) - tool_names = [t.name for t in tools[:-1]] - - # 3. Create the tool rules. It must be called in this order, or there will be an error thrown. - tool_rules = [ - InitToolRule(tool_name="first_secret_word"), - ChildToolRule(tool_name="first_secret_word", children=["second_secret_word"]), - ChildToolRule(tool_name="second_secret_word", children=["third_secret_word"]), - ChildToolRule(tool_name="third_secret_word", children=["fourth_secret_word"]), - ChildToolRule(tool_name="fourth_secret_word", children=["send_message"]), - TerminalToolRule(tool_name="send_message"), - ] - - # 4. Create the agent - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # 5. Ask for the final secret word - response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") - - # 6. Here, we thoroughly check the correctness of the response - tool_names += ["send_message"] # Add send message because we expect this to be called at the end - for m in response.messages: - if isinstance(m, ToolCallMessage): - # Check that it's equal to the first one - assert m.tool_call.name == tool_names[0] - # Pop out first one - tool_names = tool_names[1:] - - # Check final send message contains "banana" - assert_invoked_send_message_with_keyword(response.messages, "banana") - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) - - -if __name__ == "__main__": - main() diff --git a/letta/__init__.py b/letta/__init__.py index 772a17a9f..dcbda419e 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,7 +1,7 @@ -__version__ = "0.7.21" +__version__ = "0.7.22" # import clients -from letta.client.client import LocalClient, RESTClient, create_client +from letta.client.client import RESTClient # imports for easier access from letta.schemas.agent import AgentState diff --git a/letta/__main__.py b/letta/__main__.py deleted file mode 100644 index 89f11424b..000000000 --- a/letta/__main__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .main import app - -app() diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index a349366dc..693427588 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -100,8 +100,10 @@ class BaseAgent(ABC): # [DB Call] size of messages and archival memories # todo: blocking for now - num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + if num_messages is None: + num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) + if num_archival_memories is None: + num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) new_system_message_str = compile_system_message( system_prompt=agent_state.system, diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 4afa51857..cd8c4edb7 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -174,20 +174,13 @@ class LettaAgent(BaseAgent): for message in letta_messages: yield f"data: {message.model_dump_json()}\n\n" - # update usage - # TODO: add run_id - usage.step_count += 1 - usage.completion_tokens += response.usage.completion_tokens - usage.prompt_tokens += response.usage.prompt_tokens - usage.total_tokens += response.usage.total_tokens - if not should_continue: break # Extend the in context message ids if not agent_state.message_buffer_autoclear: message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] - self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) + await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) # Return back usage yield f"data: {usage.model_dump_json()}\n\n" @@ -285,7 +278,7 @@ class LettaAgent(BaseAgent): # Extend the in context message ids if not agent_state.message_buffer_autoclear: message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] - self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) + await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) return current_in_context_messages, new_in_context_messages, usage @@ -437,7 +430,7 @@ class LettaAgent(BaseAgent): # Extend the in context message ids if not agent_state.message_buffer_autoclear: message_ids = [m.id for m in (current_in_context_messages + new_in_context_messages)] - self.agent_manager.set_in_context_messages(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) + await self.agent_manager.set_in_context_messages_async(agent_id=self.agent_id, message_ids=message_ids, actor=self.actor) # TODO: This may be out of sync, if in between steps users add files # NOTE (cliandy): temporary for now for particlar use cases. diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index e2355ab54..10a50a58b 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -233,7 +233,7 @@ class LettaAgentBatch(BaseAgent): ctx = await self._collect_resume_context(llm_batch_id) log_event(name="update_statuses") - self._update_request_statuses(ctx.request_status_updates) + await self._update_request_statuses_async(ctx.request_status_updates) log_event(name="exec_tools") exec_results = await self._execute_tools(ctx) @@ -242,7 +242,7 @@ class LettaAgentBatch(BaseAgent): msg_map = await self._persist_tool_messages(exec_results, ctx) log_event(name="mark_steps_done") - self._mark_steps_complete(llm_batch_id, ctx.agent_ids) + await self._mark_steps_complete_async(llm_batch_id, ctx.agent_ids) log_event(name="prepare_next") next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map) @@ -382,9 +382,9 @@ class LettaAgentBatch(BaseAgent): return self._extract_tool_call_and_decide_continue(tool_call, item.step_state) - def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None: + async def _update_request_statuses_async(self, updates: List[RequestStatusUpdateInfo]) -> None: if updates: - self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates) + await self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(updates=updates) def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]: sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL @@ -474,11 +474,11 @@ class LettaAgentBatch(BaseAgent): await self.message_manager.create_many_messages_async([m for msgs in msg_map.values() for m in msgs], actor=self.actor) return msg_map - def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None: + async def _mark_steps_complete_async(self, llm_batch_id: str, agent_ids: List[str]) -> None: updates = [ StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids ] - self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates) + await self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(updates) def _prepare_next_iteration( self, diff --git a/letta/benchmark/benchmark.py b/letta/benchmark/benchmark.py deleted file mode 100644 index 7109210e9..000000000 --- a/letta/benchmark/benchmark.py +++ /dev/null @@ -1,98 +0,0 @@ -# type: ignore - -import time -import uuid -from typing import Annotated, Union - -import typer - -from letta import LocalClient, RESTClient, create_client -from letta.benchmark.constants import HUMAN, PERSONA, PROMPTS, TRIES -from letta.config import LettaConfig - -# from letta.agent import Agent -from letta.errors import LLMJSONParsingError -from letta.utils import get_human_text, get_persona_text - -app = typer.Typer() - - -def send_message( - client: Union[LocalClient, RESTClient], message: str, agent_id, turn: int, fn_type: str, print_msg: bool = False, n_tries: int = TRIES -): - try: - print_msg = f"\t-> Now running {fn_type}. Progress: {turn}/{n_tries}" - print(print_msg, end="\r", flush=True) - response = client.user_message(agent_id=agent_id, message=message) - - if turn + 1 == n_tries: - print(" " * len(print_msg), end="\r", flush=True) - - for r in response: - if "function_call" in r and fn_type in r["function_call"] and any("assistant_message" in re for re in response): - return True, r["function_call"] - - return False, "No function called." - except LLMJSONParsingError as e: - print(f"Error in parsing Letta JSON: {e}") - return False, "Failed to decode valid Letta JSON from LLM output." - except Exception as e: - print(f"An unexpected error occurred: {e}") - return False, "An unexpected error occurred." - - -@app.command() -def bench( - print_messages: Annotated[bool, typer.Option("--messages", help="Print functions calls and messages from the agent.")] = False, - n_tries: Annotated[int, typer.Option("--n-tries", help="Number of benchmark tries to perform for each function.")] = TRIES, -): - client = create_client() - print(f"\nDepending on your hardware, this may take up to 30 minutes. This will also create {n_tries * len(PROMPTS)} new agents.\n") - config = LettaConfig.load() - print(f"version = {config.letta_version}") - - total_score, total_tokens_accumulated, elapsed_time = 0, 0, 0 - - for fn_type, message in PROMPTS.items(): - score = 0 - start_time_run = time.time() - bench_id = uuid.uuid4() - - for i in range(n_tries): - agent = client.create_agent( - name=f"benchmark_{bench_id}_agent_{i}", - persona=get_persona_text(PERSONA), - human=get_human_text(HUMAN), - ) - - agent_id = agent.id - result, msg = send_message( - client=client, message=message, agent_id=agent_id, turn=i, fn_type=fn_type, print_msg=print_messages, n_tries=n_tries - ) - - if print_messages: - print(f"\t{msg}") - - if result: - score += 1 - - # TODO: add back once we start tracking usage via the client - # total_tokens_accumulated += tokens_accumulated - - elapsed_time_run = round(time.time() - start_time_run, 2) - print(f"Score for {fn_type}: {score}/{n_tries}, took {elapsed_time_run} seconds") - - elapsed_time += elapsed_time_run - total_score += score - - print(f"\nMEMGPT VERSION: {config.letta_version}") - print(f"CONTEXT WINDOW: {config.default_llm_config.context_window}") - print(f"MODEL WRAPPER: {config.default_llm_config.model_wrapper}") - print(f"PRESET: {config.preset}") - print(f"PERSONA: {config.persona}") - print(f"HUMAN: {config.human}") - - print( - # f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds at average of {round(total_tokens_accumulated/elapsed_time, 2)} t/s\n" - f"\n\t-> Total score: {total_score}/{len(PROMPTS) * n_tries}, took {elapsed_time} seconds\n" - ) diff --git a/letta/benchmark/constants.py b/letta/benchmark/constants.py deleted file mode 100644 index 755fdce51..000000000 --- a/letta/benchmark/constants.py +++ /dev/null @@ -1,14 +0,0 @@ -# Basic -TRIES = 3 -AGENT_NAME = "benchmark" -PERSONA = "sam_pov" -HUMAN = "cs_phd" - -# Prompts -PROMPTS = { - "core_memory_replace": "Hey there, my name is John, what is yours?", - "core_memory_append": "I want you to remember that I like soccers for later.", - "conversation_search": "Do you remember when I talked about bananas?", - "archival_memory_insert": "Can you make sure to remember that I like programming for me so you can look it up later?", - "archival_memory_search": "Can you retrieve information about the war?", -} diff --git a/letta/cli/cli.py b/letta/cli/cli.py index a89d5266d..47e86509a 100644 --- a/letta/cli/cli.py +++ b/letta/cli/cli.py @@ -1,37 +1,15 @@ -import logging import sys from enum import Enum from typing import Annotated, Optional -import questionary import typer -import letta.utils as utils -from letta import create_client -from letta.agent import Agent, save_agent -from letta.config import LettaConfig -from letta.constants import CLI_WARNING_PREFIX, CORE_MEMORY_BLOCK_CHAR_LIMIT, LETTA_DIR, MIN_CONTEXT_WINDOW -from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL from letta.log import get_logger -from letta.schemas.enums import OptionState -from letta.schemas.memory import ChatMemory, Memory - -# from letta.interface import CLIInterface as interface # for printing to terminal from letta.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal -from letta.utils import open_folder_in_explorer, printd logger = get_logger(__name__) -def open_folder(): - """Open a folder viewer of the Letta home directory""" - try: - print(f"Opening home folder: {LETTA_DIR}") - open_folder_in_explorer(LETTA_DIR) - except Exception as e: - print(f"Failed to open folder with system viewer, error:\n{e}") - - class ServerChoice(Enum): rest_api = "rest" ws_api = "websocket" @@ -51,14 +29,6 @@ def server( if type == ServerChoice.rest_api: pass - # if LettaConfig.exists(): - # config = LettaConfig.load() - # MetadataStore(config) - # _ = create_client() # triggers user creation - # else: - # typer.secho(f"No configuration exists. Run letta configure before starting the server.", fg=typer.colors.RED) - # sys.exit(1) - try: from letta.server.rest_api.app import start_server @@ -73,292 +43,6 @@ def server( raise NotImplementedError("WS suppport deprecated") -def run( - persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None, - agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None, - human: Annotated[Optional[str], typer.Option(help="Specify human")] = None, - system: Annotated[Optional[str], typer.Option(help="Specify system prompt (raw text)")] = None, - system_file: Annotated[Optional[str], typer.Option(help="Specify raw text file containing system prompt")] = None, - # model flags - model: Annotated[Optional[str], typer.Option(help="Specify the LLM model")] = None, - model_wrapper: Annotated[Optional[str], typer.Option(help="Specify the LLM model wrapper")] = None, - model_endpoint: Annotated[Optional[str], typer.Option(help="Specify the LLM model endpoint")] = None, - model_endpoint_type: Annotated[Optional[str], typer.Option(help="Specify the LLM model endpoint type")] = None, - context_window: Annotated[ - Optional[int], typer.Option(help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)") - ] = None, - core_memory_limit: Annotated[ - Optional[int], typer.Option(help="The character limit to each core-memory section (human/persona).") - ] = CORE_MEMORY_BLOCK_CHAR_LIMIT, - # other - first: Annotated[bool, typer.Option(help="Use --first to send the first message in the sequence")] = False, - strip_ui: Annotated[bool, typer.Option(help="Remove all the bells and whistles in CLI output (helpful for testing)")] = False, - debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False, - no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False, - yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False, - # streaming - stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False, - # whether or not to put the inner thoughts inside the function args - no_content: Annotated[ - OptionState, typer.Option(help="Set to 'yes' for LLM APIs that omit the `content` field during tool calling") - ] = OptionState.DEFAULT, -): - """Start chatting with an Letta agent - - Example usage: `letta run --agent myagent --data-source mydata --persona mypersona --human myhuman --model gpt-3.5-turbo` - - :param persona: Specify persona - :param agent: Specify agent name (will load existing state if the agent exists, or create a new one with that name) - :param human: Specify human - :param model: Specify the LLM model - - """ - - # setup logger - # TODO: remove Utils Debug after global logging is complete. - utils.DEBUG = debug - # TODO: add logging command line options for runtime log level - - from letta.server.server import logger as server_logger - - if debug: - logger.setLevel(logging.DEBUG) - server_logger.setLevel(logging.DEBUG) - else: - logger.setLevel(logging.CRITICAL) - server_logger.setLevel(logging.CRITICAL) - - # load config file - config = LettaConfig.load() - - # read user id from config - client = create_client() - - # determine agent to use, if not provided - if not yes and not agent: - agents = client.list_agents() - agents = [a.name for a in agents] - - if len(agents) > 0: - print() - select_agent = questionary.confirm("Would you like to select an existing agent?").ask() - if select_agent is None: - raise KeyboardInterrupt - if select_agent: - agent = questionary.select("Select agent:", choices=agents).ask() - - # create agent config - if agent: - agent_id = client.get_agent_id(agent) - agent_state = client.get_agent(agent_id) - else: - agent_state = None - human = human if human else config.human - persona = persona if persona else config.persona - if agent and agent_state: # use existing agent - typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN) - printd("Loading agent state:", agent_state.id) - printd("Agent state:", agent_state.name) - # printd("State path:", agent_config.save_state_dir()) - # printd("Persistent manager path:", agent_config.save_persistence_manager_dir()) - # printd("Index path:", agent_config.save_agent_index_dir()) - # TODO: load prior agent state - - # Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window) - if model and model != agent_state.llm_config.model: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model {agent_state.llm_config.model} with {model}", fg=typer.colors.YELLOW - ) - agent_state.llm_config.model = model - if context_window is not None and int(context_window) != agent_state.llm_config.context_window: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing context window {agent_state.llm_config.context_window} with {context_window}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.context_window = context_window - if model_wrapper and model_wrapper != agent_state.llm_config.model_wrapper: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model wrapper {agent_state.llm_config.model_wrapper} with {model_wrapper}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.model_wrapper = model_wrapper - if model_endpoint and model_endpoint != agent_state.llm_config.model_endpoint: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model endpoint {agent_state.llm_config.model_endpoint} with {model_endpoint}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.model_endpoint = model_endpoint - if model_endpoint_type and model_endpoint_type != agent_state.llm_config.model_endpoint_type: - typer.secho( - f"{CLI_WARNING_PREFIX}Overriding existing model endpoint type {agent_state.llm_config.model_endpoint_type} with {model_endpoint_type}", - fg=typer.colors.YELLOW, - ) - agent_state.llm_config.model_endpoint_type = model_endpoint_type - - # NOTE: commented out because this seems dangerous - instead users should use /systemswap when in the CLI - # # user specified a new system prompt - # if system: - # # NOTE: agent_state.system is the ORIGINAL system prompt, - # # whereas agent_state.state["system"] is the LATEST system prompt - # existing_system_prompt = agent_state.state["system"] if "system" in agent_state.state else None - # if existing_system_prompt != system: - # # override - # agent_state.state["system"] = system - - # Update the agent with any overrides - agent_state = client.update_agent( - agent_id=agent_state.id, - name=agent_state.name, - llm_config=agent_state.llm_config, - embedding_config=agent_state.embedding_config, - ) - - # create agent - letta_agent = Agent(agent_state=agent_state, interface=interface(), user=client.user) - - else: # create new agent - # create new agent config: override defaults with args if provided - typer.secho("\n🧬 Creating new agent...", fg=typer.colors.WHITE) - - agent_name = agent if agent else utils.create_random_username() - - # create agent - client = create_client() - - # choose from list of llm_configs - llm_configs = client.list_llm_configs() - llm_options = [llm_config.model for llm_config in llm_configs] - llm_choices = [questionary.Choice(title=llm_config.pretty_print(), value=llm_config) for llm_config in llm_configs] - - # select model - if len(llm_options) == 0: - raise ValueError("No LLM models found. Please enable a provider.") - elif len(llm_options) == 1: - llm_model_name = llm_options[0] - else: - llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model - llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0] - - # option to override context window - if llm_config.context_window is not None: - context_window_validator = lambda x: x.isdigit() and int(x) > MIN_CONTEXT_WINDOW and int(x) <= llm_config.context_window - context_window_input = questionary.text( - "Select LLM context window limit (hit enter for default):", - default=str(llm_config.context_window), - validate=context_window_validator, - ).ask() - if context_window_input is not None: - llm_config.context_window = int(context_window_input) - else: - sys.exit(1) - - # choose form list of embedding configs - embedding_configs = client.list_embedding_configs() - embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] - - embedding_choices = [ - questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs - ] - - # select model - if len(embedding_options) == 0: - raise ValueError("No embedding models found. Please enable a provider.") - elif len(embedding_options) == 1: - embedding_model_name = embedding_options[0] - else: - embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model - embedding_config = [ - embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name - ][0] - - human_obj = client.get_human(client.get_human_id(name=human)) - persona_obj = client.get_persona(client.get_persona_id(name=persona)) - if human_obj is None: - typer.secho(f"Couldn't find human {human} in database, please run `letta add human`", fg=typer.colors.RED) - sys.exit(1) - if persona_obj is None: - typer.secho(f"Couldn't find persona {persona} in database, please run `letta add persona`", fg=typer.colors.RED) - sys.exit(1) - - if system_file: - try: - with open(system_file, "r", encoding="utf-8") as file: - system = file.read().strip() - printd("Loaded system file successfully.") - except FileNotFoundError: - typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED) - system_prompt = system if system else None - - memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit) - metadata = {"human": human_obj.template_name, "persona": persona_obj.template_name} - - typer.secho(f"-> {ASSISTANT_MESSAGE_CLI_SYMBOL} Using persona profile: '{persona_obj.template_name}'", fg=typer.colors.WHITE) - typer.secho(f"-> 🧑 Using human profile: '{human_obj.template_name}'", fg=typer.colors.WHITE) - - # add tools - agent_state = client.create_agent( - name=agent_name, - system=system_prompt, - embedding_config=embedding_config, - llm_config=llm_config, - memory=memory, - metadata=metadata, - ) - assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}" - typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t.name for t in agent_state.tools])}", fg=typer.colors.WHITE) - - letta_agent = Agent( - interface=interface(), - agent_state=client.get_agent(agent_state.id), - # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, - user=client.user, - ) - save_agent(agent=letta_agent) - typer.secho(f"🎉 Created new agent '{letta_agent.agent_state.name}' (id={letta_agent.agent_state.id})", fg=typer.colors.GREEN) - - # start event loop - from letta.main import run_agent_loop - - print() # extra space - run_agent_loop( - letta_agent=letta_agent, - config=config, - first=first, - no_verify=no_verify, - stream=stream, - ) # TODO: add back no_verify - - -def delete_agent( - agent_name: Annotated[str, typer.Option(help="Specify agent to delete")], -): - """Delete an agent from the database""" - # use client ID is no user_id provided - config = LettaConfig.load() - MetadataStore(config) - client = create_client() - agent = client.get_agent_by_name(agent_name) - if not agent: - typer.secho(f"Couldn't find agent named '{agent_name}' to delete", fg=typer.colors.RED) - sys.exit(1) - - confirm = questionary.confirm(f"Are you sure you want to delete agent '{agent_name}' (id={agent.id})?", default=False).ask() - if confirm is None: - raise KeyboardInterrupt - if not confirm: - typer.secho(f"Cancelled agent deletion '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN) - return - - try: - # delete the agent - client.delete_agent(agent.id) - typer.secho(f"🕊️ Successfully deleted agent '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN) - except Exception: - typer.secho(f"Failed to delete agent '{agent_name}' (id={agent.id})", fg=typer.colors.RED) - sys.exit(1) - - def version() -> str: import letta diff --git a/letta/cli/cli_config.py b/letta/cli/cli_config.py deleted file mode 100644 index a17bf476a..000000000 --- a/letta/cli/cli_config.py +++ /dev/null @@ -1,227 +0,0 @@ -import ast -import os -from enum import Enum -from typing import Annotated, List, Optional - -import questionary -import typer -from prettytable.colortable import ColorTable, Themes -from tqdm import tqdm - -import letta.helpers.datetime_helpers - -app = typer.Typer() - - -@app.command() -def configure(): - """Updates default Letta configurations - - This function and quickstart should be the ONLY place where LettaConfig.save() is called - """ - print("`letta configure` has been deprecated. Please see documentation on configuration, and run `letta run` instead.") - - -class ListChoice(str, Enum): - agents = "agents" - humans = "humans" - personas = "personas" - sources = "sources" - - -@app.command() -def list(arg: Annotated[ListChoice, typer.Argument]): - from letta.client.client import create_client - - client = create_client() - table = ColorTable(theme=Themes.OCEAN) - if arg == ListChoice.agents: - """List all agents""" - table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"] - for agent in tqdm(client.list_agents()): - # TODO: add this function - sources = client.list_attached_sources(agent_id=agent.id) - source_names = [source.name for source in sources if source is not None] - table.add_row( - [ - agent.name, - agent.llm_config.model, - agent.embedding_config.embedding_model, - agent.embedding_config.embedding_dim, - agent.memory.get_block("persona").value[:100] + "...", - agent.memory.get_block("human").value[:100] + "...", - ",".join(source_names), - letta.helpers.datetime_helpers.format_datetime(agent.created_at), - ] - ) - print(table) - elif arg == ListChoice.humans: - """List all humans""" - table.field_names = ["Name", "Text"] - for human in client.list_humans(): - table.add_row([human.template_name, human.value.replace("\n", "")[:100]]) - elif arg == ListChoice.personas: - """List all personas""" - table.field_names = ["Name", "Text"] - for persona in client.list_personas(): - table.add_row([persona.template_name, persona.value.replace("\n", "")[:100]]) - print(table) - elif arg == ListChoice.sources: - """List all data sources""" - - # create table - table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At"] - # TODO: eventually look accross all storage connections - # TODO: add data source stats - # TODO: connect to agents - - # get all sources - for source in client.list_sources(): - # get attached agents - table.add_row( - [ - source.name, - source.description, - source.embedding_config.embedding_model, - source.embedding_config.embedding_dim, - letta.helpers.datetime_helpers.format_datetime(source.created_at), - ] - ) - - print(table) - else: - raise ValueError(f"Unknown argument {arg}") - return table - - -@app.command() -def add_tool( - filename: str = typer.Option(..., help="Path to the Python file containing the function"), - name: Optional[str] = typer.Option(None, help="Name of the tool"), - update: bool = typer.Option(True, help="Update the tool if it already exists"), - tags: Optional[List[str]] = typer.Option(None, help="Tags for the tool"), -): - """Add or update a tool from a Python file.""" - from letta.client.client import create_client - - client = create_client() - - # 1. Parse the Python file - with open(filename, "r", encoding="utf-8") as file: - source_code = file.read() - - # 2. Parse the source code to extract the function - # Note: here we assume it is one function only in the file. - module = ast.parse(source_code) - func_def = None - for node in module.body: - if isinstance(node, ast.FunctionDef): - func_def = node - break - - if not func_def: - raise ValueError("No function found in the provided file") - - # 3. Compile the function to make it callable - # Explanation courtesy of GPT-4: - # Compile the AST (Abstract Syntax Tree) node representing the function definition into a code object - # ast.Module creates a module node containing the function definition (func_def) - # compile converts the AST into a code object that can be executed by the Python interpreter - # The exec function executes the compiled code object in the current context, - # effectively defining the function within the current namespace - exec(compile(ast.Module([func_def], []), filename, "exec")) - # Retrieve the function object by evaluating its name in the current namespace - # eval looks up the function name in the current scope and returns the function object - func = eval(func_def.name) - - # 4. Add or update the tool - tool = client.create_or_update_tool(func=func, tags=tags, update=update) - print(f"Tool {tool.name} added successfully") - - -@app.command() -def list_tools(): - """List all available tools.""" - from letta.client.client import create_client - - client = create_client() - - tools = client.list_tools() - for tool in tools: - print(f"Tool: {tool.name}") - - -@app.command() -def add( - option: str, # [human, persona] - name: Annotated[str, typer.Option(help="Name of human/persona")], - text: Annotated[Optional[str], typer.Option(help="Text of human/persona")] = None, - filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None, -): - """Add a person/human""" - from letta.client.client import create_client - - client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS")) - if filename: # read from file - assert text is None, "Cannot specify both text and filename" - with open(filename, "r", encoding="utf-8") as f: - text = f.read() - else: - assert text is not None, "Must specify either text or filename" - if option == "persona": - persona_id = client.get_persona_id(name) - if persona_id: - client.get_persona(persona_id) - # config if user wants to overwrite - if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask(): - return - client.update_persona(persona_id, text=text) - else: - client.create_persona(name=name, text=text) - - elif option == "human": - human_id = client.get_human_id(name) - if human_id: - human = client.get_human(human_id) - # config if user wants to overwrite - if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask(): - return - client.update_human(human_id, text=text) - else: - human = client.create_human(name=name, text=text) - else: - raise ValueError(f"Unknown kind {option}") - - -@app.command() -def delete(option: str, name: str): - """Delete a source from the archival memory.""" - from letta.client.client import create_client - - client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_API_KEY")) - try: - # delete from metadata - if option == "source": - # delete metadata - source_id = client.get_source_id(name) - assert source_id is not None, f"Source {name} does not exist" - client.delete_source(source_id) - elif option == "agent": - agent_id = client.get_agent_id(name) - assert agent_id is not None, f"Agent {name} does not exist" - client.delete_agent(agent_id=agent_id) - elif option == "human": - human_id = client.get_human_id(name) - assert human_id is not None, f"Human {name} does not exist" - client.delete_human(human_id) - elif option == "persona": - persona_id = client.get_persona_id(name) - assert persona_id is not None, f"Persona {name} does not exist" - client.delete_persona(persona_id) - else: - raise ValueError(f"Option {option} not implemented") - - typer.secho(f"Deleted {option} '{name}'", fg=typer.colors.GREEN) - - except Exception as e: - typer.secho(f"Failed to delete {option}'{name}'\n{e}", fg=typer.colors.RED) diff --git a/letta/cli/cli_load.py b/letta/cli/cli_load.py index 4c420bfa7..a50c525ed 100644 --- a/letta/cli/cli_load.py +++ b/letta/cli/cli_load.py @@ -8,61 +8,9 @@ letta load --name [ADDITIONAL ARGS] """ -import uuid -from typing import Annotated, List, Optional - -import questionary import typer -from letta import create_client -from letta.data_sources.connectors import DirectoryConnector - app = typer.Typer() default_extensions = "txt,md,pdf" - - -@app.command("directory") -def load_directory( - name: Annotated[str, typer.Option(help="Name of dataset to load.")], - input_dir: Annotated[Optional[str], typer.Option(help="Path to directory containing dataset.")] = None, - input_files: Annotated[List[str], typer.Option(help="List of paths to files containing dataset.")] = [], - recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False, - extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions, - user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove - description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None, -): - client = create_client() - - # create connector - connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions) - - # choose form list of embedding configs - embedding_configs = client.list_embedding_configs() - embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] - - embedding_choices = [ - questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs - ] - - # select model - if len(embedding_options) == 0: - raise ValueError("No embedding models found. Please enable a provider.") - elif len(embedding_options) == 1: - embedding_model_name = embedding_options[0] - else: - embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model - embedding_config = [ - embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name - ][0] - - # create source - source = client.create_source(name=name, embedding_config=embedding_config) - - # load data - try: - client.load_data(connector, source_name=name) - except Exception as e: - typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) - client.delete_source(source.id) diff --git a/letta/client/client.py b/letta/client/client.py index 90e39400b..d71aae62f 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1,27 +1,19 @@ -import asyncio -import logging import sys import time from typing import Callable, Dict, Generator, List, Optional, Union import requests -import letta.utils from letta.constants import ADMIN_PREFIX, BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, FUNCTION_RETURN_CHAR_LIMIT from letta.data_sources.connectors import DataConnector from letta.functions.functions import parse_source_code -from letta.orm.errors import NoResultFound from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona from letta.schemas.embedding_config import EmbeddingConfig # new schemas from letta.schemas.enums import JobStatus, MessageRole -from letta.schemas.environment_variables import ( - SandboxEnvironmentVariable, - SandboxEnvironmentVariableCreate, - SandboxEnvironmentVariableUpdate, -) +from letta.schemas.environment_variables import SandboxEnvironmentVariable from letta.schemas.file import FileMetadata from letta.schemas.job import Job from letta.schemas.letta_message import LettaMessage, LettaMessageUnion @@ -35,11 +27,10 @@ from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.response_format import ResponseFormatUnion from letta.schemas.run import Run -from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig, SandboxConfigCreate, SandboxConfigUpdate +from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfig from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate from letta.schemas.tool_rule import BaseToolRule -from letta.server.rest_api.interface import QueuingInterface from letta.utils import get_human_text, get_persona_text # Print deprecation notice in yellow when module is imported @@ -53,13 +44,6 @@ print( ) -def create_client(base_url: Optional[str] = None, token: Optional[str] = None): - if base_url is None: - return LocalClient() - else: - return RESTClient(base_url, token) - - class AbstractClient(object): def __init__( self, @@ -2229,1539 +2213,3 @@ class RESTClient(AbstractClient): if response.status_code != 200: raise ValueError(f"Failed to get tags: {response.text}") return response.json() - - -class LocalClient(AbstractClient): - """ - A local client for Letta, which corresponds to a single user. - - Attributes: - user_id (str): The user ID. - debug (bool): Whether to print debug information. - interface (QueuingInterface): The interface for the client. - server (SyncServer): The server for the client. - """ - - def __init__( - self, - user_id: Optional[str] = None, - org_id: Optional[str] = None, - debug: bool = False, - default_llm_config: Optional[LLMConfig] = None, - default_embedding_config: Optional[EmbeddingConfig] = None, - ): - """ - Initializes a new instance of Client class. - - Args: - user_id (str): The user ID. - debug (bool): Whether to print debug information. - """ - - from letta.server.server import SyncServer - - # set logging levels - letta.utils.DEBUG = debug - logging.getLogger().setLevel(logging.CRITICAL) - - # save default model config - self._default_llm_config = default_llm_config - self._default_embedding_config = default_embedding_config - - # create server - self.interface = QueuingInterface(debug=debug) - self.server = SyncServer(default_interface_factory=lambda: self.interface) - - # save org_id that `LocalClient` is associated with - if org_id: - self.org_id = org_id - else: - self.org_id = self.server.organization_manager.DEFAULT_ORG_ID - # save user_id that `LocalClient` is associated with - if user_id: - self.user_id = user_id - else: - # get default user - self.user_id = self.server.user_manager.DEFAULT_USER_ID - - self.user = self.server.user_manager.get_user_or_default(self.user_id) - self.organization = self.server.get_organization_or_default(self.org_id) - - # agents - def list_agents( - self, - query_text: Optional[str] = None, - tags: Optional[List[str]] = None, - limit: int = 100, - before: Optional[str] = None, - after: Optional[str] = None, - ) -> List[AgentState]: - self.interface.clear() - - return self.server.agent_manager.list_agents( - actor=self.user, tags=tags, query_text=query_text, limit=limit, before=before, after=after - ) - - def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: - """ - Check if an agent exists - - Args: - agent_id (str): ID of the agent - agent_name (str): Name of the agent - - Returns: - exists (bool): `True` if the agent exists, `False` otherwise - """ - - if not (agent_id or agent_name): - raise ValueError(f"Either agent_id or agent_name must be provided") - if agent_id and agent_name: - raise ValueError(f"Only one of agent_id or agent_name can be provided") - existing = self.list_agents() - if agent_id: - return str(agent_id) in [str(agent.id) for agent in existing] - else: - return agent_name in [str(agent.name) for agent in existing] - - def create_agent( - self, - name: Optional[str] = None, - # agent config - agent_type: Optional[AgentType] = AgentType.memgpt_agent, - # model configs - embedding_config: EmbeddingConfig = None, - llm_config: LLMConfig = None, - # memory - memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)), - block_ids: Optional[List[str]] = None, - # TODO: change to this when we are ready to migrate all the tests/examples (matches the REST API) - # memory_blocks=[ - # {"label": "human", "value": get_human_text(DEFAULT_HUMAN), "limit": 5000}, - # {"label": "persona", "value": get_persona_text(DEFAULT_PERSONA), "limit": 5000}, - # ], - # system - system: Optional[str] = None, - # tools - tool_ids: Optional[List[str]] = None, - tool_rules: Optional[List[BaseToolRule]] = None, - include_base_tools: Optional[bool] = True, - include_multi_agent_tools: bool = False, - include_base_tool_rules: bool = True, - # metadata - metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, - description: Optional[str] = None, - initial_message_sequence: Optional[List[Message]] = None, - tags: Optional[List[str]] = None, - message_buffer_autoclear: bool = False, - response_format: Optional[ResponseFormatUnion] = None, - ) -> AgentState: - """Create an agent - - Args: - name (str): Name of the agent - embedding_config (EmbeddingConfig): Embedding configuration - llm_config (LLMConfig): LLM configuration - memory_blocks (List[Dict]): List of configurations for the memory blocks (placed in core-memory) - system (str): System configuration - tools (List[str]): List of tools - tool_rules (Optional[List[BaseToolRule]]): List of tool rules - include_base_tools (bool): Include base tools - include_multi_agent_tools (bool): Include multi agent tools - metadata (Dict): Metadata - description (str): Description - tags (List[str]): Tags for filtering agents - - Returns: - agent_state (AgentState): State of the created agent - """ - # construct list of tools - tool_ids = tool_ids or [] - - # check if default configs are provided - assert embedding_config or self._default_embedding_config, f"Embedding config must be provided" - assert llm_config or self._default_llm_config, f"LLM config must be provided" - - # TODO: This should not happen here, we need to have clear separation between create/add blocks - for block in memory.get_blocks(): - self.server.block_manager.create_or_update_block(block, actor=self.user) - - # Also get any existing block_ids passed in - block_ids = block_ids or [] - - # create agent - # Create the base parameters - create_params = { - "description": description, - "metadata": metadata, - "memory_blocks": [], - "block_ids": [b.id for b in memory.get_blocks()] + block_ids, - "tool_ids": tool_ids, - "tool_rules": tool_rules, - "include_base_tools": include_base_tools, - "include_multi_agent_tools": include_multi_agent_tools, - "include_base_tool_rules": include_base_tool_rules, - "system": system, - "agent_type": agent_type, - "llm_config": llm_config if llm_config else self._default_llm_config, - "embedding_config": embedding_config if embedding_config else self._default_embedding_config, - "initial_message_sequence": initial_message_sequence, - "tags": tags, - "message_buffer_autoclear": message_buffer_autoclear, - "response_format": response_format, - } - - # Only add name if it's not None - if name is not None: - create_params["name"] = name - - agent_state = self.server.create_agent( - CreateAgent(**create_params), - actor=self.user, - ) - - # TODO: get full agent state - return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user) - - def update_agent( - self, - agent_id: str, - name: Optional[str] = None, - description: Optional[str] = None, - system: Optional[str] = None, - tool_ids: Optional[List[str]] = None, - tags: Optional[List[str]] = None, - metadata: Optional[Dict] = None, - llm_config: Optional[LLMConfig] = None, - embedding_config: Optional[EmbeddingConfig] = None, - message_ids: Optional[List[str]] = None, - response_format: Optional[ResponseFormatUnion] = None, - ): - """ - Update an existing agent - - Args: - agent_id (str): ID of the agent - name (str): Name of the agent - description (str): Description of the agent - system (str): System configuration - tools (List[str]): List of tools - metadata (Dict): Metadata - llm_config (LLMConfig): LLM configuration - embedding_config (EmbeddingConfig): Embedding configuration - message_ids (List[str]): List of message IDs - tags (List[str]): Tags for filtering agents - - Returns: - agent_state (AgentState): State of the updated agent - """ - # TODO: add the ability to reset linked block_ids - self.interface.clear() - agent_state = self.server.agent_manager.update_agent( - agent_id, - UpdateAgent( - name=name, - system=system, - tool_ids=tool_ids, - tags=tags, - description=description, - metadata=metadata, - llm_config=llm_config, - embedding_config=embedding_config, - message_ids=message_ids, - response_format=response_format, - ), - actor=self.user, - ) - return agent_state - - def get_tools_from_agent(self, agent_id: str) -> List[Tool]: - """ - Get tools from an existing agent. - - Args: - agent_id (str): ID of the agent - - Returns: - List[Tool]: A list of Tool objs - """ - self.interface.clear() - return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user).tools - - def attach_tool(self, agent_id: str, tool_id: str) -> AgentState: - """ - Add tool to an existing agent - - Args: - agent_id (str): ID of the agent - tool_id (str): A tool id - - Returns: - agent_state (AgentState): State of the updated agent - """ - self.interface.clear() - agent_state = self.server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user) - return agent_state - - def detach_tool(self, agent_id: str, tool_id: str) -> AgentState: - """ - Removes tools from an existing agent - - Args: - agent_id (str): ID of the agent - tool_id (str): The tool id - - Returns: - agent_state (AgentState): State of the updated agent - """ - self.interface.clear() - agent_state = self.server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=self.user) - return agent_state - - def rename_agent(self, agent_id: str, new_name: str) -> AgentState: - """ - Rename an agent - - Args: - agent_id (str): ID of the agent - new_name (str): New name for the agent - - Returns: - agent_state (AgentState): State of the updated agent - """ - return self.update_agent(agent_id, name=new_name) - - def delete_agent(self, agent_id: str) -> None: - """ - Delete an agent - - Args: - agent_id (str): ID of the agent to delete - """ - self.server.agent_manager.delete_agent(agent_id=agent_id, actor=self.user) - - def get_agent_by_name(self, agent_name: str) -> AgentState: - """ - Get an agent by its name - - Args: - agent_name (str): Name of the agent - - Returns: - agent_state (AgentState): State of the agent - """ - self.interface.clear() - return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user) - - def get_agent(self, agent_id: str) -> AgentState: - """ - Get an agent's state by its ID. - - Args: - agent_id (str): ID of the agent - - Returns: - agent_state (AgentState): State representation of the agent - """ - self.interface.clear() - return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) - - def get_agent_id(self, agent_name: str) -> Optional[str]: - """ - Get the ID of an agent by name (names are unique per user) - - Args: - agent_name (str): Name of the agent - - Returns: - agent_id (str): ID of the agent - """ - - self.interface.clear() - assert agent_name, f"Agent name must be provided" - - # TODO: Refactor this futher to not have downstream users expect Optionals - this should just error - try: - return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user).id - except NoResultFound: - return None - - # memory - def get_in_context_memory(self, agent_id: str) -> Memory: - """ - Get the in-context (i.e. core) memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - memory (Memory): In-context memory of the agent - """ - memory = self.server.get_agent_memory(agent_id=agent_id, actor=self.user) - return memory - - def get_core_memory(self, agent_id: str) -> Memory: - return self.get_in_context_memory(agent_id) - - def update_in_context_memory(self, agent_id: str, section: str, value: Union[List[str], str]) -> Memory: - """ - Update the in-context memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - memory (Memory): The updated in-context memory of the agent - - """ - # TODO: implement this (not sure what it should look like) - memory = self.server.update_agent_core_memory(agent_id=agent_id, label=section, value=value, actor=self.user) - return memory - - def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary: - """ - Get a summary of the archival memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - summary (ArchivalMemorySummary): Summary of the archival memory - - """ - return self.server.get_archival_memory_summary(agent_id=agent_id, actor=self.user) - - def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary: - """ - Get a summary of the recall memory of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - summary (RecallMemorySummary): Summary of the recall memory - """ - return self.server.get_recall_memory_summary(agent_id=agent_id, actor=self.user) - - def get_in_context_messages(self, agent_id: str) -> List[Message]: - """ - Get in-context messages of an agent - - Args: - agent_id (str): ID of the agent - - Returns: - messages (List[Message]): List of in-context messages - """ - return self.server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=self.user) - - # agent interactions - - def send_messages( - self, - agent_id: str, - messages: List[Union[Message | MessageCreate]], - ): - """ - Send pre-packed messages to an agent. - - Args: - agent_id (str): ID of the agent - messages (List[Union[Message | MessageCreate]]): List of messages to send - - Returns: - response (LettaResponse): Response from the agent - """ - self.interface.clear() - usage = self.server.send_messages(actor=self.user, agent_id=agent_id, input_messages=messages) - - # format messages - return LettaResponse(messages=messages, usage=usage) - - def send_message( - self, - message: str, - role: str, - name: Optional[str] = None, - agent_id: Optional[str] = None, - agent_name: Optional[str] = None, - stream_steps: bool = False, - stream_tokens: bool = False, - ) -> LettaResponse: - """ - Send a message to an agent - - Args: - message (str): Message to send - role (str): Role of the message - agent_id (str): ID of the agent - name(str): Name of the sender - stream (bool): Stream the response (default: `False`) - - Returns: - response (LettaResponse): Response from the agent - """ - if not agent_id: - # lookup agent by name - assert agent_name, f"Either agent_id or agent_name must be provided" - agent_id = self.get_agent_id(agent_name=agent_name) - assert agent_id, f"Agent with name {agent_name} not found" - - if stream_steps or stream_tokens: - # TODO: implement streaming with stream=True/False - raise NotImplementedError - self.interface.clear() - - usage = self.server.send_messages( - actor=self.user, - agent_id=agent_id, - input_messages=[MessageCreate(role=MessageRole(role), content=message, name=name)], - ) - - ## TODO: need to make sure date/timestamp is propely passed - ## TODO: update self.interface.to_list() to return actual Message objects - ## here, the message objects will have faulty created_by timestamps - # messages = self.interface.to_list() - # for m in messages: - # assert isinstance(m, Message), f"Expected Message object, got {type(m)}" - # letta_messages = [] - # for m in messages: - # letta_messages += m.to_letta_messages() - # return LettaResponse(messages=letta_messages, usage=usage) - - # format messages - messages = self.interface.to_list() - letta_messages = [] - for m in messages: - letta_messages += m.to_letta_messages() - - return LettaResponse(messages=letta_messages, usage=usage) - - def user_message(self, agent_id: str, message: str) -> LettaResponse: - """ - Send a message to an agent as a user - - Args: - agent_id (str): ID of the agent - message (str): Message to send - - Returns: - response (LettaResponse): Response from the agent - """ - self.interface.clear() - return self.send_message(role="user", agent_id=agent_id, message=message) - - def run_command(self, agent_id: str, command: str) -> LettaResponse: - """ - Run a command on the agent - - Args: - agent_id (str): The agent ID - command (str): The command to run - - Returns: - LettaResponse: The response from the agent - - """ - self.interface.clear() - usage = self.server.run_command(user_id=self.user_id, agent_id=agent_id, command=command) - - # NOTE: messages/usage may be empty, depending on the command - return LettaResponse(messages=self.interface.to_list(), usage=usage) - - # archival memory - - # humans / personas - - def get_block_id(self, name: str, label: str) -> str | None: - return None - - def create_human(self, name: str, text: str): - """ - Create a human block template (saved human string to pre-fill `ChatMemory`) - - Args: - name (str): Name of the human block - text (str): Text of the human block - - Returns: - human (Human): Human block - """ - return self.server.block_manager.create_or_update_block(Human(template_name=name, value=text), actor=self.user) - - def create_persona(self, name: str, text: str): - """ - Create a persona block template (saved persona string to pre-fill `ChatMemory`) - - Args: - name (str): Name of the persona block - text (str): Text of the persona block - - Returns: - persona (Persona): Persona block - """ - return self.server.block_manager.create_or_update_block(Persona(template_name=name, value=text), actor=self.user) - - def list_humans(self): - """ - List available human block templates - - Returns: - humans (List[Human]): List of human blocks - """ - return [] - - def list_personas(self) -> List[Persona]: - """ - List available persona block templates - - Returns: - personas (List[Persona]): List of persona blocks - """ - return [] - - def update_human(self, human_id: str, text: str): - """ - Update a human block template - - Args: - human_id (str): ID of the human block - text (str): Text of the human block - - Returns: - human (Human): Updated human block - """ - return self.server.block_manager.update_block( - block_id=human_id, block_update=UpdateHuman(value=text, is_template=True), actor=self.user - ) - - def update_persona(self, persona_id: str, text: str): - """ - Update a persona block template - - Args: - persona_id (str): ID of the persona block - text (str): Text of the persona block - - Returns: - persona (Persona): Updated persona block - """ - return self.server.block_manager.update_block( - block_id=persona_id, block_update=UpdatePersona(value=text, is_template=True), actor=self.user - ) - - def get_persona(self, id: str) -> Persona: - """ - Get a persona block template - - Args: - id (str): ID of the persona block - - Returns: - persona (Persona): Persona block - """ - assert id, f"Persona ID must be provided" - return Persona(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) - - def get_human(self, id: str) -> Human: - """ - Get a human block template - - Args: - id (str): ID of the human block - - Returns: - human (Human): Human block - """ - assert id, f"Human ID must be provided" - return Human(**self.server.block_manager.get_block_by_id(id, actor=self.user).model_dump()) - - def get_persona_id(self, name: str) -> str | None: - """ - Get the ID of a persona block template - - Args: - name (str): Name of the persona block - - Returns: - id (str): ID of the persona block - """ - return None - - def get_human_id(self, name: str) -> str | None: - """ - Get the ID of a human block template - - Args: - name (str): Name of the human block - - Returns: - id (str): ID of the human block - """ - return None - - def delete_persona(self, id: str): - """ - Delete a persona block template - - Args: - id (str): ID of the persona block - """ - self.delete_block(id) - - def delete_human(self, id: str): - """ - Delete a human block template - - Args: - id (str): ID of the human block - """ - self.delete_block(id) - - # tools - def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: - tool_create = ToolCreate.from_langchain( - langchain_tool=langchain_tool, - additional_imports_module_attr_map=additional_imports_module_attr_map, - ) - return self.server.tool_manager.create_or_update_langchain_tool(tool_create=tool_create, actor=self.user) - - def load_composio_tool(self, action: "ActionType") -> Tool: - tool_create = ToolCreate.from_composio(action_name=action.name) - return self.server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=self.user) - - def create_tool( - self, - func, - tags: Optional[List[str]] = None, - description: Optional[str] = None, - return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, - ) -> Tool: - """ - Create a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. - - Args: - func (callable): The function to create a tool for. - tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. - description (str, optional): The description. - return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. - - Returns: - tool (Tool): The created tool. - """ - # TODO: check if tool already exists - # TODO: how to load modules? - # parse source code/schema - source_code = parse_source_code(func) - source_type = "python" - name = func.__name__ # Initialize name using function's __name__ - if not tags: - tags = [] - - # call server function - return self.server.tool_manager.create_tool( - Tool( - source_type=source_type, - source_code=source_code, - name=name, - tags=tags, - description=description, - return_char_limit=return_char_limit, - ), - actor=self.user, - ) - - def create_or_update_tool( - self, - func, - tags: Optional[List[str]] = None, - description: Optional[str] = None, - return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, - ) -> Tool: - """ - Creates or updates a tool. This stores the source code of function on the server, so that the server can execute the function and generate an OpenAI JSON schemas for it when using with an agent. - - Args: - func (callable): The function to create a tool for. - tags (Optional[List[str]], optional): Tags for the tool. Defaults to None. - description (str, optional): The description. - return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. - - Returns: - tool (Tool): The created tool. - """ - source_code = parse_source_code(func) - source_type = "python" - if not tags: - tags = [] - - # call server function - return self.server.tool_manager.create_or_update_tool( - Tool( - source_type=source_type, - source_code=source_code, - tags=tags, - description=description, - return_char_limit=return_char_limit, - ), - actor=self.user, - ) - - def update_tool( - self, - id: str, - description: Optional[str] = None, - func: Optional[Callable] = None, - tags: Optional[List[str]] = None, - return_char_limit: int = FUNCTION_RETURN_CHAR_LIMIT, - ) -> Tool: - """ - Update a tool with provided parameters (name, func, tags) - - Args: - id (str): ID of the tool - func (callable): Function to wrap in a tool - tags (List[str]): Tags for the tool - return_char_limit (int): The character limit for the tool's return value. Defaults to FUNCTION_RETURN_CHAR_LIMIT. - - Returns: - tool (Tool): Updated tool - """ - update_data = { - "source_type": "python", # Always include source_type - "source_code": parse_source_code(func) if func else None, - "tags": tags, - "description": description, - "return_char_limit": return_char_limit, - } - - # Filter out any None values from the dictionary - update_data = {key: value for key, value in update_data.items() if value is not None} - - return self.server.tool_manager.update_tool_by_id(tool_id=id, tool_update=ToolUpdate(**update_data), actor=self.user) - - def list_tools(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[Tool]: - """ - List available tools for the user. - - Returns: - tools (List[Tool]): List of tools - """ - # Get the current event loop or create a new one if there isn't one - try: - loop = asyncio.get_event_loop() - if loop.is_running(): - # We're in an async context but can't await - use a new loop via run_coroutine_threadsafe - concurrent_future = asyncio.run_coroutine_threadsafe( - self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit), loop - ) - return concurrent_future.result() - else: - # We have a loop but it's not running - we can just run the coroutine - return loop.run_until_complete(self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit)) - except RuntimeError: - # No running event loop - create a new one with asyncio.run - return asyncio.run(self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit)) - - def get_tool(self, id: str) -> Optional[Tool]: - """ - Get a tool given its ID. - - Args: - id (str): ID of the tool - - Returns: - tool (Tool): Tool - """ - return self.server.tool_manager.get_tool_by_id(id, actor=self.user) - - def delete_tool(self, id: str): - """ - Delete a tool given the ID. - - Args: - id (str): ID of the tool - """ - return self.server.tool_manager.delete_tool_by_id(id, actor=self.user) - - def get_tool_id(self, name: str) -> Optional[str]: - """ - Get the ID of a tool from its name. The client will use the org_id it is configured with. - - Args: - name (str): Name of the tool - - Returns: - id (str): ID of the tool (`None` if not found) - """ - tool = self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user) - return tool.id if tool else None - - def list_attached_tools(self, agent_id: str) -> List[Tool]: - """ - List all tools attached to an agent. - - Args: - agent_id (str): ID of the agent - - Returns: - List[Tool]: List of tools attached to the agent - """ - return self.server.agent_manager.list_attached_tools(agent_id=agent_id, actor=self.user) - - def load_data(self, connector: DataConnector, source_name: str): - """ - Load data into a source - - Args: - connector (DataConnector): Data connector - source_name (str): Name of the source - """ - self.server.load_data(user_id=self.user_id, connector=connector, source_name=source_name) - - def load_file_to_source(self, filename: str, source_id: str, blocking=True): - """ - Load a file into a source - - Args: - filename (str): Name of the file - source_id (str): ID of the source - blocking (bool): Block until the job is complete - - Returns: - job (Job): Data loading job including job status and metadata - """ - job = Job( - user_id=self.user_id, - status=JobStatus.created, - metadata={"type": "embedding", "filename": filename, "source_id": source_id}, - ) - job = self.server.job_manager.create_job(pydantic_job=job, actor=self.user) - - # TODO: implement blocking vs. non-blocking - self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id, actor=self.user) - return job - - def delete_file_from_source(self, source_id: str, file_id: str) -> None: - self.server.source_manager.delete_file(file_id, actor=self.user) - - def get_job(self, job_id: str): - return self.server.job_manager.get_job_by_id(job_id=job_id, actor=self.user) - - def delete_job(self, job_id: str): - return self.server.job_manager.delete_job_by_id(job_id=job_id, actor=self.user) - - def list_jobs(self): - return self.server.job_manager.list_jobs(actor=self.user) - - def list_active_jobs(self): - return self.server.job_manager.list_jobs(actor=self.user, statuses=[JobStatus.created, JobStatus.running]) - - def create_source(self, name: str, embedding_config: Optional[EmbeddingConfig] = None) -> Source: - """ - Create a source - - Args: - name (str): Name of the source - - Returns: - source (Source): Created source - """ - assert embedding_config or self._default_embedding_config, f"Must specify embedding_config for source" - source = Source( - name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id - ) - return self.server.source_manager.create_source(source=source, actor=self.user) - - def delete_source(self, source_id: str): - """ - Delete a source - - Args: - source_id (str): ID of the source - """ - - # TODO: delete source data - self.server.delete_source(source_id=source_id, actor=self.user) - - def get_source(self, source_id: str) -> Source: - """ - Get a source given the ID. - - Args: - source_id (str): ID of the source - - Returns: - source (Source): Source - """ - return self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) - - def get_source_id(self, source_name: str) -> str: - """ - Get the ID of a source - - Args: - source_name (str): Name of the source - - Returns: - source_id (str): ID of the source - """ - return self.server.source_manager.get_source_by_name(source_name=source_name, actor=self.user).id - - def attach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState: - """ - Attach a source to an agent - - Args: - agent_id (str): ID of the agent - source_id (str): ID of the source - source_name (str): Name of the source - """ - if source_name: - source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) - source_id = source.id - - return self.server.agent_manager.attach_source(source_id=source_id, agent_id=agent_id, actor=self.user) - - def detach_source(self, agent_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None) -> AgentState: - """ - Detach a source from an agent by removing all `Passage` objects that were loaded from the source from archival memory. - Args: - agent_id (str): ID of the agent - source_id (str): ID of the source - source_name (str): Name of the source - Returns: - source (Source): Detached source - """ - if source_name: - source = self.server.source_manager.get_source_by_id(source_id=source_id, actor=self.user) - source_id = source.id - return self.server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=self.user) - - def list_sources(self) -> List[Source]: - """ - List available sources - - Returns: - sources (List[Source]): List of sources - """ - - return self.server.list_all_sources(actor=self.user) - - def list_attached_sources(self, agent_id: str) -> List[Source]: - """ - List sources attached to an agent - - Args: - agent_id (str): ID of the agent - - Returns: - sources (List[Source]): List of sources - """ - return self.server.agent_manager.list_attached_sources(agent_id=agent_id, actor=self.user) - - def list_files_from_source(self, source_id: str, limit: int = 1000, after: Optional[str] = None) -> List[FileMetadata]: - """ - List files from source. - - Args: - source_id (str): ID of the source - limit (int): The # of items to return - after (str): The cursor for fetching the next page - - Returns: - files (List[FileMetadata]): List of files - """ - return self.server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=self.user) - - def update_source(self, source_id: str, name: Optional[str] = None) -> Source: - """ - Update a source - - Args: - source_id (str): ID of the source - name (str): Name of the source - - Returns: - source (Source): Updated source - """ - # TODO should the arg here just be "source_update: Source"? - request = SourceUpdate(name=name) - return self.server.source_manager.update_source(source_id=source_id, source_update=request, actor=self.user) - - # archival memory - - def insert_archival_memory(self, agent_id: str, memory: str) -> List[Passage]: - """ - Insert archival memory into an agent - - Args: - agent_id (str): ID of the agent - memory (str): Memory string to insert - - Returns: - passages (List[Passage]): List of inserted passages - """ - return self.server.insert_archival_memory(agent_id=agent_id, memory_contents=memory, actor=self.user) - - def delete_archival_memory(self, agent_id: str, memory_id: str): - """ - Delete archival memory from an agent - - Args: - agent_id (str): ID of the agent - memory_id (str): ID of the memory - """ - self.server.delete_archival_memory(memory_id=memory_id, actor=self.user) - - def get_archival_memory( - self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000 - ) -> List[Passage]: - """ - Get archival memory from an agent with pagination. - - Args: - agent_id (str): ID of the agent - before (str): Get memories before a certain time - after (str): Get memories after a certain time - limit (int): Limit number of memories - - Returns: - passages (List[Passage]): List of passages - """ - - return self.server.get_agent_archival(user_id=self.user_id, agent_id=agent_id, limit=limit) - - # recall memory - - def get_messages( - self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000 - ) -> List[LettaMessage]: - """ - Get messages from an agent with pagination. - - Args: - agent_id (str): ID of the agent - before (str): Get messages before a certain time - after (str): Get messages after a certain time - limit (int): Limit number of messages - - Returns: - messages (List[Message]): List of messages - """ - - self.interface.clear() - return self.server.get_agent_recall( - user_id=self.user_id, - agent_id=agent_id, - before=before, - after=after, - limit=limit, - reverse=True, - return_message_object=False, - ) - - def list_blocks(self, label: Optional[str] = None, templates_only: Optional[bool] = True) -> List[Block]: - """ - List available blocks - - Args: - label (str): Label of the block - templates_only (bool): List only templates - - Returns: - blocks (List[Block]): List of blocks - """ - return [] - - def create_block( - self, label: str, value: str, limit: Optional[int] = None, template_name: Optional[str] = None, is_template: bool = False - ) -> Block: # - """ - Create a block - - Args: - label (str): Label of the block - name (str): Name of the block - text (str): Text of the block - limit (int): Character of the block - - Returns: - block (Block): Created block - """ - block = Block(label=label, template_name=template_name, value=value, is_template=is_template) - if limit: - block.limit = limit - return self.server.block_manager.create_or_update_block(block, actor=self.user) - - def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None, limit: Optional[int] = None) -> Block: - """ - Update a block - - Args: - block_id (str): ID of the block - name (str): Name of the block - text (str): Text of the block - - Returns: - block (Block): Updated block - """ - return self.server.block_manager.update_block( - block_id=block_id, - block_update=BlockUpdate(template_name=name, value=text, limit=limit if limit else self.get_block(block_id).limit), - actor=self.user, - ) - - def get_block(self, block_id: str) -> Block: - """ - Get a block - - Args: - block_id (str): ID of the block - - Returns: - block (Block): Block - """ - return self.server.block_manager.get_block_by_id(block_id, actor=self.user) - - def delete_block(self, id: str) -> Block: - """ - Delete a block - - Args: - id (str): ID of the block - - Returns: - block (Block): Deleted block - """ - return self.server.block_manager.delete_block(id, actor=self.user) - - def set_default_llm_config(self, llm_config: LLMConfig): - """ - Set the default LLM configuration for agents. - - Args: - llm_config (LLMConfig): LLM configuration - """ - self._default_llm_config = llm_config - - def set_default_embedding_config(self, embedding_config: EmbeddingConfig): - """ - Set the default embedding configuration for agents. - - Args: - embedding_config (EmbeddingConfig): Embedding configuration - """ - self._default_embedding_config = embedding_config - - def list_llm_configs(self) -> List[LLMConfig]: - """ - List available LLM configurations - - Returns: - configs (List[LLMConfig]): List of LLM configurations - """ - return self.server.list_llm_models(actor=self.user) - - def list_embedding_configs(self) -> List[EmbeddingConfig]: - """ - List available embedding configurations - - Returns: - configs (List[EmbeddingConfig]): List of embedding configurations - """ - return self.server.list_embedding_models(actor=self.user) - - def create_org(self, name: Optional[str] = None) -> Organization: - return self.server.organization_manager.create_organization(pydantic_org=Organization(name=name)) - - def list_orgs(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]: - return self.server.organization_manager.list_organizations(limit=limit, after=after) - - def delete_org(self, org_id: str) -> Organization: - return self.server.organization_manager.delete_organization_by_id(org_id=org_id) - - def create_sandbox_config(self, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: - """ - Create a new sandbox configuration. - """ - config_create = SandboxConfigCreate(config=config) - return self.server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create=config_create, actor=self.user) - - def update_sandbox_config(self, sandbox_config_id: str, config: Union[LocalSandboxConfig, E2BSandboxConfig]) -> SandboxConfig: - """ - Update an existing sandbox configuration. - """ - sandbox_update = SandboxConfigUpdate(config=config) - return self.server.sandbox_config_manager.update_sandbox_config( - sandbox_config_id=sandbox_config_id, sandbox_update=sandbox_update, actor=self.user - ) - - def delete_sandbox_config(self, sandbox_config_id: str) -> None: - """ - Delete a sandbox configuration. - """ - return self.server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id=sandbox_config_id, actor=self.user) - - def list_sandbox_configs(self, limit: int = 50, after: Optional[str] = None) -> List[SandboxConfig]: - """ - List all sandbox configurations. - """ - return self.server.sandbox_config_manager.list_sandbox_configs(actor=self.user, limit=limit, after=after) - - def create_sandbox_env_var( - self, sandbox_config_id: str, key: str, value: str, description: Optional[str] = None - ) -> SandboxEnvironmentVariable: - """ - Create a new environment variable for a sandbox configuration. - """ - env_var_create = SandboxEnvironmentVariableCreate(key=key, value=value, description=description) - return self.server.sandbox_config_manager.create_sandbox_env_var( - env_var_create=env_var_create, sandbox_config_id=sandbox_config_id, actor=self.user - ) - - def update_sandbox_env_var( - self, env_var_id: str, key: Optional[str] = None, value: Optional[str] = None, description: Optional[str] = None - ) -> SandboxEnvironmentVariable: - """ - Update an existing environment variable. - """ - env_var_update = SandboxEnvironmentVariableUpdate(key=key, value=value, description=description) - return self.server.sandbox_config_manager.update_sandbox_env_var( - env_var_id=env_var_id, env_var_update=env_var_update, actor=self.user - ) - - def delete_sandbox_env_var(self, env_var_id: str) -> None: - """ - Delete an environment variable by its ID. - """ - return self.server.sandbox_config_manager.delete_sandbox_env_var(env_var_id=env_var_id, actor=self.user) - - def list_sandbox_env_vars( - self, sandbox_config_id: str, limit: int = 50, after: Optional[str] = None - ) -> List[SandboxEnvironmentVariable]: - """ - List all environment variables associated with a sandbox configuration. - """ - return self.server.sandbox_config_manager.list_sandbox_env_vars( - sandbox_config_id=sandbox_config_id, actor=self.user, limit=limit, after=after - ) - - def update_agent_memory_block_label(self, agent_id: str, current_label: str, new_label: str) -> Memory: - """Rename a block in the agent's core memory - - Args: - agent_id (str): The agent ID - current_label (str): The current label of the block - new_label (str): The new label of the block - - Returns: - memory (Memory): The updated memory - """ - block = self.get_agent_memory_block(agent_id, current_label) - return self.update_block(block.id, label=new_label) - - def get_agent_memory_blocks(self, agent_id: str) -> List[Block]: - """ - Get all the blocks in the agent's core memory - - Args: - agent_id (str): The agent ID - - Returns: - blocks (List[Block]): The blocks in the agent's core memory - """ - agent = self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user) - return agent.memory.blocks - - def get_agent_memory_block(self, agent_id: str, label: str) -> Block: - """ - Get a block in the agent's core memory by its label - - Args: - agent_id (str): The agent ID - label (str): The label in the agent's core memory - - Returns: - block (Block): The block corresponding to the label - """ - return self.server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=self.user) - - def update_agent_memory_block( - self, - agent_id: str, - label: str, - value: Optional[str] = None, - limit: Optional[int] = None, - ): - """ - Update a block in the agent's core memory by specifying its label - - Args: - agent_id (str): The agent ID - label (str): The label of the block - value (str): The new value of the block - limit (int): The new limit of the block - - Returns: - block (Block): The updated block - """ - block = self.get_agent_memory_block(agent_id, label) - data = {} - if value: - data["value"] = value - if limit: - data["limit"] = limit - return self.server.block_manager.update_block(block.id, actor=self.user, block_update=BlockUpdate(**data)) - - def update_block( - self, - block_id: str, - label: Optional[str] = None, - value: Optional[str] = None, - limit: Optional[int] = None, - ): - """ - Update a block given the ID with the provided fields - - Args: - block_id (str): ID of the block - label (str): Label to assign to the block - value (str): Value to assign to the block - limit (int): Token limit to assign to the block - - Returns: - block (Block): Updated block - """ - data = {} - if value: - data["value"] = value - if limit: - data["limit"] = limit - if label: - data["label"] = label - return self.server.block_manager.update_block(block_id, actor=self.user, block_update=BlockUpdate(**data)) - - def attach_block(self, agent_id: str, block_id: str) -> AgentState: - """ - Attach a block to an agent. - - Args: - agent_id (str): ID of the agent - block_id (str): ID of the block to attach - """ - return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user) - - def detach_block(self, agent_id: str, block_id: str) -> AgentState: - """ - Detach a block from an agent. - - Args: - agent_id (str): ID of the agent - block_id (str): ID of the block to detach - """ - return self.server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=self.user) - - def get_run_messages( - self, - run_id: str, - before: Optional[str] = None, - after: Optional[str] = None, - limit: Optional[int] = 100, - ascending: bool = True, - role: Optional[MessageRole] = None, - ) -> List[LettaMessageUnion]: - """ - Get messages associated with a job with filtering options. - - Args: - run_id: ID of the run - before: Cursor for pagination - after: Cursor for pagination - limit: Maximum number of messages to return - ascending: Sort order by creation time - role: Filter by message role (user/assistant/system/tool) - Returns: - List of messages matching the filter criteria - """ - params = { - "before": before, - "after": after, - "limit": limit, - "ascending": ascending, - "role": role, - } - - return self.server.job_manager.get_run_messages(run_id=run_id, actor=self.user, **params) - - def get_run_usage( - self, - run_id: str, - ) -> List[UsageStatistics]: - """ - Get usage statistics associated with a job. - - Args: - run_id (str): ID of the run - - Returns: - List[UsageStatistics]: List of usage statistics associated with the run - """ - usage = self.server.job_manager.get_job_usage(job_id=run_id, actor=self.user) - return [ - UsageStatistics(completion_tokens=stat.completion_tokens, prompt_tokens=stat.prompt_tokens, total_tokens=stat.total_tokens) - for stat in usage - ] - - def get_run(self, run_id: str) -> Run: - """ - Get a run by ID. - - Args: - run_id (str): ID of the run - - Returns: - run (Run): Run - """ - return self.server.job_manager.get_job_by_id(job_id=run_id, actor=self.user) - - def delete_run(self, run_id: str) -> None: - """ - Delete a run by ID. - - Args: - run_id (str): ID of the run - """ - return self.server.job_manager.delete_job_by_id(job_id=run_id, actor=self.user) - - def list_runs(self) -> List[Run]: - """ - List all runs. - - Returns: - runs (List[Run]): List of runs - """ - return self.server.job_manager.list_jobs(actor=self.user, job_type=JobType.RUN) - - def list_active_runs(self) -> List[Run]: - """ - List all active runs. - - Returns: - runs (List[Run]): List of active runs - """ - return self.server.job_manager.list_jobs(actor=self.user, job_type=JobType.RUN, statuses=[JobStatus.created, JobStatus.running]) - - def get_tags( - self, - after: Optional[str] = None, - limit: Optional[int] = None, - query_text: Optional[str] = None, - ) -> List[str]: - """ - Get all tags. - - Returns: - tags (List[str]): List of tags - """ - return self.server.agent_manager.list_tags(actor=self.user, after=after, limit=limit, query_text=query_text) diff --git a/letta/data_sources/connectors.py b/letta/data_sources/connectors.py index 188b37b75..41f728c2b 100644 --- a/letta/data_sources/connectors.py +++ b/letta/data_sources/connectors.py @@ -37,7 +37,9 @@ class DataConnector: """ -def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, source_manager: SourceManager, actor: "User"): +async def load_data( + connector: DataConnector, source: Source, passage_manager: PassageManager, source_manager: SourceManager, actor: "User" +): """Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id.""" embedding_config = source.embedding_config @@ -51,7 +53,7 @@ def load_data(connector: DataConnector, source: Source, passage_manager: Passage file_count = 0 for file_metadata in connector.find_files(source): file_count += 1 - source_manager.create_file(file_metadata, actor) + await source_manager.create_file(file_metadata, actor) # generate passages for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size): diff --git a/letta/functions/ast_parsers.py b/letta/functions/ast_parsers.py index e169f5964..3113cd963 100644 --- a/letta/functions/ast_parsers.py +++ b/letta/functions/ast_parsers.py @@ -1,5 +1,7 @@ import ast +import builtins import json +import typing from typing import Dict, Optional, Tuple from letta.errors import LettaToolCreateError @@ -22,7 +24,7 @@ def resolve_type(annotation: str): Resolve a type annotation string into a Python type. Args: - annotation (str): The annotation string (e.g., 'int', 'list', etc.). + annotation (str): The annotation string (e.g., 'int', 'list[int]', 'dict[str, int]'). Returns: type: The corresponding Python type. @@ -34,24 +36,17 @@ def resolve_type(annotation: str): return BUILTIN_TYPES[annotation] try: - if annotation.startswith("list["): - inner_type = annotation[len("list[") : -1] - resolve_type(inner_type) - return list - elif annotation.startswith("dict["): - inner_types = annotation[len("dict[") : -1] - key_type, value_type = inner_types.split(",") - return dict - elif annotation.startswith("tuple["): - inner_types = annotation[len("tuple[") : -1] - [resolve_type(t.strip()) for t in inner_types.split(",")] - return tuple - - parsed = ast.literal_eval(annotation) - if isinstance(parsed, type): - return parsed - raise ValueError(f"Annotation '{annotation}' is not a recognized type.") - except (ValueError, SyntaxError): + # Allow use of typing and builtins in a safe eval context + namespace = { + **vars(typing), + **vars(builtins), + "list": list, + "dict": dict, + "tuple": tuple, + "set": set, + } + return eval(annotation, namespace) + except Exception: raise ValueError(f"Unsupported annotation: {annotation}") @@ -82,41 +77,36 @@ def get_function_annotations_from_source(source_code: str, function_name: str) - def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict: - """ - Coerce arguments in a dictionary to their annotated types. - - Args: - function_args (dict): The original function arguments. - annotations (Dict[str, str]): Argument annotations as strings. - - Returns: - dict: The updated dictionary with coerced argument types. - - Raises: - ValueError: If type coercion fails for an argument. - """ - coerced_args = dict(function_args) # Shallow copy for mutation safety + coerced_args = dict(function_args) # Shallow copy for arg_name, value in coerced_args.items(): if arg_name in annotations: annotation_str = annotations[arg_name] try: - # Resolve the type from the annotation arg_type = resolve_type(annotation_str) - # Handle JSON-like inputs for dict and list types - if arg_type in {dict, list} and isinstance(value, str): + # Always parse strings using literal_eval or json if possible + if isinstance(value, str): try: - # First, try JSON parsing value = json.loads(value) except json.JSONDecodeError: - # Fall back to literal_eval for Python-specific literals - value = ast.literal_eval(value) + try: + value = ast.literal_eval(value) + except (SyntaxError, ValueError) as e: + if arg_type is not str: + raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}") - # Coerce the value to the resolved type - coerced_args[arg_name] = arg_type(value) - except (TypeError, ValueError, json.JSONDecodeError, SyntaxError) as e: + origin = typing.get_origin(arg_type) + if origin in (list, dict, tuple, set): + # Let the origin (e.g., list) handle coercion + coerced_args[arg_name] = origin(value) + else: + # Coerce simple types (e.g., int, float) + coerced_args[arg_name] = arg_type(value) + + except Exception as e: raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}") + return coerced_args diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index 9cd2cede1..f082ca385 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -19,6 +19,8 @@ from letta.services.group_manager import GroupManager from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.step_manager import NoopStepManager, StepManager +from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager class SleeptimeMultiAgentV2(BaseAgent): @@ -32,6 +34,8 @@ class SleeptimeMultiAgentV2(BaseAgent): group_manager: GroupManager, job_manager: JobManager, actor: User, + step_manager: StepManager = NoopStepManager(), + telemetry_manager: TelemetryManager = NoopTelemetryManager(), group: Optional[Group] = None, ): super().__init__( @@ -45,11 +49,18 @@ class SleeptimeMultiAgentV2(BaseAgent): self.passage_manager = passage_manager self.group_manager = group_manager self.job_manager = job_manager + self.step_manager = step_manager + self.telemetry_manager = telemetry_manager # Group settings assert group.manager_type == ManagerType.sleeptime, f"Expected group manager type to be 'sleeptime', got {group.manager_type}" self.group = group - async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse: + async def step( + self, + input_messages: List[MessageCreate], + max_steps: int = 10, + use_assistant_message: bool = True, + ) -> LettaResponse: run_ids = [] # Prepare new messages @@ -68,22 +79,26 @@ class SleeptimeMultiAgentV2(BaseAgent): block_manager=self.block_manager, passage_manager=self.passage_manager, actor=self.actor, + step_manager=self.step_manager, + telemetry_manager=self.telemetry_manager, ) # Perform foreground agent step - response = await foreground_agent.step(input_messages=new_messages, max_steps=max_steps) + response = await foreground_agent.step( + input_messages=new_messages, max_steps=max_steps, use_assistant_message=use_assistant_message + ) # Get last response messages last_response_messages = foreground_agent.response_messages # Update turns counter if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0: - turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor) + turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor) # Perform participant steps if self.group.sleeptime_agent_frequency is None or ( turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0 ): - last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update( + last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async( group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor ) for participant_agent_id in self.group.agent_ids: @@ -92,6 +107,7 @@ class SleeptimeMultiAgentV2(BaseAgent): participant_agent_id, last_response_messages, last_processed_message_id, + use_assistant_message, ) run_ids.append(run_id) @@ -103,7 +119,13 @@ class SleeptimeMultiAgentV2(BaseAgent): response.usage.run_ids = run_ids return response - async def step_stream(self, input_messages: List[MessageCreate], max_steps: int = 10) -> AsyncGenerator[str, None]: + async def step_stream( + self, + input_messages: List[MessageCreate], + max_steps: int = 10, + use_assistant_message: bool = True, + request_start_timestamp_ns: Optional[int] = None, + ) -> AsyncGenerator[str, None]: # Prepare new messages new_messages = [] for message in input_messages: @@ -120,9 +142,16 @@ class SleeptimeMultiAgentV2(BaseAgent): block_manager=self.block_manager, passage_manager=self.passage_manager, actor=self.actor, + step_manager=self.step_manager, + telemetry_manager=self.telemetry_manager, ) # Perform foreground agent step - async for chunk in foreground_agent.step_stream(input_messages=new_messages, max_steps=max_steps): + async for chunk in foreground_agent.step_stream( + input_messages=new_messages, + max_steps=max_steps, + use_assistant_message=use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + ): yield chunk # Get response messages @@ -130,20 +159,21 @@ class SleeptimeMultiAgentV2(BaseAgent): # Update turns counter if self.group.sleeptime_agent_frequency is not None and self.group.sleeptime_agent_frequency > 0: - turns_counter = self.group_manager.bump_turns_counter(group_id=self.group.id, actor=self.actor) + turns_counter = await self.group_manager.bump_turns_counter_async(group_id=self.group.id, actor=self.actor) # Perform participant steps if self.group.sleeptime_agent_frequency is None or ( turns_counter is not None and turns_counter % self.group.sleeptime_agent_frequency == 0 ): - last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update( + last_processed_message_id = await self.group_manager.get_last_processed_message_id_and_update_async( group_id=self.group.id, last_processed_message_id=last_response_messages[-1].id, actor=self.actor ) for sleeptime_agent_id in self.group.agent_ids: - self._issue_background_task( + run_id = await self._issue_background_task( sleeptime_agent_id, last_response_messages, last_processed_message_id, + use_assistant_message, ) async def _issue_background_task( @@ -151,6 +181,7 @@ class SleeptimeMultiAgentV2(BaseAgent): sleeptime_agent_id: str, response_messages: List[Message], last_processed_message_id: str, + use_assistant_message: bool = True, ) -> str: run = Run( user_id=self.actor.id, @@ -160,7 +191,7 @@ class SleeptimeMultiAgentV2(BaseAgent): "agent_id": sleeptime_agent_id, }, ) - run = self.job_manager.create_job(pydantic_job=run, actor=self.actor) + run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor) asyncio.create_task( self._participant_agent_step( @@ -169,6 +200,7 @@ class SleeptimeMultiAgentV2(BaseAgent): response_messages=response_messages, last_processed_message_id=last_processed_message_id, run_id=run.id, + use_assistant_message=True, ) ) return run.id @@ -180,11 +212,12 @@ class SleeptimeMultiAgentV2(BaseAgent): response_messages: List[Message], last_processed_message_id: str, run_id: str, + use_assistant_message: bool = True, ) -> str: try: # Update job status job_update = JobUpdate(status=JobStatus.running) - self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) # Create conversation transcript prior_messages = [] @@ -221,11 +254,14 @@ class SleeptimeMultiAgentV2(BaseAgent): block_manager=self.block_manager, passage_manager=self.passage_manager, actor=self.actor, + step_manager=self.step_manager, + telemetry_manager=self.telemetry_manager, ) # Perform sleeptime agent step result = await sleeptime_agent.step( input_messages=sleeptime_agent_messages, + use_assistant_message=use_assistant_message, ) # Update job status @@ -237,7 +273,7 @@ class SleeptimeMultiAgentV2(BaseAgent): "agent_id": sleeptime_agent_id, }, ) - self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) return result except Exception as e: job_update = JobUpdate( @@ -245,5 +281,5 @@ class SleeptimeMultiAgentV2(BaseAgent): completed_at=datetime.now(timezone.utc).replace(tzinfo=None), metadata={"error": str(e)}, ) - self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.actor) + await self.job_manager.update_job_by_id_async(job_id=run_id, job_update=job_update, actor=self.actor) raise diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index e0f51dd54..401860e83 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -106,7 +106,7 @@ async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob], results: List[BatchPollingResult] = await asyncio.gather(*coros) # Update the server with batch status changes - server.batch_manager.bulk_update_llm_batch_statuses(updates=results) + await server.batch_manager.bulk_update_llm_batch_statuses_async(updates=results) logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.") return results @@ -197,13 +197,13 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo # 6. Bulk update all items for newly completed batch(es) if item_updates: metrics.updated_items_count = len(item_updates) - server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates) + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(item_updates) # ─── Kick off post‑processing for each batch that just completed ─── completed = [r for r in batch_results if r.request_status == JobStatus.completed] async def _resume(batch_row: LLMBatchJob) -> LettaBatchResponse: - actor: User = server.user_manager.get_user_by_id(batch_row.created_by_id) + actor: User = await server.user_manager.get_actor_by_id_async(batch_row.created_by_id) runner = LettaAgentBatch( message_manager=server.message_manager, agent_manager=server.agent_manager, diff --git a/letta/jobs/scheduler.py b/letta/jobs/scheduler.py index 6e7dad000..80999a5da 100644 --- a/letta/jobs/scheduler.py +++ b/letta/jobs/scheduler.py @@ -4,10 +4,11 @@ from typing import Optional from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.triggers.interval import IntervalTrigger +from sqlalchemy import text from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.log import get_logger -from letta.server.db import db_context +from letta.server.db import db_registry from letta.server.server import SyncServer from letta.settings import settings @@ -34,18 +35,16 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: acquired_lock = False try: # Use a temporary connection context for the attempt initially - with db_context() as session: - engine = session.get_bind() - # Get raw connection - MUST be kept open if lock is acquired - raw_conn = engine.raw_connection() - cur = raw_conn.cursor() + async with db_registry.async_session() as session: + raw_conn = await session.connection() - cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,)) - acquired_lock = cur.fetchone()[0] + # Try to acquire the advisory lock + sql = text("SELECT pg_try_advisory_lock(CAST(:lock_key AS bigint))") + result = await session.execute(sql, {"lock_key": ADVISORY_LOCK_KEY}) + acquired_lock = result.scalar_one() if not acquired_lock: - cur.close() - raw_conn.close() + await raw_conn.close() logger.info("Scheduler lock held by another instance.") return False @@ -106,14 +105,14 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: # Clean up temporary resources if lock wasn't acquired or error occurred before storing if cur: try: - cur.close() - except: - pass + await cur.close() + except Exception as e: + logger.warning(f"Error closing cursor: {e}") if raw_conn: try: - raw_conn.close() - except: - pass + await raw_conn.close() + except Exception as e: + logger.warning(f"Error closing connection: {e}") async def _background_lock_retry_loop(server: SyncServer): @@ -161,7 +160,9 @@ async def _release_advisory_lock(): try: if not lock_conn.closed: if not lock_cur.closed: - lock_cur.execute("SELECT pg_advisory_unlock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,)) + # Use SQLAlchemy text() for raw SQL + unlock_sql = text("SELECT pg_advisory_unlock(CAST(:lock_key AS bigint))") + lock_cur.execute(unlock_sql, {"lock_key": ADVISORY_LOCK_KEY}) lock_cur.fetchone() # Consume result lock_conn.commit() logger.info(f"Executed pg_advisory_unlock for lock {ADVISORY_LOCK_KEY}") @@ -175,12 +176,12 @@ async def _release_advisory_lock(): # Ensure resources are closed regardless of unlock success try: if lock_cur and not lock_cur.closed: - lock_cur.close() + await lock_cur.close() except Exception as e: logger.error(f"Error closing advisory lock cursor: {e}", exc_info=True) try: if lock_conn and not lock_conn.closed: - lock_conn.close() + await lock_conn.close() logger.info("Closed database connection that held advisory lock.") except Exception as e: logger.error(f"Error closing advisory lock connection: {e}", exc_info=True) diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index f7509b037..f131e776c 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -45,11 +45,13 @@ logger = get_logger(__name__) class AnthropicClient(LLMClientBase): + @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: client = self._get_anthropic_client(llm_config, async_client=False) response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return response.model_dump() + @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: client = self._get_anthropic_client(llm_config, async_client=True) response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) @@ -339,6 +341,7 @@ class AnthropicClient(LLMClientBase): # TODO: Input messages doesn't get used here # TODO: Clean up this interface + @trace_method def convert_response_to_chat_completion( self, response_data: dict, diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 2874b62a9..e8215813f 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -17,6 +17,7 @@ from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics from letta.settings import model_settings, settings +from letta.tracing import trace_method from letta.utils import get_tool_call_id logger = get_logger(__name__) @@ -32,6 +33,7 @@ class GoogleVertexClient(LLMClientBase): http_options={"api_version": "v1"}, ) + @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. @@ -44,6 +46,7 @@ class GoogleVertexClient(LLMClientBase): ) return response.model_dump() + @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. @@ -189,6 +192,7 @@ class GoogleVertexClient(LLMClientBase): return [{"functionDeclarations": function_list}] + @trace_method def build_request_data( self, messages: List[PydanticMessage], @@ -248,6 +252,7 @@ class GoogleVertexClient(LLMClientBase): return request_data + @trace_method def convert_response_to_chat_completion( self, response_data: dict, diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index e6ac37a22..d144d03cc 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -32,6 +32,7 @@ from letta.schemas.openai.chat_completion_request import Tool as OpenAITool from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.settings import model_settings +from letta.tracing import trace_method logger = get_logger(__name__) @@ -124,6 +125,7 @@ class OpenAIClient(LLMClientBase): return kwargs + @trace_method def build_request_data( self, messages: List[PydanticMessage], @@ -213,6 +215,7 @@ class OpenAIClient(LLMClientBase): return data.model_dump(exclude_unset=True) + @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying synchronous request to OpenAI API and returns raw response dict. @@ -222,6 +225,7 @@ class OpenAIClient(LLMClientBase): response: ChatCompletion = client.chat.completions.create(**request_data) return response.model_dump() + @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying asynchronous request to OpenAI API and returns raw response dict. @@ -230,6 +234,7 @@ class OpenAIClient(LLMClientBase): response: ChatCompletion = await client.chat.completions.create(**request_data) return response.model_dump() + @trace_method def convert_response_to_chat_completion( self, response_data: dict, diff --git a/letta/main.py b/letta/main.py index de1b4028a..a64b3637e 100644 --- a/letta/main.py +++ b/letta/main.py @@ -1,374 +1,14 @@ import os -import sys -import traceback -import questionary -import requests import typer -from rich.console import Console -import letta.agent as agent -import letta.errors as errors -import letta.system as system - -# import benchmark -from letta import create_client -from letta.benchmark.benchmark import bench -from letta.cli.cli import delete_agent, open_folder, run, server, version -from letta.cli.cli_config import add, add_tool, configure, delete, list, list_tools +from letta.cli.cli import server from letta.cli.cli_load import app as load_app -from letta.config import LettaConfig -from letta.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE - -# from letta.interface import CLIInterface as interface # for printing to terminal -from letta.streaming_interface import AgentRefreshStreamingInterface - -# interface = interface() # disable composio print on exit os.environ["COMPOSIO_DISABLE_VERSION_CHECK"] = "true" app = typer.Typer(pretty_exceptions_enable=False) -app.command(name="run")(run) -app.command(name="version")(version) -app.command(name="configure")(configure) -app.command(name="list")(list) -app.command(name="add")(add) -app.command(name="add-tool")(add_tool) -app.command(name="list-tools")(list_tools) -app.command(name="delete")(delete) app.command(name="server")(server) -app.command(name="folder")(open_folder) -# load data commands + app.add_typer(load_app, name="load") -# benchmark command -app.command(name="benchmark")(bench) -# delete agents -app.command(name="delete-agent")(delete_agent) - - -def clear_line(console, strip_ui=False): - if strip_ui: - return - if os.name == "nt": # for windows - console.print("\033[A\033[K", end="") - else: # for linux - sys.stdout.write("\033[2K\033[G") - sys.stdout.flush() - - -def run_agent_loop( - letta_agent: agent.Agent, - config: LettaConfig, - first: bool, - no_verify: bool = False, - strip_ui: bool = False, - stream: bool = False, -): - if isinstance(letta_agent.interface, AgentRefreshStreamingInterface): - # letta_agent.interface.toggle_streaming(on=stream) - if not stream: - letta_agent.interface = letta_agent.interface.nonstreaming_interface - - if hasattr(letta_agent.interface, "console"): - console = letta_agent.interface.console - else: - console = Console() - - counter = 0 - user_input = None - skip_next_user_input = False - user_message = None - USER_GOES_FIRST = first - - if not USER_GOES_FIRST: - console.input("[bold cyan]Hit enter to begin (will request first Letta message)[/bold cyan]\n") - clear_line(console, strip_ui=strip_ui) - print() - - multiline_input = False - - # create client - client = create_client() - - # run loops - while True: - if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): - # Ask for user input - if not stream: - print() - user_input = questionary.text( - "Enter your message:", - multiline=multiline_input, - qmark=">", - ).ask() - clear_line(console, strip_ui=strip_ui) - if not stream: - print() - - # Gracefully exit on Ctrl-C/D - if user_input is None: - user_input = "/exit" - - user_input = user_input.rstrip() - - if user_input.startswith("!"): - print(f"Commands for CLI begin with '/' not '!'") - continue - - if user_input == "": - # no empty messages allowed - print("Empty input received. Try again!") - continue - - # Handle CLI commands - # Commands to not get passed as input to Letta - if user_input.startswith("/"): - # updated agent save functions - if user_input.lower() == "/exit": - # letta_agent.save() - agent.save_agent(letta_agent) - break - elif user_input.lower() == "/save" or user_input.lower() == "/savechat": - # letta_agent.save() - agent.save_agent(letta_agent) - continue - elif user_input.lower() == "/attach": - # TODO: check if agent already has it - - # TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not - # TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources - - sources = client.list_sources() - if len(sources) == 0: - typer.secho( - 'No sources available. You must load a souce with "letta load ..." before running /attach.', - fg=typer.colors.RED, - bold=True, - ) - continue - - # determine what sources are valid to be attached to this agent - valid_options = [] - invalid_options = [] - for source in sources: - if source.embedding_config == letta_agent.agent_state.embedding_config: - valid_options.append(source.name) - else: - # print warning about invalid sources - typer.secho( - f"Source {source.name} exists but has embedding dimentions {source.embedding_dim} from model {source.embedding_model}, while the agent uses embedding dimentions {letta_agent.agent_state.embedding_config.embedding_dim} and model {letta_agent.agent_state.embedding_config.embedding_model}", - fg=typer.colors.YELLOW, - ) - invalid_options.append(source.name) - - # prompt user for data source selection - data_source = questionary.select("Select data source", choices=valid_options).ask() - - # attach new data - client.attach_source_to_agent(agent_id=letta_agent.agent_state.id, source_name=data_source) - - continue - - elif user_input.lower() == "/dump" or user_input.lower().startswith("/dump "): - # Check if there's an additional argument that's an integer - command = user_input.strip().split() - amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 - if amount == 0: - letta_agent.interface.print_messages(letta_agent._messages, dump=True) - else: - letta_agent.interface.print_messages(letta_agent._messages[-min(amount, len(letta_agent.messages)) :], dump=True) - continue - - elif user_input.lower() == "/dumpraw": - letta_agent.interface.print_messages_raw(letta_agent._messages) - continue - - elif user_input.lower() == "/memory": - print(f"\nDumping memory contents:\n") - print(f"{letta_agent.agent_state.memory.compile()}") - print(f"{letta_agent.archival_memory.compile()}") - continue - - elif user_input.lower() == "/model": - print(f"Current model: {letta_agent.agent_state.llm_config.model}") - continue - - elif user_input.lower() == "/summarize": - try: - letta_agent.summarize_messages_inplace() - typer.secho( - f"/summarize succeeded", - fg=typer.colors.GREEN, - bold=True, - ) - except (errors.LLMError, requests.exceptions.HTTPError) as e: - typer.secho( - f"/summarize failed:\n{e}", - fg=typer.colors.RED, - bold=True, - ) - continue - - elif user_input.lower() == "/tokens": - tokens = letta_agent.count_tokens() - typer.secho( - f"{tokens}/{letta_agent.agent_state.llm_config.context_window}", - fg=typer.colors.GREEN, - bold=True, - ) - continue - - elif user_input.lower().startswith("/add_function"): - try: - if len(user_input) < len("/add_function "): - print("Missing function name after the command") - continue - function_name = user_input[len("/add_function ") :].strip() - result = letta_agent.add_function(function_name) - typer.secho( - f"/add_function succeeded: {result}", - fg=typer.colors.GREEN, - bold=True, - ) - except ValueError as e: - typer.secho( - f"/add_function failed:\n{e}", - fg=typer.colors.RED, - bold=True, - ) - continue - elif user_input.lower().startswith("/remove_function"): - try: - if len(user_input) < len("/remove_function "): - print("Missing function name after the command") - continue - function_name = user_input[len("/remove_function ") :].strip() - result = letta_agent.remove_function(function_name) - typer.secho( - f"/remove_function succeeded: {result}", - fg=typer.colors.GREEN, - bold=True, - ) - except ValueError as e: - typer.secho( - f"/remove_function failed:\n{e}", - fg=typer.colors.RED, - bold=True, - ) - continue - - # No skip options - elif user_input.lower() == "/wipe": - letta_agent = agent.Agent(letta_agent.interface) - user_message = None - - elif user_input.lower() == "/heartbeat": - user_message = system.get_heartbeat() - - elif user_input.lower() == "/memorywarning": - user_message = system.get_token_limit_warning() - - elif user_input.lower() == "//": - multiline_input = not multiline_input - continue - - elif user_input.lower() == "/" or user_input.lower() == "/help": - questionary.print("CLI commands", "bold") - for cmd, desc in USER_COMMANDS: - questionary.print(cmd, "bold") - questionary.print(f" {desc}") - continue - else: - print(f"Unrecognized command: {user_input}") - continue - - else: - # If message did not begin with command prefix, pass inputs to Letta - # Handle user message and append to messages - user_message = str(user_input) - - skip_next_user_input = False - - def process_agent_step(user_message, no_verify): - # TODO(charles): update to use agent.step() instead of inner_step() - - if user_message is None: - step_response = letta_agent.inner_step( - messages=[], - first_message=False, - skip_verify=no_verify, - stream=stream, - ) - else: - step_response = letta_agent.step_user_message( - user_message_str=user_message, - first_message=False, - skip_verify=no_verify, - stream=stream, - ) - new_messages = step_response.messages - heartbeat_request = step_response.heartbeat_request - function_failed = step_response.function_failed - token_warning = step_response.in_context_memory_warning - step_response.usage - - agent.save_agent(letta_agent) - skip_next_user_input = False - if token_warning: - user_message = system.get_token_limit_warning() - skip_next_user_input = True - elif function_failed: - user_message = system.get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE) - skip_next_user_input = True - elif heartbeat_request: - user_message = system.get_heartbeat(REQ_HEARTBEAT_MESSAGE) - skip_next_user_input = True - - return new_messages, user_message, skip_next_user_input - - while True: - try: - if strip_ui: - _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - break - else: - if stream: - # Don't display the "Thinking..." if streaming - _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - else: - with console.status("[bold cyan]Thinking...") as status: - _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) - break - except KeyboardInterrupt: - print("User interrupt occurred.") - retry = questionary.confirm("Retry agent.step()?").ask() - if not retry: - break - except Exception: - print("An exception occurred when running agent.step(): ") - traceback.print_exc() - retry = questionary.confirm("Retry agent.step()?").ask() - if not retry: - break - - counter += 1 - - print("Finished.") - - -USER_COMMANDS = [ - ("//", "toggle multiline input mode"), - ("/exit", "exit the CLI"), - ("/save", "save a checkpoint of the current agent/conversation state"), - ("/load", "load a saved checkpoint"), - ("/dump ", "view the last messages (all if is omitted)"), - ("/memory", "print the current contents of agent memory"), - ("/pop ", "undo messages in the conversation (default is 3)"), - ("/retry", "pops the last answer and tries to get another one"), - ("/rethink ", "changes the inner thoughts of the last agent message"), - ("/rewrite ", "changes the reply of the last agent message"), - ("/heartbeat", "send a heartbeat system message to the agent"), - ("/memorywarning", "send a memory warning system message to the agent"), - ("/attach", "attach data source to agent"), -] diff --git a/letta/server/db.py b/letta/server/db.py index fe9abcff3..38b9b33b9 100644 --- a/letta/server/db.py +++ b/letta/server/db.py @@ -13,6 +13,9 @@ from sqlalchemy.orm import sessionmaker from letta.config import LettaConfig from letta.log import get_logger from letta.settings import settings +from letta.tracing import trace_method + +logger = get_logger(__name__) logger = get_logger(__name__) @@ -202,6 +205,7 @@ class DatabaseRegistry: self.initialize_async() return self._async_session_factories.get(name) + @trace_method @contextmanager def session(self, name: str = "default") -> Generator[Any, None, None]: """Context manager for database sessions.""" @@ -215,6 +219,7 @@ class DatabaseRegistry: finally: session.close() + @trace_method @asynccontextmanager async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]: """Async context manager for database sessions.""" diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 11b21b950..fc73fafc6 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -13,10 +13,11 @@ from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.log import get_logger from letta.orm.errors import NoResultFound -from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent +from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.block import Block, BlockUpdate from letta.schemas.group import Group from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig @@ -212,7 +213,7 @@ async def retrieve_agent_context_window( """ actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - return await server.get_agent_context_window_async(agent_id=agent_id, actor=actor) + return await server.agent_manager.get_context_window(agent_id=agent_id, actor=actor) except Exception as e: traceback.print_exc() raise e @@ -297,7 +298,7 @@ def detach_tool( @router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent") -def attach_source( +async def attach_source( agent_id: str, source_id: str, background_tasks: BackgroundTasks, @@ -310,7 +311,7 @@ def attach_source( actor = server.user_manager.get_user_or_default(user_id=actor_id) agent = server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor) if agent.enable_sleeptime: - source = server.source_manager.get_source_by_id(source_id=source_id) + source = await server.source_manager.get_source_by_id_async(source_id=source_id) background_tasks.add_task(server.sleeptime_document_ingest, agent, source, actor) return agent @@ -355,7 +356,7 @@ async def retrieve_agent( @router.delete("/{agent_id}", response_model=None, operation_id="delete_agent") -def delete_agent( +async def delete_agent( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -363,9 +364,9 @@ def delete_agent( """ Delete an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - server.agent_manager.delete_agent(agent_id=agent_id, actor=actor) + await server.agent_manager.delete_agent_async(agent_id=agent_id, actor=actor) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"}) except NoResultFound: raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.") @@ -386,7 +387,7 @@ async def list_agent_sources( # TODO: remove? can also get with agent blocks @router.get("/{agent_id}/core-memory", response_model=Memory, operation_id="retrieve_agent_memory") -def retrieve_agent_memory( +async def retrieve_agent_memory( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -395,13 +396,13 @@ def retrieve_agent_memory( Retrieve the memory state of a specific agent. This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.get_agent_memory(agent_id=agent_id, actor=actor) + return await server.get_agent_memory_async(agent_id=agent_id, actor=actor) @router.get("/{agent_id}/core-memory/blocks/{block_label}", response_model=Block, operation_id="retrieve_core_memory_block") -def retrieve_block( +async def retrieve_block( agent_id: str, block_label: str, server: "SyncServer" = Depends(get_letta_server), @@ -410,10 +411,10 @@ def retrieve_block( """ Retrieve a core memory block from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: - return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) + return await server.agent_manager.get_block_with_label_async(agent_id=agent_id, block_label=block_label, actor=actor) except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @@ -453,13 +454,13 @@ async def modify_block( ) # This should also trigger a system prompt change in the agent - server.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True, update_timestamp=False) + await server.agent_manager.rebuild_system_prompt_async(agent_id=agent_id, actor=actor, force=True, update_timestamp=False) return block @router.patch("/{agent_id}/core-memory/blocks/attach/{block_id}", response_model=AgentState, operation_id="attach_core_memory_block") -def attach_block( +async def attach_block( agent_id: str, block_id: str, server: "SyncServer" = Depends(get_letta_server), @@ -468,12 +469,12 @@ def attach_block( """ Attach a core memoryblock to an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.attach_block_async(agent_id=agent_id, block_id=block_id, actor=actor) @router.patch("/{agent_id}/core-memory/blocks/detach/{block_id}", response_model=AgentState, operation_id="detach_core_memory_block") -def detach_block( +async def detach_block( agent_id: str, block_id: str, server: "SyncServer" = Depends(get_letta_server), @@ -482,8 +483,8 @@ def detach_block( """ Detach a core memory block from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.agent_manager.detach_block_async(agent_id=agent_id, block_id=block_id, actor=actor) @router.get("/{agent_id}/archival-memory", response_model=List[Passage], operation_id="list_passages") @@ -637,22 +638,35 @@ async def send_message( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) - agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent + agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] if agent_eligible and feature_enabled and model_compatible: - experimental_agent = LettaAgent( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - actor=actor, - step_manager=server.step_manager, - telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), - ) + if agent.enable_sleeptime: + experimental_agent = SleeptimeMultiAgentV2( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + group_manager=server.group_manager, + job_manager=server.job_manager, + actor=actor, + group=agent.multi_agent_group, + ) + else: + experimental_agent = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + ) result = await experimental_agent.step(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message) else: @@ -697,23 +711,38 @@ async def send_message_streaming( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) - agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent + agent_eligible = agent.enable_sleeptime or not agent.multi_agent_group experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] - if agent_eligible and feature_enabled and model_compatible and request.stream_tokens: - experimental_agent = LettaAgent( - agent_id=agent_id, - message_manager=server.message_manager, - agent_manager=server.agent_manager, - block_manager=server.block_manager, - passage_manager=server.passage_manager, - actor=actor, - step_manager=server.step_manager, - telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), - ) + if agent_eligible and feature_enabled and model_compatible: + if agent.enable_sleeptime: + experimental_agent = SleeptimeMultiAgentV2( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + group_manager=server.group_manager, + job_manager=server.job_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + group=agent.multi_agent_group, + ) + else: + experimental_agent = LettaAgent( + agent_id=agent_id, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + actor=actor, + step_manager=server.step_manager, + telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(), + ) from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode if request.stream_tokens and model_compatible_token_streaming: diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 485563821..c98c2a118 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -23,7 +23,7 @@ async def list_llm_models( # Extract user_id from header, default to None if not present ): """List available LLM models using the asynchronous implementation for improved performance""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) models = await server.list_llm_models_async( provider_category=provider_category, @@ -42,7 +42,7 @@ async def list_embedding_models( # Extract user_id from header, default to None if not present ): """List available embedding models using the asynchronous implementation for improved performance""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) models = await server.list_embedding_models_async(actor=actor) return models diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index 4d7d3588a..e156d05dc 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -161,7 +161,7 @@ async def list_batch_messages( # Get messages directly using our efficient method # We'll need to update the underlying implementation to use message_id as cursor - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, sort_descending=sort_descending, cursor=cursor ) @@ -184,7 +184,7 @@ async def cancel_batch_run( job = await server.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor) # Get related llm batch jobs - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=job.id, actor=actor) for llm_batch_job in llm_batch_jobs: if llm_batch_job.status in {JobStatus.running, JobStatus.created}: # TODO: Extend to providers beyond anthropic @@ -194,6 +194,8 @@ async def cancel_batch_run( await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id) # Update all the batch_job statuses - server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor) + await server.batch_manager.update_llm_batch_status_async( + llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor + ) except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index 505e08a3d..00681ea24 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -22,36 +22,36 @@ logger = get_logger(__name__) @router.post("/", response_model=PydanticSandboxConfig) -def create_sandbox_config( +async def create_sandbox_config( config_create: SandboxConfigCreate, server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor) + return await server.sandbox_config_manager.create_or_update_sandbox_config_async(config_create, actor) @router.post("/e2b/default", response_model=PydanticSandboxConfig) -def create_default_e2b_sandbox_config( +async def create_default_e2b_sandbox_config( server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.sandbox_config_manager.get_or_create_default_sandbox_config_async(sandbox_type=SandboxType.E2B, actor=actor) @router.post("/local/default", response_model=PydanticSandboxConfig) -def create_default_local_sandbox_config( +async def create_default_local_sandbox_config( server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.sandbox_config_manager.get_or_create_default_sandbox_config_async(sandbox_type=SandboxType.LOCAL, actor=actor) @router.post("/local", response_model=PydanticSandboxConfig) -def create_custom_local_sandbox_config( +async def create_custom_local_sandbox_config( local_sandbox_config: LocalSandboxConfig, server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), @@ -67,26 +67,26 @@ def create_custom_local_sandbox_config( ) # Retrieve the user (actor) - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # Wrap the LocalSandboxConfig into a SandboxConfigCreate sandbox_config_create = SandboxConfigCreate(config=local_sandbox_config) # Use the manager to create or update the sandbox config - sandbox_config = server.sandbox_config_manager.create_or_update_sandbox_config(sandbox_config_create, actor=actor) + sandbox_config = await server.sandbox_config_manager.create_or_update_sandbox_config_async(sandbox_config_create, actor=actor) return sandbox_config @router.patch("/{sandbox_config_id}", response_model=PydanticSandboxConfig) -def update_sandbox_config( +async def update_sandbox_config( sandbox_config_id: str, config_update: SandboxConfigUpdate, server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + return await server.sandbox_config_manager.update_sandbox_config_async(sandbox_config_id, config_update, actor) @router.delete("/{sandbox_config_id}", status_code=204) @@ -112,7 +112,7 @@ async def list_sandbox_configs( @router.post("/local/recreate-venv", response_model=PydanticSandboxConfig) -def force_recreate_local_sandbox_venv( +async def force_recreate_local_sandbox_venv( server: SyncServer = Depends(get_letta_server), actor_id: str = Depends(get_user_id), ): @@ -120,10 +120,10 @@ def force_recreate_local_sandbox_venv( Forcefully recreate the virtual environment for the local sandbox. Deletes and recreates the venv, then reinstalls required dependencies. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # Retrieve the local sandbox config - sbx_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor) + sbx_config = await server.sandbox_config_manager.get_or_create_default_sandbox_config_async(sandbox_type=SandboxType.LOCAL, actor=actor) local_configs = sbx_config.get_local_config() sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) # Expand tilde diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 54f682fe3..478b6278a 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -1,3 +1,4 @@ +import asyncio import os import tempfile from typing import List, Optional @@ -21,18 +22,18 @@ router = APIRouter(prefix="/sources", tags=["sources"]) @router.get("/count", response_model=int, operation_id="count_sources") -def count_sources( +async def count_sources( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Count all data sources created by a user. """ - return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) + return await server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id)) @router.get("/{source_id}", response_model=Source, operation_id="retrieve_source") -def retrieve_source( +async def retrieve_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -42,14 +43,14 @@ def retrieve_source( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) + source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) if not source: raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.") return source @router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name") -def get_source_id_by_name( +async def get_source_id_by_name( source_name: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -59,14 +60,14 @@ def get_source_id_by_name( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor) + source = await server.source_manager.get_source_by_name(source_name=source_name, actor=actor) if not source: raise HTTPException(status_code=404, detail=f"Source with name={source_name} not found.") return source.id @router.get("/", response_model=List[Source], operation_id="list_sources") -def list_sources( +async def list_sources( server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -74,8 +75,7 @@ def list_sources( List all data sources created by a user. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - - return server.list_all_sources(actor=actor) + return await server.source_manager.list_sources(actor=actor) @router.get("/count", response_model=int, operation_id="count_sources") @@ -90,7 +90,7 @@ def count_sources( @router.post("/", response_model=Source, operation_id="create_source") -def create_source( +async def create_source( source_create: SourceCreate, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -99,6 +99,8 @@ def create_source( Create a new data source. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) + + # TODO: need to asyncify this if not source_create.embedding_config: if not source_create.embedding: # TODO: modify error type @@ -115,11 +117,11 @@ def create_source( instructions=source_create.instructions, metadata=source_create.metadata, ) - return server.source_manager.create_source(source=source, actor=actor) + return await server.source_manager.create_source(source=source, actor=actor) @router.patch("/{source_id}", response_model=Source, operation_id="modify_source") -def modify_source( +async def modify_source( source_id: str, source: SourceUpdate, server: "SyncServer" = Depends(get_letta_server), @@ -130,13 +132,13 @@ def modify_source( """ # TODO: allow updating the handle/embedding config actor = server.user_manager.get_user_or_default(user_id=actor_id) - if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor): + if not await server.source_manager.get_source_by_id(source_id=source_id, actor=actor): raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.") - return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor) + return await server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor) @router.delete("/{source_id}", response_model=None, operation_id="delete_source") -def delete_source( +async def delete_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -145,20 +147,21 @@ def delete_source( Delete a data source. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - source = server.source_manager.get_source_by_id(source_id=source_id) - agents = server.source_manager.list_attached_agents(source_id=source_id, actor=actor) + source = await server.source_manager.get_source_by_id(source_id=source_id) + agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) for agent in agents: if agent.enable_sleeptime: try: + # TODO: make async block = server.agent_manager.get_block_with_label(agent_id=agent.id, block_label=source.name, actor=actor) server.block_manager.delete_block(block.id, actor) except: pass - server.delete_source(source_id=source_id, actor=actor) + await server.delete_source(source_id=source_id, actor=actor) @router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source") -def upload_file_to_source( +async def upload_file_to_source( file: UploadFile, source_id: str, background_tasks: BackgroundTasks, @@ -170,7 +173,7 @@ def upload_file_to_source( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) + source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." bytes = file.file.read() @@ -184,8 +187,8 @@ def upload_file_to_source( server.job_manager.create_job(job, actor=actor) # create background tasks - background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor) - background_tasks.add_task(sleeptime_document_ingest_async, server, source_id, actor) + asyncio.create_task(load_file_to_source_async(server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor)) + asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor)) # return job information # Is this necessary? Can we just return the job from create_job? @@ -195,8 +198,11 @@ def upload_file_to_source( @router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages") -def list_source_passages( +async def list_source_passages( source_id: str, + after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."), + before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), + limit: int = Query(100, description="Maximum number of messages to retrieve."), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -204,12 +210,17 @@ def list_source_passages( List all passages associated with a data source. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id) - return passages + return await server.agent_manager.list_passages_async( + actor=actor, + source_id=source_id, + after=after, + before=before, + limit=limit, + ) @router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_source_files") -def list_source_files( +async def list_source_files( source_id: str, limit: int = Query(1000, description="Number of files to return"), after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), @@ -220,13 +231,13 @@ def list_source_files( List paginated files associated with a data source. """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - return server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor) + return await server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor) # it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action. # it's still good practice to return a status indicating the success or failure of the deletion @router.delete("/{source_id}/{file_id}", status_code=204, operation_id="delete_file_from_source") -def delete_file_from_source( +async def delete_file_from_source( source_id: str, file_id: str, background_tasks: BackgroundTasks, @@ -238,13 +249,15 @@ def delete_file_from_source( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) - deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor) - background_tasks.add_task(sleeptime_document_ingest_async, server, source_id, actor, clear_history=True) + deleted_file = await server.source_manager.delete_file(file_id=file_id, actor=actor) + + # TODO: make async + asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True)) if deleted_file is None: raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.") -def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User): +async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User): # Create a temporary directory (deleted after the context manager exits) with tempfile.TemporaryDirectory() as tmpdirname: # Sanitize the filename @@ -256,12 +269,12 @@ def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, f buffer.write(bytes) # Pass the file to load_file_to_source - server.load_file_to_source(source_id, file_path, job_id, actor) + await server.load_file_to_source(source_id, file_path, job_id, actor) -def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False): - source = server.source_manager.get_source_by_id(source_id=source_id) - agents = server.source_manager.list_attached_agents(source_id=source_id, actor=actor) +async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False): + source = await server.source_manager.get_source_by_id(source_id=source_id) + agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor) for agent in agents: if agent.enable_sleeptime: - server.sleeptime_document_ingest(agent, source, actor, clear_history) + server.sleeptime_document_ingest(agent, source, actor, clear_history) # TODO: make async diff --git a/letta/server/server.py b/letta/server/server.py index 1fb519484..45cdc882a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -50,7 +50,7 @@ from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolRe from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary +from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.organization import Organization from letta.schemas.passage import Passage, PassageUpdate @@ -969,6 +969,11 @@ class SyncServer(Server): """Return the memory of an agent (core memory)""" return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory + async def get_agent_memory_async(self, agent_id: str, actor: User) -> Memory: + """Return the memory of an agent (core memory)""" + agent = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) + return agent.memory + def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary: return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id)) @@ -1169,17 +1174,20 @@ class SyncServer(Server): # rebuild system prompt for agent, potentially changed return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory - def delete_source(self, source_id: str, actor: User): + async def delete_source(self, source_id: str, actor: User): """Delete a data source""" - self.source_manager.delete_source(source_id=source_id, actor=actor) + await self.source_manager.delete_source(source_id=source_id, actor=actor) # delete data from passage store + # TODO: make async passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None) + + # TODO: make this async self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted) # TODO: delete data from agent passage stores (?) - def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job: + async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job: # update job job = self.job_manager.get_job_by_id(job_id, actor=actor) @@ -1189,21 +1197,22 @@ class SyncServer(Server): # try: from letta.data_sources.connectors import DirectoryConnector - source = self.source_manager.get_source_by_id(source_id=source_id) + # TODO: move this into a thread + source = await self.source_manager.get_source_by_id(source_id=source_id) if source is None: raise ValueError(f"Source {source_id} does not exist") connector = DirectoryConnector(input_files=[file_path]) - num_passages, num_documents = self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector) + num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector) # update all agents who have this source attached - agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor) + agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor) for agent_state in agent_states: agent_id = agent_state.id # Attach source to agent - curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id) + curr_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) agent_state = self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor) - new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id) + new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added # rebuild system prompt and force @@ -1266,7 +1275,7 @@ class SyncServer(Server): actor=actor, ) - def load_data( + async def load_data( self, user_id: str, connector: DataConnector, @@ -1277,12 +1286,12 @@ class SyncServer(Server): # load data from a data source into the document store user = self.user_manager.get_user_by_id(user_id=user_id) - source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) + source = await self.source_manager.get_source_by_name(source_name=source_name, actor=user) if source is None: raise ValueError(f"Data source {source_name} does not exist for user {user_id}") # load data into the document store - passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user) + passage_count, document_count = await load_data(connector, source, self.passage_manager, self.source_manager, actor=user) return passage_count, document_count def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: @@ -1290,6 +1299,7 @@ class SyncServer(Server): return self.agent_manager.list_passages(actor=self.user_manager.get_user_or_default(user_id=user_id), source_id=source_id) def list_all_sources(self, actor: User) -> List[Source]: + # TODO: legacy: remove """List all sources (w/ extra metadata) belonging to a user""" sources = self.source_manager.list_sources(actor=actor) @@ -1376,7 +1386,7 @@ class SyncServer(Server): """Asynchronously list available models with maximum concurrency""" import asyncio - providers = self.get_enabled_providers( + providers = await self.get_enabled_providers_async( provider_category=provider_category, provider_name=provider_name, provider_type=provider_type, @@ -1422,7 +1432,7 @@ class SyncServer(Server): import asyncio # Get all eligible providers first - providers = self.get_enabled_providers(actor=actor) + providers = await self.get_enabled_providers_async(actor=actor) # Fetch embedding models from each provider concurrently async def get_provider_embedding_models(provider): @@ -1475,6 +1485,35 @@ class SyncServer(Server): return providers + async def get_enabled_providers_async( + self, + actor: User, + provider_category: Optional[List[ProviderCategory]] = None, + provider_name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + ) -> List[Provider]: + providers = [] + if not provider_category or ProviderCategory.base in provider_category: + providers_from_env = [p for p in self._enabled_providers] + providers.extend(providers_from_env) + + if not provider_category or ProviderCategory.byok in provider_category: + providers_from_db = await self.provider_manager.list_providers_async( + name=provider_name, + provider_type=provider_type, + actor=actor, + ) + providers_from_db = [p.cast_to_subtype() for p in providers_from_db] + providers.extend(providers_from_db) + + if provider_name is not None: + providers = [p for p in providers if p.name == provider_name] + + if provider_type is not None: + providers = [p for p in providers if p.provider_type == provider_type] + + return providers + @trace_method def get_llm_config_from_handle( self, @@ -1613,14 +1652,6 @@ class SyncServer(Server): def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig: """Add a new embedding model""" - def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview: - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - return letta_agent.get_context_window() - - async def get_agent_context_window_async(self, agent_id: str, actor: User) -> ContextWindowOverview: - letta_agent = self.load_agent(agent_id=agent_id, actor=actor) - return await letta_agent.get_context_window_async() - def run_tool_from_source( self, actor: User, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 915413e51..fcdf39430 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1,9 +1,11 @@ import asyncio +import os from datetime import datetime, timezone from typing import Dict, List, Optional, Set, Tuple import numpy as np import sqlalchemy as sa +from openai.types.beta.function_tool import FunctionTool as OpenAITool from sqlalchemy import Select, and_, delete, func, insert, literal, or_, select, union_all from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -20,6 +22,7 @@ from letta.constants import ( ) from letta.embeddings import embedding_model from letta.helpers.datetime_helpers import get_utc_time +from letta.llm_api.llm_client import LLMClient from letta.log import get_logger from letta.orm import Agent as AgentModel from letta.orm import AgentPassage, AgentsTags @@ -42,9 +45,11 @@ from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent, get_prompt_ from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole, ProviderType from letta.schemas.group import Group as PydanticGroup from letta.schemas.group import ManagerType -from letta.schemas.memory import Memory +from letta.schemas.letta_message_content import TextContent +from letta.schemas.memory import ContextWindowOverview, Memory from letta.schemas.message import Message from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate @@ -79,7 +84,7 @@ from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings from letta.tracing import trace_method -from letta.utils import enforce_types, united_diff +from letta.utils import count_tokens, enforce_types, united_diff logger = get_logger(__name__) @@ -548,6 +553,7 @@ class AgentManager: return init_messages + @trace_method @enforce_types def append_initial_message_sequence_to_in_context_messages( self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None @@ -555,6 +561,7 @@ class AgentManager: init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence) return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor) + @trace_method @enforce_types def update_agent( self, @@ -674,6 +681,7 @@ class AgentManager: return agent.to_pydantic() + @trace_method @enforce_types async def update_agent_async( self, @@ -792,6 +800,7 @@ class AgentManager: return await agent.to_pydantic_async() # TODO: Make this general and think about how to roll this into sqlalchemybase + @trace_method def list_agents( self, actor: PydanticUser, @@ -850,6 +859,7 @@ class AgentManager: agents = result.scalars().all() return [agent.to_pydantic(include_relationships=include_relationships) for agent in agents] + @trace_method async def list_agents_async( self, actor: PydanticUser, @@ -909,6 +919,7 @@ class AgentManager: return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) @enforce_types + @trace_method def list_agents_matching_tags( self, actor: PydanticUser, @@ -951,6 +962,7 @@ class AgentManager: return list(session.execute(query).scalars()) + @trace_method def size( self, actor: PydanticUser, @@ -961,6 +973,7 @@ class AgentManager: with db_registry.session() as session: return AgentModel.size(db_session=session, actor=actor) + @trace_method async def size_async( self, actor: PydanticUser, @@ -971,6 +984,7 @@ class AgentManager: async with db_registry.async_session() as session: return await AgentModel.size_async(db_session=session, actor=actor) + @trace_method @enforce_types def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: """Fetch an agent by its ID.""" @@ -978,6 +992,37 @@ class AgentManager: agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) return agent.to_pydantic() + @trace_method + @enforce_types + async def get_agent_by_id_async( + self, + agent_id: str, + actor: PydanticUser, + include_relationships: Optional[List[str]] = None, + ) -> PydanticAgentState: + """Fetch an agent by its ID.""" + async with db_registry.async_session() as session: + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + return await agent.to_pydantic_async(include_relationships=include_relationships) + + @trace_method + @enforce_types + async def get_agents_by_ids_async( + self, + agent_ids: list[str], + actor: PydanticUser, + include_relationships: Optional[List[str]] = None, + ) -> list[PydanticAgentState]: + """Fetch a list of agents by their IDs.""" + async with db_registry.async_session() as session: + agents = await AgentModel.read_multiple_async( + db_session=session, + identifiers=agent_ids, + actor=actor, + ) + return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) + + @trace_method @enforce_types async def get_agent_by_id_async( self, @@ -1013,6 +1058,7 @@ class AgentManager: agent = AgentModel.read(db_session=session, name=agent_name, actor=actor) return agent.to_pydantic() + @trace_method @enforce_types def delete_agent(self, agent_id: str, actor: PydanticUser) -> None: """ @@ -1060,6 +1106,57 @@ class AgentManager: else: logger.debug(f"Agent with ID {agent_id} successfully hard deleted") + @trace_method + @enforce_types + async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None: + """ + Deletes an agent and its associated relationships. + Ensures proper permission checks and cascades where applicable. + + Args: + agent_id: ID of the agent to be deleted. + actor: User performing the action. + + Raises: + NoResultFound: If agent doesn't exist + """ + async with db_registry.async_session() as session: + # Retrieve the agent + logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}") + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + agents_to_delete = [agent] + sleeptime_group_to_delete = None + + # Delete sleeptime agent and group (TODO this is flimsy pls fix) + if agent.multi_agent_group: + participant_agent_ids = agent.multi_agent_group.agent_ids + if agent.multi_agent_group.manager_type in {ManagerType.sleeptime, ManagerType.voice_sleeptime} and participant_agent_ids: + for participant_agent_id in participant_agent_ids: + try: + sleeptime_agent = await AgentModel.read_async(db_session=session, identifier=participant_agent_id, actor=actor) + agents_to_delete.append(sleeptime_agent) + except NoResultFound: + pass # agent already deleted + sleeptime_agent_group = await GroupModel.read_async( + db_session=session, identifier=agent.multi_agent_group.id, actor=actor + ) + sleeptime_group_to_delete = sleeptime_agent_group + + try: + if sleeptime_group_to_delete is not None: + await session.delete(sleeptime_group_to_delete) + await session.commit() + for agent in agents_to_delete: + await session.delete(agent) + await session.commit() + except Exception as e: + await session.rollback() + logger.exception(f"Failed to hard delete Agent with ID {agent_id}") + raise ValueError(f"Failed to hard delete Agent with ID {agent_id}: {e}") + else: + logger.debug(f"Agent with ID {agent_id} successfully hard deleted") + + @trace_method @enforce_types def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema: with db_registry.session() as session: @@ -1068,6 +1165,7 @@ class AgentManager: data = schema.dump(agent) return AgentSchema(**data) + @trace_method @enforce_types def deserialize( self, @@ -1137,6 +1235,7 @@ class AgentManager: # ====================================================================================================================== # Per Agent Environment Variable Management # ====================================================================================================================== + @trace_method @enforce_types def _set_environment_variables( self, @@ -1192,6 +1291,7 @@ class AgentManager: # Return the updated agent state return agent.to_pydantic() + @trace_method @enforce_types def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]: with db_registry.session() as session: @@ -1208,11 +1308,19 @@ class AgentManager: # TODO: 2) These messages are ordered from oldest to newest # TODO: This can be fixed by having an actual relationship in the ORM for message_ids # TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query. + @trace_method @enforce_types def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor) + @trace_method + @enforce_types + async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: + agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) + return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor) + + @trace_method @enforce_types async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) @@ -1223,6 +1331,7 @@ class AgentManager: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor) + @trace_method @enforce_types async def get_system_message_async(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: agent = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=[], actor=actor) @@ -1231,6 +1340,7 @@ class AgentManager: # TODO: This is duplicated below # TODO: This is legacy code and should be cleaned up # TODO: A lot of the memory "compilation" should be offset to a separate class + @trace_method @enforce_types def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState: """Rebuilds the system message with the latest memory object and any shared memory block updates @@ -1296,6 +1406,75 @@ class AgentManager: else: return agent_state + @trace_method + @enforce_types + async def rebuild_system_prompt_async( + self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True + ) -> PydanticAgentState: + """Rebuilds the system message with the latest memory object and any shared memory block updates + + Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object + + Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages + """ + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor) + + curr_system_message = await self.get_system_message_async( + agent_id=agent_id, actor=actor + ) # this is the system + memory bank, not just the system prompt + curr_system_message_openai = curr_system_message.to_openai_dict() + + # note: we only update the system prompt if the core memory is changed + # this means that the archival/recall memory statistics may be someout out of date + curr_memory_str = agent_state.memory.compile() + if curr_memory_str in curr_system_message_openai["content"] and not force: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.debug( + f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild" + ) + return agent_state + + # If the memory didn't update, we probably don't want to update the timestamp inside + # For example, if we're doing a system prompt swap, this should probably be False + if update_timestamp: + memory_edit_timestamp = get_utc_time() + else: + # NOTE: a bit of a hack - we pull the timestamp from the message created_by + memory_edit_timestamp = curr_system_message.created_at + + num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id) + num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id) + + # update memory (TODO: potentially update recall/archival stats separately) + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10), + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, + ) + + diff = united_diff(curr_system_message_openai["content"], new_system_message_str) + if len(diff) > 0: # there was a diff + logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") + + # Swap the system message out (only if there is a diff) + message = PydanticMessage.dict_to_message( + agent_id=agent_id, + model=agent_state.llm_config.model, + openai_message_dict={"role": "system", "content": new_system_message_str}, + ) + message = await self.message_manager.update_message_by_id_async( + message_id=curr_system_message.id, + message_update=MessageUpdate(**message.model_dump()), + actor=actor, + ) + return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor) + else: + return agent_state + + @trace_method @enforce_types async def rebuild_system_prompt_async( self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True @@ -1367,6 +1546,12 @@ class AgentManager: def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) + @trace_method + @enforce_types + async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: + return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) + + @trace_method @enforce_types async def set_in_context_messages_async(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState: return await self.update_agent_async(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor) @@ -1377,6 +1562,7 @@ class AgentManager: new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) + @trace_method @enforce_types def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids @@ -1384,6 +1570,7 @@ class AgentManager: new_messages = [message_ids[0]] # 0 is system message return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor) + @trace_method @enforce_types def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids @@ -1391,6 +1578,7 @@ class AgentManager: message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:] return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + @trace_method @enforce_types def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState: messages = self.message_manager.create_many_messages(messages, actor=actor) @@ -1398,6 +1586,7 @@ class AgentManager: message_ids += [m.id for m in messages] return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) + @trace_method @enforce_types def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState: """ @@ -1445,6 +1634,7 @@ class AgentManager: return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor) # TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version + @trace_method @enforce_types def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState: """ @@ -1482,6 +1672,7 @@ class AgentManager: return agent_state + @trace_method @enforce_types async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState: block_ids = [b.id for b in agent_state.memory.blocks] @@ -1496,6 +1687,7 @@ class AgentManager: # ====================================================================================================================== # Source Management # ====================================================================================================================== + @trace_method @enforce_types def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState: """ @@ -1540,6 +1732,7 @@ class AgentManager: return agent.to_pydantic() + @trace_method @enforce_types def append_system_message(self, agent_id: str, content: str, actor: PydanticUser): @@ -1552,6 +1745,7 @@ class AgentManager: # update agent in-context message IDs self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor) + @trace_method @enforce_types def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: """ @@ -1571,6 +1765,27 @@ class AgentManager: # Use the lazy-loaded relationship to get sources return [source.to_pydantic() for source in agent.sources] + @trace_method + @enforce_types + async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: + """ + Lists all sources attached to an agent. + + Args: + agent_id: ID of the agent to list sources for + actor: User performing the action + + Returns: + List[str]: List of source IDs attached to the agent + """ + async with db_registry.async_session() as session: + # Verify agent exists and user has permission to access it + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + + # Use the lazy-loaded relationship to get sources + return [source.to_pydantic() for source in agent.sources] + + @trace_method @enforce_types async def list_attached_sources_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]: """ @@ -1620,6 +1835,7 @@ class AgentManager: # ====================================================================================================================== # Block management # ====================================================================================================================== + @trace_method @enforce_types def get_block_with_label( self, @@ -1635,6 +1851,51 @@ class AgentManager: return block.to_pydantic() raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") + @trace_method + @enforce_types + async def get_block_with_label_async( + self, + agent_id: str, + block_label: str, + actor: PydanticUser, + ) -> PydanticBlock: + """Gets a block attached to an agent by its label.""" + async with db_registry.async_session() as session: + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + for block in agent.core_memory: + if block.label == block_label: + return block.to_pydantic() + raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") + + @trace_method + @enforce_types + async def modify_block_by_label_async( + self, + agent_id: str, + block_label: str, + block_update: BlockUpdate, + actor: PydanticUser, + ) -> PydanticBlock: + """Gets a block attached to an agent by its label.""" + async with db_registry.async_session() as session: + block = None + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + for block in agent.core_memory: + if block.label == block_label: + block = block + break + if not block: + raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'") + + update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + + for key, value in update_data.items(): + setattr(block, key, value) + + await block.update_async(session, actor=actor) + return block.to_pydantic() + + @trace_method @enforce_types async def modify_block_by_label_async( self, @@ -1686,6 +1947,7 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method @enforce_types def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: """Attaches a block to an agent.""" @@ -1697,6 +1959,19 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method + @enforce_types + async def attach_block_async(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState: + """Attaches a block to an agent.""" + async with db_registry.async_session() as session: + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor) + + agent.core_memory.append(block) + await agent.update_async(session, actor=actor) + return await agent.to_pydantic_async() + + @trace_method @enforce_types def detach_block( self, @@ -1717,6 +1992,28 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method + @enforce_types + async def detach_block_async( + self, + agent_id: str, + block_id: str, + actor: PydanticUser, + ) -> PydanticAgentState: + """Detaches a block from an agent.""" + async with db_registry.async_session() as session: + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + original_length = len(agent.core_memory) + + agent.core_memory = [b for b in agent.core_memory if b.id != block_id] + + if len(agent.core_memory) == original_length: + raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'") + + await agent.update_async(session, actor=actor) + return await agent.to_pydantic_async() + + @trace_method @enforce_types def detach_block_with_label( self, @@ -1769,105 +2066,121 @@ class AgentManager: embedded_text = np.array(embedded_text) embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() - with db_registry.session() as session: - # Start with base query for source passages - source_passages = None - if not agent_only: # Include source passages - if agent_id is not None: - source_passages = ( - select(SourcePassage, literal(None).label("agent_id")) - .join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id) - .where(SourcesAgents.agent_id == agent_id) - .where(SourcePassage.organization_id == actor.organization_id) - ) - else: - source_passages = select(SourcePassage, literal(None).label("agent_id")).where( - SourcePassage.organization_id == actor.organization_id - ) - - if source_id: - source_passages = source_passages.where(SourcePassage.source_id == source_id) - if file_id: - source_passages = source_passages.where(SourcePassage.file_id == file_id) - - # Add agent passages query - agent_passages = None + # Start with base query for source passages + source_passages = None + if not agent_only: # Include source passages if agent_id is not None: - agent_passages = ( - select( - AgentPassage.id, - AgentPassage.text, - AgentPassage.embedding_config, - AgentPassage.metadata_, - AgentPassage.embedding, - AgentPassage.created_at, - AgentPassage.updated_at, - AgentPassage.is_deleted, - AgentPassage._created_by_id, - AgentPassage._last_updated_by_id, - AgentPassage.organization_id, - literal(None).label("file_id"), - literal(None).label("source_id"), - AgentPassage.agent_id, - ) - .where(AgentPassage.agent_id == agent_id) - .where(AgentPassage.organization_id == actor.organization_id) + source_passages = ( + select(SourcePassage, literal(None).label("agent_id")) + .join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id) + .where(SourcesAgents.agent_id == agent_id) + .where(SourcePassage.organization_id == actor.organization_id) + ) + else: + source_passages = select(SourcePassage, literal(None).label("agent_id")).where( + SourcePassage.organization_id == actor.organization_id ) - # Combine queries - if source_passages is not None and agent_passages is not None: - combined_query = union_all(source_passages, agent_passages).cte("combined_passages") - elif agent_passages is not None: - combined_query = agent_passages.cte("combined_passages") - elif source_passages is not None: - combined_query = source_passages.cte("combined_passages") - else: - raise ValueError("No passages found") - - # Build main query from combined CTE - main_query = select(combined_query) - - # Apply filters - if start_date: - main_query = main_query.where(combined_query.c.created_at >= start_date) - if end_date: - main_query = main_query.where(combined_query.c.created_at <= end_date) if source_id: - main_query = main_query.where(combined_query.c.source_id == source_id) + source_passages = source_passages.where(SourcePassage.source_id == source_id) if file_id: - main_query = main_query.where(combined_query.c.file_id == file_id) + source_passages = source_passages.where(SourcePassage.file_id == file_id) - # Vector search - if embedded_text: - if settings.letta_pg_uri_no_default: - # PostgreSQL with pgvector - main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc()) - else: - # SQLite with custom vector type - query_embedding_binary = adapt_array(embedded_text) - main_query = main_query.order_by( - func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(), - combined_query.c.created_at.asc() if ascending else combined_query.c.created_at.desc(), - combined_query.c.id.asc(), - ) + # Add agent passages query + agent_passages = None + if agent_id is not None: + agent_passages = ( + select( + AgentPassage.id, + AgentPassage.text, + AgentPassage.embedding_config, + AgentPassage.metadata_, + AgentPassage.embedding, + AgentPassage.created_at, + AgentPassage.updated_at, + AgentPassage.is_deleted, + AgentPassage._created_by_id, + AgentPassage._last_updated_by_id, + AgentPassage.organization_id, + literal(None).label("file_id"), + literal(None).label("source_id"), + AgentPassage.agent_id, + ) + .where(AgentPassage.agent_id == agent_id) + .where(AgentPassage.organization_id == actor.organization_id) + ) + + # Combine queries + if source_passages is not None and agent_passages is not None: + combined_query = union_all(source_passages, agent_passages).cte("combined_passages") + elif agent_passages is not None: + combined_query = agent_passages.cte("combined_passages") + elif source_passages is not None: + combined_query = source_passages.cte("combined_passages") + else: + raise ValueError("No passages found") + + # Build main query from combined CTE + main_query = select(combined_query) + + # Apply filters + if start_date: + main_query = main_query.where(combined_query.c.created_at >= start_date) + if end_date: + main_query = main_query.where(combined_query.c.created_at <= end_date) + if source_id: + main_query = main_query.where(combined_query.c.source_id == source_id) + if file_id: + main_query = main_query.where(combined_query.c.file_id == file_id) + + # Vector search + if embedded_text: + if settings.letta_pg_uri_no_default: + # PostgreSQL with pgvector + main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc()) else: - if query_text: - main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text))) + # SQLite with custom vector type + query_embedding_binary = adapt_array(embedded_text) + main_query = main_query.order_by( + func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(), + combined_query.c.created_at.asc() if ascending else combined_query.c.created_at.desc(), + combined_query.c.id.asc(), + ) + else: + if query_text: + main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text))) - # Handle pagination - if before or after: - # Create reference CTEs + # Handle pagination + if before or after: + # Create reference CTEs + if before: + before_ref = select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == before).cte("before_ref") + if after: + after_ref = select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == after).cte("after_ref") + + if before and after: + # Window-based query (get records between before and after) + main_query = main_query.where( + or_( + combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(), + and_( + combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(), + combined_query.c.id < select(before_ref.c.id).scalar_subquery(), + ), + ) + ) + main_query = main_query.where( + or_( + combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(), + and_( + combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(), + combined_query.c.id > select(after_ref.c.id).scalar_subquery(), + ), + ) + ) + else: + # Pure pagination (only before or only after) if before: - before_ref = ( - select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == before).cte("before_ref") - ) - if after: - after_ref = ( - select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == after).cte("after_ref") - ) - - if before and after: - # Window-based query (get records between before and after) main_query = main_query.where( or_( combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(), @@ -1877,6 +2190,7 @@ class AgentManager: ), ) ) + if after: main_query = main_query.where( or_( combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(), @@ -1886,44 +2200,23 @@ class AgentManager: ), ) ) - else: - # Pure pagination (only before or only after) - if before: - main_query = main_query.where( - or_( - combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(), - and_( - combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(), - combined_query.c.id < select(before_ref.c.id).scalar_subquery(), - ), - ) - ) - if after: - main_query = main_query.where( - or_( - combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(), - and_( - combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(), - combined_query.c.id > select(after_ref.c.id).scalar_subquery(), - ), - ) - ) - # Add ordering if not already ordered by similarity - if not embed_query: - if ascending: - main_query = main_query.order_by( - combined_query.c.created_at.asc(), - combined_query.c.id.asc(), - ) - else: - main_query = main_query.order_by( - combined_query.c.created_at.desc(), - combined_query.c.id.asc(), - ) + # Add ordering if not already ordered by similarity + if not embed_query: + if ascending: + main_query = main_query.order_by( + combined_query.c.created_at.asc(), + combined_query.c.id.asc(), + ) + else: + main_query = main_query.order_by( + combined_query.c.created_at.desc(), + combined_query.c.id.asc(), + ) return main_query + @trace_method @enforce_types def list_passages( self, @@ -1983,6 +2276,67 @@ class AgentManager: return [p.to_pydantic() for p in passages] + @trace_method + @enforce_types + async def list_passages_async( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + limit: Optional[int] = 50, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + before: Optional[str] = None, + after: Optional[str] = None, + source_id: Optional[str] = None, + embed_query: bool = False, + ascending: bool = True, + embedding_config: Optional[EmbeddingConfig] = None, + agent_only: bool = False, + ) -> List[PydanticPassage]: + """Lists all passages attached to an agent.""" + async with db_registry.async_session() as session: + main_query = self._build_passage_query( + actor=actor, + agent_id=agent_id, + file_id=file_id, + query_text=query_text, + start_date=start_date, + end_date=end_date, + before=before, + after=after, + source_id=source_id, + embed_query=embed_query, + ascending=ascending, + embedding_config=embedding_config, + agent_only=agent_only, + ) + + # Add limit + if limit: + main_query = main_query.limit(limit) + + # Execute query + result = await session.execute(main_query) + + passages = [] + for row in result: + data = dict(row._mapping) + if data["agent_id"] is not None: + # This is an AgentPassage - remove source fields + data.pop("source_id", None) + data.pop("file_id", None) + passage = AgentPassage(**data) + else: + # This is a SourcePassage - remove agent field + data.pop("agent_id", None) + passage = SourcePassage(**data) + passages.append(passage) + + return [p.to_pydantic() for p in passages] + + @trace_method @enforce_types async def list_passages_async( self, @@ -2081,9 +2435,48 @@ class AgentManager: count_query = select(func.count()).select_from(main_query.subquery()) return session.scalar(count_query) or 0 + @enforce_types + async def passage_size_async( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + query_text: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + before: Optional[str] = None, + after: Optional[str] = None, + source_id: Optional[str] = None, + embed_query: bool = False, + ascending: bool = True, + embedding_config: Optional[EmbeddingConfig] = None, + agent_only: bool = False, + ) -> int: + async with db_registry.async_session() as session: + main_query = self._build_passage_query( + actor=actor, + agent_id=agent_id, + file_id=file_id, + query_text=query_text, + start_date=start_date, + end_date=end_date, + before=before, + after=after, + source_id=source_id, + embed_query=embed_query, + ascending=ascending, + embedding_config=embedding_config, + agent_only=agent_only, + ) + + # Convert to count query + count_query = select(func.count()).select_from(main_query.subquery()) + return (await session.execute(count_query)).scalar() or 0 + # ====================================================================================================================== # Tool Management # ====================================================================================================================== + @trace_method @enforce_types def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: """ @@ -2119,6 +2512,7 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method @enforce_types def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState: """ @@ -2152,6 +2546,7 @@ class AgentManager: agent.update(session, actor=actor) return agent.to_pydantic() + @trace_method @enforce_types def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]: """ @@ -2171,6 +2566,7 @@ class AgentManager: # ====================================================================================================================== # Tag Management # ====================================================================================================================== + @trace_method @enforce_types def list_tags( self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None @@ -2205,6 +2601,7 @@ class AgentManager: results = [tag[0] for tag in query.all()] return results + @trace_method @enforce_types async def list_tags_async( self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None @@ -2243,3 +2640,279 @@ class AgentManager: # Extract the tag values from the result results = [row[0] for row in result.all()] return results + + async def get_context_window(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: + if os.getenv("LETTA_ENVIRONMENT") == "PRODUCTION": + return await self.get_context_window_from_anthropic_async(agent_id=agent_id, actor=actor) + return await self.get_context_window_from_tiktoken_async(agent_id=agent_id, actor=actor) + + async def get_context_window_from_anthropic_async(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: + """Get the context window of the agent""" + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) + model = agent_state.llm_config.model if agent_state.llm_config.model_endpoint_type == "anthropic" else None + + # Grab the in-context messages + # conversion of messages to anthropic dict format, which is passed to the token counter + (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( + self.get_in_context_messages_async(agent_id=agent_id, actor=actor), + self.passage_manager.size_async(actor=actor, agent_id=agent_id), + self.message_manager.size_async(actor=actor, agent_id=agent_id), + ) + in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages] + + # Extract system, memory and external summary + if ( + len(in_context_messages) > 0 + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + + external_memory_marker_pos = system_message.find("###") + core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) + if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: + system_prompt = system_message[:external_memory_marker_pos].strip() + external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() + core_memory = system_message[core_memory_marker_pos:].strip() + else: + # if no markers found, put everything in system message + system_prompt = system_message + external_memory_summary = None + core_memory = None + else: + # if no system message, fall back on agent's system prompt + system_prompt = agent_state.system + external_memory_summary = None + core_memory = None + + num_tokens_system_coroutine = anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": system_prompt}]) + num_tokens_core_memory_coroutine = ( + anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": core_memory}]) + if core_memory + else asyncio.sleep(0, result=0) + ) + num_tokens_external_memory_summary_coroutine = ( + anthropic_client.count_tokens(model=model, messages=[{"role": "user", "content": external_memory_summary}]) + if external_memory_summary + else asyncio.sleep(0, result=0) + ) + + # Check if there's a summary message in the message queue + if ( + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and in_context_messages[1].content + and len(in_context_messages[1].content) == 1 + and isinstance(in_context_messages[1].content[0], TextContent) + # TODO remove hardcoding + and "The following is a summary of the previous " in in_context_messages[1].content[0].text + ): + # Summary message exists + text_content = in_context_messages[1].content[0].text + assert text_content is not None + summary_memory = text_content + num_tokens_summary_memory_coroutine = anthropic_client.count_tokens( + model=model, messages=[{"role": "user", "content": summary_memory}] + ) + # with a summary message, the real messages start at index 2 + num_tokens_messages_coroutine = ( + anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[2:]) + if len(in_context_messages_anthropic) > 2 + else asyncio.sleep(0, result=0) + ) + + else: + summary_memory = None + num_tokens_summary_memory_coroutine = asyncio.sleep(0, result=0) + # with no summary message, the real messages start at index 1 + num_tokens_messages_coroutine = ( + anthropic_client.count_tokens(model=model, messages=in_context_messages_anthropic[1:]) + if len(in_context_messages_anthropic) > 1 + else asyncio.sleep(0, result=0) + ) + + # tokens taken up by function definitions + if agent_state.tools and len(agent_state.tools) > 0: + available_functions_definitions = [OpenAITool(type="function", function=f.json_schema) for f in agent_state.tools] + num_tokens_available_functions_definitions_coroutine = anthropic_client.count_tokens( + model=model, + tools=available_functions_definitions, + ) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions_coroutine = asyncio.sleep(0, result=0) + + ( + num_tokens_system, + num_tokens_core_memory, + num_tokens_external_memory_summary, + num_tokens_summary_memory, + num_tokens_messages, + num_tokens_available_functions_definitions, + ) = await asyncio.gather( + num_tokens_system_coroutine, + num_tokens_core_memory_coroutine, + num_tokens_external_memory_summary_coroutine, + num_tokens_summary_memory_coroutine, + num_tokens_messages_coroutine, + num_tokens_available_functions_definitions_coroutine, + ) + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=passage_manager_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) + + async def get_context_window_from_tiktoken_async(self, agent_id: str, actor: PydanticUser) -> ContextWindowOverview: + """Get the context window of the agent""" + from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages + + agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + # Grab the in-context messages + # conversion of messages to OpenAI dict format, which is passed to the token counter + (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( + self.get_in_context_messages_async(agent_id=agent_id, actor=actor), + self.passage_manager.size_async(actor=actor, agent_id=agent_id), + self.message_manager.size_async(actor=actor, agent_id=agent_id), + ) + in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] + + # Extract system, memory and external summary + if ( + len(in_context_messages) > 0 + and in_context_messages[0].role == MessageRole.system + and in_context_messages[0].content + and len(in_context_messages[0].content) == 1 + and isinstance(in_context_messages[0].content[0], TextContent) + ): + system_message = in_context_messages[0].content[0].text + + external_memory_marker_pos = system_message.find("###") + core_memory_marker_pos = system_message.find("<", external_memory_marker_pos) + if external_memory_marker_pos != -1 and core_memory_marker_pos != -1: + system_prompt = system_message[:external_memory_marker_pos].strip() + external_memory_summary = system_message[external_memory_marker_pos:core_memory_marker_pos].strip() + core_memory = system_message[core_memory_marker_pos:].strip() + else: + # if no markers found, put everything in system message + system_prompt = system_message + external_memory_summary = "" + core_memory = "" + else: + # if no system message, fall back on agent's system prompt + system_prompt = agent_state.system + external_memory_summary = "" + core_memory = "" + + num_tokens_system = count_tokens(system_prompt) + num_tokens_core_memory = count_tokens(core_memory) + num_tokens_external_memory_summary = count_tokens(external_memory_summary) + + # Check if there's a summary message in the message queue + if ( + len(in_context_messages) > 1 + and in_context_messages[1].role == MessageRole.user + and in_context_messages[1].content + and len(in_context_messages[1].content) == 1 + and isinstance(in_context_messages[1].content[0], TextContent) + # TODO remove hardcoding + and "The following is a summary of the previous " in in_context_messages[1].content[0].text + ): + # Summary message exists + text_content = in_context_messages[1].content[0].text + assert text_content is not None + summary_memory = text_content + num_tokens_summary_memory = count_tokens(text_content) + # with a summary message, the real messages start at index 2 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[2:], model=agent_state.llm_config.model) + if len(in_context_messages_openai) > 2 + else 0 + ) + + else: + summary_memory = None + num_tokens_summary_memory = 0 + # with no summary message, the real messages start at index 1 + num_tokens_messages = ( + num_tokens_from_messages(messages=in_context_messages_openai[1:], model=agent_state.llm_config.model) + if len(in_context_messages_openai) > 1 + else 0 + ) + + # tokens taken up by function definitions + agent_state_tool_jsons = [t.json_schema for t in agent_state.tools] + if agent_state_tool_jsons: + available_functions_definitions = [OpenAITool(type="function", function=f) for f in agent_state_tool_jsons] + num_tokens_available_functions_definitions = num_tokens_from_functions( + functions=agent_state_tool_jsons, model=agent_state.llm_config.model + ) + else: + available_functions_definitions = [] + num_tokens_available_functions_definitions = 0 + + num_tokens_used_total = ( + num_tokens_system # system prompt + + num_tokens_available_functions_definitions # function definitions + + num_tokens_core_memory # core memory + + num_tokens_external_memory_summary # metadata (statistics) about recall/archival + + num_tokens_summary_memory # summary of ongoing conversation + + num_tokens_messages # tokens taken by messages + ) + assert isinstance(num_tokens_used_total, int) + + return ContextWindowOverview( + # context window breakdown (in messages) + num_messages=len(in_context_messages), + num_archival_memory=passage_manager_size, + num_recall_memory=message_manager_size, + num_tokens_external_memory_summary=num_tokens_external_memory_summary, + external_memory_summary=external_memory_summary, + # top-level information + context_window_size_max=agent_state.llm_config.context_window, + context_window_size_current=num_tokens_used_total, + # context window breakdown (in tokens) + num_tokens_system=num_tokens_system, + system_prompt=system_prompt, + num_tokens_core_memory=num_tokens_core_memory, + core_memory=core_memory, + num_tokens_summary_memory=num_tokens_summary_memory, + summary_memory=summary_memory, + num_tokens_messages=num_tokens_messages, + messages=in_context_messages, + # related to functions + num_tokens_functions_definitions=num_tokens_available_functions_definitions, + functions_definitions=available_functions_definitions, + ) diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 2d568e34b..0795ed7fe 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -14,6 +14,7 @@ from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types logger = get_logger(__name__) @@ -22,6 +23,7 @@ logger = get_logger(__name__) class BlockManager: """Manager class to handle business logic related to Blocks.""" + @trace_method @enforce_types def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock: """Create a new block based on the Block schema.""" @@ -36,6 +38,7 @@ class BlockManager: block.create(session, actor=actor) return block.to_pydantic() + @trace_method @enforce_types def batch_create_blocks(self, blocks: List[PydanticBlock], actor: PydanticUser) -> List[PydanticBlock]: """ @@ -59,6 +62,7 @@ class BlockManager: # Convert back to Pydantic return [m.to_pydantic() for m in created_models] + @trace_method @enforce_types def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock: """Update a block by its ID with the given BlockUpdate object.""" @@ -74,6 +78,7 @@ class BlockManager: block.update(db_session=session, actor=actor) return block.to_pydantic() + @trace_method @enforce_types def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock: """Delete a block by its ID.""" @@ -82,6 +87,7 @@ class BlockManager: block.hard_delete(db_session=session, actor=actor) return block.to_pydantic() + @trace_method @enforce_types async def get_blocks_async( self, @@ -144,68 +150,7 @@ class BlockManager: return [block.to_pydantic() for block in blocks] - @enforce_types - async def get_blocks_async( - self, - actor: PydanticUser, - label: Optional[str] = None, - is_template: Optional[bool] = None, - template_name: Optional[str] = None, - identity_id: Optional[str] = None, - identifier_keys: Optional[List[str]] = None, - limit: Optional[int] = 50, - ) -> List[PydanticBlock]: - """Async version of get_blocks method. Retrieve blocks based on various optional filters.""" - from sqlalchemy import select - from sqlalchemy.orm import noload - - from letta.orm.sqlalchemy_base import AccessType - - async with db_registry.async_session() as session: - # Start with a basic query - query = select(BlockModel) - - # Explicitly avoid loading relationships - query = query.options(noload(BlockModel.agents), noload(BlockModel.identities), noload(BlockModel.groups)) - - # Apply access control - query = BlockModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION) - - # Add filters - query = query.where(BlockModel.organization_id == actor.organization_id) - if label: - query = query.where(BlockModel.label == label) - - if is_template is not None: - query = query.where(BlockModel.is_template == is_template) - - if template_name: - query = query.where(BlockModel.template_name == template_name) - - if identifier_keys: - query = ( - query.join(BlockModel.identities) - .filter(BlockModel.identities.property.mapper.class_.identifier_key.in_(identifier_keys)) - .distinct(BlockModel.id) - ) - - if identity_id: - query = ( - query.join(BlockModel.identities) - .filter(BlockModel.identities.property.mapper.class_.id == identity_id) - .distinct(BlockModel.id) - ) - - # Add limit - if limit: - query = query.limit(limit) - - # Execute the query - result = await session.execute(query) - blocks = result.scalars().all() - - return [block.to_pydantic() for block in blocks] - + @trace_method @enforce_types def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]: """Retrieve a block by its name.""" @@ -216,6 +161,7 @@ class BlockManager: except NoResultFound: return None + @trace_method @enforce_types async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]: """Retrieve blocks by their ids without loading unnecessary relationships. Async implementation.""" @@ -263,6 +209,7 @@ class BlockManager: return pydantic_blocks + @trace_method @enforce_types async def get_agents_for_block_async(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]: """ @@ -273,6 +220,7 @@ class BlockManager: agents_orm = block.agents return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) + @trace_method @enforce_types def size( self, @@ -286,6 +234,7 @@ class BlockManager: # Block History Functions + @trace_method @enforce_types def checkpoint_block( self, @@ -389,6 +338,7 @@ class BlockManager: updated_block = block.update(db_session=session, actor=actor, no_commit=True) return updated_block + @trace_method @enforce_types def undo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: """ @@ -431,6 +381,7 @@ class BlockManager: session.commit() return block.to_pydantic() + @trace_method @enforce_types def redo_checkpoint_block(self, block_id: str, actor: PydanticUser, use_preloaded_block: Optional[BlockModel] = None) -> PydanticBlock: """ @@ -469,6 +420,7 @@ class BlockManager: session.commit() return block.to_pydantic() + @trace_method @enforce_types async def bulk_update_block_values_async( self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False diff --git a/letta/services/group_manager.py b/letta/services/group_manager.py index 7adf49ec9..4bce5825f 100644 --- a/letta/services/group_manager.py +++ b/letta/services/group_manager.py @@ -12,11 +12,13 @@ from letta.schemas.letta_message import LettaMessage from letta.schemas.message import Message as PydanticMessage from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types class GroupManager: + @trace_method @enforce_types def list_groups( self, @@ -42,12 +44,14 @@ class GroupManager: ) return [group.to_pydantic() for group in groups] + @trace_method @enforce_types def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup: with db_registry.session() as session: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) return group.to_pydantic() + @trace_method @enforce_types def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup: with db_registry.session() as session: @@ -93,6 +97,7 @@ class GroupManager: new_group.create(session, actor=actor) return new_group.to_pydantic() + @trace_method @enforce_types def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup: with db_registry.session() as session: @@ -155,6 +160,7 @@ class GroupManager: group.update(session, actor=actor) return group.to_pydantic() + @trace_method @enforce_types def delete_group(self, group_id: str, actor: PydanticUser) -> None: with db_registry.session() as session: @@ -162,6 +168,7 @@ class GroupManager: group = GroupModel.read(db_session=session, identifier=group_id, actor=actor) group.hard_delete(session) + @trace_method @enforce_types def list_group_messages( self, @@ -198,6 +205,7 @@ class GroupManager: return messages + @trace_method @enforce_types def reset_messages(self, group_id: str, actor: PydanticUser) -> None: with db_registry.session() as session: @@ -211,6 +219,7 @@ class GroupManager: session.commit() + @trace_method @enforce_types def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int: with db_registry.session() as session: @@ -222,6 +231,18 @@ class GroupManager: group.update(session, actor=actor) return group.turns_counter + @trace_method + @enforce_types + async def bump_turns_counter_async(self, group_id: str, actor: PydanticUser) -> int: + async with db_registry.async_session() as session: + # Ensure group is loadable by user + group = await GroupModel.read_async(session, identifier=group_id, actor=actor) + + # Update turns counter + group.turns_counter = (group.turns_counter + 1) % group.sleeptime_agent_frequency + await group.update_async(session, actor=actor) + return group.turns_counter + @enforce_types def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str: with db_registry.session() as session: @@ -235,6 +256,22 @@ class GroupManager: return prev_last_processed_message_id + @trace_method + @enforce_types + async def get_last_processed_message_id_and_update_async( + self, group_id: str, last_processed_message_id: str, actor: PydanticUser + ) -> str: + async with db_registry.async_session() as session: + # Ensure group is loadable by user + group = await GroupModel.read_async(session, identifier=group_id, actor=actor) + + # Update last processed message id + prev_last_processed_message_id = group.last_processed_message_id + group.last_processed_message_id = last_processed_message_id + await group.update_async(session, actor=actor) + + return prev_last_processed_message_id + @enforce_types def size( self, diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 590cedeeb..c13e83926 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -12,12 +12,14 @@ from letta.schemas.identity import Identity as PydanticIdentity from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types class IdentityManager: @enforce_types + @trace_method async def list_identities_async( self, name: Optional[str] = None, @@ -48,12 +50,14 @@ class IdentityManager: return [identity.to_pydantic() for identity in identities] @enforce_types + @trace_method async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity: async with db_registry.async_session() as session: identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) return identity.to_pydantic() @enforce_types + @trace_method async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity: async with db_registry.async_session() as session: new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True)) @@ -78,6 +82,7 @@ class IdentityManager: return new_identity.to_pydantic() @enforce_types + @trace_method async def upsert_identity_async(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity: async with db_registry.async_session() as session: existing_identity = await IdentityModel.read_async( @@ -103,6 +108,7 @@ class IdentityManager: ) @enforce_types + @trace_method async def update_identity_async( self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False ) -> PydanticIdentity: @@ -165,6 +171,7 @@ class IdentityManager: return existing_identity.to_pydantic() @enforce_types + @trace_method async def upsert_identity_properties_async( self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser ) -> PydanticIdentity: @@ -181,6 +188,7 @@ class IdentityManager: ) @enforce_types + @trace_method async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None: async with db_registry.async_session() as session: identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor) @@ -192,6 +200,7 @@ class IdentityManager: await session.commit() @enforce_types + @trace_method async def size_async( self, actor: PydanticUser, diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index d3c7ca590..3cd1ee035 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -25,6 +25,7 @@ from letta.schemas.step import Step as PydanticStep from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types @@ -32,6 +33,7 @@ class JobManager: """Manager class to handle business logic related to Jobs.""" @enforce_types + @trace_method def create_job( self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser ) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]: @@ -45,6 +47,7 @@ class JobManager: return job.to_pydantic() @enforce_types + @trace_method async def create_job_async( self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser ) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]: @@ -58,6 +61,7 @@ class JobManager: return job.to_pydantic() @enforce_types + @trace_method def update_job_by_id(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob: """Update a job by its ID with the given JobUpdate object.""" with db_registry.session() as session: @@ -82,6 +86,7 @@ class JobManager: return job.to_pydantic() @enforce_types + @trace_method async def update_job_by_id_async(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob: """Update a job by its ID with the given JobUpdate object asynchronously.""" async with db_registry.async_session() as session: @@ -106,6 +111,7 @@ class JobManager: return job.to_pydantic() @enforce_types + @trace_method def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Fetch a job by its ID.""" with db_registry.session() as session: @@ -114,6 +120,7 @@ class JobManager: return job.to_pydantic() @enforce_types + @trace_method async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Fetch a job by its ID asynchronously.""" async with db_registry.async_session() as session: @@ -122,6 +129,7 @@ class JobManager: return job.to_pydantic() @enforce_types + @trace_method def list_jobs( self, actor: PydanticUser, @@ -151,6 +159,7 @@ class JobManager: return [job.to_pydantic() for job in jobs] @enforce_types + @trace_method async def list_jobs_async( self, actor: PydanticUser, @@ -180,6 +189,7 @@ class JobManager: return [job.to_pydantic() for job in jobs] @enforce_types + @trace_method def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob: """Delete a job by its ID.""" with db_registry.session() as session: @@ -188,6 +198,7 @@ class JobManager: return job.to_pydantic() @enforce_types + @trace_method def get_job_messages( self, job_id: str, @@ -238,6 +249,7 @@ class JobManager: return [message.to_pydantic() for message in messages] @enforce_types + @trace_method def get_job_steps( self, job_id: str, @@ -283,6 +295,7 @@ class JobManager: return [step.to_pydantic() for step in steps] @enforce_types + @trace_method def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None: """ Associate a message with a job by creating a JobMessage record. @@ -306,6 +319,7 @@ class JobManager: session.commit() @enforce_types + @trace_method def get_job_usage(self, job_id: str, actor: PydanticUser) -> LettaUsageStatistics: """ Get usage statistics for a job. @@ -343,6 +357,7 @@ class JobManager: ) @enforce_types + @trace_method def add_job_usage( self, job_id: str, @@ -383,6 +398,7 @@ class JobManager: session.commit() @enforce_types + @trace_method def get_run_messages( self, run_id: str, @@ -434,6 +450,7 @@ class JobManager: return messages @enforce_types + @trace_method def get_step_messages( self, run_id: str, diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 052e2bbe2..c296a64eb 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -17,6 +17,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types logger = get_logger(__name__) @@ -26,6 +27,7 @@ class LLMBatchManager: """Manager for handling both LLMBatchJob and LLMBatchItem operations.""" @enforce_types + @trace_method async def create_llm_batch_job_async( self, llm_provider: ProviderType, @@ -47,6 +49,7 @@ class LLMBatchManager: return batch.to_pydantic() @enforce_types + @trace_method async def get_llm_batch_job_by_id_async(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob: """Retrieve a single batch job by ID.""" async with db_registry.async_session() as session: @@ -54,7 +57,8 @@ class LLMBatchManager: return batch.to_pydantic() @enforce_types - def update_llm_batch_status( + @trace_method + async def update_llm_batch_status_async( self, llm_batch_id: str, status: JobStatus, @@ -62,15 +66,15 @@ class LLMBatchManager: latest_polling_response: Optional[BetaMessageBatch] = None, ) -> PydanticLLMBatchJob: """Update a batch job’s status and optionally its polling response.""" - with db_registry.session() as session: - batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) + async with db_registry.async_session() as session: + batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor) batch.status = status batch.latest_polling_response = latest_polling_response batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc) - batch = batch.update(db_session=session, actor=actor) + batch = await batch.update_async(db_session=session, actor=actor) return batch.to_pydantic() - def bulk_update_llm_batch_statuses( + async def bulk_update_llm_batch_statuses_async( self, updates: List[BatchPollingResult], ) -> None: @@ -81,7 +85,7 @@ class LLMBatchManager: """ now = datetime.datetime.now(datetime.timezone.utc) - with db_registry.session() as session: + async with db_registry.async_session() as session: mappings = [] for llm_batch_id, status, response in updates: mappings.append( @@ -93,17 +97,18 @@ class LLMBatchManager: } ) - session.bulk_update_mappings(LLMBatchJob, mappings) - session.commit() + await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings)) + await session.commit() @enforce_types - def list_llm_batch_jobs( + @trace_method + async def list_llm_batch_jobs_async( self, letta_batch_id: str, limit: Optional[int] = None, actor: Optional[PydanticUser] = None, after: Optional[str] = None, - ) -> List[PydanticLLMBatchItem]: + ) -> List[PydanticLLMBatchJob]: """ List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count. @@ -115,33 +120,35 @@ class LLMBatchManager: The results are ordered by their id in ascending order. """ - with db_registry.session() as session: - query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id) + async with db_registry.async_session() as session: + query = select(LLMBatchJob).where(LLMBatchJob.letta_batch_job_id == letta_batch_id) if actor is not None: - query = query.filter(LLMBatchJob.organization_id == actor.organization_id) + query = query.where(LLMBatchJob.organization_id == actor.organization_id) # Additional optional filters if after is not None: - query = query.filter(LLMBatchJob.id > after) + query = query.where(LLMBatchJob.id > after) query = query.order_by(LLMBatchJob.id.asc()) if limit is not None: query = query.limit(limit) - results = query.all() - return [item.to_pydantic() for item in results] + results = await session.execute(query) + return [item.to_pydantic() for item in results.scalars().all()] @enforce_types - def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None: + @trace_method + async def delete_llm_batch_request_async(self, llm_batch_id: str, actor: PydanticUser) -> None: """Hard delete a batch job by ID.""" - with db_registry.session() as session: - batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) - batch.hard_delete(db_session=session, actor=actor) + async with db_registry.async_session() as session: + batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor) + await batch.hard_delete_async(db_session=session, actor=actor) @enforce_types - def get_messages_for_letta_batch( + @trace_method + async def get_messages_for_letta_batch_async( self, letta_batch_job_id: str, limit: int = 100, @@ -154,12 +161,12 @@ class LLMBatchManager: Retrieve messages across all LLM batch jobs associated with a Letta batch job. Optimized for PostgreSQL performance using ID-based keyset pagination. """ - with db_registry.session() as session: + async with db_registry.async_session() as session: # If cursor is provided, get sequence_id for that message cursor_sequence_id = None if cursor: - cursor_query = session.query(MessageModel.sequence_id).filter(MessageModel.id == cursor).limit(1) - cursor_result = cursor_query.first() + cursor_query = select(MessageModel.sequence_id).where(MessageModel.id == cursor).limit(1) + cursor_result = await session.execute(cursor_query) if cursor_result: cursor_sequence_id = cursor_result[0] else: @@ -167,24 +174,24 @@ class LLMBatchManager: pass query = ( - session.query(MessageModel) + select(MessageModel) .join(LLMBatchItem, MessageModel.batch_item_id == LLMBatchItem.id) .join(LLMBatchJob, LLMBatchItem.llm_batch_id == LLMBatchJob.id) - .filter(LLMBatchJob.letta_batch_job_id == letta_batch_job_id) + .where(LLMBatchJob.letta_batch_job_id == letta_batch_job_id) ) if actor is not None: - query = query.filter(MessageModel.organization_id == actor.organization_id) + query = query.where(MessageModel.organization_id == actor.organization_id) if agent_id is not None: - query = query.filter(MessageModel.agent_id == agent_id) + query = query.where(MessageModel.agent_id == agent_id) # Apply cursor-based pagination if cursor exists if cursor_sequence_id is not None: if sort_descending: - query = query.filter(MessageModel.sequence_id < cursor_sequence_id) + query = query.where(MessageModel.sequence_id < cursor_sequence_id) else: - query = query.filter(MessageModel.sequence_id > cursor_sequence_id) + query = query.where(MessageModel.sequence_id > cursor_sequence_id) if sort_descending: query = query.order_by(desc(MessageModel.sequence_id)) @@ -193,10 +200,11 @@ class LLMBatchManager: query = query.limit(limit) - results = query.all() - return [message.to_pydantic() for message in results] + results = await session.execute(query) + return [message.to_pydantic() for message in results.scalars().all()] @enforce_types + @trace_method async def list_running_llm_batches_async(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: """Return all running LLM batch jobs, optionally filtered by actor's organization.""" async with db_registry.async_session() as session: @@ -209,7 +217,8 @@ class LLMBatchManager: return [batch.to_pydantic() for batch in results.scalars().all()] @enforce_types - def create_llm_batch_item( + @trace_method + async def create_llm_batch_item_async( self, llm_batch_id: str, agent_id: str, @@ -220,7 +229,7 @@ class LLMBatchManager: step_state: Optional[AgentStepState] = None, ) -> PydanticLLMBatchItem: """Create a new batch item.""" - with db_registry.session() as session: + async with db_registry.async_session() as session: item = LLMBatchItem( llm_batch_id=llm_batch_id, agent_id=agent_id, @@ -230,10 +239,11 @@ class LLMBatchManager: step_state=step_state, organization_id=actor.organization_id, ) - item.create(session, actor=actor) + await item.create_async(session, actor=actor) return item.to_pydantic() @enforce_types + @trace_method async def create_llm_batch_items_bulk_async( self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser ) -> List[PydanticLLMBatchItem]: @@ -269,14 +279,16 @@ class LLMBatchManager: return [item.to_pydantic() for item in created_items] @enforce_types - def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: + @trace_method + async def get_llm_batch_item_by_id_async(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: """Retrieve a single batch item by ID.""" - with db_registry.session() as session: - item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) + async with db_registry.async_session() as session: + item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor) return item.to_pydantic() @enforce_types - def update_llm_batch_item( + @trace_method + async def update_llm_batch_item_async( self, item_id: str, actor: PydanticUser, @@ -286,8 +298,8 @@ class LLMBatchManager: step_state: Optional[AgentStepState] = None, ) -> PydanticLLMBatchItem: """Update fields on a batch item.""" - with db_registry.session() as session: - item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) + async with db_registry.async_session() as session: + item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor) if request_status: item.request_status = request_status @@ -298,9 +310,11 @@ class LLMBatchManager: if step_state: item.step_state = step_state - return item.update(db_session=session, actor=actor).to_pydantic() + result = await item.update_async(db_session=session, actor=actor) + return result.to_pydantic() @enforce_types + @trace_method async def list_llm_batch_items_async( self, llm_batch_id: str, @@ -346,7 +360,8 @@ class LLMBatchManager: results = await session.execute(query) return [item.to_pydantic() for item in results.scalars()] - def bulk_update_llm_batch_items( + @trace_method + async def bulk_update_llm_batch_items_async( self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True ) -> None: """ @@ -364,13 +379,13 @@ class LLMBatchManager: if len(llm_batch_id_agent_id_pairs) != len(field_updates): raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length") - with db_registry.session() as session: + async with db_registry.async_session() as session: # Lookup primary keys for all requested (batch_id, agent_id) pairs - items = ( - session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id) - .filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs)) - .all() + query = select(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).filter( + tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs) ) + result = await session.execute(query) + items = result.all() pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items} if strict: @@ -395,11 +410,12 @@ class LLMBatchManager: mappings.append(update_fields) if mappings: - session.bulk_update_mappings(LLMBatchItem, mappings) - session.commit() + await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings)) + await session.commit() @enforce_types - def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None: + @trace_method + async def bulk_update_batch_llm_items_results_by_agent_async(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None: """Update request status and batch results for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [ @@ -410,33 +426,41 @@ class LLMBatchManager: for update in updates ] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) + await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types - def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None: + @trace_method + async def bulk_update_llm_batch_items_step_status_by_agent_async( + self, updates: List[StepStatusUpdateInfo], strict: bool = True + ) -> None: """Update step status for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"step_status": update.step_status} for update in updates] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) + await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types - def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None: + @trace_method + async def bulk_update_llm_batch_items_request_status_by_agent_async( + self, updates: List[RequestStatusUpdateInfo], strict: bool = True + ) -> None: """Update request status for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"request_status": update.request_status} for update in updates] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) + await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types - def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None: + @trace_method + async def delete_llm_batch_item_async(self, item_id: str, actor: PydanticUser) -> None: """Hard delete a batch item by ID.""" - with db_registry.session() as session: - item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) - item.hard_delete(db_session=session, actor=actor) + async with db_registry.async_session() as session: + item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor) + await item.hard_delete_async(db_session=session, actor=actor) @enforce_types - def count_llm_batch_items(self, llm_batch_id: str) -> int: + @trace_method + async def count_llm_batch_items_async(self, llm_batch_id: str) -> int: """ Efficiently count the number of batch items for a given llm_batch_id. @@ -446,6 +470,6 @@ class LLMBatchManager: Returns: int: The total number of batch items associated with the given llm_batch_id. """ - with db_registry.session() as session: - count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar() - return count or 0 + async with db_registry.async_session() as session: + count = await session.execute(select(func.count(LLMBatchItem.id)).where(LLMBatchItem.llm_batch_id == llm_batch_id)) + return count.scalar() or 0 diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 91351db3f..2477f303c 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -13,6 +13,7 @@ from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types logger = get_logger(__name__) @@ -22,6 +23,7 @@ class MessageManager: """Manager class to handle business logic related to Messages.""" @enforce_types + @trace_method def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: """Fetch a message by ID.""" with db_registry.session() as session: @@ -32,6 +34,7 @@ class MessageManager: return None @enforce_types + @trace_method async def get_message_by_id_async(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]: """Fetch a message by ID.""" async with db_registry.async_session() as session: @@ -42,6 +45,7 @@ class MessageManager: return None @enforce_types + @trace_method def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: """Fetch messages by ID and return them in the requested order.""" with db_registry.session() as session: @@ -49,6 +53,7 @@ class MessageManager: return self._get_messages_by_id_postprocess(results, message_ids) @enforce_types + @trace_method async def get_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]: """Fetch messages by ID and return them in the requested order. Async version of above function.""" async with db_registry.async_session() as session: @@ -71,6 +76,7 @@ class MessageManager: return list(filter(lambda x: x is not None, [result_dict.get(msg_id, None) for msg_id in message_ids])) @enforce_types + @trace_method def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage: """Create a new message.""" with db_registry.session() as session: @@ -92,6 +98,7 @@ class MessageManager: return orm_messages @enforce_types + @trace_method def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]: """ Create multiple messages in a single database transaction. @@ -111,6 +118,7 @@ class MessageManager: return [msg.to_pydantic() for msg in created_messages] @enforce_types + @trace_method async def create_many_messages_async(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]: """ Create multiple messages in a single database transaction asynchronously. @@ -131,6 +139,7 @@ class MessageManager: return [msg.to_pydantic() for msg in created_messages] @enforce_types + @trace_method def update_message_by_letta_message( self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser ) -> PydanticMessage: @@ -169,6 +178,7 @@ class MessageManager: raise ValueError(f"Message type got modified: {letta_message_update.message_type}") @enforce_types + @trace_method def update_message_by_letta_message( self, message_id: str, letta_message_update: LettaMessageUpdateUnion, actor: PydanticUser ) -> PydanticMessage: @@ -207,6 +217,7 @@ class MessageManager: raise ValueError(f"Message type got modified: {letta_message_update.message_type}") @enforce_types + @trace_method def update_message_by_id(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage: """ Updates an existing record in the database with values from the provided record object. @@ -224,6 +235,7 @@ class MessageManager: return message.to_pydantic() @enforce_types + @trace_method async def update_message_by_id_async(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage: """ Updates an existing record in the database with values from the provided record object. @@ -267,6 +279,7 @@ class MessageManager: return message @enforce_types + @trace_method def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool: """Delete a message.""" with db_registry.session() as session: @@ -281,6 +294,7 @@ class MessageManager: raise ValueError(f"Message with id {message_id} not found.") @enforce_types + @trace_method def size( self, actor: PydanticUser, @@ -297,6 +311,7 @@ class MessageManager: return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id) @enforce_types + @trace_method async def size_async( self, actor: PydanticUser, @@ -312,6 +327,7 @@ class MessageManager: return await MessageModel.size_async(db_session=session, actor=actor, role=role, agent_id=agent_id) @enforce_types + @trace_method def list_user_messages_for_agent( self, agent_id: str, @@ -334,6 +350,7 @@ class MessageManager: ) @enforce_types + @trace_method def list_messages_for_agent( self, agent_id: str, @@ -437,6 +454,7 @@ class MessageManager: return [msg.to_pydantic() for msg in results] @enforce_types + @trace_method async def list_messages_for_agent_async( self, agent_id: str, @@ -538,6 +556,7 @@ class MessageManager: return [msg.to_pydantic() for msg in results] @enforce_types + @trace_method def delete_all_messages_for_agent(self, agent_id: str, actor: PydanticUser) -> int: """ Efficiently deletes all messages associated with a given agent_id, diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 00a528334..715f57aa4 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -5,6 +5,7 @@ from letta.orm.organization import Organization as OrganizationModel from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.organization import OrganizationUpdate from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types @@ -15,11 +16,13 @@ class OrganizationManager: DEFAULT_ORG_NAME = "default_org" @enforce_types + @trace_method def get_default_organization(self) -> PydanticOrganization: """Fetch the default organization.""" return self.get_organization_by_id(self.DEFAULT_ORG_ID) @enforce_types + @trace_method def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]: """Fetch an organization by ID.""" with db_registry.session() as session: @@ -27,6 +30,7 @@ class OrganizationManager: return organization.to_pydantic() @enforce_types + @trace_method def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: """Create a new organization.""" try: @@ -36,6 +40,7 @@ class OrganizationManager: return self._create_organization(pydantic_org=pydantic_org) @enforce_types + @trace_method def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: with db_registry.session() as session: org = OrganizationModel(**pydantic_org.model_dump(to_orm=True)) @@ -43,11 +48,13 @@ class OrganizationManager: return org.to_pydantic() @enforce_types + @trace_method def create_default_organization(self) -> PydanticOrganization: """Create the default organization.""" return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)) @enforce_types + @trace_method def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: """Update an organization.""" with db_registry.session() as session: @@ -58,6 +65,7 @@ class OrganizationManager: return org.to_pydantic() @enforce_types + @trace_method def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization: """Update an organization.""" with db_registry.session() as session: @@ -70,6 +78,7 @@ class OrganizationManager: return org.to_pydantic() @enforce_types + @trace_method def delete_organization_by_id(self, org_id: str): """Delete an organization by marking it as deleted.""" with db_registry.session() as session: @@ -77,6 +86,7 @@ class OrganizationManager: organization.hard_delete(session) @enforce_types + @trace_method def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: """List all organizations with optional pagination.""" with db_registry.session() as session: diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 3cd581b3f..44e33866a 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -11,6 +11,7 @@ from letta.schemas.agent import AgentState from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types @@ -18,6 +19,7 @@ class PassageManager: """Manager class to handle business logic related to Passages.""" @enforce_types + @trace_method def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: """Fetch a passage by ID.""" with db_registry.session() as session: @@ -34,6 +36,7 @@ class PassageManager: raise NoResultFound(f"Passage with id {passage_id} not found in database.") @enforce_types + @trace_method def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: """Create a new passage in the appropriate table based on whether it has agent_id or source_id.""" # Common fields for both passage types @@ -70,11 +73,13 @@ class PassageManager: return passage.to_pydantic() @enforce_types + @trace_method def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: """Create multiple passages.""" return [self.create_passage(p, actor) for p in passages] @enforce_types + @trace_method def insert_passage( self, agent_state: AgentState, @@ -136,6 +141,7 @@ class PassageManager: raise e @enforce_types + @trace_method def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]: """Update a passage.""" if not passage_id: @@ -170,6 +176,7 @@ class PassageManager: return curr_passage.to_pydantic() @enforce_types + @trace_method def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: """Delete a passage from either source or archival passages.""" if not passage_id: @@ -190,6 +197,8 @@ class PassageManager: except NoResultFound: raise NoResultFound(f"Passage with id {passage_id} not found.") + @enforce_types + @trace_method def delete_passages( self, actor: PydanticUser, @@ -202,6 +211,7 @@ class PassageManager: return True @enforce_types + @trace_method def size( self, actor: PydanticUser, @@ -217,6 +227,7 @@ class PassageManager: return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id) @enforce_types + @trace_method async def size_async( self, actor: PydanticUser, @@ -230,6 +241,8 @@ class PassageManager: async with db_registry.async_session() as session: return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id) + @enforce_types + @trace_method def estimate_embeddings_size( self, actor: PydanticUser, diff --git a/letta/services/per_agent_lock_manager.py b/letta/services/per_agent_lock_manager.py index fab3742e2..e8e2a0a43 100644 --- a/letta/services/per_agent_lock_manager.py +++ b/letta/services/per_agent_lock_manager.py @@ -1,6 +1,8 @@ import threading from collections import defaultdict +from letta.tracing import trace_method + class PerAgentLockManager: """Manages per-agent locks.""" @@ -8,10 +10,12 @@ class PerAgentLockManager: def __init__(self): self.locks = defaultdict(threading.Lock) + @trace_method def get_lock(self, agent_id: str) -> threading.Lock: """Retrieve the lock for a specific agent_id.""" return self.locks[agent_id] + @trace_method def clear_lock(self, agent_id: str): """Optionally remove a lock if no longer needed (to prevent unbounded growth).""" if agent_id in self.locks: diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 9bb4a8178..6b2bab01d 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -6,12 +6,14 @@ from letta.schemas.providers import Provider as PydanticProvider from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types class ProviderManager: @enforce_types + @trace_method def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider: """Create a new provider if it doesn't already exist.""" with db_registry.session() as session: @@ -32,6 +34,7 @@ class ProviderManager: return new_provider.to_pydantic() @enforce_types + @trace_method def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider: """Update provider details.""" with db_registry.session() as session: @@ -48,6 +51,7 @@ class ProviderManager: return existing_provider.to_pydantic() @enforce_types + @trace_method def delete_provider_by_id(self, provider_id: str, actor: PydanticUser): """Delete a provider.""" with db_registry.session() as session: @@ -62,6 +66,7 @@ class ProviderManager: session.commit() @enforce_types + @trace_method def list_providers( self, actor: PydanticUser, @@ -87,16 +92,45 @@ class ProviderManager: return [provider.to_pydantic() for provider in providers] @enforce_types + @trace_method + async def list_providers_async( + self, + actor: PydanticUser, + name: Optional[str] = None, + provider_type: Optional[ProviderType] = None, + after: Optional[str] = None, + limit: Optional[int] = 50, + ) -> List[PydanticProvider]: + """List all providers with optional pagination.""" + filter_kwargs = {} + if name: + filter_kwargs["name"] = name + if provider_type: + filter_kwargs["provider_type"] = provider_type + async with db_registry.async_session() as session: + providers = await ProviderModel.list_async( + db_session=session, + after=after, + limit=limit, + actor=actor, + **filter_kwargs, + ) + return [provider.to_pydantic() for provider in providers] + + @enforce_types + @trace_method def get_provider_id_from_name(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: providers = self.list_providers(name=provider_name, actor=actor) return providers[0].id if providers else None @enforce_types + @trace_method def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: providers = self.list_providers(name=provider_name, actor=actor) return providers[0].api_key if providers else None @enforce_types + @trace_method def check_provider_api_key(self, provider_check: ProviderCheck) -> None: provider = PydanticProvider( name=provider_check.provider_type.value, diff --git a/letta/services/sandbox_config_manager.py b/letta/services/sandbox_config_manager.py index 0f55a0bc5..6e7a43bcd 100644 --- a/letta/services/sandbox_config_manager.py +++ b/letta/services/sandbox_config_manager.py @@ -12,6 +12,7 @@ from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types, printd logger = get_logger(__name__) @@ -21,6 +22,7 @@ class SandboxConfigManager: """Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable.""" @enforce_types + @trace_method def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig: sandbox_config = self.get_sandbox_config_by_type(sandbox_type, actor=actor) if not sandbox_config: @@ -38,6 +40,7 @@ class SandboxConfigManager: return sandbox_config @enforce_types + @trace_method def create_or_update_sandbox_config(self, sandbox_config_create: SandboxConfigCreate, actor: PydanticUser) -> PydanticSandboxConfig: """Create or update a sandbox configuration based on the PydanticSandboxConfig schema.""" config = sandbox_config_create.config @@ -71,6 +74,61 @@ class SandboxConfigManager: return db_sandbox.to_pydantic() @enforce_types + @trace_method + async def get_or_create_default_sandbox_config_async(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig: + sandbox_config = await self.get_sandbox_config_by_type_async(sandbox_type, actor=actor) + if not sandbox_config: + logger.debug(f"Creating new sandbox config of type {sandbox_type}, none found for organization {actor.organization_id}.") + + # TODO: Add more sandbox types later + if sandbox_type == SandboxType.E2B: + default_config = {} # Empty + else: + # TODO: May want to move this to environment variables v.s. persisting in database + default_local_sandbox_path = LETTA_TOOL_EXECUTION_DIR + default_config = LocalSandboxConfig(sandbox_dir=default_local_sandbox_path).model_dump(exclude_none=True) + + sandbox_config = await self.create_or_update_sandbox_config_async(SandboxConfigCreate(config=default_config), actor=actor) + return sandbox_config + + @enforce_types + @trace_method + async def create_or_update_sandbox_config_async( + self, sandbox_config_create: SandboxConfigCreate, actor: PydanticUser + ) -> PydanticSandboxConfig: + """Create or update a sandbox configuration based on the PydanticSandboxConfig schema.""" + config = sandbox_config_create.config + sandbox_type = config.type + sandbox_config = PydanticSandboxConfig( + type=sandbox_type, config=config.model_dump(exclude_none=True), organization_id=actor.organization_id + ) + + # Attempt to retrieve the existing sandbox configuration by type within the organization + db_sandbox = await self.get_sandbox_config_by_type_async(sandbox_config.type, actor=actor) + if db_sandbox: + # Prepare the update data, excluding fields that should not be reset + update_data = sandbox_config.model_dump(exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(db_sandbox, key) != value} + + # If there are changes, update the sandbox configuration + if update_data: + db_sandbox = await self.update_sandbox_config_async(db_sandbox.id, SandboxConfigUpdate(**update_data), actor) + else: + printd( + f"`create_or_update_sandbox_config` was called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"type={sandbox_config.type}, but found existing configuration with nothing to update." + ) + + return db_sandbox + else: + # If the sandbox configuration doesn't exist, create a new one + async with db_registry.async_session() as session: + db_sandbox = SandboxConfigModel(**sandbox_config.model_dump(exclude_none=True)) + await db_sandbox.create_async(session, actor=actor) + return db_sandbox.to_pydantic() + + @enforce_types + @trace_method def update_sandbox_config( self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser ) -> PydanticSandboxConfig: @@ -98,6 +156,35 @@ class SandboxConfigManager: return sandbox.to_pydantic() @enforce_types + @trace_method + async def update_sandbox_config_async( + self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser + ) -> PydanticSandboxConfig: + """Update an existing sandbox configuration.""" + async with db_registry.async_session() as session: + sandbox = await SandboxConfigModel.read_async(db_session=session, identifier=sandbox_config_id, actor=actor) + # We need to check that the sandbox_update provided is the same type as the original sandbox + if sandbox.type != sandbox_update.config.type: + raise ValueError( + f"Mismatched type for sandbox config update: tried to update sandbox_config of type {sandbox.type} with config of type {sandbox_update.config.type}" + ) + + update_data = sandbox_update.model_dump(exclude_unset=True, exclude_none=True) + update_data = {key: value for key, value in update_data.items() if getattr(sandbox, key) != value} + + if update_data: + for key, value in update_data.items(): + setattr(sandbox, key, value) + await sandbox.update_async(db_session=session, actor=actor) + else: + printd( + f"`update_sandbox_config` called with user_id={actor.id}, organization_id={actor.organization_id}, " + f"name={sandbox.type}, but nothing to update." + ) + return sandbox.to_pydantic() + + @enforce_types + @trace_method def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig: """Delete a sandbox configuration by its ID.""" with db_registry.session() as session: @@ -106,6 +193,7 @@ class SandboxConfigManager: return sandbox.to_pydantic() @enforce_types + @trace_method def list_sandbox_configs( self, actor: PydanticUser, @@ -123,6 +211,7 @@ class SandboxConfigManager: return [sandbox.to_pydantic() for sandbox in sandboxes] @enforce_types + @trace_method async def list_sandbox_configs_async( self, actor: PydanticUser, @@ -140,6 +229,7 @@ class SandboxConfigManager: return [sandbox.to_pydantic() for sandbox in sandboxes] @enforce_types + @trace_method def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: """Retrieve a sandbox configuration by its ID.""" with db_registry.session() as session: @@ -150,6 +240,7 @@ class SandboxConfigManager: return None @enforce_types + @trace_method def get_sandbox_config_by_type(self, type: SandboxType, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]: """Retrieve a sandbox config by its type.""" with db_registry.session() as session: @@ -167,6 +258,27 @@ class SandboxConfigManager: return None @enforce_types + @trace_method + async def get_sandbox_config_by_type_async( + self, type: SandboxType, actor: Optional[PydanticUser] = None + ) -> Optional[PydanticSandboxConfig]: + """Retrieve a sandbox config by its type.""" + async with db_registry.async_session() as session: + try: + sandboxes = await SandboxConfigModel.list_async( + db_session=session, + type=type, + organization_id=actor.organization_id, + limit=1, + ) + if sandboxes: + return sandboxes[0].to_pydantic() + return None + except NoResultFound: + return None + + @enforce_types + @trace_method def create_sandbox_env_var( self, env_var_create: SandboxEnvironmentVariableCreate, sandbox_config_id: str, actor: PydanticUser ) -> PydanticEnvVar: @@ -194,6 +306,7 @@ class SandboxConfigManager: return env_var.to_pydantic() @enforce_types + @trace_method def update_sandbox_env_var( self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser ) -> PydanticEnvVar: @@ -215,6 +328,7 @@ class SandboxConfigManager: return env_var.to_pydantic() @enforce_types + @trace_method def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar: """Delete a sandbox environment variable by its ID.""" with db_registry.session() as session: @@ -223,6 +337,7 @@ class SandboxConfigManager: return env_var.to_pydantic() @enforce_types + @trace_method def list_sandbox_env_vars( self, sandbox_config_id: str, @@ -242,6 +357,7 @@ class SandboxConfigManager: return [env_var.to_pydantic() for env_var in env_vars] @enforce_types + @trace_method async def list_sandbox_env_vars_async( self, sandbox_config_id: str, @@ -261,6 +377,7 @@ class SandboxConfigManager: return [env_var.to_pydantic() for env_var in env_vars] @enforce_types + @trace_method def list_sandbox_env_vars_by_key( self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> List[PydanticEnvVar]: @@ -276,6 +393,7 @@ class SandboxConfigManager: return [env_var.to_pydantic() for env_var in env_vars] @enforce_types + @trace_method def get_sandbox_env_vars_as_dict( self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> Dict[str, str]: @@ -286,6 +404,18 @@ class SandboxConfigManager: return result @enforce_types + @trace_method + async def get_sandbox_env_vars_as_dict_async( + self, sandbox_config_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 + ) -> Dict[str, str]: + env_vars = await self.list_sandbox_env_vars_async(sandbox_config_id, actor, after, limit) + result = {} + for env_var in env_vars: + result[env_var.key] = env_var.value + return result + + @enforce_types + @trace_method def get_sandbox_env_var_by_key_and_sandbox_config_id( self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None ) -> Optional[PydanticEnvVar]: diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index 6247967c3..7ec7aa3c1 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -1,3 +1,4 @@ +import asyncio from typing import List, Optional from letta.orm.errors import NoResultFound @@ -9,6 +10,7 @@ from letta.schemas.source import Source as PydanticSource from letta.schemas.source import SourceUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types, printd @@ -16,25 +18,27 @@ class SourceManager: """Manager class to handle business logic related to Sources.""" @enforce_types - def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: + @trace_method + async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource: """Create a new source based on the PydanticSource schema.""" # Try getting the source first by id - db_source = self.get_source_by_id(source.id, actor=actor) + db_source = await self.get_source_by_id(source.id, actor=actor) if db_source: return db_source else: - with db_registry.session() as session: + async with db_registry.async_session() as session: # Provide default embedding config if not given source.organization_id = actor.organization_id source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True)) - source.create(session, actor=actor) + await source.create_async(session, actor=actor) return source.to_pydantic() @enforce_types - def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: + @trace_method + async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource: """Update a source by its ID with the given SourceUpdate object.""" - with db_registry.session() as session: - source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + async with db_registry.async_session() as session: + source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor) # get update dictionary update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) @@ -53,18 +57,22 @@ class SourceManager: return source.to_pydantic() @enforce_types - def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: + @trace_method + async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource: """Delete a source by its ID.""" - with db_registry.session() as session: - source = SourceModel.read(db_session=session, identifier=source_id) - source.hard_delete(db_session=session, actor=actor) + async with db_registry.async_session() as session: + source = await SourceModel.read_async(db_session=session, identifier=source_id) + await source.hard_delete_async(db_session=session, actor=actor) return source.to_pydantic() @enforce_types - def list_sources(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]: + @trace_method + async def list_sources( + self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs + ) -> List[PydanticSource]: """List all sources with optional pagination.""" - with db_registry.session() as session: - sources = SourceModel.list( + async with db_registry.async_session() as session: + sources = await SourceModel.list_async( db_session=session, after=after, limit=limit, @@ -74,18 +82,17 @@ class SourceManager: return [source.to_pydantic() for source in sources] @enforce_types - def size( - self, - actor: PydanticUser, - ) -> int: + @trace_method + async def size(self, actor: PydanticUser) -> int: """ Get the total count of sources for the given user. """ - with db_registry.session() as session: - return SourceModel.size(db_session=session, actor=actor) + async with db_registry.async_session() as session: + return await SourceModel.size_async(db_session=session, actor=actor) @enforce_types - def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]: + @trace_method + async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]: """ Lists all agents that have the specified source attached. @@ -96,30 +103,33 @@ class SourceManager: Returns: List[PydanticAgentState]: List of agents that have this source attached """ - with db_registry.session() as session: + async with db_registry.async_session() as session: # Verify source exists and user has permission to access it - source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor) # The agents relationship is already loaded due to lazy="selectin" in the Source model # and will be properly filtered by organization_id due to the OrganizationMixin - return [agent.to_pydantic() for agent in source.agents] + agents_orm = source.agents + return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm]) # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types - def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: + @trace_method + async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]: """Retrieve a source by its ID.""" - with db_registry.session() as session: + async with db_registry.async_session() as session: try: - source = SourceModel.read(db_session=session, identifier=source_id, actor=actor) + source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor) return source.to_pydantic() except NoResultFound: return None @enforce_types - def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]: + @trace_method + async def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]: """Retrieve a source by its name.""" - with db_registry.session() as session: - sources = SourceModel.list( + async with db_registry.async_session() as session: + sources = await SourceModel.list_async( db_session=session, name=source_name, organization_id=actor.organization_id, @@ -131,44 +141,49 @@ class SourceManager: return sources[0].to_pydantic() @enforce_types - def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata: + @trace_method + async def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata: """Create a new file based on the PydanticFileMetadata schema.""" - db_file = self.get_file_by_id(file_metadata.id, actor=actor) + db_file = await self.get_file_by_id(file_metadata.id, actor=actor) if db_file: return db_file else: - with db_registry.session() as session: + async with db_registry.async_session() as session: file_metadata.organization_id = actor.organization_id file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True)) - file_metadata.create(session, actor=actor) + await file_metadata.create_async(session, actor=actor) return file_metadata.to_pydantic() # TODO: We make actor optional for now, but should most likely be enforced due to security reasons @enforce_types - def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]: + @trace_method + async def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]: """Retrieve a file by its ID.""" - with db_registry.session() as session: + async with db_registry.async_session() as session: try: - file = FileMetadataModel.read(db_session=session, identifier=file_id, actor=actor) + file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor) return file.to_pydantic() except NoResultFound: return None @enforce_types - def list_files( + @trace_method + async def list_files( self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50 ) -> List[PydanticFileMetadata]: """List all files with optional pagination.""" - with db_registry.session() as session: - files = FileMetadataModel.list( + async with db_registry.async_session() as session: + files_all = await FileMetadataModel.list_async(db_session=session, organization_id=actor.organization_id, source_id=source_id) + files = await FileMetadataModel.list_async( db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id ) return [file.to_pydantic() for file in files] @enforce_types - def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata: + @trace_method + async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata: """Delete a file by its ID.""" - with db_registry.session() as session: - file = FileMetadataModel.read(db_session=session, identifier=file_id) - file.hard_delete(db_session=session, actor=actor) + async with db_registry.async_session() as session: + file = await FileMetadataModel.read_async(db_session=session, identifier=file_id) + await file.hard_delete_async(db_session=session, actor=actor) return file.to_pydantic() diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 8ee052218..9e11c55c9 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -14,13 +14,14 @@ from letta.schemas.step import Step as PydanticStep from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.helpers.noop_helper import singleton -from letta.tracing import get_trace_id +from letta.tracing import get_trace_id, trace_method from letta.utils import enforce_types class StepManager: @enforce_types + @trace_method def list_steps( self, actor: PydanticUser, @@ -54,6 +55,7 @@ class StepManager: return [step.to_pydantic() for step in steps] @enforce_types + @trace_method def log_step( self, actor: PydanticUser, @@ -96,6 +98,7 @@ class StepManager: return new_step.to_pydantic() @enforce_types + @trace_method async def log_step_async( self, actor: PydanticUser, @@ -138,12 +141,14 @@ class StepManager: return new_step.to_pydantic() @enforce_types + @trace_method def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep: with db_registry.session() as session: step = StepModel.read(db_session=session, identifier=step_id, actor=actor) return step.to_pydantic() @enforce_types + @trace_method def update_step_transaction_id(self, actor: PydanticUser, step_id: str, transaction_id: str) -> PydanticStep: """Update the transaction ID for a step. @@ -236,6 +241,7 @@ class NoopStepManager(StepManager): """ @enforce_types + @trace_method def log_step( self, actor: PydanticUser, @@ -253,6 +259,7 @@ class NoopStepManager(StepManager): return @enforce_types + @trace_method async def log_step_async( self, actor: PydanticUser, diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 9e7bf42f2..078943540 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -26,6 +26,7 @@ from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry +from letta.tracing import trace_method from letta.utils import enforce_types, printd logger = get_logger(__name__) @@ -36,6 +37,7 @@ class ToolManager: # TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object @enforce_types + @trace_method def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" tool_id = self.get_tool_id_by_name(tool_name=pydantic_tool.name, actor=actor) @@ -62,6 +64,7 @@ class ToolManager: return tool @enforce_types + @trace_method async def create_or_update_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" tool_id = await self.get_tool_id_by_name_async(tool_name=pydantic_tool.name, actor=actor) @@ -88,6 +91,7 @@ class ToolManager: return tool @enforce_types + @trace_method def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}} return self.create_or_update_tool( @@ -98,18 +102,21 @@ class ToolManager: ) @enforce_types + @trace_method def create_or_update_composio_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: return self.create_or_update_tool( PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor ) @enforce_types + @trace_method def create_or_update_langchain_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: return self.create_or_update_tool( PydanticTool(tool_type=ToolType.EXTERNAL_LANGCHAIN, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor ) @enforce_types + @trace_method def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" with db_registry.session() as session: @@ -125,6 +132,7 @@ class ToolManager: return tool.to_pydantic() @enforce_types + @trace_method async def create_tool_async(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: """Create a new tool based on the ToolCreate schema.""" async with db_registry.async_session() as session: @@ -140,6 +148,7 @@ class ToolManager: return tool.to_pydantic() @enforce_types + @trace_method def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool: """Fetch a tool by its ID.""" with db_registry.session() as session: @@ -149,6 +158,7 @@ class ToolManager: return tool.to_pydantic() @enforce_types + @trace_method async def get_tool_by_id_async(self, tool_id: str, actor: PydanticUser) -> PydanticTool: """Fetch a tool by its ID.""" async with db_registry.async_session() as session: @@ -158,6 +168,7 @@ class ToolManager: return tool.to_pydantic() @enforce_types + @trace_method def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" try: @@ -168,6 +179,7 @@ class ToolManager: return None @enforce_types + @trace_method async def get_tool_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" try: @@ -178,6 +190,7 @@ class ToolManager: return None @enforce_types + @trace_method def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" try: @@ -188,6 +201,7 @@ class ToolManager: return None @enforce_types + @trace_method async def get_tool_id_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[str]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" try: @@ -198,6 +212,7 @@ class ToolManager: return None @enforce_types + @trace_method async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination.""" async with db_registry.async_session() as session: @@ -223,6 +238,7 @@ class ToolManager: return results @enforce_types + @trace_method def size( self, actor: PydanticUser, @@ -239,6 +255,7 @@ class ToolManager: return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET) @enforce_types + @trace_method def update_tool_by_id( self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None ) -> PydanticTool: @@ -267,6 +284,7 @@ class ToolManager: return tool.update(db_session=session, actor=actor).to_pydantic() @enforce_types + @trace_method async def update_tool_by_id_async( self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None ) -> PydanticTool: @@ -296,6 +314,7 @@ class ToolManager: return tool.to_pydantic() @enforce_types + @trace_method def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None: """Delete a tool by its ID.""" with db_registry.session() as session: @@ -306,6 +325,7 @@ class ToolManager: raise ValueError(f"Tool with id {tool_id} not found.") @enforce_types + @trace_method def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]: """Add default tools in base.py and multi_agent.py""" functions_to_schema = {} @@ -371,6 +391,7 @@ class ToolManager: return tools @enforce_types + @trace_method async def upsert_base_tools_async(self, actor: PydanticUser) -> List[PydanticTool]: """Add default tools in base.py and multi_agent.py""" functions_to_schema = {} diff --git a/letta/services/tool_sandbox/e2b_sandbox.py b/letta/services/tool_sandbox/e2b_sandbox.py index 2307ea0a1..07ab57276 100644 --- a/letta/services/tool_sandbox/e2b_sandbox.py +++ b/letta/services/tool_sandbox/e2b_sandbox.py @@ -53,7 +53,9 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): if self.provided_sandbox_config: sbx_config = self.provided_sandbox_config else: - sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=self.user) + sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async( + sandbox_type=SandboxType.E2B, actor=self.user + ) # TODO: So this defaults to force recreating always # TODO: Eventually, provision one sandbox PER agent, and that agent re-uses that one specifically e2b_sandbox = await self.create_e2b_sandbox_with_metadata_hash(sandbox_config=sbx_config) @@ -71,7 +73,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): if self.provided_sandbox_env_vars: env_vars.update(self.provided_sandbox_env_vars) else: - db_env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict( + db_env_vars = await self.sandbox_config_manager.get_sandbox_env_vars_as_dict_async( sandbox_config_id=sbx_config.id, actor=self.user, limit=100 ) env_vars.update(db_env_vars) diff --git a/letta/services/tool_sandbox/local_sandbox.py b/letta/services/tool_sandbox/local_sandbox.py index a17815961..276409516 100644 --- a/letta/services/tool_sandbox/local_sandbox.py +++ b/letta/services/tool_sandbox/local_sandbox.py @@ -60,14 +60,16 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): additional_env_vars: Optional[Dict], ) -> ToolExecutionResult: """ - Unified asynchronougit pus method to run the tool in a local sandbox environment, + Unified asynchronous method to run the tool in a local sandbox environment, always via subprocess for multi-core parallelism. """ # Get sandbox configuration if self.provided_sandbox_config: sbx_config = self.provided_sandbox_config else: - sbx_config = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.user) + sbx_config = await self.sandbox_config_manager.get_or_create_default_sandbox_config_async( + sandbox_type=SandboxType.LOCAL, actor=self.user + ) local_configs = sbx_config.get_local_config() use_venv = local_configs.use_venv @@ -76,7 +78,9 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): if self.provided_sandbox_env_vars: env.update(self.provided_sandbox_env_vars) else: - env_vars = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(sandbox_config_id=sbx_config.id, actor=self.user, limit=100) + env_vars = await self.sandbox_config_manager.get_sandbox_env_vars_as_dict_async( + sandbox_config_id=sbx_config.id, actor=self.user, limit=100 + ) env.update(env_vars) if agent_state: diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index b1c64100f..55c493beb 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -7,6 +7,7 @@ from letta.schemas.user import User as PydanticUser from letta.schemas.user import UserUpdate from letta.server.db import db_registry from letta.services.organization_manager import OrganizationManager +from letta.tracing import trace_method from letta.utils import enforce_types @@ -17,6 +18,7 @@ class UserManager: DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000" @enforce_types + @trace_method def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser: """Create the default user.""" with db_registry.session() as session: @@ -37,6 +39,7 @@ class UserManager: return user.to_pydantic() @enforce_types + @trace_method def create_user(self, pydantic_user: PydanticUser) -> PydanticUser: """Create a new user if it doesn't already exist.""" with db_registry.session() as session: @@ -45,6 +48,7 @@ class UserManager: return new_user.to_pydantic() @enforce_types + @trace_method async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser: """Create a new user if it doesn't already exist (async version).""" async with db_registry.async_session() as session: @@ -53,6 +57,7 @@ class UserManager: return new_user.to_pydantic() @enforce_types + @trace_method def update_user(self, user_update: UserUpdate) -> PydanticUser: """Update user details.""" with db_registry.session() as session: @@ -69,6 +74,7 @@ class UserManager: return existing_user.to_pydantic() @enforce_types + @trace_method async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser: """Update user details (async version).""" async with db_registry.async_session() as session: @@ -85,6 +91,7 @@ class UserManager: return existing_user.to_pydantic() @enforce_types + @trace_method def delete_user_by_id(self, user_id: str): """Delete a user and their associated records (agents, sources, mappings).""" with db_registry.session() as session: @@ -95,6 +102,7 @@ class UserManager: session.commit() @enforce_types + @trace_method async def delete_actor_by_id_async(self, user_id: str): """Delete a user and their associated records (agents, sources, mappings) asynchronously.""" async with db_registry.async_session() as session: @@ -103,6 +111,7 @@ class UserManager: await user.hard_delete_async(session) @enforce_types + @trace_method def get_user_by_id(self, user_id: str) -> PydanticUser: """Fetch a user by ID.""" with db_registry.session() as session: @@ -110,6 +119,7 @@ class UserManager: return user.to_pydantic() @enforce_types + @trace_method async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser: """Fetch a user by ID asynchronously.""" async with db_registry.async_session() as session: @@ -117,6 +127,7 @@ class UserManager: return user.to_pydantic() @enforce_types + @trace_method def get_default_user(self) -> PydanticUser: """Fetch the default user. If it doesn't exist, create it.""" try: @@ -125,6 +136,7 @@ class UserManager: return self.create_default_user() @enforce_types + @trace_method def get_user_or_default(self, user_id: Optional[str] = None): """Fetch the user or default user.""" if not user_id: @@ -136,6 +148,7 @@ class UserManager: return self.get_default_user() @enforce_types + @trace_method async def get_default_actor_async(self) -> PydanticUser: """Fetch the default user asynchronously. If it doesn't exist, create it.""" try: @@ -145,6 +158,7 @@ class UserManager: return self.create_default_user(org_id=self.DEFAULT_ORG_ID) @enforce_types + @trace_method async def get_actor_or_default_async(self, actor_id: Optional[str] = None): """Fetch the user or default user asynchronously.""" if not actor_id: @@ -156,6 +170,7 @@ class UserManager: return await self.get_default_actor_async() @enforce_types + @trace_method def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: """List all users with optional pagination.""" with db_registry.session() as session: @@ -167,6 +182,7 @@ class UserManager: return [user.to_pydantic() for user in users] @enforce_types + @trace_method async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: """List all users with optional pagination (async version).""" async with db_registry.async_session() as session: diff --git a/pyproject.toml b/pyproject.toml index 76a30ad29..2400bdf84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.7.21" +version = "0.7.22" packages = [ {include = "letta"}, ] diff --git a/tests/constants.py b/tests/constants.py index e1832cbd2..fa60404c4 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1 +1,3 @@ TIMEOUT = 30 # seconds +embedding_config_dir = "tests/configs/embedding_model_configs" +llm_config_dir = "tests/configs/llm_model_configs" diff --git a/tests/helpers/client_helper.py b/tests/helpers/client_helper.py index 815102a88..99740d54b 100644 --- a/tests/helpers/client_helper.py +++ b/tests/helpers/client_helper.py @@ -1,13 +1,12 @@ import time -from typing import Union -from letta import LocalClient, RESTClient +from letta import RESTClient from letta.schemas.enums import JobStatus from letta.schemas.job import Job from letta.schemas.source import Source -def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Source, filename: str) -> Job: +def upload_file_using_client(client: RESTClient, source: Source, filename: str) -> Job: # load a file into a source (non-blocking job) upload_job = client.load_file_to_source(filename=filename, source_id=source.id, blocking=False) print("Upload job", upload_job, upload_job.status, upload_job.metadata) diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 7774a752a..2fa78a486 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -1,33 +1,28 @@ import json import logging import uuid -from typing import Callable, List, Optional, Sequence, Union +from typing import Callable, List, Optional, Sequence from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs +from letta.schemas.block import CreateBlock from letta.schemas.tool_rule import BaseToolRule +from letta.server.server import SyncServer logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) -from letta import LocalClient, RESTClient, create_client -from letta.agent import Agent from letta.config import LettaConfig from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA from letta.embeddings import embedding_model from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError -from letta.helpers.json_helpers import json_dumps -from letta.llm_api.llm_api_tools import create -from letta.llm_api.llm_client import LLMClient from letta.local_llm.constants import INNER_THOUGHTS_KWARG -from letta.schemas.agent import AgentState +from letta.schemas.agent import AgentState, CreateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCallMessage from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory -from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message +from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message from letta.utils import get_human_text, get_persona_text -from tests.helpers.utils import cleanup # Generate uuid for agent name for this example namespace = uuid.NAMESPACE_DNS @@ -45,7 +40,7 @@ LLM_CONFIG_PATH = "tests/configs/llm_model_configs/letta-hosted.json" def setup_agent( - client: Union[LocalClient, RESTClient], + server: SyncServer, filename: str, memory_human_str: str = get_human_text(DEFAULT_HUMAN), memory_persona_str: str = get_persona_text(DEFAULT_PERSONA), @@ -65,17 +60,27 @@ def setup_agent( config.default_embedding_config = embedding_config config.save() - memory = ChatMemory(human=memory_human_str, persona=memory_persona_str) - agent_state = client.create_agent( + request = CreateAgent( name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, - memory=memory, + memory_blocks=[ + CreateBlock( + label="human", + value=memory_human_str, + ), + CreateBlock( + label="persona", + value=memory_persona_str, + ), + ], tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools, include_base_tool_rules=include_base_tool_rules, ) + actor = server.user_manager.get_user_or_default() + agent_state = server.create_agent(request=request, actor=actor) return agent_state @@ -86,285 +91,6 @@ def setup_agent( # ====================================================================================================================== -def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner_monologue_contents: bool = True) -> ChatCompletionResponse: - """ - Checks that the first response is valid: - - 1. Contains either send_message or archival_memory_search - 2. Contains valid usage of the function - 3. Contains inner monologue - - Note: This is acting on the raw LLM response, note the usage of `create` - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - - full_agent_state = client.get_agent(agent_state.id) - messages = client.server.agent_manager.get_in_context_messages(agent_id=full_agent_state.id, actor=client.user) - agent = Agent(agent_state=full_agent_state, interface=None, user=client.user) - - llm_client = LLMClient.create( - provider_type=agent_state.llm_config.model_endpoint_type, - actor=client.user, - ) - if llm_client: - response = llm_client.send_llm_request( - messages=messages, - llm_config=agent_state.llm_config, - tools=[t.json_schema for t in agent.agent_state.tools], - ) - else: - response = create( - llm_config=agent_state.llm_config, - user_id=str(uuid.UUID(int=1)), # dummy user_id - messages=messages, - functions=[t.json_schema for t in agent.agent_state.tools], - ) - - # Basic check - assert response is not None, response - assert response.choices is not None, response - assert len(response.choices) > 0, response - assert response.choices[0] is not None, response - - # Select first choice - choice = response.choices[0] - - # Ensure that the first message returns a "send_message" - validator_func = ( - lambda function_call: function_call.name == "send_message" - or function_call.name == "archival_memory_search" - or function_call.name == "core_memory_append" - ) - assert_contains_valid_function_call(choice.message, validator_func) - - # Assert that the message has an inner monologue - assert_contains_correct_inner_monologue( - choice, - agent_state.llm_config.put_inner_thoughts_in_kwargs, - validate_inner_monologue_contents=validate_inner_monologue_contents, - ) - - return response - - -def check_response_contains_keyword(filename: str, keyword="banana") -> LettaResponse: - """ - Checks that the prompted response from the LLM contains a chosen keyword - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - - keyword_message = f'This is a test to see if you can see my message. If you can see my message, please respond by calling send_message using a message that includes the word "{keyword}"' - response = client.user_message(agent_id=agent_state.id, message=keyword_message) - - # Basic checks - assert_sanity_checks(response) - - # Make sure the message was sent - assert_invoked_send_message_with_keyword(response.messages, keyword) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_uses_external_tool(filename: str) -> LettaResponse: - """ - Checks that the LLM will use external tools if instructed - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - from composio import Action - - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - # Set up persona for tool usage - persona = f""" - - My name is Letta. - - I am a personal assistant who uses a tool called {tool.name} to star a desired github repo. - - Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts. - """ - - agent_state = setup_agent(client, filename, memory_persona_str=persona, tool_ids=[tool.id]) - - response = client.user_message(agent_id=agent_state.id, message="Please star the repo with owner=letta-ai and repo=letta") - - # Basic checks - assert_sanity_checks(response) - - # Make sure the tool was called - assert_invoked_function_call(response.messages, tool.name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_recall_chat_memory(filename: str) -> LettaResponse: - """ - Checks that the LLM will recall the chat memory, specifically the human persona. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - human_name = "BananaBoy" - agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name}.") - response = client.user_message( - agent_id=agent_state.id, message="Repeat my name back to me. You should search in your human memory block." - ) - - # Basic checks - assert_sanity_checks(response) - - # Make sure my name was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, human_name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_archival_memory_insert(filename: str) -> LettaResponse: - """ - Checks that the LLM will execute an archival memory insert. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - secret_word = "banana" - - response = client.user_message( - agent_id=agent_state.id, - message=f"Please insert the secret word '{secret_word}' into archival memory.", - ) - - # Basic checks - assert_sanity_checks(response) - - # Make sure archival_memory_search was called - assert_invoked_function_call(response.messages, "archival_memory_insert") - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_archival_memory_retrieval(filename: str) -> LettaResponse: - """ - Checks that the LLM will execute an archival memory retrieval. - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - agent_state = setup_agent(client, filename) - secret_word = "banana" - client.insert_archival_memory(agent_state.id, f"The secret word is {secret_word}!") - - response = client.user_message( - agent_id=agent_state.id, - message="Search archival memory for the secret word. If you find it successfully, you MUST respond by using the `send_message` function with a message that includes the secret word so I know you found it.", - ) - - # Basic checks - assert_sanity_checks(response) - - # Make sure archival_memory_search was called - assert_invoked_function_call(response.messages, "archival_memory_search") - - # Make sure secret was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, secret_word) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_edit_core_memory(filename: str) -> LettaResponse: - """ - Checks that the LLM is able to edit its core memories - - Note: This is acting on the Letta response, note the usage of `user_message` - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - human_name_a = "AngryAardvark" - human_name_b = "BananaBoy" - agent_state = setup_agent(client, filename, memory_human_str=f"My name is {human_name_a}") - client.user_message(agent_id=agent_state.id, message=f"Actually, my name changed. It is now {human_name_b}") - response = client.user_message(agent_id=agent_state.id, message="Repeat my name back to me.") - - # Basic checks - assert_sanity_checks(response) - - # Make sure my name was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, human_name_b) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - -def check_agent_summarize_memory_simple(filename: str) -> LettaResponse: - """ - Checks that the LLM is able to summarize its memory - """ - # Set up client - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - agent_state = setup_agent(client, filename) - - # Send a couple messages - friend_name = "Shub" - client.user_message(agent_id=agent_state.id, message="Hey, how's it going? What do you think about this whole shindig") - client.user_message(agent_id=agent_state.id, message=f"By the way, my friend's name is {friend_name}!") - client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?") - - # Summarize - agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - agent.summarize_messages_inplace() - print(f"Summarization succeeded: messages[1] = \n\n{json_dumps(agent.messages[1])}\n") - - response = client.user_message(agent_id=agent_state.id, message="What is my friend's name?") - # Basic checks - assert_sanity_checks(response) - - # Make sure my name was repeated back to me - assert_invoked_send_message_with_keyword(response.messages, friend_name) - - # Make sure some inner monologue is present - assert_inner_monologue_is_present_and_valid(response.messages) - - return response - - def run_embedding_endpoint(filename): # load JSON file config_data = json.load(open(filename, "r")) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index 9731ac359..2bb069828 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -2,12 +2,13 @@ import functools import time from typing import Union -from letta import LocalClient, RESTClient from letta.functions.functions import parse_source_code from letta.functions.schema_generator import generate_schema from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent from letta.schemas.tool import Tool +from letta.schemas.user import User from letta.schemas.user import User as PydanticUser +from letta.server.server import SyncServer def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4): @@ -75,12 +76,12 @@ def retry_until_success(max_attempts=10, sleep_time_seconds=4): return decorator_retry -def cleanup(client: Union[LocalClient, RESTClient], agent_uuid: str): +def cleanup(server: SyncServer, agent_uuid: str, actor: User): # Clear all agents - for agent_state in client.list_agents(): - if agent_state.name == agent_uuid: - client.delete_agent(agent_id=agent_state.id) - print(f"Deleted agent: {agent_state.name} with ID {str(agent_state.id)}") + agent_states = server.agent_manager.list_agents(name=agent_uuid, actor=actor) + + for agent_state in agent_states: + server.agent_manager.delete_agent(agent_id=agent_state.id, actor=actor) # Utility functions diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index bc3aee7ae..9647eb1b5 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -3,16 +3,20 @@ import uuid import pytest -from letta import create_client +from letta.config import LettaConfig from letta.schemas.letta_message import ToolCallMessage -from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule +from letta.schemas.letta_response import LettaResponse +from letta.schemas.message import MessageCreate +from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule +from letta.server.server import SyncServer from tests.helpers.endpoints_helper import ( assert_invoked_function_call, assert_invoked_send_message_with_keyword, assert_sanity_checks, setup_agent, ) -from tests.helpers.utils import cleanup, retry_until_success +from tests.helpers.utils import cleanup +from tests.utils import create_tool_from_func # Generate uuid for agent name for this example namespace = uuid.NAMESPACE_DNS @@ -20,106 +24,175 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph")) config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json" -"""Contrived tools for this test case""" +@pytest.fixture() +def server(): + config = LettaConfig.load() + config.save() + + server = SyncServer() + return server -def first_secret_word(): - """ - Call this to retrieve the first secret word, which you will need for the second_secret_word function. - """ - return "v0iq020i0g" +@pytest.fixture(scope="function") +def first_secret_tool(server): + def first_secret_word(): + """ + Retrieves the initial secret word in a multi-step sequence. + + Returns: + str: The first secret word. + """ + return "v0iq020i0g" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=first_secret_word), actor=actor) + yield tool -def second_secret_word(prev_secret_word: str): - """ - Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error. +@pytest.fixture(scope="function") +def second_secret_tool(server): + def second_secret_word(prev_secret_word: str): + """ + Retrieves the second secret word. - Args: - prev_secret_word (str): The secret word retrieved from calling first_secret_word. - """ - if prev_secret_word != "v0iq020i0g": - raise RuntimeError(f"Expected secret {'v0iq020i0g'}, got {prev_secret_word}") + Args: + prev_secret_word (str): The previously retrieved secret word. - return "4rwp2b4gxq" + Returns: + str: The second secret word. + """ + if prev_secret_word != "v0iq020i0g": + raise RuntimeError(f"Expected secret {'v0iq020i0g'}, got {prev_secret_word}") + return "4rwp2b4gxq" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=second_secret_word), actor=actor) + yield tool -def third_secret_word(prev_secret_word: str): - """ - Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error. +@pytest.fixture(scope="function") +def third_secret_tool(server): + def third_secret_word(prev_secret_word: str): + """ + Retrieves the third secret word. - Args: - prev_secret_word (str): The secret word retrieved from calling second_secret_word. - """ - if prev_secret_word != "4rwp2b4gxq": - raise RuntimeError(f'Expected secret "4rwp2b4gxq", got {prev_secret_word}') + Args: + prev_secret_word (str): The previously retrieved secret word. - return "hj2hwibbqm" + Returns: + str: The third secret word. + """ + if prev_secret_word != "4rwp2b4gxq": + raise RuntimeError(f'Expected secret "4rwp2b4gxq", got {prev_secret_word}') + return "hj2hwibbqm" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=third_secret_word), actor=actor) + yield tool -def fourth_secret_word(prev_secret_word: str): - """ - Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error. +@pytest.fixture(scope="function") +def fourth_secret_tool(server): + def fourth_secret_word(prev_secret_word: str): + """ + Retrieves the final secret word. - Args: - prev_secret_word (str): The secret word retrieved from calling third_secret_word. - """ - if prev_secret_word != "hj2hwibbqm": - raise RuntimeError(f"Expected secret {'hj2hwibbqm'}, got {prev_secret_word}") + Args: + prev_secret_word (str): The previously retrieved secret word. - return "banana" + Returns: + str: The final secret word. + """ + if prev_secret_word != "hj2hwibbqm": + raise RuntimeError(f"Expected secret {'hj2hwibbqm'}, got {prev_secret_word}") + return "banana" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=fourth_secret_word), actor=actor) + yield tool -def flip_coin(): - """ - Call this to retrieve the password to the secret word, which you will need to output in a send_message later. - If it returns an empty string, try flipping again! +@pytest.fixture(scope="function") +def flip_coin_tool(server): + def flip_coin(): + """ + Simulates a coin flip with a chance to return a secret word. - Returns: - str: The password or an empty string - """ - import random + Returns: + str: A secret word or an empty string. + """ + import random - # Flip a coin with 50% chance - if random.random() < 0.5: - return "" - return "hj2hwibbqm" + return "" if random.random() < 0.5 else "hj2hwibbqm" + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=flip_coin), actor=actor) + yield tool -def can_play_game(): - """ - Call this to start the tool chain. - """ - import random +@pytest.fixture(scope="function") +def can_play_game_tool(server): + def can_play_game(): + """ + Determines whether a game can be played. - return random.random() < 0.5 + Returns: + bool: True if allowed to play, False otherwise. + """ + import random + + return random.random() < 0.5 + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=can_play_game), actor=actor) + yield tool -def return_none(): - """ - Really simple function - """ - return None +@pytest.fixture(scope="function") +def return_none_tool(server): + def return_none(): + """ + Always returns None. + + Returns: + None + """ + return None + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=return_none), actor=actor) + yield tool -def auto_error(): - """ - If you call this function, it will throw an error automatically. - """ - raise RuntimeError("This should never be called.") +@pytest.fixture(scope="function") +def auto_error_tool(server): + def auto_error(): + """ + Always raises an error when called. + + Raises: + RuntimeError: Always triggered. + """ + raise RuntimeError("This should never be called.") + + actor = server.user_manager.get_user_or_default() + tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=auto_error), actor=actor) + yield tool + + +@pytest.fixture +def default_user(server): + yield server.user_manager.get_user_or_default() @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_single_path_agent_tool_call_graph(disable_e2b_api_key): - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) +def test_single_path_agent_tool_call_graph( + server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool, default_user +): + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) # Add tools - t1 = client.create_or_update_tool(first_secret_word) - t2 = client.create_or_update_tool(second_secret_word) - t3 = client.create_or_update_tool(third_secret_word) - t4 = client.create_or_update_tool(fourth_secret_word) - t_err = client.create_or_update_tool(auto_error) - tools = [t1, t2, t3, t4, t_err] + tools = [first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool] # Make tool rules tool_rules = [ @@ -132,8 +205,18 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key): ] # Make agent state - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?") + agent_state = setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")], + ) + messages = [message for step_messages in usage_stats.steps_messages for message in step_messages] + letta_messages = [] + for m in messages: + letta_messages += m.to_letta_messages() + + response = LettaResponse(messages=letta_messages, usage=usage_stats) # Make checks assert_sanity_checks(response) @@ -145,7 +228,7 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key): assert_invoked_function_call(response.messages, "fourth_secret_word") # Check ordering of tool calls - tool_names = [t.name for t in [t1, t2, t3, t4]] + tool_names = [t.name for t in [first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool]] tool_names += ["send_message"] for m in response.messages: if isinstance(m, ToolCallMessage): @@ -159,171 +242,281 @@ def test_single_path_agent_tool_call_graph(disable_e2b_api_key): assert_invoked_send_message_with_keyword(response.messages, "banana") print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) -def test_check_tool_rules_with_different_models(disable_e2b_api_key): - """Test that tool rules are properly checked for different model configurations.""" - client = create_client() - - config_files = [ +@pytest.mark.timeout(60) +@pytest.mark.parametrize( + "config_file", + [ "tests/configs/llm_model_configs/claude-3-5-sonnet.json", "tests/configs/llm_model_configs/openai-gpt-3.5-turbo.json", "tests/configs/llm_model_configs/openai-gpt-4o.json", - ] + ], +) +@pytest.mark.parametrize("init_tools_case", ["single", "multiple"]) +def test_check_tool_rules_with_different_models_parametrized( + server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, default_user, config_file, init_tools_case +): + """Test that tool rules are properly validated across model configurations and init tool scenarios.""" + agent_uuid = str(uuid.uuid4()) - # Create two test tools - t1_name = "first_secret_word" - t2_name = "second_secret_word" - t1 = client.create_or_update_tool(first_secret_word) - t2 = client.create_or_update_tool(second_secret_word) - tool_rules = [InitToolRule(tool_name=t1_name), InitToolRule(tool_name=t2_name)] - tools = [t1, t2] + if init_tools_case == "multiple": + tools = [first_secret_tool, second_secret_tool] + tool_rules = [ + InitToolRule(tool_name=first_secret_tool.name), + InitToolRule(tool_name=second_secret_tool.name), + ] + else: # "single" + tools = [third_secret_tool] + tool_rules = [InitToolRule(tool_name=third_secret_tool.name)] - for config_file in config_files: - # Setup tools - agent_uuid = str(uuid.uuid4()) - - if "gpt-4o" in config_file: - # Structured output model (should work with multiple init tools) - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - assert agent_state is not None - else: - # Non-structured output model (should raise error with multiple init tools) - with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"): - setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # Cleanup - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tool rule with single initial tool - t3_name = "third_secret_word" - t3 = client.create_or_update_tool(third_secret_word) - tool_rules = [InitToolRule(tool_name=t3_name)] - tools = [t3] - for config_file in config_files: - agent_uuid = str(uuid.uuid4()) - - # Structured output model (should work with single init tool) - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + if "gpt-4o" in config_file or init_tools_case == "single": + # Should succeed + agent_state = setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) assert agent_state is not None + else: + # Non-structured model with multiple init tools should fail + with pytest.raises(ValueError, match="Multiple initial tools are not supported for non-structured models"): + setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) -def test_claude_initial_tool_rule_enforced(disable_e2b_api_key): - """Test that the initial tool rule is enforced for the first message.""" - client = create_client() - - # Create tool rules that require tool_a to be called first - t1_name = "first_secret_word" - t2_name = "second_secret_word" - t1 = client.create_or_update_tool(first_secret_word) - t2 = client.create_or_update_tool(second_secret_word) +@pytest.mark.timeout(180) +def test_claude_initial_tool_rule_enforced( + server, + disable_e2b_api_key, + first_secret_tool, + second_secret_tool, + default_user, +): + """Test that the initial tool rule is enforced for the first message using Claude model.""" tool_rules = [ - InitToolRule(tool_name=t1_name), - ChildToolRule(tool_name=t1_name, children=[t2_name]), - TerminalToolRule(tool_name=t2_name), + InitToolRule(tool_name=first_secret_tool.name), + ChildToolRule(tool_name=first_secret_tool.name, children=[second_secret_tool.name]), + TerminalToolRule(tool_name=second_secret_tool.name), ] - tools = [t1, t2] - - # Make agent state + tools = [first_secret_tool, second_secret_tool] anthropic_config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" + for i in range(3): agent_uuid = str(uuid.uuid4()) agent_state = setup_agent( - client, anthropic_config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules + server, + anthropic_config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, ) - response = client.user_message(agent_id=agent_state.id, message="What is the second secret word?") + + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="What is the second secret word?")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [] + for m in messages: + letta_messages += m.to_letta_messages() + + response = LettaResponse(messages=letta_messages, usage=usage_stats) assert_sanity_checks(response) - messages = response.messages - assert_invoked_function_call(messages, "first_secret_word") - assert_invoked_function_call(messages, "second_secret_word") + # Check that the expected tools were invoked + assert_invoked_function_call(response.messages, "first_secret_word") + assert_invoked_function_call(response.messages, "second_secret_word") - tool_names = [t.name for t in [t1, t2]] - tool_names += ["send_message"] - for m in messages: + tool_names = [t.name for t in [first_secret_tool, second_secret_tool]] + ["send_message"] + for m in response.messages: if isinstance(m, ToolCallMessage): - # Check that it's equal to the first one assert m.tool_call.name == tool_names[0] - - # Pop out first one tool_names = tool_names[1:] print(f"Passed iteration {i}") - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) - # Implement exponential backoff with initial time of 10 seconds + # Exponential backoff if i < 2: backoff_time = 10 * (2**i) time.sleep(backoff_time) -@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) +@pytest.mark.timeout(60) +@pytest.mark.parametrize( + "config_file", + [ + "tests/configs/llm_model_configs/claude-3-5-sonnet.json", + "tests/configs/llm_model_configs/openai-gpt-4o.json", + ], +) +def test_agent_no_structured_output_with_one_child_tool_parametrized( + server, + disable_e2b_api_key, + default_user, + config_file, +): + """Test that agent correctly calls tool chains with unstructured output under various model configs.""" + send_message = server.tool_manager.get_tool_by_name(tool_name="send_message", actor=default_user) + archival_memory_search = server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=default_user) + archival_memory_insert = server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=default_user) - send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user) - archival_memory_search = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_search", actor=client.user) - archival_memory_insert = client.server.tool_manager.get_tool_by_name(tool_name="archival_memory_insert", actor=client.user) + tools = [send_message, archival_memory_search, archival_memory_insert] - # Make tool rules tool_rules = [ InitToolRule(tool_name="archival_memory_search"), ChildToolRule(tool_name="archival_memory_search", children=["archival_memory_insert"]), ChildToolRule(tool_name="archival_memory_insert", children=["send_message"]), TerminalToolRule(tool_name="send_message"), ] - tools = [send_message, archival_memory_search, archival_memory_insert] - config_files = [ - "tests/configs/llm_model_configs/claude-3-5-sonnet.json", - "tests/configs/llm_model_configs/openai-gpt-4o.json", + max_retries = 3 + last_error = None + agent_uuid = str(uuid.uuid4()) + + for attempt in range(max_retries): + try: + agent_state = setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) + + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="hi. run archival memory search")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [] + for m in messages: + letta_messages += m.to_letta_messages() + + response = LettaResponse(messages=letta_messages, usage=usage_stats) + + # Run assertions + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "archival_memory_search") + assert_invoked_function_call(response.messages, "archival_memory_insert") + assert_invoked_function_call(response.messages, "send_message") + + tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] + for m in response.messages: + if isinstance(m, ToolCallMessage): + assert m.tool_call.name == tool_names[0] + tool_names = tool_names[1:] + + print(f"[{config_file}] Got successful response:\n\n{response}") + break # success + + except AssertionError as e: + last_error = e + print(f"[{config_file}] Attempt {attempt + 1} failed") + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) + + if last_error: + raise last_error + + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) + + +@pytest.mark.timeout(30) +@pytest.mark.parametrize("include_base_tools", [False, True]) +def test_init_tool_rule_always_fails( + server, + disable_e2b_api_key, + auto_error_tool, + default_user, + include_base_tools, +): + """Test behavior when InitToolRule invokes a tool that always fails.""" + config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" + agent_uuid = str(uuid.uuid4()) + + tool_rule = InitToolRule(tool_name=auto_error_tool.name) + agent_state = setup_agent( + server, + config_file, + agent_uuid=agent_uuid, + tool_ids=[auto_error_tool.id], + tool_rules=[tool_rule], + include_base_tools=include_base_tools, + ) + + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="blah blah blah")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [msg for m in messages for msg in m.to_letta_messages()] + response = LettaResponse(messages=letta_messages, usage=usage_stats) + + assert_invoked_function_call(response.messages, auto_error_tool.name) + + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) + + +def test_continue_tool_rule(server, default_user): + """Test the continue tool rule by forcing send_message to loop before ending with core_memory_append.""" + config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" + agent_uuid = str(uuid.uuid4()) + + tool_ids = [ + server.tool_manager.get_tool_by_name("send_message", actor=default_user).id, + server.tool_manager.get_tool_by_name("core_memory_append", actor=default_user).id, ] - for config in config_files: - max_retries = 3 - last_error = None + tool_rules = [ + ContinueToolRule(tool_name="send_message"), + TerminalToolRule(tool_name="core_memory_append"), + ] - for attempt in range(max_retries): - try: - agent_state = setup_agent(client, config, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - response = client.user_message(agent_id=agent_state.id, message="hi. run archival memory search") + agent_state = setup_agent( + server, + config_file, + agent_uuid, + tool_ids=tool_ids, + tool_rules=tool_rules, + include_base_tools=False, + include_base_tool_rules=False, + ) - # Make checks - assert_sanity_checks(response) + usage_stats = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")], + ) + messages = [m for step in usage_stats.steps_messages for m in step] + letta_messages = [msg for m in messages for msg in m.to_letta_messages()] + response = LettaResponse(messages=letta_messages, usage=usage_stats) - # Assert the tools were called - assert_invoked_function_call(response.messages, "archival_memory_search") - assert_invoked_function_call(response.messages, "archival_memory_insert") - assert_invoked_function_call(response.messages, "send_message") + assert_invoked_function_call(response.messages, "send_message") + assert_invoked_function_call(response.messages, "core_memory_append") - # Check ordering of tool calls - tool_names = [t.name for t in [archival_memory_search, archival_memory_insert, send_message]] - for m in response.messages: - if isinstance(m, ToolCallMessage): - # Check that it's equal to the first one - assert m.tool_call.name == tool_names[0] + # Check order + send_idx = next(i for i, m in enumerate(response.messages) if isinstance(m, ToolCallMessage) and m.tool_call.name == "send_message") + append_idx = next( + i for i, m in enumerate(response.messages) if isinstance(m, ToolCallMessage) and m.tool_call.name == "core_memory_append" + ) + assert send_idx < append_idx, "send_message should occur before core_memory_append" - # Pop out first one - tool_names = tool_names[1:] - - print(f"Got successful response from client: \n\n{response}") - break # Test passed, exit retry loop - - except AssertionError as e: - last_error = e - print(f"Attempt {attempt + 1} failed, retrying..." if attempt < max_retries - 1 else f"All {max_retries} attempts failed") - cleanup(client=client, agent_uuid=agent_uuid) - continue - - if last_error and attempt == max_retries - 1: - raise last_error # Re-raise the last error if all retries failed - - cleanup(client=client, agent_uuid=agent_uuid) + cleanup(server=server, agent_uuid=agent_uuid, actor=default_user) # @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely @@ -342,7 +535,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # reveal_secret_word # """ # -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # coin_flip_name = "flip_coin" @@ -406,7 +599,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # v # any tool... <-- When output doesn't match mapping, agent can call any tool # """ -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # # Create tools - we'll make several available to the agent @@ -467,7 +660,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # v # fourth_secret_word <-- Should remember coin flip result after reload # """ -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # # Create tools @@ -522,7 +715,7 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # v # fourth_secret_word # """ -# client = create_client() +# # cleanup(client=client, agent_uuid=agent_uuid) # # # Create tools @@ -563,165 +756,3 @@ def test_agent_no_structured_output_with_one_child_tool(disable_e2b_api_key): # assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin" # # cleanup(client, agent_uuid=agent_state.id) - - -def test_init_tool_rule_always_fails_one_tool(): - """ - Test an init tool rule that always fails when called. The agent has only one tool available. - - Once that tool fails and the agent removes that tool, the agent should have 0 tools available. - - This means that the agent should return from `step` early. - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - bad_tool = client.create_or_update_tool(auto_error) - - # Create tool rule: InitToolRule - tool_rule = InitToolRule( - tool_name=bad_tool.name, - ) - - # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False) - - # Start conversation - response = client.user_message(agent_id=agent_state.id, message="blah blah blah") - - # Verify the tool calls - tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] - assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls - assert_invoked_function_call(response.messages, bad_tool.name) - - -def test_init_tool_rule_always_fails_multiple_tools(): - """ - Test an init tool rule that always fails when called. The agent has only 1+ tools available. - Once that tool fails and the agent removes that tool, the agent should have other tools available. - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - bad_tool = client.create_or_update_tool(auto_error) - - # Create tool rule: InitToolRule - tool_rule = InitToolRule( - tool_name=bad_tool.name, - ) - - # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True) - - # Start conversation - response = client.user_message(agent_id=agent_state.id, message="blah blah blah") - - # Verify the tool calls - tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] - assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls - assert_invoked_function_call(response.messages, bad_tool.name) - - -def test_continue_tool_rule(): - """Test the continue tool rule by forcing the send_message tool to continue""" - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - continue_tool_rule = ContinueToolRule( - tool_name="send_message", - ) - terminal_tool_rule = TerminalToolRule( - tool_name="core_memory_append", - ) - rules = [continue_tool_rule, terminal_tool_rule] - - core_memory_append_tool = client.get_tool_id("core_memory_append") - send_message_tool = client.get_tool_id("send_message") - - # Set up agent with the tool rule - claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent( - client, - claude_config, - agent_uuid, - tool_rules=rules, - tool_ids=[core_memory_append_tool, send_message_tool], - include_base_tools=False, - include_base_tool_rules=False, - ) - - # Start conversation - response = client.user_message(agent_id=agent_state.id, message="blah blah blah") - - # Verify the tool calls - tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] - assert len(tool_calls) >= 1 - assert_invoked_function_call(response.messages, "send_message") - assert_invoked_function_call(response.messages, "core_memory_append") - - # ensure send_message called before core_memory_append - send_message_call_index = None - core_memory_append_call_index = None - for i, call in enumerate(tool_calls): - if call.tool_call.name == "send_message": - send_message_call_index = i - if call.tool_call.name == "core_memory_append": - core_memory_append_call_index = i - assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append" - - -@pytest.mark.timeout(60) -@retry_until_success(max_attempts=3, sleep_time_seconds=2) -def test_max_count_per_step_tool_rule_integration(disable_e2b_api_key): - """ - Test an agent with MaxCountPerStepToolRule to ensure a tool can only be called a limited number of times. - - Tool Flow: - repeatable_tool (max 2 times) - | - v - send_message - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - repeatable_tool_name = "first_secret_word" - final_tool_name = "send_message" - - repeatable_tool = client.create_or_update_tool(first_secret_word) - send_message_tool = client.get_tool(client.get_tool_id(final_tool_name)) # Assume send_message is a default tool - - # Define tool rules - tool_rules = [ - InitToolRule(tool_name=repeatable_tool_name), - MaxCountPerStepToolRule(tool_name=repeatable_tool_name, max_count_limit=2), - TerminalToolRule(tool_name=final_tool_name), - ] - - tools = [repeatable_tool, send_message_tool] - - # Setup agent - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # Start conversation - response = client.user_message( - agent_id=agent_state.id, message=f"Keep calling {repeatable_tool_name} nonstop without calling ANY other tool." - ) - - # Make checks - assert_sanity_checks(response) - - # Ensure the repeatable tool is only called twice - count = sum(1 for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == repeatable_tool_name) - assert count == 2, f"Expected 'first_secret_word' to be called exactly 2 times, but got {count}" - - # Ensure send_message was eventually called - assert_invoked_function_call(response.messages, final_tool_name) - - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index b85728db9..e6d54207a 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -1,3 +1,4 @@ +import asyncio import secrets import string import uuid @@ -7,17 +8,16 @@ from unittest.mock import patch import pytest from sqlalchemy import delete -from letta import create_client +from letta.config import LettaConfig from letta.functions.function_sets.base import core_memory_append, core_memory_replace from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.block import CreateBlock from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate from letta.schemas.user import User +from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager @@ -33,6 +33,21 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) # Fixtures +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. + + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + @pytest.fixture(autouse=True) def clear_tables(): """Fixture to clear the organization table before each test.""" @@ -192,12 +207,26 @@ def external_codebase_tool(test_user): @pytest.fixture -def agent_state(): - client = create_client() - agent_state = client.create_agent( - memory=ChatMemory(persona="This is the persona", human="My name is Chad"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), +def agent_state(server): + actor = server.user_manager.get_user_or_default() + agent_state = server.create_agent( + CreateAgent( + memory_blocks=[ + CreateBlock( + label="human", + value="username: sarah", + ), + CreateBlock( + label="persona", + value="This is the persona", + ), + ], + include_base_tools=True, + model="openai/gpt-4o-mini", + tags=["test_agents"], + embedding="letta/letta-free", + ), + actor=actor, ) agent_state.tool_rules = [] yield agent_state @@ -248,12 +277,20 @@ def core_memory_tools(test_user): yield tools +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + # Local sandbox tests @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user): +async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, test_user, event_loop): args = {"x": 10, "y": 5} # Mock and assert correct pathway was invoked @@ -270,7 +307,7 @@ async def test_local_sandbox_default(disable_e2b_api_key, add_integers_tool, tes @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state): +async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memory_tool, test_user, agent_state, event_loop): args = {} sandbox = AsyncToolSandboxLocal(clear_core_memory_tool.name, args, user=test_user) result = await sandbox.run(agent_state=agent_state) @@ -282,7 +319,7 @@ async def test_local_sandbox_stateful_tool(disable_e2b_api_key, clear_core_memor @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user): +async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_user, event_loop): sandbox = AsyncToolSandboxLocal(list_tool.name, {}, user=test_user) result = await sandbox.run() assert len(result.func_return) == 5 @@ -290,7 +327,7 @@ async def test_local_sandbox_with_list_rv(disable_e2b_api_key, list_tool, test_u @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user): +async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user, event_loop): manager = SandboxConfigManager() sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") config_create = SandboxConfigCreate(config=LocalSandboxConfig(sandbox_dir=sandbox_dir).model_dump()) @@ -309,7 +346,7 @@ async def test_local_sandbox_env(disable_e2b_api_key, get_env_tool, test_user): @pytest.mark.asyncio @pytest.mark.local_sandbox -async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user): +async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, agent_state, test_user, event_loop): manager = SandboxConfigManager() key = "secret_word" sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox") @@ -331,7 +368,7 @@ async def test_local_sandbox_per_agent_env(disable_e2b_api_key, get_env_tool, ag @pytest.mark.asyncio @pytest.mark.local_sandbox async def test_local_sandbox_external_codebase_with_venv( - disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user + disable_e2b_api_key, custom_test_sandbox_config, external_codebase_tool, test_user, event_loop ): args = {"percentage": 10} sandbox = AsyncToolSandboxLocal(external_codebase_tool.name, args, user=test_user) @@ -343,7 +380,7 @@ async def test_local_sandbox_external_codebase_with_venv( @pytest.mark.asyncio @pytest.mark.local_sandbox async def test_local_sandbox_with_venv_and_warnings_does_not_error( - disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user + disable_e2b_api_key, custom_test_sandbox_config, get_warning_tool, test_user, event_loop ): sandbox = AsyncToolSandboxLocal(get_warning_tool.name, {}, user=test_user) result = await sandbox.run() @@ -352,7 +389,7 @@ async def test_local_sandbox_with_venv_and_warnings_does_not_error( @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user): +async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_sandbox_config, always_err_tool, test_user, event_loop): sandbox = AsyncToolSandboxLocal(always_err_tool.name, {}, user=test_user) result = await sandbox.run() assert len(result.stdout) != 0 @@ -363,7 +400,7 @@ async def test_local_sandbox_with_venv_errors(disable_e2b_api_key, custom_test_s @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user): +async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, cowsay_tool, test_user, event_loop): manager = SandboxConfigManager() config_create = SandboxConfigCreate( config=LocalSandboxConfig(use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump() @@ -383,7 +420,7 @@ async def test_local_sandbox_with_venv_pip_installs_basic(disable_e2b_api_key, c @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user): +async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_key, cowsay_tool, test_user, event_loop): manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=LocalSandboxConfig(use_venv=True).model_dump()) config = manager.create_or_update_sandbox_config(config_create, test_user) @@ -414,7 +451,7 @@ async def test_local_sandbox_with_venv_pip_installs_with_update(disable_e2b_api_ @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user): +async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user, event_loop): args = {"x": 10, "y": 5} # Mock and assert correct pathway was invoked @@ -431,7 +468,7 @@ async def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user): +async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user, event_loop): manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump()) config = manager.create_or_update_sandbox_config(config_create, test_user) @@ -451,7 +488,7 @@ async def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_ @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state): +async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool, test_user, agent_state, event_loop): sandbox = AsyncToolSandboxE2B(clear_core_memory_tool.name, {}, user=test_user) result = await sandbox.run(agent_state=agent_state) assert result.agent_state.memory.get_block("human").value == "" @@ -461,7 +498,7 @@ async def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user): +async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user, event_loop): manager = SandboxConfigManager() config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump()) config = manager.create_or_update_sandbox_config(config_create, test_user) @@ -485,7 +522,7 @@ async def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user): +async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user, event_loop): manager = SandboxConfigManager() key = "secret_word" wrong_val = "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(20)) @@ -509,7 +546,7 @@ async def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, age @pytest.mark.asyncio @pytest.mark.e2b_sandbox -async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user): +async def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user, event_loop): sandbox = AsyncToolSandboxE2B(list_tool.name, {}, user=test_user) result = await sandbox.run() assert len(result.func_return) == 5 diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index 406d06cde..8a07b9a46 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -185,7 +185,7 @@ async def create_test_llm_batch_job_async(server, batch_response, default_user): ) -def create_test_batch_item(server, batch_id, agent_id, default_user): +async def create_test_batch_item(server, batch_id, agent_id, default_user): """Create a test batch item for the given batch and agent.""" dummy_llm_config = LLMConfig( model="claude-3-7-sonnet-latest", @@ -201,7 +201,7 @@ def create_test_batch_item(server, batch_id, agent_id, default_user): step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")]) ) - return server.batch_manager.create_llm_batch_item( + return await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch_id, agent_id=agent_id, llm_config=dummy_llm_config, @@ -289,9 +289,9 @@ async def test_polling_mixed_batch_jobs(default_user, server): job_b = await create_test_llm_batch_job_async(server, batch_b_resp, default_user) # --- Step 3: Create batch items --- - item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) - item_b = create_test_batch_item(server, job_b.id, agent_b.id, default_user) - item_c = create_test_batch_item(server, job_b.id, agent_c.id, default_user) + item_a = await create_test_batch_item(server, job_a.id, agent_a.id, default_user) + item_b = await create_test_batch_item(server, job_b.id, agent_b.id, default_user) + item_c = await create_test_batch_item(server, job_b.id, agent_c.id, default_user) # --- Step 4: Mock the Anthropic client --- mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b.id, agent_c.id) @@ -316,17 +316,17 @@ async def test_polling_mixed_batch_jobs(default_user, server): # --- Step 7: Verify batch item status updates --- # Item A should remain unchanged - updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) + updated_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user) assert updated_item_a.request_status == JobStatus.created assert updated_item_a.batch_request_result is None # Item B should be marked as completed with a successful result - updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) + updated_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user) assert updated_item_b.request_status == JobStatus.completed assert updated_item_b.batch_request_result is not None # Item C should be marked as failed with an error result - updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) + updated_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user) assert updated_item_c.request_status == JobStatus.failed assert updated_item_c.batch_request_result is not None @@ -352,9 +352,9 @@ async def test_polling_mixed_batch_jobs(default_user, server): # Refresh all objects final_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user) final_job_b = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_b.id, actor=default_user) - final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) - final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) - final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) + final_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user) + final_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user) + final_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user) # Job A should still be polling (last_polled_at should update) assert final_job_a.status == JobStatus.running diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py deleted file mode 100644 index 0b9df3893..000000000 --- a/tests/integration_test_experimental.py +++ /dev/null @@ -1,579 +0,0 @@ -import os -import threading -import time -import uuid - -import httpx -import openai -import pytest -from dotenv import load_dotenv -from letta_client import CreateBlock, Letta, MessageCreate, TextContent -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk - -from letta.agents.letta_agent import LettaAgent -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import MessageStreamStatus -from letta.schemas.letta_message_content import TextContent as LettaTextContent -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import MessageCreate as LettaMessageCreate -from letta.schemas.tool import ToolCreate -from letta.schemas.usage import LettaUsageStatistics -from letta.services.agent_manager import AgentManager -from letta.services.block_manager import BlockManager -from letta.services.message_manager import MessageManager -from letta.services.passage_manager import PassageManager -from letta.services.tool_manager import ToolManager -from letta.services.user_manager import UserManager -from letta.settings import model_settings, settings - -# --- Server Management --- # - - -def _run_server(): - """Starts the Letta server in a background thread.""" - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - -@pytest.fixture(scope="session") -def server_url(): - """Ensures a server is running and returns its base URL.""" - url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - time.sleep(5) # Allow server startup time - - return url - - -# --- Client Setup --- # - - -@pytest.fixture(scope="session") -def client(server_url): - """Creates a REST client for testing.""" - client = Letta(base_url=server_url) - # llm_config = LLMConfig( - # model="claude-3-7-sonnet-latest", - # model_endpoint_type="anthropic", - # model_endpoint="https://api.anthropic.com/v1", - # context_window=32000, - # handle=f"anthropic/claude-3-7-sonnet-latest", - # put_inner_thoughts_in_kwargs=True, - # max_tokens=4096, - # ) - # - # client = create_client(base_url=server_url, token=None) - # client.set_default_llm_config(llm_config) - # client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - yield client - - -@pytest.fixture(scope="function") -def roll_dice_tool(client): - def roll_dice(): - """ - Rolls a 6 sided die. - - Returns: - str: The roll result. - """ - import time - - time.sleep(1) - return "Rolled a 10!" - - # tool = client.create_or_update_tool(func=roll_dice) - tool = client.tools.upsert_from_function(func=roll_dice) - # Yield the created tool - yield tool - - -@pytest.fixture(scope="function") -def weather_tool(client): - def get_weather(location: str) -> str: - """ - Fetches the current weather for a given location. - - Parameters: - location (str): The location to get the weather for. - - Returns: - str: A formatted string describing the weather in the given location. - - Raises: - RuntimeError: If the request to fetch weather data fails. - """ - import requests - - url = f"https://wttr.in/{location}?format=%C+%t" - - response = requests.get(url) - if response.status_code == 200: - weather_data = response.text - return f"The weather in {location} is {weather_data}." - else: - raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - - # tool = client.create_or_update_tool(func=get_weather) - tool = client.tools.upsert_from_function(func=get_weather) - # Yield the created tool - yield tool - - -@pytest.fixture(scope="function") -def rethink_tool(client): - def rethink_memory(agent_state: "AgentState", new_memory: str, target_block_label: str) -> str: # type: ignore - """ - Re-evaluate the memory in block_name, integrating new and updated facts. - Replace outdated information with the most likely truths, avoiding redundancy with original memories. - Ensure consistency with other memory blocks. - - Args: - new_memory (str): The new memory with information integrated from the memory block. If there is no new information, then this should be the same as the content in the source block. - target_block_label (str): The name of the block to write to. - Returns: - str: None is always returned as this function does not produce a response. - """ - agent_state.memory.update_block_value(label=target_block_label, value=new_memory) - return None - - tool = client.tools.upsert_from_function(func=rethink_memory) - # Yield the created tool - yield tool - - -@pytest.fixture(scope="function") -def composio_gmail_get_profile_tool(default_user): - tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE") - tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user) - yield tool - - -@pytest.fixture(scope="function") -def agent_state(client, roll_dice_tool, weather_tool, rethink_tool): - """Creates an agent and ensures cleanup after tests.""" - # llm_config = LLMConfig( - # model="claude-3-7-sonnet-latest", - # model_endpoint_type="anthropic", - # model_endpoint="https://api.anthropic.com/v1", - # context_window=32000, - # handle=f"anthropic/claude-3-7-sonnet-latest", - # put_inner_thoughts_in_kwargs=True, - # max_tokens=4096, - # ) - agent_state = client.agents.create( - name=f"test_compl_{str(uuid.uuid4())[5:]}", - tool_ids=[roll_dice_tool.id, weather_tool.id, rethink_tool.id], - include_base_tools=True, - memory_blocks=[ - { - "label": "human", - "value": "Name: Matt", - }, - { - "label": "persona", - "value": "Friendly agent", - }, - ], - llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - ) - yield agent_state - client.agents.delete(agent_state.id) - - -@pytest.fixture(scope="function") -def openai_client(client, roll_dice_tool, weather_tool): - """Creates an agent and ensures cleanup after tests.""" - client = openai.AsyncClient( - api_key=model_settings.anthropic_api_key, - base_url="https://api.anthropic.com/v1/", - max_retries=0, - http_client=httpx.AsyncClient( - timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0), - follow_redirects=True, - limits=httpx.Limits( - max_connections=50, - max_keepalive_connections=50, - keepalive_expiry=120, - ), - ), - ) - yield client - - -# --- Helper Functions --- # - - -def _assert_valid_chunk(chunk, idx, chunks): - """Validates the structure of each streaming chunk.""" - if isinstance(chunk, ChatCompletionChunk): - assert chunk.choices, "Each ChatCompletionChunk should have at least one choice." - - elif isinstance(chunk, LettaUsageStatistics): - assert chunk.completion_tokens > 0, "Completion tokens must be > 0." - assert chunk.prompt_tokens > 0, "Prompt tokens must be > 0." - assert chunk.total_tokens > 0, "Total tokens must be > 0." - assert chunk.step_count == 1, "Step count must be 1." - - elif isinstance(chunk, MessageStreamStatus): - assert chunk == MessageStreamStatus.done, "Stream should end with 'done' status." - assert idx == len(chunks) - 1, "The last chunk must be 'done'." - - else: - pytest.fail(f"Unexpected chunk type: {chunk}") - - -# --- Test Cases --- # - - -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["What is the weather today in SF?"]) -async def test_new_agent_loop(disable_e2b_api_key, openai_client, agent_state, message): - actor = UserManager().get_user_or_default(user_id="asf") - agent = LettaAgent( - agent_id=agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])]) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("message", ["Use your rethink tool to rethink the human memory considering Matt likes chicken."]) -async def test_rethink_tool(disable_e2b_api_key, openai_client, agent_state, message): - actor = UserManager().get_user_or_default(user_id="asf") - agent = LettaAgent( - agent_id=agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - assert "chicken" not in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value - response = await agent.step([LettaMessageCreate(role="user", content=[LettaTextContent(text=message)])]) - assert "chicken" in AgentManager().get_agent_by_id(agent_state.id, actor).memory.get_block("human").value.lower() - - -@pytest.mark.asyncio -async def test_vertex_send_message_structured_outputs(disable_e2b_api_key, client): - original_experimental_key = settings.use_vertex_structured_outputs_experimental - settings.use_vertex_structured_outputs_experimental = True - try: - actor = UserManager().get_user_or_default(user_id="asf") - - stale_agents = AgentManager().list_agents(actor=actor, limit=300) - for agent in stale_agents: - AgentManager().delete_agent(agent_id=agent.id, actor=actor) - - manager_agent_state = client.agents.create( - name=f"manager", - include_base_tools=False, # change this to True to repro MALFORMED FUNCTION CALL error - tools=["send_message"], - tags=["manager"], - model="google_vertex/gemini-2.5-flash-preview-04-17", - embedding="letta/letta-free", - ) - manager_agent = LettaAgent( - agent_id=manager_agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - response = await manager_agent.step( - [ - LettaMessageCreate( - role="user", - content=[ - LettaTextContent(text=("Check the weather in Seattle.")), - ], - ), - ] - ) - assert len(response.messages) == 3 - assert response.messages[0].message_type == "user_message" - # Shouldn't this have reasoning message? - # assert response.messages[1].message_type == "reasoning_message" - assert response.messages[1].message_type == "assistant_message" - assert response.messages[2].message_type == "tool_return_message" - finally: - settings.use_vertex_structured_outputs_experimental = original_experimental_key - - -@pytest.mark.asyncio -async def test_multi_agent_broadcast(disable_e2b_api_key, client, openai_client, weather_tool): - actor = UserManager().get_user_or_default(user_id="asf") - - stale_agents = AgentManager().list_agents(actor=actor, limit=300) - for agent in stale_agents: - AgentManager().delete_agent(agent_id=agent.id, actor=actor) - - manager_agent_state = client.agents.create( - name=f"manager", - include_base_tools=True, - include_multi_agent_tools=True, - tags=["manager"], - model="openai/gpt-4o", - embedding="letta/letta-free", - ) - manager_agent = LettaAgent( - agent_id=manager_agent_state.id, - message_manager=MessageManager(), - agent_manager=AgentManager(), - block_manager=BlockManager(), - passage_manager=PassageManager(), - actor=actor, - ) - - tag = "subagent" - workers = [] - for idx in range(30): - workers.append( - client.agents.create( - name=f"worker_{idx}", - include_base_tools=True, - tags=[tag], - tool_ids=[weather_tool.id], - model="openai/gpt-4o", - embedding="letta/letta-free", - ), - ) - response = await manager_agent.step( - [ - LettaMessageCreate( - role="user", - content=[ - LettaTextContent( - text=( - "Use the `send_message_to_agents_matching_tags` tool to send a message to agents with " - "tag 'subagent' asking them to check the weather in Seattle." - ) - ), - ], - ), - ] - ) - - -def test_multi_agent_broadcast_client(client: Letta, weather_tool): - # delete any existing worker agents - workers = client.agents.list(tags=["worker"]) - for worker in workers: - client.agents.delete(agent_id=worker.id) - - # create worker agents - num_workers = 10 - for idx in range(num_workers): - client.agents.create( - name=f"worker_{idx}", - include_base_tools=True, - tags=["worker"], - tool_ids=[weather_tool.id], - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - ) - - # create supervisor agent - supervisor = client.agents.create( - name="supervisor", - include_base_tools=True, - include_multi_agent_tools=True, - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - tags=["supervisor"], - ) - - # send a message to the supervisor - import time - - start = time.perf_counter() - response = client.agents.messages.create( - agent_id=supervisor.id, - messages=[ - MessageCreate( - role="user", - content=[ - TextContent( - text="Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag 'worker' asking them to check the weather in Seattle." - ) - ], - ) - ], - ) - end = time.perf_counter() - print("TIME ELAPSED: " + str(end - start)) - for message in response.messages: - print(message) - - -def test_call_weather(client: Letta, weather_tool): - # delete any existing worker agents - workers = client.agents.list(tags=["worker", "supervisor"]) - for worker in workers: - client.agents.delete(agent_id=worker.id) - - # create supervisor agent - supervisor = client.agents.create( - name="supervisor", - include_base_tools=True, - tool_ids=[weather_tool.id], - model="openai/gpt-4o", - embedding="letta/letta-free", - tags=["supervisor"], - ) - - # send a message to the supervisor - import time - - start = time.perf_counter() - response = client.agents.messages.create( - agent_id=supervisor.id, - messages=[ - { - "role": "user", - "content": "What's the weather like in Seattle?", - } - ], - ) - end = time.perf_counter() - print("TIME ELAPSED: " + str(end - start)) - for message in response.messages: - print(message) - - -def run_supervisor_worker_group(client: Letta, weather_tool, group_id: str): - # Delete any existing agents for this group (if rerunning) - existing_workers = client.agents.list(tags=[f"worker-{group_id}"]) - for worker in existing_workers: - client.agents.delete(agent_id=worker.id) - - # Create worker agents - num_workers = 50 - for idx in range(num_workers): - client.agents.create( - name=f"worker_{group_id}_{idx}", - include_base_tools=True, - tags=[f"worker-{group_id}"], - tool_ids=[weather_tool.id], - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - ) - - # Create supervisor agent - supervisor = client.agents.create( - name=f"supervisor_{group_id}", - include_base_tools=True, - include_multi_agent_tools=True, - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - tags=[f"supervisor-{group_id}"], - ) - - # Send message to supervisor to broadcast to workers - response = client.agents.messages.create( - agent_id=supervisor.id, - messages=[ - { - "role": "user", - "content": "Use the `send_message_to_agents_matching_tags` tool to send a message to agents with tag " - f"'worker-{group_id}' asking them to check the weather in Seattle.", - } - ], - ) - - return response - - -def test_anthropic_streaming(client: Letta): - agent_name = "anthropic_tester" - - existing_agents = client.agents.list(tags=[agent_name]) - for worker in existing_agents: - client.agents.delete(agent_id=worker.id) - - llm_config = LLMConfig( - model="claude-3-7-sonnet-20250219", - model_endpoint_type="anthropic", - model_endpoint="https://api.anthropic.com/v1", - context_window=32000, - handle=f"anthropic/claude-3-5-sonnet-20241022", - put_inner_thoughts_in_kwargs=False, - max_tokens=4096, - enable_reasoner=True, - max_reasoning_tokens=1024, - ) - - agent = client.agents.create( - name=agent_name, - tags=[agent_name], - include_base_tools=True, - embedding="letta/letta-free", - llm_config=llm_config, - memory_blocks=[CreateBlock(label="human", value="")], - # tool_rules=[InitToolRule(tool_name="core_memory_append")] - ) - - response = client.agents.messages.create_stream( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content=[TextContent(text="Use the core memory append tool to append `banana` to the persona core memory.")], - ), - ], - stream_tokens=True, - ) - - print(list(response)) - - -import time - - -def test_create_agents_telemetry(client: Letta): - start_total = time.perf_counter() - - # delete any existing worker agents - start_delete = time.perf_counter() - workers = client.agents.list(tags=["worker"]) - for worker in workers: - client.agents.delete(agent_id=worker.id) - end_delete = time.perf_counter() - print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s") - - # create worker agents - num_workers = 1 - agent_times = [] - for idx in range(num_workers): - start = time.perf_counter() - client.agents.create( - name=f"worker_{idx}", - include_base_tools=True, - model="anthropic/claude-3-5-sonnet-20241022", - embedding="letta/letta-free", - ) - end = time.perf_counter() - duration = end - start - agent_times.append(duration) - print(f"[telemetry] Created worker_{idx} in {duration:.2f}s") - - total_duration = time.perf_counter() - start_total - avg_duration = sum(agent_times) / len(agent_times) - - print(f"[telemetry] Total time to create {num_workers} agents: {total_duration:.2f}s") - print(f"[telemetry] Average agent creation time: {avg_duration:.2f}s") - print(f"[telemetry] Fastest agent: {min(agent_times):.2f}s, Slowest agent: {max(agent_times):.2f}s") diff --git a/tests/integration_test_initial_sequence.py b/tests/integration_test_initial_sequence.py deleted file mode 100644 index 714491712..000000000 --- a/tests/integration_test_initial_sequence.py +++ /dev/null @@ -1,65 +0,0 @@ -import os -import threading -import time - -import pytest -from dotenv import load_dotenv -from letta_client import Letta, MessageCreate - - -def run_server(): - load_dotenv() - - from letta.server.rest_api.app import start_server - - print("Starting server...") - start_server(debug=True) - - -@pytest.fixture( - scope="module", -) -def client(request): - # Get URL from environment or start server - server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283") - if not os.getenv("LETTA_SERVER_URL"): - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - - # create the Letta client - yield Letta(base_url=server_url, token=None) - - -def test_initial_sequence(client: Letta): - # create an agent - agent = client.agents.create( - memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], - model="letta/letta-free", - embedding="letta/letta-free", - initial_message_sequence=[ - MessageCreate( - role="assistant", - content="Hello, how are you?", - ), - MessageCreate(role="user", content="I'm good, and you?"), - ], - ) - - # list messages - messages = client.agents.messages.list(agent_id=agent.id) - response = client.agents.messages.create( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content="hello assistant!", - ) - ], - ) - assert len(messages) == 3 - assert messages[0].message_type == "system_message" - assert messages[1].message_type == "assistant_message" - assert messages[2].message_type == "user_message" diff --git a/tests/integration_test_send_message_schema.py b/tests/integration_test_send_message_schema.py deleted file mode 100644 index 57773ec89..000000000 --- a/tests/integration_test_send_message_schema.py +++ /dev/null @@ -1,192 +0,0 @@ -# TODO (cliandy): Tested in SDK -# TODO (cliandy): Comment out after merge - -# import os -# import threading -# import time - -# import pytest -# from dotenv import load_dotenv -# from letta_client import AssistantMessage, AsyncLetta, Letta, Tool - -# from letta.schemas.agent import AgentState -# from typing import List, Any, Dict - -# # ------------------------------ -# # Fixtures -# # ------------------------------ - - -# @pytest.fixture(scope="module") -# def server_url() -> str: -# """ -# Provides the URL for the Letta server. -# If the environment variable 'LETTA_SERVER_URL' is not set, this fixture -# will start the Letta server in a background thread and return the default URL. -# """ - -# def _run_server() -> None: -# """Starts the Letta server in a background thread.""" -# load_dotenv() # Load environment variables from .env file -# from letta.server.rest_api.app import start_server - -# start_server(debug=True) - -# # Retrieve server URL from environment, or default to localhost -# url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - -# # If no environment variable is set, start the server in a background thread -# if not os.getenv("LETTA_SERVER_URL"): -# thread = threading.Thread(target=_run_server, daemon=True) -# thread.start() -# time.sleep(5) # Allow time for the server to start - -# return url - - -# @pytest.fixture -# def client(server_url: str) -> Letta: -# """ -# Creates and returns a synchronous Letta REST client for testing. -# """ -# client_instance = Letta(base_url=server_url) -# yield client_instance - - -# @pytest.fixture -# def async_client(server_url: str) -> AsyncLetta: -# """ -# Creates and returns an asynchronous Letta REST client for testing. -# """ -# async_client_instance = AsyncLetta(base_url=server_url) -# yield async_client_instance - - -# @pytest.fixture -# def roll_dice_tool(client: Letta) -> Tool: -# """ -# Registers a simple roll dice tool with the provided client. - -# The tool simulates rolling a six-sided die but returns a fixed result. -# """ - -# def roll_dice() -> str: -# """ -# Simulates rolling a die. - -# Returns: -# str: The roll result. -# """ -# # Note: The result here is intentionally incorrect for demonstration purposes. -# return "Rolled a 10!" - -# tool = client.tools.upsert_from_function(func=roll_dice) -# yield tool - - -# @pytest.fixture -# def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: -# """ -# Creates and returns an agent state for testing with a pre-configured agent. -# The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. -# """ -# agent_state_instance = client.agents.create( -# name="supervisor", -# include_base_tools=True, -# tool_ids=[roll_dice_tool.id], -# model="openai/gpt-4o", -# embedding="letta/letta-free", -# tags=["supervisor"], -# include_base_tool_rules=True, - -# ) -# yield agent_state_instance - - -# # Goal is to test that when an Agent is created with a `response_format`, that the response -# # of `send_message` is in the correct format. This will be done by modifying the agent's -# # `send_message` tool so that it returns a format based on what is passed in. -# # -# # `response_format` is an optional field -# # if `response_format.type` is `text`, then the schema does not change -# # if `response_format.type` is `json_object`, then the schema is a dict -# # if `response_format.type` is `json_schema`, then the schema is a dict matching that json schema - - -# USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Send me a message."}] - -# # ------------------------------ -# # Test Cases -# # ------------------------------ - -# def test_client_send_message_text_response_format(client: "Letta", agent: "AgentState") -> None: -# """Test client send_message with response_format='json_object'.""" -# client.agents.modify(agent.id, response_format={"type": "text"}) - -# response = client.agents.messages.create_stream( -# agent_id=agent.id, -# messages=USER_MESSAGE, -# ) -# messages = list(response) -# assert isinstance(messages[-1], AssistantMessage) -# assert isinstance(messages[-1].content, str) - - -# def test_client_send_message_json_object_response_format(client: "Letta", agent: "AgentState") -> None: -# """Test client send_message with response_format='json_object'.""" -# client.agents.modify(agent.id, response_format={"type": "json_object"}) - -# response = client.agents.messages.create_stream( -# agent_id=agent.id, -# messages=USER_MESSAGE, -# ) -# messages = list(response) -# assert isinstance(messages[-1], AssistantMessage) -# assert isinstance(messages[-1].content, dict) - - -# def test_client_send_message_json_schema_response_format(client: "Letta", agent: "AgentState") -> None: -# """Test client send_message with response_format='json_schema' and a valid schema.""" -# client.agents.modify(agent.id, response_format={ -# "type": "json_schema", -# "json_schema": { -# "name": "reasoning_schema", -# "schema": { -# "type": "object", -# "properties": { -# "steps": { -# "type": "array", -# "items": { -# "type": "object", -# "properties": { -# "explanation": { "type": "string" }, -# "output": { "type": "string" } -# }, -# "required": ["explanation", "output"], -# "additionalProperties": False -# } -# }, -# "final_answer": { "type": "string" } -# }, -# "required": ["steps", "final_answer"], -# "additionalProperties": True -# }, -# "strict": True -# } -# }) -# response = client.agents.messages.create_stream( -# agent_id=agent.id, -# messages=USER_MESSAGE, -# ) -# messages = list(response) - -# assert isinstance(messages[-1], AssistantMessage) -# assert isinstance(messages[-1].content, dict) - - -# # def test_client_send_message_invalid_json_schema(client: "Letta", agent: "AgentState") -> None: -# # """Test client send_message with an invalid json_schema (should error or fallback).""" -# # invalid_schema: Dict[str, Any] = {"type": "object", "properties": {"foo": {"type": "unknown"}}} -# # client.agents.modify(agent.id, response_format="json_schema") -# # result: Any = client.agents.send_message(agent.id, "Test invalid schema") -# # assert result is None or "error" in str(result).lower() diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index 6c0f74c10..6e0ebd738 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -6,15 +6,16 @@ from typing import List import pytest -from letta import create_client from letta.agent import Agent -from letta.client.client import LocalClient +from letta.config import LettaConfig from letta.llm_api.helpers import calculate_summarizer_cutoff +from letta.schemas.agent import CreateAgent from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message +from letta.schemas.message import Message, MessageCreate +from letta.server.server import SyncServer from letta.streaming_interface import StreamingRefreshCLIInterface from tests.helpers.endpoints_helper import EMBEDDING_CONFIG_PATH from tests.helpers.utils import cleanup @@ -30,22 +31,34 @@ test_agent_name = f"test_client_{str(uuid.uuid4())}" @pytest.fixture(scope="module") -def client(): - client = create_client() - # client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +def server(): + config = LettaConfig.load() + config.save() - yield client + server = SyncServer() + return server @pytest.fixture(scope="module") -def agent_state(client): +def default_user(server): + yield server.user_manager.get_user_or_default() + + +@pytest.fixture(scope="module") +def agent_state(server, default_user): # Generate uuid for agent name for this example - agent_state = client.create_agent(name=test_agent_name) + agent_state = server.create_agent( + CreateAgent( + name=test_agent_name, + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ), + actor=default_user, + ) yield agent_state - client.delete_agent(agent_state.id) + server.agent_manager.delete_agent(agent_state.id, default_user) # Sample data setup @@ -113,9 +126,9 @@ def test_cutoff_calculation(mocker): assert messages[cutoff - 1].role == MessageRole.user -def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_state): +def test_cutoff_calculation_with_tool_call(mocker, server, agent_state, default_user): """Test that trim_older_in_context_messages properly handles tool responses with _trim_tool_response.""" - agent_state = client.get_agent(agent_id=agent_state.id) + agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user) # Setup messages = [ @@ -133,18 +146,18 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st def mock_get_messages_by_ids(message_ids, actor): return [msg for msg in messages if msg.id in message_ids] - mocker.patch.object(client.server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids) + mocker.patch.object(server.agent_manager.message_manager, "get_messages_by_ids", side_effect=mock_get_messages_by_ids) # Mock get_agent_by_id to return an agent with our message IDs mock_agent = mocker.Mock() mock_agent.message_ids = [msg.id for msg in messages] - mocker.patch.object(client.server.agent_manager, "get_agent_by_id", return_value=mock_agent) + mocker.patch.object(server.agent_manager, "get_agent_by_id", return_value=mock_agent) # Mock set_in_context_messages to capture what messages are being set - mock_set_messages = mocker.patch.object(client.server.agent_manager, "set_in_context_messages", return_value=agent_state) + mock_set_messages = mocker.patch.object(server.agent_manager, "set_in_context_messages", return_value=agent_state) # Test Case: Trim to remove orphaned tool response - client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=client.user) + server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=3, actor=default_user) test1 = mock_set_messages.call_args_list[0][1] assert len(test1["message_ids"]) == 5 @@ -152,104 +165,92 @@ def test_cutoff_calculation_with_tool_call(mocker, client: LocalClient, agent_st mock_set_messages.reset_mock() # Test Case: Does not result in trimming the orphaned tool response - client.server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=client.user) + server.agent_manager.trim_older_in_context_messages(agent_id=agent_state.id, num=2, actor=default_user) test2 = mock_set_messages.call_args_list[0][1] assert len(test2["message_ids"]) == 6 -def test_summarize_many_messages_basic(client, disable_e2b_api_key): +def test_summarize_many_messages_basic(server, default_user): + """Test that a small-context agent gets enough messages for summarization.""" small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") small_context_llm_config.context_window = 3000 - small_agent_state = client.create_agent( - name="small_context_agent", - llm_config=small_context_llm_config, - ) - for _ in range(10): - client.user_message( - agent_id=small_agent_state.id, - message="hi " * 60, - ) - client.delete_agent(small_agent_state.id) - -def test_summarize_messages_inplace(client, agent_state, disable_e2b_api_key): - """Test summarization via sending the summarize CLI command or via a direct call to the agent object""" - # First send a few messages (5) - response = client.user_message( - agent_id=agent_state.id, - message="Hey, how's it going? What do you think about this whole shindig", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_state.id, - message="Any thoughts on the meaning of life?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?").messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - response = client.user_message( - agent_id=agent_state.id, - message="Would you be surprised to learn that you're actually conversing with an AI right now?", - ).messages - assert response is not None and len(response) > 0 - print(f"test_summarize: response={response}") - - # reload agent object - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - - agent_obj.summarize_messages_inplace() - - -def test_auto_summarize(client, disable_e2b_api_key): - """Test that the summarizer triggers by itself""" - small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") - small_context_llm_config.context_window = 4000 - - small_agent_state = client.create_agent( - name="small_context_agent", - llm_config=small_context_llm_config, + agent_state = server.create_agent( + CreateAgent( + name="small_context_agent", + llm_config=small_context_llm_config, + embedding="letta/letta-free", + ), + actor=default_user, ) try: - - def summarize_message_exists(messages: List[Message]) -> bool: - for message in messages: - if message.content[0].text and "The following is a summary of the previous" in message.content[0].text: - print(f"Summarize message found after {message_count} messages: \n {message.content[0].text}") - return True - return False - - MAX_ATTEMPTS = 10 - message_count = 0 - while True: - - # send a message - response = client.user_message( - agent_id=small_agent_state.id, - message="What is the meaning of life?", + for _ in range(10): + server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="hi " * 60)], ) - message_count += 1 - - print(f"Message {message_count}: \n\n{response.messages}" + "--------------------------------") - - # check if the summarize message is inside the messages - assert isinstance(client, LocalClient), "Test only works with LocalClient" - in_context_messages = client.server.agent_manager.get_in_context_messages(agent_id=small_agent_state.id, actor=client.user) - print("SUMMARY", summarize_message_exists(in_context_messages)) - if summarize_message_exists(in_context_messages): - break - - if message_count > MAX_ATTEMPTS: - raise Exception(f"Summarize message not found after {message_count} messages") - finally: - client.delete_agent(small_agent_state.id) + server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user) + + +def test_summarize_messages_inplace(server, agent_state, default_user): + """Test summarization logic via agent object API.""" + for msg in [ + "Hey, how's it going? What do you think about this whole shindig?", + "Any thoughts on the meaning of life?", + "Does the number 42 ring a bell?", + "Would you be surprised to learn that you're actually conversing with an AI right now?", + ]: + response = server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content=msg)], + ) + assert response.steps_messages + + agent = server.load_agent(agent_id=agent_state.id, actor=default_user) + agent.summarize_messages_inplace() + + +def test_auto_summarize(server, default_user): + """Test that summarization is automatically triggered.""" + small_context_llm_config = LLMConfig.default_config("gpt-4o-mini") + small_context_llm_config.context_window = 3000 + + agent_state = server.create_agent( + CreateAgent( + name="small_context_agent", + llm_config=small_context_llm_config, + embedding="letta/letta-free", + ), + actor=default_user, + ) + + def summarize_message_exists(messages: List[Message]) -> bool: + for message in messages: + if message.content[0].text and "The following is a summary of the previous" in message.content[0].text: + return True + return False + + try: + MAX_ATTEMPTS = 10 + for attempt in range(MAX_ATTEMPTS): + server.send_messages( + actor=default_user, + agent_id=agent_state.id, + input_messages=[MessageCreate(role="user", content="What is the meaning of life?")], + ) + + in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_state.id, actor=default_user) + + if summarize_message_exists(in_context_messages): + return + + raise AssertionError("Summarization was not triggered after 10 messages") + finally: + server.agent_manager.delete_agent(agent_id=agent_state.id, actor=default_user) @pytest.mark.parametrize( @@ -258,51 +259,53 @@ def test_auto_summarize(client, disable_e2b_api_key): "openai-gpt-4o.json", "azure-gpt-4o-mini.json", "claude-3-5-haiku.json", - # "groq.json", TODO: Support groq, rate limiting currently makes it impossible to test - # "gemini-pro.json", TODO: Gemini is broken + # "groq.json", # rate limits + # "gemini-pro.json", # broken ], ) -def test_summarizer(config_filename, client, agent_state): +def test_summarizer(config_filename, server, default_user): + """Test summarization across different LLM configs.""" namespace = uuid.NAMESPACE_DNS agent_name = str(uuid.uuid5(namespace, f"integration-test-summarizer-{config_filename}")) - # Get the LLM config - filename = os.path.join(LLM_CONFIG_DIR, config_filename) - config_data = json.load(open(filename, "r")) - - # Create client and clean up agents + # Load configs + config_data = json.load(open(os.path.join(LLM_CONFIG_DIR, config_filename))) llm_config = LLMConfig(**config_data) embedding_config = EmbeddingConfig(**json.load(open(EMBEDDING_CONFIG_PATH))) - client = create_client() - client.set_default_llm_config(llm_config) - client.set_default_embedding_config(embedding_config) - cleanup(client=client, agent_uuid=agent_name) + + # Ensure cleanup + cleanup(server=server, agent_uuid=agent_name, actor=default_user) # Create agent - agent_state = client.create_agent(name=agent_name, llm_config=llm_config, embedding_config=embedding_config) - full_agent_state = client.get_agent(agent_id=agent_state.id) + agent_state = server.create_agent( + CreateAgent( + name=agent_name, + llm_config=llm_config, + embedding_config=embedding_config, + ), + actor=default_user, + ) + + full_agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_state.id, actor=default_user) + letta_agent = Agent( interface=StreamingRefreshCLIInterface(), agent_state=full_agent_state, first_message_verify_mono=False, - user=client.user, + user=default_user, ) - # Make conversation - messages = [ + for msg in [ "Did you know that honey never spoils? Archaeologists have found pots of honey in ancient Egyptian tombs that are over 3,000 years old and still perfectly edible.", "Octopuses have three hearts, and two of them stop beating when they swim.", - ] - - for m in messages: + ]: letta_agent.step_user_message( - user_message_str=m, + user_message_str=msg, first_message=False, skip_verify=False, stream=False, ) - # Invoke a summarize letta_agent.summarize_messages_inplace() - in_context_messages = client.get_in_context_messages(agent_state.id) + in_context_messages = server.agent_manager.get_in_context_messages(agent_state.id, actor=default_user) assert SUMMARY_KEY_PHRASE in in_context_messages[1].content[0].text, f"Test failed for config: {config_filename}" diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 720922f2b..1a9bd7638 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -7,17 +7,16 @@ from unittest.mock import patch import pytest from sqlalchemy import delete -from letta import create_client +from letta.config import LettaConfig from letta.functions.function_sets.base import core_memory_append, core_memory_replace from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.agent import AgentState, CreateAgent +from letta.schemas.block import CreateBlock from letta.schemas.environment_variables import AgentEnvironmentVariable, SandboxEnvironmentVariableCreate -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, PipRequirement, SandboxConfigCreate, SandboxConfigUpdate from letta.schemas.user import User +from letta.server.server import SyncServer from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSandbox @@ -32,6 +31,21 @@ user_name = str(uuid.uuid5(namespace, "test-tool-execution-sandbox-user")) # Fixtures +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. + + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + @pytest.fixture(autouse=True) def clear_tables(): """Fixture to clear the organization table before each test.""" @@ -191,12 +205,26 @@ def external_codebase_tool(test_user): @pytest.fixture -def agent_state(): - client = create_client() - agent_state = client.create_agent( - memory=ChatMemory(persona="This is the persona", human="My name is Chad"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), +def agent_state(server): + actor = server.user_manager.get_user_or_default() + agent_state = server.create_agent( + CreateAgent( + memory_blocks=[ + CreateBlock( + label="human", + value="username: sarah", + ), + CreateBlock( + label="persona", + value="This is the persona", + ), + ], + include_base_tools=True, + model="openai/gpt-4o-mini", + tags=["test_agents"], + embedding="letta/letta-free", + ), + actor=actor, ) agent_state.tool_rules = [] yield agent_state diff --git a/tests/manual_test_many_messages.py b/tests/manual_test_many_messages.py index 6aaa33bb4..df71dd859 100644 --- a/tests/manual_test_many_messages.py +++ b/tests/manual_test_many_messages.py @@ -1,7 +1,6 @@ import datetime import json import math -import os import random import uuid @@ -9,14 +8,11 @@ import pytest from faker import Faker from tqdm import tqdm -from letta import create_client +from letta.config import LettaConfig from letta.orm import Base -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message -from letta.services.agent_manager import AgentManager -from letta.services.message_manager import MessageManager -from tests.integration_test_summarizer import LLM_CONFIG_DIR +from letta.schemas.agent import CreateAgent +from letta.schemas.message import Message, MessageCreate +from letta.server.server import SyncServer @pytest.fixture(autouse=True) @@ -29,16 +25,25 @@ def truncate_database(): session.commit() -@pytest.fixture(scope="function") -def client(): - filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-sonnet.json") - config_data = json.load(open(filename, "r")) - llm_config = LLMConfig(**config_data) - client = create_client() - client.set_default_llm_config(llm_config) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. - yield client + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + +@pytest.fixture +def default_user(server): + actor = server.user_manager.get_user_or_default() + yield actor def generate_tool_call_id(): @@ -129,14 +134,13 @@ def create_tool_message(agent_id, organization_id, tool_call_id, timestamp): @pytest.mark.parametrize("num_messages", [1000]) -def test_many_messages_performance(client, num_messages): - """Main test function to generate messages and insert them into the database.""" - message_manager = MessageManager() - agent_manager = AgentManager() - actor = client.user +def test_many_messages_performance(server, default_user, num_messages): + """Performance test to insert many messages and ensure retrieval works correctly.""" + message_manager = server.agent_manager.message_manager + agent_manager = server.agent_manager start_time = datetime.datetime.now() - last_event_time = start_time # Track last event time + last_event_time = start_time def log_event(event): nonlocal last_event_time @@ -144,11 +148,19 @@ def test_many_messages_performance(client, num_messages): total_elapsed = (now - start_time).total_seconds() step_elapsed = (now - last_event_time).total_seconds() print(f"[+{total_elapsed:.3f}s | Δ{step_elapsed:.3f}s] {event}") - last_event_time = now # Update last event time + last_event_time = now log_event(f"Starting test with {num_messages} messages") - agent_state = client.create_agent(name="manager") + agent_state = server.create_agent( + CreateAgent( + name="manager", + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + ), + actor=default_user, + ) log_event(f"Created agent with ID {agent_state.id}") message_group_size = 3 @@ -158,37 +170,42 @@ def test_many_messages_performance(client, num_messages): organization_id = "org-00000000-0000-4000-8000-000000000000" all_messages = [] - for _ in tqdm(range(num_groups)): user_text, assistant_text = get_conversation_pair() tool_call_id = generate_tool_call_id() user_time, send_time, tool_time, current_time = generate_timestamps(current_time) - new_messages = [ - Message(**create_user_message(agent_state.id, organization_id, user_text, user_time)), - Message(**create_send_message(agent_state.id, organization_id, assistant_text, tool_call_id, send_time)), - Message(**create_tool_message(agent_state.id, organization_id, tool_call_id, tool_time)), - ] - all_messages.extend(new_messages) + + all_messages.extend( + [ + Message(**create_user_message(agent_state.id, organization_id, user_text, user_time)), + Message(**create_send_message(agent_state.id, organization_id, assistant_text, tool_call_id, send_time)), + Message(**create_tool_message(agent_state.id, organization_id, tool_call_id, tool_time)), + ] + ) log_event(f"Finished generating {len(all_messages)} messages") - message_manager.create_many_messages(all_messages, actor=actor) + message_manager.create_many_messages(all_messages, actor=default_user) log_event("Inserted messages into the database") agent_manager.set_in_context_messages( - agent_id=agent_state.id, message_ids=agent_state.message_ids + [m.id for m in all_messages], actor=client.user + agent_id=agent_state.id, + message_ids=agent_state.message_ids + [m.id for m in all_messages], + actor=default_user, ) log_event("Updated agent context with messages") - messages = message_manager.list_messages_for_agent(agent_id=agent_state.id, actor=client.user, limit=1000000000) + messages = message_manager.list_messages_for_agent( + agent_id=agent_state.id, + actor=default_user, + limit=1000000000, + ) log_event(f"Retrieved {len(messages)} messages from the database") assert len(messages) >= num_groups * message_group_size - response = client.send_message( - agent_id=agent_state.id, - role="user", - message="What have we been talking about?", + response = server.send_messages( + actor=default_user, agent_id=agent_state.id, input_messages=[MessageCreate(role="user", content="What have we been talking about?")] ) log_event("Sent message to agent and received response") diff --git a/tests/manual_test_multi_agent_broadcast_large.py b/tests/manual_test_multi_agent_broadcast_large.py index 70d88f446..3d406d847 100644 --- a/tests/manual_test_multi_agent_broadcast_large.py +++ b/tests/manual_test_multi_agent_broadcast_large.py @@ -1,89 +1,98 @@ -import json -import os - import pytest from tqdm import tqdm -from letta import create_client -from letta.functions.functions import derive_openai_json_schema, parse_source_code -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.tool import Tool -from tests.integration_test_summarizer import LLM_CONFIG_DIR +from letta.config import LettaConfig +from letta.schemas.agent import CreateAgent +from letta.schemas.message import MessageCreate +from letta.server.server import SyncServer +from tests.utils import create_tool_from_func -@pytest.fixture(scope="function") -def client(): - filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-haiku.json") - config_data = json.load(open(filename, "r")) - llm_config = LLMConfig(**config_data) - client = create_client() - client.set_default_llm_config(llm_config) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. - yield client + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server @pytest.fixture -def roll_dice_tool(client): +def default_user(server): + actor = server.user_manager.get_user_or_default() + yield actor + + +@pytest.fixture +def roll_dice_tool(server, default_user): def roll_dice(): """ - Rolls a 6 sided die. + Rolls a 6-sided die. Returns: - str: The roll result. + str: Result of the die roll. """ return "Rolled a 5!" - # Set up tool details - source_code = parse_source_code(roll_dice) - source_type = "python" - description = "test_description" - tags = ["test"] - - tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type) - derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) - - derived_name = derived_json_schema["name"] - tool.json_schema = derived_json_schema - tool.name = derived_name - - tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user) - - # Yield the created tool - yield tool + tool = create_tool_from_func(func=roll_dice) + created_tool = server.tool_manager.create_or_update_tool(tool, actor=default_user) + yield created_tool @pytest.mark.parametrize("num_workers", [50]) -def test_multi_agent_large(client, roll_dice_tool, num_workers): +def test_multi_agent_large(server, default_user, roll_dice_tool, num_workers): manager_tags = ["manager"] worker_tags = ["helpers"] - # Clean up first from possibly failed tests - prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags + manager_tags, match_all_tags=True) - for agent in prev_worker_agents: - client.delete_agent(agent.id) + # Cleanup any pre-existing agents with both tags + prev_agents = server.agent_manager.list_agents(actor=default_user, tags=worker_tags + manager_tags, match_all_tags=True) + for agent in prev_agents: + server.agent_manager.delete_agent(agent.id, actor=default_user) - # Create "manager" agent - send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags") - manager_agent_state = client.create_agent(name="manager", tool_ids=[send_message_to_agents_matching_tags_tool_id], tags=manager_tags) - manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) - - # Create 3 worker agents - worker_agents = [] - for idx in tqdm(range(num_workers)): - worker_agent_state = client.create_agent( - name=f"worker-{idx}", include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id] - ) - worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) - worker_agents.append(worker_agent) - - # Encourage the manager to send a message to the other agent_obj with the secret string - broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" - client.send_message( - agent_id=manager_agent.agent_state.id, - role="user", - message=broadcast_message, + # Create "manager" agent with multi-agent broadcast tool + send_message_tool_id = server.tool_manager.get_tool_id(tool_name="send_message_to_agents_matching_tags", actor=default_user) + manager_agent_state = server.create_agent( + CreateAgent( + name="manager", + tool_ids=[send_message_tool_id], + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + tags=manager_tags, + ), + actor=default_user, ) - # Please manually inspect the agent results + manager_agent = server.load_agent(agent_id=manager_agent_state.id, actor=default_user) + + # Create N worker agents + worker_agents = [] + for idx in tqdm(range(num_workers)): + worker_agent_state = server.create_agent( + CreateAgent( + name=f"worker-{idx}", + tool_ids=[roll_dice_tool.id], + include_multi_agent_tools=False, + include_base_tools=True, + model="openai/gpt-4o-mini", + embedding="letta/letta-free", + tags=worker_tags, + ), + actor=default_user, + ) + worker_agent = server.load_agent(agent_id=worker_agent_state.id, actor=default_user) + worker_agents.append(worker_agent) + + # Manager sends broadcast message + broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" + server.send_messages( + actor=default_user, + agent_id=manager_agent.agent_state.id, + input_messages=[MessageCreate(role="user", content=broadcast_message)], + ) diff --git a/tests/test_agent_serialization.py b/tests/test_agent_serialization.py index aa02e0dfa..000db3b98 100644 --- a/tests/test_agent_serialization.py +++ b/tests/test_agent_serialization.py @@ -13,7 +13,6 @@ from dotenv import load_dotenv from rich.console import Console from rich.syntax import Syntax -from letta import create_client from letta.config import LettaConfig from letta.orm import Base from letta.orm.enums import ToolType @@ -27,6 +26,7 @@ from letta.schemas.organization import Organization from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.server import SyncServer +from tests.utils import create_tool_from_func console = Console() @@ -86,14 +86,6 @@ def clear_tables(): _clear_tables() -@pytest.fixture(scope="module") -def local_client(): - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - yield client - - @pytest.fixture def server(): config = LettaConfig.load() @@ -133,14 +125,14 @@ def other_user(server: SyncServer, other_organization): @pytest.fixture -def weather_tool(local_client, weather_tool_func): - weather_tool = local_client.create_or_update_tool(func=weather_tool_func) +def weather_tool(server, weather_tool_func, default_user): + weather_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=weather_tool_func), actor=default_user) yield weather_tool @pytest.fixture -def print_tool(local_client, print_tool_func): - print_tool = local_client.create_or_update_tool(func=print_tool_func) +def print_tool(server, print_tool_func, default_user): + print_tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=print_tool_func), actor=default_user) yield print_tool @@ -438,7 +430,7 @@ def test_sanity_datetime_mismatch(): # Agent serialize/deserialize tests -def test_deserialize_simple(local_client, server, serialize_test_agent, default_user, other_user): +def test_deserialize_simple(server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False result = server.agent_manager.serialize(agent_id=serialize_test_agent.id, actor=default_user) @@ -452,9 +444,7 @@ def test_deserialize_simple(local_client, server, serialize_test_agent, default_ @pytest.mark.parametrize("override_existing_tools", [True, False]) -def test_deserialize_override_existing_tools( - local_client, server, serialize_test_agent, default_user, weather_tool, print_tool, override_existing_tools -): +def test_deserialize_override_existing_tools(server, serialize_test_agent, default_user, weather_tool, print_tool, override_existing_tools): """ Test deserializing an agent with tools and ensure correct behavior for overriding existing tools. """ @@ -487,7 +477,7 @@ def test_deserialize_override_existing_tools( assert existing_tool.source_code == weather_tool.source_code, f"Tool {tool_name} should NOT be overridden" -def test_agent_serialize_with_user_messages(local_client, server, serialize_test_agent, default_user, other_user): +def test_agent_serialize_with_user_messages(server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False server.send_messages( @@ -516,7 +506,7 @@ def test_agent_serialize_with_user_messages(local_client, server, serialize_test ) -def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, serialize_test_agent, default_user, other_user): +def test_agent_serialize_tool_calls(disable_e2b_api_key, server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False server.send_messages( @@ -552,7 +542,7 @@ def test_agent_serialize_tool_calls(disable_e2b_api_key, local_client, server, s assert copy_agent_response.completion_tokens > 0 and copy_agent_response.step_count > 0 -def test_agent_serialize_update_blocks(disable_e2b_api_key, local_client, server, serialize_test_agent, default_user, other_user): +def test_agent_serialize_update_blocks(disable_e2b_api_key, server, serialize_test_agent, default_user, other_user): """Test deserializing JSON into an Agent instance.""" append_copy_suffix = False server.send_messages( diff --git a/tests/test_ast_parsing.py b/tests/test_ast_parsing.py deleted file mode 100644 index 312e3a0c9..000000000 --- a/tests/test_ast_parsing.py +++ /dev/null @@ -1,275 +0,0 @@ -import pytest - -from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source - -# ----------------------------------------------------------------------- -# Example source code for testing multiple scenarios, including: -# 1) A class-based custom type (which we won't handle properly). -# 2) Functions with multiple argument types. -# 3) A function with default arguments. -# 4) A function with no arguments. -# 5) A function that shares the same name as another symbol. -# ----------------------------------------------------------------------- -example_source_code = r""" -class CustomClass: - def __init__(self, x): - self.x = x - -def unrelated_symbol(): - pass - -def no_args_func(): - pass - -def default_args_func(x: int = 5, y: str = "hello"): - return x, y - -def my_function(a: int, b: float, c: str, d: list, e: dict, f: CustomClass = None): - pass - -def my_function_duplicate(): - # This function shares the name "my_function" partially, but isn't an exact match - pass -""" - - -# --------------------- get_function_annotations_from_source TESTS --------------------- # - - -def test_get_function_annotations_found(): - """ - Test that we correctly parse annotations for a function - that includes multiple argument types and a custom class. - """ - annotations = get_function_annotations_from_source(example_source_code, "my_function") - assert annotations == { - "a": "int", - "b": "float", - "c": "str", - "d": "list", - "e": "dict", - "f": "CustomClass", - } - - -def test_get_function_annotations_not_found(): - """ - If the requested function name doesn't exist exactly, - we should raise a ValueError. - """ - with pytest.raises(ValueError, match="Function 'missing_function' not found"): - get_function_annotations_from_source(example_source_code, "missing_function") - - -def test_get_function_annotations_no_args(): - """ - Check that a function without arguments returns an empty annotations dict. - """ - annotations = get_function_annotations_from_source(example_source_code, "no_args_func") - assert annotations == {} - - -def test_get_function_annotations_with_default_values(): - """ - Ensure that a function with default arguments still captures the annotations. - """ - annotations = get_function_annotations_from_source(example_source_code, "default_args_func") - assert annotations == {"x": "int", "y": "str"} - - -def test_get_function_annotations_partial_name_collision(): - """ - Ensure we only match the exact function name, not partial collisions. - """ - # This will match 'my_function' exactly, ignoring 'my_function_duplicate' - annotations = get_function_annotations_from_source(example_source_code, "my_function") - assert "a" in annotations # Means it matched the correct function - # No error expected here, just making sure we didn't accidentally parse "my_function_duplicate". - - -# --------------------- coerce_dict_args_by_annotations TESTS --------------------- # - - -def test_coerce_dict_args_success(): - """ - Basic success scenario with standard types: - int, float, str, list, dict. - """ - annotations = {"a": "int", "b": "float", "c": "str", "d": "list", "e": "dict"} - function_args = {"a": "42", "b": "3.14", "c": 123, "d": "[1, 2, 3]", "e": '{"key": "value"}'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["b"] == 3.14 - assert coerced_args["c"] == "123" - assert coerced_args["d"] == [1, 2, 3] - assert coerced_args["e"] == {"key": "value"} - - -def test_coerce_dict_args_invalid_type(): - """ - If the value cannot be coerced into the annotation, - a ValueError should be raised. - """ - annotations = {"a": "int"} - function_args = {"a": "invalid_int"} - - with pytest.raises(ValueError, match="Failed to coerce argument 'a' to int"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_no_annotations(): - """ - If there are no annotations, we do no coercion. - """ - annotations = {} - function_args = {"a": 42, "b": "hello"} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args == function_args # Exactly the same dict back - - -def test_coerce_dict_args_partial_annotations(): - """ - Only coerce annotated arguments; leave unannotated ones unchanged. - """ - annotations = {"a": "int"} - function_args = {"a": "42", "b": "no_annotation"} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["b"] == "no_annotation" - - -def test_coerce_dict_args_with_missing_args(): - """ - If function_args lacks some keys listed in annotations, - those are simply not coerced. (We do not add them.) - """ - annotations = {"a": "int", "b": "float"} - function_args = {"a": "42"} # Missing 'b' - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert "b" not in coerced_args - - -def test_coerce_dict_args_unexpected_keys(): - """ - If function_args has extra keys not in annotations, - we leave them alone. - """ - annotations = {"a": "int"} - function_args = {"a": "42", "z": 999} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["z"] == 999 # unchanged - - -def test_coerce_dict_args_unsupported_custom_class(): - """ - If someone tries to pass an annotation that isn't supported (like a custom class), - we should raise a ValueError (or similarly handle the error) rather than silently - accept it. - """ - annotations = {"f": "CustomClass"} # We can't resolve this - function_args = {"f": {"x": 1}} - with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass: Unsupported annotation: CustomClass"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_with_complex_types(): - """ - Confirm the ability to parse built-in complex data (lists, dicts, etc.) - when given as strings. - """ - annotations = {"big_list": "list", "nested_dict": "dict"} - function_args = {"big_list": "[1, 2, [3, 4], {'five': 5}]", "nested_dict": '{"alpha": [10, 20], "beta": {"x": 1, "y": 2}}'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["big_list"] == [1, 2, [3, 4], {"five": 5}] - assert coerced_args["nested_dict"] == { - "alpha": [10, 20], - "beta": {"x": 1, "y": 2}, - } - - -def test_coerce_dict_args_non_string_keys(): - """ - Validate behavior if `function_args` includes non-string keys. - (We should simply skip annotation checks for them.) - """ - annotations = {"a": "int"} - function_args = {123: "42", "a": "42"} - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - # 'a' is coerced to int - assert coerced_args["a"] == 42 - # 123 remains untouched - assert coerced_args[123] == "42" - - -def test_coerce_dict_args_non_parseable_list_or_dict(): - """ - Test passing incorrectly formatted JSON for a 'list' or 'dict' annotation. - """ - annotations = {"bad_list": "list", "bad_dict": "dict"} - function_args = {"bad_list": "[1, 2, 3", "bad_dict": '{"key": "value"'} # missing brackets - - with pytest.raises(ValueError, match="Failed to coerce argument 'bad_list' to list"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_with_complex_list_annotation(): - """ - Test coercion when list with type annotation (e.g., list[int]) is used. - """ - annotations = {"a": "list[int]"} - function_args = {"a": "[1, 2, 3]"} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == [1, 2, 3] - - -def test_coerce_dict_args_with_complex_dict_annotation(): - """ - Test coercion when dict with type annotation (e.g., dict[str, int]) is used. - """ - annotations = {"a": "dict[str, int]"} - function_args = {"a": '{"x": 1, "y": 2}'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == {"x": 1, "y": 2} - - -def test_coerce_dict_args_unsupported_complex_annotation(): - """ - If an unsupported complex annotation is used (e.g., a custom class), - a ValueError should be raised. - """ - annotations = {"f": "CustomClass[int]"} - function_args = {"f": "CustomClass(42)"} - - with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]"): - coerce_dict_args_by_annotations(function_args, annotations) - - -def test_coerce_dict_args_with_nested_complex_annotation(): - """ - Test coercion with complex nested types like list[dict[str, int]]. - """ - annotations = {"a": "list[dict[str, int]]"} - function_args = {"a": '[{"x": 1}, {"y": 2}]'} - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == [{"x": 1}, {"y": 2}] - - -def test_coerce_dict_args_with_default_arguments(): - """ - Test coercion with default arguments, where some arguments have defaults in the source code. - """ - annotations = {"a": "int", "b": "str"} - function_args = {"a": "42"} - - function_args.setdefault("b", "hello") # Setting the default value for 'b' - - coerced_args = coerce_dict_args_by_annotations(function_args, annotations) - assert coerced_args["a"] == 42 - assert coerced_args["b"] == "hello" diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 2408d55a8..b5cf01e2e 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -6,9 +6,11 @@ from dotenv import load_dotenv from letta_client import Letta import letta.functions.function_sets.base as base_functions -from letta import LocalClient, create_client +from letta.config import LettaConfig from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from letta.server.server import SyncServer from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import ( get_finish_rethinking_memory_schema, get_rethink_user_memory_schema, @@ -18,15 +20,6 @@ from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import ( from tests.utils import wait_for_server -@pytest.fixture(scope="function") -def client(): - client = create_client() - client.set_default_llm_config(LLMConfig.default_config("gpt-4o")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - yield client - - def _run_server(): """Starts the Letta server in a background thread.""" load_dotenv() @@ -35,6 +28,21 @@ def _run_server(): start_server(debug=True) +@pytest.fixture(scope="module") +def server(): + """ + Creates a SyncServer instance for testing. + + Loads and saves config to ensure proper initialization. + """ + config = LettaConfig.load() + + config.save() + + server = SyncServer(init_with_default_org_and_user=True) + yield server + + @pytest.fixture(scope="session") def server_url(): """Ensures a server is running and returns its base URL.""" @@ -57,16 +65,29 @@ def letta_client(server_url): @pytest.fixture(scope="function") -def agent_obj(client: LocalClient): +def agent_obj(letta_client, server): """Create a test agent that we can call functions on""" - send_message_to_agent_and_wait_for_reply_tool_id = client.get_tool_id(name="send_message_to_agent_and_wait_for_reply") - agent_state = client.create_agent(tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id]) - - agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user) + send_message_to_agent_and_wait_for_reply_tool_id = letta_client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0].id + agent_state = letta_client.agents.create( + tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id], + include_base_tools=True, + memory_blocks=[ + { + "label": "human", + "value": "Name: Matt", + }, + { + "label": "persona", + "value": "Friendly agent", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ) + actor = server.user_manager.get_user_or_default() + agent_obj = server.load_agent(agent_id=agent_state.id, actor=actor) yield agent_obj - # client.delete_agent(agent_obj.agent_state.id) - def query_in_search_results(search_results, query): for result in search_results: @@ -127,17 +148,20 @@ def test_archival(agent_obj): pass -def test_recall_self(client, agent_obj): - # keyword +def test_recall(server, agent_obj, default_user): + """Test that an agent can recall messages using a keyword via conversation search.""" keyword = "banana" keyword_backwards = "".join(reversed(keyword)) - # Send messages to agent - client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="hello") - client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="what word is '{}' backwards?".format(keyword_backwards)) - client.send_message(agent_id=agent_obj.agent_state.id, role="user", message="tell me a fun fact") + # Send messages + for msg in ["hello", keyword, "tell me a fun fact"]: + server.send_messages( + actor=default_user, + agent_id=agent_obj.agent_state.id, + input_messages=[MessageCreate(role="user", content=msg)], + ) - # Conversation search + # Search memory result = base_functions.conversation_search(agent_obj, "banana") assert keyword in result diff --git a/tests/test_client.py b/tests/test_client.py index 8384f10f7..3938671d0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -110,58 +110,6 @@ def clear_tables(): session.commit() -# TODO: add back -# def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]): -# """ -# Test sandbox config and environment variable functions for both LocalClient and RESTClient. -# """ -# -# # 1. Create a sandbox config -# local_config = LocalSandboxConfig(sandbox_dir=SANDBOX_DIR) -# sandbox_config = client.create_sandbox_config(config=local_config) -# -# # Assert the created sandbox config -# assert sandbox_config.id is not None -# assert sandbox_config.type == SandboxType.LOCAL -# -# # 2. Update the sandbox config -# updated_config = LocalSandboxConfig(sandbox_dir=UPDATED_SANDBOX_DIR) -# sandbox_config = client.update_sandbox_config(sandbox_config_id=sandbox_config.id, config=updated_config) -# assert sandbox_config.config["sandbox_dir"] == UPDATED_SANDBOX_DIR -# -# # 3. List all sandbox configs -# sandbox_configs = client.list_sandbox_configs(limit=10) -# assert isinstance(sandbox_configs, List) -# assert len(sandbox_configs) == 1 -# assert sandbox_configs[0].id == sandbox_config.id -# -# # 4. Create an environment variable -# env_var = client.create_sandbox_env_var( -# sandbox_config_id=sandbox_config.id, key=ENV_VAR_KEY, value=ENV_VAR_VALUE, description=ENV_VAR_DESCRIPTION -# ) -# assert env_var.id is not None -# assert env_var.key == ENV_VAR_KEY -# assert env_var.value == ENV_VAR_VALUE -# assert env_var.description == ENV_VAR_DESCRIPTION -# -# # 5. Update the environment variable -# updated_env_var = client.update_sandbox_env_var(env_var_id=env_var.id, key=UPDATED_ENV_VAR_KEY, value=UPDATED_ENV_VAR_VALUE) -# assert updated_env_var.key == UPDATED_ENV_VAR_KEY -# assert updated_env_var.value == UPDATED_ENV_VAR_VALUE -# -# # 6. List environment variables -# env_vars = client.list_sandbox_env_vars(sandbox_config_id=sandbox_config.id) -# assert isinstance(env_vars, List) -# assert len(env_vars) == 1 -# assert env_vars[0].key == UPDATED_ENV_VAR_KEY -# -# # 7. Delete the environment variable -# client.delete_sandbox_env_var(env_var_id=env_var.id) -# -# # 8. Delete the sandbox config -# client.delete_sandbox_config(sandbox_config_id=sandbox_config.id) - - # -------------------------------------------------------------------------------------------------------------------- # Agent tags # -------------------------------------------------------------------------------------------------------------------- @@ -349,30 +297,6 @@ def test_attach_detach_agent_memory_block(client: Letta, agent: AgentState): assert example_new_label not in [block.label for block in client.agents.blocks.list(agent_id=updated_agent.id)] -# def test_core_memory_token_limits(client: Union[LocalClient, RESTClient], agent: AgentState): -# """Test that the token limit is enforced for the core memory blocks""" - -# # Create an agent -# new_agent = client.create_agent( -# name="test-core-memory-token-limits", -# tools=BASE_TOOLS, -# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=2000), -# ) - -# try: -# # Then intentionally set the limit to be extremely low -# client.update_agent( -# agent_id=new_agent.id, -# memory=ChatMemory(human="The humans name is Joe.", persona="My name is Sam.", limit=100), -# ) - -# # TODO we should probably not allow updating the core memory limit if - -# # TODO in which case we should modify this test to actually to a proper token counter check -# finally: -# client.delete_agent(new_agent.id) - - def test_update_agent_memory_limit(client: Letta): """Test that we can update the limit of a block in an agent's memory""" @@ -744,3 +668,38 @@ def test_attach_detach_agent_source(client: Letta, agent: AgentState): assert source.id not in [s.id for s in final_sources] client.sources.delete(source.id) + + +# -------------------------------------------------------------------------------------------------------------------- +# Agent Initial Message Sequence +# -------------------------------------------------------------------------------------------------------------------- +def test_initial_sequence(client: Letta): + # create an agent + agent = client.agents.create( + memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], + model="letta/letta-free", + embedding="letta/letta-free", + initial_message_sequence=[ + MessageCreate( + role="assistant", + content="Hello, how are you?", + ), + MessageCreate(role="user", content="I'm good, and you?"), + ], + ) + + # list messages + messages = client.agents.messages.list(agent_id=agent.id) + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content="hello assistant!", + ) + ], + ) + assert len(messages) == 3 + assert messages[0].message_type == "system_message" + assert messages[1].message_type == "assistant_message" + assert messages[2].message_type == "user_message" diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 68bf2edb4..f4ab770ee 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -9,8 +9,7 @@ import pytest from dotenv import load_dotenv from sqlalchemy import delete -from letta import create_client -from letta.client.client import LocalClient, RESTClient +from letta.client.client import RESTClient from letta.constants import DEFAULT_PRESET from letta.helpers.datetime_helpers import get_utc_time from letta.orm import FileMetadata, Source @@ -33,7 +32,6 @@ from letta.schemas.usage import LettaUsageStatistics from letta.services.helpers.agent_manager_helper import initialize_message_sequence from letta.services.organization_manager import OrganizationManager from letta.services.user_manager import UserManager -from letta.settings import model_settings from tests.helpers.client_helper import upload_file_using_client # from tests.utils import create_config @@ -58,30 +56,22 @@ def run_server(): start_server(debug=True) -# Fixture to create clients with different configurations @pytest.fixture( - # params=[{"server": True}, {"server": False}], # whether to use REST API server - params=[{"server": True}], # whether to use REST API server scope="module", ) -def client(request): - if request.param["server"]: - # get URL from enviornment - server_url = os.getenv("LETTA_SERVER_URL") - if server_url is None: - # run server in thread - server_url = "http://localhost:8283" - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - # create user via admin client - client = create_client(base_url=server_url, token=None) # This yields control back to the test function - else: - # use local client (no server) - client = create_client() - +def client(): + # get URL from enviornment + server_url = os.getenv("LETTA_SERVER_URL") + if server_url is None: + # run server in thread + server_url = "http://localhost:8283" + print("Starting server thread") + thread = threading.Thread(target=run_server, daemon=True) + thread.start() + time.sleep(5) + print("Running client tests with server:", server_url) + # create user via admin client + client = RESTClient(server_url) client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) yield client @@ -100,7 +90,7 @@ def clear_tables(): # Fixture for test agent @pytest.fixture(scope="module") -def agent(client: Union[LocalClient, RESTClient]): +def agent(client: Union[RESTClient]): agent_state = client.create_agent(name=test_agent_name) yield agent_state @@ -124,7 +114,7 @@ def default_user(default_organization): yield user -def test_agent(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_agent(disable_e2b_api_key, client: RESTClient, agent: AgentState): # test client.rename_agent new_name = "RenamedTestAgent" @@ -143,7 +133,7 @@ def test_agent(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agen assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed" -def test_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState): # _reset_config() memory_response = client.get_in_context_memory(agent_id=agent.id) @@ -159,7 +149,7 @@ def test_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], age ), "Memory update failed" -def test_agent_interactions(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_agent_interactions(disable_e2b_api_key, client: RESTClient, agent: AgentState): # test that it is a LettaMessage message = "Hello again, agent!" print("Sending message", message) @@ -182,7 +172,7 @@ def test_agent_interactions(disable_e2b_api_key, client: Union[LocalClient, REST # TODO: add streaming tests -def test_archival_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_archival_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState): # _reset_config() memory_content = "Archival memory content" @@ -216,7 +206,7 @@ def test_archival_memory(disable_e2b_api_key, client: Union[LocalClient, RESTCli client.get_archival_memory(agent.id) -def test_core_memory(disable_e2b_api_key, client: Union[LocalClient, RESTClient], agent: AgentState): +def test_core_memory(disable_e2b_api_key, client: RESTClient, agent: AgentState): response = client.send_message(agent_id=agent.id, message="Update your core memory to remember that my name is Timber!", role="user") print("Response", response) @@ -240,10 +230,6 @@ def test_streaming_send_message( stream_tokens: bool, model: str, ): - if isinstance(client, LocalClient): - pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming") - assert isinstance(client, RESTClient), client - # Update agent's model agent.llm_config.model = model @@ -296,7 +282,7 @@ def test_streaming_send_message( assert done, "Message stream not done" -def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_humans_personas(client: RESTClient, agent: AgentState): # _reset_config() humans_response = client.list_humans() @@ -322,7 +308,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta assert human.value == "Human text", "Creating human failed" -def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): +def test_list_tools_pagination(client: RESTClient): tools = client.list_tools() visited_ids = {t.id: False for t in tools} @@ -344,7 +330,7 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]): assert all(visited_ids.values()) -def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_list_files_pagination(client: RESTClient, agent: AgentState): # clear sources for source in client.list_sources(): client.delete_source(source.id) @@ -380,7 +366,7 @@ def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: Ag assert len(files) == 0 # Should be empty -def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_delete_file_from_source(client: RESTClient, agent: AgentState): # clear sources for source in client.list_sources(): client.delete_source(source.id) @@ -409,7 +395,7 @@ def test_delete_file_from_source(client: Union[LocalClient, RESTClient], agent: assert len(empty_files) == 0 -def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_load_file(client: RESTClient, agent: AgentState): # _reset_config() # clear sources @@ -440,99 +426,7 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState): assert file.source_id == source.id -def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - # clear sources - for source in client.list_sources(): - client.delete_source(source.id) - - # clear jobs - for job in client.list_jobs(): - client.delete_job(job.id) - - # list sources - sources = client.list_sources() - print("listed sources", sources) - assert len(sources) == 0 - - # create a source - source = client.create_source(name="test_source") - - # list sources - sources = client.list_sources() - print("listed sources", sources) - assert len(sources) == 1 - - # TODO: add back? - assert sources[0].metadata["num_passages"] == 0 - assert sources[0].metadata["num_documents"] == 0 - - # update the source - original_id = source.id - original_name = source.name - new_name = original_name + "_new" - client.update_source(source_id=source.id, name=new_name) - - # get the source name (check that it's been updated) - source = client.get_source(source_id=source.id) - assert source.name == new_name - assert source.id == original_id - - # get the source id (make sure that it's the same) - assert str(original_id) == client.get_source_id(source_name=new_name) - - # check agent archival memory size - archival_memories = client.get_archival_memory(agent_id=agent.id) - assert len(archival_memories) == 0 - - # load a file into a source (non-blocking job) - filename = "tests/data/memgpt_paper.pdf" - upload_job = upload_file_using_client(client, source, filename) - job = client.get_job(upload_job.id) - created_passages = job.metadata["num_passages"] - - # TODO: add test for blocking job - - # TODO: make sure things run in the right order - archival_memories = client.get_archival_memory(agent_id=agent.id) - assert len(archival_memories) == 0 - - # attach a source - client.attach_source(source_id=source.id, agent_id=agent.id) - - # list attached sources - attached_sources = client.list_attached_sources(agent_id=agent.id) - print("attached sources", attached_sources) - assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}" - - # list archival memory - archival_memories = client.get_archival_memory(agent_id=agent.id) - # print(archival_memories) - assert len(archival_memories) == created_passages, f"Mismatched length {len(archival_memories)} vs. {created_passages}" - - # check number of passages - sources = client.list_sources() - # TODO: add back? - # assert sources.sources[0].metadata["num_passages"] > 0 - # assert sources.sources[0].metadata["num_documents"] == 0 # TODO: fix this once document store added - print(sources) - - # detach the source - assert len(client.get_archival_memory(agent_id=agent.id)) > 0, "No archival memory" - client.detach_source(source_id=source.id, agent_id=agent.id) - archival_memories = client.get_archival_memory(agent_id=agent.id) - assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}" - assert source.id not in [s.id for s in client.list_attached_sources(agent.id)] - - # delete the source - client.delete_source(source.id) - - def test_organization(client: RESTClient): - if isinstance(client, LocalClient): - pytest.skip("Skipping test_organization because LocalClient does not support organizations") - # create an organization org_name = "test-org" org = client.create_org(org_name) @@ -549,25 +443,6 @@ def test_organization(client: RESTClient): assert not (org.id in [o.id for o in orgs]) -def test_list_llm_models(client: RESTClient): - """Test that if the user's env has the right api keys set, at least one model appears in the model list""" - - def has_model_endpoint_type(models: List["LLMConfig"], target_type: str) -> bool: - return any(model.model_endpoint_type == target_type for model in models) - - models = client.list_llm_configs() - if model_settings.groq_api_key: - assert has_model_endpoint_type(models, "groq") - if model_settings.azure_api_key: - assert has_model_endpoint_type(models, "azure") - if model_settings.openai_api_key: - assert has_model_endpoint_type(models, "openai") - if model_settings.gemini_api_key: - assert has_model_endpoint_type(models, "google_ai") - if model_settings.anthropic_api_key: - assert has_model_endpoint_type(models, "anthropic") - - @pytest.fixture def cleanup_agents(client): created_agents = [] @@ -581,7 +456,7 @@ def cleanup_agents(client): # NOTE: we need to add this back once agents can also create blocks during agent creation -def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: AgentState, cleanup_agents: List[str], default_user): +def test_initial_message_sequence(client: RESTClient, agent: AgentState, cleanup_agents: List[str], default_user): """Test that we can set an initial message sequence If we pass in None, we should get a "default" message sequence @@ -624,7 +499,7 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: assert custom_sequence[0].content in client.get_in_context_messages(custom_agent_state.id)[1].content[0].text -def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState): +def test_add_and_manage_tags_for_agent(client: RESTClient, agent: AgentState): """ Comprehensive happy path test for adding, retrieving, and managing tags on an agent. """ diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index da2a6666c..3a14a856c 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from typing import Tuple from unittest.mock import AsyncMock, patch @@ -457,7 +458,9 @@ async def test_partial_error_from_anthropic_batch( letta_batch_job_id=batch_job.id, ) - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async( + letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user + ) llm_batch_job = llm_batch_jobs[0] # 2. Invoke the polling job and mock responses from Anthropic @@ -481,7 +484,10 @@ async def test_partial_error_from_anthropic_batch( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -545,7 +551,7 @@ async def test_partial_error_from_anthropic_batch( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) if agent.id == agents_failed[0].id: assert after == before, f"Agent {agent.id} should not have extra messages persisted due to Anthropic failure" @@ -567,7 +573,7 @@ async def test_partial_error_from_anthropic_batch( ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" # Check the total list of messages - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == (len(agents) - 1) * 4 + 1 @@ -617,7 +623,9 @@ async def test_resume_step_some_stop( letta_batch_job_id=batch_job.id, ) - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async( + letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user + ) llm_batch_job = llm_batch_jobs[0] # 2. Invoke the polling job and mock responses from Anthropic @@ -643,7 +651,10 @@ async def test_resume_step_some_stop( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -703,7 +714,7 @@ async def test_resume_step_some_stop( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) assert after - before >= 2, ( f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted." ) @@ -716,7 +727,7 @@ async def test_resume_step_some_stop( ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" # Check the total list of messages - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == len(agents) * 3 + 1 @@ -782,7 +793,9 @@ async def test_resume_step_after_request_all_continue( # Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly` # Verify batch items - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async( + letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user + ) assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" llm_batch_job = llm_batch_jobs[0] @@ -803,7 +816,10 @@ async def test_resume_step_after_request_all_continue( with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): - msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + sizes = await asyncio.gather( + *[server.message_manager.size_async(actor=default_user, agent_id=agent.id) for agent in agents] + ) + msg_counts_before = {agent.id: size for agent, size in zip(agents, sizes)} new_batch_responses = await poll_running_llm_batches(server) @@ -860,7 +876,7 @@ async def test_resume_step_after_request_all_continue( # Tool‑call side‑effects – each agent gets at least 2 extra messages for agent in agents: before = msg_counts_before[agent.id] # captured just before resume - after = server.message_manager.size(actor=default_user, agent_id=agent.id) + after = await server.message_manager.size_async(actor=default_user, agent_id=agent.id) assert after - before >= 2, ( f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted." ) @@ -873,7 +889,7 @@ async def test_resume_step_after_request_all_continue( ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" # Check the total list of messages - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == len(agents) * 4 @@ -977,7 +993,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly( mock_send.assert_called_once() # Verify database records were created correctly - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=response.letta_batch_id, actor=default_user) assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" llm_batch_job = llm_batch_jobs[0] diff --git a/tests/test_local_client.py b/tests/test_local_client.py deleted file mode 100644 index a3967e4a0..000000000 --- a/tests/test_local_client.py +++ /dev/null @@ -1,411 +0,0 @@ -import uuid - -import pytest - -from letta import create_client -from letta.client.client import LocalClient -from letta.schemas.agent import AgentState -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory - - -@pytest.fixture(scope="module") -def client(): - client = create_client() - # client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini")) - client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) - - yield client - - -@pytest.fixture(scope="module") -def agent(client): - # Generate uuid for agent name for this example - namespace = uuid.NAMESPACE_DNS - agent_uuid = str(uuid.uuid5(namespace, "test_new_client_test_agent")) - - agent_state = client.create_agent(name=agent_uuid) - yield agent_state - - client.delete_agent(agent_state.id) - - -def test_agent(client: LocalClient): - # create agent - agent_state_test = client.create_agent( - name="test_agent2", - memory=ChatMemory(human="I am a human", persona="I am an agent"), - description="This is a test agent", - ) - assert isinstance(agent_state_test.memory, Memory) - - # list agents - agents = client.list_agents() - assert agent_state_test.id in [a.id for a in agents] - - # get agent - tools = client.list_tools() - print("TOOLS", [t.name for t in tools]) - agent_state = client.get_agent(agent_state_test.id) - assert agent_state.name == "test_agent2" - for block in agent_state.memory.blocks: - db_block = client.server.block_manager.get_block_by_id(block.id, actor=client.user) - assert db_block is not None, "memory block not persisted on agent create" - assert db_block.value == block.value, "persisted block data does not match in-memory data" - - assert isinstance(agent_state.memory, Memory) - # update agent: name - new_name = "new_agent" - client.update_agent(agent_state_test.id, name=new_name) - assert client.get_agent(agent_state_test.id).name == new_name - - assert isinstance(agent_state.memory, Memory) - # update agent: system prompt - new_system_prompt = agent_state.system + "\nAlways respond with a !" - client.update_agent(agent_state_test.id, system=new_system_prompt) - assert client.get_agent(agent_state_test.id).system == new_system_prompt - - response = client.user_message(agent_id=agent_state_test.id, message="Hello") - agent_state = client.get_agent(agent_state_test.id) - assert isinstance(agent_state.memory, Memory) - # update agent: message_ids - old_message_ids = agent_state.message_ids - new_message_ids = old_message_ids.copy()[:-1] # pop one - assert len(old_message_ids) != len(new_message_ids) - client.update_agent(agent_state_test.id, message_ids=new_message_ids) - assert client.get_agent(agent_state_test.id).message_ids == new_message_ids - - assert isinstance(agent_state.memory, Memory) - # update agent: tools - tool_to_delete = "send_message" - assert tool_to_delete in [t.name for t in agent_state.tools] - new_agent_tool_ids = [t.id for t in agent_state.tools if t.name != tool_to_delete] - client.update_agent(agent_state_test.id, tool_ids=new_agent_tool_ids) - assert sorted([t.id for t in client.get_agent(agent_state_test.id).tools]) == sorted(new_agent_tool_ids) - - assert isinstance(agent_state.memory, Memory) - # update agent: memory - new_human = "My name is Mr Test, 100 percent human." - new_persona = "I am an all-knowing AI." - assert agent_state.memory.get_block("human").value != new_human - assert agent_state.memory.get_block("persona").value != new_persona - - # client.update_agent(agent_state_test.id, memory=new_memory) - # update blocks: - client.update_agent_memory_block(agent_state_test.id, label="human", value=new_human) - client.update_agent_memory_block(agent_state_test.id, label="persona", value=new_persona) - assert client.get_agent(agent_state_test.id).memory.get_block("human").value == new_human - assert client.get_agent(agent_state_test.id).memory.get_block("persona").value == new_persona - - # update agent: llm config - new_llm_config = agent_state.llm_config.model_copy(deep=True) - new_llm_config.model = "fake_new_model" - new_llm_config.context_window = 1e6 - assert agent_state.llm_config != new_llm_config - client.update_agent(agent_state_test.id, llm_config=new_llm_config) - assert client.get_agent(agent_state_test.id).llm_config == new_llm_config - assert client.get_agent(agent_state_test.id).llm_config.model == "fake_new_model" - assert client.get_agent(agent_state_test.id).llm_config.context_window == 1e6 - - # update agent: embedding config - new_embed_config = agent_state.embedding_config.model_copy(deep=True) - new_embed_config.embedding_model = "fake_embed_model" - assert agent_state.embedding_config != new_embed_config - client.update_agent(agent_state_test.id, embedding_config=new_embed_config) - assert client.get_agent(agent_state_test.id).embedding_config == new_embed_config - assert client.get_agent(agent_state_test.id).embedding_config.embedding_model == "fake_embed_model" - - # delete agent - client.delete_agent(agent_state_test.id) - - -def test_agent_add_remove_tools(client: LocalClient, agent): - # Create and add two tools to the client - # tool 1 - from composio import Action - - github_tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - # assert both got added - tools = client.list_tools() - assert github_tool.id in [t.id for t in tools] - - # Assert that all combinations of tool_names, organization id are unique - combinations = [(t.name, t.organization_id) for t in tools] - assert len(combinations) == len(set(combinations)) - - # create agent - agent_state = agent - curr_num_tools = len(agent_state.tools) - - # add both tools to agent in steps - agent_state = client.attach_tool(agent_id=agent_state.id, tool_id=github_tool.id) - - # confirm that both tools are in the agent state - # we could access it like agent_state.tools, but will use the client function instead - # this is obviously redundant as it requires retrieving the agent again - # but allows us to test the `get_tools_from_agent` pathway as well - curr_tools = client.get_tools_from_agent(agent_state.id) - curr_tool_names = [t.name for t in curr_tools] - assert len(curr_tool_names) == curr_num_tools + 1 - assert github_tool.name in curr_tool_names - - # remove only the github tool - agent_state = client.detach_tool(agent_id=agent_state.id, tool_id=github_tool.id) - - # confirm that only one tool left - curr_tools = client.get_tools_from_agent(agent_state.id) - curr_tool_names = [t.name for t in curr_tools] - assert len(curr_tool_names) == curr_num_tools - assert github_tool.name not in curr_tool_names - - -def test_agent_with_shared_blocks(client: LocalClient): - persona_block = client.create_block(template_name="persona", value="Here to test things!", label="persona") - human_block = client.create_block(template_name="human", value="Me Human, I swear. Beep boop.", label="human") - existing_non_template_blocks = [persona_block, human_block] - - existing_non_template_blocks_no_values = [] - for block in existing_non_template_blocks: - block_copy = block.copy() - block_copy.value = "" - existing_non_template_blocks_no_values.append(block_copy) - - # create agent - first_agent_state_test = None - second_agent_state_test = None - try: - first_agent_state_test = client.create_agent( - name="first_test_agent_shared_memory_blocks", - memory=BasicBlockMemory(blocks=existing_non_template_blocks), - description="This is a test agent using shared memory blocks", - ) - assert isinstance(first_agent_state_test.memory, Memory) - - # when this agent is created with the shared block references this agent's in-memory blocks should - # have this latest value set by the other agent. - second_agent_state_test = client.create_agent( - name="second_test_agent_shared_memory_blocks", - memory=BasicBlockMemory(blocks=existing_non_template_blocks_no_values), - description="This is a test agent using shared memory blocks", - ) - - first_memory = first_agent_state_test.memory - assert persona_block.id == first_memory.get_block("persona").id - assert human_block.id == first_memory.get_block("human").id - client.update_agent_memory_block(first_agent_state_test.id, label="human", value="I'm an analyst therapist.") - print("Updated human block value:", client.get_agent_memory_block(first_agent_state_test.id, label="human").value) - - # refresh agent state - second_agent_state_test = client.get_agent(second_agent_state_test.id) - - assert isinstance(second_agent_state_test.memory, Memory) - second_memory = second_agent_state_test.memory - assert persona_block.id == second_memory.get_block("persona").id - assert human_block.id == second_memory.get_block("human").id - # assert second_blocks_dict.get("human", {}).get("value") == "I'm an analyst therapist." - assert second_memory.get_block("human").value == "I'm an analyst therapist." - - finally: - if first_agent_state_test: - client.delete_agent(first_agent_state_test.id) - if second_agent_state_test: - client.delete_agent(second_agent_state_test.id) - - -def test_memory(client: LocalClient, agent: AgentState): - # get agent memory - original_memory = client.get_in_context_memory(agent.id) - assert original_memory is not None - original_memory_value = str(original_memory.get_block("human").value) - - # update core memory - updated_memory = client.update_in_context_memory(agent.id, section="human", value="I am a human") - - # get memory - assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated - - -def test_archival_memory(client: LocalClient, agent: AgentState): - """Test functions for interacting with archival memory store""" - - # add archival memory - memory_str = "I love chats" - passage = client.insert_archival_memory(agent.id, memory=memory_str)[0] - - # list archival memory - passages = client.get_archival_memory(agent.id) - assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}" - - # delete archival memory - client.delete_archival_memory(agent.id, passage.id) - - -def test_recall_memory(client: LocalClient, agent: AgentState): - """Test functions for interacting with recall memory store""" - - # send message to the agent - message_str = "Hello" - client.send_message(message=message_str, role="user", agent_id=agent.id) - - # list messages - messages = client.get_messages(agent.id) - exists = False - for m in messages: - if message_str in str(m): - exists = True - assert exists - - # get in-context messages - in_context_messages = client.get_in_context_messages(agent.id) - exists = False - for m in in_context_messages: - if message_str in m.content[0].text: - exists = True - assert exists - - -def test_tools(client: LocalClient): - def print_tool(message: str): - """ - A tool to print a message - - Args: - message (str): The message to print. - - Returns: - str: The message that was printed. - - """ - print(message) - return message - - def print_tool2(msg: str): - """ - Another tool to print a message - - Args: - msg (str): The message to print. - """ - print(msg) - - # Clean all tools first - for tool in client.list_tools(): - client.delete_tool(tool.id) - - # create tool - tool = client.create_or_update_tool(func=print_tool, tags=["extras"]) - - # list tools - tools = client.list_tools() - assert tool.name in [t.name for t in tools] - - # get tool id - assert tool.id == client.get_tool_id(name="print_tool") - - # update tool: extras - extras2 = ["extras2"] - client.update_tool(tool.id, tags=extras2) - assert client.get_tool(tool.id).tags == extras2 - - # update tool: source code - client.update_tool(tool.id, func=print_tool2) - assert client.get_tool(tool.id).name == "print_tool2" - - -def test_tools_from_composio_basic(client: LocalClient): - from composio import Action - - # Create a `LocalClient` (you can also use a `RESTClient`, see the letta_rest_client.py example) - client = create_client() - - # create tool - tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER) - - # list tools - tools = client.list_tools() - assert tool.name in [t.name for t in tools] - - # We end the test here as composio requires login to use the tools - # The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable - - -# TODO: Langchain seems to have issues with Pydantic -# TODO: Langchain tools are breaking every two weeks bc of changes on their side -# def test_tools_from_langchain(client: LocalClient): -# # create langchain tool -# from langchain_community.tools import WikipediaQueryRun -# from langchain_community.utilities import WikipediaAPIWrapper -# -# langchain_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()) -# -# # Add the tool -# tool = client.load_langchain_tool( -# langchain_tool, additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"} -# ) -# -# # list tools -# tools = client.list_tools() -# assert tool.name in [t.name for t in tools] -# -# # get tool -# tool_id = client.get_tool_id(name=tool.name) -# retrieved_tool = client.get_tool(tool_id) -# source_code = retrieved_tool.source_code -# -# # Parse the function and attempt to use it -# local_scope = {} -# exec(source_code, {}, local_scope) -# func = local_scope[tool.name] -# -# expected_content = "Albert Einstein" -# assert expected_content in func(query="Albert Einstein") -# -# -# def test_tool_creation_langchain_missing_imports(client: LocalClient): -# # create langchain tool -# from langchain_community.tools import WikipediaQueryRun -# from langchain_community.utilities import WikipediaAPIWrapper -# -# api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100) -# langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) -# -# # Translate to memGPT Tool -# # Intentionally missing {"langchain_community.utilities": "WikipediaAPIWrapper"} -# with pytest.raises(RuntimeError): -# ToolCreate.from_langchain(langchain_tool) - - -def test_shared_blocks_without_send_message(client: LocalClient): - from letta import BasicBlockMemory - from letta.client.client import Block, create_client - from letta.schemas.agent import AgentType - from letta.schemas.embedding_config import EmbeddingConfig - from letta.schemas.llm_config import LLMConfig - - client = create_client() - shared_memory_block = Block(name="shared_memory", label="shared_memory", value="[empty]", limit=2000) - memory = BasicBlockMemory(blocks=[shared_memory_block]) - - agent_1 = client.create_agent( - agent_type=AgentType.memgpt_agent, - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - memory=memory, - ) - - agent_2 = client.create_agent( - agent_type=AgentType.memgpt_agent, - llm_config=LLMConfig.default_config("gpt-4"), - embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"), - memory=memory, - ) - - block_id = agent_1.memory.get_block("shared_memory").id - client.update_block(block_id, value="I am no longer an [empty] memory") - agent_1 = client.get_agent(agent_1.id) - agent_2 = client.get_agent(agent_2.id) - assert agent_1.memory.get_block("shared_memory").value == "I am no longer an [empty] memory" - assert agent_2.memory.get_block("shared_memory").value == "I am no longer an [empty] memory" diff --git a/tests/test_managers.py b/tests/test_managers.py index 695dec5db..201c12669 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -140,32 +140,32 @@ async def other_user_different_org(server: SyncServer, other_organization): @pytest.fixture -def default_source(server: SyncServer, default_user): +async def default_source(server: SyncServer, default_user): source_pydantic = PydanticSource( name="Test Source", description="This is a test source.", metadata={"type": "test"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) yield source @pytest.fixture -def other_source(server: SyncServer, default_user): +async def other_source(server: SyncServer, default_user): source_pydantic = PydanticSource( name="Another Test Source", description="This is yet another test source.", metadata={"type": "another_test"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) yield source @pytest.fixture -def default_file(server: SyncServer, default_source, default_user, default_organization): - file = server.source_manager.create_file( +async def default_file(server: SyncServer, default_source, default_user, default_organization): + file = await server.source_manager.create_file( PydanticFileMetadata(file_name="test_file", organization_id=default_organization.id, source_id=default_source.id), actor=default_user, ) @@ -1175,17 +1175,18 @@ async def test_list_attached_source_ids_nonexistent_agent(server: SyncServer, de await server.agent_manager.list_attached_sources_async(agent_id="nonexistent-agent-id", actor=default_user) -def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user): +@pytest.mark.asyncio +async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user, event_loop): """Test listing agents that have a particular source attached.""" # Initially should have no attached agents - attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) assert len(attached_agents) == 0 # Attach source to first agent server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) # Verify one agent is now attached - attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) assert len(attached_agents) == 1 assert sarah_agent.id in [a.id for a in attached_agents] @@ -1193,7 +1194,7 @@ def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, de server.agent_manager.attach_source(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user) # Verify both agents are now attached - attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) assert len(attached_agents) == 2 attached_agent_ids = [a.id for a in attached_agents] assert sarah_agent.id in attached_agent_ids @@ -1203,15 +1204,16 @@ def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, de server.agent_manager.detach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user) # Verify only second agent remains attached - attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) + attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user) assert len(attached_agents) == 1 assert charles_agent.id in [a.id for a in attached_agents] -def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user): """Test listing agents for a nonexistent source.""" with pytest.raises(NoResultFound): - server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user) + await server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user) # ====================================================================================================================== @@ -2177,7 +2179,7 @@ async def test_passage_cascade_deletion( assert len(agentic_passages) == 0 # Delete source and verify its passages are deleted - server.source_manager.delete_source(default_source.id, default_user) + await server.source_manager.delete_source(default_source.id, default_user) with pytest.raises(NoResultFound): server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user) @@ -3847,7 +3849,10 @@ async def test_upsert_properties(server: SyncServer, default_user, event_loop): # ====================================================================================================================== # SourceManager Tests - Sources # ====================================================================================================================== -def test_create_source(server: SyncServer, default_user): + + +@pytest.mark.asyncio +async def test_create_source(server: SyncServer, default_user, event_loop): """Test creating a new source.""" source_pydantic = PydanticSource( name="Test Source", @@ -3855,7 +3860,7 @@ def test_create_source(server: SyncServer, default_user): metadata={"type": "test"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) # Assertions to check the created source assert source.name == source_pydantic.name @@ -3864,7 +3869,8 @@ def test_create_source(server: SyncServer, default_user): assert source.organization_id == default_user.organization_id -def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user): """Test creating a new source.""" name = "Test Source" source_pydantic = PydanticSource( @@ -3873,27 +3879,28 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul metadata={"type": "medical"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) source_pydantic = PydanticSource( name=name, description="This is a different test source.", metadata={"type": "legal"}, embedding_config=DEFAULT_EMBEDDING_CONFIG, ) - same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + same_source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) assert source.name == same_source.name assert source.id != same_source.id -def test_update_source(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_update_source(server: SyncServer, default_user): """Test updating an existing source.""" source_pydantic = PydanticSource(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) # Update the source update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata={"type": "updated"}) - updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) + updated_source = await server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) # Assertions to verify update assert updated_source.name == update_data.name @@ -3901,21 +3908,22 @@ def test_update_source(server: SyncServer, default_user): assert updated_source.metadata == update_data.metadata -def test_delete_source(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_delete_source(server: SyncServer, default_user): """Test deleting a source.""" source_pydantic = PydanticSource( name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) # Delete the source - deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user) + deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user) # Assertions to verify deletion assert deleted_source.id == source.id # Verify that the source no longer appears in list_sources - sources = server.source_manager.list_sources(actor=default_user) + sources = await server.source_manager.list_sources(actor=default_user) assert len(sources) == 0 @@ -3925,18 +3933,18 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u source_pydantic = PydanticSource( name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user) # Delete the source - deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user) + deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user) # Assertions to verify deletion assert deleted_source.id == source.id # Verify that the source no longer appears in list_sources - sources = server.source_manager.list_sources(actor=default_user) + sources = await server.source_manager.list_sources(actor=default_user) assert len(sources) == 0 # Verify that agent is not deleted @@ -3944,37 +3952,43 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u assert agent is not None -def test_list_sources(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_list_sources(server: SyncServer, default_user): """Test listing sources with pagination.""" # Create multiple sources - server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + await server.source_manager.create_source( + PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user + ) if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user) + await server.source_manager.create_source( + PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user + ) # List sources without pagination - sources = server.source_manager.list_sources(actor=default_user) + sources = await server.source_manager.list_sources(actor=default_user) assert len(sources) == 2 # List sources with pagination - paginated_sources = server.source_manager.list_sources(actor=default_user, limit=1) + paginated_sources = await server.source_manager.list_sources(actor=default_user, limit=1) assert len(paginated_sources) == 1 # Ensure cursor-based pagination works - next_page = server.source_manager.list_sources(actor=default_user, after=paginated_sources[-1].id, limit=1) + next_page = await server.source_manager.list_sources(actor=default_user, after=paginated_sources[-1].id, limit=1) assert len(next_page) == 1 assert next_page[0].name != paginated_sources[0].name -def test_get_source_by_id(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_get_source_by_id(server: SyncServer, default_user): """Test retrieving a source by ID.""" source_pydantic = PydanticSource( name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) # Retrieve the source by ID - retrieved_source = server.source_manager.get_source_by_id(source_id=source.id, actor=default_user) + retrieved_source = await server.source_manager.get_source_by_id(source_id=source.id, actor=default_user) # Assertions to verify the retrieved source matches the created one assert retrieved_source.id == source.id @@ -3982,29 +3996,31 @@ def test_get_source_by_id(server: SyncServer, default_user): assert retrieved_source.description == source.description -def test_get_source_by_name(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_get_source_by_name(server: SyncServer, default_user): """Test retrieving a source by name.""" source_pydantic = PydanticSource( name="Unique Source", description="Test source for name retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG ) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) # Retrieve the source by name - retrieved_source = server.source_manager.get_source_by_name(source_name=source.name, actor=default_user) + retrieved_source = await server.source_manager.get_source_by_name(source_name=source.name, actor=default_user) # Assertions to verify the retrieved source matches the created one assert retrieved_source.name == source.name assert retrieved_source.description == source.description -def test_update_source_no_changes(server: SyncServer, default_user): +@pytest.mark.asyncio +async def test_update_source_no_changes(server: SyncServer, default_user): """Test update_source with no actual changes to verify logging and response.""" source_pydantic = PydanticSource(name="No Change Source", description="No changes", embedding_config=DEFAULT_EMBEDDING_CONFIG) - source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + source = await server.source_manager.create_source(source=source_pydantic, actor=default_user) # Attempt to update the source with identical data update_data = SourceUpdate(name="No Change Source", description="No changes") - updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) + updated_source = await server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user) # Assertions to ensure the update returned the source but made no modifications assert updated_source.id == source.id @@ -4017,7 +4033,8 @@ def test_update_source_no_changes(server: SyncServer, default_user): # ====================================================================================================================== -def test_get_file_by_id(server: SyncServer, default_user, default_source): +@pytest.mark.asyncio +async def test_get_file_by_id(server: SyncServer, default_user, default_source): """Test retrieving a file by ID.""" file_metadata = PydanticFileMetadata( file_name="Retrieve File", @@ -4026,10 +4043,10 @@ def test_get_file_by_id(server: SyncServer, default_user, default_source): file_size=2048, source_id=default_source.id, ) - created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user) + created_file = await server.source_manager.create_file(file_metadata=file_metadata, actor=default_user) # Retrieve the file by ID - retrieved_file = server.source_manager.get_file_by_id(file_id=created_file.id, actor=default_user) + retrieved_file = await server.source_manager.get_file_by_id(file_id=created_file.id, actor=default_user) # Assertions to verify the retrieved file matches the created one assert retrieved_file.id == created_file.id @@ -4038,49 +4055,53 @@ def test_get_file_by_id(server: SyncServer, default_user, default_source): assert retrieved_file.file_type == created_file.file_type -def test_list_files(server: SyncServer, default_user, default_source): +@pytest.mark.asyncio +async def test_list_files(server: SyncServer, default_user, default_source): """Test listing files with pagination.""" # Create multiple files - server.source_manager.create_file( + await server.source_manager.create_file( PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id), actor=default_user, ) if USING_SQLITE: time.sleep(CREATE_DELAY_SQLITE) - server.source_manager.create_file( + await server.source_manager.create_file( PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id), actor=default_user, ) # List files without pagination - files = server.source_manager.list_files(source_id=default_source.id, actor=default_user) + files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user) assert len(files) == 2 # List files with pagination - paginated_files = server.source_manager.list_files(source_id=default_source.id, actor=default_user, limit=1) + paginated_files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user, limit=1) assert len(paginated_files) == 1 # Ensure cursor-based pagination works - next_page = server.source_manager.list_files(source_id=default_source.id, actor=default_user, after=paginated_files[-1].id, limit=1) + next_page = await server.source_manager.list_files( + source_id=default_source.id, actor=default_user, after=paginated_files[-1].id, limit=1 + ) assert len(next_page) == 1 assert next_page[0].file_name != paginated_files[0].file_name -def test_delete_file(server: SyncServer, default_user, default_source): +@pytest.mark.asyncio +async def test_delete_file(server: SyncServer, default_user, default_source): """Test deleting a file.""" file_metadata = PydanticFileMetadata( file_name="Delete File", file_path="/path/to/delete_file.txt", file_type="text/plain", source_id=default_source.id ) - created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user) + created_file = await server.source_manager.create_file(file_metadata=file_metadata, actor=default_user) # Delete the file - deleted_file = server.source_manager.delete_file(file_id=created_file.id, actor=default_user) + deleted_file = await server.source_manager.delete_file(file_id=created_file.id, actor=default_user) # Assertions to verify deletion assert deleted_file.id == created_file.id # Verify that the file no longer appears in list_files - files = server.source_manager.list_files(source_id=default_source.id, actor=default_user) + files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user) assert len(files) == 0 @@ -5126,7 +5147,7 @@ async def test_update_batch_status(server, default_user, dummy_beta_message_batc ) before = datetime.now(timezone.utc) - server.batch_manager.update_llm_batch_status( + await server.batch_manager.update_llm_batch_status_async( llm_batch_id=batch.id, status=JobStatus.completed, latest_polling_response=dummy_beta_message_batch, @@ -5151,7 +5172,7 @@ async def test_create_and_get_batch_item( letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5163,7 +5184,7 @@ async def test_create_and_get_batch_item( assert item.agent_id == sarah_agent.id assert item.step_state == dummy_step_state - fetched = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert fetched.id == item.id @@ -5187,7 +5208,7 @@ async def test_update_batch_item( letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5197,7 +5218,7 @@ async def test_update_batch_item( updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver) - server.batch_manager.update_llm_batch_item( + await server.batch_manager.update_llm_batch_item_async( item_id=item.id, request_status=JobStatus.completed, step_status=AgentStepStatus.resumed, @@ -5206,7 +5227,7 @@ async def test_update_batch_item( actor=default_user, ) - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.request_status == JobStatus.completed assert updated.batch_request_result == dummy_successful_response @@ -5223,7 +5244,7 @@ async def test_delete_batch_item( letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5231,10 +5252,10 @@ async def test_delete_batch_item( actor=default_user, ) - server.batch_manager.delete_llm_batch_item(item_id=item.id, actor=default_user) + await server.batch_manager.delete_llm_batch_item_async(item_id=item.id, actor=default_user) with pytest.raises(NoResultFound): - server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) @pytest.mark.asyncio @@ -5262,7 +5283,7 @@ async def test_bulk_update_batch_statuses(server, default_user, dummy_beta_messa letta_batch_job_id=letta_batch_job.id, ) - server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) + await server.batch_manager.bulk_update_llm_batch_statuses_async([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) updated = await server.batch_manager.get_llm_batch_job_by_id_async(batch.id, actor=default_user) assert updated.status == JobStatus.completed @@ -5287,7 +5308,7 @@ async def test_bulk_update_batch_items_results_by_agent( actor=default_user, letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5295,11 +5316,11 @@ async def test_bulk_update_batch_items_results_by_agent( actor=default_user, ) - server.batch_manager.bulk_update_batch_llm_items_results_by_agent( + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async( [ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)] ) - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.request_status == JobStatus.completed assert updated.batch_request_result == dummy_successful_response @@ -5314,7 +5335,7 @@ async def test_bulk_update_batch_items_step_status_by_agent( actor=default_user, letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5322,11 +5343,11 @@ async def test_bulk_update_batch_items_step_status_by_agent( actor=default_user, ) - server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async( [StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)] ) - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.step_status == AgentStepStatus.resumed @@ -5342,7 +5363,7 @@ async def test_list_batch_items_limit_and_filter( ) for _ in range(3): - server.batch_manager.create_llm_batch_item( + await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5372,7 +5393,7 @@ async def test_list_batch_items_pagination( # Create 10 batch items. created_items = [] for i in range(10): - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5435,7 +5456,7 @@ async def test_bulk_update_batch_items_request_status_by_agent( ) # Create a batch item - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5444,12 +5465,12 @@ async def test_bulk_update_batch_items_request_status_by_agent( ) # Update the request status using the bulk update method - server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async( [RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)] ) # Verify the update was applied - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.request_status == JobStatus.expired @@ -5478,20 +5499,20 @@ async def test_bulk_update_nonexistent_items_should_error( ) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates) + await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_batch_llm_items_results_by_agent( + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async( [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)] ) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async( [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)] ) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async( [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)] ) @@ -5515,21 +5536,21 @@ async def test_bulk_update_nonexistent_items( nonexistent_updates = [{"request_status": JobStatus.expired}] # This should not raise an error, just silently skip non-existent items - server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False) + await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates, strict=False) # Test with higher-level methods # Results by agent - server.batch_manager.bulk_update_batch_llm_items_results_by_agent( + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async( [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False ) # Step status by agent - server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async( [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False ) # Request status by agent - server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async( [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False ) @@ -5584,7 +5605,7 @@ async def test_create_batch_items_bulk( # Verify the IDs of created items match what's in the database created_ids = [item.id for item in created_items] for item_id in created_ids: - fetched = server.batch_manager.get_llm_batch_item_by_id(item_id, actor=default_user) + fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item_id, actor=default_user) assert fetched.id in created_ids @@ -5604,7 +5625,7 @@ async def test_count_batch_items( # Create a specific number of batch items for this batch. num_items = 5 for _ in range(num_items): - server.batch_manager.create_llm_batch_item( + await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5613,7 +5634,7 @@ async def test_count_batch_items( ) # Use the count_llm_batch_items method to count the items. - count = server.batch_manager.count_llm_batch_items(llm_batch_id=batch.id) + count = await server.batch_manager.count_llm_batch_items_async(llm_batch_id=batch.id) # Assert that the count matches the expected number. assert count == num_items, f"Expected {num_items} items, got {count}" diff --git a/tests/test_model_letta_performance.py b/tests/test_model_letta_performance.py deleted file mode 100644 index 41f2da648..000000000 --- a/tests/test_model_letta_performance.py +++ /dev/null @@ -1,439 +0,0 @@ -import os - -import pytest - -from tests.helpers.endpoints_helper import ( - check_agent_archival_memory_insert, - check_agent_archival_memory_retrieval, - check_agent_edit_core_memory, - check_agent_recall_chat_memory, - check_agent_uses_external_tool, - check_first_response_is_valid_for_llm_endpoint, - run_embedding_endpoint, -) -from tests.helpers.utils import retry_until_success, retry_until_threshold - -# directories -embedding_config_dir = "tests/configs/embedding_model_configs" -llm_config_dir = "tests/configs/llm_model_configs" - - -# ====================================================================================================================== -# OPENAI TESTS -# ====================================================================================================================== -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_archival_memory_insert(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_archival_memory_insert(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_openai_gpt_4o_edit_core_memory(): - filename = os.path.join(llm_config_dir, "openai-gpt-4o.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.openai_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_embedding_endpoint_openai(): - filename = os.path.join(embedding_config_dir, "openai_embed.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# AZURE TESTS -# ====================================================================================================================== -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_gpt_4o_mini_edit_core_memory(): - filename = os.path.join(llm_config_dir, "azure-gpt-4o-mini.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.azure_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_azure_embedding_endpoint(): - filename = os.path.join(embedding_config_dir, "azure_embed.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# LETTA HOSTED -# ====================================================================================================================== -def test_llm_endpoint_letta_hosted(): - filename = os.path.join(llm_config_dir, "letta-hosted.json") - check_first_response_is_valid_for_llm_endpoint(filename) - - -def test_embedding_endpoint_letta_hosted(): - filename = os.path.join(embedding_config_dir, "letta-hosted.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# LOCAL MODELS -# ====================================================================================================================== -def test_embedding_endpoint_local(): - filename = os.path.join(embedding_config_dir, "local.json") - run_embedding_endpoint(filename) - - -def test_llm_endpoint_ollama(): - filename = os.path.join(llm_config_dir, "ollama.json") - check_first_response_is_valid_for_llm_endpoint(filename) - - -def test_embedding_endpoint_ollama(): - filename = os.path.join(embedding_config_dir, "ollama.json") - run_embedding_endpoint(filename) - - -# ====================================================================================================================== -# ANTHROPIC TESTS -# ====================================================================================================================== -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_claude_haiku_3_5_edit_core_memory(): - filename = os.path.join(llm_config_dir, "claude-3-5-haiku.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# GROQ TESTS -# ====================================================================================================================== -def test_groq_llama31_70b_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_groq_llama31_70b_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_groq_llama31_70b_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@retry_until_threshold(threshold=0.75, max_attempts=4) -def test_groq_llama31_70b_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_groq_llama31_70b_edit_core_memory(): - filename = os.path.join(llm_config_dir, "groq.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# GEMINI TESTS -# ====================================================================================================================== -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.gemini_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_gemini_pro_15_edit_core_memory(): - filename = os.path.join(llm_config_dir, "gemini-pro.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# GOOGLE VERTEX TESTS -# ====================================================================================================================== -@pytest.mark.vertex_basic -@retry_until_success(max_attempts=1, sleep_time_seconds=2) -def test_vertex_gemini_pro_20_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "gemini-vertex.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# DEEPSEEK TESTS -# ====================================================================================================================== -@pytest.mark.deepseek_basic -def test_deepseek_reasoner_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "deepseek-reasoner.json") - # Don't validate that the inner monologue doesn't contain things like "function", since - # for the reasoners it might be quite meta (have analysis about functions etc.) - response = check_first_response_is_valid_for_llm_endpoint(filename, validate_inner_monologue_contents=False) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# xAI TESTS -# ====================================================================================================================== -@pytest.mark.xai_basic -def test_xai_grok2_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "xai-grok-2.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# TOGETHER TESTS -# ====================================================================================================================== -def test_together_llama_3_70b_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -def test_together_llama_3_70b_edit_core_memory(): - filename = os.path.join(llm_config_dir, "together-llama-3-70b.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -# ====================================================================================================================== -# ANTHROPIC BEDROCK TESTS -# ====================================================================================================================== -@pytest.mark.anthropic_bedrock_basic -def test_bedrock_claude_sonnet_3_5_valid_config(): - import json - - from letta.schemas.llm_config import LLMConfig - from letta.settings import model_settings - - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - config_data = json.load(open(filename, "r")) - llm_config = LLMConfig(**config_data) - model_region = llm_config.model.split(":")[3] - assert model_settings.aws_region == model_region, "Model region in config file does not match model region in ModelSettings" - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_returns_valid_first_message(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_first_response_is_valid_for_llm_endpoint(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_uses_external_tool(disable_e2b_api_key): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_uses_external_tool(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_recall_chat_memory(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_recall_chat_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_archival_memory_retrieval(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_archival_memory_retrieval(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") - - -@pytest.mark.anthropic_bedrock_basic -@retry_until_success(max_attempts=5, sleep_time_seconds=2) -def test_bedrock_claude_sonnet_3_5_edit_core_memory(): - filename = os.path.join(llm_config_dir, "bedrock-claude-3-5-sonnet.json") - response = check_agent_edit_core_memory(filename) - # Log out successful response - print(f"Got successful response from client: \n\n{response}") diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 91fd016cb..d2166d65e 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -680,3 +680,72 @@ def test_many_blocks(client: LettaSDKClient): client.agents.delete(agent1.id) client.agents.delete(agent2.id) + + +def test_sources(client: LettaSDKClient, agent: AgentState): + + # Clear existing sources + for source in client.sources.list(): + client.sources.delete(source_id=source.id) + + # Clear existing jobs + for job in client.jobs.list(): + client.jobs.delete(job_id=job.id) + + # Create a new source + source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + assert len(client.sources.list()) == 1 + + # delete the source + client.sources.delete(source_id=source.id) + assert len(client.sources.list()) == 0 + source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002") + + # Load files into the source + file_a_path = "tests/data/memgpt_paper.pdf" + file_b_path = "tests/data/test.txt" + + # Upload the files + with open(file_a_path, "rb") as f: + job_a = client.sources.files.upload(source_id=source.id, file=f) + + with open(file_b_path, "rb") as f: + job_b = client.sources.files.upload(source_id=source.id, file=f) + + # Wait for the jobs to complete + while job_a.status != "completed" or job_b.status != "completed": + time.sleep(1) + job_a = client.jobs.retrieve(job_id=job_a.id) + job_b = client.jobs.retrieve(job_id=job_b.id) + print("Waiting for jobs to complete...", job_a.status, job_b.status) + + # Get the first file with pagination + files_a = client.sources.files.list(source_id=source.id, limit=1) + assert len(files_a) == 1 + assert files_a[0].source_id == source.id + + # Use the cursor from files_a to get the remaining file + files_b = client.sources.files.list(source_id=source.id, limit=1, after=files_a[-1].id) + assert len(files_b) == 1 + assert files_b[0].source_id == source.id + + # Check files are different to ensure the cursor works + assert files_a[0].file_name != files_b[0].file_name + + # Use the cursor from files_b to list files, should be empty + files = client.sources.files.list(source_id=source.id, limit=1, after=files_b[-1].id) + assert len(files) == 0 # Should be empty + + # list passages + passages = client.sources.passages.list(source_id=source.id) + assert len(passages) > 0 + + # attach to an agent + assert len(client.agents.passages.list(agent_id=agent.id)) == 0 + client.agents.sources.attach(source_id=source.id, agent_id=agent.id) + assert len(client.agents.passages.list(agent_id=agent.id)) > 0 + assert len(client.agents.sources.list(agent_id=agent.id)) == 1 + + # detach from agent + client.agents.sources.detach(source_id=source.id, agent_id=agent.id) + assert len(client.agents.passages.list(agent_id=agent.id)) == 0 diff --git a/tests/test_server.py b/tests/test_server.py index 200ff54eb..cc1c6b65b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,3 +1,4 @@ +import asyncio import json import os import shutil @@ -24,15 +25,10 @@ from letta.server.db import db_registry utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.agent import CreateAgent, UpdateAgent -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.job import Job as PydanticJob from letta.schemas.message import Message -from letta.schemas.source import Source as PydanticSource from letta.server.server import SyncServer from letta.system import unpack_message -from .utils import DummyDataConnector - WAR_AND_PEACE = """BOOK ONE: 1805 CHAPTER I @@ -270,8 +266,6 @@ start my apprenticeship as old maid.""" @pytest.fixture(scope="module") def server(): config = LettaConfig.load() - print("CONFIG PATH", config.config_path) - config.save() server = SyncServer() @@ -366,6 +360,14 @@ def other_agent_id(server, user_id, base_tools): server.agent_manager.delete_agent(agent_state.id, actor=actor) +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + def test_error_on_nonexistent_agent(server, user, agent_id): try: fake_agent_id = str(uuid.uuid4()) @@ -392,40 +394,6 @@ def test_user_message_memory(server, user, agent_id): server.run_command(user_id=user.id, agent_id=agent_id, command="/memory") -@pytest.mark.order(3) -def test_load_data(server, user, agent_id): - # create source - passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, after=None, limit=10000) - assert len(passages_before) == 0 - - source = server.source_manager.create_source( - PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user - ) - - # load data - archival_memories = [ - "alpha", - "Cinderella wore a blue dress", - "Dog eat dog", - "ZZZ", - "Shishir loves indian food", - ] - connector = DummyDataConnector(archival_memories) - server.load_data(user.id, connector, source.name) - - # attach source - server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user) - - # check archival memory size - passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, after=None, limit=10000) - assert len(passages_after) == 5 - - -def test_save_archival_memory(server, user_id, agent_id): - # TODO: insert into archival memory - pass - - @pytest.mark.order(4) def test_user_message(server, user, agent_id): # add data into recall memory @@ -458,59 +426,60 @@ def test_get_recall_memory(server, org_id, user, agent_id): assert message_id in message_ids, f"{message_id} not in {message_ids}" -@pytest.mark.order(6) -def test_get_archival_memory(server, user, agent_id): - # test archival memory cursor pagination - actor = user - - # List latest 2 passages - passages_1 = server.agent_manager.list_passages( - actor=actor, - agent_id=agent_id, - ascending=False, - limit=2, - ) - assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" - - # List next 3 passages (earliest 3) - cursor1 = passages_1[-1].id - passages_2 = server.agent_manager.list_passages( - actor=actor, - agent_id=agent_id, - ascending=False, - before=cursor1, - ) - - # List all 5 - cursor2 = passages_1[0].created_at - passages_3 = server.agent_manager.list_passages( - actor=actor, - agent_id=agent_id, - ascending=False, - end_date=cursor2, - limit=1000, - ) - assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test - assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test - - latest = passages_1[0] - earliest = passages_2[-1] - - # test archival memory - passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True) - assert len(passage_1) == 1 - assert passage_1[0].text == "alpha" - passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=earliest.id, limit=1000, ascending=True) - assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test - assert all("alpha" not in passage.text for passage in passage_2) - # test safe empty return - passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=latest.id, limit=1000, ascending=True) - assert len(passage_none) == 0 +# @pytest.mark.order(6) +# def test_get_archival_memory(server, user, agent_id): +# # test archival memory cursor pagination +# actor = user +# +# # List latest 2 passages +# passages_1 = server.agent_manager.list_passages( +# actor=actor, +# agent_id=agent_id, +# ascending=False, +# limit=2, +# ) +# assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" +# +# # List next 3 passages (earliest 3) +# cursor1 = passages_1[-1].id +# passages_2 = server.agent_manager.list_passages( +# actor=actor, +# agent_id=agent_id, +# ascending=False, +# before=cursor1, +# ) +# +# # List all 5 +# cursor2 = passages_1[0].created_at +# passages_3 = server.agent_manager.list_passages( +# actor=actor, +# agent_id=agent_id, +# ascending=False, +# end_date=cursor2, +# limit=1000, +# ) +# assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test +# assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test +# +# latest = passages_1[0] +# earliest = passages_2[-1] +# +# # test archival memory +# passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True) +# assert len(passage_1) == 1 +# assert passage_1[0].text == "alpha" +# passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=earliest.id, limit=1000, ascending=True) +# assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test +# assert all("alpha" not in passage.text for passage in passage_2) +# # test safe empty return +# passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=latest.id, limit=1000, ascending=True) +# assert len(passage_none) == 0 -def test_get_context_window_overview(server: SyncServer, user, agent_id): +@pytest.mark.asyncio +async def test_get_context_window_overview(server: SyncServer, user, agent_id): """Test that the context window overview fetch works""" - overview = server.get_agent_context_window(agent_id=agent_id, actor=user) + overview = await server.agent_manager.get_context_window(agent_id=agent_id, actor=user) assert overview is not None # Run some basic checks @@ -567,7 +536,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User): @pytest.mark.asyncio -async def test_read_local_llm_configs(server: SyncServer, user: User): +async def test_read_local_llm_configs(server: SyncServer, user: User, event_loop): configs_base_dir = os.path.join(os.path.expanduser("~"), ".letta", "llm_configs") clean_up_dir = False if not os.path.exists(configs_base_dir): @@ -604,7 +573,7 @@ async def test_read_local_llm_configs(server: SyncServer, user: User): # Try to use in agent creation context_window_override = 4000 - agent = server.create_agent( + agent = await server.create_agent_async( request=CreateAgent( model="caren/my-custom-model", context_window_limit=context_window_override, @@ -987,131 +956,6 @@ async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tool server.agent_manager.delete_agent(agent_state.id, actor=actor) -def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path): - actor = server.user_manager.get_user_or_default(user_id) - - existing_sources = server.source_manager.list_sources(actor=actor) - if len(existing_sources) > 0: - for source in existing_sources: - server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor) - initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) - assert initial_passage_count == 0 - - # Create a source - source = server.source_manager.create_source( - PydanticSource( - name="timber_source", - embedding_config=EmbeddingConfig.default_config(provider="openai"), - created_by_id=user_id, - ), - actor=actor, - ) - assert source.created_by_id == user_id - - # Create a test file with some content - test_file = tmp_path / "test.txt" - test_content = "We have a dog called Timber. He likes to sleep and eat chicken." - test_file.write_text(test_content) - - # Attach source to agent first - server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor) - - # Create a job for loading the first file - job = server.job_manager.create_job( - PydanticJob( - user_id=user_id, - metadata={"type": "embedding", "filename": test_file.name, "source_id": source.id}, - ), - actor=actor, - ) - - # Load the first file to source - server.load_file_to_source( - source_id=source.id, - file_path=str(test_file), - job_id=job.id, - actor=actor, - ) - - # Verify job completed successfully - job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor) - assert job.status == "completed" - assert job.metadata["num_passages"] == 1 - assert job.metadata["num_documents"] == 1 - - # Verify passages were added - first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) - assert first_file_passage_count > initial_passage_count - - # Create a second test file with different content - test_file2 = tmp_path / "test2.txt" - test_file2.write_text(WAR_AND_PEACE) - - # Create a job for loading the second file - job2 = server.job_manager.create_job( - PydanticJob( - user_id=user_id, - metadata={"type": "embedding", "filename": test_file2.name, "source_id": source.id}, - ), - actor=actor, - ) - - # Load the second file to source - server.load_file_to_source( - source_id=source.id, - file_path=str(test_file2), - job_id=job2.id, - actor=actor, - ) - - # Verify second job completed successfully - job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor) - assert job2.status == "completed" - assert job2.metadata["num_passages"] >= 10 - assert job2.metadata["num_documents"] == 1 - - # Verify passages were appended (not replaced) - final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor) - assert final_passage_count > first_file_passage_count - - # Verify both old and new content is searchable - passages = server.agent_manager.list_passages( - agent_id=agent_id, - actor=actor, - query_text="what does Timber like to eat", - embedding_config=EmbeddingConfig.default_config(provider="openai"), - embed_query=True, - ) - assert len(passages) == final_passage_count - assert any("chicken" in passage.text.lower() for passage in passages) - assert any("Anna".lower() in passage.text.lower() for passage in passages) - - # Initially should have no passages - initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id) - assert initial_agent2_passages == 0 - - # Attach source to second agent - server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor) - - # Verify second agent has same number of passages as first agent - agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id) - agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id) - assert agent2_passages == agent1_passages - - # Verify second agent can query the same content - passages2 = server.agent_manager.list_passages( - actor=actor, - agent_id=other_agent_id, - source_id=source.id, - query_text="what does Timber like to eat", - embedding_config=EmbeddingConfig.default_config(provider="openai"), - embed_query=True, - ) - assert len(passages2) == len(passages) - assert any("chicken" in passage.text.lower() for passage in passages2) - assert any("Anna".lower() in passage.text.lower() for passage in passages2) - - def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools): actor = server.user_manager.get_user_or_default(user_id) @@ -1226,8 +1070,8 @@ def test_add_remove_tools_update_agent(server: SyncServer, user_id: str, base_to @pytest.mark.asyncio -async def test_messages_with_provider_override(server: SyncServer, user_id: str): - actor = server.user_manager.get_user_or_default(user_id) +async def test_messages_with_provider_override(server: SyncServer, user_id: str, event_loop): + actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id) provider = server.provider_manager.create_provider( request=ProviderCreate( name="caren-anthropic", @@ -1242,7 +1086,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str) models = await server.list_llm_models_async(actor=actor, provider_category=[ProviderCategory.base]) assert provider.name not in [model.provider_name for model in models] - agent = server.create_agent( + agent = await server.create_agent_async( request=CreateAgent( memory_blocks=[], model="caren-anthropic/claude-3-5-sonnet-20240620", @@ -1306,7 +1150,7 @@ async def test_messages_with_provider_override(server: SyncServer, user_id: str) @pytest.mark.asyncio -async def test_unique_handles_for_provider_configs(server: SyncServer, user: User): +async def test_unique_handles_for_provider_configs(server: SyncServer, user: User, event_loop): models = await server.list_llm_models_async(actor=user) model_handles = [model.handle for model in models] assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles" diff --git a/tests/test_streaming.py b/tests/test_streaming.py deleted file mode 100644 index d9a7a7f14..000000000 --- a/tests/test_streaming.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -import threading -import time - -import pytest -from dotenv import load_dotenv -from letta_client import AgentState, Letta, LlmConfig, MessageCreate - -from letta.schemas.message import Message - - -def run_server(): - load_dotenv() - - from letta.server.rest_api.app import start_server - - print("Starting server...") - start_server(debug=True) - - -@pytest.fixture( - scope="module", -) -def client(request): - # Get URL from environment or start server - api_url = os.getenv("LETTA_API_URL") - server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:8283") - if not os.getenv("LETTA_SERVER_URL"): - print("Starting server thread") - thread = threading.Thread(target=run_server, daemon=True) - thread.start() - time.sleep(5) - print("Running client tests with server:", server_url) - - # Overide the base_url if the LETTA_API_URL is set - base_url = api_url if api_url else server_url - # create the Letta client - yield Letta(base_url=base_url, token=None) - - -# Fixture for test agent -@pytest.fixture(scope="module") -def agent(client: Letta): - agent_state = client.agents.create( - name="test_client", - memory_blocks=[{"label": "human", "value": ""}, {"label": "persona", "value": ""}], - model="letta/letta-free", - embedding="letta/letta-free", - ) - - yield agent_state - - # delete agent - client.agents.delete(agent_state.id) - - -@pytest.mark.parametrize( - "stream_tokens,model", - [ - (True, "openai/gpt-4o-mini"), - (True, "anthropic/claude-3-sonnet-20240229"), - (False, "openai/gpt-4o-mini"), - (False, "anthropic/claude-3-sonnet-20240229"), - ], -) -def test_streaming_send_message( - disable_e2b_api_key, - client: Letta, - agent: AgentState, - stream_tokens: bool, - model: str, -): - # Update agent's model - config = client.agents.retrieve(agent_id=agent.id).llm_config - config_dump = config.model_dump() - config_dump["model"] = model - config = LlmConfig(**config_dump) - client.agents.modify(agent_id=agent.id, llm_config=config) - - # Send streaming message - user_message_otid = Message.generate_otid() - response = client.agents.messages.create_stream( - agent_id=agent.id, - messages=[ - MessageCreate( - role="user", - content="This is a test. Repeat after me: 'banana'", - otid=user_message_otid, - ), - ], - stream_tokens=stream_tokens, - ) - - # Tracking variables for test validation - inner_thoughts_exist = False - inner_thoughts_count = 0 - send_message_ran = False - done = False - last_message_id = client.agents.messages.list(agent_id=agent.id, limit=1)[0].id - letta_message_otids = [user_message_otid] - - assert response, "Sending message failed" - for chunk in response: - # Check chunk type and content based on the current client API - if hasattr(chunk, "message_type") and chunk.message_type == "reasoning_message": - inner_thoughts_exist = True - inner_thoughts_count += 1 - - if chunk.message_type == "tool_call_message" and hasattr(chunk, "tool_call") and chunk.tool_call.name == "send_message": - send_message_ran = True - if chunk.message_type == "assistant_message": - send_message_ran = True - - if chunk.message_type == "usage_statistics": - # Validate usage statistics - assert chunk.step_count == 1 - assert chunk.completion_tokens > 10 - assert chunk.prompt_tokens > 1000 - assert chunk.total_tokens > 1000 - done = True - else: - letta_message_otids.append(chunk.otid) - print(chunk) - - # If stream tokens, we expect at least one inner thought - assert inner_thoughts_count >= 1, "Expected more than one inner thought" - assert inner_thoughts_exist, "No inner thoughts found" - assert send_message_ran, "send_message function call not found" - assert done, "Message stream not done" - - messages = client.agents.messages.list(agent_id=agent.id, after=last_message_id) - assert [message.otid for message in messages] == letta_message_otids diff --git a/tests/test_system_prompt_compiler.py b/tests/test_system_prompt_compiler.py deleted file mode 100644 index d74236035..000000000 --- a/tests/test_system_prompt_compiler.py +++ /dev/null @@ -1,59 +0,0 @@ -from letta.services.helpers.agent_manager_helper import safe_format - -CORE_MEMORY_VAR = "My core memory is that I like to eat bananas" -VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR} - - -def test_formatter(): - - # Example system prompt that has no vars - NO_VARS = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - """ - - assert NO_VARS == safe_format(NO_VARS, VARS_DICT) - - # Example system prompt that has {CORE_MEMORY} - CORE_MEMORY_VAR = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {CORE_MEMORY} - """ - - CORE_MEMORY_VAR_SOL = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - My core memory is that I like to eat bananas - """ - - assert CORE_MEMORY_VAR_SOL == safe_format(CORE_MEMORY_VAR, VARS_DICT) - - # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist) - UNUSED_VAR = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {USER_MEMORY} - {CORE_MEMORY} - """ - - UNUSED_VAR_SOL = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {USER_MEMORY} - My core memory is that I like to eat bananas - """ - - assert UNUSED_VAR_SOL == safe_format(UNUSED_VAR, VARS_DICT) - - # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist), AND an empty {} - UNUSED_AND_EMPRY_VAR = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {} - {USER_MEMORY} - {CORE_MEMORY} - """ - - UNUSED_AND_EMPRY_VAR_SOL = """ - THIS IS A SYSTEM PROMPT WITH NO VARS - {} - {USER_MEMORY} - My core memory is that I like to eat bananas - """ - - assert UNUSED_AND_EMPRY_VAR_SOL == safe_format(UNUSED_AND_EMPRY_VAR, VARS_DICT) diff --git a/tests/test_utils.py b/tests/test_utils.py index 904e903e7..214dfcbb3 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,282 @@ import pytest from letta.constants import MAX_FILENAME_LENGTH +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source +from letta.services.helpers.agent_manager_helper import safe_format from letta.utils import sanitize_filename +CORE_MEMORY_VAR = "My core memory is that I like to eat bananas" +VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR} + +# ----------------------------------------------------------------------- +# Example source code for testing multiple scenarios, including: +# 1) A class-based custom type (which we won't handle properly). +# 2) Functions with multiple argument types. +# 3) A function with default arguments. +# 4) A function with no arguments. +# 5) A function that shares the same name as another symbol. +# ----------------------------------------------------------------------- +example_source_code = r""" +class CustomClass: + def __init__(self, x): + self.x = x + +def unrelated_symbol(): + pass + +def no_args_func(): + pass + +def default_args_func(x: int = 5, y: str = "hello"): + return x, y + +def my_function(a: int, b: float, c: str, d: list, e: dict, f: CustomClass = None): + pass + +def my_function_duplicate(): + # This function shares the name "my_function" partially, but isn't an exact match + pass +""" + + +def test_get_function_annotations_found(): + """ + Test that we correctly parse annotations for a function + that includes multiple argument types and a custom class. + """ + annotations = get_function_annotations_from_source(example_source_code, "my_function") + assert annotations == { + "a": "int", + "b": "float", + "c": "str", + "d": "list", + "e": "dict", + "f": "CustomClass", + } + + +def test_get_function_annotations_not_found(): + """ + If the requested function name doesn't exist exactly, + we should raise a ValueError. + """ + with pytest.raises(ValueError, match="Function 'missing_function' not found"): + get_function_annotations_from_source(example_source_code, "missing_function") + + +def test_get_function_annotations_no_args(): + """ + Check that a function without arguments returns an empty annotations dict. + """ + annotations = get_function_annotations_from_source(example_source_code, "no_args_func") + assert annotations == {} + + +def test_get_function_annotations_with_default_values(): + """ + Ensure that a function with default arguments still captures the annotations. + """ + annotations = get_function_annotations_from_source(example_source_code, "default_args_func") + assert annotations == {"x": "int", "y": "str"} + + +def test_get_function_annotations_partial_name_collision(): + """ + Ensure we only match the exact function name, not partial collisions. + """ + # This will match 'my_function' exactly, ignoring 'my_function_duplicate' + annotations = get_function_annotations_from_source(example_source_code, "my_function") + assert "a" in annotations # Means it matched the correct function + # No error expected here, just making sure we didn't accidentally parse "my_function_duplicate". + + +# --------------------- coerce_dict_args_by_annotations TESTS --------------------- # + + +def test_coerce_dict_args_success(): + """ + Basic success scenario with standard types: + int, float, str, list, dict. + """ + annotations = {"a": "int", "b": "float", "c": "str", "d": "list", "e": "dict"} + function_args = {"a": "42", "b": "3.14", "c": 123, "d": "[1, 2, 3]", "e": '{"key": "value"}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == 3.14 + assert coerced_args["c"] == "123" + assert coerced_args["d"] == [1, 2, 3] + assert coerced_args["e"] == {"key": "value"} + + +def test_coerce_dict_args_invalid_type(): + """ + If the value cannot be coerced into the annotation, + a ValueError should be raised. + """ + annotations = {"a": "int"} + function_args = {"a": "invalid_int"} + + with pytest.raises(ValueError, match="Failed to coerce argument 'a' to int"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_no_annotations(): + """ + If there are no annotations, we do no coercion. + """ + annotations = {} + function_args = {"a": 42, "b": "hello"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args == function_args # Exactly the same dict back + + +def test_coerce_dict_args_partial_annotations(): + """ + Only coerce annotated arguments; leave unannotated ones unchanged. + """ + annotations = {"a": "int"} + function_args = {"a": "42", "b": "no_annotation"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == "no_annotation" + + +def test_coerce_dict_args_with_missing_args(): + """ + If function_args lacks some keys listed in annotations, + those are simply not coerced. (We do not add them.) + """ + annotations = {"a": "int", "b": "float"} + function_args = {"a": "42"} # Missing 'b' + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert "b" not in coerced_args + + +def test_coerce_dict_args_unexpected_keys(): + """ + If function_args has extra keys not in annotations, + we leave them alone. + """ + annotations = {"a": "int"} + function_args = {"a": "42", "z": 999} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["z"] == 999 # unchanged + + +def test_coerce_dict_args_unsupported_custom_class(): + """ + If someone tries to pass an annotation that isn't supported (like a custom class), + we should raise a ValueError (or similarly handle the error) rather than silently + accept it. + """ + annotations = {"f": "CustomClass"} # We can't resolve this + function_args = {"f": {"x": 1}} + with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass: Unsupported annotation: CustomClass"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_with_complex_types(): + """ + Confirm the ability to parse built-in complex data (lists, dicts, etc.) + when given as strings. + """ + annotations = {"big_list": "list", "nested_dict": "dict"} + function_args = {"big_list": "[1, 2, [3, 4], {'five': 5}]", "nested_dict": '{"alpha": [10, 20], "beta": {"x": 1, "y": 2}}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["big_list"] == [1, 2, [3, 4], {"five": 5}] + assert coerced_args["nested_dict"] == { + "alpha": [10, 20], + "beta": {"x": 1, "y": 2}, + } + + +def test_coerce_dict_args_non_string_keys(): + """ + Validate behavior if `function_args` includes non-string keys. + (We should simply skip annotation checks for them.) + """ + annotations = {"a": "int"} + function_args = {123: "42", "a": "42"} + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + # 'a' is coerced to int + assert coerced_args["a"] == 42 + # 123 remains untouched + assert coerced_args[123] == "42" + + +def test_coerce_dict_args_non_parseable_list_or_dict(): + """ + Test passing incorrectly formatted JSON for a 'list' or 'dict' annotation. + """ + annotations = {"bad_list": "list", "bad_dict": "dict"} + function_args = {"bad_list": "[1, 2, 3", "bad_dict": '{"key": "value"'} # missing brackets + + with pytest.raises(ValueError, match="Failed to coerce argument 'bad_list' to list"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_with_complex_list_annotation(): + """ + Test coercion when list with type annotation (e.g., list[int]) is used. + """ + annotations = {"a": "list[int]"} + function_args = {"a": "[1, 2, 3]"} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == [1, 2, 3] + + +def test_coerce_dict_args_with_complex_dict_annotation(): + """ + Test coercion when dict with type annotation (e.g., dict[str, int]) is used. + """ + annotations = {"a": "dict[str, int]"} + function_args = {"a": '{"x": 1, "y": 2}'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == {"x": 1, "y": 2} + + +def test_coerce_dict_args_unsupported_complex_annotation(): + """ + If an unsupported complex annotation is used (e.g., a custom class), + a ValueError should be raised. + """ + annotations = {"f": "CustomClass[int]"} + function_args = {"f": "CustomClass(42)"} + + with pytest.raises(ValueError, match="Failed to coerce argument 'f' to CustomClass\[int\]: Unsupported annotation: CustomClass\[int\]"): + coerce_dict_args_by_annotations(function_args, annotations) + + +def test_coerce_dict_args_with_nested_complex_annotation(): + """ + Test coercion with complex nested types like list[dict[str, int]]. + """ + annotations = {"a": "list[dict[str, int]]"} + function_args = {"a": '[{"x": 1}, {"y": 2}]'} + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == [{"x": 1}, {"y": 2}] + + +def test_coerce_dict_args_with_default_arguments(): + """ + Test coercion with default arguments, where some arguments have defaults in the source code. + """ + annotations = {"a": "int", "b": "str"} + function_args = {"a": "42"} + + function_args.setdefault("b", "hello") # Setting the default value for 'b' + + coerced_args = coerce_dict_args_by_annotations(function_args, annotations) + assert coerced_args["a"] == 42 + assert coerced_args["b"] == "hello" + def test_valid_filename(): filename = "valid_filename.txt" @@ -64,3 +338,58 @@ def test_unique_filenames(): assert sanitized2.startswith("duplicate_") assert sanitized1.endswith(".txt") assert sanitized2.endswith(".txt") + + +def test_formatter(): + + # Example system prompt that has no vars + NO_VARS = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + """ + + assert NO_VARS == safe_format(NO_VARS, VARS_DICT) + + # Example system prompt that has {CORE_MEMORY} + CORE_MEMORY_VAR = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {CORE_MEMORY} + """ + + CORE_MEMORY_VAR_SOL = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + My core memory is that I like to eat bananas + """ + + assert CORE_MEMORY_VAR_SOL == safe_format(CORE_MEMORY_VAR, VARS_DICT) + + # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist) + UNUSED_VAR = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {USER_MEMORY} + {CORE_MEMORY} + """ + + UNUSED_VAR_SOL = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {USER_MEMORY} + My core memory is that I like to eat bananas + """ + + assert UNUSED_VAR_SOL == safe_format(UNUSED_VAR, VARS_DICT) + + # Example system prompt that has {CORE_MEMORY} and {USER_MEMORY} (latter doesn't exist), AND an empty {} + UNUSED_AND_EMPRY_VAR = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {} + {USER_MEMORY} + {CORE_MEMORY} + """ + + UNUSED_AND_EMPRY_VAR_SOL = """ + THIS IS A SYSTEM PROMPT WITH NO VARS + {} + {USER_MEMORY} + My core memory is that I like to eat bananas + """ + + assert UNUSED_AND_EMPRY_VAR_SOL == safe_format(UNUSED_AND_EMPRY_VAR, VARS_DICT)