mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Rewrite agents (#2232)
This commit is contained in:
parent
d42c1e5e72
commit
e49a8b4365
5
.github/workflows/integration_tests.yml
vendored
5
.github/workflows/integration_tests.yml
vendored
@ -18,7 +18,7 @@ on:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
run-integration-tests:
|
||||
integ-run:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
@ -27,6 +27,9 @@ jobs:
|
||||
integration_test_suite:
|
||||
- "integration_test_summarizer.py"
|
||||
- "integration_test_tool_execution_sandbox.py"
|
||||
- "integration_test_offline_memory_agent.py"
|
||||
- "integration_test_agent_tool_graph.py"
|
||||
- "integration_test_o1_agent.py"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
|
74
.github/workflows/tests.yml
vendored
74
.github/workflows/tests.yml
vendored
@ -13,25 +13,27 @@ on:
|
||||
pull_request:
|
||||
|
||||
jobs:
|
||||
run-core-unit-tests:
|
||||
unit-run:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 15
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
test_suite:
|
||||
- "test_vector_embeddings.py"
|
||||
- "test_client.py"
|
||||
- "test_local_client.py"
|
||||
- "test_client_legacy.py"
|
||||
- "test_server.py"
|
||||
- "test_managers.py"
|
||||
- "test_o1_agent.py"
|
||||
- "test_tool_rule_solver.py"
|
||||
- "test_agent_tool_graph.py"
|
||||
- "test_utils.py"
|
||||
- "test_tool_schema_parsing.py"
|
||||
- "test_v1_routes.py"
|
||||
- "test_offline_memory_agent.py"
|
||||
- "test_local_client.py"
|
||||
- "test_managers.py"
|
||||
- "test_base_functions.py"
|
||||
- "test_tool_schema_parsing.py"
|
||||
- "test_tool_rule_solver.py"
|
||||
- "test_memory.py"
|
||||
- "test_utils.py"
|
||||
- "test_stream_buffer_readers.py"
|
||||
- "test_summarize.py"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
@ -81,57 +83,3 @@ jobs:
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/${{ matrix.test_suite }}
|
||||
|
||||
misc-unit-tests:
|
||||
runs-on: ubuntu-latest
|
||||
needs: run-core-unit-tests
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
ports:
|
||||
- 6333:6333
|
||||
postgres:
|
||||
image: pgvector/pgvector:pg17
|
||||
ports:
|
||||
- 5432:5432
|
||||
env:
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_USER: postgres
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup Python, Poetry, and Dependencies
|
||||
uses: packetcoders/action-setup-cache-python-poetry@main
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E postgres -E external-tools -E tests -E cloud-tool-sandbox"
|
||||
- name: Migrate database
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
run: |
|
||||
psql -h localhost -U postgres -d postgres -c 'CREATE EXTENSION vector'
|
||||
poetry run alembic upgrade head
|
||||
- name: Run misc unit tests
|
||||
env:
|
||||
LETTA_PG_PORT: 5432
|
||||
LETTA_PG_USER: postgres
|
||||
LETTA_PG_PASSWORD: postgres
|
||||
LETTA_PG_DB: postgres
|
||||
LETTA_PG_HOST: localhost
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not test_offline_memory_agent.py and not test_v1_routes.py and not test_model_letta_perfomance.py and not test_utils.py and not test_client.py and not integration_test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests
|
||||
|
175
alembic/versions/d05669b60ebe_migrate_agents_to_orm.py
Normal file
175
alembic/versions/d05669b60ebe_migrate_agents_to_orm.py
Normal file
@ -0,0 +1,175 @@
|
||||
"""Migrate agents to orm
|
||||
|
||||
Revision ID: d05669b60ebe
|
||||
Revises: c5d964280dff
|
||||
Create Date: 2024-12-12 10:25:31.825635
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d05669b60ebe"
|
||||
down_revision: Union[str, None] = "c5d964280dff"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"sources_agents",
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("source_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_id"],
|
||||
["agents.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["source_id"],
|
||||
["sources.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("agent_id", "source_id"),
|
||||
)
|
||||
op.drop_index("agent_source_mapping_idx_user", table_name="agent_source_mapping")
|
||||
op.drop_table("agent_source_mapping")
|
||||
op.add_column("agents", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("agents", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
op.add_column("agents", sa.Column("_created_by_id", sa.String(), nullable=True))
|
||||
op.add_column("agents", sa.Column("_last_updated_by_id", sa.String(), nullable=True))
|
||||
op.add_column("agents", sa.Column("organization_id", sa.String(), nullable=True))
|
||||
# Populate `organization_id` based on `user_id`
|
||||
# Use a raw SQL query to update the organization_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE agents
|
||||
SET organization_id = users.organization_id
|
||||
FROM users
|
||||
WHERE agents.user_id = users.id
|
||||
"""
|
||||
)
|
||||
op.alter_column("agents", "organization_id", nullable=False)
|
||||
op.alter_column("agents", "name", existing_type=sa.VARCHAR(), nullable=True)
|
||||
op.drop_index("agents_idx_user", table_name="agents")
|
||||
op.create_unique_constraint("unique_org_agent_name", "agents", ["organization_id", "name"])
|
||||
op.create_foreign_key(None, "agents", "organizations", ["organization_id"], ["id"])
|
||||
op.drop_column("agents", "tool_names")
|
||||
op.drop_column("agents", "user_id")
|
||||
op.drop_constraint("agents_tags_organization_id_fkey", "agents_tags", type_="foreignkey")
|
||||
op.drop_column("agents_tags", "_created_by_id")
|
||||
op.drop_column("agents_tags", "_last_updated_by_id")
|
||||
op.drop_column("agents_tags", "updated_at")
|
||||
op.drop_column("agents_tags", "id")
|
||||
op.drop_column("agents_tags", "is_deleted")
|
||||
op.drop_column("agents_tags", "created_at")
|
||||
op.drop_column("agents_tags", "organization_id")
|
||||
op.create_unique_constraint("unique_agent_block", "blocks_agents", ["agent_id", "block_id"])
|
||||
op.drop_constraint("fk_block_id_label", "blocks_agents", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"fk_block_id_label", "blocks_agents", "block", ["block_id", "block_label"], ["id", "label"], initially="DEFERRED", deferrable=True
|
||||
)
|
||||
op.drop_column("blocks_agents", "_created_by_id")
|
||||
op.drop_column("blocks_agents", "_last_updated_by_id")
|
||||
op.drop_column("blocks_agents", "updated_at")
|
||||
op.drop_column("blocks_agents", "id")
|
||||
op.drop_column("blocks_agents", "is_deleted")
|
||||
op.drop_column("blocks_agents", "created_at")
|
||||
op.drop_constraint("unique_tool_per_agent", "tools_agents", type_="unique")
|
||||
op.create_unique_constraint("unique_agent_tool", "tools_agents", ["agent_id", "tool_id"])
|
||||
op.drop_constraint("fk_tool_id", "tools_agents", type_="foreignkey")
|
||||
op.drop_constraint("tools_agents_agent_id_fkey", "tools_agents", type_="foreignkey")
|
||||
op.create_foreign_key(None, "tools_agents", "tools", ["tool_id"], ["id"], ondelete="CASCADE")
|
||||
op.create_foreign_key(None, "tools_agents", "agents", ["agent_id"], ["id"], ondelete="CASCADE")
|
||||
op.drop_column("tools_agents", "_created_by_id")
|
||||
op.drop_column("tools_agents", "tool_name")
|
||||
op.drop_column("tools_agents", "_last_updated_by_id")
|
||||
op.drop_column("tools_agents", "updated_at")
|
||||
op.drop_column("tools_agents", "id")
|
||||
op.drop_column("tools_agents", "is_deleted")
|
||||
op.drop_column("tools_agents", "created_at")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"tools_agents",
|
||||
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"tools_agents", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
|
||||
)
|
||||
op.add_column("tools_agents", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.add_column(
|
||||
"tools_agents",
|
||||
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column("tools_agents", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.add_column("tools_agents", sa.Column("tool_name", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.add_column("tools_agents", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint(None, "tools_agents", type_="foreignkey")
|
||||
op.drop_constraint(None, "tools_agents", type_="foreignkey")
|
||||
op.create_foreign_key("tools_agents_agent_id_fkey", "tools_agents", "agents", ["agent_id"], ["id"])
|
||||
op.create_foreign_key("fk_tool_id", "tools_agents", "tools", ["tool_id"], ["id"])
|
||||
op.drop_constraint("unique_agent_tool", "tools_agents", type_="unique")
|
||||
op.create_unique_constraint("unique_tool_per_agent", "tools_agents", ["agent_id", "tool_name"])
|
||||
op.add_column(
|
||||
"blocks_agents",
|
||||
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"blocks_agents", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
|
||||
)
|
||||
op.add_column("blocks_agents", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.add_column(
|
||||
"blocks_agents",
|
||||
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column("blocks_agents", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.add_column("blocks_agents", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.drop_constraint("fk_block_id_label", "blocks_agents", type_="foreignkey")
|
||||
op.create_foreign_key("fk_block_id_label", "blocks_agents", "block", ["block_id", "block_label"], ["id", "label"])
|
||||
op.drop_constraint("unique_agent_block", "blocks_agents", type_="unique")
|
||||
op.add_column("agents_tags", sa.Column("organization_id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.add_column(
|
||||
"agents_tags",
|
||||
sa.Column("created_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"agents_tags", sa.Column("is_deleted", sa.BOOLEAN(), server_default=sa.text("false"), autoincrement=False, nullable=False)
|
||||
)
|
||||
op.add_column("agents_tags", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.add_column(
|
||||
"agents_tags",
|
||||
sa.Column("updated_at", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.add_column("agents_tags", sa.Column("_last_updated_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.add_column("agents_tags", sa.Column("_created_by_id", sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
op.create_foreign_key("agents_tags_organization_id_fkey", "agents_tags", "organizations", ["organization_id"], ["id"])
|
||||
op.add_column("agents", sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.add_column("agents", sa.Column("tool_names", postgresql.JSON(astext_type=sa.Text()), autoincrement=False, nullable=True))
|
||||
op.drop_constraint(None, "agents", type_="foreignkey")
|
||||
op.drop_constraint("unique_org_agent_name", "agents", type_="unique")
|
||||
op.create_index("agents_idx_user", "agents", ["user_id"], unique=False)
|
||||
op.alter_column("agents", "name", existing_type=sa.VARCHAR(), nullable=False)
|
||||
op.drop_column("agents", "organization_id")
|
||||
op.drop_column("agents", "_last_updated_by_id")
|
||||
op.drop_column("agents", "_created_by_id")
|
||||
op.drop_column("agents", "is_deleted")
|
||||
op.drop_column("agents", "updated_at")
|
||||
op.create_table(
|
||||
"agent_source_mapping",
|
||||
sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column("user_id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column("agent_id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column("source_id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="agent_source_mapping_pkey"),
|
||||
)
|
||||
op.create_index("agent_source_mapping_idx_user", "agent_source_mapping", ["user_id", "agent_id", "source_id"], unique=False)
|
||||
op.drop_table("sources_agents")
|
||||
# ### end Alembic commands ###
|
@ -30,7 +30,7 @@ def upgrade() -> None:
|
||||
op.execute("DELETE FROM tools")
|
||||
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("agents", sa.Column("tool_rules", letta.metadata.ToolRulesColumn(), nullable=True))
|
||||
op.add_column("agents", sa.Column("tool_rules", letta.orm.agent.ToolRulesColumn(), nullable=True))
|
||||
op.alter_column("block", "name", new_column_name="template_name", nullable=True)
|
||||
op.add_column("organizations", sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True))
|
||||
op.add_column("organizations", sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False))
|
||||
|
@ -74,7 +74,7 @@ def main():
|
||||
"""
|
||||
|
||||
# Create an agent
|
||||
agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tools=[tool.name])
|
||||
agent = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[tool.id])
|
||||
print(f"Created agent: {agent.name} with ID {str(agent.id)}")
|
||||
|
||||
# Send a message to the agent
|
||||
|
@ -29,7 +29,7 @@ agent_state = client.create_agent(
|
||||
# whether to include base letta tools (default: True)
|
||||
include_base_tools=True,
|
||||
# list of additional tools (by name) to add to the agent
|
||||
tools=[],
|
||||
tool_ids=[],
|
||||
)
|
||||
print(f"Created agent with name {agent_state.name} and unique ID {agent_state.id}")
|
||||
|
||||
|
@ -36,7 +36,7 @@ print(f"Created tool with name {tool.name}")
|
||||
# create a new agent
|
||||
agent_state = client.create_agent(
|
||||
# create the agent with an additional tool
|
||||
tools=[tool.name],
|
||||
tool_ids=[tool.id],
|
||||
# add tool rules that terminate execution after specific tools
|
||||
tool_rules=[
|
||||
# exit after roll_d20 is called
|
||||
@ -45,7 +45,7 @@ agent_state = client.create_agent(
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
],
|
||||
)
|
||||
print(f"Created agent with name {agent_state.name} with tools {agent_state.tool_names}")
|
||||
print(f"Created agent with name {agent_state.name} with tools {[t.name for t in agent_state.tools]}")
|
||||
|
||||
# Message an agent
|
||||
response = client.send_message(agent_id=agent_state.id, role="user", message="roll a dice")
|
||||
@ -61,7 +61,8 @@ client.add_tool_to_agent(agent_id=agent_state.id, tool_id=tool.id)
|
||||
client.delete_agent(agent_id=agent_state.id)
|
||||
|
||||
# create an agent with only a subset of default tools
|
||||
agent_state = client.create_agent(include_base_tools=False, tools=[tool.name, "send_message"])
|
||||
send_message_tool = client.get_tool_id("send_message")
|
||||
agent_state = client.create_agent(include_base_tools=False, tool_ids=[tool.id, send_message_tool])
|
||||
|
||||
# message the agent to search archival memory (will be unable to do so)
|
||||
response = client.send_message(agent_id=agent_state.id, role="user", message="search your archival memory")
|
||||
|
@ -67,7 +67,9 @@ def main():
|
||||
"""
|
||||
|
||||
# Create an agent
|
||||
agent_state = client.create_agent(name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tools=[tool_name])
|
||||
agent_state = client.create_agent(
|
||||
name=agent_uuid, memory=ChatMemory(human="My name is Matt.", persona=persona), tool_ids=[wikipedia_query_tool.id]
|
||||
)
|
||||
print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}")
|
||||
|
||||
# Send a message to the agent
|
||||
|
@ -108,7 +108,7 @@ def main():
|
||||
]
|
||||
|
||||
# 4. Create the agent
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules)
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# 5. Ask for the final secret word
|
||||
response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?")
|
||||
|
@ -4,7 +4,7 @@ __version__ = "0.6.4"
|
||||
from letta.client.client import LocalClient, RESTClient, create_client
|
||||
|
||||
# imports for easier access
|
||||
from letta.schemas.agent import AgentState, PersistedAgentState
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus
|
||||
|
@ -6,8 +6,6 @@ import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
CLI_WARNING_PREFIX,
|
||||
@ -30,7 +28,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
|
||||
from letta.memory import summarize_messages
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.orm import User
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
@ -49,12 +47,12 @@ from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import (
|
||||
get_heartbeat,
|
||||
@ -316,7 +314,7 @@ class Agent(BaseAgent):
|
||||
|
||||
else:
|
||||
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
|
||||
assert self.agent_state.id is not None and self.agent_state.user_id is not None
|
||||
assert self.agent_state.id is not None and self.agent_state.created_by_id is not None
|
||||
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
@ -335,7 +333,7 @@ class Agent(BaseAgent):
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
@ -358,7 +356,7 @@ class Agent(BaseAgent):
|
||||
# Cast to Message objects
|
||||
init_messages = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=msg
|
||||
agent_id=self.agent_state.id, user_id=self.agent_state.created_by_id, model=self.model, openai_message_dict=msg
|
||||
)
|
||||
for msg in init_messages
|
||||
]
|
||||
@ -439,11 +437,12 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
# execute tool in a sandbox
|
||||
# TODO: allow agent_state to specify which sandbox to execute tools in
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run(
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.agent_state.created_by_id).run(
|
||||
agent_state=self.agent_state.__deepcopy__()
|
||||
)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
|
||||
self.update_memory_if_change(updated_agent_state.memory)
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
@ -573,7 +572,7 @@ class Agent(BaseAgent):
|
||||
added_messages_objs = [
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=msg,
|
||||
)
|
||||
@ -603,7 +602,7 @@ class Agent(BaseAgent):
|
||||
response = create(
|
||||
llm_config=self.agent_state.llm_config,
|
||||
messages=message_sequence,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
functions=allowed_functions,
|
||||
functions_python=self.functions_python,
|
||||
function_call=function_call,
|
||||
@ -689,7 +688,7 @@ class Agent(BaseAgent):
|
||||
Message.dict_to_message(
|
||||
id=response_message_id,
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
)
|
||||
@ -722,7 +721,7 @@ class Agent(BaseAgent):
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "tool",
|
||||
@ -745,7 +744,7 @@ class Agent(BaseAgent):
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "tool",
|
||||
@ -823,7 +822,7 @@ class Agent(BaseAgent):
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "tool",
|
||||
@ -842,7 +841,7 @@ class Agent(BaseAgent):
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "tool",
|
||||
@ -861,7 +860,7 @@ class Agent(BaseAgent):
|
||||
Message.dict_to_message(
|
||||
id=response_message_id,
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
)
|
||||
@ -920,7 +919,7 @@ class Agent(BaseAgent):
|
||||
# logger.debug("Saving agent state")
|
||||
# save updated state
|
||||
if ms:
|
||||
save_agent(self, ms)
|
||||
save_agent(self)
|
||||
|
||||
# Chain stops
|
||||
if not chaining:
|
||||
@ -931,10 +930,10 @@ class Agent(BaseAgent):
|
||||
break
|
||||
# Chain handlers
|
||||
elif token_warning:
|
||||
assert self.agent_state.user_id is not None
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
@ -943,10 +942,10 @@ class Agent(BaseAgent):
|
||||
)
|
||||
continue # always chain
|
||||
elif function_failed:
|
||||
assert self.agent_state.user_id is not None
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
@ -955,10 +954,10 @@ class Agent(BaseAgent):
|
||||
)
|
||||
continue # always chain
|
||||
elif heartbeat_request:
|
||||
assert self.agent_state.user_id is not None
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
@ -1129,10 +1128,10 @@ class Agent(BaseAgent):
|
||||
openai_message_dict = {"role": "user", "content": cleaned_user_message_text, "name": name}
|
||||
|
||||
# Create the associated Message object (in the database)
|
||||
assert self.agent_state.user_id is not None, "User ID is not set"
|
||||
assert self.agent_state.created_by_id is not None, "User ID is not set"
|
||||
user_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=openai_message_dict,
|
||||
# created_at=timestamp,
|
||||
@ -1232,7 +1231,7 @@ class Agent(BaseAgent):
|
||||
[
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=packed_summary_message,
|
||||
)
|
||||
@ -1260,7 +1259,7 @@ class Agent(BaseAgent):
|
||||
assert isinstance(new_system_message, str)
|
||||
new_system_message_obj = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message},
|
||||
)
|
||||
@ -1371,7 +1370,14 @@ class Agent(BaseAgent):
|
||||
# TODO: recall memory
|
||||
raise NotImplementedError()
|
||||
|
||||
def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore, page_size: Optional[int] = None):
|
||||
def attach_source(
|
||||
self,
|
||||
user: PydanticUser,
|
||||
source_id: str,
|
||||
source_manager: SourceManager,
|
||||
agent_manager: AgentManager,
|
||||
page_size: Optional[int] = None,
|
||||
):
|
||||
"""Attach data with name `source_name` to the agent from source_connector."""
|
||||
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
|
||||
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
|
||||
@ -1384,7 +1390,7 @@ class Agent(BaseAgent):
|
||||
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
|
||||
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
|
||||
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
|
||||
assert len(agents_passages) == passage_size # sanity check
|
||||
assert len(agents_passages) == passage_size # sanity check
|
||||
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
|
||||
|
||||
# attach to agent
|
||||
@ -1393,7 +1399,7 @@ class Agent(BaseAgent):
|
||||
|
||||
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
|
||||
# TODO: delete @matt and remove
|
||||
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
|
||||
agent_manager.attach_source(agent_id=self.agent_state.id, source_id=source_id, actor=user)
|
||||
|
||||
printd(
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
|
||||
@ -1610,20 +1616,31 @@ class Agent(BaseAgent):
|
||||
return context_window_breakdown.context_window_size_current
|
||||
|
||||
|
||||
def save_agent(agent: Agent, ms: MetadataStore):
|
||||
def save_agent(agent: Agent):
|
||||
"""Save agent to metadata store"""
|
||||
|
||||
agent.update_state()
|
||||
agent_state = agent.agent_state
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
||||
|
||||
# TODO: move this to agent manager
|
||||
# TODO: Completely strip out metadata
|
||||
# convert to persisted model
|
||||
persisted_agent_state = agent.agent_state.to_persisted_agent_state()
|
||||
if ms.get_agent(agent_id=persisted_agent_state.id):
|
||||
ms.update_agent(persisted_agent_state)
|
||||
else:
|
||||
ms.create_agent(persisted_agent_state)
|
||||
agent_manager = AgentManager()
|
||||
update_agent = UpdateAgent(
|
||||
name=agent_state.name,
|
||||
tool_ids=[t.id for t in agent_state.tools],
|
||||
source_ids=[s.id for s in agent_state.sources],
|
||||
block_ids=[b.id for b in agent_state.memory.blocks],
|
||||
tags=agent_state.tags,
|
||||
system=agent_state.system,
|
||||
tool_rules=agent_state.tool_rules,
|
||||
llm_config=agent_state.llm_config,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
message_ids=agent_state.message_ids,
|
||||
description=agent_state.description,
|
||||
metadata_=agent_state.metadata_,
|
||||
)
|
||||
agent_manager.update_agent(agent_id=agent_state.id, agent_update=update_agent, actor=agent.user)
|
||||
|
||||
|
||||
def strip_name_field_from_user_message(user_message_text: str) -> Tuple[str, Optional[str]]:
|
||||
|
@ -2,7 +2,6 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from letta.agent import Agent
|
||||
|
||||
from letta.interface import AgentInterface
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.prompts import gpt_system
|
||||
@ -68,8 +67,10 @@ class ChatOnlyAgent(Agent):
|
||||
name="chat_agent_persona_new", label="chat_agent_persona_new", value=conversation_persona_block.value, limit=2000
|
||||
)
|
||||
|
||||
recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit:]
|
||||
conversation_messages_block = Block(name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit)
|
||||
recent_convo = "".join([str(message) for message in self.messages[3:]])[-self.recent_convo_limit :]
|
||||
conversation_messages_block = Block(
|
||||
name="conversation_block", label="conversation_block", value=recent_convo, limit=self.recent_convo_limit
|
||||
)
|
||||
|
||||
offline_memory = BasicBlockMemory(
|
||||
blocks=[
|
||||
@ -89,7 +90,7 @@ class ChatOnlyAgent(Agent):
|
||||
memory=offline_memory,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
tools=self.agent_state.metadata_.get("offline_memory_tools", []),
|
||||
tool_ids=self.agent_state.metadata_.get("offline_memory_tools", []),
|
||||
include_base_tools=False,
|
||||
)
|
||||
self.offline_memory_agent.memory.update_block_value(label="conversation_block", value=recent_convo)
|
||||
|
@ -140,7 +140,6 @@ def run(
|
||||
# read user id from config
|
||||
ms = MetadataStore(config)
|
||||
client = create_client()
|
||||
server = client.server
|
||||
|
||||
# determine agent to use, if not provided
|
||||
if not yes and not agent:
|
||||
@ -165,8 +164,6 @@ def run(
|
||||
persona = persona if persona else config.persona
|
||||
if agent and agent_state: # use existing agent
|
||||
typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN)
|
||||
# agent_config = AgentConfig.load(agent)
|
||||
# agent_state = ms.get_agent(agent_name=agent, user_id=user_id)
|
||||
printd("Loading agent state:", agent_state.id)
|
||||
printd("Agent state:", agent_state.name)
|
||||
# printd("State path:", agent_config.save_state_dir())
|
||||
@ -224,8 +221,6 @@ def run(
|
||||
)
|
||||
|
||||
# create agent
|
||||
tools = [server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=client.user) for tool_name in agent_state.tool_names]
|
||||
agent_state.tools = tools
|
||||
letta_agent = Agent(agent_state=agent_state, interface=interface(), user=client.user)
|
||||
|
||||
else: # create new agent
|
||||
@ -317,7 +312,7 @@ def run(
|
||||
metadata=metadata,
|
||||
)
|
||||
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
|
||||
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tool_names])}", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t.name for t in agent_state.tools])}", fg=typer.colors.WHITE)
|
||||
|
||||
letta_agent = Agent(
|
||||
interface=interface(),
|
||||
@ -326,7 +321,7 @@ def run(
|
||||
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
|
||||
user=client.user,
|
||||
)
|
||||
save_agent(agent=letta_agent, ms=ms)
|
||||
save_agent(agent=letta_agent)
|
||||
typer.secho(f"🎉 Created new agent '{letta_agent.agent_state.name}' (id={letta_agent.agent_state.id})", fg=typer.colors.GREEN)
|
||||
|
||||
# start event loop
|
||||
|
@ -15,7 +15,8 @@ from letta.constants import (
|
||||
)
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock, Human, Persona
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
@ -65,10 +66,8 @@ def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
|
||||
class AbstractClient(object):
|
||||
def __init__(
|
||||
self,
|
||||
auto_save: bool = False,
|
||||
debug: bool = False,
|
||||
):
|
||||
self.auto_save = auto_save
|
||||
self.debug = debug
|
||||
|
||||
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
|
||||
@ -81,8 +80,9 @@ class AbstractClient(object):
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
memory=None,
|
||||
block_ids: Optional[List[str]] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
|
||||
@ -97,7 +97,7 @@ class AbstractClient(object):
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
@ -436,7 +436,6 @@ class RESTClient(AbstractClient):
|
||||
Initializes a new instance of Client class.
|
||||
|
||||
Args:
|
||||
auto_save (bool): Whether to automatically save changes.
|
||||
user_id (str): The user ID.
|
||||
debug (bool): Whether to print debug information.
|
||||
default_llm_config (Optional[LLMConfig]): The default LLM configuration.
|
||||
@ -456,6 +455,7 @@ class RESTClient(AbstractClient):
|
||||
params = {}
|
||||
if tags:
|
||||
params["tags"] = tags
|
||||
params["match_all_tags"] = False
|
||||
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
|
||||
return [AgentState(**agent) for agent in response.json()]
|
||||
@ -491,10 +491,12 @@ class RESTClient(AbstractClient):
|
||||
llm_config: LLMConfig = None,
|
||||
# memory
|
||||
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
|
||||
# Existing blocks
|
||||
block_ids: Optional[List[str]] = None,
|
||||
# system
|
||||
system: Optional[str] = None,
|
||||
# tools
|
||||
tools: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
# metadata
|
||||
@ -511,7 +513,7 @@ class RESTClient(AbstractClient):
|
||||
llm_config (LLMConfig): LLM configuration
|
||||
memory (Memory): Memory configuration
|
||||
system (str): System configuration
|
||||
tools (List[str]): List of tools
|
||||
tool_ids (List[str]): List of tool ids
|
||||
include_base_tools (bool): Include base tools
|
||||
metadata (Dict): Metadata
|
||||
description (str): Description
|
||||
@ -520,31 +522,54 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
agent_state (AgentState): State of the created agent
|
||||
"""
|
||||
tool_ids = tool_ids or []
|
||||
tool_names = []
|
||||
if tools:
|
||||
tool_names += tools
|
||||
if include_base_tools:
|
||||
tool_names += BASE_TOOLS
|
||||
tool_names += BASE_MEMORY_TOOLS
|
||||
tool_ids += [self.get_tool_id(tool_name=name) for name in tool_names]
|
||||
|
||||
assert embedding_config or self._default_embedding_config, f"Embedding config must be provided"
|
||||
assert llm_config or self._default_llm_config, f"LLM config must be provided"
|
||||
|
||||
# TODO: This should not happen here, we need to have clear separation between create/add blocks
|
||||
# TODO: This is insanely hacky and a result of allowing free-floating blocks
|
||||
# TODO: When we create the block, it gets it's own block ID
|
||||
blocks = []
|
||||
for block in memory.get_blocks():
|
||||
blocks.append(
|
||||
self.create_block(
|
||||
label=block.label,
|
||||
value=block.value,
|
||||
limit=block.limit,
|
||||
template_name=block.template_name,
|
||||
is_template=block.is_template,
|
||||
)
|
||||
)
|
||||
memory.blocks = blocks
|
||||
block_ids = block_ids or []
|
||||
|
||||
# create agent
|
||||
request = CreateAgent(
|
||||
name=name,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
memory_blocks=[],
|
||||
tools=tool_names,
|
||||
tool_rules=tool_rules,
|
||||
system=system,
|
||||
agent_type=agent_type,
|
||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||
initial_message_sequence=initial_message_sequence,
|
||||
tags=tags,
|
||||
)
|
||||
create_params = {
|
||||
"description": description,
|
||||
"metadata_": metadata,
|
||||
"memory_blocks": [],
|
||||
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
|
||||
"tool_ids": tool_ids,
|
||||
"tool_rules": tool_rules,
|
||||
"system": system,
|
||||
"agent_type": agent_type,
|
||||
"llm_config": llm_config if llm_config else self._default_llm_config,
|
||||
"embedding_config": embedding_config if embedding_config else self._default_embedding_config,
|
||||
"initial_message_sequence": initial_message_sequence,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
# Only add name if it's not None
|
||||
if name is not None:
|
||||
create_params["name"] = name
|
||||
|
||||
request = CreateAgent(**create_params)
|
||||
|
||||
# Use model_dump_json() instead of model_dump()
|
||||
# If we use model_dump(), the datetime objects will not be serialized correctly
|
||||
@ -561,14 +586,6 @@ class RESTClient(AbstractClient):
|
||||
# gather agent state
|
||||
agent_state = AgentState(**response.json())
|
||||
|
||||
# create and link blocks
|
||||
for block in memory.get_blocks():
|
||||
if not self.get_block(block.id):
|
||||
# note: this does not update existing blocks
|
||||
# WARNING: this resets the block ID - this method is a hack for backwards compat, should eventually use CreateBlock not Memory
|
||||
block = self.create_block(label=block.label, value=block.value, limit=block.limit)
|
||||
self.link_agent_memory_block(agent_id=agent_state.id, block_id=block.id)
|
||||
|
||||
# refresh and return agent
|
||||
return self.get_agent(agent_state.id)
|
||||
|
||||
@ -602,7 +619,7 @@ class RESTClient(AbstractClient):
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
tool_names: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
@ -617,7 +634,7 @@ class RESTClient(AbstractClient):
|
||||
name (str): Name of the agent
|
||||
description (str): Description of the agent
|
||||
system (str): System configuration
|
||||
tool_names (List[str]): List of tools
|
||||
tool_ids (List[str]): List of tools
|
||||
metadata (Dict): Metadata
|
||||
llm_config (LLMConfig): LLM configuration
|
||||
embedding_config (EmbeddingConfig): Embedding configuration
|
||||
@ -627,11 +644,10 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
agent_state (AgentState): State of the updated agent
|
||||
"""
|
||||
request = UpdateAgentState(
|
||||
id=agent_id,
|
||||
request = UpdateAgent(
|
||||
name=name,
|
||||
system=system,
|
||||
tool_names=tool_names,
|
||||
tool_ids=tool_ids,
|
||||
tags=tags,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
@ -742,7 +758,7 @@ class RESTClient(AbstractClient):
|
||||
agents = [AgentState(**agent) for agent in response.json()]
|
||||
if len(agents) == 0:
|
||||
return None
|
||||
agents = [agents[0]] # TODO: @matt monkeypatched
|
||||
agents = [agents[0]] # TODO: @matt monkeypatched
|
||||
assert len(agents) == 1, f"Multiple agents with the same name: {[(agents.name, agents.id) for agents in agents]}"
|
||||
return agents[0].id
|
||||
|
||||
@ -1052,7 +1068,7 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to update block: {response.text}")
|
||||
return Block(**response.json())
|
||||
|
||||
def get_block(self, block_id: str) -> Block:
|
||||
def get_block(self, block_id: str) -> Optional[Block]:
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
@ -1607,23 +1623,6 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return Tool(**response.json())
|
||||
|
||||
def get_tool_id(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Get a tool ID by its name.
|
||||
|
||||
Args:
|
||||
id (str): ID of the tool
|
||||
|
||||
Returns:
|
||||
tool (Tool): Tool
|
||||
"""
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/tools/name/{name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def set_default_llm_config(self, llm_config: LLMConfig):
|
||||
"""
|
||||
Set the default LLM configuration
|
||||
@ -2006,7 +2005,6 @@ class LocalClient(AbstractClient):
|
||||
A local client for Letta, which corresponds to a single user.
|
||||
|
||||
Attributes:
|
||||
auto_save (bool): Whether to automatically save changes.
|
||||
user_id (str): The user ID.
|
||||
debug (bool): Whether to print debug information.
|
||||
interface (QueuingInterface): The interface for the client.
|
||||
@ -2015,7 +2013,6 @@ class LocalClient(AbstractClient):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_save: bool = False,
|
||||
user_id: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
debug: bool = False,
|
||||
@ -2026,11 +2023,9 @@ class LocalClient(AbstractClient):
|
||||
Initializes a new instance of Client class.
|
||||
|
||||
Args:
|
||||
auto_save (bool): Whether to automatically save changes.
|
||||
user_id (str): The user ID.
|
||||
debug (bool): Whether to print debug information.
|
||||
"""
|
||||
self.auto_save = auto_save
|
||||
|
||||
# set logging levels
|
||||
letta.utils.DEBUG = debug
|
||||
@ -2056,14 +2051,14 @@ class LocalClient(AbstractClient):
|
||||
# get default user
|
||||
self.user_id = self.server.user_manager.DEFAULT_USER_ID
|
||||
|
||||
self.user = self.server.get_user_or_default(self.user_id)
|
||||
self.user = self.server.user_manager.get_user_or_default(self.user_id)
|
||||
self.organization = self.server.get_organization_or_default(self.org_id)
|
||||
|
||||
# agents
|
||||
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
|
||||
self.interface.clear()
|
||||
|
||||
return self.server.list_agents(user_id=self.user_id, tags=tags)
|
||||
return self.server.agent_manager.list_agents(actor=self.user, tags=tags)
|
||||
|
||||
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
@ -2097,6 +2092,7 @@ class LocalClient(AbstractClient):
|
||||
llm_config: LLMConfig = None,
|
||||
# memory
|
||||
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
|
||||
block_ids: Optional[List[str]] = None,
|
||||
# TODO: change to this when we are ready to migrate all the tests/examples (matches the REST API)
|
||||
# memory_blocks=[
|
||||
# {"label": "human", "value": get_human_text(DEFAULT_HUMAN), "limit": 5000},
|
||||
@ -2105,7 +2101,7 @@ class LocalClient(AbstractClient):
|
||||
# system
|
||||
system: Optional[str] = None,
|
||||
# tools
|
||||
tools: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
# metadata
|
||||
@ -2132,55 +2128,53 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
agent_state (AgentState): State of the created agent
|
||||
"""
|
||||
|
||||
if name and self.agent_exists(agent_name=name):
|
||||
raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})")
|
||||
|
||||
# construct list of tools
|
||||
tool_ids = tool_ids or []
|
||||
tool_names = []
|
||||
if tools:
|
||||
tool_names += tools
|
||||
if include_base_tools:
|
||||
tool_names += BASE_TOOLS
|
||||
tool_names += BASE_MEMORY_TOOLS
|
||||
tool_ids += [self.server.tool_manager.get_tool_by_name(tool_name=name, actor=self.user).id for name in tool_names]
|
||||
|
||||
# check if default configs are provided
|
||||
assert embedding_config or self._default_embedding_config, f"Embedding config must be provided"
|
||||
assert llm_config or self._default_llm_config, f"LLM config must be provided"
|
||||
|
||||
# TODO: This should not happen here, we need to have clear separation between create/add blocks
|
||||
for block in memory.get_blocks():
|
||||
self.server.block_manager.create_or_update_block(block, actor=self.user)
|
||||
|
||||
# Also get any existing block_ids passed in
|
||||
block_ids = block_ids or []
|
||||
|
||||
# create agent
|
||||
# Create the base parameters
|
||||
create_params = {
|
||||
"description": description,
|
||||
"metadata_": metadata,
|
||||
"memory_blocks": [],
|
||||
"block_ids": [b.id for b in memory.get_blocks()] + block_ids,
|
||||
"tool_ids": tool_ids,
|
||||
"tool_rules": tool_rules,
|
||||
"system": system,
|
||||
"agent_type": agent_type,
|
||||
"llm_config": llm_config if llm_config else self._default_llm_config,
|
||||
"embedding_config": embedding_config if embedding_config else self._default_embedding_config,
|
||||
"initial_message_sequence": initial_message_sequence,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
# Only add name if it's not None
|
||||
if name is not None:
|
||||
create_params["name"] = name
|
||||
|
||||
agent_state = self.server.create_agent(
|
||||
CreateAgent(
|
||||
name=name,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
# memory=memory,
|
||||
memory_blocks=[],
|
||||
# memory_blocks = memory.get_blocks(),
|
||||
# memory_tools=memory_tools,
|
||||
tools=tool_names,
|
||||
tool_rules=tool_rules,
|
||||
system=system,
|
||||
agent_type=agent_type,
|
||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||
initial_message_sequence=initial_message_sequence,
|
||||
tags=tags,
|
||||
),
|
||||
CreateAgent(**create_params),
|
||||
actor=self.user,
|
||||
)
|
||||
|
||||
# TODO: remove when we fully migrate to block creation CreateAgent model
|
||||
# Link additional blocks to the agent (block ids created on the client)
|
||||
# This needs to happen since the create agent does not allow passing in blocks which have already been persisted and have an ID
|
||||
# So we create the agent and then link the blocks afterwards
|
||||
user = self.server.get_user_or_default(self.user_id)
|
||||
for block in memory.get_blocks():
|
||||
self.server.block_manager.create_or_update_block(block, actor=user)
|
||||
self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_state.id, block_id=block.id)
|
||||
|
||||
# TODO: get full agent state
|
||||
return self.server.get_agent(agent_state.id)
|
||||
return self.server.agent_manager.get_agent_by_id(agent_state.id, actor=self.user)
|
||||
|
||||
def update_message(
|
||||
self,
|
||||
@ -2202,6 +2196,7 @@ class LocalClient(AbstractClient):
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=tool_call_id,
|
||||
),
|
||||
actor=self.user,
|
||||
)
|
||||
return message
|
||||
|
||||
@ -2211,7 +2206,7 @@ class LocalClient(AbstractClient):
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
@ -2239,11 +2234,11 @@ class LocalClient(AbstractClient):
|
||||
# TODO: add the abilitty to reset linked block_ids
|
||||
self.interface.clear()
|
||||
agent_state = self.server.update_agent(
|
||||
UpdateAgentState(
|
||||
id=agent_id,
|
||||
agent_id,
|
||||
UpdateAgent(
|
||||
name=name,
|
||||
system=system,
|
||||
tool_names=tools,
|
||||
tool_ids=tool_ids,
|
||||
tags=tags,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
@ -2315,7 +2310,7 @@ class LocalClient(AbstractClient):
|
||||
Args:
|
||||
agent_id (str): ID of the agent to delete
|
||||
"""
|
||||
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
|
||||
self.server.agent_manager.delete_agent(agent_id=agent_id, actor=self.user)
|
||||
|
||||
def get_agent_by_name(self, agent_name: str) -> AgentState:
|
||||
"""
|
||||
@ -2328,7 +2323,7 @@ class LocalClient(AbstractClient):
|
||||
agent_state (AgentState): State of the agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None)
|
||||
return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user)
|
||||
|
||||
def get_agent(self, agent_id: str) -> AgentState:
|
||||
"""
|
||||
@ -2340,9 +2335,8 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
agent_state (AgentState): State representation of the agent
|
||||
"""
|
||||
# TODO: include agent_name
|
||||
self.interface.clear()
|
||||
return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id)
|
||||
return self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
|
||||
def get_agent_id(self, agent_name: str) -> Optional[str]:
|
||||
"""
|
||||
@ -2357,7 +2351,12 @@ class LocalClient(AbstractClient):
|
||||
|
||||
self.interface.clear()
|
||||
assert agent_name, f"Agent name must be provided"
|
||||
return self.server.get_agent_id(name=agent_name, user_id=self.user_id)
|
||||
|
||||
# TODO: Refactor this futher to not have downstream users expect Optionals - this should just error
|
||||
try:
|
||||
return self.server.agent_manager.get_agent_by_name(agent_name=agent_name, actor=self.user).id
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
# memory
|
||||
def get_in_context_memory(self, agent_id: str) -> Memory:
|
||||
@ -2370,7 +2369,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
memory (Memory): In-context memory of the agent
|
||||
"""
|
||||
memory = self.server.get_agent_memory(agent_id=agent_id)
|
||||
memory = self.server.get_agent_memory(agent_id=agent_id, actor=self.user)
|
||||
return memory
|
||||
|
||||
def get_core_memory(self, agent_id: str) -> Memory:
|
||||
@ -2388,7 +2387,7 @@ class LocalClient(AbstractClient):
|
||||
|
||||
"""
|
||||
# TODO: implement this (not sure what it should look like)
|
||||
memory = self.server.update_agent_core_memory(user_id=self.user_id, agent_id=agent_id, label=section, value=value)
|
||||
memory = self.server.update_agent_core_memory(agent_id=agent_id, label=section, value=value, actor=self.user)
|
||||
return memory
|
||||
|
||||
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
|
||||
@ -2402,7 +2401,7 @@ class LocalClient(AbstractClient):
|
||||
summary (ArchivalMemorySummary): Summary of the archival memory
|
||||
|
||||
"""
|
||||
return self.server.get_archival_memory_summary(agent_id=agent_id)
|
||||
return self.server.get_archival_memory_summary(agent_id=agent_id, actor=self.user)
|
||||
|
||||
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
|
||||
"""
|
||||
@ -2414,7 +2413,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
summary (RecallMemorySummary): Summary of the recall memory
|
||||
"""
|
||||
return self.server.get_recall_memory_summary(agent_id=agent_id)
|
||||
return self.server.get_recall_memory_summary(agent_id=agent_id, actor=self.user)
|
||||
|
||||
def get_in_context_messages(self, agent_id: str) -> List[Message]:
|
||||
"""
|
||||
@ -2426,7 +2425,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
messages (List[Message]): List of in-context messages
|
||||
"""
|
||||
return self.server.get_in_context_messages(agent_id=agent_id)
|
||||
return self.server.get_in_context_messages(agent_id=agent_id, actor=self.user)
|
||||
|
||||
# agent interactions
|
||||
|
||||
@ -2446,11 +2445,7 @@ class LocalClient(AbstractClient):
|
||||
response (LettaResponse): Response from the agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
usage = self.server.send_messages(user_id=self.user_id, agent_id=agent_id, messages=messages)
|
||||
|
||||
# auto-save
|
||||
if self.auto_save:
|
||||
self.save()
|
||||
usage = self.server.send_messages(actor=self.user, agent_id=agent_id, messages=messages)
|
||||
|
||||
# format messages
|
||||
return LettaResponse(messages=messages, usage=usage)
|
||||
@ -2490,15 +2485,11 @@ class LocalClient(AbstractClient):
|
||||
self.interface.clear()
|
||||
|
||||
usage = self.server.send_messages(
|
||||
user_id=self.user_id,
|
||||
actor=self.user,
|
||||
agent_id=agent_id,
|
||||
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
|
||||
)
|
||||
|
||||
# auto-save
|
||||
if self.auto_save:
|
||||
self.save()
|
||||
|
||||
## TODO: need to make sure date/timestamp is propely passed
|
||||
## TODO: update self.interface.to_list() to return actual Message objects
|
||||
## here, the message objects will have faulty created_by timestamps
|
||||
@ -2547,16 +2538,9 @@ class LocalClient(AbstractClient):
|
||||
self.interface.clear()
|
||||
usage = self.server.run_command(user_id=self.user_id, agent_id=agent_id, command=command)
|
||||
|
||||
# auto-save
|
||||
if self.auto_save:
|
||||
self.save()
|
||||
|
||||
# NOTE: messages/usage may be empty, depending on the command
|
||||
return LettaResponse(messages=self.interface.to_list(), usage=usage)
|
||||
|
||||
def save(self):
|
||||
self.server.save_agents()
|
||||
|
||||
# archival memory
|
||||
|
||||
# humans / personas
|
||||
@ -3036,7 +3020,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
sources (List[Source]): List of sources
|
||||
"""
|
||||
return self.server.list_attached_sources(agent_id=agent_id)
|
||||
return self.server.agent_manager.list_attached_sources(agent_id=agent_id, actor=self.user)
|
||||
|
||||
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
||||
"""
|
||||
@ -3080,7 +3064,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
passages (List[Passage]): List of inserted passages
|
||||
"""
|
||||
return self.server.insert_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_contents=memory)
|
||||
return self.server.insert_archival_memory(agent_id=agent_id, memory_contents=memory, actor=self.user)
|
||||
|
||||
def delete_archival_memory(self, agent_id: str, memory_id: str):
|
||||
"""
|
||||
@ -3090,7 +3074,7 @@ class LocalClient(AbstractClient):
|
||||
agent_id (str): ID of the agent
|
||||
memory_id (str): ID of the memory
|
||||
"""
|
||||
self.server.delete_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_id=memory_id)
|
||||
self.server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=self.user)
|
||||
|
||||
def get_archival_memory(
|
||||
self, agent_id: str, before: Optional[str] = None, after: Optional[str] = None, limit: Optional[int] = 1000
|
||||
@ -3349,8 +3333,8 @@ class LocalClient(AbstractClient):
|
||||
block_req = Block(**create_block.model_dump())
|
||||
block = self.server.block_manager.create_or_update_block(actor=self.user, block=block_req)
|
||||
# Link the block to the agent
|
||||
updated_memory = self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block.id)
|
||||
return updated_memory
|
||||
agent = self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=self.user)
|
||||
return agent.memory
|
||||
|
||||
def link_agent_memory_block(self, agent_id: str, block_id: str) -> Memory:
|
||||
"""
|
||||
@ -3363,7 +3347,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
"""
|
||||
return self.server.link_block_to_agent_memory(user_id=self.user_id, agent_id=agent_id, block_id=block_id)
|
||||
return self.server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=self.user)
|
||||
|
||||
def remove_agent_memory_block(self, agent_id: str, block_label: str) -> Memory:
|
||||
"""
|
||||
@ -3376,7 +3360,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
memory (Memory): The updated memory
|
||||
"""
|
||||
return self.server.unlink_block_from_agent_memory(user_id=self.user_id, agent_id=agent_id, block_label=block_label)
|
||||
return self.server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=self.user)
|
||||
|
||||
def get_agent_memory_blocks(self, agent_id: str) -> List[Block]:
|
||||
"""
|
||||
@ -3388,8 +3372,8 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
blocks (List[Block]): The blocks in the agent's core memory
|
||||
"""
|
||||
block_ids = self.server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id)
|
||||
return [self.server.block_manager.get_block_by_id(block_id, actor=self.user) for block_id in block_ids]
|
||||
agent = self.server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=self.user)
|
||||
return agent.memory.blocks
|
||||
|
||||
def get_agent_memory_block(self, agent_id: str, label: str) -> Block:
|
||||
"""
|
||||
@ -3402,8 +3386,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
block (Block): The block corresponding to the label
|
||||
"""
|
||||
block_id = self.server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=label)
|
||||
return self.server.block_manager.get_block_by_id(block_id, actor=self.user)
|
||||
return self.server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=label, actor=self.user)
|
||||
|
||||
def update_agent_memory_block(
|
||||
self,
|
||||
|
161
letta/config.py
161
letta/config.py
@ -1,12 +1,9 @@
|
||||
import configparser
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import letta
|
||||
import letta.utils as utils
|
||||
from letta.constants import (
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
@ -16,7 +13,6 @@ from letta.constants import (
|
||||
LETTA_DIR,
|
||||
)
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import PersistedAgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
@ -312,160 +308,3 @@ class LettaConfig:
|
||||
for folder in folders:
|
||||
if not os.path.exists(os.path.join(LETTA_DIR, folder)):
|
||||
os.makedirs(os.path.join(LETTA_DIR, folder))
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentConfig:
|
||||
"""
|
||||
|
||||
NOTE: this is a deprecated class, use AgentState instead. This class is only used for backcompatibility.
|
||||
Configuration for a specific instance of an agent
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
persona,
|
||||
human,
|
||||
# model info
|
||||
model=None,
|
||||
model_endpoint_type=None,
|
||||
model_endpoint=None,
|
||||
model_wrapper=None,
|
||||
context_window=None,
|
||||
# embedding info
|
||||
embedding_endpoint_type=None,
|
||||
embedding_endpoint=None,
|
||||
embedding_model=None,
|
||||
embedding_dim=None,
|
||||
embedding_chunk_size=None,
|
||||
# other
|
||||
preset=None,
|
||||
data_sources=None,
|
||||
# agent info
|
||||
agent_config_path=None,
|
||||
name=None,
|
||||
create_time=None,
|
||||
letta_version=None,
|
||||
# functions
|
||||
functions=None, # schema definitions ONLY (linked at runtime)
|
||||
):
|
||||
|
||||
assert name, f"Agent name must be provided"
|
||||
self.name = name
|
||||
|
||||
config = LettaConfig.load() # get default values
|
||||
self.persona = config.persona if persona is None else persona
|
||||
self.human = config.human if human is None else human
|
||||
self.preset = config.preset if preset is None else preset
|
||||
self.context_window = config.default_llm_config.context_window if context_window is None else context_window
|
||||
self.model = config.default_llm_config.model if model is None else model
|
||||
self.model_endpoint_type = config.default_llm_config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type
|
||||
self.model_endpoint = config.default_llm_config.model_endpoint if model_endpoint is None else model_endpoint
|
||||
self.model_wrapper = config.default_llm_config.model_wrapper if model_wrapper is None else model_wrapper
|
||||
self.llm_config = LLMConfig(
|
||||
model=self.model,
|
||||
model_endpoint_type=self.model_endpoint_type,
|
||||
model_endpoint=self.model_endpoint,
|
||||
model_wrapper=self.model_wrapper,
|
||||
context_window=self.context_window,
|
||||
)
|
||||
self.embedding_endpoint_type = (
|
||||
config.default_embedding_config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type
|
||||
)
|
||||
self.embedding_endpoint = config.default_embedding_config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint
|
||||
self.embedding_model = config.default_embedding_config.embedding_model if embedding_model is None else embedding_model
|
||||
self.embedding_dim = config.default_embedding_config.embedding_dim if embedding_dim is None else embedding_dim
|
||||
self.embedding_chunk_size = (
|
||||
config.default_embedding_config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size
|
||||
)
|
||||
self.embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type=self.embedding_endpoint_type,
|
||||
embedding_endpoint=self.embedding_endpoint,
|
||||
embedding_model=self.embedding_model,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_chunk_size=self.embedding_chunk_size,
|
||||
)
|
||||
|
||||
# agent metadata
|
||||
self.data_sources = data_sources if data_sources is not None else []
|
||||
self.create_time = create_time if create_time is not None else utils.get_local_time()
|
||||
if letta_version is None:
|
||||
import letta
|
||||
|
||||
self.letta_version = letta.__version__
|
||||
else:
|
||||
self.letta_version = letta_version
|
||||
|
||||
# functions
|
||||
self.functions = functions
|
||||
|
||||
# save agent config
|
||||
self.agent_config_path = (
|
||||
os.path.join(LETTA_DIR, "agents", self.name, "config.json") if agent_config_path is None else agent_config_path
|
||||
)
|
||||
|
||||
def attach_data_source(self, data_source: str):
|
||||
# TODO: add warning that only once source can be attached
|
||||
# i.e. previous source will be overriden
|
||||
self.data_sources.append(data_source)
|
||||
self.save()
|
||||
|
||||
def save_dir(self):
|
||||
return os.path.join(LETTA_DIR, "agents", self.name)
|
||||
|
||||
def save_state_dir(self):
|
||||
# directory to save agent state
|
||||
return os.path.join(LETTA_DIR, "agents", self.name, "agent_state")
|
||||
|
||||
def save_persistence_manager_dir(self):
|
||||
# directory to save persistent manager state
|
||||
return os.path.join(LETTA_DIR, "agents", self.name, "persistence_manager")
|
||||
|
||||
def save_agent_index_dir(self):
|
||||
# save llama index inside of persistent manager directory
|
||||
return os.path.join(self.save_persistence_manager_dir(), "index")
|
||||
|
||||
def save(self):
|
||||
# save state of persistence manager
|
||||
os.makedirs(os.path.join(LETTA_DIR, "agents", self.name), exist_ok=True)
|
||||
# save version
|
||||
self.letta_version = letta.__version__
|
||||
with open(self.agent_config_path, "w", encoding="utf-8") as f:
|
||||
json.dump(vars(self), f, indent=4)
|
||||
|
||||
def to_agent_state(self):
|
||||
return PersistedAgentState(
|
||||
name=self.name,
|
||||
preset=self.preset,
|
||||
persona=self.persona,
|
||||
human=self.human,
|
||||
llm_config=self.llm_config,
|
||||
embedding_config=self.embedding_config,
|
||||
create_time=self.create_time,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def exists(name: str):
|
||||
"""Check if agent config exists"""
|
||||
agent_config_path = os.path.join(LETTA_DIR, "agents", name)
|
||||
return os.path.exists(agent_config_path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, name: str):
|
||||
"""Load agent config from JSON file"""
|
||||
agent_config_path = os.path.join(LETTA_DIR, "agents", name, "config.json")
|
||||
assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}"
|
||||
with open(agent_config_path, "r", encoding="utf-8") as f:
|
||||
agent_config = json.load(f)
|
||||
# allow compatibility accross versions
|
||||
try:
|
||||
class_args = inspect.getargspec(cls.__init__).args
|
||||
except AttributeError:
|
||||
# https://github.com/pytorch/pytorch/issues/15344
|
||||
class_args = inspect.getfullargspec(cls.__init__).args
|
||||
agent_fields = list(agent_config.keys())
|
||||
for key in agent_fields:
|
||||
if key not in class_args:
|
||||
utils.printd(f"Removing missing argument {key} from agent config")
|
||||
del agent_config[key]
|
||||
return cls(**agent_config)
|
||||
|
@ -130,11 +130,11 @@ def run_agent_loop(
|
||||
# updated agent save functions
|
||||
if user_input.lower() == "/exit":
|
||||
# letta_agent.save()
|
||||
agent.save_agent(letta_agent, ms)
|
||||
agent.save_agent(letta_agent)
|
||||
break
|
||||
elif user_input.lower() == "/save" or user_input.lower() == "/savechat":
|
||||
# letta_agent.save()
|
||||
agent.save_agent(letta_agent, ms)
|
||||
agent.save_agent(letta_agent)
|
||||
continue
|
||||
elif user_input.lower() == "/attach":
|
||||
# TODO: check if agent already has it
|
||||
@ -394,7 +394,7 @@ def run_agent_loop(
|
||||
token_warning = step_response.in_context_memory_warning
|
||||
step_response.usage
|
||||
|
||||
agent.save_agent(letta_agent, ms)
|
||||
agent.save_agent(letta_agent)
|
||||
skip_next_user_input = False
|
||||
if token_warning:
|
||||
user_message = system.get_token_limit_warning()
|
||||
|
@ -1,23 +1,13 @@
|
||||
import datetime
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, List, Tuple, Union
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
|
||||
from letta.embeddings import embedding_model, parse_and_chunk_text, query_embedding
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.utils import (
|
||||
count_tokens,
|
||||
extract_date_from_timestamp,
|
||||
get_local_time,
|
||||
printd,
|
||||
validate_date_format,
|
||||
)
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
|
||||
def get_memory_functions(cls: Memory) -> Dict[str, Callable]:
|
||||
@ -67,7 +57,6 @@ def summarize_messages(
|
||||
+ message_sequence_to_summarize[cutoff:]
|
||||
)
|
||||
|
||||
agent_state.user_id
|
||||
dummy_agent_id = agent_state.id
|
||||
message_sequence = []
|
||||
message_sequence.append(Message(agent_id=dummy_agent_id, role=MessageRole.system, text=summary_prompt))
|
||||
@ -79,7 +68,7 @@ def summarize_messages(
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
response = create(
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
user_id=agent_state.user_id,
|
||||
user_id=agent_state.created_by_id,
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
)
|
||||
|
@ -2,23 +2,18 @@
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import JSON, Column, DateTime, Index, String, TypeDecorator
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import JSON, Column, Index, String, TypeDecorator
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.orm.base import Base
|
||||
from letta.schemas.agent import PersistedAgentState
|
||||
from letta.schemas.api_key import APIKey
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types, printd
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class LLMConfigColumn(TypeDecorator):
|
||||
@ -65,18 +60,6 @@ class EmbeddingConfigColumn(TypeDecorator):
|
||||
return value
|
||||
|
||||
|
||||
# TODO: eventually store providers?
|
||||
# class Provider(Base):
|
||||
# __tablename__ = "providers"
|
||||
# __table_args__ = {"extend_existing": True}
|
||||
#
|
||||
# id = Column(String, primary_key=True)
|
||||
# name = Column(String, nullable=False)
|
||||
# created_at = Column(DateTime(timezone=True))
|
||||
# api_key = Column(String, nullable=False)
|
||||
# base_url = Column(String, nullable=False)
|
||||
|
||||
|
||||
class APIKeyModel(Base):
|
||||
"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""
|
||||
|
||||
@ -113,115 +96,6 @@ def generate_api_key(prefix="sk-", length=51) -> str:
|
||||
return new_key
|
||||
|
||||
|
||||
class ToolRulesColumn(TypeDecorator):
|
||||
"""Custom type for storing a list of ToolRules as JSON"""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Convert a list of ToolRules to JSON-serializable format."""
|
||||
if value:
|
||||
data = [rule.model_dump() for rule in value]
|
||||
for d in data:
|
||||
d["type"] = d["type"].value
|
||||
|
||||
for d in data:
|
||||
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
|
||||
return data
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
|
||||
"""Convert JSON back to a list of ToolRules."""
|
||||
if value:
|
||||
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
|
||||
if rule_type == ToolRuleType.run_first:
|
||||
return InitToolRule(**data)
|
||||
elif rule_type == ToolRuleType.exit_loop:
|
||||
return TerminalToolRule(**data)
|
||||
elif rule_type == ToolRuleType.constrain_child_tools:
|
||||
rule = ChildToolRule(**data)
|
||||
return rule
|
||||
else:
|
||||
raise ValueError(f"Unknown tool rule type: {rule_type}")
|
||||
|
||||
|
||||
class AgentModel(Base):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
__tablename__ = "agents"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
description = Column(String)
|
||||
|
||||
# state (context compilation)
|
||||
message_ids = Column(JSON)
|
||||
system = Column(String)
|
||||
|
||||
# configs
|
||||
agent_type = Column(String)
|
||||
llm_config = Column(LLMConfigColumn)
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
|
||||
# state
|
||||
metadata_ = Column(JSON)
|
||||
|
||||
# tools
|
||||
tool_names = Column(JSON)
|
||||
tool_rules = Column(ToolRulesColumn)
|
||||
|
||||
Index(__tablename__ + "_idx_user", user_id),
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Agent(id='{self.id}', name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> PersistedAgentState:
|
||||
agent_state = PersistedAgentState(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
created_at=self.created_at,
|
||||
description=self.description,
|
||||
message_ids=self.message_ids,
|
||||
system=self.system,
|
||||
tool_names=self.tool_names,
|
||||
tool_rules=self.tool_rules,
|
||||
agent_type=self.agent_type,
|
||||
llm_config=self.llm_config,
|
||||
embedding_config=self.embedding_config,
|
||||
metadata_=self.metadata_,
|
||||
)
|
||||
return agent_state
|
||||
|
||||
|
||||
class AgentSourceMappingModel(Base):
|
||||
"""Stores mapping between agent -> source"""
|
||||
|
||||
__tablename__ = "agent_source_mapping"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
source_id = Column(String, nullable=False)
|
||||
Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
|
||||
|
||||
|
||||
class MetadataStore:
|
||||
uri: Optional[str] = None
|
||||
|
||||
@ -281,127 +155,3 @@ class MetadataStore:
|
||||
results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all()
|
||||
tokens = [r.to_record() for r in results]
|
||||
return tokens
|
||||
|
||||
@enforce_types
|
||||
def create_agent(self, agent: PersistedAgentState):
|
||||
# insert into agent table
|
||||
# make sure agent.name does not already exist for user user_id
|
||||
with self.session_maker() as session:
|
||||
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
|
||||
raise ValueError(f"Agent with name {agent.name} already exists")
|
||||
fields = vars(agent)
|
||||
# fields["memory"] = agent.memory.to_dict()
|
||||
# if "_internal_memory" in fields:
|
||||
# del fields["_internal_memory"]
|
||||
# else:
|
||||
# warnings.warn(f"Agent {agent.id} has no _internal_memory field")
|
||||
if "tags" in fields:
|
||||
del fields["tags"]
|
||||
# else:
|
||||
# warnings.warn(f"Agent {agent.id} has no tags field")
|
||||
session.add(AgentModel(**fields))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent: PersistedAgentState):
|
||||
with self.session_maker() as session:
|
||||
fields = vars(agent)
|
||||
# if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
|
||||
# fields["memory"] = agent.memory.to_dict()
|
||||
# if "_internal_memory" in fields:
|
||||
# del fields["_internal_memory"]
|
||||
# else:
|
||||
# warnings.warn(f"Agent {agent.id} has no _internal_memory field")
|
||||
if "tags" in fields:
|
||||
del fields["tags"]
|
||||
# else:
|
||||
# warnings.warn(f"Agent {agent.id} has no tags field")
|
||||
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManager):
|
||||
# TODO: Remove this once Agent is on the ORM
|
||||
# TODO: To prevent unbounded growth
|
||||
per_agent_lock_manager.clear_lock(agent_id)
|
||||
|
||||
with self.session_maker() as session:
|
||||
|
||||
# delete agents
|
||||
session.query(AgentModel).filter(AgentModel.id == agent_id).delete()
|
||||
|
||||
# delete mappings
|
||||
session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).delete()
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_agents(self, user_id: str) -> List[PersistedAgentState]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def get_agent(
|
||||
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
|
||||
) -> Optional[PersistedAgentState]:
|
||||
with self.session_maker() as session:
|
||||
if agent_id:
|
||||
results = session.query(AgentModel).filter(AgentModel.id == agent_id).all()
|
||||
else:
|
||||
assert agent_name is not None and user_id is not None, "Must provide either agent_id or agent_name"
|
||||
results = session.query(AgentModel).filter(AgentModel.name == agent_name).filter(AgentModel.user_id == user_id).all()
|
||||
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
|
||||
return results[0].to_record()
|
||||
|
||||
# agent source metadata
|
||||
@enforce_types
|
||||
def attach_source(self, user_id: str, agent_id: str, source_id: str):
|
||||
with self.session_maker() as session:
|
||||
# TODO: remove this (is a hack)
|
||||
mapping_id = f"{user_id}-{agent_id}-{source_id}"
|
||||
existing = session.query(AgentSourceMappingModel).filter(
|
||||
AgentSourceMappingModel.id == mapping_id
|
||||
).first()
|
||||
|
||||
if existing is None:
|
||||
# Only create if it doesn't exist
|
||||
session.add(AgentSourceMappingModel(
|
||||
id=mapping_id,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
source_id=source_id
|
||||
))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_attached_source_ids(self, agent_id: str) -> List[str]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
|
||||
return [r.source_id for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_attached_agents(self, source_id: str) -> List[str]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
|
||||
|
||||
agent_ids = []
|
||||
# make sure agent exists
|
||||
for r in results:
|
||||
agent = self.get_agent(agent_id=r.agent_id)
|
||||
if agent:
|
||||
agent_ids.append(r.agent_id)
|
||||
else:
|
||||
printd(f"Warning: agent {r.agent_id} does not exist but exists in mapping database. This should never happen.")
|
||||
return agent_ids
|
||||
|
||||
@enforce_types
|
||||
def detach_source(self, agent_id: str, source_id: str):
|
||||
with self.session_maker() as session:
|
||||
session.query(AgentSourceMappingModel).filter(
|
||||
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
||||
).delete()
|
||||
session.commit()
|
||||
|
@ -85,6 +85,6 @@ class O1Agent(Agent):
|
||||
if step_response.messages[-1].name == "send_final_message":
|
||||
break
|
||||
if ms:
|
||||
save_agent(self, ms)
|
||||
save_agent(self)
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
@ -130,7 +130,7 @@ class OfflineMemoryAgent(Agent):
|
||||
# extras
|
||||
first_message_verify_mono: bool = False,
|
||||
max_memory_rethinks: int = 10,
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
initial_message_sequence: Optional[List[Message]] = None,
|
||||
):
|
||||
super().__init__(interface, agent_state, user, initial_message_sequence=initial_message_sequence)
|
||||
self.first_message_verify_mono = first_message_verify_mono
|
||||
@ -173,6 +173,6 @@ class OfflineMemoryAgent(Agent):
|
||||
self.interface.step_complete()
|
||||
|
||||
if ms:
|
||||
save_agent(self, ms)
|
||||
save_agent(self)
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
@ -1,3 +1,4 @@
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.base import Base
|
||||
from letta.orm.block import Block
|
||||
@ -9,6 +10,7 @@ from letta.orm.organization import Organization
|
||||
from letta.orm.passage import Passage
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.sources_agents import SourcesAgents
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
from letta.orm.user import User
|
||||
|
196
letta/orm/agent.py
Normal file
196
letta/orm/agent.py
Normal file
@ -0,0 +1,196 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from sqlalchemy import JSON, String, TypeDecorator, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.block import Block
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
InitToolRule,
|
||||
TerminalToolRule,
|
||||
ToolRule,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.tool import Tool
|
||||
|
||||
|
||||
class LLMConfigColumn(TypeDecorator):
|
||||
"""Custom type for storing LLMConfig as JSON"""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
# return vars(value)
|
||||
if isinstance(value, LLMConfig):
|
||||
return value.model_dump()
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return LLMConfig(**value)
|
||||
return value
|
||||
|
||||
|
||||
class EmbeddingConfigColumn(TypeDecorator):
|
||||
"""Custom type for storing EmbeddingConfig as JSON"""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
# return vars(value)
|
||||
if isinstance(value, EmbeddingConfig):
|
||||
return value.model_dump()
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return EmbeddingConfig(**value)
|
||||
return value
|
||||
|
||||
|
||||
class ToolRulesColumn(TypeDecorator):
|
||||
"""Custom type for storing a list of ToolRules as JSON"""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Convert a list of ToolRules to JSON-serializable format."""
|
||||
if value:
|
||||
data = [rule.model_dump() for rule in value]
|
||||
for d in data:
|
||||
d["type"] = d["type"].value
|
||||
|
||||
for d in data:
|
||||
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
|
||||
return data
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
|
||||
"""Convert JSON back to a list of ToolRules."""
|
||||
if value:
|
||||
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule]:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
|
||||
if rule_type == ToolRuleType.run_first:
|
||||
return InitToolRule(**data)
|
||||
elif rule_type == ToolRuleType.exit_loop:
|
||||
return TerminalToolRule(**data)
|
||||
elif rule_type == ToolRuleType.constrain_child_tools:
|
||||
rule = ChildToolRule(**data)
|
||||
return rule
|
||||
else:
|
||||
raise ValueError(f"Unknown tool rule type: {rule_type}")
|
||||
|
||||
|
||||
class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
__tablename__ = "agents"
|
||||
__pydantic_model__ = PydanticAgentState
|
||||
__table_args__ = (UniqueConstraint("organization_id", "name", name="unique_org_agent_name"),)
|
||||
|
||||
# agent generates its own id
|
||||
# TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# TODO: Move this in this PR? at the very end?
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agent-{uuid.uuid4()}")
|
||||
|
||||
# Descriptor fields
|
||||
agent_type: Mapped[Optional[AgentType]] = mapped_column(String, nullable=True, doc="The type of Agent")
|
||||
name: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="a human-readable identifier for an agent, non-unique.")
|
||||
description: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The description of the agent.")
|
||||
|
||||
# System prompt
|
||||
system: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The system prompt used by the agent.")
|
||||
|
||||
# In context memory
|
||||
# TODO: This should be a separate mapping table
|
||||
# This is dangerously flexible with the JSON type
|
||||
message_ids: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True, doc="List of message IDs in in-context memory.")
|
||||
|
||||
# Metadata and configs
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="metadata for the agent.")
|
||||
llm_config: Mapped[Optional[LLMConfig]] = mapped_column(
|
||||
LLMConfigColumn, nullable=True, doc="the LLM backend configuration object for this agent."
|
||||
)
|
||||
embedding_config: Mapped[Optional[EmbeddingConfig]] = mapped_column(
|
||||
EmbeddingConfigColumn, doc="the embedding configuration object for this agent."
|
||||
)
|
||||
|
||||
# Tool rules
|
||||
tool_rules: Mapped[Optional[List[ToolRule]]] = mapped_column(ToolRulesColumn, doc="the tool rules for this agent.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents")
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", secondary="tools_agents", lazy="selectin", passive_deletes=True)
|
||||
sources: Mapped[List["Source"]] = relationship("Source", secondary="sources_agents", lazy="selectin")
|
||||
core_memory: Mapped[List["Block"]] = relationship("Block", secondary="blocks_agents", lazy="selectin")
|
||||
messages: Mapped[List["Message"]] = relationship(
|
||||
"Message",
|
||||
back_populates="agent",
|
||||
lazy="selectin",
|
||||
cascade="all, delete-orphan", # Ensure messages are deleted when the agent is deleted
|
||||
passive_deletes=True,
|
||||
)
|
||||
tags: Mapped[List["AgentsTags"]] = relationship(
|
||||
"AgentsTags",
|
||||
back_populates="agent",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
doc="Tags associated with the agent.",
|
||||
)
|
||||
# passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="agent", lazy="selectin")
|
||||
|
||||
def to_pydantic(self) -> PydanticAgentState:
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
state = {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"message_ids": self.message_ids,
|
||||
"tools": self.tools,
|
||||
"sources": self.sources,
|
||||
"tags": [t.tag for t in self.tags],
|
||||
"tool_rules": self.tool_rules,
|
||||
"system": self.system,
|
||||
"agent_type": self.agent_type,
|
||||
"llm_config": self.llm_config,
|
||||
"embedding_config": self.embedding_config,
|
||||
"metadata_": self.metadata_,
|
||||
"memory": Memory(blocks=[b.to_pydantic() for b in self.core_memory]),
|
||||
"created_by_id": self.created_by_id,
|
||||
"last_updated_by_id": self.last_updated_by_id,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
return self.__pydantic_model__(**state)
|
@ -1,28 +1,20 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class AgentsTags(SqlalchemyBase, OrganizationMixin):
|
||||
"""Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering."""
|
||||
|
||||
class AgentsTags(Base):
|
||||
__tablename__ = "agents_tags"
|
||||
__pydantic_model__ = PydanticAgentsTags
|
||||
__table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),)
|
||||
|
||||
# The agent associated with this tag
|
||||
agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
# # agent generates its own id
|
||||
# # TODO: We want to migrate all the ORM models to do this, so we will need to move this to the SqlalchemyBase
|
||||
# # TODO: Move this in this PR? at the very end?
|
||||
# id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"agents_tags-{uuid.uuid4()}")
|
||||
|
||||
# The name of the tag
|
||||
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.")
|
||||
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
tag: Mapped[str] = mapped_column(String, doc="The name of the tag associated with the agent.", primary_key=True)
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags")
|
||||
# Relationships
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="tags")
|
||||
|
@ -1,16 +1,17 @@
|
||||
from typing import TYPE_CHECKING, Optional, Type
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy import JSON, BigInteger, Integer, UniqueConstraint, event
|
||||
from sqlalchemy.orm import Mapped, attributes, mapped_column, relationship
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||
from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import Human, Persona
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm import BlocksAgents, Organization
|
||||
from letta.orm import Organization
|
||||
|
||||
|
||||
class Block(OrganizationMixin, SqlalchemyBase):
|
||||
@ -35,7 +36,6 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
|
||||
# relationships
|
||||
organization: Mapped[Optional["Organization"]] = relationship("Organization")
|
||||
blocks_agents: Mapped[list["BlocksAgents"]] = relationship("BlocksAgents", back_populates="block", cascade="all, delete")
|
||||
|
||||
def to_pydantic(self) -> Type:
|
||||
match self.label:
|
||||
@ -46,3 +46,28 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
case _:
|
||||
Schema = PydanticBlock
|
||||
return Schema.model_validate(self)
|
||||
|
||||
|
||||
@event.listens_for(Block, "after_update") # Changed from 'before_update'
|
||||
def block_before_update(mapper, connection, target):
|
||||
"""Handle updating BlocksAgents when a block's label changes."""
|
||||
label_history = attributes.get_history(target, "label")
|
||||
if not label_history.has_changes():
|
||||
return
|
||||
|
||||
blocks_agents = BlocksAgents.__table__
|
||||
connection.execute(
|
||||
blocks_agents.update()
|
||||
.where(blocks_agents.c.block_id == target.id, blocks_agents.c.block_label == label_history.deleted[0])
|
||||
.values(block_label=label_history.added[0])
|
||||
)
|
||||
|
||||
|
||||
@event.listens_for(Block, "before_insert")
|
||||
@event.listens_for(Block, "before_update")
|
||||
def validate_value_length(mapper, connection, target):
|
||||
"""Ensure the value length does not exceed the limit."""
|
||||
if target.value and len(target.value) > target.limit:
|
||||
raise ValueError(
|
||||
f"Value length ({len(target.value)}) exceeds the limit ({target.limit}) for block with label '{target.label}' and id '{target.id}'."
|
||||
)
|
||||
|
@ -1,15 +1,13 @@
|
||||
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class BlocksAgents(SqlalchemyBase):
|
||||
class BlocksAgents(Base):
|
||||
"""Agents must have one or many blocks to make up their core memory."""
|
||||
|
||||
__tablename__ = "blocks_agents"
|
||||
__pydantic_model__ = PydanticBlocksAgents
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"agent_id",
|
||||
@ -17,16 +15,12 @@ class BlocksAgents(SqlalchemyBase):
|
||||
name="unique_label_per_agent",
|
||||
),
|
||||
ForeignKeyConstraint(
|
||||
["block_id", "block_label"],
|
||||
["block.id", "block.label"],
|
||||
name="fk_block_id_label",
|
||||
["block_id", "block_label"], ["block.id", "block.label"], name="fk_block_id_label", deferrable=True, initially="DEFERRED"
|
||||
),
|
||||
UniqueConstraint("agent_id", "block_id", name="unique_agent_block"),
|
||||
)
|
||||
|
||||
# unique agent + block label
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
block_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
block_label: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
# relationships
|
||||
block: Mapped["Block"] = relationship("Block", back_populates="blocks_agents")
|
||||
|
@ -59,6 +59,5 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
||||
tool_call_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="ID of the tool call")
|
||||
|
||||
# Relationships
|
||||
# TODO: Add in after Agent ORM is created
|
||||
# agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
||||
|
@ -7,6 +7,7 @@ from letta.schemas.organization import Organization as PydanticOrganization
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
from letta.orm.agent import Agent
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
@ -25,7 +26,6 @@ class Organization(SqlalchemyBase):
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
||||
blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
|
||||
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
|
||||
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
|
||||
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
|
||||
sandbox_configs: Mapped[List["SandboxConfig"]] = relationship(
|
||||
"SandboxConfig", back_populates="organization", cascade="all, delete-orphan"
|
||||
@ -36,10 +36,5 @@ class Organization(SqlalchemyBase):
|
||||
|
||||
# relationships
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
# TODO: Map these relationships later when we actually make these models
|
||||
# below is just a suggestion
|
||||
# agents: Mapped[List["Agent"]] = relationship("Agent", back_populates="organization", cascade="all, delete-orphan")
|
||||
# tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
||||
# documents: Mapped[List["Document"]] = relationship("Document", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
@ -1,19 +1,18 @@
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from sqlalchemy import Column, String, DateTime, JSON, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.types import TypeDecorator, BINARY
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
import base64
|
||||
|
||||
from letta.orm.source import EmbeddingConfigColumn
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.orm.mixins import FileMixin, OrganizationMixin
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from sqlalchemy import JSON, Column, DateTime, ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.types import BINARY, TypeDecorator
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.orm.mixins import FileMixin, OrganizationMixin
|
||||
from letta.orm.source import EmbeddingConfigColumn
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.settings import settings
|
||||
|
||||
config = LettaConfig()
|
||||
@ -21,8 +20,10 @@ config = LettaConfig()
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
|
||||
class CommonVector(TypeDecorator):
|
||||
"""Common type for representing vectors in SQLite"""
|
||||
|
||||
impl = BINARY
|
||||
cache_ok = True
|
||||
|
||||
@ -43,10 +44,12 @@ class CommonVector(TypeDecorator):
|
||||
value = base64.b64decode(value)
|
||||
return np.frombuffer(value, dtype=np.float32)
|
||||
|
||||
# TODO: After migration to Passage, will need to manually delete passages where files
|
||||
|
||||
# TODO: After migration to Passage, will need to manually delete passages where files
|
||||
# are deleted on web
|
||||
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
||||
"""Defines data model for storing Passages"""
|
||||
|
||||
__tablename__ = "passages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
__pydantic_model__ = PydanticPassage
|
||||
@ -59,6 +62,7 @@ class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
if settings.letta_pg_uri_no_default:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
else:
|
||||
embedding = Column(CommonVector)
|
||||
|
@ -48,4 +48,4 @@ class Source(SqlalchemyBase, OrganizationMixin):
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="sources")
|
||||
files: Mapped[List["Source"]] = relationship("FileMetadata", back_populates="source", cascade="all, delete-orphan")
|
||||
# agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
||||
agents: Mapped[List["Agent"]] = relationship("Agent", secondary="sources_agents", back_populates="sources")
|
||||
|
13
letta/orm/sources_agents.py
Normal file
13
letta/orm/sources_agents.py
Normal file
@ -0,0 +1,13 @@
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class SourcesAgents(Base):
|
||||
"""Agents can have zero to many sources"""
|
||||
|
||||
__tablename__ = "sources_agents"
|
||||
|
||||
agent_id: Mapped[String] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
source_id: Mapped[String] = mapped_column(String, ForeignKey("sources.id"), primary_key=True)
|
@ -1,7 +1,6 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
||||
import sqlite3
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional
|
||||
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
@ -9,12 +8,12 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance
|
||||
from letta.orm.errors import (
|
||||
ForeignKeyConstraintViolationError,
|
||||
NoResultFound,
|
||||
UniqueConstraintViolationError,
|
||||
)
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import BaseModel
|
||||
@ -64,11 +63,26 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
**kwargs,
|
||||
) -> List[Type["SqlalchemyBase"]]:
|
||||
) -> List["SqlalchemyBase"]:
|
||||
"""
|
||||
List records with cursor-based pagination, ordering by created_at.
|
||||
Cursor is an ID, but pagination is based on the cursor object's created_at value.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session
|
||||
cursor: ID of the last item seen (for pagination)
|
||||
start_date: Filter items after this date
|
||||
end_date: Filter items before this date
|
||||
limit: Maximum number of items to return
|
||||
query_text: Text to search for
|
||||
query_embedding: Vector to search for similar embeddings
|
||||
ascending: Sort direction
|
||||
tags: List of tags to filter by
|
||||
match_all_tags: If True, return items matching all tags. If False, match any tag.
|
||||
**kwargs: Additional filters to apply
|
||||
"""
|
||||
if start_date and end_date and start_date > end_date:
|
||||
raise ValueError("start_date must be earlier than or equal to end_date")
|
||||
@ -84,7 +98,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
query = select(cls)
|
||||
|
||||
# Apply filtering logic
|
||||
# Handle tag filtering if the model has tags
|
||||
if tags and hasattr(cls, "tags"):
|
||||
query = select(cls)
|
||||
|
||||
if match_all_tags:
|
||||
# Match ALL tags - use subqueries
|
||||
for tag in tags:
|
||||
subquery = select(cls.tags.property.mapper.class_.agent_id).where(cls.tags.property.mapper.class_.tag == tag)
|
||||
query = query.filter(cls.id.in_(subquery))
|
||||
else:
|
||||
# Match ANY tag - use join and filter
|
||||
query = (
|
||||
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).group_by(cls.id) # Deduplicate results
|
||||
)
|
||||
|
||||
# Group by primary key and all necessary columns to avoid JSON comparison
|
||||
query = query.group_by(cls.id)
|
||||
|
||||
# Apply filtering logic from kwargs
|
||||
for key, value in kwargs.items():
|
||||
column = getattr(cls, key)
|
||||
if isinstance(value, (list, tuple, set)):
|
||||
@ -98,9 +130,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at < end_date)
|
||||
|
||||
# Cursor-based pagination using created_at
|
||||
# TODO: There is a really nasty race condition issue here with Sqlite
|
||||
# TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
|
||||
# Cursor-based pagination
|
||||
if cursor_obj:
|
||||
if ascending:
|
||||
query = query.where(cls.created_at >= cursor_obj.created_at).where(
|
||||
@ -111,40 +141,34 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
|
||||
)
|
||||
|
||||
# Apply text search
|
||||
# Text search
|
||||
if query_text:
|
||||
from sqlalchemy import func
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
|
||||
# Apply embedding search (Passages)
|
||||
# Embedding search (for Passages)
|
||||
is_ordered = False
|
||||
if query_embedding:
|
||||
# check if embedding column exists. should only exist for passages
|
||||
if not hasattr(cls, "embedding"):
|
||||
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
||||
|
||||
|
||||
from letta.settings import settings
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
from pgvector.sqlalchemy import Vector
|
||||
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
from sqlalchemy import func
|
||||
|
||||
query_embedding_binary = adapt_array(query_embedding)
|
||||
query = query.order_by(
|
||||
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
||||
cls.created_at.asc(),
|
||||
cls.id.asc()
|
||||
func.cosine_distance(cls.embedding, query_embedding_binary).asc(), cls.created_at.asc(), cls.id.asc()
|
||||
)
|
||||
is_ordered = True
|
||||
|
||||
# Handle ordering and soft deletes
|
||||
# Handle soft deletes
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
# Apply ordering by created_at
|
||||
|
||||
# Apply ordering
|
||||
if not is_ordered:
|
||||
if ascending:
|
||||
query = query.order_by(cls.created_at, cls.id)
|
||||
@ -164,7 +188,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
) -> Type["SqlalchemyBase"]:
|
||||
) -> "SqlalchemyBase":
|
||||
"""The primary accessor for an ORM record.
|
||||
Args:
|
||||
db_session: the database session to use when retrieving the record
|
||||
@ -207,7 +231,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
raise NoResultFound(f"{cls.__name__} not found with {conditions_str}")
|
||||
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
|
||||
def create(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
if actor:
|
||||
@ -221,7 +245,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
except DBAPIError as e:
|
||||
self._handle_dbapi_error(e)
|
||||
|
||||
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
|
||||
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
|
||||
if actor:
|
||||
@ -245,7 +269,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
else:
|
||||
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
||||
|
||||
def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
|
||||
def update(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
logger.debug(f"Updating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
@ -388,14 +412,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
raise
|
||||
|
||||
@property
|
||||
def __pydantic_model__(self) -> Type["BaseModel"]:
|
||||
def __pydantic_model__(self) -> "BaseModel":
|
||||
raise NotImplementedError("Sqlalchemy models must declare a __pydantic_model__ property to be convertable.")
|
||||
|
||||
def to_pydantic(self) -> Type["BaseModel"]:
|
||||
def to_pydantic(self) -> "BaseModel":
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
return self.__pydantic_model__.model_validate(self)
|
||||
|
||||
def to_record(self) -> Type["BaseModel"]:
|
||||
def to_record(self) -> "BaseModel":
|
||||
"""Deprecated accessor for to_pydantic"""
|
||||
logger.warning("to_record is deprecated, use to_pydantic instead.")
|
||||
return self.to_pydantic()
|
||||
return self.to_pydantic()
|
||||
|
@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, UniqueConstraint, event
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
# TODO everything in functions should live in this model
|
||||
@ -11,7 +11,6 @@ from letta.schemas.tool import Tool as PydanticTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
|
||||
|
||||
class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
@ -42,20 +41,3 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||
tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
# Add event listener to update tool_name in ToolsAgents when Tool name changes
|
||||
@event.listens_for(Tool, "before_update")
|
||||
def update_tool_name_in_tools_agents(mapper, connection, target):
|
||||
"""Update tool_name in ToolsAgents when Tool name changes."""
|
||||
state = target._sa_instance_state
|
||||
history = state.get_history("name", passive=True)
|
||||
if not history.has_changes():
|
||||
return
|
||||
|
||||
# Get the new name and update all associated ToolsAgents records
|
||||
new_name = target.name
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
|
||||
connection.execute(ToolsAgents.__table__.update().where(ToolsAgents.tool_id == target.id).values(tool_name=new_name))
|
||||
|
@ -1,32 +1,15 @@
|
||||
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
|
||||
from letta.orm import Base
|
||||
|
||||
|
||||
class ToolsAgents(SqlalchemyBase):
|
||||
class ToolsAgents(Base):
|
||||
"""Agents can have one or many tools associated with them."""
|
||||
|
||||
__tablename__ = "tools_agents"
|
||||
__pydantic_model__ = PydanticToolsAgents
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"agent_id",
|
||||
"tool_name",
|
||||
name="unique_tool_per_agent",
|
||||
),
|
||||
ForeignKeyConstraint(
|
||||
["tool_id"],
|
||||
["tools.id"],
|
||||
name="fk_tool_id",
|
||||
),
|
||||
)
|
||||
__table_args__ = (UniqueConstraint("agent_id", "tool_id", name="unique_agent_tool"),)
|
||||
|
||||
# Each agent must have unique tool names
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
tool_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
tool_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
# relationships
|
||||
tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents")
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id", ondelete="CASCADE"), primary_key=True)
|
||||
tool_id: Mapped[str] = mapped_column(String, ForeignKey("tools.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
@ -20,10 +20,9 @@ class User(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
|
||||
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan")
|
||||
jobs: Mapped[List["Job"]] = relationship(
|
||||
"Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# TODO: Add this back later potentially
|
||||
# agents: Mapped[List["Agent"]] = relationship(
|
||||
# "Agent", secondary="users_agents", back_populates="users", doc="the agents associated with this user."
|
||||
# )
|
||||
# tokens: Mapped[List["Token"]] = relationship("Token", back_populates="user", doc="the tokens associated with this user.")
|
||||
|
@ -1,13 +1,11 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
@ -15,15 +13,15 @@ from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.utils import create_random_username
|
||||
|
||||
|
||||
class BaseAgent(LettaBase, validate_assignment=True):
|
||||
class BaseAgent(OrmMetadataBase, validate_assignment=True):
|
||||
__id_prefix__ = "agent"
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
|
||||
# metadata
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
user_id: Optional[str] = Field(None, description="The user id of the agent.")
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
@ -38,37 +36,7 @@ class AgentType(str, Enum):
|
||||
chat_only_agent = "chat_only_agent"
|
||||
|
||||
|
||||
class PersistedAgentState(BaseAgent, validate_assignment=True):
|
||||
# NOTE: this has been changed to represent the data stored in the ORM, NOT what is passed around internally or returned to the user
|
||||
id: str = BaseAgent.generate_id_field()
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
created_at: datetime = Field(..., description="The datetime the agent was created.", default_factory=datetime.now)
|
||||
|
||||
# in-context memory
|
||||
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
|
||||
# tools
|
||||
# TODO: move to ORM mapping
|
||||
tool_names: List[str] = Field(..., description="The tools used by the agent.")
|
||||
|
||||
# tool rules
|
||||
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
|
||||
|
||||
# system prompt
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
|
||||
# agent configuration
|
||||
agent_type: AgentType = Field(..., description="The type of agent.")
|
||||
|
||||
# llm information
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
validate_assignment = True
|
||||
|
||||
|
||||
class AgentState(PersistedAgentState):
|
||||
class AgentState(BaseAgent):
|
||||
"""
|
||||
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
|
||||
|
||||
@ -86,42 +54,53 @@ class AgentState(PersistedAgentState):
|
||||
"""
|
||||
|
||||
# NOTE: this is what is returned to the client and also what is used to initialize `Agent`
|
||||
id: str = BaseAgent.generate_id_field()
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
# tool rules
|
||||
tool_rules: Optional[List[ToolRule]] = Field(default=None, description="The list of tool rules.")
|
||||
|
||||
# in-context memory
|
||||
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
|
||||
|
||||
# system prompt
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
|
||||
# agent configuration
|
||||
agent_type: AgentType = Field(..., description="The type of agent.")
|
||||
|
||||
# llm information
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
||||
|
||||
# This is an object representing the in-process state of a running `Agent`
|
||||
# Field in this object can be theoretically edited by tools, and will be persisted by the ORM
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the agent.")
|
||||
|
||||
memory: Memory = Field(..., description="The in-context memory of the agent.")
|
||||
tools: List[Tool] = Field(..., description="The tools used by the agent.")
|
||||
sources: List[Source] = Field(..., description="The sources used by the agent.")
|
||||
tags: List[str] = Field(..., description="The tags associated with the agent.")
|
||||
# TODO: add in context message objects
|
||||
|
||||
def to_persisted_agent_state(self) -> PersistedAgentState:
|
||||
# turn back into persisted agent
|
||||
data = self.model_dump()
|
||||
del data["memory"]
|
||||
del data["tools"]
|
||||
del data["sources"]
|
||||
del data["tags"]
|
||||
return PersistedAgentState(**data)
|
||||
|
||||
|
||||
class CreateAgent(BaseAgent): #
|
||||
# all optional as server can generate defaults
|
||||
name: Optional[str] = Field(None, description="The name of the agent.")
|
||||
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
name: str = Field(default_factory=lambda: create_random_username(), description="The name of the agent.")
|
||||
|
||||
# memory creation
|
||||
memory_blocks: List[CreateBlock] = Field(
|
||||
# [CreateHuman(), CreatePersona()], description="The blocks to create in the agent's in-context memory."
|
||||
...,
|
||||
description="The blocks to create in the agent's in-context memory.",
|
||||
)
|
||||
|
||||
tools: List[str] = Field(BASE_TOOLS + BASE_MEMORY_TOOLS, description="The tools used by the agent.")
|
||||
# TODO: This is a legacy field and should be removed ASAP to force `tool_ids` usage
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.")
|
||||
source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.")
|
||||
block_ids: Optional[List[str]] = Field(None, description="The ids of the blocks used by the agent.")
|
||||
tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.")
|
||||
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
|
||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||
agent_type: AgentType = Field(AgentType.memgpt_agent, description="The type of agent.")
|
||||
agent_type: AgentType = Field(default_factory=lambda: AgentType.memgpt_agent, description="The type of agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
# Note: if this is None, then we'll populate with the standard "more human than human" initial message sequence
|
||||
@ -129,6 +108,7 @@ class CreateAgent(BaseAgent): #
|
||||
initial_message_sequence: Optional[List[MessageCreate]] = Field(
|
||||
None, description="The initial set of messages to put in the agent's in-context memory."
|
||||
)
|
||||
include_base_tools: bool = Field(True, description="The LLM configuration used by the agent.")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@ -156,18 +136,21 @@ class CreateAgent(BaseAgent): #
|
||||
return name
|
||||
|
||||
|
||||
class UpdateAgentState(BaseAgent):
|
||||
id: str = Field(..., description="The id of the agent.")
|
||||
class UpdateAgent(BaseAgent):
|
||||
name: Optional[str] = Field(None, description="The name of the agent.")
|
||||
tool_names: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
tool_ids: Optional[List[str]] = Field(None, description="The ids of the tools used by the agent.")
|
||||
source_ids: Optional[List[str]] = Field(None, description="The ids of the sources used by the agent.")
|
||||
block_ids: Optional[List[str]] = Field(None, description="The ids of the blocks used by the agent.")
|
||||
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
|
||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||
tool_rules: Optional[List[ToolRule]] = Field(None, description="The tool rules governing the agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
|
||||
# TODO: determine if these should be editable via this schema?
|
||||
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
||||
|
||||
class AgentStepResponse(BaseModel):
|
||||
messages: List[Message] = Field(..., description="The messages generated during the agent's step.")
|
||||
|
@ -1,33 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class AgentsTagsBase(LettaBase):
|
||||
__id_prefix__ = "agents_tags"
|
||||
|
||||
|
||||
class AgentsTags(AgentsTagsBase):
|
||||
"""
|
||||
Schema representing the relationship between tags and agents.
|
||||
|
||||
Parameters:
|
||||
agent_id (str): The ID of the associated agent.
|
||||
tag_id (str): The ID of the associated tag.
|
||||
tag_name (str): The name of the tag.
|
||||
created_at (datetime): The date this relationship was created.
|
||||
"""
|
||||
|
||||
id: str = AgentsTagsBase.generate_id_field()
|
||||
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||
tag: str = Field(..., description="The name of the tag.")
|
||||
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||
updated_at: Optional[datetime] = Field(None, description="The update date of the tag.")
|
||||
is_deleted: bool = Field(False, description="Whether this tag is deleted or not.")
|
||||
|
||||
|
||||
class AgentsTagsCreate(AgentsTagsBase):
|
||||
tag: str = Field(..., description="The tag name.")
|
@ -1,32 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class BlocksAgentsBase(LettaBase):
|
||||
__id_prefix__ = "blocks_agents"
|
||||
|
||||
|
||||
class BlocksAgents(BlocksAgentsBase):
|
||||
"""
|
||||
Schema representing the relationship between blocks and agents.
|
||||
|
||||
Parameters:
|
||||
agent_id (str): The ID of the associated agent.
|
||||
block_id (str): The ID of the associated block.
|
||||
block_label (str): The label of the block.
|
||||
created_at (datetime): The date this relationship was created.
|
||||
updated_at (datetime): The date this relationship was last updated.
|
||||
is_deleted (bool): Whether this block-agent relationship is deleted or not.
|
||||
"""
|
||||
|
||||
id: str = BlocksAgentsBase.generate_id_field()
|
||||
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||
block_id: str = Field(..., description="The ID of the associated block.")
|
||||
block_label: str = Field(..., description="The label of the block.")
|
||||
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
|
||||
is_deleted: bool = Field(False, description="Whether this block-agent relationship is deleted or not.")
|
@ -87,7 +87,7 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
Template(prompt_template)
|
||||
|
||||
# Validate compatibility with current memory structure
|
||||
test_render = Template(prompt_template).render(blocks=self.blocks)
|
||||
Template(prompt_template).render(blocks=self.blocks)
|
||||
|
||||
# If we get here, the template is valid and compatible
|
||||
self.prompt_template = prompt_template
|
||||
@ -213,6 +213,7 @@ class ChatMemory(BasicBlockMemory):
|
||||
human (str): The starter value for the human block.
|
||||
limit (int): The character limit for each block.
|
||||
"""
|
||||
# TODO: Should these be CreateBlocks?
|
||||
super().__init__(blocks=[Block(value=persona, limit=limit, label="persona"), Block(value=human, limit=limit, label="human")])
|
||||
|
||||
|
||||
|
@ -1,32 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class ToolsAgentsBase(LettaBase):
|
||||
__id_prefix__ = "tools_agents"
|
||||
|
||||
|
||||
class ToolsAgents(ToolsAgentsBase):
|
||||
"""
|
||||
Schema representing the relationship between tools and agents.
|
||||
|
||||
Parameters:
|
||||
agent_id (str): The ID of the associated agent.
|
||||
tool_id (str): The ID of the associated tool.
|
||||
tool_name (str): The name of the tool.
|
||||
created_at (datetime): The date this relationship was created.
|
||||
updated_at (datetime): The date this relationship was last updated.
|
||||
is_deleted (bool): Whether this tool-agent relationship is deleted or not.
|
||||
"""
|
||||
|
||||
id: str = ToolsAgentsBase.generate_id_field()
|
||||
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||
tool_id: str = Field(..., description="The ID of the associated tool.")
|
||||
tool_name: str = Field(..., description="The name of the tool.")
|
||||
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
|
||||
is_deleted: bool = Field(False, description="Whether this tool-agent relationship is deleted or not.")
|
@ -25,9 +25,6 @@ from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.routers.openai.assistants.assistants import (
|
||||
router as openai_assistants_router,
|
||||
)
|
||||
from letta.server.rest_api.routers.openai.assistants.threads import (
|
||||
router as openai_threads_router,
|
||||
)
|
||||
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import (
|
||||
router as openai_chat_completions_router,
|
||||
)
|
||||
@ -215,7 +212,6 @@ def create_application() -> "FastAPI":
|
||||
|
||||
# openai
|
||||
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
|
||||
app.include_router(openai_threads_router, prefix=OPENAI_API_PREFIX)
|
||||
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
|
||||
|
||||
# /api/auth endpoints
|
||||
@ -236,7 +232,6 @@ def create_application() -> "FastAPI":
|
||||
@app.on_event("shutdown")
|
||||
def on_shutdown():
|
||||
global server
|
||||
server.save_agents()
|
||||
# server = None
|
||||
|
||||
return app
|
||||
|
@ -1,338 +0,0 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path, Query
|
||||
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.openai import (
|
||||
MessageFile,
|
||||
OpenAIMessage,
|
||||
OpenAIRun,
|
||||
OpenAIRunStep,
|
||||
OpenAIThread,
|
||||
Text,
|
||||
)
|
||||
from letta.server.rest_api.routers.openai.assistants.schemas import (
|
||||
CreateMessageRequest,
|
||||
CreateRunRequest,
|
||||
CreateThreadRequest,
|
||||
CreateThreadRunRequest,
|
||||
DeleteThreadResponse,
|
||||
ListMessagesResponse,
|
||||
ModifyMessageRequest,
|
||||
ModifyRunRequest,
|
||||
ModifyThreadRequest,
|
||||
OpenAIThread,
|
||||
SubmitToolOutputsToRunRequest,
|
||||
)
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.utils import get_utc_time
|
||||
|
||||
|
||||
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
|
||||
router = APIRouter(prefix="/v1/threads", tags=["threads"])
|
||||
|
||||
|
||||
@router.post("/", response_model=OpenAIThread)
|
||||
def create_thread(
|
||||
request: CreateThreadRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
# TODO: use requests.description and requests.metadata fields
|
||||
# TODO: handle requests.file_ids and requests.tools
|
||||
# TODO: eventually allow request to override embedding/llm model
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
print("Create thread/agent", request)
|
||||
# create a letta agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(),
|
||||
user_id=actor.id,
|
||||
)
|
||||
# TODO: insert messages into recall memory
|
||||
return OpenAIThread(
|
||||
id=str(agent_state.id),
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
metadata={}, # TODO add metadata?
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=OpenAIThread)
|
||||
def retrieve_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
agent = server.get_agent(user_id=actor.id, agent_id=thread_id)
|
||||
assert agent is not None
|
||||
return OpenAIThread(
|
||||
id=str(agent.id),
|
||||
created_at=int(agent.created_at.timestamp()),
|
||||
metadata={}, # TODO add metadata?
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}", response_model=OpenAIThread)
|
||||
def modify_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: ModifyThreadRequest = Body(...),
|
||||
):
|
||||
# TODO: add agent metadata so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.delete("/{thread_id}", response_model=DeleteThreadResponse)
|
||||
def delete_thread(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
):
|
||||
# TODO: delete agent
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/messages", tags=["messages"], response_model=OpenAIMessage)
|
||||
def create_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateMessageRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
agent_id = thread_id
|
||||
# create message object
|
||||
message = Message(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
role=MessageRole(request.role),
|
||||
text=request.content,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
name=None,
|
||||
)
|
||||
agent = server.load_agent(agent_id=agent_id)
|
||||
# add message to agent
|
||||
agent._append_to_messages([message])
|
||||
|
||||
openai_message = OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=(message.text if message.text else ""))],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# TODO(sarah) fill in?
|
||||
run_id=None,
|
||||
file_ids=None,
|
||||
metadata=None,
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
return openai_message
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages", tags=["messages"], response_model=ListMessagesResponse)
|
||||
def list_messages(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many messages to retrieve."),
|
||||
order: str = Query("asc", description="Order of messages to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_user_or_default(user_id)
|
||||
after_uuid = after if before else None
|
||||
before_uuid = before if before else None
|
||||
agent_id = thread_id
|
||||
reverse = True if (order == "desc") else False
|
||||
json_messages = server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
limit=limit,
|
||||
after=after_uuid,
|
||||
before=before_uuid,
|
||||
order_by="created_at",
|
||||
reverse=reverse,
|
||||
)
|
||||
assert isinstance(json_messages, List)
|
||||
assert all([isinstance(message, Message) for message in json_messages])
|
||||
assert isinstance(json_messages[0], Message)
|
||||
print(json_messages[0].text)
|
||||
# convert to openai style messages
|
||||
openai_messages = []
|
||||
for message in json_messages:
|
||||
assert isinstance(message, Message)
|
||||
openai_messages.append(
|
||||
OpenAIMessage(
|
||||
id=str(message.id),
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=(message.text if message.text else ""))],
|
||||
role=str(message.role),
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# TODO(sarah) fill in?
|
||||
run_id=None,
|
||||
file_ids=None,
|
||||
metadata=None,
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
)
|
||||
print("MESSAGES", openai_messages)
|
||||
# TODO: cast back to message objects
|
||||
return ListMessagesResponse(messages=openai_messages)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
def retrieve_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
agent_id = thread_id
|
||||
message = server.get_agent_message(agent_id=agent_id, message_id=message_id)
|
||||
assert message is not None
|
||||
return OpenAIMessage(
|
||||
id=message_id,
|
||||
created_at=int(message.created_at.timestamp()),
|
||||
content=[Text(text=(message.text if message.text else ""))],
|
||||
role=message.role,
|
||||
thread_id=str(message.agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
# TODO(sarah) fill in?
|
||||
run_id=None,
|
||||
file_ids=None,
|
||||
metadata=None,
|
||||
# file_ids=message.file_ids,
|
||||
# metadata=message.metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{thread_id}/messages/{message_id}/files/{file_id}", tags=["messages"], response_model=MessageFile)
|
||||
def retrieve_message_file(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
file_id: str = Path(..., description="The unique identifier of the file."),
|
||||
):
|
||||
# TODO: implement?
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/messages/{message_id}", tags=["messages"], response_model=OpenAIMessage)
|
||||
def modify_message(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
message_id: str = Path(..., description="The unique identifier of the message."),
|
||||
request: ModifyMessageRequest = Body(...),
|
||||
):
|
||||
# TODO: add metada field to message so this can be modified
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
request: CreateRunRequest = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
|
||||
# TODO: add request.instructions as a message?
|
||||
agent_id = thread_id
|
||||
# TODO: override preset of agent with request.assistant_id
|
||||
agent = server.load_agent(agent_id=agent_id)
|
||||
agent.inner_step(messages=[]) # already has messages added
|
||||
run_id = str(uuid.uuid4())
|
||||
create_time = int(get_utc_time().timestamp())
|
||||
return OpenAIRun(
|
||||
id=run_id,
|
||||
created_at=create_time,
|
||||
thread_id=str(agent_id),
|
||||
assistant_id=DEFAULT_PRESET, # TODO: update this
|
||||
status="completed", # TODO: eventaully allow offline execution
|
||||
expires_at=create_time,
|
||||
model=agent.agent_state.llm_config.model,
|
||||
instructions=request.instructions,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/runs", tags=["runs"], response_model=OpenAIRun)
|
||||
def create_thread_and_run(
|
||||
request: CreateThreadRunRequest = Body(...),
|
||||
):
|
||||
# TODO: add a bunch of messages and execute
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs", tags=["runs"], response_model=List[OpenAIRun])
|
||||
def list_runs(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
limit: int = Query(1000, description="How many runs to retrieve."),
|
||||
order: str = Query("asc", description="Order of runs to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/steps", tags=["runs"], response_model=List[OpenAIRunStep])
|
||||
def list_run_steps(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
limit: int = Query(1000, description="How many run steps to retrieve."),
|
||||
order: str = Query("asc", description="Order of run steps to retrieve (either 'asc' or 'desc')."),
|
||||
after: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
before: str = Query(None, description="A cursor for use in pagination. `after` is an object ID that defines your place in the list."),
|
||||
):
|
||||
# TODO: store run information in a DB so it can be returned here
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def retrieve_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.get("/{thread_id}/runs/{run_id}/steps/{step_id}", tags=["runs"], response_model=OpenAIRunStep)
|
||||
def retrieve_run_step(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
step_id: str = Path(..., description="The unique identifier of the run step."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}", tags=["runs"], response_model=OpenAIRun)
|
||||
def modify_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: ModifyRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/submit_tool_outputs", tags=["runs"], response_model=OpenAIRun)
|
||||
def submit_tool_outputs_to_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
request: SubmitToolOutputsToRunRequest = Body(...),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
||||
|
||||
|
||||
@router.post("/{thread_id}/runs/{run_id}/cancel", tags=["runs"], response_model=OpenAIRun)
|
||||
def cancel_run(
|
||||
thread_id: str = Path(..., description="The unique identifier of the thread."),
|
||||
run_id: str = Path(..., description="The unique identifier of the run."),
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not yet implemented (coming soon)")
|
@ -36,7 +36,7 @@ async def create_chat_completion(
|
||||
The bearer token will be used to identify the user.
|
||||
The 'user' field in the completion_request should be set to the agent ID.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_id = completion_request.user
|
||||
if agent_id is None:
|
||||
|
@ -17,7 +17,8 @@ from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import Field
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import ( # , BlockLabelUpdate, BlockLimitUpdate
|
||||
Block,
|
||||
BlockUpdate,
|
||||
@ -54,23 +55,38 @@ from letta.server.server import SyncServer
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
|
||||
# TODO: This should be paginated
|
||||
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
|
||||
def list_agents(
|
||||
name: Optional[str] = Query(None, description="Name of the agent"),
|
||||
tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"),
|
||||
match_all_tags: bool = Query(
|
||||
False,
|
||||
description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.",
|
||||
),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
# Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agents = server.list_agents(user_id=actor.id, tags=tags)
|
||||
# TODO: move this logic to the ORM
|
||||
if name:
|
||||
agents = [a for a in agents if a.name == name]
|
||||
# Use dictionary comprehension to build kwargs dynamically
|
||||
kwargs = {
|
||||
key: value
|
||||
for key, value in {
|
||||
"tags": tags,
|
||||
"match_all_tags": match_all_tags,
|
||||
"name": name,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
# Call list_agents with the dynamic kwargs
|
||||
agents = server.agent_manager.list_agents(actor=actor, **kwargs)
|
||||
return agents
|
||||
|
||||
|
||||
@ -83,7 +99,7 @@ def get_agent_context_window(
|
||||
"""
|
||||
Retrieve the context window of a specific agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_agent_context_window(user_id=actor.id, agent_id=agent_id)
|
||||
|
||||
@ -106,20 +122,20 @@ def create_agent(
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.create_agent(agent, actor=actor)
|
||||
|
||||
|
||||
@router.patch("/{agent_id}", response_model=AgentState, operation_id="update_agent")
|
||||
def update_agent(
|
||||
agent_id: str,
|
||||
update_agent: UpdateAgentState = Body(...),
|
||||
update_agent: UpdateAgent = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an exsiting agent"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.update_agent(update_agent, actor=actor)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.update_agent(agent_id, update_agent, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/tools", response_model=List[Tool], operation_id="get_tools_from_agent")
|
||||
@ -129,7 +145,7 @@ def get_tools_from_agent(
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Get tools from an existing agent"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.get_tools_from_agent(agent_id=agent_id, user_id=actor.id)
|
||||
|
||||
|
||||
@ -141,7 +157,7 @@ def add_tool_to_agent(
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Add tools to an existing agent"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.add_tool_to_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
|
||||
|
||||
@ -153,7 +169,7 @@ def remove_tool_from_agent(
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Add tools to an existing agent"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.remove_tool_from_agent(agent_id=agent_id, tool_id=tool_id, user_id=actor.id)
|
||||
|
||||
|
||||
@ -166,13 +182,12 @@ def get_agent_state(
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
|
||||
# agent does not exist
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
return server.get_agent_state(user_id=actor.id, agent_id=agent_id)
|
||||
try:
|
||||
return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{agent_id}", response_model=AgentState, operation_id="delete_agent")
|
||||
@ -184,38 +199,37 @@ def delete_agent(
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent = server.get_agent(agent_id)
|
||||
if not agent:
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
server.delete_agent(user_id=actor.id, agent_id=agent_id)
|
||||
return agent
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
try:
|
||||
return server.agent_manager.delete_agent(agent_id=agent_id, actor=actor)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found for user_id={actor.id}.")
|
||||
|
||||
|
||||
@router.get("/{agent_id}/sources", response_model=List[Source], operation_id="get_agent_sources")
|
||||
def get_agent_sources(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Get the sources associated with an agent.
|
||||
"""
|
||||
|
||||
return server.list_attached_sources(agent_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/messages", response_model=List[Message], operation_id="list_agent_in_context_messages")
|
||||
def get_agent_in_context_messages(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the messages in the context of a specific agent.
|
||||
"""
|
||||
|
||||
return server.get_in_context_messages(agent_id=agent_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.get_in_context_messages(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
# TODO: remove? can also get with agent blocks
|
||||
@ -223,13 +237,15 @@ def get_agent_in_context_messages(
|
||||
def get_agent_memory(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_agent_memory(agent_id=agent_id)
|
||||
return server.get_agent_memory(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="get_agent_memory_block")
|
||||
@ -242,10 +258,12 @@ def get_agent_memory_block(
|
||||
"""
|
||||
Retrieve a memory block from an agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label)
|
||||
return server.block_manager.get_block_by_id(block_id, actor=actor)
|
||||
try:
|
||||
return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/block", response_model=List[Block], operation_id="get_agent_memory_blocks")
|
||||
@ -257,9 +275,12 @@ def get_agent_memory_blocks(
|
||||
"""
|
||||
Retrieve the memory blocks of a specific agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
block_ids = server.blocks_agents_manager.list_block_ids_for_agent(agent_id=agent_id)
|
||||
return [server.block_manager.get_block_by_id(block_id, actor=actor) for block_id in block_ids]
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
try:
|
||||
agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor)
|
||||
return agent.memory.blocks
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/{agent_id}/memory/block", response_model=Memory, operation_id="add_agent_memory_block")
|
||||
@ -272,16 +293,17 @@ def add_agent_memory_block(
|
||||
"""
|
||||
Creates a memory block and links it to the agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Copied from POST /blocks
|
||||
# TODO: Should have block_manager accept only CreateBlock
|
||||
# TODO: This will be possible once we move ID creation to the ORM
|
||||
block_req = Block(**create_block.model_dump())
|
||||
block = server.block_manager.create_or_update_block(actor=actor, block=block_req)
|
||||
|
||||
# Link the block to the agent
|
||||
updated_memory = server.link_block_to_agent_memory(user_id=actor.id, agent_id=agent_id, block_id=block.id)
|
||||
|
||||
return updated_memory
|
||||
agent = server.agent_manager.attach_block(agent_id=agent_id, block_id=block.id, actor=actor)
|
||||
return agent.memory
|
||||
|
||||
|
||||
@router.delete("/{agent_id}/memory/block/{block_label}", response_model=Memory, operation_id="remove_agent_memory_block_by_label")
|
||||
@ -296,56 +318,56 @@ def remove_agent_memory_block(
|
||||
"""
|
||||
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Unlink the block from the agent
|
||||
updated_memory = server.unlink_block_from_agent_memory(user_id=actor.id, agent_id=agent_id, block_label=block_label)
|
||||
agent = server.agent_manager.detach_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
|
||||
return updated_memory
|
||||
return agent.memory
|
||||
|
||||
|
||||
@router.patch("/{agent_id}/memory/block/{block_label}", response_model=Block, operation_id="update_agent_memory_block_by_label")
|
||||
def update_agent_memory_block(
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
update_block: BlockUpdate = Body(...),
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Removes a memory block from an agent by unlnking it. If the block is not linked to any other agent, it is deleted.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# get the block_id from the label
|
||||
block_id = server.blocks_agents_manager.get_block_id_for_label(agent_id=agent_id, block_label=block_label)
|
||||
|
||||
# update the block
|
||||
return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor)
|
||||
block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor)
|
||||
return server.block_manager.update_block(block.id, block_update=block_update, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/recall", response_model=RecallMemorySummary, operation_id="get_agent_recall_memory_summary")
|
||||
def get_agent_recall_memory_summary(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the recall memory of a specific agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_recall_memory_summary(agent_id=agent_id)
|
||||
return server.get_recall_memory_summary(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/memory/archival", response_model=ArchivalMemorySummary, operation_id="get_agent_archival_memory_summary")
|
||||
def get_agent_archival_memory_summary(
|
||||
agent_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the archival memory of a specific agent.
|
||||
"""
|
||||
|
||||
return server.get_archival_memory_summary(agent_id=agent_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.get_archival_memory_summary(agent_id=agent_id, actor=actor)
|
||||
|
||||
|
||||
@router.get("/{agent_id}/archival", response_model=List[Passage], operation_id="list_agent_archival_memory")
|
||||
@ -360,7 +382,7 @@ def get_agent_archival_memory(
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# TODO need to add support for non-postgres here
|
||||
# chroma will throw:
|
||||
@ -369,7 +391,7 @@ def get_agent_archival_memory(
|
||||
return server.get_agent_archival_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
cursor=after, # TODO: deleting before, after. is this expected?
|
||||
cursor=after, # TODO: deleting before, after. is this expected?
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
@ -384,9 +406,9 @@ def insert_agent_archival_memory(
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.insert_archival_memory(user_id=actor.id, agent_id=agent_id, memory_contents=request.text)
|
||||
return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor)
|
||||
|
||||
|
||||
# TODO(ethan): query or path parameter for memory_id?
|
||||
@ -402,9 +424,9 @@ def delete_agent_archival_memory(
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
server.delete_archival_memory(user_id=actor.id, agent_id=agent_id, memory_id=memory_id)
|
||||
server.delete_archival_memory(agent_id=agent_id, memory_id=memory_id, actor=actor)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
|
||||
|
||||
@ -429,7 +451,7 @@ def get_agent_messages(
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
@ -449,11 +471,13 @@ def update_message(
|
||||
message_id: str,
|
||||
request: MessageUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Update the details of a message associated with an agent.
|
||||
"""
|
||||
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.update_agent_message(agent_id=agent_id, message_id=message_id, request=request, actor=actor)
|
||||
|
||||
|
||||
@router.post(
|
||||
@ -471,11 +495,11 @@ async def send_message(
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
result = await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
@ -511,11 +535,11 @@ async def send_message_streaming(
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
result = await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
actor=actor,
|
||||
messages=request.messages,
|
||||
stream_steps=True,
|
||||
stream_tokens=request.stream_tokens,
|
||||
@ -531,7 +555,6 @@ async def process_message_background(
|
||||
server: SyncServer,
|
||||
actor: User,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
messages: list,
|
||||
assistant_message_tool_name: str,
|
||||
assistant_message_tool_kwarg: str,
|
||||
@ -542,7 +565,7 @@ async def process_message_background(
|
||||
result = await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
actor=actor,
|
||||
messages=messages,
|
||||
stream_steps=False, # NOTE(matt)
|
||||
stream_tokens=False,
|
||||
@ -585,7 +608,7 @@ async def send_message_async(
|
||||
Asynchronously process a user message and return a job ID.
|
||||
The actual processing happens in the background, and the status can be checked using the job ID.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Create a new job
|
||||
job = Job(
|
||||
@ -605,7 +628,6 @@ async def send_message_async(
|
||||
server=server,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
messages=request.messages,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
@ -618,7 +640,7 @@ async def send_message_async(
|
||||
async def send_message_to_agent(
|
||||
server: SyncServer,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
actor: User,
|
||||
# role: MessageRole,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
stream_steps: bool,
|
||||
@ -645,8 +667,7 @@ async def send_message_to_agent(
|
||||
|
||||
# Get the generator object off of the agent's streaming interface
|
||||
# This will be attached to the POST SSE request used under-the-hood
|
||||
# letta_agent = server.load_agent(agent_id=agent_id)
|
||||
letta_agent = server.load_agent(agent_id=agent_id)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Disable token streaming if not OpenAI
|
||||
# TODO: cleanup this logic
|
||||
@ -685,7 +706,7 @@ async def send_message_to_agent(
|
||||
task = asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
server.send_messages,
|
||||
user_id=user_id,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
messages=messages,
|
||||
interface=streaming_interface,
|
||||
|
@ -1,10 +1,9 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Response
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.schemas.memory import Memory
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@ -23,7 +22,7 @@ def list_blocks(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name)
|
||||
|
||||
|
||||
@ -33,7 +32,7 @@ def create_block(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
block = Block(**create_block.model_dump())
|
||||
return server.block_manager.create_or_update_block(actor=actor, block=block)
|
||||
|
||||
@ -41,12 +40,12 @@ def create_block(
|
||||
@router.patch("/{block_id}", response_model=Block, operation_id="update_memory_block")
|
||||
def update_block(
|
||||
block_id: str,
|
||||
update_block: BlockUpdate = Body(...),
|
||||
block_update: BlockUpdate = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
return server.block_manager.update_block(block_id=block_id, block_update=update_block, actor=actor)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor)
|
||||
|
||||
|
||||
@router.delete("/{block_id}", response_model=Block, operation_id="delete_memory_block")
|
||||
@ -55,7 +54,7 @@ def delete_block(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.block_manager.delete_block(block_id=block_id, actor=actor)
|
||||
|
||||
|
||||
@ -66,7 +65,7 @@ def get_block(
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
print("call get block", block_id)
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
try:
|
||||
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
|
||||
if block is None:
|
||||
@ -76,7 +75,7 @@ def get_block(
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
|
||||
|
||||
@router.patch("/{block_id}/attach", response_model=Block, operation_id="link_agent_memory_block")
|
||||
@router.patch("/{block_id}/attach", response_model=None, status_code=204, operation_id="link_agent_memory_block")
|
||||
def link_agent_memory_block(
|
||||
block_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
@ -86,17 +85,16 @@ def link_agent_memory_block(
|
||||
"""
|
||||
Link a memory block to an agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
|
||||
if block is None:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block.label)
|
||||
return block
|
||||
try:
|
||||
server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
return Response(status_code=204)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/{block_id}/detach", response_model=Memory, operation_id="unlink_agent_memory_block")
|
||||
@router.patch("/{block_id}/detach", response_model=None, status_code=204, operation_id="unlink_agent_memory_block")
|
||||
def unlink_agent_memory_block(
|
||||
block_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
@ -106,11 +104,10 @@ def unlink_agent_memory_block(
|
||||
"""
|
||||
Unlink a memory block from an agent
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor)
|
||||
if block is None:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
# Link the block to the agent
|
||||
server.blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id)
|
||||
return block
|
||||
try:
|
||||
server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor)
|
||||
return Response(status_code=204)
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
@ -20,7 +20,7 @@ def list_jobs(
|
||||
"""
|
||||
List all jobs.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# TODO: add filtering by status
|
||||
jobs = server.job_manager.list_jobs(actor=actor)
|
||||
@ -40,7 +40,7 @@ def list_active_jobs(
|
||||
"""
|
||||
List all active jobs.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running])
|
||||
|
||||
@ -54,7 +54,7 @@ def get_job(
|
||||
"""
|
||||
Get the status of a job.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
try:
|
||||
return server.job_manager.get_job_by_id(job_id=job_id, actor=actor)
|
||||
@ -71,7 +71,7 @@ def delete_job(
|
||||
"""
|
||||
Delete a job by its job_id.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor)
|
||||
|
@ -25,7 +25,7 @@ def create_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor)
|
||||
|
||||
@ -35,7 +35,7 @@ def create_default_e2b_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor)
|
||||
|
||||
|
||||
@ -44,7 +44,7 @@ def create_default_local_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor)
|
||||
|
||||
|
||||
@ -55,7 +55,7 @@ def update_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor)
|
||||
|
||||
|
||||
@ -65,7 +65,7 @@ def delete_sandbox_config(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor)
|
||||
|
||||
|
||||
@ -76,7 +76,7 @@ def list_sandbox_configs(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, cursor=cursor)
|
||||
|
||||
|
||||
@ -90,7 +90,7 @@ def create_sandbox_env_var(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor)
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ def update_sandbox_env_var(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor)
|
||||
|
||||
|
||||
@ -111,7 +111,7 @@ def delete_sandbox_env_var(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor)
|
||||
|
||||
|
||||
@ -123,5 +123,5 @@ def list_sandbox_env_vars(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: str = Depends(get_user_id),
|
||||
):
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, cursor=cursor)
|
||||
|
@ -36,7 +36,7 @@ def get_source(
|
||||
"""
|
||||
Get all sources
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
if not source:
|
||||
@ -53,7 +53,7 @@ def get_source_id_by_name(
|
||||
"""
|
||||
Get a source by name
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
|
||||
if not source:
|
||||
@ -69,7 +69,7 @@ def list_sources(
|
||||
"""
|
||||
List all data sources created by a user.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.list_all_sources(actor=actor)
|
||||
|
||||
@ -83,7 +83,7 @@ def create_source(
|
||||
"""
|
||||
Create a new data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
source = Source(**source_create.model_dump())
|
||||
|
||||
return server.source_manager.create_source(source=source, actor=actor)
|
||||
@ -99,7 +99,7 @@ def update_source(
|
||||
"""
|
||||
Update the name or documentation of an existing data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
|
||||
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
|
||||
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
|
||||
@ -114,7 +114,7 @@ def delete_source(
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
server.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
@ -129,7 +129,7 @@ def attach_source_to_agent(
|
||||
"""
|
||||
Attach a data source to an existing agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
@ -147,7 +147,7 @@ def detach_source_from_agent(
|
||||
"""
|
||||
Detach a data source from an existing agent.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
return server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=actor.id)
|
||||
|
||||
@ -163,7 +163,7 @@ def upload_file_to_source(
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
|
||||
assert source is not None, f"Source with id={source_id} not found."
|
||||
@ -197,7 +197,7 @@ def list_passages(
|
||||
"""
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
|
||||
return passages
|
||||
|
||||
@ -213,7 +213,7 @@ def list_files_from_source(
|
||||
"""
|
||||
List paginated files associated with a data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.source_manager.list_files(source_id=source_id, limit=limit, cursor=cursor, actor=actor)
|
||||
|
||||
|
||||
@ -229,7 +229,7 @@ def delete_file_from_source(
|
||||
"""
|
||||
Delete a data source.
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
|
||||
if deleted_file is None:
|
||||
|
@ -25,7 +25,7 @@ def delete_tool(
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
|
||||
|
||||
@ -38,7 +38,7 @@ def get_tool(
|
||||
"""
|
||||
Get a tool by ID
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
@ -55,7 +55,7 @@ def get_tool_id(
|
||||
"""
|
||||
Get a tool ID by name
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
tool = server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
||||
if tool:
|
||||
return tool.id
|
||||
@ -74,7 +74,7 @@ def list_tools(
|
||||
Get a list of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.tool_manager.list_tools(actor=actor, cursor=cursor, limit=limit)
|
||||
except Exception as e:
|
||||
# Log or print the full exception here for debugging
|
||||
@ -92,7 +92,7 @@ def create_tool(
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
tool = Tool(**request.model_dump())
|
||||
return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor)
|
||||
except UniqueConstraintViolationError as e:
|
||||
@ -124,7 +124,7 @@ def upsert_tool(
|
||||
Create or update a tool
|
||||
"""
|
||||
try:
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor)
|
||||
return tool
|
||||
except UniqueConstraintViolationError as e:
|
||||
@ -147,7 +147,7 @@ def update_tool(
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor)
|
||||
|
||||
|
||||
@ -159,7 +159,7 @@ def add_base_tools(
|
||||
"""
|
||||
Add base tools
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.tool_manager.add_base_tools(actor=actor)
|
||||
|
||||
|
||||
@ -173,7 +173,7 @@ def add_base_tools(
|
||||
# """
|
||||
# Run an existing tool on provided arguments
|
||||
# """
|
||||
# actor = server.get_user_or_default(user_id=user_id)
|
||||
# actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)
|
||||
|
||||
@ -187,7 +187,7 @@ def run_tool_from_source(
|
||||
"""
|
||||
Attempt to build a tool from source, then run it on the provided arguments
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
try:
|
||||
return server.run_tool_from_source(
|
||||
@ -220,7 +220,7 @@ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id:
|
||||
"""
|
||||
Get a list of all Composio apps
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
composio_api_key = get_composio_key(server, actor=actor)
|
||||
return server.get_composio_apps(api_key=composio_api_key)
|
||||
|
||||
@ -234,7 +234,7 @@ def list_composio_actions_by_app(
|
||||
"""
|
||||
Get a list of all Composio actions for a specific app
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
composio_api_key = get_composio_key(server, actor=actor)
|
||||
return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name, api_key=composio_api_key)
|
||||
|
||||
@ -248,7 +248,7 @@ def add_composio_tool(
|
||||
"""
|
||||
Add a new Composio tool by action name (Composio refers to each tool as an `Action`)
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
composio_api_key = get_composio_key(server, actor=actor)
|
||||
|
||||
try:
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -19,11 +19,6 @@ class WebSocketServer:
|
||||
self.server = SyncServer(default_interface=self.interface)
|
||||
|
||||
def shutdown_server(self):
|
||||
try:
|
||||
self.server.save_agents()
|
||||
print(f"Saved agents")
|
||||
except Exception as e:
|
||||
print(f"Saving agents failed with: {e}")
|
||||
try:
|
||||
self.interface.close()
|
||||
print(f"Closed the WS interface")
|
||||
|
405
letta/services/agent_manager.py
Normal file
405
letta/services/agent_manager.py
Normal file
@ -0,0 +1,405 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
derive_system_message,
|
||||
)
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
# Agent Manager Class
|
||||
class AgentManager:
|
||||
"""Manager class to handle business logic related to Agents."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
self.block_manager = BlockManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.source_manager = SourceManager()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
def create_agent(
|
||||
self,
|
||||
agent_create: CreateAgent,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
system = derive_system_message(agent_type=agent_create.agent_type, system=agent_create.system)
|
||||
|
||||
# create blocks (note: cannot be linked into the agent_id is created)
|
||||
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
||||
for create_block in agent_create.memory_blocks:
|
||||
block = self.block_manager.create_or_update_block(PydanticBlock(**create_block.model_dump()), actor=actor)
|
||||
block_ids.append(block.id)
|
||||
|
||||
# TODO: Remove this block once we deprecate the legacy `tools` field
|
||||
# create passed in `tools`
|
||||
tool_names = []
|
||||
if agent_create.include_base_tools:
|
||||
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
if agent_create.tools:
|
||||
tool_names.extend(agent_create.tools)
|
||||
|
||||
tool_ids = agent_create.tool_ids or []
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
||||
if tool:
|
||||
tool_ids.append(tool.id)
|
||||
# Remove duplicates
|
||||
tool_ids = list(set(tool_ids))
|
||||
|
||||
return self._create_agent(
|
||||
name=agent_create.name,
|
||||
system=system,
|
||||
agent_type=agent_create.agent_type,
|
||||
llm_config=agent_create.llm_config,
|
||||
embedding_config=agent_create.embedding_config,
|
||||
block_ids=block_ids,
|
||||
tool_ids=tool_ids,
|
||||
source_ids=agent_create.source_ids or [],
|
||||
tags=agent_create.tags or [],
|
||||
description=agent_create.description,
|
||||
metadata_=agent_create.metadata_,
|
||||
tool_rules=agent_create.tool_rules,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
def _create_agent(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
name: str,
|
||||
system: str,
|
||||
agent_type: AgentType,
|
||||
llm_config: LLMConfig,
|
||||
embedding_config: EmbeddingConfig,
|
||||
block_ids: List[str],
|
||||
tool_ids: List[str],
|
||||
source_ids: List[str],
|
||||
tags: List[str],
|
||||
description: Optional[str] = None,
|
||||
metadata_: Optional[Dict] = None,
|
||||
tool_rules: Optional[List[PydanticToolRule]] = None,
|
||||
) -> PydanticAgentState:
|
||||
"""Create a new agent."""
|
||||
with self.session_maker() as session:
|
||||
# Prepare the agent data
|
||||
data = {
|
||||
"name": name,
|
||||
"system": system,
|
||||
"agent_type": agent_type,
|
||||
"llm_config": llm_config,
|
||||
"embedding_config": embedding_config,
|
||||
"organization_id": actor.organization_id,
|
||||
"description": description,
|
||||
"metadata_": metadata_,
|
||||
"tool_rules": tool_rules,
|
||||
}
|
||||
|
||||
# Create the new agent using SqlalchemyBase.create
|
||||
new_agent = AgentModel(**data)
|
||||
_process_relationship(session, new_agent, "tools", ToolModel, tool_ids, replace=True)
|
||||
_process_relationship(session, new_agent, "sources", SourceModel, source_ids, replace=True)
|
||||
_process_relationship(session, new_agent, "core_memory", BlockModel, block_ids, replace=True)
|
||||
_process_tags(new_agent, tags, replace=True)
|
||||
new_agent.create(session, actor=actor)
|
||||
|
||||
# Convert to PydanticAgentState and return
|
||||
return new_agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update an existing agent.
|
||||
|
||||
Args:
|
||||
agent_id: The ID of the agent to update.
|
||||
agent_update: UpdateAgent object containing the updated fields.
|
||||
actor: User performing the action.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent as a Pydantic model.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the existing agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Update scalar fields directly
|
||||
scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata_"}
|
||||
for field in scalar_fields:
|
||||
value = getattr(agent_update, field, None)
|
||||
if value is not None:
|
||||
setattr(agent, field, value)
|
||||
|
||||
# Update relationships using _process_relationship and _process_tags
|
||||
if agent_update.tool_ids is not None:
|
||||
_process_relationship(session, agent, "tools", ToolModel, agent_update.tool_ids, replace=True)
|
||||
if agent_update.source_ids is not None:
|
||||
_process_relationship(session, agent, "sources", SourceModel, agent_update.source_ids, replace=True)
|
||||
if agent_update.block_ids is not None:
|
||||
_process_relationship(session, agent, "core_memory", BlockModel, agent_update.block_ids, replace=True)
|
||||
if agent_update.tags is not None:
|
||||
_process_tags(agent, agent_update.tags, replace=True)
|
||||
|
||||
# Commit and refresh the agent
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
# Convert to PydanticAgentState and return
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def list_agents(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
tags: Optional[List[str]] = None,
|
||||
match_all_tags: bool = False,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
**kwargs,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
List agents that have the specified tags.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
agents = AgentModel.list(
|
||||
db_session=session,
|
||||
tags=tags,
|
||||
match_all_tags=match_all_tags,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
organization_id=actor.organization_id if actor else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return [agent.to_pydantic() for agent in agents]
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_agent(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Deletes an agent and its associated relationships.
|
||||
Ensures proper permission checks and cascades where applicable.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to be deleted.
|
||||
actor: User performing the action.
|
||||
|
||||
Returns:
|
||||
PydanticAgentState: The deleted agent state
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
|
||||
# TODO: This is done very hacky on purpose
|
||||
# TODO: 1000 limit is also wack
|
||||
passage_manager = PassageManager()
|
||||
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
|
||||
|
||||
agent_state = agent.to_pydantic()
|
||||
agent.hard_delete(session)
|
||||
return agent_state
|
||||
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Attaches a source to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to attach the source to
|
||||
source_id: ID of the source to attach
|
||||
actor: User performing the action
|
||||
|
||||
Raises:
|
||||
ValueError: If either agent or source doesn't exist
|
||||
IntegrityError: If the source is already attached to the agent
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Verify both agent and source exist and user has permission to access them
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# The _process_relationship helper already handles duplicate checking via unique constraint
|
||||
_process_relationship(
|
||||
session=session,
|
||||
agent=agent,
|
||||
relationship_name="sources",
|
||||
model_class=SourceModel,
|
||||
item_ids=[source_id],
|
||||
allow_partial=False,
|
||||
replace=False, # Extend existing sources rather than replace
|
||||
)
|
||||
|
||||
# Commit the changes
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||||
"""
|
||||
Lists all sources attached to an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to list sources for
|
||||
actor: User performing the action
|
||||
|
||||
Returns:
|
||||
List[str]: List of source IDs attached to the agent
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Use the lazy-loaded relationship to get sources
|
||||
return [source.to_pydantic() for source in agent.sources]
|
||||
|
||||
@enforce_types
|
||||
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
Detaches a source from an agent.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent to detach the source from
|
||||
source_id: ID of the source to detach
|
||||
actor: User performing the action
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# Remove the source from the relationship
|
||||
agent.sources = [s for s in agent.sources if s.id != source_id]
|
||||
|
||||
# Commit the changes
|
||||
agent.update(session, actor=actor)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Block management
|
||||
# ======================================================================================================================
|
||||
@enforce_types
|
||||
def get_block_with_label(
|
||||
self,
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticBlock:
|
||||
"""Gets a block attached to an agent by its label."""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
for block in agent.core_memory:
|
||||
if block.label == block_label:
|
||||
return block.to_pydantic()
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||||
|
||||
@enforce_types
|
||||
def update_block_with_label(
|
||||
self,
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
new_block_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""Updates which block is assigned to a specific label for an agent."""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor)
|
||||
|
||||
if new_block.label != block_label:
|
||||
raise ValueError(f"New block label '{new_block.label}' doesn't match required label '{block_label}'")
|
||||
|
||||
# Remove old block with this label if it exists
|
||||
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
|
||||
|
||||
# Add new block
|
||||
agent.core_memory.append(new_block)
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Attaches a block to an agent."""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
|
||||
agent.core_memory.append(block)
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def detach_block(
|
||||
self,
|
||||
agent_id: str,
|
||||
block_id: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""Detaches a block from an agent."""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
original_length = len(agent.core_memory)
|
||||
|
||||
agent.core_memory = [b for b in agent.core_memory if b.id != block_id]
|
||||
|
||||
if len(agent.core_memory) == original_length:
|
||||
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||||
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def detach_block_with_label(
|
||||
self,
|
||||
agent_id: str,
|
||||
block_label: str,
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""Detaches a block with the specified label from an agent."""
|
||||
with self.session_maker() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
original_length = len(agent.core_memory)
|
||||
|
||||
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
|
||||
|
||||
if len(agent.core_memory) == original_length:
|
||||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||||
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
@ -1,64 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from letta.orm.agents_tags import AgentsTags as AgentsTagsModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class AgentsTagsManager:
|
||||
"""Manager class to handle business logic related to Tags."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def add_tag_to_agent(self, agent_id: str, tag: str, actor: PydanticUser) -> PydanticAgentsTags:
|
||||
"""Add a tag to an agent."""
|
||||
with self.session_maker() as session:
|
||||
# Check if the tag already exists for this agent
|
||||
try:
|
||||
agents_tags_model = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
|
||||
return agents_tags_model.to_pydantic()
|
||||
except NoResultFound:
|
||||
agents_tags = PydanticAgentsTags(agent_id=agent_id, tag=tag).model_dump(exclude_none=True)
|
||||
new_tag = AgentsTagsModel(**agents_tags, organization_id=actor.organization_id)
|
||||
new_tag.create(session, actor=actor)
|
||||
return new_tag.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_all_tags_from_agent(self, agent_id: str, actor: PydanticUser):
|
||||
"""Delete a tag from an agent. This is a permanent hard delete."""
|
||||
tags = self.get_tags_for_agent(agent_id=agent_id, actor=actor)
|
||||
for tag in tags:
|
||||
self.delete_tag_from_agent(agent_id=agent_id, tag=tag, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def delete_tag_from_agent(self, agent_id: str, tag: str, actor: PydanticUser):
|
||||
"""Delete a tag from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Retrieve and delete the tag association
|
||||
tag_association = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
|
||||
tag_association.hard_delete(session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tag '{tag}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def get_agents_by_tag(self, tag: str, actor: PydanticUser) -> List[str]:
|
||||
"""Retrieve all agent IDs associated with a specific tag."""
|
||||
with self.session_maker() as session:
|
||||
# Query for all agents with the given tag
|
||||
agents_with_tag = AgentsTagsModel.list(db_session=session, tag=tag, organization_id=actor.organization_id)
|
||||
return [record.agent_id for record in agents_with_tag]
|
||||
|
||||
@enforce_types
|
||||
def get_tags_for_agent(self, agent_id: str, actor: PydanticUser) -> List[str]:
|
||||
"""Retrieve all tags associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
# Query for all tags associated with the given agent
|
||||
tags_for_agent = AgentsTagsModel.list(db_session=session, agent_id=agent_id, organization_id=actor.organization_id)
|
||||
return [record.tag for record in tags_for_agent]
|
@ -7,7 +7,6 @@ from letta.schemas.block import Block
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, Human, Persona
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||
from letta.utils import enforce_types, list_human_files, list_persona_files
|
||||
|
||||
|
||||
@ -37,33 +36,17 @@ class BlockManager:
|
||||
@enforce_types
|
||||
def update_block(self, block_id: str, block_update: BlockUpdate, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
|
||||
blocks_agents_manager = BlocksAgentsManager()
|
||||
agent_ids = []
|
||||
if block_update.label:
|
||||
agent_ids = blocks_agents_manager.list_agent_ids_with_block(block_id=block_id)
|
||||
for agent_id in agent_ids:
|
||||
blocks_agents_manager.remove_block_with_id_from_agent(agent_id=agent_id, block_id=block_id)
|
||||
# Safety check for block
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Update block
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
update_data = block_update.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(block, key, value)
|
||||
try:
|
||||
block.to_pydantic()
|
||||
except Exception as e:
|
||||
# invalid pydantic model
|
||||
raise ValueError(f"Failed to create pydantic model: {e}")
|
||||
|
||||
block.update(db_session=session, actor=actor)
|
||||
|
||||
# TODO: REMOVE THIS ONCE AGENT IS ON ORM -> Update blocks_agents
|
||||
if block_update.label:
|
||||
for agent_id in agent_ids:
|
||||
blocks_agents_manager.add_block_to_agent(agent_id=agent_id, block_id=block_id, block_label=block_update.label)
|
||||
|
||||
return block.to_pydantic()
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
|
||||
@ -111,6 +94,15 @@ class BlockManager:
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
# TODO: We can do this much more efficiently by listing, instead of executing individual queries per block_id
|
||||
blocks = []
|
||||
for block_id in block_ids:
|
||||
block = self.get_block_by_id(block_id, actor=actor)
|
||||
blocks.append(block)
|
||||
return blocks
|
||||
|
||||
@enforce_types
|
||||
def add_default_blocks(self, actor: PydanticUser):
|
||||
for persona_file in list_persona_files():
|
||||
|
@ -1,106 +0,0 @@
|
||||
import warnings
|
||||
from typing import List
|
||||
|
||||
from letta.orm.blocks_agents import BlocksAgents as BlocksAgentsModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.blocks_agents import BlocksAgents as PydanticBlocksAgents
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
# TODO: DELETE THIS ASAP
|
||||
# TODO: So we have a patch where we manually specify CRUD operations
|
||||
# TODO: This is because Agent is NOT migrated to the ORM yet
|
||||
# TODO: Once we migrate Agent to the ORM, we should deprecate any agents relationship table managers
|
||||
class BlocksAgentsManager:
|
||||
"""Manager class to handle business logic related to Blocks and Agents."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def add_block_to_agent(self, agent_id: str, block_id: str, block_label: str) -> PydanticBlocksAgents:
|
||||
"""Add a block to an agent. If the label already exists on that agent, this will error."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Check if the block-label combination already exists for this agent
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||
warnings.warn(f"Block label '{block_label}' already exists for agent '{agent_id}'.")
|
||||
except NoResultFound:
|
||||
blocks_agents_record = PydanticBlocksAgents(agent_id=agent_id, block_id=block_id, block_label=block_label)
|
||||
blocks_agents_record = BlocksAgentsModel(**blocks_agents_record.model_dump(exclude_none=True))
|
||||
blocks_agents_record.create(session)
|
||||
|
||||
return blocks_agents_record.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def remove_block_with_label_from_agent(self, agent_id: str, block_label: str) -> PydanticBlocksAgents:
|
||||
"""Remove a block with a label from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Find and delete the block-label association for the agent
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||
blocks_agents_record.hard_delete(session)
|
||||
return blocks_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def remove_block_with_id_from_agent(self, agent_id: str, block_id: str) -> PydanticBlocksAgents:
|
||||
"""Remove a block with a label from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Find and delete the block-label association for the agent
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_id=block_id)
|
||||
blocks_agents_record.hard_delete(session)
|
||||
return blocks_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block id '{block_id}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def update_block_id_for_agent(self, agent_id: str, block_label: str, new_block_id: str) -> PydanticBlocksAgents:
|
||||
"""Update the block ID for a specific block label for an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||
blocks_agents_record.block_id = new_block_id
|
||||
return blocks_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def list_block_ids_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List all block ids associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
return [record.block_id for record in blocks_agents_record]
|
||||
|
||||
@enforce_types
|
||||
def list_block_labels_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List all block labels associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
blocks_agents_record = BlocksAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
return [record.block_label for record in blocks_agents_record]
|
||||
|
||||
@enforce_types
|
||||
def list_agent_ids_with_block(self, block_id: str) -> List[str]:
|
||||
"""List all agents associated with a specific block."""
|
||||
with self.session_maker() as session:
|
||||
blocks_agents_record = BlocksAgentsModel.list(db_session=session, block_id=block_id)
|
||||
return [record.agent_id for record in blocks_agents_record]
|
||||
|
||||
@enforce_types
|
||||
def get_block_id_for_label(self, agent_id: str, block_label: str) -> str:
|
||||
"""Get the block ID for a specific block label for an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
blocks_agents_record = BlocksAgentsModel.read(db_session=session, agent_id=agent_id, block_label=block_label)
|
||||
return blocks_agents_record.block_id
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Block label '{block_label}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def remove_all_agent_blocks(self, agent_id: str):
|
||||
for block_id in self.list_block_ids_for_agent(agent_id):
|
||||
self.remove_block_with_id_from_agent(agent_id, block_id)
|
90
letta/services/helpers/agent_manager_helper.py
Normal file
90
letta/services/helpers/agent_manager_helper.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.prompts import gpt_system
|
||||
from letta.schemas.agent import AgentType
|
||||
|
||||
|
||||
# Static methods
|
||||
def _process_relationship(
|
||||
session, agent: AgentModel, relationship_name: str, model_class, item_ids: List[str], allow_partial=False, replace=True
|
||||
):
|
||||
"""
|
||||
Generalized function to handle relationships like tools, sources, and blocks using item IDs.
|
||||
|
||||
Args:
|
||||
session: The database session.
|
||||
agent: The AgentModel instance.
|
||||
relationship_name: The name of the relationship attribute (e.g., 'tools', 'sources').
|
||||
model_class: The ORM class corresponding to the related items.
|
||||
item_ids: List of IDs to set or update.
|
||||
allow_partial: If True, allows missing items without raising errors.
|
||||
replace: If True, replaces the entire relationship; otherwise, extends it.
|
||||
|
||||
Raises:
|
||||
ValueError: If `allow_partial` is False and some IDs are missing.
|
||||
"""
|
||||
current_relationship = getattr(agent, relationship_name, [])
|
||||
if not item_ids:
|
||||
if replace:
|
||||
setattr(agent, relationship_name, [])
|
||||
return
|
||||
|
||||
# Retrieve models for the provided IDs
|
||||
found_items = session.query(model_class).filter(model_class.id.in_(item_ids)).all()
|
||||
|
||||
# Validate all items are found if allow_partial is False
|
||||
if not allow_partial and len(found_items) != len(item_ids):
|
||||
missing = set(item_ids) - {item.id for item in found_items}
|
||||
raise NoResultFound(f"Items not found in {relationship_name}: {missing}")
|
||||
|
||||
if replace:
|
||||
# Replace the relationship
|
||||
setattr(agent, relationship_name, found_items)
|
||||
else:
|
||||
# Extend the relationship (only add new items)
|
||||
current_ids = {item.id for item in current_relationship}
|
||||
new_items = [item for item in found_items if item.id not in current_ids]
|
||||
current_relationship.extend(new_items)
|
||||
|
||||
|
||||
def _process_tags(agent: AgentModel, tags: List[str], replace=True):
|
||||
"""
|
||||
Handles tags for an agent.
|
||||
|
||||
Args:
|
||||
agent: The AgentModel instance.
|
||||
tags: List of tags to set or update.
|
||||
replace: If True, replaces all tags; otherwise, extends them.
|
||||
"""
|
||||
if not tags:
|
||||
if replace:
|
||||
agent.tags = []
|
||||
return
|
||||
|
||||
# Ensure tags are unique and prepare for replacement/extension
|
||||
new_tags = {AgentsTags(agent_id=agent.id, tag=tag) for tag in set(tags)}
|
||||
if replace:
|
||||
agent.tags = list(new_tags)
|
||||
else:
|
||||
existing_tags = {t.tag for t in agent.tags}
|
||||
agent.tags.extend([tag for tag in new_tags if tag.tag not in existing_tags])
|
||||
|
||||
|
||||
def derive_system_message(agent_type: AgentType, system: Optional[str] = None):
|
||||
if system is None:
|
||||
# TODO: don't hardcode
|
||||
if agent_type == AgentType.memgpt_agent:
|
||||
system = gpt_system.get_system_text("memgpt_chat")
|
||||
elif agent_type == AgentType.o1_agent:
|
||||
system = gpt_system.get_system_text("memgpt_modified_o1")
|
||||
elif agent_type == AgentType.offline_memory_agent:
|
||||
system = gpt_system.get_system_text("memgpt_offline_memory")
|
||||
elif agent_type == AgentType.chat_only_agent:
|
||||
system = gpt_system.get_system_text("memgpt_convo_only")
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type: {agent_type}")
|
||||
|
||||
return system
|
@ -1,25 +1,25 @@
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.utils import enforce_types
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model, parse_and_chunk_text
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.passage import Passage as PassageModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class PassageManager:
|
||||
"""Manager class to handle business logic related to Passages."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
@ -43,20 +43,20 @@ class PassageManager:
|
||||
return [self.create_passage(p, actor) for p in passages]
|
||||
|
||||
@enforce_types
|
||||
def insert_passage(self,
|
||||
def insert_passage(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
agent_id: str,
|
||||
text: str,
|
||||
actor: PydanticUser,
|
||||
return_ids: bool = False
|
||||
text: str,
|
||||
actor: PydanticUser,
|
||||
) -> List[PydanticPassage]:
|
||||
""" Insert passage(s) into archival memory """
|
||||
"""Insert passage(s) into archival memory"""
|
||||
|
||||
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
embed_model = embedding_model(agent_state.embedding_config)
|
||||
|
||||
passages = []
|
||||
|
||||
|
||||
try:
|
||||
# breakup string into passages
|
||||
for text in parse_and_chunk_text(text, embedding_chunk_size):
|
||||
@ -75,12 +75,12 @@ class PassageManager:
|
||||
agent_id=agent_id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_config=agent_state.embedding_config
|
||||
embedding_config=agent_state.embedding_config,
|
||||
),
|
||||
actor=actor
|
||||
actor=actor,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
@ -125,20 +125,21 @@ class PassageManager:
|
||||
raise ValueError(f"Passage with id {passage_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def list_passages(self,
|
||||
actor : PydanticUser,
|
||||
agent_id : Optional[str] = None,
|
||||
file_id : Optional[str] = None,
|
||||
cursor : Optional[str] = None,
|
||||
limit : Optional[int] = 50,
|
||||
query_text : Optional[str] = None,
|
||||
start_date : Optional[datetime] = None,
|
||||
end_date : Optional[datetime] = None,
|
||||
ascending : bool = True,
|
||||
source_id : Optional[str] = None,
|
||||
embed_query : bool = False,
|
||||
embedding_config: Optional[EmbeddingConfig] = None
|
||||
) -> List[PydanticPassage]:
|
||||
def list_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
ascending: bool = True,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""List passages with pagination."""
|
||||
with self.session_maker() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
@ -148,7 +149,7 @@ class PassageManager:
|
||||
filters["file_id"] = file_id
|
||||
if source_id:
|
||||
filters["source_id"] = source_id
|
||||
|
||||
|
||||
embedded_text = None
|
||||
if embed_query:
|
||||
assert embedding_config is not None
|
||||
@ -161,7 +162,7 @@ class PassageManager:
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
results = PassageModel.list(
|
||||
db_session=session,
|
||||
db_session=session,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
@ -169,17 +170,12 @@ class PassageManager:
|
||||
ascending=ascending,
|
||||
query_text=query_text if not embedded_text else None,
|
||||
query_embedding=embedded_text,
|
||||
**filters
|
||||
**filters,
|
||||
)
|
||||
return [p.to_pydantic() for p in results]
|
||||
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
actor : PydanticUser,
|
||||
agent_id : Optional[str] = None,
|
||||
**kwargs
|
||||
) -> int:
|
||||
def size(self, actor: PydanticUser, agent_id: Optional[str] = None, **kwargs) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
|
||||
Args:
|
||||
@ -189,28 +185,32 @@ class PassageManager:
|
||||
with self.session_maker() as session:
|
||||
return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs)
|
||||
|
||||
def delete_passages(self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
cursor: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
source_id: Optional[str] = None
|
||||
) -> bool:
|
||||
|
||||
def delete_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
cursor: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
|
||||
passages = self.list_passages(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
cursor=cursor,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
query_text=query_text,
|
||||
source_id=source_id)
|
||||
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
query_text=query_text,
|
||||
source_id=source_id,
|
||||
)
|
||||
|
||||
# TODO: This is very inefficient
|
||||
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
||||
for passage in passages:
|
||||
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
||||
|
@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.file import FileMetadata as FileMetadataModel
|
||||
from letta.orm.source import Source as SourceModel
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
@ -60,7 +61,7 @@ class SourceManager:
|
||||
"""Delete a source by its ID."""
|
||||
with self.session_maker() as session:
|
||||
source = SourceModel.read(db_session=session, identifier=source_id)
|
||||
source.delete(db_session=session, actor=actor)
|
||||
source.hard_delete(db_session=session, actor=actor)
|
||||
return source.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@ -76,6 +77,26 @@ class SourceManager:
|
||||
)
|
||||
return [source.to_pydantic() for source in sources]
|
||||
|
||||
@enforce_types
|
||||
def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Lists all agents that have the specified source attached.
|
||||
|
||||
Args:
|
||||
source_id: ID of the source to find attached agents for
|
||||
actor: User performing the action (optional for now, following existing pattern)
|
||||
|
||||
Returns:
|
||||
List[PydanticAgentState]: List of agents that have this source attached
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Verify source exists and user has permission to access it
|
||||
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
|
||||
|
||||
# The agents relationship is already loaded due to lazy="selectin" in the Source model
|
||||
# and will be properly filtered by organization_id due to the OrganizationMixin
|
||||
return [agent.to_pydantic() for agent in source.agents]
|
||||
|
||||
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
||||
@enforce_types
|
||||
def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
|
||||
|
@ -1,94 +0,0 @@
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.tools_agents import ToolsAgents as ToolsAgentsModel
|
||||
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
|
||||
|
||||
class ToolsAgentsManager:
|
||||
"""Manages the relationship between tools and agents."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
self.session_maker = db_context
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str, tool_name: str) -> PydanticToolsAgents:
|
||||
"""Add a tool to an agent.
|
||||
|
||||
When a tool is added to an agent, it will be added to all agents in the same organization.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Check if the tool-agent combination already exists for this agent
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||
warnings.warn(f"Tool name '{tool_name}' already exists for agent '{agent_id}'.")
|
||||
except NoResultFound:
|
||||
tools_agents_record = PydanticToolsAgents(agent_id=agent_id, tool_id=tool_id, tool_name=tool_name)
|
||||
tools_agents_record = ToolsAgentsModel(**tools_agents_record.model_dump(exclude_none=True))
|
||||
tools_agents_record.create(session)
|
||||
|
||||
return tools_agents_record.to_pydantic()
|
||||
|
||||
def remove_tool_with_name_from_agent(self, agent_id: str, tool_name: str) -> None:
|
||||
"""Remove a tool from an agent by its name.
|
||||
|
||||
When a tool is removed from an agent, it will be removed from all agents in the same organization.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Find and delete the tool-agent association for the agent
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||
tools_agents_record.hard_delete(session)
|
||||
return tools_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
|
||||
|
||||
def remove_tool_with_id_from_agent(self, agent_id: str, tool_id: str) -> PydanticToolsAgents:
|
||||
"""Remove a tool with an ID from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_id=tool_id)
|
||||
tools_agents_record.hard_delete(session)
|
||||
return tools_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool ID '{tool_id}' not found for agent '{agent_id}'.")
|
||||
|
||||
def list_tool_ids_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List all tool IDs associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
return [record.tool_id for record in tools_agents_record]
|
||||
|
||||
def list_tool_names_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List all tool names associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
return [record.tool_name for record in tools_agents_record]
|
||||
|
||||
def list_agent_ids_with_tool(self, tool_id: str) -> List[str]:
|
||||
"""List all agents associated with a specific tool."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_record = ToolsAgentsModel.list(db_session=session, tool_id=tool_id)
|
||||
return [record.agent_id for record in tools_agents_record]
|
||||
|
||||
def get_tool_id_for_name(self, agent_id: str, tool_name: str) -> str:
|
||||
"""Get the tool ID for a specific tool name for an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||
return tools_agents_record.tool_id
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
|
||||
|
||||
def remove_all_agent_tools(self, agent_id: str) -> None:
|
||||
"""Remove all tools associated with an agent."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_records = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
for record in tools_agents_records:
|
||||
record.hard_delete(session)
|
@ -73,12 +73,6 @@ class UserManager:
|
||||
user = UserModel.read(db_session=session, identifier=user_id)
|
||||
user.hard_delete(session)
|
||||
|
||||
# TODO: Integrate this via the ORM models for the Agent, Source, and AgentSourceMapping
|
||||
# Cascade delete for related models: Agent, Source, AgentSourceMapping
|
||||
# session.query(AgentModel).filter(AgentModel.user_id == user_id).delete()
|
||||
# session.query(SourceModel).filter(SourceModel.user_id == user_id).delete()
|
||||
# session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.user_id == user_id).delete()
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
@ -93,6 +87,17 @@ class UserManager:
|
||||
"""Fetch the default user."""
|
||||
return self.get_user_by_id(self.DEFAULT_USER_ID)
|
||||
|
||||
@enforce_types
|
||||
def get_user_or_default(self, user_id: Optional[str] = None):
|
||||
"""Fetch the user or default user."""
|
||||
if not user_id:
|
||||
return self.get_default_user()
|
||||
|
||||
try:
|
||||
return self.get_user_by_id(user_id=user_id)
|
||||
except NoResultFound:
|
||||
return self.get_default_user()
|
||||
|
||||
@enforce_types
|
||||
def list_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> Tuple[Optional[str], List[PydanticUser]]:
|
||||
"""List users with pagination using cursor (id) and limit."""
|
||||
|
@ -548,13 +548,13 @@ def enforce_types(func):
|
||||
for arg_name, arg_value in args_with_hints.items():
|
||||
hint = hints.get(arg_name)
|
||||
if hint and not matches_type(arg_value, hint):
|
||||
raise ValueError(f"Argument {arg_name} does not match type {hint}")
|
||||
raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}")
|
||||
|
||||
# Check types of keyword arguments
|
||||
for arg_name, arg_value in kwargs.items():
|
||||
hint = hints.get(arg_name)
|
||||
if hint and not matches_type(arg_value, hint):
|
||||
raise ValueError(f"Argument {arg_name} does not match type {hint}")
|
||||
raise ValueError(f"Argument {arg_name} does not match type {hint}; is {arg_value}")
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
@ -4,7 +4,7 @@ import string
|
||||
from locust import HttpUser, between, task
|
||||
|
||||
from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.schemas.agent import CreateAgent, PersistedAgentState
|
||||
from letta.schemas.agent import AgentState, CreateAgent
|
||||
from letta.schemas.letta_request import LettaRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.memory import ChatMemory
|
||||
@ -49,7 +49,7 @@ class LettaUser(HttpUser):
|
||||
response.failure(f"Failed to create agent: {response.text}")
|
||||
|
||||
response_json = response.json()
|
||||
agent_state = PersistedAgentState(**response_json)
|
||||
agent_state = AgentState(**response_json)
|
||||
self.agent_id = agent_state.id
|
||||
print("Created agent", self.agent_id, agent_state.name)
|
||||
|
||||
|
@ -1,90 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import MetaData, Table, create_engine
|
||||
|
||||
from letta import create_client
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_types import AgentState, EmbeddingConfig, LLMConfig
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.presets.presets import add_default_tools
|
||||
from letta.prompts import gpt_system
|
||||
|
||||
# Replace this with your actual database connection URL
|
||||
config = LettaConfig.load()
|
||||
if config.recall_storage_type == "sqlite":
|
||||
DATABASE_URL = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
||||
else:
|
||||
DATABASE_URL = config.recall_storage_uri
|
||||
print(DATABASE_URL)
|
||||
engine = create_engine(DATABASE_URL)
|
||||
metadata = MetaData()
|
||||
|
||||
# defaults
|
||||
system_prompt = gpt_system.get_system_text("memgpt_chat")
|
||||
|
||||
# Reflect the existing table
|
||||
table = Table("agents", metadata, autoload_with=engine)
|
||||
|
||||
|
||||
# get all agent rows
|
||||
agent_states = []
|
||||
with engine.connect() as conn:
|
||||
agents = conn.execute(table.select()).fetchall()
|
||||
for agent in agents:
|
||||
id = uuid.UUID(agent[0])
|
||||
user_id = uuid.UUID(agent[1])
|
||||
name = agent[2]
|
||||
print(f"Migrating agent {name}")
|
||||
persona = agent[3]
|
||||
human = agent[4]
|
||||
system = agent[5]
|
||||
preset = agent[6]
|
||||
created_at = agent[7]
|
||||
llm_config = LLMConfig(**agent[8])
|
||||
embedding_config = EmbeddingConfig(**agent[9])
|
||||
state = agent[10]
|
||||
tools = agent[11]
|
||||
|
||||
state["memory"] = {"human": {"value": human, "limit": 2000}, "persona": {"value": persona, "limit": 2000}}
|
||||
|
||||
agent_state = AgentState(
|
||||
id=id,
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
system=system,
|
||||
created_at=created_at,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=state,
|
||||
tools=tools,
|
||||
_metadata={"human": "migrated", "persona": "migrated"},
|
||||
)
|
||||
|
||||
agent_states.append(agent_state)
|
||||
|
||||
# remove agents table
|
||||
agents_model = Table("agents", metadata, autoload_with=engine)
|
||||
agents_model.drop(engine)
|
||||
|
||||
# remove tool table
|
||||
tool_model = Table("toolmodel", metadata, autoload_with=engine)
|
||||
tool_model.drop(engine)
|
||||
|
||||
# re-create tables and add default tools
|
||||
ms = MetadataStore(config)
|
||||
add_default_tools(None, ms)
|
||||
print("Tools", [tool.name for tool in ms.list_tools()])
|
||||
|
||||
|
||||
for agent in agent_states:
|
||||
ms.create_agent(agent)
|
||||
print(f"Agent {agent.name} migrated successfully!")
|
||||
|
||||
# add another agent to create core memory tool
|
||||
client = create_client()
|
||||
dummy_agent = client.create_agent(name="dummy_agent")
|
||||
tools = client.list_tools()
|
||||
assert "core_memory_append" in [tool.name for tool in tools]
|
||||
|
||||
print("Migration completed successfully!")
|
@ -2,8 +2,6 @@ import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@ -11,6 +9,8 @@ def pytest_configure(config):
|
||||
|
||||
@pytest.fixture
|
||||
def mock_e2b_api_key_none():
|
||||
from letta.settings import tool_settings
|
||||
|
||||
# Store the original value of e2b_api_key
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
|
||||
|
@ -61,7 +61,7 @@ def setup_agent(
|
||||
filename: str,
|
||||
memory_human_str: str = get_human_text(DEFAULT_HUMAN),
|
||||
memory_persona_str: str = get_persona_text(DEFAULT_PERSONA),
|
||||
tools: Optional[List[str]] = None,
|
||||
tool_ids: Optional[List[str]] = None,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
agent_uuid: str = agent_uuid,
|
||||
) -> AgentState:
|
||||
@ -77,7 +77,7 @@ def setup_agent(
|
||||
|
||||
memory = ChatMemory(human=memory_human_str, persona=memory_persona_str)
|
||||
agent_state = client.create_agent(
|
||||
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tools=tools, tool_rules=tool_rules
|
||||
name=agent_uuid, llm_config=llm_config, embedding_config=embedding_config, memory=memory, tool_ids=tool_ids, tool_rules=tool_rules
|
||||
)
|
||||
|
||||
return agent_state
|
||||
@ -103,7 +103,6 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
agent_state = setup_agent(client, filename)
|
||||
|
||||
tools = [client.get_tool(client.get_tool_id(name=name)) for name in agent_state.tool_names]
|
||||
full_agent_state = client.get_agent(agent_state.id)
|
||||
agent = Agent(agent_state=full_agent_state, interface=None, user=client.user)
|
||||
|
||||
@ -171,19 +170,18 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
|
||||
tool_name = tool.name
|
||||
|
||||
# Set up persona for tool usage
|
||||
persona = f"""
|
||||
|
||||
My name is Letta.
|
||||
|
||||
I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool_name} which will search `example.com` and answer the relevant question.
|
||||
I am a personal assistant who answers a user's questions about a website `example.com`. When a user asks me a question about `example.com`, I will use a tool called {tool.name} which will search `example.com` and answer the relevant question.
|
||||
|
||||
Don’t forget - inner monologue / inner thoughts should always be different than the contents of send_message! send_message is how you communicate with the user, whereas inner thoughts are your own personal inner thoughts.
|
||||
"""
|
||||
|
||||
agent_state = setup_agent(client, filename, memory_persona_str=persona, tools=[tool_name])
|
||||
agent_state = setup_agent(client, filename, memory_persona_str=persona, tool_ids=[tool.id])
|
||||
|
||||
response = client.user_message(agent_id=agent_state.id, message="What's on the example.com website?")
|
||||
|
||||
@ -191,7 +189,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Make sure the tool was called
|
||||
assert_invoked_function_call(response.messages, tool_name)
|
||||
assert_invoked_function_call(response.messages, tool.name)
|
||||
|
||||
# Make sure some inner monologue is present
|
||||
assert_inner_monologue_is_present_and_valid(response.messages)
|
||||
@ -334,7 +332,7 @@ def check_agent_summarize_memory_simple(filename: str) -> LettaResponse:
|
||||
client.user_message(agent_id=agent_state.id, message="Does the number 42 ring a bell?")
|
||||
|
||||
# Summarize
|
||||
agent = client.server.load_agent(agent_id=agent_state.id)
|
||||
agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
agent.summarize_messages_inplace()
|
||||
print(f"Summarization succeeded: messages[1] = \n\n{json_dumps(agent.messages[1])}\n")
|
||||
|
||||
|
@ -3,6 +3,7 @@ from typing import Union
|
||||
from letta import LocalClient, RESTClient
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
||||
from letta.schemas.tool import Tool
|
||||
|
||||
|
||||
@ -24,3 +25,57 @@ def create_tool_from_func(func: callable):
|
||||
source_code=parse_source_code(func),
|
||||
json_schema=generate_schema(func, None),
|
||||
)
|
||||
|
||||
|
||||
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent]):
|
||||
# Assert scalar fields
|
||||
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
|
||||
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
|
||||
assert agent.metadata_ == request.metadata_, f"Metadata mismatch: {agent.metadata_} != {request.metadata_}"
|
||||
|
||||
# Assert agent type
|
||||
if hasattr(request, "agent_type"):
|
||||
assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}"
|
||||
|
||||
# Assert LLM configuration
|
||||
assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}"
|
||||
|
||||
# Assert embedding configuration
|
||||
assert (
|
||||
agent.embedding_config == request.embedding_config
|
||||
), f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
|
||||
|
||||
# Assert memory blocks
|
||||
if hasattr(request, "memory_blocks"):
|
||||
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(
|
||||
request.block_ids
|
||||
), f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
|
||||
memory_block_values = {block.value for block in agent.memory.blocks}
|
||||
expected_block_values = {block.value for block in request.memory_blocks}
|
||||
assert expected_block_values.issubset(
|
||||
memory_block_values
|
||||
), f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
|
||||
|
||||
# Assert tools
|
||||
assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}"
|
||||
assert {tool.id for tool in agent.tools} == set(
|
||||
request.tool_ids
|
||||
), f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
|
||||
|
||||
# Assert sources
|
||||
assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}"
|
||||
assert {source.id for source in agent.sources} == set(
|
||||
request.source_ids
|
||||
), f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
|
||||
|
||||
# Assert tags
|
||||
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
|
||||
|
||||
# Assert tool rules
|
||||
if request.tool_rules:
|
||||
assert len(agent.tool_rules) == len(
|
||||
request.tool_rules
|
||||
), f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
|
||||
assert all(
|
||||
any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules
|
||||
), f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"
|
||||
|
@ -99,7 +99,7 @@ def test_single_path_agent_tool_call_graph(mock_e2b_api_key_none):
|
||||
]
|
||||
|
||||
# Make agent state
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tools=[t.name for t in tools], tool_rules=tool_rules)
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="What is the fourth secret word?")
|
||||
|
||||
# Make checks
|
@ -17,7 +17,7 @@ def test_o1_agent():
|
||||
|
||||
agent_state = client.create_agent(
|
||||
agent_type=AgentType.o1_agent,
|
||||
tools=[thinking_tool.name, final_tool.name],
|
||||
tool_ids=[thinking_tool.id, final_tool.id],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
memory=ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text("o1_persona")),
|
@ -32,8 +32,10 @@ def clear_agents(client):
|
||||
for agent in client.list_agents():
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_ripple_edit(client, mock_e2b_api_key_none):
|
||||
trigger_rethink_memory_tool = client.create_or_update_tool(trigger_rethink_memory)
|
||||
send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user)
|
||||
|
||||
conversation_human_block = Block(name="human", label="human", value=get_human_text(DEFAULT_HUMAN), limit=2000)
|
||||
conversation_persona_block = Block(name="persona", label="persona", value=get_persona_text(DEFAULT_PERSONA), limit=2000)
|
||||
@ -64,7 +66,7 @@ def test_ripple_edit(client, mock_e2b_api_key_none):
|
||||
system=gpt_system.get_system_text("memgpt_convo_only"),
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
tools=["send_message", trigger_rethink_memory_tool.name],
|
||||
tool_ids=[send_message.id, trigger_rethink_memory_tool.id],
|
||||
memory=conversation_memory,
|
||||
include_base_tools=False,
|
||||
)
|
||||
@ -81,7 +83,7 @@ def test_ripple_edit(client, mock_e2b_api_key_none):
|
||||
memory=offline_memory,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
tools=[rethink_memory_tool.name, finish_rethinking_memory_tool.name],
|
||||
tool_ids=[rethink_memory_tool.id, finish_rethinking_memory_tool.id],
|
||||
tool_rules=[TerminalToolRule(tool_name=finish_rethinking_memory_tool.name)],
|
||||
include_base_tools=False,
|
||||
)
|
||||
@ -111,16 +113,16 @@ def test_chat_only_agent(client, mock_e2b_api_key_none):
|
||||
)
|
||||
conversation_memory = BasicBlockMemory(blocks=[conversation_persona_block, conversation_human_block])
|
||||
|
||||
client = create_client()
|
||||
send_message = client.server.tool_manager.get_tool_by_name(tool_name="send_message", actor=client.user)
|
||||
chat_only_agent = client.create_agent(
|
||||
name="conversation_agent",
|
||||
agent_type=AgentType.chat_only_agent,
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config("text-embedding-ada-002"),
|
||||
tools=["send_message"],
|
||||
tool_ids=[send_message.id],
|
||||
memory=conversation_memory,
|
||||
include_base_tools=False,
|
||||
metadata={"offline_memory_tools": [rethink_memory.name, finish_rethinking_memory.name]},
|
||||
metadata={"offline_memory_tools": [rethink_memory.id, finish_rethinking_memory.id]},
|
||||
)
|
||||
assert chat_only_agent is not None
|
||||
assert set(chat_only_agent.memory.list_block_labels()) == {"chat_agent_persona", "chat_agent_human"}
|
||||
@ -135,6 +137,7 @@ def test_chat_only_agent(client, mock_e2b_api_key_none):
|
||||
# Clean up agent
|
||||
client.delete_agent(chat_only_agent.id)
|
||||
|
||||
|
||||
def test_initial_message_sequence(client, mock_e2b_api_key_none):
|
||||
"""
|
||||
Test that when we set the initial sequence to an empty list,
|
||||
@ -150,8 +153,6 @@ def test_initial_message_sequence(client, mock_e2b_api_key_none):
|
||||
initial_message_sequence=[],
|
||||
)
|
||||
assert offline_memory_agent is not None
|
||||
assert len(offline_memory_agent.message_ids) == 1 # There should just the system message
|
||||
assert len(offline_memory_agent.message_ids) == 1 # There should just the system message
|
||||
|
||||
client.delete_agent(offline_memory_agent.id)
|
||||
|
||||
|
@ -1,41 +0,0 @@
|
||||
# TODO: add back
|
||||
|
||||
# import os
|
||||
# import subprocess
|
||||
#
|
||||
# import pytest
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
|
||||
# def test_agent_groupchat():
|
||||
#
|
||||
# # Define the path to the script you want to test
|
||||
# script_path = "letta/autogen/examples/agent_groupchat.py"
|
||||
#
|
||||
# # Dynamically get the project's root directory (assuming this script is run from the root)
|
||||
# # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
# # print(project_root)
|
||||
# # project_root = os.path.join(project_root, "Letta")
|
||||
# # print(project_root)
|
||||
# # sys.exit(1)
|
||||
#
|
||||
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
# project_root = os.path.join(project_root, "letta")
|
||||
# print(f"Adding the following to PATH: {project_root}")
|
||||
#
|
||||
# # Prepare the environment, adding the project root to PYTHONPATH
|
||||
# env = os.environ.copy()
|
||||
# env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}"
|
||||
#
|
||||
# # Run the script using subprocess.run
|
||||
# # Capture the output (stdout) and the exit code
|
||||
# # result = subprocess.run(["python", script_path], capture_output=True, text=True)
|
||||
# result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True)
|
||||
#
|
||||
# # Check the exit code (0 indicates success)
|
||||
# assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}"
|
||||
#
|
||||
# # Optionally, check the output for expected content
|
||||
# # For example, if you expect a specific line in the output, uncomment and adapt the following line:
|
||||
# # assert "expected output" in result.stdout, "Expected output not found in script's output"
|
||||
#
|
@ -23,7 +23,7 @@ def agent_obj():
|
||||
agent_state = client.create_agent()
|
||||
|
||||
global agent_obj
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id)
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
yield agent_obj
|
||||
|
||||
client.delete_agent(agent_obj.agent_state.id)
|
||||
@ -35,49 +35,50 @@ def query_in_search_results(search_results, query):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_archival(agent_obj):
|
||||
"""Test archival memory functions comprehensively."""
|
||||
# Test 1: Basic insertion and retrieval
|
||||
base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat")
|
||||
base_functions.archival_memory_insert(agent_obj, "The dog plays in the park")
|
||||
base_functions.archival_memory_insert(agent_obj, "Python is a programming language")
|
||||
|
||||
|
||||
# Test exact text search
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "cat")
|
||||
assert query_in_search_results(results, "cat")
|
||||
|
||||
|
||||
# Test semantic search (should return animal-related content)
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "animal pets")
|
||||
assert query_in_search_results(results, "cat") or query_in_search_results(results, "dog")
|
||||
|
||||
|
||||
# Test unrelated search (should not return animal content)
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "programming computers")
|
||||
assert query_in_search_results(results, "python")
|
||||
|
||||
|
||||
# Test 2: Test pagination
|
||||
# Insert more items to test pagination
|
||||
for i in range(10):
|
||||
base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}")
|
||||
|
||||
|
||||
# Get first page
|
||||
page0_results, next_page = base_functions.archival_memory_search(agent_obj, "Test passage", page=0)
|
||||
# Get second page
|
||||
page1_results, _ = base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page)
|
||||
|
||||
|
||||
assert page0_results != page1_results
|
||||
assert query_in_search_results(page0_results, "Test passage")
|
||||
assert query_in_search_results(page1_results, "Test passage")
|
||||
|
||||
|
||||
# Test 3: Test complex text patterns
|
||||
base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John")
|
||||
base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week")
|
||||
base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching")
|
||||
|
||||
|
||||
# Search for meeting-related content
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "meeting schedule")
|
||||
assert query_in_search_results(results, "meeting")
|
||||
assert query_in_search_results(results, "2024-01-15") or query_in_search_results(results, "next week")
|
||||
|
||||
|
||||
# Test 4: Test error handling
|
||||
# Test invalid page number
|
||||
try:
|
||||
@ -85,7 +86,7 @@ def test_archival(agent_obj):
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def test_recall(agent_obj):
|
||||
base_functions.conversation_search(agent_obj, "banana")
|
||||
|
@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@ -42,8 +41,8 @@ def run_server():
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
# params=[{"server": False}, {"server": True}], # whether to use REST API server
|
||||
params=[{"server": False}], # whether to use REST API server
|
||||
params=[{"server": False}, {"server": True}], # whether to use REST API server
|
||||
# params=[{"server": True}], # whether to use REST API server
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
@ -69,6 +68,7 @@ def client(request):
|
||||
@pytest.fixture(scope="module")
|
||||
def agent(client: Union[LocalClient, RESTClient]):
|
||||
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}")
|
||||
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
@ -86,6 +86,47 @@ def clear_tables():
|
||||
session.commit()
|
||||
|
||||
|
||||
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient]):
|
||||
# _reset_config()
|
||||
|
||||
# create a block
|
||||
block = client.create_block(label="human", value="username: sarah")
|
||||
|
||||
# create agents with shared block
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.memory import BasicBlockMemory
|
||||
|
||||
# persona1_block = client.create_block(label="persona", value="you are agent 1")
|
||||
# persona2_block = client.create_block(label="persona", value="you are agent 2")
|
||||
# create agents
|
||||
agent_state1 = client.create_agent(
|
||||
name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1")]), block_ids=[block.id]
|
||||
)
|
||||
agent_state2 = client.create_agent(
|
||||
name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2")]), block_ids=[block.id]
|
||||
)
|
||||
|
||||
## attach shared block to both agents
|
||||
# client.link_agent_memory_block(agent_state1.id, block.id)
|
||||
# client.link_agent_memory_block(agent_state2.id, block.id)
|
||||
|
||||
# update memory
|
||||
client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
|
||||
|
||||
# check agent 2 memory
|
||||
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
|
||||
|
||||
client.user_message(agent_id=agent_state2.id, message="whats my name?")
|
||||
assert (
|
||||
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
|
||||
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
|
||||
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
|
||||
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
client.delete_agent(agent_state2.id)
|
||||
|
||||
|
||||
def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]):
|
||||
"""
|
||||
Test sandbox config and environment variable functions for both LocalClient and RESTClient.
|
||||
@ -137,15 +178,15 @@ def test_sandbox_config_and_env_var_basic(client: Union[LocalClient, RESTClient]
|
||||
client.delete_sandbox_config(sandbox_config_id=sandbox_config.id)
|
||||
|
||||
|
||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient]):
|
||||
"""
|
||||
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
|
||||
"""
|
||||
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]
|
||||
|
||||
# Step 0: create an agent with tags
|
||||
tagged_agent = client.create_agent(tags=tags_to_add)
|
||||
assert set(tagged_agent.tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {tagged_agent.tags}"
|
||||
# Step 0: create an agent with no tags
|
||||
agent = client.create_agent()
|
||||
assert len(agent.tags) == 0
|
||||
|
||||
# Step 1: Add multiple tags to the agent
|
||||
client.update_agent(agent_id=agent.id, tags=tags_to_add)
|
||||
@ -175,6 +216,9 @@ def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], a
|
||||
final_tags = client.get_agent(agent_id=agent.id).tags
|
||||
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
|
||||
|
||||
# Remove agent
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_update_agent_memory_label(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test that we can update the label of a block in an agent's memory"""
|
||||
@ -255,35 +299,33 @@ def test_add_remove_agent_memory_block(client: Union[LocalClient, RESTClient], a
|
||||
# client.delete_agent(new_agent.id)
|
||||
|
||||
|
||||
def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient]):
|
||||
"""Test that we can update the limit of a block in an agent's memory"""
|
||||
|
||||
agent = client.create_agent(name=create_random_username())
|
||||
agent = client.create_agent()
|
||||
|
||||
try:
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
example_label = current_labels[0]
|
||||
example_new_limit = 1
|
||||
current_block = agent.memory.get_block(label=example_label)
|
||||
current_block_length = len(current_block.value)
|
||||
current_labels = agent.memory.list_block_labels()
|
||||
example_label = current_labels[0]
|
||||
example_new_limit = 1
|
||||
current_block = agent.memory.get_block(label=example_label)
|
||||
current_block_length = len(current_block.value)
|
||||
|
||||
assert example_new_limit != agent.memory.get_block(label=example_label).limit
|
||||
assert example_new_limit < current_block_length
|
||||
assert example_new_limit != agent.memory.get_block(label=example_label).limit
|
||||
assert example_new_limit < current_block_length
|
||||
|
||||
# We expect this to throw a value error
|
||||
with pytest.raises(ValueError):
|
||||
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
|
||||
|
||||
# Now try the same thing with a higher limit
|
||||
example_new_limit = current_block_length + 10000
|
||||
assert example_new_limit > current_block_length
|
||||
# We expect this to throw a value error
|
||||
with pytest.raises(ValueError):
|
||||
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
|
||||
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit
|
||||
# Now try the same thing with a higher limit
|
||||
example_new_limit = current_block_length + 10000
|
||||
assert example_new_limit > current_block_length
|
||||
client.update_agent_memory_block(agent_id=agent.id, label=example_label, limit=example_new_limit)
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent.id)
|
||||
updated_agent = client.get_agent(agent_id=agent.id)
|
||||
assert example_new_limit == updated_agent.memory.get_block(label=example_label).limit
|
||||
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
@ -316,7 +358,7 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
padding = len("[NOTE: function output was truncated since it exceeded the character limit (100000 > 1000)]") + 50
|
||||
tool = client.create_or_update_tool(func=big_return, return_char_limit=1000)
|
||||
agent = client.create_agent(name="agent1", tools=[tool.name])
|
||||
agent = client.create_agent(tool_ids=[tool.id])
|
||||
# get function response
|
||||
response = client.send_message(agent_id=agent.id, message="call the big_return function", role="user")
|
||||
print(response.messages)
|
||||
@ -330,10 +372,14 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]):
|
||||
assert response_message, "FunctionReturn message not found in response"
|
||||
res = response_message.function_return
|
||||
assert "function output was truncated " in res
|
||||
res_json = json.loads(res)
|
||||
assert (
|
||||
len(res_json["message"]) <= 1000 + padding
|
||||
), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
|
||||
|
||||
# TODO: Re-enable later
|
||||
# res_json = json.loads(res)
|
||||
# assert (
|
||||
# len(res_json["message"]) <= 1000 + padding
|
||||
# ), f"Expected length to be less than or equal to 1000 + {padding}, but got {len(res_json['message'])}"
|
||||
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -583,43 +583,6 @@ def test_list_llm_models(client: RESTClient):
|
||||
assert has_model_endpoint_type(models, "anthropic")
|
||||
|
||||
|
||||
def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
# create a block
|
||||
block = client.create_block(label="human", value="username: sarah")
|
||||
|
||||
# create agents with shared block
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.memory import BasicBlockMemory
|
||||
|
||||
# persona1_block = client.create_block(label="persona", value="you are agent 1")
|
||||
# persona2_block = client.create_block(label="persona", value="you are agent 2")
|
||||
# create agnets
|
||||
agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory([Block(label="persona", value="you are agent 1"), block]))
|
||||
agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory([Block(label="persona", value="you are agent 2"), block]))
|
||||
|
||||
## attach shared block to both agents
|
||||
# client.link_agent_memory_block(agent_state1.id, block.id)
|
||||
# client.link_agent_memory_block(agent_state2.id, block.id)
|
||||
|
||||
# update memory
|
||||
response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
|
||||
|
||||
# check agent 2 memory
|
||||
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
|
||||
|
||||
response = client.user_message(agent_id=agent_state2.id, message="whats my name?")
|
||||
assert (
|
||||
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
|
||||
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
|
||||
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
|
||||
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
client.delete_agent(agent_state2.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cleanup_agents(client):
|
||||
created_agents = []
|
||||
|
@ -1,142 +0,0 @@
|
||||
# TODO: add back when messaging works
|
||||
|
||||
# import os
|
||||
# import threading
|
||||
# import time
|
||||
# import uuid
|
||||
#
|
||||
# import pytest
|
||||
# from dotenv import load_dotenv
|
||||
#
|
||||
# from letta import Admin, create_client
|
||||
# from letta.config import LettaConfig
|
||||
# from letta.credentials import LettaCredentials
|
||||
# from letta.settings import settings
|
||||
# from tests.utils import create_config
|
||||
#
|
||||
# test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
## test_preset_name = "test_preset"
|
||||
# test_agent_state = None
|
||||
# client = None
|
||||
#
|
||||
# test_agent_state_post_message = None
|
||||
# test_user_id = uuid.uuid4()
|
||||
#
|
||||
#
|
||||
## admin credentials
|
||||
# test_server_token = "test_server_token"
|
||||
#
|
||||
#
|
||||
# def _reset_config():
|
||||
#
|
||||
# # Use os.getenv with a fallback to os.environ.get
|
||||
# db_url = settings.letta_pg_uri
|
||||
#
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# create_config("openai")
|
||||
# credentials = LettaCredentials(
|
||||
# openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# )
|
||||
# else: # hosted
|
||||
# create_config("letta_hosted")
|
||||
# credentials = LettaCredentials()
|
||||
#
|
||||
# config = LettaConfig.load()
|
||||
#
|
||||
# # set to use postgres
|
||||
# config.archival_storage_uri = db_url
|
||||
# config.recall_storage_uri = db_url
|
||||
# config.metadata_storage_uri = db_url
|
||||
# config.archival_storage_type = "postgres"
|
||||
# config.recall_storage_type = "postgres"
|
||||
# config.metadata_storage_type = "postgres"
|
||||
#
|
||||
# config.save()
|
||||
# credentials.save()
|
||||
# print("_reset_config :: ", config.config_path)
|
||||
#
|
||||
#
|
||||
# def run_server():
|
||||
#
|
||||
# load_dotenv()
|
||||
#
|
||||
# _reset_config()
|
||||
#
|
||||
# from letta.server.rest_api.server import start_server
|
||||
#
|
||||
# print("Starting server...")
|
||||
# start_server(debug=True)
|
||||
#
|
||||
#
|
||||
## Fixture to create clients with different configurations
|
||||
# @pytest.fixture(
|
||||
# params=[ # whether to use REST API server
|
||||
# {"server": True},
|
||||
# # {"server": False} # TODO: add when implemented
|
||||
# ],
|
||||
# scope="module",
|
||||
# )
|
||||
# def admin_client(request):
|
||||
# if request.param["server"]:
|
||||
# # get URL from enviornment
|
||||
# server_url = os.getenv("MEMGPT_SERVER_URL")
|
||||
# if server_url is None:
|
||||
# # run server in thread
|
||||
# # NOTE: must set MEMGPT_SERVER_PASS enviornment variable
|
||||
# server_url = "http://localhost:8283"
|
||||
# print("Starting server thread")
|
||||
# thread = threading.Thread(target=run_server, daemon=True)
|
||||
# thread.start()
|
||||
# time.sleep(5)
|
||||
# print("Running client tests with server:", server_url)
|
||||
# # create user via admin client
|
||||
# admin = Admin(server_url, test_server_token)
|
||||
# response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
#
|
||||
# yield admin
|
||||
#
|
||||
#
|
||||
# def test_concurrent_messages(admin_client):
|
||||
# # test concurrent messages
|
||||
#
|
||||
# # create three
|
||||
#
|
||||
# results = []
|
||||
#
|
||||
# def _send_message():
|
||||
# try:
|
||||
# print("START SEND MESSAGE")
|
||||
# response = admin_client.create_user()
|
||||
# token = response.api_key
|
||||
# client = create_client(base_url=admin_client.base_url, token=token)
|
||||
# agent = client.create_agent()
|
||||
#
|
||||
# print("Agent created", agent.id)
|
||||
#
|
||||
# st = time.time()
|
||||
# message = "Hello, how are you?"
|
||||
# response = client.send_message(agent_id=agent.id, message=message, role="user")
|
||||
# et = time.time()
|
||||
# print(f"Message sent from {st} to {et}")
|
||||
# print(response.messages)
|
||||
# results.append((st, et))
|
||||
# except Exception as e:
|
||||
# print("ERROR", e)
|
||||
#
|
||||
# threads = []
|
||||
# print("Starting threads...")
|
||||
# for i in range(5):
|
||||
# thread = threading.Thread(target=_send_message)
|
||||
# threads.append(thread)
|
||||
# thread.start()
|
||||
# print("CREATED THREAD")
|
||||
#
|
||||
# print("waiting for threads to finish...")
|
||||
# for thread in threads:
|
||||
# print(thread.join())
|
||||
#
|
||||
# # make sure runtime are overlapping
|
||||
# assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
|
||||
# results[1][0] < results[0][0] and results[1][1] > results[0][0]
|
||||
# ), f"Threads should have overlapping runtimes {results}"
|
||||
#
|
@ -1,121 +0,0 @@
|
||||
# TODO: add back once tests are cleaned up
|
||||
|
||||
# import os
|
||||
# import uuid
|
||||
#
|
||||
# from letta import create_client
|
||||
# from letta.agent_store.storage import StorageConnector, TableType
|
||||
# from letta.schemas.passage import Passage
|
||||
# from letta.embeddings import embedding_model
|
||||
# from tests import TEST_MEMGPT_CONFIG
|
||||
#
|
||||
# from .utils import create_config, wipe_config
|
||||
#
|
||||
# test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
# test_agent_state = None
|
||||
# client = None
|
||||
#
|
||||
# test_agent_state_post_message = None
|
||||
# test_user_id = uuid.uuid4()
|
||||
#
|
||||
#
|
||||
# def generate_passages(user, agent):
|
||||
# # Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
# texts = [
|
||||
# "This is a test passage",
|
||||
# "This is another test passage",
|
||||
# "Cinderella wept",
|
||||
# ]
|
||||
# embed_model = embedding_model(agent.embedding_config)
|
||||
# orig_embeddings = []
|
||||
# passages = []
|
||||
# for text in texts:
|
||||
# embedding = embed_model.get_text_embedding(text)
|
||||
# orig_embeddings.append(list(embedding))
|
||||
# passages.append(
|
||||
# Passage(
|
||||
# user_id=user.id,
|
||||
# agent_id=agent.id,
|
||||
# text=text,
|
||||
# embedding=embedding,
|
||||
# embedding_dim=agent.embedding_config.embedding_dim,
|
||||
# embedding_model=agent.embedding_config.embedding_model,
|
||||
# )
|
||||
# )
|
||||
# return passages, orig_embeddings
|
||||
#
|
||||
#
|
||||
# def test_create_user():
|
||||
# if not os.getenv("OPENAI_API_KEY"):
|
||||
# print("Skipping test, missing OPENAI_API_KEY")
|
||||
# return
|
||||
#
|
||||
# wipe_config()
|
||||
#
|
||||
# # create client
|
||||
# create_config("openai")
|
||||
# client = create_client()
|
||||
#
|
||||
# # openai: create agent
|
||||
# openai_agent = client.create_agent(
|
||||
# name="openai_agent",
|
||||
# )
|
||||
# assert (
|
||||
# openai_agent.embedding_config.embedding_endpoint_type == "openai"
|
||||
# ), f"openai_agent.embedding_config.embedding_endpoint_type={openai_agent.embedding_config.embedding_endpoint_type}"
|
||||
#
|
||||
# # openai: add passages
|
||||
# passages, openai_embeddings = generate_passages(client.user, openai_agent)
|
||||
# openai_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=openai_agent.id)
|
||||
# openai_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
#
|
||||
# # create client
|
||||
# create_config("letta_hosted")
|
||||
# client = create_client()
|
||||
#
|
||||
# # hosted: create agent
|
||||
# hosted_agent = client.create_agent(
|
||||
# name="hosted_agent",
|
||||
# )
|
||||
# # check to make sure endpoint overriden
|
||||
# assert (
|
||||
# hosted_agent.embedding_config.embedding_endpoint_type == "hugging-face"
|
||||
# ), f"hosted_agent.embedding_config.embedding_endpoint_type={hosted_agent.embedding_config.embedding_endpoint_type}"
|
||||
#
|
||||
# # hosted: add passages
|
||||
# passages, hosted_embeddings = generate_passages(client.user, hosted_agent)
|
||||
# hosted_agent_run = client.server.load_agent(user_id=client.user.id, agent_id=hosted_agent.id)
|
||||
# hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
#
|
||||
# # test passage dimentionality
|
||||
# storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id)
|
||||
# storage.filters = {} # clear filters to be able to get all passages
|
||||
# passages = storage.get_all()
|
||||
# for passage in passages:
|
||||
# if passage.agent_id == hosted_agent.id:
|
||||
# assert (
|
||||
# passage.embedding_dim == hosted_agent.embedding_config.embedding_dim
|
||||
# ), f"passage.embedding_dim={passage.embedding_dim} != hosted_agent.embedding_config.embedding_dim={hosted_agent.embedding_config.embedding_dim}"
|
||||
#
|
||||
# # ensure was in original embeddings
|
||||
# embedding = passage.embedding[: passage.embedding_dim]
|
||||
# assert embedding in hosted_embeddings, f"embedding={embedding} not in hosted_embeddings={hosted_embeddings}"
|
||||
#
|
||||
# # make sure all zeros
|
||||
# assert not any(
|
||||
# passage.embedding[passage.embedding_dim :]
|
||||
# ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"
|
||||
# elif passage.agent_id == openai_agent.id:
|
||||
# assert (
|
||||
# passage.embedding_dim == openai_agent.embedding_config.embedding_dim
|
||||
# ), f"passage.embedding_dim={passage.embedding_dim} != openai_agent.embedding_config.embedding_dim={openai_agent.embedding_config.embedding_dim}"
|
||||
#
|
||||
# # ensure was in original embeddings
|
||||
# embedding = passage.embedding[: passage.embedding_dim]
|
||||
# assert embedding in openai_embeddings, f"embedding={embedding} not in openai_embeddings={openai_embeddings}"
|
||||
#
|
||||
# # make sure all zeros
|
||||
# assert not any(
|
||||
# passage.embedding[passage.embedding_dim :]
|
||||
# ), f"passage.embedding[passage.embedding_dim:]={passage.embedding[passage.embedding_dim:]}"
|
||||
#
|
@ -1,48 +0,0 @@
|
||||
import letta.system as system
|
||||
from letta.local_llm.function_parser import patch_function
|
||||
from letta.utils import json_dumps
|
||||
|
||||
EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = {
|
||||
"message_history": [
|
||||
{"role": "user", "content": system.package_user_message("hello")},
|
||||
],
|
||||
# "new_message": {
|
||||
# "role": "function",
|
||||
# "name": "send_message",
|
||||
# "content": system.package_function_response(was_success=True, response_string="None"),
|
||||
# },
|
||||
"new_message": {
|
||||
"role": "assistant",
|
||||
"content": "I'll send a message.",
|
||||
"function_call": {
|
||||
"name": "send_message",
|
||||
"arguments": "null",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = {
|
||||
"message_history": [
|
||||
{"role": "user", "content": system.package_user_message("hello")},
|
||||
],
|
||||
"new_message": {
|
||||
"role": "assistant",
|
||||
"content": "I'll append to memory.",
|
||||
"function_call": {
|
||||
"name": "core_memory_append",
|
||||
"arguments": json_dumps({"content": "new_stuff"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_function_parsers():
|
||||
"""Try various broken JSON and check that the parsers can fix it"""
|
||||
|
||||
og_message = EXAMPLE_FUNCTION_CALL_SEND_MESSAGE["new_message"]
|
||||
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_SEND_MESSAGE)
|
||||
assert corrected_message == og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"
|
||||
|
||||
og_message = EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING["new_message"].copy()
|
||||
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING)
|
||||
assert corrected_message != og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"
|
@ -1,99 +0,0 @@
|
||||
import letta.local_llm.json_parser as json_parser
|
||||
from letta.utils import json_loads
|
||||
|
||||
EXAMPLE_ESCAPED_UNDERSCORES = """{
|
||||
"function":"send\_message",
|
||||
"params": {
|
||||
"inner\_thoughts": "User is asking for information about themselves. Retrieving data from core memory.",
|
||||
"message": "I know that you are Chad. Is there something specific you would like to know or talk about regarding yourself?"
|
||||
"""
|
||||
|
||||
|
||||
EXAMPLE_MISSING_CLOSING_BRACE = """{
|
||||
"function": "send_message",
|
||||
"params": {
|
||||
"inner_thoughts": "Oops, I got their name wrong! I should apologize and correct myself.",
|
||||
"message": "Sorry about that! I assumed you were Chad. Welcome, Brad! "
|
||||
}
|
||||
"""
|
||||
|
||||
EXAMPLE_BAD_TOKEN_END = """{
|
||||
"function": "send_message",
|
||||
"params": {
|
||||
"inner_thoughts": "Oops, I got their name wrong! I should apologize and correct myself.",
|
||||
"message": "Sorry about that! I assumed you were Chad. Welcome, Brad! "
|
||||
}
|
||||
}<|>"""
|
||||
|
||||
EXAMPLE_DOUBLE_JSON = """{
|
||||
"function": "core_memory_append",
|
||||
"params": {
|
||||
"name": "human",
|
||||
"content": "Brad, 42 years old, from Germany."
|
||||
}
|
||||
}
|
||||
{
|
||||
"function": "send_message",
|
||||
"params": {
|
||||
"message": "Got it! Your age and nationality are now saved in my memory."
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
EXAMPLE_HARD_LINE_FEEDS = """{
|
||||
"function": "send_message",
|
||||
"params": {
|
||||
"message": "Let's create a list:
|
||||
- First, we can do X
|
||||
- Then, we can do Y!
|
||||
- Lastly, we can do Z :)"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# Situation where beginning of send_message call is fine (and thus can be extracted)
|
||||
# but has a long training garbage string that comes after
|
||||
EXAMPLE_SEND_MESSAGE_PREFIX_OK_REST_BAD = """{
|
||||
"function": "send_message",
|
||||
"params": {
|
||||
"inner_thoughts": "User request for debug assistance",
|
||||
"message": "Of course, Chad. Please check the system log file for 'assistant.json' and send me the JSON output you're getting. Armed with that data, I'll assist you in debugging the issue.",
|
||||
GARBAGEGARBAGEGARBAGEGARBAGE
|
||||
GARBAGEGARBAGEGARBAGEGARBAGE
|
||||
GARBAGEGARBAGEGARBAGEGARBAGE
|
||||
"""
|
||||
|
||||
EXAMPLE_ARCHIVAL_SEARCH = """
|
||||
|
||||
{
|
||||
"function": "archival_memory_search",
|
||||
"params": {
|
||||
"inner_thoughts": "Looking for WaitingForAction.",
|
||||
"query": "WaitingForAction",
|
||||
"""
|
||||
|
||||
|
||||
def test_json_parsers():
|
||||
"""Try various broken JSON and check that the parsers can fix it"""
|
||||
|
||||
test_strings = [
|
||||
EXAMPLE_ESCAPED_UNDERSCORES,
|
||||
EXAMPLE_MISSING_CLOSING_BRACE,
|
||||
EXAMPLE_BAD_TOKEN_END,
|
||||
EXAMPLE_DOUBLE_JSON,
|
||||
EXAMPLE_HARD_LINE_FEEDS,
|
||||
EXAMPLE_SEND_MESSAGE_PREFIX_OK_REST_BAD,
|
||||
EXAMPLE_ARCHIVAL_SEARCH,
|
||||
]
|
||||
|
||||
for string in test_strings:
|
||||
try:
|
||||
json_loads(string)
|
||||
assert False, f"Test JSON string should have failed basic JSON parsing:\n{string}"
|
||||
except:
|
||||
print("String failed (expectedly)")
|
||||
try:
|
||||
json_parser.clean_json(string)
|
||||
except:
|
||||
f"Failed to repair test JSON string:\n{string}"
|
||||
raise
|
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient
|
||||
from letta.schemas.agent import PersistedAgentState
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
|
||||
@ -13,6 +13,7 @@ from letta.schemas.memory import BasicBlockMemory, ChatMemory, Memory
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
client = create_client()
|
||||
# client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
|
||||
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
|
||||
|
||||
@ -29,7 +30,6 @@ def agent(client):
|
||||
yield agent_state
|
||||
|
||||
client.delete_agent(agent_state.id)
|
||||
assert client.get_agent(agent_state.id) is None, f"Failed to properly delete agent {agent_state.id}"
|
||||
|
||||
|
||||
def test_agent(client: LocalClient):
|
||||
@ -80,16 +80,15 @@ def test_agent(client: LocalClient):
|
||||
assert isinstance(agent_state.memory, Memory)
|
||||
# update agent: tools
|
||||
tool_to_delete = "send_message"
|
||||
assert tool_to_delete in agent_state.tool_names
|
||||
new_agent_tools = [t_name for t_name in agent_state.tool_names if t_name != tool_to_delete]
|
||||
client.update_agent(agent_state_test.id, tools=new_agent_tools)
|
||||
assert client.get_agent(agent_state_test.id).tool_names == new_agent_tools
|
||||
assert tool_to_delete in [t.name for t in agent_state.tools]
|
||||
new_agent_tool_ids = [t.id for t in agent_state.tools if t.name != tool_to_delete]
|
||||
client.update_agent(agent_state_test.id, tool_ids=new_agent_tool_ids)
|
||||
assert sorted([t.id for t in client.get_agent(agent_state_test.id).tools]) == sorted(new_agent_tool_ids)
|
||||
|
||||
assert isinstance(agent_state.memory, Memory)
|
||||
# update agent: memory
|
||||
new_human = "My name is Mr Test, 100 percent human."
|
||||
new_persona = "I am an all-knowing AI."
|
||||
new_memory = ChatMemory(human=new_human, persona=new_persona)
|
||||
assert agent_state.memory.get_block("human").value != new_human
|
||||
assert agent_state.memory.get_block("persona").value != new_persona
|
||||
|
||||
@ -216,7 +215,7 @@ def test_agent_with_shared_blocks(client: LocalClient):
|
||||
client.delete_agent(second_agent_state_test.id)
|
||||
|
||||
|
||||
def test_memory(client: LocalClient, agent: PersistedAgentState):
|
||||
def test_memory(client: LocalClient, agent: AgentState):
|
||||
# get agent memory
|
||||
original_memory = client.get_in_context_memory(agent.id)
|
||||
assert original_memory is not None
|
||||
@ -229,7 +228,7 @@ def test_memory(client: LocalClient, agent: PersistedAgentState):
|
||||
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
|
||||
|
||||
|
||||
def test_archival_memory(client: LocalClient, agent: PersistedAgentState):
|
||||
def test_archival_memory(client: LocalClient, agent: AgentState):
|
||||
"""Test functions for interacting with archival memory store"""
|
||||
|
||||
# add archival memory
|
||||
@ -244,7 +243,7 @@ def test_archival_memory(client: LocalClient, agent: PersistedAgentState):
|
||||
client.delete_archival_memory(agent.id, passage.id)
|
||||
|
||||
|
||||
def test_recall_memory(client: LocalClient, agent: PersistedAgentState):
|
||||
def test_recall_memory(client: LocalClient, agent: AgentState):
|
||||
"""Test functions for interacting with recall memory store"""
|
||||
|
||||
# send message to the agent
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,126 +0,0 @@
|
||||
# TODO: fix later
|
||||
|
||||
# import os
|
||||
# import random
|
||||
# import string
|
||||
# import unittest.mock
|
||||
#
|
||||
# import pytest
|
||||
#
|
||||
# from letta.cli.cli_config import add, delete, list
|
||||
# from letta.config import LettaConfig
|
||||
# from letta.credentials import LettaCredentials
|
||||
# from tests.utils import create_config
|
||||
#
|
||||
#
|
||||
# def _reset_config():
|
||||
#
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# create_config("openai")
|
||||
# credentials = LettaCredentials(
|
||||
# openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# )
|
||||
# else: # hosted
|
||||
# create_config("letta_hosted")
|
||||
# credentials = LettaCredentials()
|
||||
#
|
||||
# config = LettaConfig.load()
|
||||
# config.save()
|
||||
# credentials.save()
|
||||
# print("_reset_config :: ", config.config_path)
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skip(reason="This is a helper function.")
|
||||
# def generate_random_string(length):
|
||||
# characters = string.ascii_letters + string.digits
|
||||
# random_string = "".join(random.choices(characters, k=length))
|
||||
# return random_string
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skip(reason="Ensures LocalClient is used during testing.")
|
||||
# def unset_env_variables():
|
||||
# server_url = os.environ.pop("MEMGPT_BASE_URL", None)
|
||||
# token = os.environ.pop("MEMGPT_SERVER_PASS", None)
|
||||
# return server_url, token
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skip(reason="Set env variables back to values before test.")
|
||||
# def reset_env_variables(server_url, token):
|
||||
# if server_url is not None:
|
||||
# os.environ["MEMGPT_BASE_URL"] = server_url
|
||||
# if token is not None:
|
||||
# os.environ["MEMGPT_SERVER_PASS"] = token
|
||||
#
|
||||
#
|
||||
# def test_crud_human(capsys):
|
||||
# _reset_config()
|
||||
#
|
||||
# server_url, token = unset_env_variables()
|
||||
#
|
||||
# # Initialize values that won't interfere with existing ones
|
||||
# human_1 = generate_random_string(16)
|
||||
# text_1 = generate_random_string(32)
|
||||
# human_2 = generate_random_string(16)
|
||||
# text_2 = generate_random_string(32)
|
||||
# text_3 = generate_random_string(32)
|
||||
#
|
||||
# # Add inital human
|
||||
# add("human", human_1, text_1)
|
||||
#
|
||||
# # Expect inital human to be listed
|
||||
# list("humans")
|
||||
# captured = capsys.readouterr()
|
||||
# output = captured.out[captured.out.find(human_1) :]
|
||||
#
|
||||
# assert human_1 in output
|
||||
# assert text_1 in output
|
||||
#
|
||||
# # Add second human
|
||||
# add("human", human_2, text_2)
|
||||
#
|
||||
# # Expect to see second human
|
||||
# list("humans")
|
||||
# captured = capsys.readouterr()
|
||||
# output = captured.out[captured.out.find(human_1) :]
|
||||
#
|
||||
# assert human_1 in output
|
||||
# assert text_1 in output
|
||||
# assert human_2 in output
|
||||
# assert text_2 in output
|
||||
#
|
||||
# with unittest.mock.patch("questionary.confirm") as mock_confirm:
|
||||
# mock_confirm.return_value.ask.return_value = True
|
||||
#
|
||||
# # Update second human
|
||||
# add("human", human_2, text_3)
|
||||
#
|
||||
# # Expect to see update text
|
||||
# list("humans")
|
||||
# captured = capsys.readouterr()
|
||||
# output = captured.out[captured.out.find(human_1) :]
|
||||
#
|
||||
# assert human_1 in output
|
||||
# assert text_1 in output
|
||||
# assert human_2 in output
|
||||
# assert output.count(human_2) == 1
|
||||
# assert text_3 in output
|
||||
# assert text_2 not in output
|
||||
#
|
||||
# # Delete second human
|
||||
# delete("human", human_2)
|
||||
#
|
||||
# # Expect second human to be deleted
|
||||
# list("humans")
|
||||
# captured = capsys.readouterr()
|
||||
# output = captured.out[captured.out.find(human_1) :]
|
||||
#
|
||||
# assert human_1 in output
|
||||
# assert text_1 in output
|
||||
# assert human_2 not in output
|
||||
# assert text_2 not in output
|
||||
#
|
||||
# # Clean up
|
||||
# delete("human", human_1)
|
||||
#
|
||||
# reset_env_variables(server_url, token)
|
||||
#
|
@ -1,93 +0,0 @@
|
||||
from logging import getLogger
|
||||
|
||||
from openai import APIConnectionError, OpenAI
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def test_openai_assistant():
|
||||
client = OpenAI(base_url="http://127.0.0.1:8080/v1")
|
||||
# create assistant
|
||||
try:
|
||||
assistant = client.beta.assistants.create(
|
||||
name="Math Tutor",
|
||||
instructions="You are a personal math tutor. Write and run code to answer math questions.",
|
||||
# tools=[{"type": "code_interpreter"}],
|
||||
model="gpt-4-turbo-preview",
|
||||
)
|
||||
except APIConnectionError as e:
|
||||
logger.error("Connection issue with localhost openai stub: %s", e)
|
||||
return
|
||||
# create thread
|
||||
thread = client.beta.threads.create()
|
||||
|
||||
message = client.beta.threads.messages.create(
|
||||
thread_id=thread.id, role="user", content="I need to solve the equation `3x + 11 = 14`. Can you help me?"
|
||||
)
|
||||
|
||||
run = client.beta.threads.runs.create(
|
||||
thread_id=thread.id, assistant_id=assistant.id, instructions="Please address the user as Jane Doe. The user has a premium account."
|
||||
)
|
||||
|
||||
# run = client.beta.threads.runs.create(
|
||||
# thread_id=thread.id,
|
||||
# assistant_id=assistant.id,
|
||||
# model="gpt-4-turbo-preview",
|
||||
# instructions="New instructions that override the Assistant instructions",
|
||||
# tools=[{"type": "code_interpreter"}, {"type": "retrieval"}]
|
||||
# )
|
||||
|
||||
# Store the run ID
|
||||
run_id = run.id
|
||||
print(run_id)
|
||||
|
||||
# NOTE: Letta does not support polling yet, so run status is always "completed"
|
||||
# Retrieve all messages from the thread
|
||||
messages = client.beta.threads.messages.list(thread_id=thread.id)
|
||||
|
||||
# Print all messages from the thread
|
||||
for msg in messages.messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"][0]
|
||||
print(f"{role.capitalize()}: {content}")
|
||||
|
||||
# TODO: add once polling works
|
||||
## Polling for the run status
|
||||
# while True:
|
||||
# # Retrieve the run status
|
||||
# run_status = client.beta.threads.runs.retrieve(
|
||||
# thread_id=thread.id,
|
||||
# run_id=run_id
|
||||
# )
|
||||
|
||||
# # Check and print the step details
|
||||
# run_steps = client.beta.threads.runs.steps.list(
|
||||
# thread_id=thread.id,
|
||||
# run_id=run_id
|
||||
# )
|
||||
# for step in run_steps.data:
|
||||
# if step.type == 'tool_calls':
|
||||
# print(f"Tool {step.type} invoked.")
|
||||
|
||||
# # If step involves code execution, print the code
|
||||
# if step.type == 'code_interpreter':
|
||||
# print(f"Python Code Executed: {step.step_details['code_interpreter']['input']}")
|
||||
|
||||
# if run_status.status == 'completed':
|
||||
# # Retrieve all messages from the thread
|
||||
# messages = client.beta.threads.messages.list(
|
||||
# thread_id=thread.id
|
||||
# )
|
||||
|
||||
# # Print all messages from the thread
|
||||
# for msg in messages.data:
|
||||
# role = msg.role
|
||||
# content = msg.content[0].text.value
|
||||
# print(f"{role.capitalize()}: {content}")
|
||||
# break # Exit the polling loop since the run is complete
|
||||
# elif run_status.status in ['queued', 'in_progress']:
|
||||
# print(f'{run_status.status.capitalize()}... Please wait.')
|
||||
# time.sleep(1.5) # Wait before checking again
|
||||
# else:
|
||||
# print(f"Run status: {run_status.status}")
|
||||
# break # Exit the polling loop if the status is neither 'in_progress' nor 'completed'
|
@ -1,52 +0,0 @@
|
||||
# test state saving between client session
|
||||
# TODO: update this test with correct imports
|
||||
|
||||
|
||||
# def test_save_load(client):
|
||||
# """Test that state is being persisted correctly after an /exit
|
||||
#
|
||||
# Create a new agent, and request a message
|
||||
#
|
||||
# Then trigger
|
||||
# """
|
||||
# assert client is not None, "Run create_agent test first"
|
||||
# assert test_agent_state is not None, "Run create_agent test first"
|
||||
# assert test_agent_state_post_message is not None, "Run test_user_message test first"
|
||||
#
|
||||
# # Create a new client (not thread safe), and load the same agent
|
||||
# # The agent state inside should correspond to the initial state pre-message
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# client2 = Letta(quickstart="openai", user_id=test_user_id)
|
||||
# else:
|
||||
# client2 = Letta(quickstart="letta_hosted", user_id=test_user_id)
|
||||
# print(f"\n\n[3] CREATING CLIENT2, LOADING AGENT {test_agent_state.id}!")
|
||||
# client2_agent_obj = client2.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id)
|
||||
# client2_agent_state = client2_agent_obj.update_state()
|
||||
# print(f"[3] LOADED AGENT! AGENT {client2_agent_state.id}\n\tmessages={client2_agent_state.state['messages']}")
|
||||
#
|
||||
# # assert test_agent_state == client2_agent_state, f"{vars(test_agent_state)}\n{vars(client2_agent_state)}"
|
||||
# def check_state_equivalence(state_1, state_2):
|
||||
# """Helper function that checks the equivalence of two AgentState objects"""
|
||||
# assert state_1.keys() == state_2.keys(), f"{state_1.keys()}\n{state_2.keys}"
|
||||
# for k, v1 in state_1.items():
|
||||
# v2 = state_2[k]
|
||||
# if isinstance(v1, LLMConfig) or isinstance(v1, EmbeddingConfig):
|
||||
# assert vars(v1) == vars(v2), f"{vars(v1)}\n{vars(v2)}"
|
||||
# else:
|
||||
# assert v1 == v2, f"{v1}\n{v2}"
|
||||
#
|
||||
# check_state_equivalence(vars(test_agent_state), vars(client2_agent_state))
|
||||
#
|
||||
# # Now, write out the save from the original client
|
||||
# # This should persist the test message into the agent state
|
||||
# client.save()
|
||||
#
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# client3 = Letta(quickstart="openai", user_id=test_user_id)
|
||||
# else:
|
||||
# client3 = Letta(quickstart="letta_hosted", user_id=test_user_id)
|
||||
# client3_agent_obj = client3.server.load_agent(user_id=test_user_id, agent_id=test_agent_state.id)
|
||||
# client3_agent_state = client3_agent_obj.update_state()
|
||||
#
|
||||
# check_state_equivalence(vars(test_agent_state_post_message), vars(client3_agent_state))
|
||||
#
|
@ -1,62 +0,0 @@
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
|
||||
|
||||
def send_message(self, message: str):
|
||||
"""
|
||||
Sends a message to the human user.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def send_message_missing_types(self, message):
|
||||
"""
|
||||
Sends a message to the human user.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
def send_message_missing_docstring(self, message: str):
|
||||
return None
|
||||
|
||||
|
||||
def test_schema_generator():
|
||||
# Check that a basic function schema converts correctly
|
||||
correct_schema = {
|
||||
"name": "send_message",
|
||||
"description": "Sends a message to the human user.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}},
|
||||
"required": ["message"],
|
||||
},
|
||||
}
|
||||
generated_schema = generate_schema(send_message)
|
||||
print(f"\n\nreference_schema={correct_schema}")
|
||||
print(f"\n\ngenerated_schema={generated_schema}")
|
||||
assert correct_schema == generated_schema
|
||||
|
||||
# Check that missing types results in an error
|
||||
try:
|
||||
_ = generate_schema(send_message_missing_types)
|
||||
assert False
|
||||
except:
|
||||
pass
|
||||
|
||||
# Check that missing docstring results in an error
|
||||
try:
|
||||
_ = generate_schema(send_message_missing_docstring)
|
||||
assert False
|
||||
except:
|
||||
pass
|
@ -19,8 +19,6 @@ from letta.schemas.letta_message import (
|
||||
)
|
||||
from letta.schemas.user import User
|
||||
|
||||
from .test_managers import DEFAULT_EMBEDDING_CONFIG
|
||||
|
||||
utils.DEBUG = True
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent
|
||||
@ -266,6 +264,7 @@ Lise, young Bolkónski's wife, this very evening, and perhaps the
|
||||
thing can be arranged. It shall be on your family's behalf that I'll
|
||||
start my apprenticeship as old maid."""
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
@ -302,42 +301,66 @@ def user_id(server, org_id):
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_id(server, user_id):
|
||||
def base_tools(server, user_id):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
tools = []
|
||||
for tool_name in BASE_TOOLS:
|
||||
tools.append(server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor))
|
||||
|
||||
yield tools
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def base_memory_tools(server, user_id):
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
tools = []
|
||||
for tool_name in BASE_MEMORY_TOOLS:
|
||||
tools.append(server.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor))
|
||||
|
||||
yield tools
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_id(server, user_id, base_tools):
|
||||
# create agent
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="test_agent",
|
||||
tools=BASE_TOOLS,
|
||||
tool_ids=[t.id for t in base_tools],
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=server.get_user_or_default(user_id),
|
||||
actor=actor,
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
|
||||
# cleanup
|
||||
server.delete_agent(user_id, agent_state.id)
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def other_agent_id(server, user_id):
|
||||
def other_agent_id(server, user_id, base_tools):
|
||||
# create agent
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="test_agent_other",
|
||||
tools=BASE_TOOLS,
|
||||
tool_ids=[t.id for t in base_tools],
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=server.get_user_or_default(user_id),
|
||||
actor=actor,
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
|
||||
# cleanup
|
||||
server.delete_agent(user_id, agent_state.id)
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=actor)
|
||||
|
||||
|
||||
def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
try:
|
||||
@ -416,6 +439,7 @@ def test_user_message(server, user_id, agent_id):
|
||||
@pytest.mark.order(5)
|
||||
def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
cursor1 = messages_1[-1].id
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
@ -427,7 +451,9 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
|
||||
message_ids = [m.id for m in messages_3]
|
||||
for message_id in in_context_ids:
|
||||
assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
@ -437,10 +463,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
def test_get_archival_memory(server, user_id, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
|
||||
# List latest 2 passages
|
||||
passages_1 = server.passage_manager.list_passages(
|
||||
actor=user, agent_id=agent_id, ascending=False, limit=2,
|
||||
actor=user,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
limit=2,
|
||||
)
|
||||
assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
|
||||
|
||||
@ -483,12 +512,13 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
|
||||
- "rewrite" replaces the text of the last assistant message
|
||||
- "retry" retries the last assistant message
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
# Send an initial message
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
|
||||
# Grab the raw Agent object
|
||||
letta_agent = server.load_agent(agent_id=agent_id)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
@ -496,10 +526,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
|
||||
# Try "rethink"
|
||||
new_thought = "I am thinking about the meaning of life, the universe, and everything. Bananas?"
|
||||
assert last_agent_message.text is not None and last_agent_message.text != new_thought
|
||||
server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought)
|
||||
server.rethink_agent_message(agent_id=agent_id, new_thought=new_thought, actor=actor)
|
||||
|
||||
# Grab the agent object again (make sure it's live)
|
||||
letta_agent = server.load_agent(agent_id=agent_id)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
@ -513,10 +543,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
|
||||
assert "message" in args_json and args_json["message"] is not None and args_json["message"] != ""
|
||||
|
||||
new_text = "Why hello there my good friend! Is 42 what you're looking for? Bananas?"
|
||||
server.rewrite_agent_message(agent_id=agent_id, new_text=new_text)
|
||||
server.rewrite_agent_message(agent_id=agent_id, new_text=new_text, actor=actor)
|
||||
|
||||
# Grab the agent object again (make sure it's live)
|
||||
letta_agent = server.load_agent(agent_id=agent_id)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
@ -524,10 +554,10 @@ def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
|
||||
assert "message" in args_json and args_json["message"] is not None and args_json["message"] == new_text
|
||||
|
||||
# Try retry
|
||||
server.retry_agent_message(agent_id=agent_id)
|
||||
server.retry_agent_message(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Grab the agent object again (make sure it's live)
|
||||
letta_agent = server.load_agent(agent_id=agent_id)
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
assert letta_agent._messages[-1].role == MessageRole.tool
|
||||
assert letta_agent._messages[-2].role == MessageRole.assistant
|
||||
last_agent_message = letta_agent._messages[-2]
|
||||
@ -581,33 +611,6 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
|
||||
)
|
||||
|
||||
|
||||
def test_load_agent_with_nonexistent_tool_names_does_not_error(server: SyncServer, user_id: str):
|
||||
fake_tool_name = "blahblahblah"
|
||||
tools = BASE_TOOLS + [fake_tool_name]
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="nonexistent_tools_agent",
|
||||
tools=tools,
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=server.get_user_or_default(user_id),
|
||||
)
|
||||
|
||||
# Check that the tools in agent_state do NOT include the fake name
|
||||
assert fake_tool_name not in agent_state.tool_names
|
||||
assert set(BASE_TOOLS).issubset(set(agent_state.tool_names))
|
||||
|
||||
# Load the agent from the database and check that it doesn't error / tools are correct
|
||||
saved_tools = server.get_tools_from_agent(agent_id=agent_state.id, user_id=user_id)
|
||||
assert fake_tool_name not in agent_state.tool_names
|
||||
assert set(BASE_TOOLS).issubset(set(agent_state.tool_names))
|
||||
|
||||
# cleanup
|
||||
server.delete_agent(user_id, agent_state.id)
|
||||
|
||||
|
||||
def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
@ -616,14 +619,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=server.get_user_or_default(user_id),
|
||||
actor=server.user_manager.get_user_or_default(user_id),
|
||||
)
|
||||
|
||||
# create another user in the same org
|
||||
another_user = server.user_manager.create_user(User(organization_id=org_id, name="another"))
|
||||
|
||||
# test that another user in the same org can delete the agent
|
||||
server.delete_agent(another_user.id, agent_state.id)
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=another_user)
|
||||
|
||||
|
||||
def _test_get_messages_letta_format(
|
||||
@ -887,14 +890,14 @@ def test_composio_client_simple(server):
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
|
||||
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="memory_rebuild_test_agent",
|
||||
tools=BASE_TOOLS + BASE_MEMORY_TOOLS,
|
||||
tool_ids=[t.id for t in base_tools + base_memory_tools],
|
||||
memory_blocks=[
|
||||
CreateBlock(label="human", value="The human's name is Bob."),
|
||||
CreateBlock(label="persona", value="My name is Alice."),
|
||||
@ -902,7 +905,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=server.get_user_or_default(user_id),
|
||||
actor=actor,
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
|
||||
@ -929,31 +932,28 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
|
||||
try:
|
||||
# At this stage, there should only be 1 system message inside of recall storage
|
||||
num_system_messages, all_messages = count_system_messages_in_recall()
|
||||
# assert num_system_messages == 1, (num_system_messages, all_messages)
|
||||
assert num_system_messages == 2, (num_system_messages, all_messages)
|
||||
assert num_system_messages == 1, (num_system_messages, all_messages)
|
||||
|
||||
# Assuming core memory append actually ran correctly, at this point there should be 2 messages
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
|
||||
|
||||
# At this stage, there should only be 1 system message inside of recall storage
|
||||
# At this stage, there should be 2 system message inside of recall storage
|
||||
num_system_messages, all_messages = count_system_messages_in_recall()
|
||||
# assert num_system_messages == 2, (num_system_messages, all_messages)
|
||||
assert num_system_messages == 3, (num_system_messages, all_messages)
|
||||
assert num_system_messages == 2, (num_system_messages, all_messages)
|
||||
|
||||
# Run server.load_agent, and make sure that the number of system messages is still 2
|
||||
server.load_agent(agent_id=agent_state.id)
|
||||
server.load_agent(agent_id=agent_state.id, actor=actor)
|
||||
|
||||
num_system_messages, all_messages = count_system_messages_in_recall()
|
||||
# assert num_system_messages == 2, (num_system_messages, all_messages)
|
||||
assert num_system_messages == 3, (num_system_messages, all_messages)
|
||||
assert num_system_messages == 2, (num_system_messages, all_messages)
|
||||
|
||||
finally:
|
||||
# cleanup
|
||||
server.delete_agent(user_id, agent_state.id)
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=actor)
|
||||
|
||||
|
||||
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
|
||||
user = server.get_user_or_default(user_id)
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
|
||||
# Create a source
|
||||
source = server.source_manager.create_source(
|
||||
@ -962,7 +962,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
created_by_id=user_id,
|
||||
),
|
||||
actor=user
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Create a test file with some content
|
||||
@ -971,11 +971,10 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
test_file.write_text(test_content)
|
||||
|
||||
# Attach source to agent first
|
||||
agent = server.load_agent(agent_id=agent_id)
|
||||
agent.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
|
||||
|
||||
# Get initial passage count
|
||||
initial_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
|
||||
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
assert initial_passage_count == 0
|
||||
|
||||
# Create a job for loading the first file
|
||||
@ -984,7 +983,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
user_id=user_id,
|
||||
metadata_={"type": "embedding", "filename": test_file.name, "source_id": source.id},
|
||||
),
|
||||
actor=user
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Load the first file to source
|
||||
@ -992,17 +991,17 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
source_id=source.id,
|
||||
file_path=str(test_file),
|
||||
job_id=job.id,
|
||||
actor=user,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Verify job completed successfully
|
||||
job = server.job_manager.get_job_by_id(job_id=job.id, actor=user)
|
||||
job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor)
|
||||
assert job.status == "completed"
|
||||
assert job.metadata_["num_passages"] == 1
|
||||
assert job.metadata_["num_passages"] == 1
|
||||
assert job.metadata_["num_documents"] == 1
|
||||
|
||||
# Verify passages were added
|
||||
first_file_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
|
||||
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
assert first_file_passage_count > initial_passage_count
|
||||
|
||||
# Create a second test file with different content
|
||||
@ -1015,7 +1014,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
user_id=user_id,
|
||||
metadata_={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
|
||||
),
|
||||
actor=user
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Load the second file to source
|
||||
@ -1023,22 +1022,22 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
|
||||
source_id=source.id,
|
||||
file_path=str(test_file2),
|
||||
job_id=job2.id,
|
||||
actor=user,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Verify second job completed successfully
|
||||
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=user)
|
||||
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor)
|
||||
assert job2.status == "completed"
|
||||
assert job2.metadata_["num_passages"] >= 10
|
||||
assert job2.metadata_["num_documents"] == 1
|
||||
|
||||
# Verify passages were appended (not replaced)
|
||||
final_passage_count = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
|
||||
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
|
||||
assert final_passage_count > first_file_passage_count
|
||||
|
||||
# Verify both old and new content is searchable
|
||||
passages = server.passage_manager.list_passages(
|
||||
actor=user,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
source_id=source.id,
|
||||
query_text="what does Timber like to eat",
|
||||
|
@ -33,7 +33,7 @@ def create_test_agent():
|
||||
)
|
||||
|
||||
global agent_obj
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id)
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
|
||||
|
||||
def test_summarize_messages_inplace(mock_e2b_api_key_none):
|
||||
@ -74,7 +74,7 @@ def test_summarize_messages_inplace(mock_e2b_api_key_none):
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
# reload agent object
|
||||
agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id)
|
||||
agent_obj = client.server.load_agent(agent_id=agent_obj.agent_state.id, actor=client.user)
|
||||
|
||||
agent_obj.summarize_messages_inplace()
|
||||
print(f"Summarization succeeded: messages[1] = \n{agent_obj.messages[1]}")
|
||||
@ -121,7 +121,7 @@ def test_auto_summarize(mock_e2b_api_key_none):
|
||||
|
||||
# check if the summarize message is inside the messages
|
||||
assert isinstance(client, LocalClient), "Test only works with LocalClient"
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id)
|
||||
agent_obj = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
print("SUMMARY", summarize_message_exists(agent_obj._messages))
|
||||
if summarize_message_exists(agent_obj._messages):
|
||||
break
|
||||
|
@ -169,7 +169,7 @@ def configure_mock_sync_server(mock_sync_server):
|
||||
mock_sync_server.sandbox_config_manager.list_sandbox_env_vars_by_key.return_value = [mock_api_key]
|
||||
|
||||
# Mock user retrieval
|
||||
mock_sync_server.get_user_or_default.return_value = Mock() # Provide additional attributes if needed
|
||||
mock_sync_server.user_manager.get_user_or_default.return_value = Mock() # Provide additional attributes if needed
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@ -182,7 +182,7 @@ def test_delete_tool(client, mock_sync_server, add_integers_tool):
|
||||
|
||||
assert response.status_code == 200
|
||||
mock_sync_server.tool_manager.delete_tool_by_id.assert_called_once_with(
|
||||
tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value
|
||||
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
@ -195,7 +195,7 @@ def test_get_tool(client, mock_sync_server, add_integers_tool):
|
||||
assert response.json()["id"] == add_integers_tool.id
|
||||
assert response.json()["source_code"] == add_integers_tool.source_code
|
||||
mock_sync_server.tool_manager.get_tool_by_id.assert_called_once_with(
|
||||
tool_id=add_integers_tool.id, actor=mock_sync_server.get_user_or_default.return_value
|
||||
tool_id=add_integers_tool.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
@ -216,7 +216,7 @@ def test_get_tool_id(client, mock_sync_server, add_integers_tool):
|
||||
assert response.status_code == 200
|
||||
assert response.json() == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.get_tool_by_name.assert_called_once_with(
|
||||
tool_name=add_integers_tool.name, actor=mock_sync_server.get_user_or_default.return_value
|
||||
tool_name=add_integers_tool.name, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
@ -268,7 +268,7 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer
|
||||
assert response.status_code == 200
|
||||
assert response.json()["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.update_tool_by_id.assert_called_once_with(
|
||||
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.get_user_or_default.return_value
|
||||
tool_id=add_integers_tool.id, tool_update=update_integers_tool, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
@ -280,7 +280,9 @@ def test_add_base_tools(client, mock_sync_server, add_integers_tool):
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(actor=mock_sync_server.get_user_or_default.return_value)
|
||||
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_list_composio_apps(client, mock_sync_server, composio_apps):
|
||||
|
@ -1,42 +1,39 @@
|
||||
import numpy as np
|
||||
import sqlite3
|
||||
import base64
|
||||
from numpy.testing import assert_array_almost_equal
|
||||
|
||||
import pytest
|
||||
from letta.orm.sqlalchemy_base import adapt_array
|
||||
from letta.orm.sqlite_functions import convert_array, verify_embedding_dimension
|
||||
|
||||
from letta.orm.sqlalchemy_base import adapt_array, convert_array
|
||||
from letta.orm.sqlite_functions import verify_embedding_dimension
|
||||
|
||||
def test_vector_conversions():
|
||||
"""Test the vector conversion functions"""
|
||||
# Create test data
|
||||
original = np.random.random(4096).astype(np.float32)
|
||||
print(f"Original shape: {original.shape}")
|
||||
|
||||
|
||||
# Test full conversion cycle
|
||||
encoded = adapt_array(original)
|
||||
print(f"Encoded type: {type(encoded)}")
|
||||
print(f"Encoded length: {len(encoded)}")
|
||||
|
||||
|
||||
decoded = convert_array(encoded)
|
||||
print(f"Decoded shape: {decoded.shape}")
|
||||
print(f"Dimension verification: {verify_embedding_dimension(decoded)}")
|
||||
|
||||
|
||||
# Verify data integrity
|
||||
np.testing.assert_array_almost_equal(original, decoded)
|
||||
print("✓ Data integrity verified")
|
||||
|
||||
|
||||
# Test with a list
|
||||
list_data = original.tolist()
|
||||
encoded_list = adapt_array(list_data)
|
||||
decoded_list = convert_array(encoded_list)
|
||||
np.testing.assert_array_almost_equal(original, decoded_list)
|
||||
print("✓ List conversion verified")
|
||||
|
||||
|
||||
# Test None handling
|
||||
assert adapt_array(None) is None
|
||||
assert convert_array(None) is None
|
||||
print("✓ None handling verified")
|
||||
|
||||
# Run the tests
|
||||
|
||||
# Run the tests
|
||||
|
Loading…
Reference in New Issue
Block a user