feat: Rewrite agents (#2232)

This commit is contained in:
Matthew Zhou 2024-12-13 14:43:19 -08:00 committed by GitHub
parent d42c1e5e72
commit e49a8b4365
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
86 changed files with 2495 additions and 3980 deletions

View File

@ -18,7 +18,7 @@ on:
branches: [ main ]
jobs:
run-integration-tests:
integ-run:
runs-on: ubuntu-latest
timeout-minutes: 15
strategy:
@ -27,6 +27,9 @@ jobs:
integration_test_suite:
- "integration_test_summarizer.py"
- "integration_test_tool_execution_sandbox.py"
- "integration_test_offline_memory_agent.py"
- "integration_test_agent_tool_graph.py"
- "integration_test_o1_agent.py"
services:
qdrant:
image: qdrant/qdrant

View File

@ -13,25 +13,27 @@ on:
pull_request:
jobs:
run-core-unit-tests:
unit-run:
runs-on: ubuntu-latest
timeout-minutes: 15
strategy:
fail-fast: false
matrix:
test_suite:
- "test_vector_embeddings.py"
- "test_client.py"
- "test_local_client.py"
- "test_client_legacy.py"
- "test_server.py"
- "test_managers.py"
- "test_o1_agent.py"
- "test_tool_rule_solver.py"
- "test_agent_tool_graph.py"
- "test_utils.py"
- "test_tool_schema_parsing.py"
- "test_v1_routes.py"
- "test_offline_memory_agent.py"
- "test_local_client.py"
- "test_managers.py"
- "test_base_functions.py"
- "test_tool_schema_parsing.py"
- "test_tool_rule_solver.py"
- "test_memory.py"
- "test_utils.py"
- "test_stream_buffer_readers.py"
- "test_summarize.py"
services:
qdrant:
image: qdrant/qdrant
@ -81,57 +83,3 @@ jobs:
LETTA_SERVER_PASS: test_server_token
run: |
poetry run pytest -s -vv tests/${{ matrix.test_suite }}
misc-unit-tests:
runs-on: ubuntu-latest
needs: run-core-unit-tests
services:
qdrant:
image: qdrant/qdrant
ports:
- 6333:6333
postgres:
image: pgvector/pgvector:pg17
ports:
- 5432:5432
env:
POSTGRES_HOST_AUTH_METHOD: trust
POSTGRES_DB: postgres
POSTGRES_USER: postgres
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Python, Poetry, and Dependencies
uses: packetcoders/action-setup-cache-python-poetry@main
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
LETTA_PG_USER: postgres
LETTA_PG_PASSWORD: postgres
LETTA_PG_DB: postgres
LETTA_PG_HOST: localhost
run: |
psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector'
poetry run alembic upgrade head
- name: Run misc unit tests
env:
LETTA_PG_PORT: 5432
LETTA_PG_USER: postgres
LETTA_PG_PASSWORD: postgres
LETTA_PG_DB: postgres
LETTA_PG_HOST: localhost
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_offline_memory_agent.py and not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests

View File

@ -0,0 +1,175 @@
"""Migrate agents to orm
Revision ID: d05669b60ebe
Revises: c5d964280dff
Create Date: 2024-12-12 10:25:31.825635
"""
from typing import Sequence, Union
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "d05669b60ebe"
down_revision: Union[str, None] = "c5d964280dff"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"sources_agents",
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("source_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["agent_id"],
["agents.id"],
),
sa.ForeignKeyConstraint(
["source_id"],
["sources.id"],
),
sa.PrimaryKeyConstraint("agent_id", "source_id"),
)
op.drop_index("agent_source_mapping_idx_user", table_name="agent_source_mapping")
op.drop_table("agent_source_mapping")
op.add_column("agents", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("agents", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
op.add_column("agents", sa.Column("_created_by_id", sa.String(), nullable=True))
op.add_column("agents", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
op.add_column("agents", sa.Column("organization_id", sa.String(), nullable=True))
# Populate `organization_id` based on `user_id`
# Use a raw SQL query to update the organization_id
op.execute(
"""
UPDATE agents
SET organization_id = users.organization_id
FROM users
WHERE agents.user_id = users.id
"""
)
op.alter_column("agents", "organization_id", nullable=False)
op.alter_column("agents", "name", existing_type=sa.VARCHAR(), nullable=True)
op.drop_index("agents_idx_user", table_name="agents")
op.create_unique_constraint("unique_org_agent_name", "agents", ["organization_id", "name"])
op.create_foreign_key(None, "agents", "organizations", ["organization_id"], ["id"])
op.drop_column("agents", "tool_names")
op.drop_column("agents", "user_id")
op.drop_constraint("agents_tags_organization_id_fkey", "agents_tags", type_="foreignkey")
op.drop_column("agents_tags", "_created_by_id")
op.drop_column("agents_tags", "_last_updated_by_id")
op.drop_column("agents_tags", "updated_at")
op.drop_column("agents_tags", "id")
op.drop_column("agents_tags", "is_deleted")
op.drop_column("agents_tags", "created_at")
op.drop_column("agents_tags", "organization_id")
op.create_unique_constraint("unique_agent_block", "blocks_agents", ["agent_id", "block_id"])
op.drop_constraint("fk_block_id_label", "blocks_agents", type_="foreignkey")
op.create_foreign_key(
"fk_block_id_label", "blocks_agents", "block", ["block_id", "block_label"], ["id", "label"], initially="DEFERRED", deferrable=True
)
op.drop_column("blocks_agents", "_created_by_id")
op.drop_column("blocks_agents", "_last_updated_by_id")
op.drop_column("blocks_agents", "updated_at")
op.drop_column("blocks_agents", "id")
op.drop_column("blocks_agents", "is_deleted")
op.drop_column("blocks_agents", "created_at")
op.drop_constraint("unique_tool_per_agent", "tools_agents", type_="unique")
op.create_unique_constraint("unique_agent_tool", "tools_agents", ["agent_id", "tool_id"])
op.drop_constraint("fk_tool_id", "tools_agents", type_="foreignkey")
op.drop_constraint("tools_agents_agent_id_fkey", "tools_agents", type_="foreignkey")
op.create_foreign_key(None, "tools_agents", "tools", ["tool_id"], ["id"], ondelete="CASCADE")
op.create_foreign_key(None, "tools_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
op.drop_column("tools_agents", "_created_by_id")
op.drop_column("tools_agents", "tool_name")
op.drop_column("tools_agents", "_last_updated_by_id")
op.drop_column("tools_agents", "updated_at")
op.drop_column("tools_agents", "id")
op.drop_column("tools_agents", "is_deleted")
op.drop_column("tools_agents", "created_at")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"tools_agents",
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
)
op.add_column(
"tools_agents", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
)
op.add_column("tools_agents", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.add_column(
"tools_agents",
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
)
op.add_column("tools_agents", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.add_column("tools_agents", sa.Column("tool_name", sa.VARCHAR(), autoincrement=False, nullable=False))
op.add_column("tools_agents", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_constraint(None, "tools_agents", type_="foreignkey")
op.drop_constraint(None, "tools_agents", type_="foreignkey")
op.create_foreign_key("tools_agents_agent_id_fkey", "tools_agents", "agents", ["agent_id"], ["id"])
op.create_foreign_key("fk_tool_id", "tools_agents", "tools", ["tool_id"], ["id"])
op.drop_constraint("unique_agent_tool", "tools_agents", type_="unique")
op.create_unique_constraint("unique_tool_per_agent", "tools_agents", ["agent_id", "tool_name"])
op.add_column(
"blocks_agents",
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
)
op.add_column(
"blocks_agents", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
)
op.add_column("blocks_agents", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.add_column(
"blocks_agents",
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
)
op.add_column("blocks_agents", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.add_column("blocks_agents", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.drop_constraint("fk_block_id_label", "blocks_agents", type_="foreignkey")
op.create_foreign_key("fk_block_id_label", "blocks_agents", "block", ["block_id", "block_label"], ["id", "label"])
op.drop_constraint("unique_agent_block", "blocks_agents", type_="unique")
op.add_column("agents_tags", sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.add_column(
"agents_tags",
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
)
op.add_column(
"agents_tags", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
)
op.add_column("agents_tags", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.add_column(
"agents_tags",
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
)
op.add_column("agents_tags", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.add_column("agents_tags", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
op.create_foreign_key("agents_tags_organization_id_fkey", "agents_tags", "organizations", ["organization_id"], ["id"])
op.add_column("agents", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
op.add_column("agents", sa.Column("tool_names", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True))
op.drop_constraint(None, "agents", type_="foreignkey")
op.drop_constraint("unique_org_agent_name", "agents", type_="unique")
op.create_index("agents_idx_user", "agents", ["user_id"], unique=False)
op.alter_column("agents", "name", existing_type=sa.VARCHAR(), nullable=False)
op.drop_column("agents", "organization_id")
op.drop_column("agents", "_last_updated_by_id")
op.drop_column("agents", "_created_by_id")
op.drop_column("agents", "is_deleted")
op.drop_column("agents", "updated_at")
op.create_table(
"agent_source_mapping",
sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint("id", name="agent_source_mapping_pkey"),
)
op.create_index("agent_source_mapping_idx_user", "agent_source_mapping", ["user_id", "agent_id", "source_id"], unique=False)
op.drop_table("sources_agents")
# ### end Alembic commands ###

View File

@ -30,7 +30,7 @@ def upgrade() -> None:
op.execute("DELETE FROM tools")
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True))
op.add_column("agents", sa.Column("tool_rules", letta.orm.agent.ToolRulesColumn(), nullable=True))
op.alter_column("block", "name", new_column_name="template_name", nullable=True)
op.add_column("organizations", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
op.add_column("organizations", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))

View File

@ -74,7 +74,7 @@ def main():
"""
# Create an agent
agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tools=[tool.name])
agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[tool.id])
print(f"Created agent: {agent.name} with ID {str(agent.id)}")
# Send a message to the agent

View File

@ -29,7 +29,7 @@ agent_state = client.create_agent(
# whether to include base letta tools (default: True)
include_base_tools=True,
# list of additional tools (by name) to add to the agent
tools=[],
tool_ids=[],
)
print(f"Created agent with name {agent_state.name} and unique ID {agent_state.id}")

View File

@ -36,7 +36,7 @@ print(f"Created tool with name {tool.name}")
# create a new agent
agent_state = client.create_agent(
# create the agent with an additional tool
tools=[tool.name],
tool_ids=[tool.id],
# add tool rules that terminate execution after specific tools
tool_rules=[
# exit after roll_d20 is called
@ -45,7 +45,7 @@ agent_state = client.create_agent(
TerminalToolRule(tool_name="send_message"),
],
)
print(f"Created agent with name {agent_state.name} with tools {agent_state.tool_names}")
print(f"Created agent with name {agent_state.name} with tools {[t.name for t in agent_state.tools]}")
# Message an agent
response = client.send_message(agent_id=agent_state.id, role="user", message="roll a dice")
@ -61,7 +61,8 @@ client.add_tool_to_agent(agent_id=agent_state.id, tool_id=tool.id)
client.delete_agent(agent_id=agent_state.id)
# create an agent with only a subset of default tools
agent_state = client.create_agent(include_base_tools=False, tools=[tool.name, "send_message"])
send_message_tool = client.get_tool_id("send_message")
agent_state = client.create_agent(include_base_tools=False, tool_ids=[tool.id, send_message_tool])
# message the agent to search archival memory (will be unable to do so)
response = client.send_message(agent_id=agent_state.id, role="user", message="search your archival memory")

View File

@ -67,7 +67,9 @@ def main():
"""
# Create an agent
agent_state = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tools=[tool_name])
agent_state = client.create_agent(
name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[wikipedia_query_tool.id]
)
print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}")
# Send a message to the agent

View File

@ -108,7 +108,7 @@ def main():
]
# 4. Create the agent
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules)
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# 5. Ask for the final secret word
response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?")

View File

@ -4,7 +4,7 @@ __version__ = "0.6.4"
from letta.client.client import LocalClient, RESTClient, create_client
# imports for easier access
from letta.schemas.agent import AgentState, PersistedAgentState
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus

View File

@ -6,8 +6,6 @@ import warnings
from abc import ABC, abstractmethod
from typing import List, Literal, Optional, Tuple, Union
from tqdm import tqdm
from letta.constants import (
BASE_TOOLS,
CLI_WARNING_PREFIX,
@ -30,7 +28,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
from letta.memory import summarize_messages
from letta.metadata import MetadataStore
from letta.orm import User
from letta.schemas.agent import AgentState, AgentStepResponse
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageRole
@ -49,12 +47,12 @@ from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.user_manager import UserManager
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import (
get_heartbeat,
@ -316,7 +314,7 @@ class Agent(BaseAgent):
else:
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
assert self.agent_state.id is not None and self.agent_state.user_id is not None
assert self.agent_state.id is not None and self.agent_state.created_by_id is not None
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence(
@ -335,7 +333,7 @@ class Agent(BaseAgent):
# We always need the system prompt up front
system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=init_messages[0],
)
@ -358,7 +356,7 @@ class Agent(BaseAgent):
# Cast to Message objects
init_messages = [
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg
agent_id=self.agent_state.id, user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=msg
)
for msg in init_messages
]
@ -439,11 +437,12 @@ class Agent(BaseAgent):
else:
# execute tool in a sandbox
# TODO: allow agent_state to specify which sandbox to execute tools in
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run(
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.created_by_id).run(
agent_state=self.agent_state.__deepcopy__()
)
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
self.update_memory_if_change(updated_agent_state.memory)
except Exception as e:
# Need to catch error here, or else trunction wont happen
@ -573,7 +572,7 @@ class Agent(BaseAgent):
added_messages_objs = [
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=msg,
)
@ -603,7 +602,7 @@ class Agent(BaseAgent):
response = create(
llm_config=self.agent_state.llm_config,
messages=message_sequence,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
functions=allowed_functions,
functions_python=self.functions_python,
function_call=function_call,
@ -689,7 +688,7 @@ class Agent(BaseAgent):
Message.dict_to_message(
id=response_message_id,
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=response_message.model_dump(),
)
@ -722,7 +721,7 @@ class Agent(BaseAgent):
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
@ -745,7 +744,7 @@ class Agent(BaseAgent):
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
@ -823,7 +822,7 @@ class Agent(BaseAgent):
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
@ -842,7 +841,7 @@ class Agent(BaseAgent):
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
@ -861,7 +860,7 @@ class Agent(BaseAgent):
Message.dict_to_message(
id=response_message_id,
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=response_message.model_dump(),
)
@ -920,7 +919,7 @@ class Agent(BaseAgent):
# logger.debug("Saving agent state")
# save updated state
if ms:
save_agent(self, ms)
save_agent(self)
# Chain stops
if not chaining:
@ -931,10 +930,10 @@ class Agent(BaseAgent):
break
# Chain handlers
elif token_warning:
assert self.agent_state.user_id is not None
assert self.agent_state.created_by_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
@ -943,10 +942,10 @@ class Agent(BaseAgent):
)
continue # always chain
elif function_failed:
assert self.agent_state.user_id is not None
assert self.agent_state.created_by_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
@ -955,10 +954,10 @@ class Agent(BaseAgent):
)
continue # always chain
elif heartbeat_request:
assert self.agent_state.user_id is not None
assert self.agent_state.created_by_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
@ -1129,10 +1128,10 @@ class Agent(BaseAgent):
openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name}
# Create the associated Message object (in the database)
assert self.agent_state.user_id is not None, "User ID is not set"
assert self.agent_state.created_by_id is not None, "User ID is not set"
user_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=openai_message_dict,
# created_at=timestamp,
@ -1232,7 +1231,7 @@ class Agent(BaseAgent):
[
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict=packed_summary_message,
)
@ -1260,7 +1259,7 @@ class Agent(BaseAgent):
assert isinstance(new_system_message, str)
new_system_message_obj = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={"role": "system", "content": new_system_message},
)
@ -1371,7 +1370,14 @@ class Agent(BaseAgent):
# TODO: recall memory
raise NotImplementedError()
def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore, page_size: Optional[int] = None):
def attach_source(
self,
user: PydanticUser,
source_id: str,
source_manager: SourceManager,
agent_manager: AgentManager,
page_size: Optional[int] = None,
):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
@ -1384,7 +1390,7 @@ class Agent(BaseAgent):
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
assert len(agents_passages) == passage_size # sanity check
assert len(agents_passages) == passage_size # sanity check
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
# attach to agent
@ -1393,7 +1399,7 @@ class Agent(BaseAgent):
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
# TODO: delete @matt and remove
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
@ -1610,20 +1616,31 @@ class Agent(BaseAgent):
return context_window_breakdown.context_window_size_current
def save_agent(agent: Agent, ms: MetadataStore):
def save_agent(agent: Agent):
"""Save agent to metadata store"""
agent.update_state()
agent_state = agent.agent_state
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
# TODO: move this to agent manager
# TODO: Completely strip out metadata
# convert to persisted model
persisted_agent_state = agent.agent_state.to_persisted_agent_state()
if ms.get_agent(agent_id=persisted_agent_state.id):
ms.update_agent(persisted_agent_state)
else:
ms.create_agent(persisted_agent_state)
agent_manager = AgentManager()
update_agent = UpdateAgent(
name=agent_state.name,
tool_ids=[t.id for t in agent_state.tools],
source_ids=[s.id for s in agent_state.sources],
block_ids=[b.id for b in agent_state.memory.blocks],
tags=agent_state.tags,
system=agent_state.system,
tool_rules=agent_state.tool_rules,
llm_config=agent_state.llm_config,
embedding_config=agent_state.embedding_config,
message_ids=agent_state.message_ids,
description=agent_state.description,
metadata_=agent_state.metadata_,
)
agent_manager.update_agent(agent_id=agent_state.id, agent_update=update_agent, actor=agent.user)
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:

View File

@ -2,7 +2,6 @@ from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Union
from letta.agent import Agent
from letta.interface import AgentInterface
from letta.metadata import MetadataStore
from letta.prompts import gpt_system
@ -68,8 +67,10 @@ class ChatOnlyAgent(Agent):
name="chat_agent_persona_new", label="chat_agent_persona_new", value=conversation_persona_block.value, limit=2000
)
recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit:]
conversation_messages_block = Block(name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit)
recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit :]
conversation_messages_block = Block(
name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit
)
offline_memory = BasicBlockMemory(
blocks=[
@ -89,7 +90,7 @@ class ChatOnlyAgent(Agent):
memory=offline_memory,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
tools=self.agent_state.metadata_.get("offline_memory_tools", []),
tool_ids=self.agent_state.metadata_.get("offline_memory_tools", []),
include_base_tools=False,
)
self.offline_memory_agent.memory.update_block_value(label="conversation_block", value=recent_convo)

View File

@ -140,7 +140,6 @@ def run(
# read user id from config
ms = MetadataStore(config)
client = create_client()
server = client.server
# determine agent to use, if not provided
if not yes and not agent:
@ -165,8 +164,6 @@ def run(
persona = persona if persona else config.persona
if agent and agent_state: # use existing agent
typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN)
# agent_config = AgentConfig.load(agent)
# agent_state = ms.get_agent(agent_name=agent, user_id=user_id)
printd("Loading agent state:", agent_state.id)
printd("Agent state:", agent_state.name)
# printd("State path:", agent_config.save_state_dir())
@ -224,8 +221,6 @@ def run(
)
# create agent
tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tool_names]
agent_state.tools = tools
letta_agent = Agent(agent_state=agent_state, interface=interface(), user=client.user)
else: # create new agent
@ -317,7 +312,7 @@ def run(
metadata=metadata,
)
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tool_names])}", fg=typer.colors.WHITE)
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t.name for t in agent_state.tools])}", fg=typer.colors.WHITE)
letta_agent = Agent(
interface=interface(),
@ -326,7 +321,7 @@ def run(
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
user=client.user,
)
save_agent(agent=letta_agent, ms=ms)
save_agent(agent=letta_agent)
typer.secho(f"🎉 Created new agent '{letta_agent.agent_state.name}' (id={letta_agent.agent_state.id})", fg=typer.colors.GREEN)
# start event loop

View File

@ -15,7 +15,8 @@ from letta.constants import (
)
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona
from letta.schemas.embedding_config import EmbeddingConfig
@ -65,10 +66,8 @@ def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
class AbstractClient(object):
def __init__(
self,
auto_save: bool = False,
debug: bool = False,
):
self.auto_save = auto_save
self.debug = debug
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
@ -81,8 +80,9 @@ class AbstractClient(object):
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
memory=None,
block_ids: Optional[List[str]] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
@ -97,7 +97,7 @@ class AbstractClient(object):
name: Optional[str] = None,
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
@ -436,7 +436,6 @@ class RESTClient(AbstractClient):
Initializes a new instance of Client class.
Args:
auto_save (bool): Whether to automatically save changes.
user_id (str): The user ID.
debug (bool): Whether to print debug information.
default_llm_config (Optional[LLMConfig]): The default LLM configuration.
@ -456,6 +455,7 @@ class RESTClient(AbstractClient):
params = {}
if tags:
params["tags"] = tags
params["match_all_tags"] = False
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
return [AgentState(**agent) for agent in response.json()]
@ -491,10 +491,12 @@ class RESTClient(AbstractClient):
llm_config: LLMConfig = None,
# memory
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
# Existing blocks
block_ids: Optional[List[str]] = None,
# system
system: Optional[str] = None,
# tools
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
# metadata
@ -511,7 +513,7 @@ class RESTClient(AbstractClient):
llm_config (LLMConfig): LLM configuration
memory (Memory): Memory configuration
system (str): System configuration
tools (List[str]): List of tools
tool_ids (List[str]): List of tool ids
include_base_tools (bool): Include base tools
metadata (Dict): Metadata
description (str): Description
@ -520,31 +522,54 @@ class RESTClient(AbstractClient):
Returns:
agent_state (AgentState): State of the created agent
"""
tool_ids = tool_ids or []
tool_names = []
if tools:
tool_names += tools
if include_base_tools:
tool_names += BASE_TOOLS
tool_names += BASE_MEMORY_TOOLS
tool_ids += [self.get_tool_id(tool_name=name) for name in tool_names]
assert embedding_config or self._default_embedding_config, f"Embedding config must be provided"
assert llm_config or self._default_llm_config, f"LLM config must be provided"
# TODO: This should not happen here, we need to have clear separation between create/add blocks
# TODO: This is insanely hacky and a result of allowing free-floating blocks
# TODO: When we create the block, it gets it's own block ID
blocks = []
for block in memory.get_blocks():
blocks.append(
self.create_block(
label=block.label,
value=block.value,
limit=block.limit,
template_name=block.template_name,
is_template=block.is_template,
)
)
memory.blocks = blocks
block_ids = block_ids or []
# create agent
request = CreateAgent(
name=name,
description=description,
metadata_=metadata,
memory_blocks=[],
tools=tool_names,
tool_rules=tool_rules,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
tags=tags,
)
create_params = {
"description": description,
"metadata_": metadata,
"memory_blocks": [],
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
"tool_rules": tool_rules,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,
"embedding_config": embedding_config if embedding_config else self._default_embedding_config,
"initial_message_sequence": initial_message_sequence,
"tags": tags,
}
# Only add name if it's not None
if name is not None:
create_params["name"] = name
request = CreateAgent(**create_params)
# Use model_dump_json() instead of model_dump()
# If we use model_dump(), the datetime objects will not be serialized correctly
@ -561,14 +586,6 @@ class RESTClient(AbstractClient):
# gather agent state
agent_state = AgentState(**response.json())
# create and link blocks
for block in memory.get_blocks():
if not self.get_block(block.id):
# note: this does not update existing blocks
# WARNING: this resets the block ID - this method is a hack for backwards compat, should eventually use CreateBlock not Memory
block = self.create_block(label=block.label, value=block.value, limit=block.limit)
self.link_agent_memory_block(agent_id=agent_state.id, block_id=block.id)
# refresh and return agent
return self.get_agent(agent_state.id)
@ -602,7 +619,7 @@ class RESTClient(AbstractClient):
name: Optional[str] = None,
description: Optional[str] = None,
system: Optional[str] = None,
tool_names: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
@ -617,7 +634,7 @@ class RESTClient(AbstractClient):
name (str): Name of the agent
description (str): Description of the agent
system (str): System configuration
tool_names (List[str]): List of tools
tool_ids (List[str]): List of tools
metadata (Dict): Metadata
llm_config (LLMConfig): LLM configuration
embedding_config (EmbeddingConfig): Embedding configuration
@ -627,11 +644,10 @@ class RESTClient(AbstractClient):
Returns:
agent_state (AgentState): State of the updated agent
"""
request = UpdateAgentState(
id=agent_id,
request = UpdateAgent(
name=name,
system=system,
tool_names=tool_names,
tool_ids=tool_ids,
tags=tags,
description=description,
metadata_=metadata,
@ -742,7 +758,7 @@ class RESTClient(AbstractClient):
agents = [AgentState(**agent) for agent in response.json()]
if len(agents) == 0:
return None
agents = [agents[0]] # TODO: @matt monkeypatched
agents = [agents[0]] # TODO: @matt monkeypatched
assert len(agents) == 1, f"Multiple agents with the same name: {[(agents.name, agents.id) for agents in agents]}"
return agents[0].id
@ -1052,7 +1068,7 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to update block: {response.text}")
return Block(**response.json())
def get_block(self, block_id: str) -> Block:
def get_block(self, block_id: str) -> Optional[Block]:
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", headers=self.headers)
if response.status_code == 404:
return None
@ -1607,23 +1623,6 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to get tool: {response.text}")
return Tool(**response.json())
def get_tool_id(self, name: str) -> Optional[str]:
"""
Get a tool ID by its name.
Args:
id (str): ID of the tool
Returns:
tool (Tool): Tool
"""
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return response.json()
def set_default_llm_config(self, llm_config: LLMConfig):
"""
Set the default LLM configuration
@ -2006,7 +2005,6 @@ class LocalClient(AbstractClient):
A local client for Letta, which corresponds to a single user.
Attributes:
auto_save (bool): Whether to automatically save changes.
user_id (str): The user ID.
debug (bool): Whether to print debug information.
interface (QueuingInterface): The interface for the client.
@ -2015,7 +2013,6 @@ class LocalClient(AbstractClient):
def __init__(
self,
auto_save: bool = False,
user_id: Optional[str] = None,
org_id: Optional[str] = None,
debug: bool = False,
@ -2026,11 +2023,9 @@ class LocalClient(AbstractClient):
Initializes a new instance of Client class.
Args:
auto_save (bool): Whether to automatically save changes.
user_id (str): The user ID.
debug (bool): Whether to print debug information.
"""
self.auto_save = auto_save
# set logging levels
letta.utils.DEBUG = debug
@ -2056,14 +2051,14 @@ class LocalClient(AbstractClient):
# get default user
self.user_id = self.server.user_manager.DEFAULT_USER_ID
self.user = self.server.get_user_or_default(self.user_id)
self.user = self.server.user_manager.get_user_or_default(self.user_id)
self.organization = self.server.get_organization_or_default(self.org_id)
# agents
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
self.interface.clear()
return self.server.list_agents(user_id=self.user_id, tags=tags)
return self.server.agent_manager.list_agents(actor=self.user, tags=tags)
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
"""
@ -2097,6 +2092,7 @@ class LocalClient(AbstractClient):
llm_config: LLMConfig = None,
# memory
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
block_ids: Optional[List[str]] = None,
# TODO: change to this when we are ready to migrate all the tests/examples (matches the REST API)
# memory_blocks=[
# {"label": "human", "value": get_human_text(DEFAULT_HUMAN), "limit": 5000},
@ -2105,7 +2101,7 @@ class LocalClient(AbstractClient):
# system
system: Optional[str] = None,
# tools
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
# metadata
@ -2132,55 +2128,53 @@ class LocalClient(AbstractClient):
Returns:
agent_state (AgentState): State of the created agent
"""
if name and self.agent_exists(agent_name=name):
raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})")
# construct list of tools
tool_ids = tool_ids or []
tool_names = []
if tools:
tool_names += tools
if include_base_tools:
tool_names += BASE_TOOLS
tool_names += BASE_MEMORY_TOOLS
tool_ids += [self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user).id for name in tool_names]
# check if default configs are provided
assert embedding_config or self._default_embedding_config, f"Embedding config must be provided"
assert llm_config or self._default_llm_config, f"LLM config must be provided"
# TODO: This should not happen here, we need to have clear separation between create/add blocks
for block in memory.get_blocks():
self.server.block_manager.create_or_update_block(block, actor=self.user)
# Also get any existing block_ids passed in
block_ids = block_ids or []
# create agent
# Create the base parameters
create_params = {
"description": description,
"metadata_": metadata,
"memory_blocks": [],
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
"tool_ids": tool_ids,
"tool_rules": tool_rules,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,
"embedding_config": embedding_config if embedding_config else self._default_embedding_config,
"initial_message_sequence": initial_message_sequence,
"tags": tags,
}
# Only add name if it's not None
if name is not None:
create_params["name"] = name
agent_state = self.server.create_agent(
CreateAgent(
name=name,
description=description,
metadata_=metadata,
# memory=memory,
memory_blocks=[],
# memory_blocks = memory.get_blocks(),
# memory_tools=memory_tools,
tools=tool_names,
tool_rules=tool_rules,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
initial_message_sequence=initial_message_sequence,
tags=tags,
),
CreateAgent(**create_params),
actor=self.user,
)
# TODO: remove when we fully migrate to block creation CreateAgent model
# Link additional blocks to the agent (block ids created on the client)
# This needs to happen since the create agent does not allow passing in blocks which have already been persisted and have an ID
# So we create the agent and then link the blocks afterwards
user = self.server.get_user_or_default(self.user_id)
for block in memory.get_blocks():
self.server.block_manager.create_or_update_block(block, actor=user)
self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_state.id, block_id=block.id)
# TODO: get full agent state
return self.server.get_agent(agent_state.id)
return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user)
def update_message(
self,
@ -2202,6 +2196,7 @@ class LocalClient(AbstractClient):
tool_calls=tool_calls,
tool_call_id=tool_call_id,
),
actor=self.user,
)
return message
@ -2211,7 +2206,7 @@ class LocalClient(AbstractClient):
name: Optional[str] = None,
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
@ -2239,11 +2234,11 @@ class LocalClient(AbstractClient):
# TODO: add the abilitty to reset linked block_ids
self.interface.clear()
agent_state = self.server.update_agent(
UpdateAgentState(
id=agent_id,
agent_id,
UpdateAgent(
name=name,
system=system,
tool_names=tools,
tool_ids=tool_ids,
tags=tags,
description=description,
metadata_=metadata,
@ -2315,7 +2310,7 @@ class LocalClient(AbstractClient):
Args:
agent_id (str): ID of the agent to delete
"""
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
self.server.agent_manager.delete_agent(agent_id=agent_id, actor=self.user)
def get_agent_by_name(self, agent_name: str) -> AgentState:
"""
@ -2328,7 +2323,7 @@ class LocalClient(AbstractClient):
agent_state (AgentState): State of the agent
"""
self.interface.clear()
return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None)
return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user)
def get_agent(self, agent_id: str) -> AgentState:
"""
@ -2340,9 +2335,8 @@ class LocalClient(AbstractClient):
Returns:
agent_state (AgentState): State representation of the agent
"""
# TODO: include agent_name
self.interface.clear()
return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id)
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
def get_agent_id(self, agent_name: str) -> Optional[str]:
"""
@ -2357,7 +2351,12 @@ class LocalClient(AbstractClient):
self.interface.clear()
assert agent_name, f"Agent name must be provided"
return self.server.get_agent_id(name=agent_name, user_id=self.user_id)
# TODO: Refactor this futher to not have downstream users expect Optionals - this should just error
try:
return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user).id
except NoResultFound:
return None
# memory
def get_in_context_memory(self, agent_id: str) -> Memory:
@ -2370,7 +2369,7 @@ class LocalClient(AbstractClient):
Returns:
memory (Memory): In-context memory of the agent
"""
memory = self.server.get_agent_memory(agent_id=agent_id)
memory = self.server.get_agent_memory(agent_id=agent_id, actor=self.user)
return memory
def get_core_memory(self, agent_id: str) -> Memory:
@ -2388,7 +2387,7 @@ class LocalClient(AbstractClient):
"""
# TODO: implement this (not sure what it should look like)
memory = self.server.update_agent_core_memory(user_id=self.user_id, agent_id=agent_id, label=section, value=value)
memory = self.server.update_agent_core_memory(agent_id=agent_id, label=section, value=value, actor=self.user)
return memory
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
@ -2402,7 +2401,7 @@ class LocalClient(AbstractClient):
summary (ArchivalMemorySummary): Summary of the archival memory
"""
return self.server.get_archival_memory_summary(agent_id=agent_id)
return self.server.get_archival_memory_summary(agent_id=agent_id, actor=self.user)
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
"""
@ -2414,7 +2413,7 @@ class LocalClient(AbstractClient):
Returns:
summary (RecallMemorySummary): Summary of the recall memory
"""
return self.server.get_recall_memory_summary(agent_id=agent_id)
return self.server.get_recall_memory_summary(agent_id=agent_id, actor=self.user)
def get_in_context_messages(self, agent_id: str) -> List[Message]:
"""
@ -2426,7 +2425,7 @@ class LocalClient(AbstractClient):
Returns:
messages (List[Message]): List of in-context messages
"""
return self.server.get_in_context_messages(agent_id=agent_id)
return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user)
# agent interactions
@ -2446,11 +2445,7 @@ class LocalClient(AbstractClient):
response (LettaResponse): Response from the agent
"""
self.interface.clear()
usage = self.server.send_messages(user_id=self.user_id, agent_id=agent_id, messages=messages)
# auto-save
if self.auto_save:
self.save()
usage = self.server.send_messages(actor=self.user, agent_id=agent_id, messages=messages)
# format messages
return LettaResponse(messages=messages, usage=usage)
@ -2490,15 +2485,11 @@ class LocalClient(AbstractClient):
self.interface.clear()
usage = self.server.send_messages(
user_id=self.user_id,
actor=self.user,
agent_id=agent_id,
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
)
# auto-save
if self.auto_save:
self.save()
## TODO: need to make sure date/timestamp is propely passed
## TODO: update self.interface.to_list() to return actual Message objects
## here, the message objects will have faulty created_by timestamps
@ -2547,16 +2538,9 @@ class LocalClient(AbstractClient):
self.interface.clear()
usage = self.server.run_command(user_id=self.user_id, agent_id=agent_id, command=command)
# auto-save
if self.auto_save:
self.save()
# NOTE: messages/usage may be empty, depending on the command
return LettaResponse(messages=self.interface.to_list(), usage=usage)
def save(self):
self.server.save_agents()
# archival memory
# humans / personas
@ -3036,7 +3020,7 @@ class LocalClient(AbstractClient):
Returns:
sources (List[Source]): List of sources
"""
return self.server.list_attached_sources(agent_id=agent_id)
return self.server.agent_manager.list_attached_sources(agent_id=agent_id, actor=self.user)
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
"""
@ -3080,7 +3064,7 @@ class LocalClient(AbstractClient):
Returns:
passages (List[Passage]): List of inserted passages
"""
return self.server.insert_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_contents=memory)
return self.server.insert_archival_memory(agent_id=agent_id, memory_contents=memory, actor=self.user)
def delete_archival_memory(self, agent_id: str, memory_id: str):
"""
@ -3090,7 +3074,7 @@ class LocalClient(AbstractClient):
agent_id (str): ID of the agent
memory_id (str): ID of the memory
"""
self.server.delete_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_id=memory_id)
self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user)
def get_archival_memory(
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
@ -3349,8 +3333,8 @@ class LocalClient(AbstractClient):
block_req = Block(**create_block.model_dump())
block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req)
# Link the block to the agent
updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id)
return updated_memory
agent = self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=self.user)
return agent.memory
def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory:
"""
@ -3363,7 +3347,7 @@ class LocalClient(AbstractClient):
Returns:
memory (Memory): The updated memory
"""
return self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block_id)
return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
"""
@ -3376,7 +3360,7 @@ class LocalClient(AbstractClient):
Returns:
memory (Memory): The updated memory
"""
return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label)
return self.server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=self.user)
def get_agent_memory_blocks(self, agent_id: str) -> List[Block]:
"""
@ -3388,8 +3372,8 @@ class LocalClient(AbstractClient):
Returns:
blocks (List[Block]): The blocks in the agent's core memory
"""
block_ids = self.server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id)
return [self.server.block_manager.get_block_by_id(block_id, actor=self.user) for block_id in block_ids]
agent = self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
return agent.memory.blocks
def get_agent_memory_block(self, agent_id: str, label: str) -> Block:
"""
@ -3402,8 +3386,7 @@ class LocalClient(AbstractClient):
Returns:
block (Block): The block corresponding to the label
"""
block_id = self.server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=label)
return self.server.block_manager.get_block_by_id(block_id, actor=self.user)
return self.server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=self.user)
def update_agent_memory_block(
self,

View File

@ -1,12 +1,9 @@
import configparser
import inspect
import json
import os
from dataclasses import dataclass
from typing import Optional
import letta
import letta.utils as utils
from letta.constants import (
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
@ -16,7 +13,6 @@ from letta.constants import (
LETTA_DIR,
)
from letta.log import get_logger
from letta.schemas.agent import PersistedAgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
@ -312,160 +308,3 @@ class LettaConfig:
for folder in folders:
if not os.path.exists(os.path.join(LETTA_DIR, folder)):
os.makedirs(os.path.join(LETTA_DIR, folder))
@dataclass
class AgentConfig:
"""
NOTE: this is a deprecated class, use AgentState instead. This class is only used for backcompatibility.
Configuration for a specific instance of an agent
"""
def __init__(
self,
persona,
human,
# model info
model=None,
model_endpoint_type=None,
model_endpoint=None,
model_wrapper=None,
context_window=None,
# embedding info
embedding_endpoint_type=None,
embedding_endpoint=None,
embedding_model=None,
embedding_dim=None,
embedding_chunk_size=None,
# other
preset=None,
data_sources=None,
# agent info
agent_config_path=None,
name=None,
create_time=None,
letta_version=None,
# functions
functions=None, # schema definitions ONLY (linked at runtime)
):
assert name, f"Agent name must be provided"
self.name = name
config = LettaConfig.load() # get default values
self.persona = config.persona if persona is None else persona
self.human = config.human if human is None else human
self.preset = config.preset if preset is None else preset
self.context_window = config.default_llm_config.context_window if context_window is None else context_window
self.model = config.default_llm_config.model if model is None else model
self.model_endpoint_type = config.default_llm_config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type
self.model_endpoint = config.default_llm_config.model_endpoint if model_endpoint is None else model_endpoint
self.model_wrapper = config.default_llm_config.model_wrapper if model_wrapper is None else model_wrapper
self.llm_config = LLMConfig(
model=self.model,
model_endpoint_type=self.model_endpoint_type,
model_endpoint=self.model_endpoint,
model_wrapper=self.model_wrapper,
context_window=self.context_window,
)
self.embedding_endpoint_type = (
config.default_embedding_config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type
)
self.embedding_endpoint = config.default_embedding_config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint
self.embedding_model = config.default_embedding_config.embedding_model if embedding_model is None else embedding_model
self.embedding_dim = config.default_embedding_config.embedding_dim if embedding_dim is None else embedding_dim
self.embedding_chunk_size = (
config.default_embedding_config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size
)
self.embedding_config = EmbeddingConfig(
embedding_endpoint_type=self.embedding_endpoint_type,
embedding_endpoint=self.embedding_endpoint,
embedding_model=self.embedding_model,
embedding_dim=self.embedding_dim,
embedding_chunk_size=self.embedding_chunk_size,
)
# agent metadata
self.data_sources = data_sources if data_sources is not None else []
self.create_time = create_time if create_time is not None else utils.get_local_time()
if letta_version is None:
import letta
self.letta_version = letta.__version__
else:
self.letta_version = letta_version
# functions
self.functions = functions
# save agent config
self.agent_config_path = (
os.path.join(LETTA_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path
)
def attach_data_source(self, data_source: str):
# TODO: add warning that only once source can be attached
# i.e. previous source will be overriden
self.data_sources.append(data_source)
self.save()
def save_dir(self):
return os.path.join(LETTA_DIR, "agents", self.name)
def save_state_dir(self):
# directory to save agent state
return os.path.join(LETTA_DIR, "agents", self.name, "agent_state")
def save_persistence_manager_dir(self):
# directory to save persistent manager state
return os.path.join(LETTA_DIR, "agents", self.name, "persistence_manager")
def save_agent_index_dir(self):
# save llama index inside of persistent manager directory
return os.path.join(self.save_persistence_manager_dir(), "index")
def save(self):
# save state of persistence manager
os.makedirs(os.path.join(LETTA_DIR, "agents", self.name), exist_ok=True)
# save version
self.letta_version = letta.__version__
with open(self.agent_config_path, "w", encoding="utf-8") as f:
json.dump(vars(self), f, indent=4)
def to_agent_state(self):
return PersistedAgentState(
name=self.name,
preset=self.preset,
persona=self.persona,
human=self.human,
llm_config=self.llm_config,
embedding_config=self.embedding_config,
create_time=self.create_time,
)
@staticmethod
def exists(name: str):
"""Check if agent config exists"""
agent_config_path = os.path.join(LETTA_DIR, "agents", name)
return os.path.exists(agent_config_path)
@classmethod
def load(cls, name: str):
"""Load agent config from JSON file"""
agent_config_path = os.path.join(LETTA_DIR, "agents", name, "config.json")
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
with open(agent_config_path, "r", encoding="utf-8") as f:
agent_config = json.load(f)
# allow compatibility accross versions
try:
class_args = inspect.getargspec(cls.__init__).args
except AttributeError:
# https://github.com/pytorch/pytorch/issues/15344
class_args = inspect.getfullargspec(cls.__init__).args
agent_fields = list(agent_config.keys())
for key in agent_fields:
if key not in class_args:
utils.printd(f"Removing missing argument {key} from agent config")
del agent_config[key]
return cls(**agent_config)

View File

@ -130,11 +130,11 @@ def run_agent_loop(
# updated agent save functions
if user_input.lower() == "/exit":
# letta_agent.save()
agent.save_agent(letta_agent, ms)
agent.save_agent(letta_agent)
break
elif user_input.lower() == "/save" or user_input.lower() == "/savechat":
# letta_agent.save()
agent.save_agent(letta_agent, ms)
agent.save_agent(letta_agent)
continue
elif user_input.lower() == "/attach":
# TODO: check if agent already has it
@ -394,7 +394,7 @@ def run_agent_loop(
token_warning = step_response.in_context_memory_warning
step_response.usage
agent.save_agent(letta_agent, ms)
agent.save_agent(letta_agent)
skip_next_user_input = False
if token_warning:
user_message = system.get_token_limit_warning()

View File

@ -1,23 +1,13 @@
import datetime
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Tuple, Union
from typing import Callable, Dict, List
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
from letta.embeddings import embedding_model, parse_and_chunk_text, query_embedding
from letta.llm_api.llm_api_tools import create
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.memory import Memory
from letta.schemas.message import Message
from letta.schemas.passage import Passage
from letta.utils import (
count_tokens,
extract_date_from_timestamp,
get_local_time,
printd,
validate_date_format,
)
from letta.utils import count_tokens, printd
def get_memory_functions(cls: Memory) -> Dict[str, Callable]:
@ -67,7 +57,6 @@ def summarize_messages(
+ message_sequence_to_summarize[cutoff:]
)
agent_state.user_id
dummy_agent_id = agent_state.id
message_sequence = []
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt))
@ -79,7 +68,7 @@ def summarize_messages(
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
response = create(
llm_config=llm_config_no_inner_thoughts,
user_id=agent_state.user_id,
user_id=agent_state.created_by_id,
messages=message_sequence,
stream=False,
)

View File

@ -2,23 +2,18 @@
import os
import secrets
from typing import List, Optional, Union
from typing import List, Optional
from sqlalchemy import JSON, Column, DateTime, Index, String, TypeDecorator
from sqlalchemy.sql import func
from sqlalchemy import JSON, Column, Index, String, TypeDecorator
from letta.config import LettaConfig
from letta.orm.base import Base
from letta.schemas.agent import PersistedAgentState
from letta.schemas.api_key import APIKey
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
from letta.schemas.user import User
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.settings import settings
from letta.utils import enforce_types, printd
from letta.utils import enforce_types
class LLMConfigColumn(TypeDecorator):
@ -65,18 +60,6 @@ class EmbeddingConfigColumn(TypeDecorator):
return value
# TODO: eventually store providers?
# class Provider(Base):
# __tablename__ = "providers"
# __table_args__ = {"extend_existing": True}
#
# id = Column(String, primary_key=True)
# name = Column(String, nullable=False)
# created_at = Column(DateTime(timezone=True))
# api_key = Column(String, nullable=False)
# base_url = Column(String, nullable=False)
class APIKeyModel(Base):
"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""
@ -113,115 +96,6 @@ def generate_api_key(prefix="sk-", length=51) -> str:
return new_key
class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value
for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value
@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")
class AgentModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""
__tablename__ = "agents"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
description = Column(String)
# state (context compilation)
message_ids = Column(JSON)
system = Column(String)
# configs
agent_type = Column(String)
llm_config = Column(LLMConfigColumn)
embedding_config = Column(EmbeddingConfigColumn)
# state
metadata_ = Column(JSON)
# tools
tool_names = Column(JSON)
tool_rules = Column(ToolRulesColumn)
Index(__tablename__ + "_idx_user", user_id),
def __repr__(self) -> str:
return f"<Agent(id='{self.id}', name='{self.name}')>"
def to_record(self) -> PersistedAgentState:
agent_state = PersistedAgentState(
id=self.id,
user_id=self.user_id,
name=self.name,
created_at=self.created_at,
description=self.description,
message_ids=self.message_ids,
system=self.system,
tool_names=self.tool_names,
tool_rules=self.tool_rules,
agent_type=self.agent_type,
llm_config=self.llm_config,
embedding_config=self.embedding_config,
metadata_=self.metadata_,
)
return agent_state
class AgentSourceMappingModel(Base):
"""Stores mapping between agent -> source"""
__tablename__ = "agent_source_mapping"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
source_id = Column(String, nullable=False)
Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),
def __repr__(self) -> str:
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
class MetadataStore:
uri: Optional[str] = None
@ -281,127 +155,3 @@ class MetadataStore:
results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all()
tokens = [r.to_record() for r in results]
return tokens
@enforce_types
def create_agent(self, agent: PersistedAgentState):
# insert into agent table
# make sure agent.name does not already exist for user user_id
with self.session_maker() as session:
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
raise ValueError(f"Agent with name {agent.name} already exists")
fields = vars(agent)
# fields["memory"] = agent.memory.to_dict()
# if "_internal_memory" in fields:
# del fields["_internal_memory"]
# else:
# warnings.warn(f"Agent {agent.id} has no _internal_memory field")
if "tags" in fields:
del fields["tags"]
# else:
# warnings.warn(f"Agent {agent.id} has no tags field")
session.add(AgentModel(**fields))
session.commit()
@enforce_types
def update_agent(self, agent: PersistedAgentState):
with self.session_maker() as session:
fields = vars(agent)
# if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
# fields["memory"] = agent.memory.to_dict()
# if "_internal_memory" in fields:
# del fields["_internal_memory"]
# else:
# warnings.warn(f"Agent {agent.id} has no _internal_memory field")
if "tags" in fields:
del fields["tags"]
# else:
# warnings.warn(f"Agent {agent.id} has no tags field")
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.commit()
@enforce_types
def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManager):
# TODO: Remove this once Agent is on the ORM
# TODO: To prevent unbounded growth
per_agent_lock_manager.clear_lock(agent_id)
with self.session_maker() as session:
# delete agents
session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
# delete mappings
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).delete()
session.commit()
@enforce_types
def list_agents(self, user_id: str) -> List[PersistedAgentState]:
with self.session_maker() as session:
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results]
@enforce_types
def get_agent(
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
) -> Optional[PersistedAgentState]:
with self.session_maker() as session:
if agent_id:
results = session.query(AgentModel).filter(AgentModel.id == agent_id).all()
else:
assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name"
results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
return results[0].to_record()
# agent source metadata
@enforce_types
def attach_source(self, user_id: str, agent_id: str, source_id: str):
with self.session_maker() as session:
# TODO: remove this (is a hack)
mapping_id = f"{user_id}-{agent_id}-{source_id}"
existing = session.query(AgentSourceMappingModel).filter(
AgentSourceMappingModel.id == mapping_id
).first()
if existing is None:
# Only create if it doesn't exist
session.add(AgentSourceMappingModel(
id=mapping_id,
user_id=user_id,
agent_id=agent_id,
source_id=source_id
))
session.commit()
@enforce_types
def list_attached_source_ids(self, agent_id: str) -> List[str]:
with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
return [r.source_id for r in results]
@enforce_types
def list_attached_agents(self, source_id: str) -> List[str]:
with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
agent_ids = []
# make sure agent exists
for r in results:
agent = self.get_agent(agent_id=r.agent_id)
if agent:
agent_ids.append(r.agent_id)
else:
printd(f"Warning: agent {r.agent_id} does not exist but exists in mapping database. This should never happen.")
return agent_ids
@enforce_types
def detach_source(self, agent_id: str, source_id: str):
with self.session_maker() as session:
session.query(AgentSourceMappingModel).filter(
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
).delete()
session.commit()

View File

@ -85,6 +85,6 @@ class O1Agent(Agent):
if step_response.messages[-1].name == "send_final_message":
break
if ms:
save_agent(self, ms)
save_agent(self)
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)

View File

@ -130,7 +130,7 @@ class OfflineMemoryAgent(Agent):
# extras
first_message_verify_mono: bool = False,
max_memory_rethinks: int = 10,
initial_message_sequence: Optional[List[Message]] = None,
initial_message_sequence: Optional[List[Message]] = None,
):
super().__init__(interface, agent_state, user, initial_message_sequence=initial_message_sequence)
self.first_message_verify_mono = first_message_verify_mono
@ -173,6 +173,6 @@ class OfflineMemoryAgent(Agent):
self.interface.step_complete()
if ms:
save_agent(self, ms)
save_agent(self)
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)

View File

@ -1,3 +1,4 @@
from letta.orm.agent import Agent
from letta.orm.agents_tags import AgentsTags
from letta.orm.base import Base
from letta.orm.block import Block
@ -9,6 +10,7 @@ from letta.orm.organization import Organization
from letta.orm.passage import Passage
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
from letta.orm.source import Source
from letta.orm.sources_agents import SourcesAgents
from letta.orm.tool import Tool
from letta.orm.tools_agents import ToolsAgents
from letta.orm.user import User

196
letta/orm/agent.py Normal file
View File

@ -0,0 +1,196 @@
import uuid
from typing import TYPE_CHECKING, List, Optional, Union
from sqlalchemy import JSON, String, TypeDecorator, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.block import Block
from letta.orm.message import Message
from letta.orm.mixins import OrganizationMixin
from letta.orm.organization import Organization
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.tool_rule import (
ChildToolRule,
InitToolRule,
TerminalToolRule,
ToolRule,
)
if TYPE_CHECKING:
from letta.orm.agents_tags import AgentsTags
from letta.orm.organization import Organization
from letta.orm.source import Source
from letta.orm.tool import Tool
class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
# return vars(value)
if isinstance(value, LLMConfig):
return value.model_dump()
return value
def process_result_value(self, value, dialect):
if value:
return LLMConfig(**value)
return value
class EmbeddingConfigColumn(TypeDecorator):
"""Custom type for storing EmbeddingConfig as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
# return vars(value)
if isinstance(value, EmbeddingConfig):
return value.model_dump()
return value
def process_result_value(self, value, dialect):
if value:
return EmbeddingConfig(**value)
return value
class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value
for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value
@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first:
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop:
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools:
rule = ChildToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")
class Agent(SqlalchemyBase, OrganizationMixin):
__tablename__ = "agents"
__pydantic_model__ = PydanticAgentState
__table_args__ = (UniqueConstraint("organization_id", "name", name="unique_org_agent_name"),)
# agent generates its own id
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
# TODO: Move this in this PR? at the very end?
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-{uuid.uuid4()}")
# Descriptor fields
agent_type: Mapped[Optional[AgentType]] = mapped_column(String, nullable=True, doc="The type of Agent")
name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a human-readable identifier for an agent, non-unique.")
description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The description of the agent.")
# System prompt
system: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The system prompt used by the agent.")
# In context memory
# TODO: This should be a separate mapping table
# This is dangerously flexible with the JSON type
message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
# Metadata and configs
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
LLMConfigColumn, nullable=True, doc="the LLM backend configuration object for this agent."
)
embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column(
EmbeddingConfigColumn, doc="the embedding configuration object for this agent."
)
# Tool rules
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents", lazy="selectin", passive_deletes=True)
sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents", lazy="selectin")
core_memory: Mapped[List["Block"]] = relationship("Block", secondary="blocks_agents", lazy="selectin")
messages: Mapped[List["Message"]] = relationship(
"Message",
back_populates="agent",
lazy="selectin",
cascade="all, delete-orphan", # Ensure messages are deleted when the agent is deleted
passive_deletes=True,
)
tags: Mapped[List["AgentsTags"]] = relationship(
"AgentsTags",
back_populates="agent",
cascade="all, delete-orphan",
lazy="selectin",
doc="Tags associated with the agent.",
)
# passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin")
def to_pydantic(self) -> PydanticAgentState:
"""converts to the basic pydantic model counterpart"""
state = {
"id": self.id,
"name": self.name,
"description": self.description,
"message_ids": self.message_ids,
"tools": self.tools,
"sources": self.sources,
"tags": [t.tag for t in self.tags],
"tool_rules": self.tool_rules,
"system": self.system,
"agent_type": self.agent_type,
"llm_config": self.llm_config,
"embedding_config": self.embedding_config,
"metadata_": self.metadata_,
"memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]),
"created_by_id": self.created_by_id,
"last_updated_by_id": self.last_updated_by_id,
"created_at": self.created_at,
"updated_at": self.updated_at,
}
return self.__pydantic_model__(**state)

View File

@ -1,28 +1,20 @@
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.base import Base
class AgentsTags(SqlalchemyBase, OrganizationMixin):
"""Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering."""
class AgentsTags(Base):
__tablename__ = "agents_tags"
__pydantic_model__ = PydanticAgentsTags
__table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),)
# The agent associated with this tag
agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
# # agent generates its own id
# # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
# # TODO: Move this in this PR? at the very end?
# id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agents_tags-{uuid.uuid4()}")
# The name of the tag
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.")
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the agent.", primary_key=True)
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags")
# Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="tags")

View File

@ -1,16 +1,17 @@
from typing import TYPE_CHECKING, Optional, Type
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
from letta.orm.blocks_agents import BlocksAgents
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import Human, Persona
if TYPE_CHECKING:
from letta.orm import BlocksAgents, Organization
from letta.orm import Organization
class Block(OrganizationMixin, SqlalchemyBase):
@ -35,7 +36,6 @@ class Block(OrganizationMixin, SqlalchemyBase):
# relationships
organization: Mapped[Optional["Organization"]] = relationship("Organization")
blocks_agents: Mapped[list["BlocksAgents"]] = relationship("BlocksAgents", back_populates="block", cascade="all, delete")
def to_pydantic(self) -> Type:
match self.label:
@ -46,3 +46,28 @@ class Block(OrganizationMixin, SqlalchemyBase):
case _:
Schema = PydanticBlock
return Schema.model_validate(self)
@event.listens_for(Block, "after_update") # Changed from 'before_update'
def block_before_update(mapper, connection, target):
"""Handle updating BlocksAgents when a block's label changes."""
label_history = attributes.get_history(target, "label")
if not label_history.has_changes():
return
blocks_agents = BlocksAgents.__table__
connection.execute(
blocks_agents.update()
.where(blocks_agents.c.block_id == target.id, blocks_agents.c.block_label == label_history.deleted[0])
.values(block_label=label_history.added[0])
)
@event.listens_for(Block, "before_insert")
@event.listens_for(Block, "before_update")
def validate_value_length(mapper, connection, target):
"""Ensure the value length does not exceed the limit."""
if target.value and len(target.value) > target.limit:
raise ValueError(
f"Value length ({len(target.value)}) exceeds the limit ({target.limit}) for block with label '{target.label}' and id '{target.id}'."
)

View File

@ -1,15 +1,13 @@
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
from letta.orm.base import Base
class BlocksAgents(SqlalchemyBase):
class BlocksAgents(Base):
"""Agents must have one or many blocks to make up their core memory."""
__tablename__ = "blocks_agents"
__pydantic_model__ = PydanticBlocksAgents
__table_args__ = (
UniqueConstraint(
"agent_id",
@ -17,16 +15,12 @@ class BlocksAgents(SqlalchemyBase):
name="unique_label_per_agent",
),
ForeignKeyConstraint(
["block_id", "block_label"],
["block.id", "block.label"],
name="fk_block_id_label",
["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label", deferrable=True, initially="DEFERRED"
),
UniqueConstraint("agent_id", "block_id", name="unique_agent_block"),
)
# unique agent + block label
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
block_id: Mapped[str] = mapped_column(String, primary_key=True)
block_label: Mapped[str] = mapped_column(String, primary_key=True)
# relationships
block: Mapped["Block"] = relationship("Block", back_populates="blocks_agents")

View File

@ -59,6 +59,5 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
# Relationships
# TODO: Add in after Agent ORM is created
# agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")

View File

@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization
if TYPE_CHECKING:
from letta.orm.agent import Agent
from letta.orm.file import FileMetadata
from letta.orm.tool import Tool
from letta.orm.user import User
@ -25,7 +26,6 @@ class Organization(SqlalchemyBase):
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
sandbox_configs: Mapped[List["SandboxConfig"]] = relationship(
"SandboxConfig", back_populates="organization", cascade="all, delete-orphan"
@ -36,10 +36,5 @@ class Organization(SqlalchemyBase):
# relationships
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
# TODO: Map these relationships later when we actually make these models
# below is just a suggestion
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
# tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
# documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan")

View File

@ -1,19 +1,18 @@
import base64
from datetime import datetime
from typing import Optional, TYPE_CHECKING
from sqlalchemy import Column, String, DateTime, JSON, ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import TypeDecorator, BINARY
from typing import TYPE_CHECKING, Optional
import numpy as np
import base64
from letta.orm.source import EmbeddingConfigColumn
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.orm.mixins import FileMixin, OrganizationMixin
from letta.schemas.passage import Passage as PydanticPassage
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy.types import BINARY, TypeDecorator
from letta.config import LettaConfig
from letta.constants import MAX_EMBEDDING_DIM
from letta.orm.mixins import FileMixin, OrganizationMixin
from letta.orm.source import EmbeddingConfigColumn
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.passage import Passage as PydanticPassage
from letta.settings import settings
config = LettaConfig()
@ -21,8 +20,10 @@ config = LettaConfig()
if TYPE_CHECKING:
from letta.orm.organization import Organization
class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""
impl = BINARY
cache_ok = True
@ -43,10 +44,12 @@ class CommonVector(TypeDecorator):
value = base64.b64decode(value)
return np.frombuffer(value, dtype=np.float32)
# TODO: After migration to Passage, will need to manually delete passages where files
# TODO: After migration to Passage, will need to manually delete passages where files
# are deleted on web
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
"""Defines data model for storing Passages"""
__tablename__ = "passages"
__table_args__ = {"extend_existing": True}
__pydantic_model__ = PydanticPassage
@ -59,6 +62,7 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
if settings.letta_pg_uri_no_default:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
else:
embedding = Column(CommonVector)

View File

@ -48,4 +48,4 @@ class Source(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
# agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")

View File

@ -0,0 +1,13 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class SourcesAgents(Base):
"""Agents can have zero to many sources"""
__tablename__ = "sources_agents"
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)

View File

@ -1,7 +1,6 @@
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, List, Literal, Optional, Type
import sqlite3
from typing import TYPE_CHECKING, List, Literal, Optional
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError
@ -9,12 +8,12 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance
from letta.orm.errors import (
ForeignKeyConstraintViolationError,
NoResultFound,
UniqueConstraintViolationError,
)
from letta.orm.sqlite_functions import adapt_array
if TYPE_CHECKING:
from pydantic import BaseModel
@ -64,11 +63,26 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
**kwargs,
) -> List[Type["SqlalchemyBase"]]:
) -> List["SqlalchemyBase"]:
"""
List records with cursor-based pagination, ordering by created_at.
Cursor is an ID, but pagination is based on the cursor object's created_at value.
Args:
db_session: SQLAlchemy session
cursor: ID of the last item seen (for pagination)
start_date: Filter items after this date
end_date: Filter items before this date
limit: Maximum number of items to return
query_text: Text to search for
query_embedding: Vector to search for similar embeddings
ascending: Sort direction
tags: List of tags to filter by
match_all_tags: If True, return items matching all tags. If False, match any tag.
**kwargs: Additional filters to apply
"""
if start_date and end_date and start_date > end_date:
raise ValueError("start_date must be earlier than or equal to end_date")
@ -84,7 +98,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query = select(cls)
# Apply filtering logic
# Handle tag filtering if the model has tags
if tags and hasattr(cls, "tags"):
query = select(cls)
if match_all_tags:
# Match ALL tags - use subqueries
for tag in tags:
subquery = select(cls.tags.property.mapper.class_.agent_id).where(cls.tags.property.mapper.class_.tag == tag)
query = query.filter(cls.id.in_(subquery))
else:
# Match ANY tag - use join and filter
query = (
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id) # Deduplicate results
)
# Group by primary key and all necessary columns to avoid JSON comparison
query = query.group_by(cls.id)
# Apply filtering logic from kwargs
for key, value in kwargs.items():
column = getattr(cls, key)
if isinstance(value, (list, tuple, set)):
@ -98,9 +130,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if end_date:
query = query.filter(cls.created_at < end_date)
# Cursor-based pagination using created_at
# TODO: There is a really nasty race condition issue here with Sqlite
# TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
# Cursor-based pagination
if cursor_obj:
if ascending:
query = query.where(cls.created_at >= cursor_obj.created_at).where(
@ -111,40 +141,34 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
)
# Apply text search
# Text search
if query_text:
from sqlalchemy import func
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
# Apply embedding search (Passages)
# Embedding search (for Passages)
is_ordered = False
if query_embedding:
# check if embedding column exists. should only exist for passages
if not hasattr(cls, "embedding"):
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
from letta.settings import settings
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
from pgvector.sqlalchemy import Vector
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
else:
# SQLite with custom vector type
from sqlalchemy import func
query_embedding_binary = adapt_array(query_embedding)
query = query.order_by(
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
cls.created_at.asc(),
cls.id.asc()
func.cosine_distance(cls.embedding, query_embedding_binary).asc(), cls.created_at.asc(), cls.id.asc()
)
is_ordered = True
# Handle ordering and soft deletes
# Handle soft deletes
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
# Apply ordering by created_at
# Apply ordering
if not is_ordered:
if ascending:
query = query.order_by(cls.created_at, cls.id)
@ -164,7 +188,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
) -> Type["SqlalchemyBase"]:
) -> "SqlalchemyBase":
"""The primary accessor for an ORM record.
Args:
db_session: the database session to use when retrieving the record
@ -207,7 +231,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
@ -221,7 +245,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
except DBAPIError as e:
self._handle_dbapi_error(e)
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
@ -245,7 +269,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
else:
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
if actor:
self._set_created_and_updated_by_fields(actor.id)
@ -388,14 +412,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
raise
@property
def __pydantic_model__(self) -> Type["BaseModel"]:
def __pydantic_model__(self) -> "BaseModel":
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
def to_pydantic(self) -> Type["BaseModel"]:
def to_pydantic(self) -> "BaseModel":
"""converts to the basic pydantic model counterpart"""
return self.__pydantic_model__.model_validate(self)
def to_record(self) -> Type["BaseModel"]:
def to_record(self) -> "BaseModel":
"""Deprecated accessor for to_pydantic"""
logger.warning("to_record is deprecated, use to_pydantic instead.")
return self.to_pydantic()
return self.to_pydantic()

View File

@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional
from sqlalchemy import JSON, String, UniqueConstraint, event
from sqlalchemy import JSON, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
# TODO everything in functions should live in this model
@ -11,7 +11,6 @@ from letta.schemas.tool import Tool as PydanticTool
if TYPE_CHECKING:
from letta.orm.organization import Organization
from letta.orm.tools_agents import ToolsAgents
class Tool(SqlalchemyBase, OrganizationMixin):
@ -42,20 +41,3 @@ class Tool(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan")
# Add event listener to update tool_name in ToolsAgents when Tool name changes
@event.listens_for(Tool, "before_update")
def update_tool_name_in_tools_agents(mapper, connection, target):
"""Update tool_name in ToolsAgents when Tool name changes."""
state = target._sa_instance_state
history = state.get_history("name", passive=True)
if not history.has_changes():
return
# Get the new name and update all associated ToolsAgents records
new_name = target.name
from letta.orm.tools_agents import ToolsAgents
connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name))

View File

@ -1,32 +1,15 @@
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
from letta.orm import Base
class ToolsAgents(SqlalchemyBase):
class ToolsAgents(Base):
"""Agents can have one or many tools associated with them."""
__tablename__ = "tools_agents"
__pydantic_model__ = PydanticToolsAgents
__table_args__ = (
UniqueConstraint(
"agent_id",
"tool_name",
name="unique_tool_per_agent",
),
ForeignKeyConstraint(
["tool_id"],
["tools.id"],
name="fk_tool_id",
),
)
__table_args__ = (UniqueConstraint("agent_id", "tool_id", name="unique_agent_tool"),)
# Each agent must have unique tool names
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
tool_id: Mapped[str] = mapped_column(String, primary_key=True)
tool_name: Mapped[str] = mapped_column(String, primary_key=True)
# relationships
tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents")
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
tool_id: Mapped[str] = mapped_column(String, ForeignKey("tools.id", ondelete="CASCADE"), primary_key=True)

View File

@ -20,10 +20,9 @@ class User(SqlalchemyBase, OrganizationMixin):
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan")
jobs: Mapped[List["Job"]] = relationship(
"Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan"
)
# TODO: Add this back later potentially
# agents: Mapped[List["Agent"]] = relationship(
# "Agent", secondary="users_agents", back_populates="users", doc="the agents associated with this user."
# )
# tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.")

View File

@ -1,13 +1,11 @@
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.schemas.block import CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_base import LettaBase
from letta.schemas.letta_base import OrmMetadataBase
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import Memory
from letta.schemas.message import Message, MessageCreate
@ -15,15 +13,15 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ToolRule
from letta.utils import create_random_username
class BaseAgent(LettaBase, validate_assignment=True):
class BaseAgent(OrmMetadataBase, validate_assignment=True):
__id_prefix__ = "agent"
description: Optional[str] = Field(None, description="The description of the agent.")
# metadata
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
user_id: Optional[str] = Field(None, description="The user id of the agent.")
class AgentType(str, Enum):
@ -38,37 +36,7 @@ class AgentType(str, Enum):
chat_only_agent = "chat_only_agent"
class PersistedAgentState(BaseAgent, validate_assignment=True):
# NOTE: this has been changed to represent the data stored in the ORM, NOT what is passed around internally or returned to the user
id: str = BaseAgent.generate_id_field()
name: str = Field(..., description="The name of the agent.")
created_at: datetime = Field(..., description="The datetime the agent was created.", default_factory=datetime.now)
# in-context memory
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
# tools
# TODO: move to ORM mapping
tool_names: List[str] = Field(..., description="The tools used by the agent.")
# tool rules
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
# system prompt
system: str = Field(..., description="The system prompt used by the agent.")
# agent configuration
agent_type: AgentType = Field(..., description="The type of agent.")
# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
class Config:
arbitrary_types_allowed = True
validate_assignment = True
class AgentState(PersistedAgentState):
class AgentState(BaseAgent):
"""
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
@ -86,42 +54,53 @@ class AgentState(PersistedAgentState):
"""
# NOTE: this is what is returned to the client and also what is used to initialize `Agent`
id: str = BaseAgent.generate_id_field()
name: str = Field(..., description="The name of the agent.")
# tool rules
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
# in-context memory
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
# system prompt
system: str = Field(..., description="The system prompt used by the agent.")
# agent configuration
agent_type: AgentType = Field(..., description="The type of agent.")
# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
# This is an object representing the in-process state of a running `Agent`
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.")
memory: Memory = Field(..., description="The in-context memory of the agent.")
tools: List[Tool] = Field(..., description="The tools used by the agent.")
sources: List[Source] = Field(..., description="The sources used by the agent.")
tags: List[str] = Field(..., description="The tags associated with the agent.")
# TODO: add in context message objects
def to_persisted_agent_state(self) -> PersistedAgentState:
# turn back into persisted agent
data = self.model_dump()
del data["memory"]
del data["tools"]
del data["sources"]
del data["tags"]
return PersistedAgentState(**data)
class CreateAgent(BaseAgent): #
# all optional as server can generate defaults
name: Optional[str] = Field(None, description="The name of the agent.")
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
name: str = Field(default_factory=lambda: create_random_username(), description="The name of the agent.")
# memory creation
memory_blocks: List[CreateBlock] = Field(
# [CreateHuman(), CreatePersona()], description="The blocks to create in the agent's in-context memory."
...,
description="The blocks to create in the agent's in-context memory.",
)
tools: List[str] = Field(BASE_TOOLS + BASE_MEMORY_TOOLS, description="The tools used by the agent.")
# TODO: This is a legacy field and should be removed ASAP to force `tool_ids` usage
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.")
source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.")
block_ids: Optional[List[str]] = Field(None, description="The ids of the blocks used by the agent.")
tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: AgentType = Field(AgentType.memgpt_agent, description="The type of agent.")
agent_type: AgentType = Field(default_factory=lambda: AgentType.memgpt_agent, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
@ -129,6 +108,7 @@ class CreateAgent(BaseAgent): #
initial_message_sequence: Optional[List[MessageCreate]] = Field(
None, description="The initial set of messages to put in the agent's in-context memory."
)
include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.")
@field_validator("name")
@classmethod
@ -156,18 +136,21 @@ class CreateAgent(BaseAgent): #
return name
class UpdateAgentState(BaseAgent):
id: str = Field(..., description="The id of the agent.")
class UpdateAgent(BaseAgent):
name: Optional[str] = Field(None, description="The name of the agent.")
tool_names: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.")
source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.")
block_ids: Optional[List[str]] = Field(None, description="The ids of the blocks used by the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
# TODO: determine if these should be editable via this schema?
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
class Config:
extra = "ignore" # Ignores extra fields
class AgentStepResponse(BaseModel):
messages: List[Message] = Field(..., description="The messages generated during the agent's step.")

View File

@ -1,33 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class AgentsTagsBase(LettaBase):
__id_prefix__ = "agents_tags"
class AgentsTags(AgentsTagsBase):
"""
Schema representing the relationship between tags and agents.
Parameters:
agent_id (str): The ID of the associated agent.
tag_id (str): The ID of the associated tag.
tag_name (str): The name of the tag.
created_at (datetime): The date this relationship was created.
"""
id: str = AgentsTagsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
tag: str = Field(..., description="The name of the tag.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the tag.")
is_deleted: bool = Field(False, description="Whether this tag is deleted or not.")
class AgentsTagsCreate(AgentsTagsBase):
tag: str = Field(..., description="The tag name.")

View File

@ -1,32 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class BlocksAgentsBase(LettaBase):
__id_prefix__ = "blocks_agents"
class BlocksAgents(BlocksAgentsBase):
"""
Schema representing the relationship between blocks and agents.
Parameters:
agent_id (str): The ID of the associated agent.
block_id (str): The ID of the associated block.
block_label (str): The label of the block.
created_at (datetime): The date this relationship was created.
updated_at (datetime): The date this relationship was last updated.
is_deleted (bool): Whether this block-agent relationship is deleted or not.
"""
id: str = BlocksAgentsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
block_id: str = Field(..., description="The ID of the associated block.")
block_label: str = Field(..., description="The label of the block.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.")

View File

@ -87,7 +87,7 @@ class Memory(BaseModel, validate_assignment=True):
Template(prompt_template)
# Validate compatibility with current memory structure
test_render = Template(prompt_template).render(blocks=self.blocks)
Template(prompt_template).render(blocks=self.blocks)
# If we get here, the template is valid and compatible
self.prompt_template = prompt_template
@ -213,6 +213,7 @@ class ChatMemory(BasicBlockMemory):
human (str): The starter value for the human block.
limit (int): The character limit for each block.
"""
# TODO: Should these be CreateBlocks?
super().__init__(blocks=[Block(value=persona, limit=limit, label="persona"), Block(value=human, limit=limit, label="human")])

View File

@ -1,32 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class ToolsAgentsBase(LettaBase):
__id_prefix__ = "tools_agents"
class ToolsAgents(ToolsAgentsBase):
"""
Schema representing the relationship between tools and agents.
Parameters:
agent_id (str): The ID of the associated agent.
tool_id (str): The ID of the associated tool.
tool_name (str): The name of the tool.
created_at (datetime): The date this relationship was created.
updated_at (datetime): The date this relationship was last updated.
is_deleted (bool): Whether this tool-agent relationship is deleted or not.
"""
id: str = ToolsAgentsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
tool_id: str = Field(..., description="The ID of the associated tool.")
tool_name: str = Field(..., description="The name of the tool.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
is_deleted: bool = Field(False, description="Whether this tool-agent relationship is deleted or not.")

View File

@ -25,9 +25,6 @@ from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.routers.openai.assistants.assistants import (
router as openai_assistants_router,
)
from letta.server.rest_api.routers.openai.assistants.threads import (
router as openai_threads_router,
)
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import (
router as openai_chat_completions_router,
)
@ -215,7 +212,6 @@ def create_application() -> "FastAPI":
# openai
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
# /api/auth endpoints
@ -236,7 +232,6 @@ def create_application() -> "FastAPI":
@app.on_event("shutdown")
def on_shutdown():
global server
server.save_agents()
# server = None
return app

View File

@ -1,338 +0,0 @@
import uuid
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import CreateAgent
from letta.schemas.enums import MessageRole
from letta.schemas.message import Message
from letta.schemas.openai.openai import (
MessageFile,
OpenAIMessage,
OpenAIRun,
OpenAIRunStep,
OpenAIThread,
Text,
)
from letta.server.rest_api.routers.openai.assistants.schemas import (
CreateMessageRequest,
CreateRunRequest,
CreateThreadRequest,
CreateThreadRunRequest,
DeleteThreadResponse,
ListMessagesResponse,
ModifyMessageRequest,
ModifyRunRequest,
ModifyThreadRequest,
OpenAIThread,
SubmitToolOutputsToRunRequest,
)
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
if TYPE_CHECKING:
from letta.utils import get_utc_time
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
router = APIRouter(prefix="/v1/threads", tags=["threads"])
@router.post("/", response_model=OpenAIThread)
def create_thread(
request: CreateThreadRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
# TODO: use requests.description and requests.metadata fields
# TODO: handle requests.file_ids and requests.tools
# TODO: eventually allow request to override embedding/llm model
actor = server.get_user_or_default(user_id=user_id)
print("Create thread/agent", request)
# create a letta agent
agent_state = server.create_agent(
request=CreateAgent(),
user_id=actor.id,
)
# TODO: insert messages into recall memory
return OpenAIThread(
id=str(agent_state.id),
created_at=int(agent_state.created_at.timestamp()),
metadata={}, # TODO add metadata?
)
@router.get("/{thread_id}", response_model=OpenAIThread)
def retrieve_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
assert agent is not None
return OpenAIThread(
id=str(agent.id),
created_at=int(agent.created_at.timestamp()),
metadata={}, # TODO add metadata?
)
@router.get("/{thread_id}", response_model=OpenAIThread)
def modify_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: ModifyThreadRequest = Body(...),
):
# TODO: add agent metadata so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.delete("/{thread_id}", response_model=DeleteThreadResponse)
def delete_thread(
thread_id: str = Path(..., description="The unique identifier of the thread."),
):
# TODO: delete agent
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
def create_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateMessageRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
agent_id = thread_id
# create message object
message = Message(
user_id=actor.id,
agent_id=agent_id,
role=MessageRole(request.role),
text=request.content,
model=None,
tool_calls=None,
tool_call_id=None,
name=None,
)
agent = server.load_agent(agent_id=agent_id)
# add message to agent
agent._append_to_messages([message])
openai_message = OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
return openai_message
@router.get("/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
def list_messages(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many messages to retrieve."),
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id)
after_uuid = after if before else None
before_uuid = before if before else None
agent_id = thread_id
reverse = True if (order == "desc") else False
json_messages = server.get_agent_recall_cursor(
user_id=actor.id,
agent_id=agent_id,
limit=limit,
after=after_uuid,
before=before_uuid,
order_by="created_at",
reverse=reverse,
)
assert isinstance(json_messages, List)
assert all([isinstance(message, Message) for message in json_messages])
assert isinstance(json_messages[0], Message)
print(json_messages[0].text)
# convert to openai style messages
openai_messages = []
for message in json_messages:
assert isinstance(message, Message)
openai_messages.append(
OpenAIMessage(
id=str(message.id),
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=str(message.role),
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
)
print("MESSAGES", openai_messages)
# TODO: cast back to message objects
return ListMessagesResponse(messages=openai_messages)
@router.get("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def retrieve_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
server: SyncServer = Depends(get_letta_server),
):
agent_id = thread_id
message = server.get_agent_message(agent_id=agent_id, message_id=message_id)
assert message is not None
return OpenAIMessage(
id=message_id,
created_at=int(message.created_at.timestamp()),
content=[Text(text=(message.text if message.text else ""))],
role=message.role,
thread_id=str(message.agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
# TODO(sarah) fill in?
run_id=None,
file_ids=None,
metadata=None,
# file_ids=message.file_ids,
# metadata=message.metadata,
)
@router.get("/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
def retrieve_message_file(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
file_id: str = Path(..., description="The unique identifier of the file."),
):
# TODO: implement?
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
def modify_message(
thread_id: str = Path(..., description="The unique identifier of the thread."),
message_id: str = Path(..., description="The unique identifier of the message."),
request: ModifyMessageRequest = Body(...),
):
# TODO: add metada field to message so this can be modified
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
def create_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
request: CreateRunRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
):
# TODO: add request.instructions as a message?
agent_id = thread_id
# TODO: override preset of agent with request.assistant_id
agent = server.load_agent(agent_id=agent_id)
agent.inner_step(messages=[]) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(get_utc_time().timestamp())
return OpenAIRun(
id=run_id,
created_at=create_time,
thread_id=str(agent_id),
assistant_id=DEFAULT_PRESET, # TODO: update this
status="completed", # TODO: eventaully allow offline execution
expires_at=create_time,
model=agent.agent_state.llm_config.model,
instructions=request.instructions,
)
@router.post("/runs", tags=["runs"], response_model=OpenAIRun)
def create_thread_and_run(
request: CreateThreadRunRequest = Body(...),
):
# TODO: add a bunch of messages and execute
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
def list_runs(
thread_id: str = Path(..., description="The unique identifier of the thread."),
limit: int = Query(1000, description="How many runs to retrieve."),
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
def list_run_steps(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
limit: int = Query(1000, description="How many run steps to retrieve."),
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
):
# TODO: store run information in a DB so it can be returned here
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def retrieve_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.get("/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
def retrieve_run_step(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
step_id: str = Path(..., description="The unique identifier of the run step."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
def modify_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: ModifyRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
def submit_tool_outputs_to_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
request: SubmitToolOutputsToRunRequest = Body(...),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
@router.post("/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
def cancel_run(
thread_id: str = Path(..., description="The unique identifier of the thread."),
run_id: str = Path(..., description="The unique identifier of the run."),
):
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")

View File

@ -36,7 +36,7 @@ async def create_chat_completion(
The bearer token will be used to identify the user.
The 'user' field in the completion_request should be set to the agent ID.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_id = completion_request.user
if agent_id is None:

View File

@ -17,7 +17,8 @@ from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import Field
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate
Block,
BlockUpdate,
@ -54,23 +55,38 @@ from letta.server.server import SyncServer
router = APIRouter(prefix="/agents", tags=["agents"])
# TODO: This should be paginated
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
def list_agents(
name: Optional[str] = Query(None, description="Name of the agent"),
tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"),
match_all_tags: bool = Query(
False,
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.",
),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
user_id: Optional[str] = Header(None, alias="user_id"),
# Extract user_id from header, default to None if not present
):
"""
List all agents associated with a given user.
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
agents = server.list_agents(user_id=actor.id, tags=tags)
# TODO: move this logic to the ORM
if name:
agents = [a for a in agents if a.name == name]
# Use dictionary comprehension to build kwargs dynamically
kwargs = {
key: value
for key, value in {
"tags": tags,
"match_all_tags": match_all_tags,
"name": name,
}.items()
if value is not None
}
# Call list_agents with the dynamic kwargs
agents = server.agent_manager.list_agents(actor=actor, **kwargs)
return agents
@ -83,7 +99,7 @@ def get_agent_context_window(
"""
Retrieve the context window of a specific agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
@ -106,20 +122,20 @@ def create_agent(
"""
Create a new agent with the specified configuration.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.create_agent(agent, actor=actor)
@router.patch("/{agent_id}", response_model=AgentState, operation_id="update_agent")
def update_agent(
agent_id: str,
update_agent: UpdateAgentState = Body(...),
update_agent: UpdateAgent = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Update an exsiting agent"""
actor = server.get_user_or_default(user_id=user_id)
return server.update_agent(update_agent, actor=actor)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.update_agent(agent_id, update_agent, actor=actor)
@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent")
@ -129,7 +145,7 @@ def get_tools_from_agent(
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Get tools from an existing agent"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id)
@ -141,7 +157,7 @@ def add_tool_to_agent(
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Add tools to an existing agent"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
@ -153,7 +169,7 @@ def remove_tool_from_agent(
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""Add tools to an existing agent"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
@ -166,13 +182,12 @@ def get_agent_state(
"""
Get the state of the agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
return server.get_agent_state(user_id=actor.id, agent_id=agent_id)
try:
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.delete("/{agent_id}", response_model=AgentState, operation_id="delete_agent")
@ -184,38 +199,37 @@ def delete_agent(
"""
Delete an agent.
"""
actor = server.get_user_or_default(user_id=user_id)
agent = server.get_agent(agent_id)
if not agent:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
server.delete_agent(user_id=actor.id, agent_id=agent_id)
return agent
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
return server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="get_agent_sources")
def get_agent_sources(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Get the sources associated with an agent.
"""
return server.list_attached_sources(agent_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor)
@router.get("/{agent_id}/memory/messages", response_model=List[Message], operation_id="list_agent_in_context_messages")
def get_agent_in_context_messages(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the messages in the context of a specific agent.
"""
return server.get_in_context_messages(agent_id=agent_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_in_context_messages(agent_id=agent_id, actor=actor)
# TODO: remove? can also get with agent blocks
@ -223,13 +237,15 @@ def get_agent_in_context_messages(
def get_agent_memory(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the memory state of a specific agent.
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_agent_memory(agent_id=agent_id)
return server.get_agent_memory(agent_id=agent_id, actor=actor)
@router.get("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="get_agent_memory_block")
@ -242,10 +258,12 @@ def get_agent_memory_block(
"""
Retrieve a memory block from an agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label)
return server.block_manager.get_block_by_id(block_id, actor=actor)
try:
return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{agent_id}/memory/block", response_model=List[Block], operation_id="get_agent_memory_blocks")
@ -257,9 +275,12 @@ def get_agent_memory_blocks(
"""
Retrieve the memory blocks of a specific agent.
"""
actor = server.get_user_or_default(user_id=user_id)
block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id)
return [server.block_manager.get_block_by_id(block_id, actor=actor) for block_id in block_ids]
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
return agent.memory.blocks
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block")
@ -272,16 +293,17 @@ def add_agent_memory_block(
"""
Creates a memory block and links it to the agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
# Copied from POST /blocks
# TODO: Should have block_manager accept only CreateBlock
# TODO: This will be possible once we move ID creation to the ORM
block_req = Block(**create_block.model_dump())
block = server.block_manager.create_or_update_block(actor=actor, block=block_req)
# Link the block to the agent
updated_memory = server.link_block_to_agent_memory(user_id=actor.id, agent_id=agent_id, block_id=block.id)
return updated_memory
agent = server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=actor)
return agent.memory
@router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block_by_label")
@ -296,56 +318,56 @@ def remove_agent_memory_block(
"""
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
# Unlink the block from the agent
updated_memory = server.unlink_block_from_agent_memory(user_id=actor.id, agent_id=agent_id, block_label=block_label)
agent = server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
return updated_memory
return agent.memory
@router.patch("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="update_agent_memory_block_by_label")
def update_agent_memory_block(
agent_id: str,
block_label: str,
update_block: BlockUpdate = Body(...),
block_update: BlockUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
# get the block_id from the label
block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label)
# update the block
return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor)
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
return server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
def get_agent_recall_memory_summary(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the summary of the recall memory of a specific agent.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_recall_memory_summary(agent_id=agent_id)
return server.get_recall_memory_summary(agent_id=agent_id, actor=actor)
@router.get("/{agent_id}/memory/archival", response_model=ArchivalMemorySummary, operation_id="get_agent_archival_memory_summary")
def get_agent_archival_memory_summary(
agent_id: str,
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Retrieve the summary of the archival memory of a specific agent.
"""
return server.get_archival_memory_summary(agent_id=agent_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_archival_memory_summary(agent_id=agent_id, actor=actor)
@router.get("/{agent_id}/archival", response_model=List[Passage], operation_id="list_agent_archival_memory")
@ -360,7 +382,7 @@ def get_agent_archival_memory(
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
# TODO need to add support for non-postgres here
# chroma will throw:
@ -369,7 +391,7 @@ def get_agent_archival_memory(
return server.get_agent_archival_cursor(
user_id=actor.id,
agent_id=agent_id,
cursor=after, # TODO: deleting before, after. is this expected?
cursor=after, # TODO: deleting before, after. is this expected?
limit=limit,
)
@ -384,9 +406,9 @@ def insert_agent_archival_memory(
"""
Insert a memory into an agent's archival memory store.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor)
# TODO(ethan): query or path parameter for memory_id?
@ -402,9 +424,9 @@ def delete_agent_archival_memory(
"""
Delete a memory from an agent's archival memory store.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=actor)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
@ -429,7 +451,7 @@ def get_agent_messages(
"""
Retrieve message history for an agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.get_agent_recall_cursor(
user_id=actor.id,
@ -449,11 +471,13 @@ def update_message(
message_id: str,
request: MessageUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Update the details of a message associated with an agent.
"""
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request, actor=actor)
@router.post(
@ -471,11 +495,11 @@ async def send_message(
Process a user message and return the agent's response.
This endpoint accepts a message from a user and processes it through the agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
result = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
actor=actor,
messages=request.messages,
stream_steps=False,
stream_tokens=False,
@ -511,11 +535,11 @@ async def send_message_streaming(
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
result = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
actor=actor,
messages=request.messages,
stream_steps=True,
stream_tokens=request.stream_tokens,
@ -531,7 +555,6 @@ async def process_message_background(
server: SyncServer,
actor: User,
agent_id: str,
user_id: str,
messages: list,
assistant_message_tool_name: str,
assistant_message_tool_kwarg: str,
@ -542,7 +565,7 @@ async def process_message_background(
result = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=user_id,
actor=actor,
messages=messages,
stream_steps=False, # NOTE(matt)
stream_tokens=False,
@ -585,7 +608,7 @@ async def send_message_async(
Asynchronously process a user message and return a job ID.
The actual processing happens in the background, and the status can be checked using the job ID.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
# Create a new job
job = Job(
@ -605,7 +628,6 @@ async def send_message_async(
server=server,
actor=actor,
agent_id=agent_id,
user_id=actor.id,
messages=request.messages,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
@ -618,7 +640,7 @@ async def send_message_async(
async def send_message_to_agent(
server: SyncServer,
agent_id: str,
user_id: str,
actor: User,
# role: MessageRole,
messages: Union[List[Message], List[MessageCreate]],
stream_steps: bool,
@ -645,8 +667,7 @@ async def send_message_to_agent(
# Get the generator object off of the agent's streaming interface
# This will be attached to the POST SSE request used under-the-hood
# letta_agent = server.load_agent(agent_id=agent_id)
letta_agent = server.load_agent(agent_id=agent_id)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
# Disable token streaming if not OpenAI
# TODO: cleanup this logic
@ -685,7 +706,7 @@ async def send_message_to_agent(
task = asyncio.create_task(
asyncio.to_thread(
server.send_messages,
user_id=user_id,
actor=actor,
agent_id=agent_id,
messages=messages,
interface=streaming_interface,

View File

@ -1,10 +1,9 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response
from letta.orm.errors import NoResultFound
from letta.schemas.block import Block, BlockUpdate, CreateBlock
from letta.schemas.memory import Memory
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
@ -23,7 +22,7 @@ def list_blocks(
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name)
@ -33,7 +32,7 @@ def create_block(
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
block = Block(**create_block.model_dump())
return server.block_manager.create_or_update_block(actor=actor, block=block)
@ -41,12 +40,12 @@ def create_block(
@router.patch("/{block_id}", response_model=Block, operation_id="update_memory_block")
def update_block(
block_id: str,
update_block: BlockUpdate = Body(...),
block_update: BlockUpdate = Body(...),
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
actor = server.get_user_or_default(user_id=user_id)
return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor)
@router.delete("/{block_id}", response_model=Block, operation_id="delete_memory_block")
@ -55,7 +54,7 @@ def delete_block(
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.block_manager.delete_block(block_id=block_id, actor=actor)
@ -66,7 +65,7 @@ def get_block(
user_id: Optional[str] = Header(None, alias="user_id"),
):
print("call get block", block_id)
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
if block is None:
@ -76,7 +75,7 @@ def get_block(
raise HTTPException(status_code=404, detail="Block not found")
@router.patch("/{block_id}/attach", response_model=Block, operation_id="link_agent_memory_block")
@router.patch("/{block_id}/attach", response_model=None, status_code=204, operation_id="link_agent_memory_block")
def link_agent_memory_block(
block_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
@ -86,17 +85,16 @@ def link_agent_memory_block(
"""
Link a memory block to an agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
if block is None:
raise HTTPException(status_code=404, detail="Block not found")
server.blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block.label)
return block
try:
server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
return Response(status_code=204)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.patch("/{block_id}/detach", response_model=Memory, operation_id="unlink_agent_memory_block")
@router.patch("/{block_id}/detach", response_model=None, status_code=204, operation_id="unlink_agent_memory_block")
def unlink_agent_memory_block(
block_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
@ -106,11 +104,10 @@ def unlink_agent_memory_block(
"""
Unlink a memory block from an agent
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
if block is None:
raise HTTPException(status_code=404, detail="Block not found")
# Link the block to the agent
server.blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id)
return block
try:
server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
return Response(status_code=204)
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))

View File

@ -20,7 +20,7 @@ def list_jobs(
"""
List all jobs.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
# TODO: add filtering by status
jobs = server.job_manager.list_jobs(actor=actor)
@ -40,7 +40,7 @@ def list_active_jobs(
"""
List all active jobs.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running])
@ -54,7 +54,7 @@ def get_job(
"""
Get the status of a job.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
return server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
@ -71,7 +71,7 @@ def delete_job(
"""
Delete a job by its job_id.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor)

View File

@ -25,7 +25,7 @@ def create_sandbox_config(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor)
@ -35,7 +35,7 @@ def create_default_e2b_sandbox_config(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor)
@ -44,7 +44,7 @@ def create_default_local_sandbox_config(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
@ -55,7 +55,7 @@ def update_sandbox_config(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor)
@ -65,7 +65,7 @@ def delete_sandbox_config(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor)
@ -76,7 +76,7 @@ def list_sandbox_configs(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, cursor=cursor)
@ -90,7 +90,7 @@ def create_sandbox_env_var(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor)
@ -101,7 +101,7 @@ def update_sandbox_env_var(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor)
@ -111,7 +111,7 @@ def delete_sandbox_env_var(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor)
@ -123,5 +123,5 @@ def list_sandbox_env_vars(
server: SyncServer = Depends(get_letta_server),
user_id: str = Depends(get_user_id),
):
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, cursor=cursor)

View File

@ -36,7 +36,7 @@ def get_source(
"""
Get all sources
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if not source:
@ -53,7 +53,7 @@ def get_source_id_by_name(
"""
Get a source by name
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
if not source:
@ -69,7 +69,7 @@ def list_sources(
"""
List all data sources created by a user.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.list_all_sources(actor=actor)
@ -83,7 +83,7 @@ def create_source(
"""
Create a new data source.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
source = Source(**source_create.model_dump())
return server.source_manager.create_source(source=source, actor=actor)
@ -99,7 +99,7 @@ def update_source(
"""
Update the name or documentation of an existing data source.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
@ -114,7 +114,7 @@ def delete_source(
"""
Delete a data source.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.delete_source(source_id=source_id, actor=actor)
@ -129,7 +129,7 @@ def attach_source_to_agent(
"""
Attach a data source to an existing agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
@ -147,7 +147,7 @@ def detach_source_from_agent(
"""
Detach a data source from an existing agent.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
@ -163,7 +163,7 @@ def upload_file_to_source(
"""
Upload a file to a data source.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
@ -197,7 +197,7 @@ def list_passages(
"""
List all passages associated with a data source.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
return passages
@ -213,7 +213,7 @@ def list_files_from_source(
"""
List paginated files associated with a data source.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor)
@ -229,7 +229,7 @@ def delete_file_from_source(
"""
Delete a data source.
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
if deleted_file is None:

View File

@ -25,7 +25,7 @@ def delete_tool(
"""
Delete a tool by name
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor)
@ -38,7 +38,7 @@ def get_tool(
"""
Get a tool by ID
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
if tool is None:
# return 404 error
@ -55,7 +55,7 @@ def get_tool_id(
"""
Get a tool ID by name
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
if tool:
return tool.id
@ -74,7 +74,7 @@ def list_tools(
Get a list of all tools available to agents belonging to the org of the user
"""
try:
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.tool_manager.list_tools(actor=actor, cursor=cursor, limit=limit)
except Exception as e:
# Log or print the full exception here for debugging
@ -92,7 +92,7 @@ def create_tool(
Create a new tool
"""
try:
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
tool = Tool(**request.model_dump())
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
except UniqueConstraintViolationError as e:
@ -124,7 +124,7 @@ def upsert_tool(
Create or update a tool
"""
try:
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
return tool
except UniqueConstraintViolationError as e:
@ -147,7 +147,7 @@ def update_tool(
"""
Update an existing tool
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor)
@ -159,7 +159,7 @@ def add_base_tools(
"""
Add base tools
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.tool_manager.add_base_tools(actor=actor)
@ -173,7 +173,7 @@ def add_base_tools(
# """
# Run an existing tool on provided arguments
# """
# actor = server.get_user_or_default(user_id=user_id)
# actor = server.user_manager.get_user_or_default(user_id=user_id)
# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)
@ -187,7 +187,7 @@ def run_tool_from_source(
"""
Attempt to build a tool from source, then run it on the provided arguments
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
return server.run_tool_from_source(
@ -220,7 +220,7 @@ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id:
"""
Get a list of all Composio apps
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
composio_api_key = get_composio_key(server, actor=actor)
return server.get_composio_apps(api_key=composio_api_key)
@ -234,7 +234,7 @@ def list_composio_actions_by_app(
"""
Get a list of all Composio actions for a specific app
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
composio_api_key = get_composio_key(server, actor=actor)
return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name, api_key=composio_api_key)
@ -248,7 +248,7 @@ def add_composio_tool(
"""
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
"""
actor = server.get_user_or_default(user_id=user_id)
actor = server.user_manager.get_user_or_default(user_id=user_id)
composio_api_key = get_composio_key(server, actor=actor)
try:

File diff suppressed because it is too large Load Diff

View File

@ -19,11 +19,6 @@ class WebSocketServer:
self.server = SyncServer(default_interface=self.interface)
def shutdown_server(self):
try:
self.server.save_agents()
print(f"Saved agents")
except Exception as e:
print(f"Saving agents failed with: {e}")
try:
self.interface.close()
print(f"Closed the WS interface")

View File

@ -0,0 +1,405 @@
from typing import Dict, List, Optional
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.orm import Agent as AgentModel
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import Tool as ToolModel
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_process_relationship,
_process_tags,
derive_system_message,
)
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.utils import enforce_types
# Agent Manager Class
class AgentManager:
"""Manager class to handle business logic related to Agents."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
self.block_manager = BlockManager()
self.tool_manager = ToolManager()
self.source_manager = SourceManager()
# ======================================================================================================================
# Basic CRUD operations
# ======================================================================================================================
@enforce_types
def create_agent(
self,
agent_create: CreateAgent,
actor: PydanticUser,
) -> PydanticAgentState:
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)
# create blocks (note: cannot be linked into the agent_id is created)
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
for create_block in agent_create.memory_blocks:
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump()), actor=actor)
block_ids.append(block.id)
# TODO: Remove this block once we deprecate the legacy `tools` field
# create passed in `tools`
tool_names = []
if agent_create.include_base_tools:
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
if agent_create.tools:
tool_names.extend(agent_create.tools)
tool_ids = agent_create.tool_ids or []
for tool_name in tool_names:
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
if tool:
tool_ids.append(tool.id)
# Remove duplicates
tool_ids = list(set(tool_ids))
return self._create_agent(
name=agent_create.name,
system=system,
agent_type=agent_create.agent_type,
llm_config=agent_create.llm_config,
embedding_config=agent_create.embedding_config,
block_ids=block_ids,
tool_ids=tool_ids,
source_ids=agent_create.source_ids or [],
tags=agent_create.tags or [],
description=agent_create.description,
metadata_=agent_create.metadata_,
tool_rules=agent_create.tool_rules,
actor=actor,
)
@enforce_types
def _create_agent(
self,
actor: PydanticUser,
name: str,
system: str,
agent_type: AgentType,
llm_config: LLMConfig,
embedding_config: EmbeddingConfig,
block_ids: List[str],
tool_ids: List[str],
source_ids: List[str],
tags: List[str],
description: Optional[str] = None,
metadata_: Optional[Dict] = None,
tool_rules: Optional[List[PydanticToolRule]] = None,
) -> PydanticAgentState:
"""Create a new agent."""
with self.session_maker() as session:
# Prepare the agent data
data = {
"name": name,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config,
"embedding_config": embedding_config,
"organization_id": actor.organization_id,
"description": description,
"metadata_": metadata_,
"tool_rules": tool_rules,
}
# Create the new agent using SqlalchemyBase.create
new_agent = AgentModel(**data)
_process_relationship(session, new_agent, "tools", ToolModel, tool_ids, replace=True)
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
_process_tags(new_agent, tags, replace=True)
new_agent.create(session, actor=actor)
# Convert to PydanticAgentState and return
return new_agent.to_pydantic()
@enforce_types
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
"""
Update an existing agent.
Args:
agent_id: The ID of the agent to update.
agent_update: UpdateAgent object containing the updated fields.
actor: User performing the action.
Returns:
PydanticAgentState: The updated agent as a Pydantic model.
"""
with self.session_maker() as session:
# Retrieve the existing agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Update scalar fields directly
scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata_"}
for field in scalar_fields:
value = getattr(agent_update, field, None)
if value is not None:
setattr(agent, field, value)
# Update relationships using _process_relationship and _process_tags
if agent_update.tool_ids is not None:
_process_relationship(session, agent, "tools", ToolModel, agent_update.tool_ids, replace=True)
if agent_update.source_ids is not None:
_process_relationship(session, agent, "sources", SourceModel, agent_update.source_ids, replace=True)
if agent_update.block_ids is not None:
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
if agent_update.tags is not None:
_process_tags(agent, agent_update.tags, replace=True)
# Commit and refresh the agent
agent.update(session, actor=actor)
# Convert to PydanticAgentState and return
return agent.to_pydantic()
@enforce_types
def list_agents(
self,
actor: PydanticUser,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
cursor: Optional[str] = None,
limit: Optional[int] = 50,
**kwargs,
) -> List[PydanticAgentState]:
"""
List agents that have the specified tags.
"""
with self.session_maker() as session:
agents = AgentModel.list(
db_session=session,
tags=tags,
match_all_tags=match_all_tags,
cursor=cursor,
limit=limit,
organization_id=actor.organization_id if actor else None,
**kwargs,
)
return [agent.to_pydantic() for agent in agents]
@enforce_types
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return agent.to_pydantic()
@enforce_types
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
return agent.to_pydantic()
@enforce_types
def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""
Deletes an agent and its associated relationships.
Ensures proper permission checks and cascades where applicable.
Args:
agent_id: ID of the agent to be deleted.
actor: User performing the action.
Returns:
PydanticAgentState: The deleted agent state
"""
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
# TODO: This is done very hacky on purpose
# TODO: 1000 limit is also wack
passage_manager = PassageManager()
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
agent_state = agent.to_pydantic()
agent.hard_delete(session)
return agent_state
# ======================================================================================================================
# Source Management
# ======================================================================================================================
@enforce_types
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
"""
Attaches a source to an agent.
Args:
agent_id: ID of the agent to attach the source to
source_id: ID of the source to attach
actor: User performing the action
Raises:
ValueError: If either agent or source doesn't exist
IntegrityError: If the source is already attached to the agent
"""
with self.session_maker() as session:
# Verify both agent and source exist and user has permission to access them
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# The _process_relationship helper already handles duplicate checking via unique constraint
_process_relationship(
session=session,
agent=agent,
relationship_name="sources",
model_class=SourceModel,
item_ids=[source_id],
allow_partial=False,
replace=False, # Extend existing sources rather than replace
)
# Commit the changes
agent.update(session, actor=actor)
@enforce_types
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
"""
Lists all sources attached to an agent.
Args:
agent_id: ID of the agent to list sources for
actor: User performing the action
Returns:
List[str]: List of source IDs attached to the agent
"""
with self.session_maker() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Use the lazy-loaded relationship to get sources
return [source.to_pydantic() for source in agent.sources]
@enforce_types
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
"""
Detaches a source from an agent.
Args:
agent_id: ID of the agent to detach the source from
source_id: ID of the source to detach
actor: User performing the action
"""
with self.session_maker() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# Remove the source from the relationship
agent.sources = [s for s in agent.sources if s.id != source_id]
# Commit the changes
agent.update(session, actor=actor)
# ======================================================================================================================
# Block management
# ======================================================================================================================
@enforce_types
def get_block_with_label(
self,
agent_id: str,
block_label: str,
actor: PydanticUser,
) -> PydanticBlock:
"""Gets a block attached to an agent by its label."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
for block in agent.core_memory:
if block.label == block_label:
return block.to_pydantic()
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
@enforce_types
def update_block_with_label(
self,
agent_id: str,
block_label: str,
new_block_id: str,
actor: PydanticUser,
) -> PydanticAgentState:
"""Updates which block is assigned to a specific label for an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor)
if new_block.label != block_label:
raise ValueError(f"New block label '{new_block.label}' doesn't match required label '{block_label}'")
# Remove old block with this label if it exists
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
# Add new block
agent.core_memory.append(new_block)
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Attaches a block to an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
agent.core_memory.append(block)
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def detach_block(
self,
agent_id: str,
block_id: str,
actor: PydanticUser,
) -> PydanticAgentState:
"""Detaches a block from an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
original_length = len(agent.core_memory)
agent.core_memory = [b for b in agent.core_memory if b.id != block_id]
if len(agent.core_memory) == original_length:
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
agent.update(session, actor=actor)
return agent.to_pydantic()
@enforce_types
def detach_block_with_label(
self,
agent_id: str,
block_label: str,
actor: PydanticUser,
) -> PydanticAgentState:
"""Detaches a block with the specified label from an agent."""
with self.session_maker() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
original_length = len(agent.core_memory)
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
if len(agent.core_memory) == original_length:
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}' with actor id: '{actor.id}'")
agent.update(session, actor=actor)
return agent.to_pydantic()

View File

@ -1,64 +0,0 @@
from typing import List
from letta.orm.agents_tags import AgentsTags as AgentsTagsModel
from letta.orm.errors import NoResultFound
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class AgentsTagsManager:
"""Manager class to handle business logic related to Tags."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def add_tag_to_agent(self, agent_id: str, tag: str, actor: PydanticUser) -> PydanticAgentsTags:
"""Add a tag to an agent."""
with self.session_maker() as session:
# Check if the tag already exists for this agent
try:
agents_tags_model = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
return agents_tags_model.to_pydantic()
except NoResultFound:
agents_tags = PydanticAgentsTags(agent_id=agent_id, tag=tag).model_dump(exclude_none=True)
new_tag = AgentsTagsModel(**agents_tags, organization_id=actor.organization_id)
new_tag.create(session, actor=actor)
return new_tag.to_pydantic()
@enforce_types
def delete_all_tags_from_agent(self, agent_id: str, actor: PydanticUser):
"""Delete a tag from an agent. This is a permanent hard delete."""
tags = self.get_tags_for_agent(agent_id=agent_id, actor=actor)
for tag in tags:
self.delete_tag_from_agent(agent_id=agent_id, tag=tag, actor=actor)
@enforce_types
def delete_tag_from_agent(self, agent_id: str, tag: str, actor: PydanticUser):
"""Delete a tag from an agent."""
with self.session_maker() as session:
try:
# Retrieve and delete the tag association
tag_association = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
tag_association.hard_delete(session, actor=actor)
except NoResultFound:
raise ValueError(f"Tag '{tag}' not found for agent '{agent_id}'.")
@enforce_types
def get_agents_by_tag(self, tag: str, actor: PydanticUser) -> List[str]:
"""Retrieve all agent IDs associated with a specific tag."""
with self.session_maker() as session:
# Query for all agents with the given tag
agents_with_tag = AgentsTagsModel.list(db_session=session, tag=tag, organization_id=actor.organization_id)
return [record.agent_id for record in agents_with_tag]
@enforce_types
def get_tags_for_agent(self, agent_id: str, actor: PydanticUser) -> List[str]:
"""Retrieve all tags associated with a specific agent."""
with self.session_maker() as session:
# Query for all tags associated with the given agent
tags_for_agent = AgentsTagsModel.list(db_session=session, agent_id=agent_id, organization_id=actor.organization_id)
return [record.tag for record in tags_for_agent]

View File

@ -7,7 +7,6 @@ from letta.schemas.block import Block
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, Human, Persona
from letta.schemas.user import User as PydanticUser
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.utils import enforce_types, list_human_files, list_persona_files
@ -37,33 +36,17 @@ class BlockManager:
@enforce_types
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
"""Update a block by its ID with the given BlockUpdate object."""
# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
blocks_agents_manager = BlocksAgentsManager()
agent_ids = []
if block_update.label:
agent_ids = blocks_agents_manager.list_agent_ids_with_block(block_id=block_id)
for agent_id in agent_ids:
blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id)
# Safety check for block
with self.session_maker() as session:
# Update block
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(block, key, value)
try:
block.to_pydantic()
except Exception as e:
# invalid pydantic model
raise ValueError(f"Failed to create pydantic model: {e}")
block.update(db_session=session, actor=actor)
# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
if block_update.label:
for agent_id in agent_ids:
blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block_update.label)
return block.to_pydantic()
return block.to_pydantic()
@enforce_types
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
@ -111,6 +94,15 @@ class BlockManager:
except NoResultFound:
return None
@enforce_types
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
# TODO: We can do this much more efficiently by listing, instead of executing individual queries per block_id
blocks = []
for block_id in block_ids:
block = self.get_block_by_id(block_id, actor=actor)
blocks.append(block)
return blocks
@enforce_types
def add_default_blocks(self, actor: PydanticUser):
for persona_file in list_persona_files():

View File

@ -1,106 +0,0 @@
import warnings
from typing import List
from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel
from letta.orm.errors import NoResultFound
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
from letta.utils import enforce_types
# TODO: DELETE THIS ASAP
# TODO: So we have a patch where we manually specify CRUD operations
# TODO: This is because Agent is NOT migrated to the ORM yet
# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers
class BlocksAgentsManager:
"""Manager class to handle business logic related to Blocks and Agents."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents:
"""Add a block to an agent. If the label already exists on that agent, this will error."""
with self.session_maker() as session:
try:
# Check if the block-label combination already exists for this agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.")
except NoResultFound:
blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label)
blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True))
blocks_agents_record.create(session)
return blocks_agents_record.to_pydantic()
@enforce_types
def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents:
"""Remove a block with a label from an agent."""
with self.session_maker() as session:
try:
# Find and delete the block-label association for the agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
blocks_agents_record.hard_delete(session)
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
@enforce_types
def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents:
"""Remove a block with a label from an agent."""
with self.session_maker() as session:
try:
# Find and delete the block-label association for the agent
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id)
blocks_agents_record.hard_delete(session)
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.")
@enforce_types
def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents:
"""Update the block ID for a specific block label for an agent."""
with self.session_maker() as session:
try:
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
blocks_agents_record.block_id = new_block_id
return blocks_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
@enforce_types
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
"""List all block ids associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_id for record in blocks_agents_record]
@enforce_types
def list_block_labels_for_agent(self, agent_id: str) -> List[str]:
"""List all block labels associated with a specific agent."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.block_label for record in blocks_agents_record]
@enforce_types
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
"""List all agents associated with a specific block."""
with self.session_maker() as session:
blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id)
return [record.agent_id for record in blocks_agents_record]
@enforce_types
def get_block_id_for_label(self, agent_id: str, block_label: str) -> str:
"""Get the block ID for a specific block label for an agent."""
with self.session_maker() as session:
try:
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
return blocks_agents_record.block_id
except NoResultFound:
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
@enforce_types
def remove_all_agent_blocks(self, agent_id: str):
for block_id in self.list_block_ids_for_agent(agent_id):
self.remove_block_with_id_from_agent(agent_id, block_id)

View File

@ -0,0 +1,90 @@
from typing import List, Optional
from letta.orm.agent import Agent as AgentModel
from letta.orm.agents_tags import AgentsTags
from letta.orm.errors import NoResultFound
from letta.prompts import gpt_system
from letta.schemas.agent import AgentType
# Static methods
def _process_relationship(
session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
):
"""
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
Args:
session: The database session.
agent: The AgentModel instance.
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
model_class: The ORM class corresponding to the related items.
item_ids: List of IDs to set or update.
allow_partial: If True, allows missing items without raising errors.
replace: If True, replaces the entire relationship; otherwise, extends it.
Raises:
ValueError: If `allow_partial` is False and some IDs are missing.
"""
current_relationship = getattr(agent, relationship_name, [])
if not item_ids:
if replace:
setattr(agent, relationship_name, [])
return
# Retrieve models for the provided IDs
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
# Validate all items are found if allow_partial is False
if not allow_partial and len(found_items) != len(item_ids):
missing = set(item_ids) - {item.id for item in found_items}
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
if replace:
# Replace the relationship
setattr(agent, relationship_name, found_items)
else:
# Extend the relationship (only add new items)
current_ids = {item.id for item in current_relationship}
new_items = [item for item in found_items if item.id not in current_ids]
current_relationship.extend(new_items)
def _process_tags(agent: AgentModel, tags: List[str], replace=True):
"""
Handles tags for an agent.
Args:
agent: The AgentModel instance.
tags: List of tags to set or update.
replace: If True, replaces all tags; otherwise, extends them.
"""
if not tags:
if replace:
agent.tags = []
return
# Ensure tags are unique and prepare for replacement/extension
new_tags = {AgentsTags(agent_id=agent.id, tag=tag) for tag in set(tags)}
if replace:
agent.tags = list(new_tags)
else:
existing_tags = {t.tag for t in agent.tags}
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
if system is None:
# TODO: don't hardcode
if agent_type == AgentType.memgpt_agent:
system = gpt_system.get_system_text("memgpt_chat")
elif agent_type == AgentType.o1_agent:
system = gpt_system.get_system_text("memgpt_modified_o1")
elif agent_type == AgentType.offline_memory_agent:
system = gpt_system.get_system_text("memgpt_offline_memory")
elif agent_type == AgentType.chat_only_agent:
system = gpt_system.get_system_text("memgpt_convo_only")
else:
raise ValueError(f"Invalid agent type: {agent_type}")
return system

View File

@ -1,25 +1,25 @@
from typing import List, Optional, Dict, Tuple
from letta.constants import MAX_EMBEDDING_DIM
from datetime import datetime
from typing import List, Optional
import numpy as np
from letta.orm.errors import NoResultFound
from letta.utils import enforce_types
from letta.constants import MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model, parse_and_chunk_text
from letta.schemas.embedding_config import EmbeddingConfig
from letta.orm.errors import NoResultFound
from letta.orm.passage import Passage as PassageModel
from letta.orm.sqlalchemy_base import AccessType
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class PassageManager:
"""Manager class to handle business logic related to Passages."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
@ -43,20 +43,20 @@ class PassageManager:
return [self.create_passage(p, actor) for p in passages]
@enforce_types
def insert_passage(self,
def insert_passage(
self,
agent_state: AgentState,
agent_id: str,
text: str,
actor: PydanticUser,
return_ids: bool = False
text: str,
actor: PydanticUser,
) -> List[PydanticPassage]:
""" Insert passage(s) into archival memory """
"""Insert passage(s) into archival memory"""
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
embed_model = embedding_model(agent_state.embedding_config)
passages = []
try:
# breakup string into passages
for text in parse_and_chunk_text(text, embedding_chunk_size):
@ -75,12 +75,12 @@ class PassageManager:
agent_id=agent_id,
text=text,
embedding=embedding,
embedding_config=agent_state.embedding_config
embedding_config=agent_state.embedding_config,
),
actor=actor
actor=actor,
)
passages.append(passage)
return passages
except Exception as e:
@ -125,20 +125,21 @@ class PassageManager:
raise ValueError(f"Passage with id {passage_id} not found.")
@enforce_types
def list_passages(self,
actor : PydanticUser,
agent_id : Optional[str] = None,
file_id : Optional[str] = None,
cursor : Optional[str] = None,
limit : Optional[int] = 50,
query_text : Optional[str] = None,
start_date : Optional[datetime] = None,
end_date : Optional[datetime] = None,
ascending : bool = True,
source_id : Optional[str] = None,
embed_query : bool = False,
embedding_config: Optional[EmbeddingConfig] = None
) -> List[PydanticPassage]:
def list_passages(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
cursor: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
ascending: bool = True,
source_id: Optional[str] = None,
embed_query: bool = False,
embedding_config: Optional[EmbeddingConfig] = None,
) -> List[PydanticPassage]:
"""List passages with pagination."""
with self.session_maker() as session:
filters = {"organization_id": actor.organization_id}
@ -148,7 +149,7 @@ class PassageManager:
filters["file_id"] = file_id
if source_id:
filters["source_id"] = source_id
embedded_text = None
if embed_query:
assert embedding_config is not None
@ -161,7 +162,7 @@ class PassageManager:
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
results = PassageModel.list(
db_session=session,
db_session=session,
cursor=cursor,
start_date=start_date,
end_date=end_date,
@ -169,17 +170,12 @@ class PassageManager:
ascending=ascending,
query_text=query_text if not embedded_text else None,
query_embedding=embedded_text,
**filters
**filters,
)
return [p.to_pydantic() for p in results]
@enforce_types
def size(
self,
actor : PydanticUser,
agent_id : Optional[str] = None,
**kwargs
) -> int:
def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int:
"""Get the total count of messages with optional filters.
Args:
@ -189,28 +185,32 @@ class PassageManager:
with self.session_maker() as session:
return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs)
def delete_passages(self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
cursor: Optional[str] = None,
query_text: Optional[str] = None,
source_id: Optional[str] = None
) -> bool:
def delete_passages(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
cursor: Optional[str] = None,
query_text: Optional[str] = None,
source_id: Optional[str] = None,
) -> bool:
passages = self.list_passages(
actor=actor,
agent_id=agent_id,
file_id=file_id,
cursor=cursor,
actor=actor,
agent_id=agent_id,
file_id=file_id,
cursor=cursor,
limit=limit,
start_date=start_date,
end_date=end_date,
query_text=query_text,
source_id=source_id)
start_date=start_date,
end_date=end_date,
query_text=query_text,
source_id=source_id,
)
# TODO: This is very inefficient
# TODO: We should have a base `delete_all_matching_filters`-esque function
for passage in passages:
self.delete_passage_by_id(passage_id=passage.id, actor=actor)

View File

@ -3,6 +3,7 @@ from typing import List, Optional
from letta.orm.errors import NoResultFound
from letta.orm.file import FileMetadata as FileMetadataModel
from letta.orm.source import Source as SourceModel
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
@ -60,7 +61,7 @@ class SourceManager:
"""Delete a source by its ID."""
with self.session_maker() as session:
source = SourceModel.read(db_session=session, identifier=source_id)
source.delete(db_session=session, actor=actor)
source.hard_delete(db_session=session, actor=actor)
return source.to_pydantic()
@enforce_types
@ -76,6 +77,26 @@ class SourceManager:
)
return [source.to_pydantic() for source in sources]
@enforce_types
def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
"""
Lists all agents that have the specified source attached.
Args:
source_id: ID of the source to find attached agents for
actor: User performing the action (optional for now, following existing pattern)
Returns:
List[PydanticAgentState]: List of agents that have this source attached
"""
with self.session_maker() as session:
# Verify source exists and user has permission to access it
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
# The agents relationship is already loaded due to lazy="selectin" in the Source model
# and will be properly filtered by organization_id due to the OrganizationMixin
return [agent.to_pydantic() for agent in source.agents]
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:

View File

@ -1,94 +0,0 @@
import warnings
from typing import List, Optional
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization
from letta.orm.tool import Tool
from letta.orm.tools_agents import ToolsAgents as ToolsAgentsModel
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
class ToolsAgentsManager:
"""Manages the relationship between tools and agents."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
def add_tool_to_agent(self, agent_id: str, tool_id: str, tool_name: str) -> PydanticToolsAgents:
"""Add a tool to an agent.
When a tool is added to an agent, it will be added to all agents in the same organization.
"""
with self.session_maker() as session:
try:
# Check if the tool-agent combination already exists for this agent
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
warnings.warn(f"Tool name '{tool_name}' already exists for agent '{agent_id}'.")
except NoResultFound:
tools_agents_record = PydanticToolsAgents(agent_id=agent_id, tool_id=tool_id, tool_name=tool_name)
tools_agents_record = ToolsAgentsModel(**tools_agents_record.model_dump(exclude_none=True))
tools_agents_record.create(session)
return tools_agents_record.to_pydantic()
def remove_tool_with_name_from_agent(self, agent_id: str, tool_name: str) -> None:
"""Remove a tool from an agent by its name.
When a tool is removed from an agent, it will be removed from all agents in the same organization.
"""
with self.session_maker() as session:
try:
# Find and delete the tool-agent association for the agent
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
tools_agents_record.hard_delete(session)
return tools_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
def remove_tool_with_id_from_agent(self, agent_id: str, tool_id: str) -> PydanticToolsAgents:
"""Remove a tool with an ID from an agent."""
with self.session_maker() as session:
try:
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_id=tool_id)
tools_agents_record.hard_delete(session)
return tools_agents_record.to_pydantic()
except NoResultFound:
raise ValueError(f"Tool ID '{tool_id}' not found for agent '{agent_id}'.")
def list_tool_ids_for_agent(self, agent_id: str) -> List[str]:
"""List all tool IDs associated with a specific agent."""
with self.session_maker() as session:
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.tool_id for record in tools_agents_record]
def list_tool_names_for_agent(self, agent_id: str) -> List[str]:
"""List all tool names associated with a specific agent."""
with self.session_maker() as session:
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
return [record.tool_name for record in tools_agents_record]
def list_agent_ids_with_tool(self, tool_id: str) -> List[str]:
"""List all agents associated with a specific tool."""
with self.session_maker() as session:
tools_agents_record = ToolsAgentsModel.list(db_session=session, tool_id=tool_id)
return [record.agent_id for record in tools_agents_record]
def get_tool_id_for_name(self, agent_id: str, tool_name: str) -> str:
"""Get the tool ID for a specific tool name for an agent."""
with self.session_maker() as session:
try:
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
return tools_agents_record.tool_id
except NoResultFound:
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
def remove_all_agent_tools(self, agent_id: str) -> None:
"""Remove all tools associated with an agent."""
with self.session_maker() as session:
tools_agents_records = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
for record in tools_agents_records:
record.hard_delete(session)

View File

@ -73,12 +73,6 @@ class UserManager:
user = UserModel.read(db_session=session, identifier=user_id)
user.hard_delete(session)
# TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping
# Cascade delete for related models: Agent, Source, AgentSourceMapping
# session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
# session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
# session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
session.commit()
@enforce_types
@ -93,6 +87,17 @@ class UserManager:
"""Fetch the default user."""
return self.get_user_by_id(self.DEFAULT_USER_ID)
@enforce_types
def get_user_or_default(self, user_id: Optional[str] = None):
"""Fetch the user or default user."""
if not user_id:
return self.get_default_user()
try:
return self.get_user_by_id(user_id=user_id)
except NoResultFound:
return self.get_default_user()
@enforce_types
def list_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> Tuple[Optional[str], List[PydanticUser]]:
"""List users with pagination using cursor (id) and limit."""

View File

@ -548,13 +548,13 @@ def enforce_types(func):
for arg_name, arg_value in args_with_hints.items():
hint = hints.get(arg_name)
if hint and not matches_type(arg_value, hint):
raise ValueError(f"Argument {arg_name} does not match type {hint}")
raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}")
# Check types of keyword arguments
for arg_name, arg_value in kwargs.items():
hint = hints.get(arg_name)
if hint and not matches_type(arg_value, hint):
raise ValueError(f"Argument {arg_name} does not match type {hint}")
raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}")
return func(*args, **kwargs)

View File

@ -4,7 +4,7 @@ import string
from locust import HttpUser, between, task
from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.schemas.agent import CreateAgent, PersistedAgentState
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.letta_request import LettaRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.memory import ChatMemory
@ -49,7 +49,7 @@ class LettaUser(HttpUser):
response.failure(f"Failed to create agent: {response.text}")
response_json = response.json()
agent_state = PersistedAgentState(**response_json)
agent_state = AgentState(**response_json)
self.agent_id = agent_state.id
print("Created agent", self.agent_id, agent_state.name)

View File

@ -1,90 +0,0 @@
import os
import uuid
from sqlalchemy import MetaData, Table, create_engine
from letta import create_client
from letta.config import LettaConfig
from letta.data_types import AgentState, EmbeddingConfig, LLMConfig
from letta.metadata import MetadataStore
from letta.presets.presets import add_default_tools
from letta.prompts import gpt_system
# Replace this with your actual database connection URL
config = LettaConfig.load()
if config.recall_storage_type == "sqlite":
DATABASE_URL = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
else:
DATABASE_URL = config.recall_storage_uri
print(DATABASE_URL)
engine = create_engine(DATABASE_URL)
metadata = MetaData()
# defaults
system_prompt = gpt_system.get_system_text("memgpt_chat")
# Reflect the existing table
table = Table("agents", metadata, autoload_with=engine)
# get all agent rows
agent_states = []
with engine.connect() as conn:
agents = conn.execute(table.select()).fetchall()
for agent in agents:
id = uuid.UUID(agent[0])
user_id = uuid.UUID(agent[1])
name = agent[2]
print(f"Migrating agent {name}")
persona = agent[3]
human = agent[4]
system = agent[5]
preset = agent[6]
created_at = agent[7]
llm_config = LLMConfig(**agent[8])
embedding_config = EmbeddingConfig(**agent[9])
state = agent[10]
tools = agent[11]
state["memory"] = {"human": {"value": human, "limit": 2000}, "persona": {"value": persona, "limit": 2000}}
agent_state = AgentState(
id=id,
user_id=user_id,
name=name,
system=system,
created_at=created_at,
llm_config=llm_config,
embedding_config=embedding_config,
state=state,
tools=tools,
_metadata={"human": "migrated", "persona": "migrated"},
)
agent_states.append(agent_state)
# remove agents table
agents_model = Table("agents", metadata, autoload_with=engine)
agents_model.drop(engine)
# remove tool table
tool_model = Table("toolmodel", metadata, autoload_with=engine)
tool_model.drop(engine)
# re-create tables and add default tools
ms = MetadataStore(config)
add_default_tools(None, ms)
print("Tools", [tool.name for tool in ms.list_tools()])
for agent in agent_states:
ms.create_agent(agent)
print(f"Agent {agent.name} migrated successfully!")
# add another agent to create core memory tool
client = create_client()
dummy_agent = client.create_agent(name="dummy_agent")
tools = client.list_tools()
assert "core_memory_append" in [tool.name for tool in tools]
print("Migration completed successfully!")

View File

@ -2,8 +2,6 @@ import logging
import pytest
from letta.settings import tool_settings
def pytest_configure(config):
logging.basicConfig(level=logging.DEBUG)
@ -11,6 +9,8 @@ def pytest_configure(config):
@pytest.fixture
def mock_e2b_api_key_none():
from letta.settings import tool_settings
# Store the original value of e2b_api_key
original_api_key = tool_settings.e2b_api_key

View File

@ -61,7 +61,7 @@ def setup_agent(
filename: str,
memory_human_str: str = get_human_text(DEFAULT_HUMAN),
memory_persona_str: str = get_persona_text(DEFAULT_PERSONA),
tools: Optional[List[str]] = None,
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
agent_uuid: str = agent_uuid,
) -> AgentState:
@ -77,7 +77,7 @@ def setup_agent(
memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
agent_state = client.create_agent(
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tools=tools, tool_rules=tool_rules
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules
)
return agent_state
@ -103,7 +103,6 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
cleanup(client=client, agent_uuid=agent_uuid)
agent_state = setup_agent(client, filename)
tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tool_names]
full_agent_state = client.get_agent(agent_state.id)
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
@ -171,19 +170,18 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
tool_name = tool.name
# Set up persona for tool usage
persona = f"""
My name is Letta.
I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question.
I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool.name} which will search `example.com` and answer the relevant question.
Dont forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
"""
agent_state = setup_agent(client, filename, memory_persona_str=persona, tools=[tool_name])
agent_state = setup_agent(client, filename, memory_persona_str=persona, tool_ids=[tool.id])
response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?")
@ -191,7 +189,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
assert_sanity_checks(response)
# Make sure the tool was called
assert_invoked_function_call(response.messages, tool_name)
assert_invoked_function_call(response.messages, tool.name)
# Make sure some inner monologue is present
assert_inner_monologue_is_present_and_valid(response.messages)
@ -334,7 +332,7 @@ def check_agent_summarize_memory_simple(filename: str) -> LettaResponse:
client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?")
# Summarize
agent = client.server.load_agent(agent_id=agent_state.id)
agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
agent.summarize_messages_inplace()
print(f"Summarization succeeded: messages[1] = \n\n{json_dumps(agent.messages[1])}\n")

View File

@ -3,6 +3,7 @@ from typing import Union
from letta import LocalClient, RESTClient
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.tool import Tool
@ -24,3 +25,57 @@ def create_tool_from_func(func: callable):
source_code=parse_source_code(func),
json_schema=generate_schema(func, None),
)
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent]):
# Assert scalar fields
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}"
# Assert agent type
if hasattr(request, "agent_type"):
assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}"
# Assert LLM configuration
assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}"
# Assert embedding configuration
assert (
agent.embedding_config == request.embedding_config
), f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
# Assert memory blocks
if hasattr(request, "memory_blocks"):
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(
request.block_ids
), f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
memory_block_values = {block.value for block in agent.memory.blocks}
expected_block_values = {block.value for block in request.memory_blocks}
assert expected_block_values.issubset(
memory_block_values
), f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
# Assert tools
assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}"
assert {tool.id for tool in agent.tools} == set(
request.tool_ids
), f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
# Assert sources
assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}"
assert {source.id for source in agent.sources} == set(
request.source_ids
), f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
# Assert tags
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
# Assert tool rules
if request.tool_rules:
assert len(agent.tool_rules) == len(
request.tool_rules
), f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
assert all(
any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules
), f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"

View File

@ -99,7 +99,7 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none):
]
# Make agent state
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules)
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?")
# Make checks

View File

@ -17,7 +17,7 @@ def test_o1_agent():
agent_state = client.create_agent(
agent_type=AgentType.o1_agent,
tools=[thinking_tool.name, final_tool.name],
tool_ids=[thinking_tool.id, final_tool.id],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
memory=ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text("o1_persona")),

View File

@ -32,8 +32,10 @@ def clear_agents(client):
for agent in client.list_agents():
client.delete_agent(agent.id)
def test_ripple_edit(client, mock_e2b_api_key_none):
trigger_rethink_memory_tool = client.create_or_update_tool(trigger_rethink_memory)
send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user)
conversation_human_block = Block(name="human", label="human", value=get_human_text(DEFAULT_HUMAN), limit=2000)
conversation_persona_block = Block(name="persona", label="persona", value=get_persona_text(DEFAULT_PERSONA), limit=2000)
@ -64,7 +66,7 @@ def test_ripple_edit(client, mock_e2b_api_key_none):
system=gpt_system.get_system_text("memgpt_convo_only"),
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
tools=["send_message", trigger_rethink_memory_tool.name],
tool_ids=[send_message.id, trigger_rethink_memory_tool.id],
memory=conversation_memory,
include_base_tools=False,
)
@ -81,7 +83,7 @@ def test_ripple_edit(client, mock_e2b_api_key_none):
memory=offline_memory,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
tools=[rethink_memory_tool.name, finish_rethinking_memory_tool.name],
tool_ids=[rethink_memory_tool.id, finish_rethinking_memory_tool.id],
tool_rules=[TerminalToolRule(tool_name=finish_rethinking_memory_tool.name)],
include_base_tools=False,
)
@ -111,16 +113,16 @@ def test_chat_only_agent(client, mock_e2b_api_key_none):
)
conversation_memory = BasicBlockMemory(blocks=[conversation_persona_block, conversation_human_block])
client = create_client()
send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user)
chat_only_agent = client.create_agent(
name="conversation_agent",
agent_type=AgentType.chat_only_agent,
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
tools=["send_message"],
tool_ids=[send_message.id],
memory=conversation_memory,
include_base_tools=False,
metadata={"offline_memory_tools": [rethink_memory.name, finish_rethinking_memory.name]},
metadata={"offline_memory_tools": [rethink_memory.id, finish_rethinking_memory.id]},
)
assert chat_only_agent is not None
assert set(chat_only_agent.memory.list_block_labels()) == {"chat_agent_persona", "chat_agent_human"}
@ -135,6 +137,7 @@ def test_chat_only_agent(client, mock_e2b_api_key_none):
# Clean up agent
client.delete_agent(chat_only_agent.id)
def test_initial_message_sequence(client, mock_e2b_api_key_none):
"""
Test that when we set the initial sequence to an empty list,
@ -150,8 +153,6 @@ def test_initial_message_sequence(client, mock_e2b_api_key_none):
initial_message_sequence=[],
)
assert offline_memory_agent is not None
assert len(offline_memory_agent.message_ids) == 1 # There should just the system message
assert len(offline_memory_agent.message_ids) == 1 # There should just the system message
client.delete_agent(offline_memory_agent.id)

View File

@ -1,41 +0,0 @@
# TODO: add back
# import os
# import subprocess
#
# import pytest
#
#
# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
# def test_agent_groupchat():
#
# # Define the path to the script you want to test
# script_path = "letta/autogen/examples/agent_groupchat.py"
#
# # Dynamically get the project's root directory (assuming this script is run from the root)
# # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
# # print(project_root)
# # project_root = os.path.join(project_root, "Letta")
# # print(project_root)
# # sys.exit(1)
#
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# project_root = os.path.join(project_root, "letta")
# print(f"Adding the following to PATH: {project_root}")
#
# # Prepare the environment, adding the project root to PYTHONPATH
# env = os.environ.copy()
# env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}"
#
# # Run the script using subprocess.run
# # Capture the output (stdout) and the exit code
# # result = subprocess.run(["python", script_path], capture_output=True, text=True)
# result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True)
#
# # Check the exit code (0 indicates success)
# assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}"
#
# # Optionally, check the output for expected content
# # For example, if you expect a specific line in the output, uncomment and adapt the following line:
# # assert "expected output" in result.stdout, "Expected output not found in script's output"
#

View File

@ -23,7 +23,7 @@ def agent_obj():
agent_state = client.create_agent()
global agent_obj
agent_obj = client.server.load_agent(agent_id=agent_state.id)
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
yield agent_obj
client.delete_agent(agent_obj.agent_state.id)
@ -35,49 +35,50 @@ def query_in_search_results(search_results, query):
return True
return False
def test_archival(agent_obj):
"""Test archival memory functions comprehensively."""
# Test 1: Basic insertion and retrieval
base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat")
base_functions.archival_memory_insert(agent_obj, "The dog plays in the park")
base_functions.archival_memory_insert(agent_obj, "Python is a programming language")
# Test exact text search
results, _ = base_functions.archival_memory_search(agent_obj, "cat")
assert query_in_search_results(results, "cat")
# Test semantic search (should return animal-related content)
results, _ = base_functions.archival_memory_search(agent_obj, "animal pets")
assert query_in_search_results(results, "cat") or query_in_search_results(results, "dog")
# Test unrelated search (should not return animal content)
results, _ = base_functions.archival_memory_search(agent_obj, "programming computers")
assert query_in_search_results(results, "python")
# Test 2: Test pagination
# Insert more items to test pagination
for i in range(10):
base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}")
# Get first page
page0_results, next_page = base_functions.archival_memory_search(agent_obj, "Test passage", page=0)
# Get second page
page1_results, _ = base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page)
assert page0_results != page1_results
assert query_in_search_results(page0_results, "Test passage")
assert query_in_search_results(page1_results, "Test passage")
# Test 3: Test complex text patterns
base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John")
base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week")
base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching")
# Search for meeting-related content
results, _ = base_functions.archival_memory_search(agent_obj, "meeting schedule")
assert query_in_search_results(results, "meeting")
assert query_in_search_results(results, "2024-01-15") or query_in_search_results(results, "next week")
# Test 4: Test error handling
# Test invalid page number
try:
@ -85,7 +86,7 @@ def test_archival(agent_obj):
assert False, "Should have raised ValueError"
except ValueError:
pass
def test_recall(agent_obj):
base_functions.conversation_search(agent_obj, "banana")

View File

@ -1,5 +1,4 @@
import asyncio
import json
import os
import threading
import time
@ -42,8 +41,8 @@ def run_server():
@pytest.fixture(
# params=[{"server": False}, {"server": True}], # whether to use REST API server
params=[{"server": False}], # whether to use REST API server
params=[{"server": False}, {"server": True}], # whether to use REST API server
# params=[{"server": True}], # whether to use REST API server
scope="module",
)
def client(request):
@ -69,6 +68,7 @@ def client(request):
@pytest.fixture(scope="module")
def agent(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}")
yield agent_state
# delete agent
@ -86,6 +86,47 @@ def clear_tables():
session.commit()
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
# _reset_config()
# create a block
block = client.create_block(label="human", value="username: sarah")
# create agents with shared block
from letta.schemas.block import Block
from letta.schemas.memory import BasicBlockMemory
# persona1_block = client.create_block(label="persona", value="you are agent 1")
# persona2_block = client.create_block(label="persona", value="you are agent 2")
# create agents
agent_state1 = client.create_agent(
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
)
agent_state2 = client.create_agent(
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
)
## attach shared block to both agents
# client.link_agent_memory_block(agent_state1.id, block.id)
# client.link_agent_memory_block(agent_state2.id, block.id)
# update memory
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
# check agent 2 memory
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
client.user_message(agent_id=agent_state2.id, message="whats my name?")
assert (
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
"""
Test sandbox config and environment variable functions for both LocalClient and RESTClient.
@ -137,15 +178,15 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]
client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]):
"""
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
"""
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]
# Step 0: create an agent with tags
tagged_agent = client.create_agent(tags=tags_to_add)
assert set(tagged_agent.tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {tagged_agent.tags}"
# Step 0: create an agent with no tags
agent = client.create_agent()
assert len(agent.tags) == 0
# Step 1: Add multiple tags to the agent
client.update_agent(agent_id=agent.id, tags=tags_to_add)
@ -175,6 +216,9 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a
final_tags = client.get_agent(agent_id=agent.id).tags
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
# Remove agent
client.delete_agent(agent.id)
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test that we can update the label of a block in an agent's memory"""
@ -255,35 +299,33 @@ def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], a
# client.delete_agent(new_agent.id)
def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent: AgentState):
def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]):
"""Test that we can update the limit of a block in an agent's memory"""
agent = client.create_agent(name=create_random_username())
agent = client.create_agent()
try:
current_labels = agent.memory.list_block_labels()
example_label = current_labels[0]
example_new_limit = 1
current_block = agent.memory.get_block(label=example_label)
current_block_length = len(current_block.value)
current_labels = agent.memory.list_block_labels()
example_label = current_labels[0]
example_new_limit = 1
current_block = agent.memory.get_block(label=example_label)
current_block_length = len(current_block.value)
assert example_new_limit != agent.memory.get_block(label=example_label).limit
assert example_new_limit < current_block_length
assert example_new_limit != agent.memory.get_block(label=example_label).limit
assert example_new_limit < current_block_length
# We expect this to throw a value error
with pytest.raises(ValueError):
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
# Now try the same thing with a higher limit
example_new_limit = current_block_length + 10000
assert example_new_limit > current_block_length
# We expect this to throw a value error
with pytest.raises(ValueError):
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit
# Now try the same thing with a higher limit
example_new_limit = current_block_length + 10000
assert example_new_limit > current_block_length
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
finally:
client.delete_agent(agent.id)
updated_agent = client.get_agent(agent_id=agent.id)
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit
client.delete_agent(agent.id)
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
@ -316,7 +358,7 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50
tool = client.create_or_update_tool(func=big_return, return_char_limit=1000)
agent = client.create_agent(name="agent1", tools=[tool.name])
agent = client.create_agent(tool_ids=[tool.id])
# get function response
response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user")
print(response.messages)
@ -330,10 +372,14 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
assert response_message, "FunctionReturn message not found in response"
res = response_message.function_return
assert "function output was truncated " in res
res_json = json.loads(res)
assert (
len(res_json["message"]) <= 1000 + padding
), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
# TODO: Re-enable later
# res_json = json.loads(res)
# assert (
# len(res_json["message"]) <= 1000 + padding
# ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
client.delete_agent(agent_id=agent.id)
@pytest.mark.asyncio

View File

@ -583,43 +583,6 @@ def test_list_llm_models(client: RESTClient):
assert has_model_endpoint_type(models, "anthropic")
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
# create a block
block = client.create_block(label="human", value="username: sarah")
# create agents with shared block
from letta.schemas.block import Block
from letta.schemas.memory import BasicBlockMemory
# persona1_block = client.create_block(label="persona", value="you are agent 1")
# persona2_block = client.create_block(label="persona", value="you are agent 2")
# create agnets
agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1"), block]))
agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2"), block]))
## attach shared block to both agents
# client.link_agent_memory_block(agent_state1.id, block.id)
# client.link_agent_memory_block(agent_state2.id, block.id)
# update memory
response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
# check agent 2 memory
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
response = client.user_message(agent_id=agent_state2.id, message="whats my name?")
assert (
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)
@pytest.fixture
def cleanup_agents(client):
created_agents = []

View File

@ -1,142 +0,0 @@
# TODO: add back when messaging works
# import os
# import threading
# import time
# import uuid
#
# import pytest
# from dotenv import load_dotenv
#
# from letta import Admin, create_client
# from letta.config import LettaConfig
# from letta.credentials import LettaCredentials
# from letta.settings import settings
# from tests.utils import create_config
#
# test_agent_name = f"test_client_{str(uuid.uuid4())}"
## test_preset_name = "test_preset"
# test_agent_state = None
# client = None
#
# test_agent_state_post_message = None
# test_user_id = uuid.uuid4()
#
#
## admin credentials
# test_server_token = "test_server_token"
#
#
# def _reset_config():
#
# # Use os.getenv with a fallback to os.environ.get
# db_url = settings.letta_pg_uri
#
# if os.getenv("OPENAI_API_KEY"):
# create_config("openai")
# credentials = LettaCredentials(
# openai_key=os.getenv("OPENAI_API_KEY"),
# )
# else: # hosted
# create_config("letta_hosted")
# credentials = LettaCredentials()
#
# config = LettaConfig.load()
#
# # set to use postgres
# config.archival_storage_uri = db_url
# config.recall_storage_uri = db_url
# config.metadata_storage_uri = db_url
# config.archival_storage_type = "postgres"
# config.recall_storage_type = "postgres"
# config.metadata_storage_type = "postgres"
#
# config.save()
# credentials.save()
# print("_reset_config :: ", config.config_path)
#
#
# def run_server():
#
# load_dotenv()
#
# _reset_config()
#
# from letta.server.rest_api.server import start_server
#
# print("Starting server...")
# start_server(debug=True)
#
#
## Fixture to create clients with different configurations
# @pytest.fixture(
# params=[ # whether to use REST API server
# {"server": True},
# # {"server": False} # TODO: add when implemented
# ],
# scope="module",
# )
# def admin_client(request):
# if request.param["server"]:
# # get URL from enviornment
# server_url = os.getenv("MEMGPT_SERVER_URL")
# if server_url is None:
# # run server in thread
# # NOTE: must set MEMGPT_SERVER_PASS enviornment variable
# server_url = "http://localhost:8283"
# print("Starting server thread")
# thread = threading.Thread(target=run_server, daemon=True)
# thread.start()
# time.sleep(5)
# print("Running client tests with server:", server_url)
# # create user via admin client
# admin = Admin(server_url, test_server_token)
# response = admin.create_user(test_user_id) # Adjust as per your client's method
#
# yield admin
#
#
# def test_concurrent_messages(admin_client):
# # test concurrent messages
#
# # create three
#
# results = []
#
# def _send_message():
# try:
# print("START SEND MESSAGE")
# response = admin_client.create_user()
# token = response.api_key
# client = create_client(base_url=admin_client.base_url, token=token)
# agent = client.create_agent()
#
# print("Agent created", agent.id)
#
# st = time.time()
# message = "Hello, how are you?"
# response = client.send_message(agent_id=agent.id, message=message, role="user")
# et = time.time()
# print(f"Message sent from {st} to {et}")
# print(response.messages)
# results.append((st, et))
# except Exception as e:
# print("ERROR", e)
#
# threads = []
# print("Starting threads...")
# for i in range(5):
# thread = threading.Thread(target=_send_message)
# threads.append(thread)
# thread.start()
# print("CREATED THREAD")
#
# print("waiting for threads to finish...")
# for thread in threads:
# print(thread.join())
#
# # make sure runtime are overlapping
# assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
# results[1][0] < results[0][0] and results[1][1] > results[0][0]
# ), f"Threads should have overlapping runtimes {results}"
#

View File

@ -1,121 +0,0 @@
# TODO: add back once tests are cleaned up
# import os
# import uuid
#
# from letta import create_client
# from letta.agent_store.storage import StorageConnector, TableType
# from letta.schemas.passage import Passage
# from letta.embeddings import embedding_model
# from tests import TEST_MEMGPT_CONFIG
#
# from .utils import create_config, wipe_config
#
# test_agent_name = f"test_client_{str(uuid.uuid4())}"
# test_agent_state = None
# client = None
#
# test_agent_state_post_message = None
# test_user_id = uuid.uuid4()
#
#
# def generate_passages(user, agent):
# # Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
# texts = [
# "This is a test passage",
# "This is another test passage",
# "Cinderella wept",
# ]
# embed_model = embedding_model(agent.embedding_config)
# orig_embeddings = []
# passages = []
# for text in texts:
# embedding = embed_model.get_text_embedding(text)
# orig_embeddings.append(list(embedding))
# passages.append(
# Passage(
# user_id=user.id,
# agent_id=agent.id,
# text=text,
# embedding=embedding,
# embedding_dim=agent.embedding_config.embedding_dim,
# embedding_model=agent.embedding_config.embedding_model,
# )
# )
# return passages, orig_embeddings
#
#
# def test_create_user():
# if not os.getenv("OPENAI_API_KEY"):
# print("Skipping test, missing OPENAI_API_KEY")
# return
#
# wipe_config()
#
# # create client
# create_config("openai")
# client = create_client()
#
# # openai: create agent
# openai_agent = client.create_agent(
# name="openai_agent",
# )
# assert (
# openai_agent.embedding_config.embedding_endpoint_type == "openai"
# ), f"openai_agent.embedding_config.embedding_endpoint_type={openai_agent.embedding_config.embedding_endpoint_type}"
#
# # openai: add passages
# passages, openai_embeddings = generate_passages(client.user, openai_agent)
# openai_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=openai_agent.id)
# openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
#
# # create client
# create_config("letta_hosted")
# client = create_client()
#
# # hosted: create agent
# hosted_agent = client.create_agent(
# name="hosted_agent",
# )
# # check to make sure endpoint overriden
# assert (
# hosted_agent.embedding_config.embedding_endpoint_type == "hugging-face"
# ), f"hosted_agent.embedding_config.embedding_endpoint_type={hosted_agent.embedding_config.embedding_endpoint_type}"
#
# # hosted: add passages
# passages, hosted_embeddings = generate_passages(client.user, hosted_agent)
# hosted_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=hosted_agent.id)
# hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
#
# # test passage dimentionality
# storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id)
# storage.filters = {} # clear filters to be able to get all passages
# passages = storage.get_all()
# for passage in passages:
# if passage.agent_id == hosted_agent.id:
# assert (
# passage.embedding_dim == hosted_agent.embedding_config.embedding_dim
# ), f"passage.embedding_dim={passage.embedding_dim} != hosted_agent.embedding_config.embedding_dim={hosted_agent.embedding_config.embedding_dim}"
#
# # ensure was in original embeddings
# embedding = passage.embedding[: passage.embedding_dim]
# assert embedding in hosted_embeddings, f"embedding={embedding} not in hosted_embeddings={hosted_embeddings}"
#
# # make sure all zeros
# assert not any(
# passage.embedding[passage.embedding_dim :]
# ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"
# elif passage.agent_id == openai_agent.id:
# assert (
# passage.embedding_dim == openai_agent.embedding_config.embedding_dim
# ), f"passage.embedding_dim={passage.embedding_dim} != openai_agent.embedding_config.embedding_dim={openai_agent.embedding_config.embedding_dim}"
#
# # ensure was in original embeddings
# embedding = passage.embedding[: passage.embedding_dim]
# assert embedding in openai_embeddings, f"embedding={embedding} not in openai_embeddings={openai_embeddings}"
#
# # make sure all zeros
# assert not any(
# passage.embedding[passage.embedding_dim :]
# ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"
#

View File

@ -1,48 +0,0 @@
import letta.system as system
from letta.local_llm.function_parser import patch_function
from letta.utils import json_dumps
EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = {
"message_history": [
{"role": "user", "content": system.package_user_message("hello")},
],
# "new_message": {
# "role": "function",
# "name": "send_message",
# "content": system.package_function_response(was_success=True, response_string="None"),
# },
"new_message": {
"role": "assistant",
"content": "I'll send a message.",
"function_call": {
"name": "send_message",
"arguments": "null",
},
},
}
EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = {
"message_history": [
{"role": "user", "content": system.package_user_message("hello")},
],
"new_message": {
"role": "assistant",
"content": "I'll append to memory.",
"function_call": {
"name": "core_memory_append",
"arguments": json_dumps({"content": "new_stuff"}),
},
},
}
def test_function_parsers():
"""Try various broken JSON and check that the parsers can fix it"""
og_message = EXAMPLE_FUNCTION_CALL_SEND_MESSAGE["new_message"]
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_SEND_MESSAGE)
assert corrected_message == og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"
og_message = EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING["new_message"].copy()
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING)
assert corrected_message != og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"

View File

@ -1,99 +0,0 @@
import letta.local_llm.json_parser as json_parser
from letta.utils import json_loads
EXAMPLE_ESCAPED_UNDERSCORES = """{
"function":"send\_message",
"params": {
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
"""
EXAMPLE_MISSING_CLOSING_BRACE = """{
"function": "send_message",
"params": {
"inner_thoughts": "Oops, I got their name wrong! I should apologize and correct myself.",
"message": "Sorry about that! I assumed you were Chad. Welcome, Brad! "
}
"""
EXAMPLE_BAD_TOKEN_END = """{
"function": "send_message",
"params": {
"inner_thoughts": "Oops, I got their name wrong! I should apologize and correct myself.",
"message": "Sorry about that! I assumed you were Chad. Welcome, Brad! "
}
}<|>"""
EXAMPLE_DOUBLE_JSON = """{
"function": "core_memory_append",
"params": {
"name": "human",
"content": "Brad, 42 years old, from Germany."
}
}
{
"function": "send_message",
"params": {
"message": "Got it! Your age and nationality are now saved in my memory."
}
}
"""
EXAMPLE_HARD_LINE_FEEDS = """{
"function": "send_message",
"params": {
"message": "Let's create a list:
- First, we can do X
- Then, we can do Y!
- Lastly, we can do Z :)"
}
}
"""
# Situation where beginning of send_message call is fine (and thus can be extracted)
# but has a long training garbage string that comes after
EXAMPLE_SEND_MESSAGE_PREFIX_OK_REST_BAD = """{
"function": "send_message",
"params": {
"inner_thoughts": "User request for debug assistance",
"message": "Of course, Chad. Please check the system log file for 'assistant.json' and send me the JSON output you're getting. Armed with that data, I'll assist you in debugging the issue.",
GARBAGEGARBAGEGARBAGEGARBAGE
GARBAGEGARBAGEGARBAGEGARBAGE
GARBAGEGARBAGEGARBAGEGARBAGE
"""
EXAMPLE_ARCHIVAL_SEARCH = """
{
"function": "archival_memory_search",
"params": {
"inner_thoughts": "Looking for WaitingForAction.",
"query": "WaitingForAction",
"""
def test_json_parsers():
"""Try various broken JSON and check that the parsers can fix it"""
test_strings = [
EXAMPLE_ESCAPED_UNDERSCORES,
EXAMPLE_MISSING_CLOSING_BRACE,
EXAMPLE_BAD_TOKEN_END,
EXAMPLE_DOUBLE_JSON,
EXAMPLE_HARD_LINE_FEEDS,
EXAMPLE_SEND_MESSAGE_PREFIX_OK_REST_BAD,
EXAMPLE_ARCHIVAL_SEARCH,
]
for string in test_strings:
try:
json_loads(string)
assert False, f"Test JSON string should have failed basic JSON parsing:\n{string}"
except:
print("String failed (expectedly)")
try:
json_parser.clean_json(string)
except:
f"Failed to repair test JSON string:\n{string}"
raise

View File

@ -4,7 +4,7 @@ import pytest
from letta import create_client
from letta.client.client import LocalClient
from letta.schemas.agent import PersistedAgentState
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
@ -13,6 +13,7 @@ from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
@pytest.fixture(scope="module")
def client():
client = create_client()
# client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
@ -29,7 +30,6 @@ def agent(client):
yield agent_state
client.delete_agent(agent_state.id)
assert client.get_agent(agent_state.id) is None, f"Failed to properly delete agent {agent_state.id}"
def test_agent(client: LocalClient):
@ -80,16 +80,15 @@ def test_agent(client: LocalClient):
assert isinstance(agent_state.memory, Memory)
# update agent: tools
tool_to_delete = "send_message"
assert tool_to_delete in agent_state.tool_names
new_agent_tools = [t_name for t_name in agent_state.tool_names if t_name != tool_to_delete]
client.update_agent(agent_state_test.id, tools=new_agent_tools)
assert client.get_agent(agent_state_test.id).tool_names == new_agent_tools
assert tool_to_delete in [t.name for t in agent_state.tools]
new_agent_tool_ids = [t.id for t in agent_state.tools if t.name != tool_to_delete]
client.update_agent(agent_state_test.id, tool_ids=new_agent_tool_ids)
assert sorted([t.id for t in client.get_agent(agent_state_test.id).tools]) == sorted(new_agent_tool_ids)
assert isinstance(agent_state.memory, Memory)
# update agent: memory
new_human = "My name is Mr Test, 100 percent human."
new_persona = "I am an all-knowing AI."
new_memory = ChatMemory(human=new_human, persona=new_persona)
assert agent_state.memory.get_block("human").value != new_human
assert agent_state.memory.get_block("persona").value != new_persona
@ -216,7 +215,7 @@ def test_agent_with_shared_blocks(client: LocalClient):
client.delete_agent(second_agent_state_test.id)
def test_memory(client: LocalClient, agent: PersistedAgentState):
def test_memory(client: LocalClient, agent: AgentState):
# get agent memory
original_memory = client.get_in_context_memory(agent.id)
assert original_memory is not None
@ -229,7 +228,7 @@ def test_memory(client: LocalClient, agent: PersistedAgentState):
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
def test_archival_memory(client: LocalClient, agent: PersistedAgentState):
def test_archival_memory(client: LocalClient, agent: AgentState):
"""Test functions for interacting with archival memory store"""
# add archival memory
@ -244,7 +243,7 @@ def test_archival_memory(client: LocalClient, agent: PersistedAgentState):
client.delete_archival_memory(agent.id, passage.id)
def test_recall_memory(client: LocalClient, agent: PersistedAgentState):
def test_recall_memory(client: LocalClient, agent: AgentState):
"""Test functions for interacting with recall memory store"""
# send message to the agent

File diff suppressed because it is too large Load Diff

View File

@ -1,126 +0,0 @@
# TODO: fix later
# import os
# import random
# import string
# import unittest.mock
#
# import pytest
#
# from letta.cli.cli_config import add, delete, list
# from letta.config import LettaConfig
# from letta.credentials import LettaCredentials
# from tests.utils import create_config
#
#
# def _reset_config():
#
# if os.getenv("OPENAI_API_KEY"):
# create_config("openai")
# credentials = LettaCredentials(
# openai_key=os.getenv("OPENAI_API_KEY"),
# )
# else: # hosted
# create_config("letta_hosted")
# credentials = LettaCredentials()
#
# config = LettaConfig.load()
# config.save()
# credentials.save()
# print("_reset_config :: ", config.config_path)
#
#
# @pytest.mark.skip(reason="This is a helper function.")
# def generate_random_string(length):
# characters = string.ascii_letters + string.digits
# random_string = "".join(random.choices(characters, k=length))
# return random_string
#
#
# @pytest.mark.skip(reason="Ensures LocalClient is used during testing.")
# def unset_env_variables():
# server_url = os.environ.pop("MEMGPT_BASE_URL", None)
# token = os.environ.pop("MEMGPT_SERVER_PASS", None)
# return server_url, token
#
#
# @pytest.mark.skip(reason="Set env variables back to values before test.")
# def reset_env_variables(server_url, token):
# if server_url is not None:
# os.environ["MEMGPT_BASE_URL"] = server_url
# if token is not None:
# os.environ["MEMGPT_SERVER_PASS"] = token
#
#
# def test_crud_human(capsys):
# _reset_config()
#
# server_url, token = unset_env_variables()
#
# # Initialize values that won't interfere with existing ones
# human_1 = generate_random_string(16)
# text_1 = generate_random_string(32)
# human_2 = generate_random_string(16)
# text_2 = generate_random_string(32)
# text_3 = generate_random_string(32)
#
# # Add inital human
# add("human", human_1, text_1)
#
# # Expect inital human to be listed
# list("humans")
# captured = capsys.readouterr()
# output = captured.out[captured.out.find(human_1) :]
#
# assert human_1 in output
# assert text_1 in output
#
# # Add second human
# add("human", human_2, text_2)
#
# # Expect to see second human
# list("humans")
# captured = capsys.readouterr()
# output = captured.out[captured.out.find(human_1) :]
#
# assert human_1 in output
# assert text_1 in output
# assert human_2 in output
# assert text_2 in output
#
# with unittest.mock.patch("questionary.confirm") as mock_confirm:
# mock_confirm.return_value.ask.return_value = True
#
# # Update second human
# add("human", human_2, text_3)
#
# # Expect to see update text
# list("humans")
# captured = capsys.readouterr()
# output = captured.out[captured.out.find(human_1) :]
#
# assert human_1 in output
# assert text_1 in output
# assert human_2 in output
# assert output.count(human_2) == 1
# assert text_3 in output
# assert text_2 not in output
#
# # Delete second human
# delete("human", human_2)
#
# # Expect second human to be deleted
# list("humans")
# captured = capsys.readouterr()
# output = captured.out[captured.out.find(human_1) :]
#
# assert human_1 in output
# assert text_1 in output
# assert human_2 not in output
# assert text_2 not in output
#
# # Clean up
# delete("human", human_1)
#
# reset_env_variables(server_url, token)
#

View File

@ -1,93 +0,0 @@
from logging import getLogger
from openai import APIConnectionError, OpenAI
logger = getLogger(__name__)
def test_openai_assistant():
client = OpenAI(base_url="http://127.0.0.1:8080/v1")
# create assistant
try:
assistant = client.beta.assistants.create(
name="Math Tutor",
instructions="You are a personal math tutor. Write and run code to answer math questions.",
# tools=[{"type": "code_interpreter"}],
model="gpt-4-turbo-preview",
)
except APIConnectionError as e:
logger.error("Connection issue with localhost openai stub: %s", e)
return
# create thread
thread = client.beta.threads.create()
message = client.beta.threads.messages.create(
thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?"
)
run = client.beta.threads.runs.create(
thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account."
)
# run = client.beta.threads.runs.create(
# thread_id=thread.id,
# assistant_id=assistant.id,
# model="gpt-4-turbo-preview",
# instructions="New instructions that override the Assistant instructions",
# tools=[{"type": "code_interpreter"}, {"type": "retrieval"}]
# )
# Store the run ID
run_id = run.id
print(run_id)
# NOTE: Letta does not support polling yet, so run status is always "completed"
# Retrieve all messages from the thread
messages = client.beta.threads.messages.list(thread_id=thread.id)
# Print all messages from the thread
for msg in messages.messages:
role = msg["role"]
content = msg["content"][0]
print(f"{role.capitalize()}: {content}")
# TODO: add once polling works
## Polling for the run status
# while True:
# # Retrieve the run status
# run_status = client.beta.threads.runs.retrieve(
# thread_id=thread.id,
# run_id=run_id
# )
# # Check and print the step details
# run_steps = client.beta.threads.runs.steps.list(
# thread_id=thread.id,
# run_id=run_id
# )
# for step in run_steps.data:
# if step.type == 'tool_calls':
# print(f"Tool {step.type} invoked.")
# # If step involves code execution, print the code
# if step.type == 'code_interpreter':
# print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")
# if run_status.status == 'completed':
# # Retrieve all messages from the thread
# messages = client.beta.threads.messages.list(
# thread_id=thread.id
# )
# # Print all messages from the thread
# for msg in messages.data:
# role = msg.role
# content = msg.content[0].text.value
# print(f"{role.capitalize()}: {content}")
# break # Exit the polling loop since the run is complete
# elif run_status.status in ['queued', 'in_progress']:
# print(f'{run_status.status.capitalize()}... Please wait.')
# time.sleep(1.5) # Wait before checking again
# else:
# print(f"Run status: {run_status.status}")
# break # Exit the polling loop if the status is neither 'in_progress' nor 'completed'

View File

@ -1,52 +0,0 @@
# test state saving between client session
# TODO: update this test with correct imports
# def test_save_load(client):
# """Test that state is being persisted correctly after an /exit
#
# Create a new agent, and request a message
#
# Then trigger
# """
# assert client is not None, "Run create_agent test first"
# assert test_agent_state is not None, "Run create_agent test first"
# assert test_agent_state_post_message is not None, "Run test_user_message test first"
#
# # Create a new client (not thread safe), and load the same agent
# # The agent state inside should correspond to the initial state pre-message
# if os.getenv("OPENAI_API_KEY"):
# client2 = Letta(quickstart="openai", user_id=test_user_id)
# else:
# client2 = Letta(quickstart="letta_hosted", user_id=test_user_id)
# print(f"\n\n[3] CREATING CLIENT2, LOADING AGENT {test_agent_state.id}!")
# client2_agent_obj = client2.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id)
# client2_agent_state = client2_agent_obj.update_state()
# print(f"[3] LOADED AGENT! AGENT {client2_agent_state.id}\n\tmessages={client2_agent_state.state['messages']}")
#
# # assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}"
# def check_state_equivalence(state_1, state_2):
# """Helper function that checks the equivalence of two AgentState objects"""
# assert state_1.keys() == state_2.keys(), f"{state_1.keys()}\n{state_2.keys}"
# for k, v1 in state_1.items():
# v2 = state_2[k]
# if isinstance(v1, LLMConfig) or isinstance(v1, EmbeddingConfig):
# assert vars(v1) == vars(v2), f"{vars(v1)}\n{vars(v2)}"
# else:
# assert v1 == v2, f"{v1}\n{v2}"
#
# check_state_equivalence(vars(test_agent_state), vars(client2_agent_state))
#
# # Now, write out the save from the original client
# # This should persist the test message into the agent state
# client.save()
#
# if os.getenv("OPENAI_API_KEY"):
# client3 = Letta(quickstart="openai", user_id=test_user_id)
# else:
# client3 = Letta(quickstart="letta_hosted", user_id=test_user_id)
# client3_agent_obj = client3.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id)
# client3_agent_state = client3_agent_obj.update_state()
#
# check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state))
#

View File

@ -1,62 +0,0 @@
from letta.functions.schema_generator import generate_schema
def send_message(self, message: str):
"""
Sends a message to the human user.
Args:
message (str): Message contents. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
return None
def send_message_missing_types(self, message):
"""
Sends a message to the human user.
Args:
message (str): Message contents. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
return None
def send_message_missing_docstring(self, message: str):
return None
def test_schema_generator():
# Check that a basic function schema converts correctly
correct_schema = {
"name": "send_message",
"description": "Sends a message to the human user.",
"parameters": {
"type": "object",
"properties": {"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}},
"required": ["message"],
},
}
generated_schema = generate_schema(send_message)
print(f"\n\nreference_schema={correct_schema}")
print(f"\n\ngenerated_schema={generated_schema}")
assert correct_schema == generated_schema
# Check that missing types results in an error
try:
_ = generate_schema(send_message_missing_types)
assert False
except:
pass
# Check that missing docstring results in an error
try:
_ = generate_schema(send_message_missing_docstring)
assert False
except:
pass

View File

@ -19,8 +19,6 @@ from letta.schemas.letta_message import (
)
from letta.schemas.user import User
from .test_managers import DEFAULT_EMBEDDING_CONFIG
utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
@ -266,6 +264,7 @@ Lise, young Bolkónski's wife, this very evening, and perhaps the
thing can be arranged. It shall be on your family's behalf that I'll
start my apprenticeship as old maid."""
@pytest.fixture(scope="module")
def server():
config = LettaConfig.load()
@ -302,42 +301,66 @@ def user_id(server, org_id):
@pytest.fixture(scope="module")
def agent_id(server, user_id):
def base_tools(server, user_id):
actor = server.user_manager.get_user_or_default(user_id)
tools = []
for tool_name in BASE_TOOLS:
tools.append(server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor))
yield tools
@pytest.fixture(scope="module")
def base_memory_tools(server, user_id):
actor = server.user_manager.get_user_or_default(user_id)
tools = []
for tool_name in BASE_MEMORY_TOOLS:
tools.append(server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor))
yield tools
@pytest.fixture(scope="module")
def agent_id(server, user_id, base_tools):
# create agent
actor = server.user_manager.get_user_or_default(user_id)
agent_state = server.create_agent(
request=CreateAgent(
name="test_agent",
tools=BASE_TOOLS,
tool_ids=[t.id for t in base_tools],
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=server.get_user_or_default(user_id),
actor=actor,
)
print(f"Created agent\n{agent_state}")
yield agent_state.id
# cleanup
server.delete_agent(user_id, agent_state.id)
server.agent_manager.delete_agent(agent_state.id, actor=actor)
@pytest.fixture(scope="module")
def other_agent_id(server, user_id):
def other_agent_id(server, user_id, base_tools):
# create agent
actor = server.user_manager.get_user_or_default(user_id)
agent_state = server.create_agent(
request=CreateAgent(
name="test_agent_other",
tools=BASE_TOOLS,
tool_ids=[t.id for t in base_tools],
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=server.get_user_or_default(user_id),
actor=actor,
)
print(f"Created agent\n{agent_state}")
yield agent_state.id
# cleanup
server.delete_agent(user_id, agent_state.id)
server.agent_manager.delete_agent(agent_state.id, actor=actor)
def test_error_on_nonexistent_agent(server, user_id, agent_id):
try:
@ -416,6 +439,7 @@ def test_user_message(server, user_id, agent_id):
@pytest.mark.order(5)
def test_get_recall_memory(server, org_id, user_id, agent_id):
# test recall memory cursor pagination
actor = server.user_manager.get_user_or_default(user_id=user_id)
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
cursor1 = messages_1[-1].id
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
@ -427,7 +451,9 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
assert len(messages_4) == 1
# test in-context message ids
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
message_ids = [m.id for m in messages_3]
for message_id in in_context_ids:
assert message_id in message_ids, f"{message_id} not in {message_ids}"
@ -437,10 +463,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
def test_get_archival_memory(server, user_id, agent_id):
# test archival memory cursor pagination
user = server.user_manager.get_user_by_id(user_id=user_id)
# List latest 2 passages
passages_1 = server.passage_manager.list_passages(
actor=user, agent_id=agent_id, ascending=False, limit=2,
actor=user,
agent_id=agent_id,
ascending=False,
limit=2,
)
assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
@ -483,12 +512,13 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
- "rewrite" replaces the text of the last assistant message
- "retry" retries the last assistant message
"""
actor = server.user_manager.get_user_or_default(user_id)
# Send an initial message
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# Grab the raw Agent object
letta_agent = server.load_agent(agent_id=agent_id)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
@ -496,10 +526,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
# Try "rethink"
new_thought = "I am thinking about the meaning of life, the universe, and everything. Bananas?"
assert last_agent_message.text is not None and last_agent_message.text != new_thought
server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought)
server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought, actor=actor)
# Grab the agent object again (make sure it's live)
letta_agent = server.load_agent(agent_id=agent_id)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
@ -513,10 +543,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != ""
new_text = "Why hello there my good friend! Is 42 what you're looking for? Bananas?"
server.rewrite_agent_message(agent_id=agent_id, new_text=new_text)
server.rewrite_agent_message(agent_id=agent_id, new_text=new_text, actor=actor)
# Grab the agent object again (make sure it's live)
letta_agent = server.load_agent(agent_id=agent_id)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
@ -524,10 +554,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
assert "message" in args_json and args_json["message"] is not None and args_json["message"] == new_text
# Try retry
server.retry_agent_message(agent_id=agent_id)
server.retry_agent_message(agent_id=agent_id, actor=actor)
# Grab the agent object again (make sure it's live)
letta_agent = server.load_agent(agent_id=agent_id)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
assert letta_agent._messages[-1].role == MessageRole.tool
assert letta_agent._messages[-2].role == MessageRole.assistant
last_agent_message = letta_agent._messages[-2]
@ -581,33 +611,6 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
)
def test_load_agent_with_nonexistent_tool_names_does_not_error(server: SyncServer, user_id: str):
fake_tool_name = "blahblahblah"
tools = BASE_TOOLS + [fake_tool_name]
agent_state = server.create_agent(
request=CreateAgent(
name="nonexistent_tools_agent",
tools=tools,
memory_blocks=[],
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=server.get_user_or_default(user_id),
)
# Check that the tools in agent_state do NOT include the fake name
assert fake_tool_name not in agent_state.tool_names
assert set(BASE_TOOLS).issubset(set(agent_state.tool_names))
# Load the agent from the database and check that it doesn't error / tools are correct
saved_tools = server.get_tools_from_agent(agent_id=agent_state.id, user_id=user_id)
assert fake_tool_name not in agent_state.tool_names
assert set(BASE_TOOLS).issubset(set(agent_state.tool_names))
# cleanup
server.delete_agent(user_id, agent_state.id)
def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
agent_state = server.create_agent(
request=CreateAgent(
@ -616,14 +619,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=server.get_user_or_default(user_id),
actor=server.user_manager.get_user_or_default(user_id),
)
# create another user in the same org
another_user = server.user_manager.create_user(User(organization_id=org_id, name="another"))
# test that another user in the same org can delete the agent
server.delete_agent(another_user.id, agent_state.id)
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
def _test_get_messages_letta_format(
@ -887,14 +890,14 @@ def test_composio_client_simple(server):
assert len(actions) > 0
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
# create agent
agent_state = server.create_agent(
request=CreateAgent(
name="memory_rebuild_test_agent",
tools=BASE_TOOLS + BASE_MEMORY_TOOLS,
tool_ids=[t.id for t in base_tools + base_memory_tools],
memory_blocks=[
CreateBlock(label="human", value="The human's name is Bob."),
CreateBlock(label="persona", value="My name is Alice."),
@ -902,7 +905,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=server.get_user_or_default(user_id),
actor=actor,
)
print(f"Created agent\n{agent_state}")
@ -929,31 +932,28 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
try:
# At this stage, there should only be 1 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()
# assert num_system_messages == 1, (num_system_messages, all_messages)
assert num_system_messages == 2, (num_system_messages, all_messages)
assert num_system_messages == 1, (num_system_messages, all_messages)
# Assuming core memory append actually ran correctly, at this point there should be 2 messages
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
# At this stage, there should only be 1 system message inside of recall storage
# At this stage, there should be 2 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()
# assert num_system_messages == 2, (num_system_messages, all_messages)
assert num_system_messages == 3, (num_system_messages, all_messages)
assert num_system_messages == 2, (num_system_messages, all_messages)
# Run server.load_agent, and make sure that the number of system messages is still 2
server.load_agent(agent_id=agent_state.id)
server.load_agent(agent_id=agent_state.id, actor=actor)
num_system_messages, all_messages = count_system_messages_in_recall()
# assert num_system_messages == 2, (num_system_messages, all_messages)
assert num_system_messages == 3, (num_system_messages, all_messages)
assert num_system_messages == 2, (num_system_messages, all_messages)
finally:
# cleanup
server.delete_agent(user_id, agent_state.id)
server.agent_manager.delete_agent(agent_state.id, actor=actor)
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
user = server.get_user_or_default(user_id)
actor = server.user_manager.get_user_or_default(user_id)
# Create a source
source = server.source_manager.create_source(
@ -962,7 +962,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
embedding_config=EmbeddingConfig.default_config(provider="openai"),
created_by_id=user_id,
),
actor=user
actor=actor,
)
# Create a test file with some content
@ -971,11 +971,10 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
test_file.write_text(test_content)
# Attach source to agent first
agent = server.load_agent(agent_id=agent_id)
agent.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
# Get initial passage count
initial_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
assert initial_passage_count == 0
# Create a job for loading the first file
@ -984,7 +983,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
user_id=user_id,
metadata_={"type": "embedding", "filename": test_file.name, "source_id": source.id},
),
actor=user
actor=actor,
)
# Load the first file to source
@ -992,17 +991,17 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
source_id=source.id,
file_path=str(test_file),
job_id=job.id,
actor=user,
actor=actor,
)
# Verify job completed successfully
job = server.job_manager.get_job_by_id(job_id=job.id, actor=user)
job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor)
assert job.status == "completed"
assert job.metadata_["num_passages"] == 1
assert job.metadata_["num_passages"] == 1
assert job.metadata_["num_documents"] == 1
# Verify passages were added
first_file_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
assert first_file_passage_count > initial_passage_count
# Create a second test file with different content
@ -1015,7 +1014,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
user_id=user_id,
metadata_={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
),
actor=user
actor=actor,
)
# Load the second file to source
@ -1023,22 +1022,22 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
source_id=source.id,
file_path=str(test_file2),
job_id=job2.id,
actor=user,
actor=actor,
)
# Verify second job completed successfully
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=user)
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor)
assert job2.status == "completed"
assert job2.metadata_["num_passages"] >= 10
assert job2.metadata_["num_documents"] == 1
# Verify passages were appended (not replaced)
final_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
assert final_passage_count > first_file_passage_count
# Verify both old and new content is searchable
passages = server.passage_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
source_id=source.id,
query_text="what does Timber like to eat",

View File

@ -33,7 +33,7 @@ def create_test_agent():
)
global agent_obj
agent_obj = client.server.load_agent(agent_id=agent_state.id)
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
def test_summarize_messages_inplace(mock_e2b_api_key_none):
@ -74,7 +74,7 @@ def test_summarize_messages_inplace(mock_e2b_api_key_none):
print(f"test_summarize: response={response}")
# reload agent object
agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id)
agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id, actor=client.user)
agent_obj.summarize_messages_inplace()
print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}")
@ -121,7 +121,7 @@ def test_auto_summarize(mock_e2b_api_key_none):
# check if the summarize message is inside the messages
assert isinstance(client, LocalClient), "Test only works with LocalClient"
agent_obj = client.server.load_agent(agent_id=agent_state.id)
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
print("SUMMARY", summarize_message_exists(agent_obj._messages))
if summarize_message_exists(agent_obj._messages):
break

View File

@ -169,7 +169,7 @@ def configure_mock_sync_server(mock_sync_server):
mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key]
# Mock user retrieval
mock_sync_server.get_user_or_default.return_value = Mock() # Provide additional attributes if needed
mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed
# ======================================================================================================================
@ -182,7 +182,7 @@ def test_delete_tool(client, mock_sync_server, add_integers_tool):
assert response.status_code == 200
mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
@ -195,7 +195,7 @@ def test_get_tool(client, mock_sync_server, add_integers_tool):
assert response.json()["id"] == add_integers_tool.id
assert response.json()["source_code"] == add_integers_tool.source_code
mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
@ -216,7 +216,7 @@ def test_get_tool_id(client, mock_sync_server, add_integers_tool):
assert response.status_code == 200
assert response.json() == add_integers_tool.id
mock_sync_server.tool_manager.get_tool_by_name.assert_called_once_with(
tool_name=add_integers_tool.name, actor=mock_sync_server.get_user_or_default.return_value
tool_name=add_integers_tool.name, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
@ -268,7 +268,7 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer
assert response.status_code == 200
assert response.json()["id"] == add_integers_tool.id
mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with(
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.get_user_or_default.return_value
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
@ -280,7 +280,9 @@ def test_add_base_tools(client, mock_sync_server, add_integers_tool):
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == add_integers_tool.id
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(actor=mock_sync_server.get_user_or_default.return_value)
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_list_composio_apps(client, mock_sync_server, composio_apps):

View File

@ -1,42 +1,39 @@
import numpy as np
import sqlite3
import base64
from numpy.testing import assert_array_almost_equal
import pytest
from letta.orm.sqlalchemy_base import adapt_array
from letta.orm.sqlite_functions import convert_array, verify_embedding_dimension
from letta.orm.sqlalchemy_base import adapt_array, convert_array
from letta.orm.sqlite_functions import verify_embedding_dimension
def test_vector_conversions():
"""Test the vector conversion functions"""
# Create test data
original = np.random.random(4096).astype(np.float32)
print(f"Original shape: {original.shape}")
# Test full conversion cycle
encoded = adapt_array(original)
print(f"Encoded type: {type(encoded)}")
print(f"Encoded length: {len(encoded)}")
decoded = convert_array(encoded)
print(f"Decoded shape: {decoded.shape}")
print(f"Dimension verification: {verify_embedding_dimension(decoded)}")
# Verify data integrity
np.testing.assert_array_almost_equal(original, decoded)
print("✓ Data integrity verified")
# Test with a list
list_data = original.tolist()
encoded_list = adapt_array(list_data)
decoded_list = convert_array(encoded_list)
np.testing.assert_array_almost_equal(original, decoded_list)
print("✓ List conversion verified")
# Test None handling
assert adapt_array(None) is None
assert convert_array(None) is None
print("✓ None handling verified")
# Run the tests
# Run the tests