mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
1568 lines
69 KiB
Python
1568 lines
69 KiB
Python
from datetime import datetime, timezone
|
||
from typing import Dict, List, Optional, Set, Tuple
|
||
|
||
import numpy as np
|
||
import sqlalchemy as sa
|
||
from sqlalchemy import Select, and_, delete, func, insert, literal, or_, select, union_all
|
||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||
|
||
from letta.constants import (
|
||
BASE_MEMORY_TOOLS,
|
||
BASE_SLEEPTIME_CHAT_TOOLS,
|
||
BASE_SLEEPTIME_TOOLS,
|
||
BASE_TOOLS,
|
||
DATA_SOURCE_ATTACH_ALERT,
|
||
MAX_EMBEDDING_DIM,
|
||
MULTI_AGENT_TOOLS,
|
||
)
|
||
from letta.embeddings import embedding_model
|
||
from letta.helpers.datetime_helpers import get_utc_time
|
||
from letta.log import get_logger
|
||
from letta.orm import Agent as AgentModel
|
||
from letta.orm import AgentPassage, AgentsTags
|
||
from letta.orm import Block as BlockModel
|
||
from letta.orm import BlocksAgents
|
||
from letta.orm import Group as GroupModel
|
||
from letta.orm import IdentitiesAgents
|
||
from letta.orm import Source as SourceModel
|
||
from letta.orm import SourcePassage, SourcesAgents
|
||
from letta.orm import Tool as ToolModel
|
||
from letta.orm import ToolsAgents
|
||
from letta.orm.enums import ToolType
|
||
from letta.orm.errors import NoResultFound
|
||
from letta.orm.sandbox_config import AgentEnvironmentVariable
|
||
from letta.orm.sandbox_config import AgentEnvironmentVariable as AgentEnvironmentVariableModel
|
||
from letta.orm.sqlalchemy_base import AccessType
|
||
from letta.orm.sqlite_functions import adapt_array
|
||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent, get_prompt_template_for_agent_type
|
||
from letta.schemas.block import Block as PydanticBlock
|
||
from letta.schemas.block import BlockUpdate
|
||
from letta.schemas.embedding_config import EmbeddingConfig
|
||
from letta.schemas.group import Group as PydanticGroup
|
||
from letta.schemas.group import ManagerType
|
||
from letta.schemas.memory import Memory
|
||
from letta.schemas.message import Message
|
||
from letta.schemas.message import Message as PydanticMessage
|
||
from letta.schemas.message import MessageCreate, MessageUpdate
|
||
from letta.schemas.passage import Passage as PydanticPassage
|
||
from letta.schemas.source import Source as PydanticSource
|
||
from letta.schemas.tool import Tool as PydanticTool
|
||
from letta.schemas.tool_rule import ContinueToolRule, TerminalToolRule
|
||
from letta.schemas.user import User as PydanticUser
|
||
from letta.serialize_schemas import MarshmallowAgentSchema
|
||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||
from letta.services.block_manager import BlockManager
|
||
from letta.services.helpers.agent_manager_helper import (
|
||
_apply_filters,
|
||
_apply_identity_filters,
|
||
_apply_pagination,
|
||
_apply_tag_filter,
|
||
_process_relationship,
|
||
check_supports_structured_output,
|
||
compile_system_message,
|
||
derive_system_message,
|
||
initialize_message_sequence,
|
||
package_initial_message_sequence,
|
||
)
|
||
from letta.services.identity_manager import IdentityManager
|
||
from letta.services.message_manager import MessageManager
|
||
from letta.services.passage_manager import PassageManager
|
||
from letta.services.source_manager import SourceManager
|
||
from letta.services.tool_manager import ToolManager
|
||
from letta.settings import settings
|
||
from letta.tracing import trace_method
|
||
from letta.utils import enforce_types, united_diff
|
||
|
||
logger = get_logger(__name__)
|
||
|
||
|
||
class AgentManager:
|
||
"""Manager class to handle business logic related to Agents."""
|
||
|
||
def __init__(self):
|
||
from letta.server.db import db_context
|
||
|
||
self.session_maker = db_context
|
||
self.block_manager = BlockManager()
|
||
self.tool_manager = ToolManager()
|
||
self.source_manager = SourceManager()
|
||
self.message_manager = MessageManager()
|
||
self.passage_manager = PassageManager()
|
||
self.identity_manager = IdentityManager()
|
||
|
||
@staticmethod
|
||
def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||
"""
|
||
Bulk‑fetch all ToolModel rows matching either name ∈ names or id ∈ ids
|
||
(and scoped to this organization), and return two maps:
|
||
name_to_id, id_to_name.
|
||
Raises if any requested name or id was not found.
|
||
"""
|
||
stmt = select(ToolModel.id, ToolModel.name).where(
|
||
ToolModel.organization_id == org_id,
|
||
or_(
|
||
ToolModel.name.in_(names),
|
||
ToolModel.id.in_(ids),
|
||
),
|
||
)
|
||
rows = session.execute(stmt).all()
|
||
name_to_id = {name: tid for tid, name in rows}
|
||
id_to_name = {tid: name for tid, name in rows}
|
||
|
||
missing_names = names - set(name_to_id.keys())
|
||
missing_ids = ids - set(id_to_name.keys())
|
||
if missing_names:
|
||
raise ValueError(f"Tools not found by name: {missing_names}")
|
||
if missing_ids:
|
||
raise ValueError(f"Tools not found by id: {missing_ids}")
|
||
|
||
return name_to_id, id_to_name
|
||
|
||
@staticmethod
|
||
@trace_method
|
||
def _bulk_insert_pivot(session, table, rows: list[dict]):
|
||
if not rows:
|
||
return
|
||
|
||
dialect = session.bind.dialect.name
|
||
if dialect == "postgresql":
|
||
stmt = pg_insert(table).values(rows).on_conflict_do_nothing()
|
||
elif dialect == "sqlite":
|
||
stmt = sa.insert(table).values(rows).prefix_with("OR IGNORE")
|
||
else:
|
||
# fallback: filter out exact-duplicate dicts in Python
|
||
seen = set()
|
||
filtered = []
|
||
for row in rows:
|
||
key = tuple(sorted(row.items()))
|
||
if key not in seen:
|
||
seen.add(key)
|
||
filtered.append(row)
|
||
stmt = sa.insert(table).values(filtered)
|
||
|
||
session.execute(stmt)
|
||
|
||
@staticmethod
|
||
@trace_method
|
||
def _replace_pivot_rows(session, table, agent_id: str, rows: list[dict]):
|
||
"""
|
||
Replace all pivot rows for an agent with *exactly* the provided list.
|
||
Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING).
|
||
"""
|
||
# delete all existing rows for this agent
|
||
session.execute(delete(table).where(table.c.agent_id == agent_id))
|
||
if rows:
|
||
AgentManager._bulk_insert_pivot(session, table, rows)
|
||
|
||
# ======================================================================================================================
|
||
# Basic CRUD operations
|
||
# ======================================================================================================================
|
||
@trace_method
|
||
def create_agent(self, agent_create: CreateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||
# validate required configs
|
||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||
raise ValueError("llm_config and embedding_config are required")
|
||
|
||
# blocks
|
||
block_ids = list(agent_create.block_ids or [])
|
||
if agent_create.memory_blocks:
|
||
pydantic_blocks = [PydanticBlock(**b.model_dump(to_orm=True)) for b in agent_create.memory_blocks]
|
||
created_blocks = self.block_manager.batch_create_blocks(
|
||
pydantic_blocks,
|
||
actor=actor,
|
||
)
|
||
block_ids.extend([blk.id for blk in created_blocks])
|
||
|
||
# tools
|
||
tool_names = set(agent_create.tools or [])
|
||
if agent_create.include_base_tools:
|
||
if agent_create.agent_type == AgentType.sleeptime_agent:
|
||
tool_names |= set(BASE_SLEEPTIME_TOOLS)
|
||
elif agent_create.enable_sleeptime:
|
||
tool_names |= set(BASE_SLEEPTIME_CHAT_TOOLS)
|
||
else:
|
||
tool_names |= set(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||
if agent_create.include_multi_agent_tools:
|
||
tool_names |= set(MULTI_AGENT_TOOLS)
|
||
|
||
supplied_ids = set(agent_create.tool_ids or [])
|
||
|
||
source_ids = agent_create.source_ids or []
|
||
identity_ids = agent_create.identity_ids or []
|
||
tag_values = agent_create.tags or []
|
||
|
||
with self.session_maker() as session:
|
||
with session.begin():
|
||
name_to_id, id_to_name = self._resolve_tools(
|
||
session,
|
||
tool_names,
|
||
supplied_ids,
|
||
actor.organization_id,
|
||
)
|
||
|
||
tool_ids = set(name_to_id.values()) | set(id_to_name.keys())
|
||
tool_names = set(name_to_id.keys()) # now canonical
|
||
|
||
tool_rules = list(agent_create.tool_rules or [])
|
||
if agent_create.include_base_tool_rules:
|
||
for tn in tool_names:
|
||
if tn in {"send_message", "send_message_to_agent_async", "memory_finish_edits"}:
|
||
tool_rules.append(TerminalToolRule(tool_name=tn))
|
||
elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_SLEEPTIME_TOOLS):
|
||
tool_rules.append(ContinueToolRule(tool_name=tn))
|
||
|
||
if tool_rules:
|
||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules)
|
||
|
||
new_agent = AgentModel(
|
||
name=agent_create.name,
|
||
system=derive_system_message(
|
||
agent_type=agent_create.agent_type,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
system=agent_create.system,
|
||
),
|
||
agent_type=agent_create.agent_type,
|
||
llm_config=agent_create.llm_config,
|
||
embedding_config=agent_create.embedding_config,
|
||
organization_id=actor.organization_id,
|
||
description=agent_create.description,
|
||
metadata_=agent_create.metadata,
|
||
tool_rules=tool_rules,
|
||
project_id=agent_create.project_id,
|
||
template_id=agent_create.template_id,
|
||
base_template_id=agent_create.base_template_id,
|
||
message_buffer_autoclear=agent_create.message_buffer_autoclear,
|
||
enable_sleeptime=agent_create.enable_sleeptime,
|
||
created_by_id=actor.id,
|
||
last_updated_by_id=actor.id,
|
||
)
|
||
session.add(new_agent)
|
||
session.flush()
|
||
aid = new_agent.id
|
||
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in tool_ids],
|
||
)
|
||
|
||
if block_ids:
|
||
rows = [
|
||
{"agent_id": aid, "block_id": bid, "block_label": lbl}
|
||
for bid, lbl in session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(block_ids))).all()
|
||
]
|
||
self._bulk_insert_pivot(session, BlocksAgents.__table__, rows)
|
||
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
[{"agent_id": aid, "source_id": sid} for sid in source_ids],
|
||
)
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
AgentsTags.__table__,
|
||
[{"agent_id": aid, "tag": tag} for tag in tag_values],
|
||
)
|
||
self._bulk_insert_pivot(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in identity_ids],
|
||
)
|
||
|
||
if agent_create.tool_exec_environment_variables:
|
||
env_rows = [
|
||
{
|
||
"agent_id": aid,
|
||
"key": key,
|
||
"value": val,
|
||
"organization_id": actor.organization_id,
|
||
}
|
||
for key, val in agent_create.tool_exec_environment_variables.items()
|
||
]
|
||
session.execute(insert(AgentEnvironmentVariable).values(env_rows))
|
||
|
||
# initial message sequence
|
||
init_messages = self._generate_initial_message_sequence(
|
||
actor,
|
||
agent_state=new_agent.to_pydantic(include_relationships={"memory"}),
|
||
supplied_initial_message_sequence=agent_create.initial_message_sequence,
|
||
)
|
||
new_agent.message_ids = [msg.id for msg in init_messages]
|
||
|
||
session.refresh(new_agent)
|
||
|
||
self.message_manager.create_many_messages(pydantic_msgs=init_messages, actor=actor)
|
||
return new_agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def _generate_initial_message_sequence(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> List[Message]:
|
||
init_messages = initialize_message_sequence(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
if supplied_initial_message_sequence is not None:
|
||
# We always need the system prompt up front
|
||
system_message_obj = PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=init_messages[0],
|
||
)
|
||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||
init_messages = [system_message_obj]
|
||
init_messages.extend(
|
||
package_initial_message_sequence(agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, actor)
|
||
)
|
||
else:
|
||
init_messages = [
|
||
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
|
||
for msg in init_messages
|
||
]
|
||
|
||
return init_messages
|
||
|
||
@enforce_types
|
||
def append_initial_message_sequence_to_in_context_messages(
|
||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||
) -> PydanticAgentState:
|
||
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
|
||
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
|
||
|
||
@enforce_types
|
||
def update_agent(
|
||
self,
|
||
agent_id: str,
|
||
agent_update: UpdateAgent,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
|
||
new_tools = set(agent_update.tool_ids or [])
|
||
new_sources = set(agent_update.source_ids or [])
|
||
new_blocks = set(agent_update.block_ids or [])
|
||
new_idents = set(agent_update.identity_ids or [])
|
||
new_tags = set(agent_update.tags or [])
|
||
|
||
with self.session_maker() as session, session.begin():
|
||
|
||
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
agent.updated_at = datetime.now(timezone.utc)
|
||
agent.last_updated_by_id = actor.id
|
||
|
||
scalar_updates = {
|
||
"name": agent_update.name,
|
||
"system": agent_update.system,
|
||
"llm_config": agent_update.llm_config,
|
||
"embedding_config": agent_update.embedding_config,
|
||
"message_ids": agent_update.message_ids,
|
||
"tool_rules": agent_update.tool_rules,
|
||
"description": agent_update.description,
|
||
"project_id": agent_update.project_id,
|
||
"template_id": agent_update.template_id,
|
||
"base_template_id": agent_update.base_template_id,
|
||
"message_buffer_autoclear": agent_update.message_buffer_autoclear,
|
||
"enable_sleeptime": agent_update.enable_sleeptime,
|
||
"response_format": agent_update.response_format,
|
||
}
|
||
for col, val in scalar_updates.items():
|
||
if val is not None:
|
||
setattr(agent, col, val)
|
||
|
||
if agent_update.metadata is not None:
|
||
agent.metadata_ = agent_update.metadata
|
||
|
||
aid = agent.id
|
||
|
||
if agent_update.tool_ids is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
ToolsAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tool_id": tid} for tid in new_tools],
|
||
)
|
||
session.expire(agent, ["tools"])
|
||
|
||
if agent_update.source_ids is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
SourcesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "source_id": sid} for sid in new_sources],
|
||
)
|
||
session.expire(agent, ["sources"])
|
||
|
||
if agent_update.block_ids is not None:
|
||
rows = []
|
||
if new_blocks:
|
||
label_map = {
|
||
bid: lbl
|
||
for bid, lbl in session.execute(select(BlockModel.id, BlockModel.label).where(BlockModel.id.in_(new_blocks)))
|
||
}
|
||
rows = [{"agent_id": aid, "block_id": bid, "block_label": label_map[bid]} for bid in new_blocks]
|
||
|
||
self._replace_pivot_rows(session, BlocksAgents.__table__, aid, rows)
|
||
session.expire(agent, ["core_memory"])
|
||
|
||
if agent_update.identity_ids is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
IdentitiesAgents.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "identity_id": iid} for iid in new_idents],
|
||
)
|
||
session.expire(agent, ["identities"])
|
||
|
||
if agent_update.tags is not None:
|
||
self._replace_pivot_rows(
|
||
session,
|
||
AgentsTags.__table__,
|
||
aid,
|
||
[{"agent_id": aid, "tag": tag} for tag in new_tags],
|
||
)
|
||
session.expire(agent, ["tags"])
|
||
|
||
if agent_update.tool_exec_environment_variables is not None:
|
||
session.execute(delete(AgentEnvironmentVariable).where(AgentEnvironmentVariable.agent_id == aid))
|
||
env_rows = [
|
||
{
|
||
"agent_id": aid,
|
||
"key": k,
|
||
"value": v,
|
||
"organization_id": agent.organization_id,
|
||
}
|
||
for k, v in agent_update.tool_exec_environment_variables.items()
|
||
]
|
||
if env_rows:
|
||
self._bulk_insert_pivot(session, AgentEnvironmentVariable.__table__, env_rows)
|
||
session.expire(agent, ["tool_exec_environment_variables"])
|
||
|
||
if agent_update.enable_sleeptime and agent_update.system is None:
|
||
agent.system = derive_system_message(
|
||
agent_type=agent.agent_type,
|
||
enable_sleeptime=agent_update.enable_sleeptime,
|
||
system=agent.system,
|
||
)
|
||
|
||
session.flush()
|
||
session.refresh(agent)
|
||
|
||
return agent.to_pydantic()
|
||
|
||
# TODO: Make this general and think about how to roll this into sqlalchemybase
|
||
def list_agents(
|
||
self,
|
||
actor: PydanticUser,
|
||
name: Optional[str] = None,
|
||
tags: Optional[List[str]] = None,
|
||
match_all_tags: bool = False,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
project_id: Optional[str] = None,
|
||
template_id: Optional[str] = None,
|
||
base_template_id: Optional[str] = None,
|
||
identity_id: Optional[str] = None,
|
||
identifier_keys: Optional[List[str]] = None,
|
||
include_relationships: Optional[List[str]] = None,
|
||
ascending: bool = True,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents with optimized filtering and optional field selection.
|
||
|
||
Args:
|
||
actor: The User requesting the list
|
||
name (Optional[str]): Filter by agent name.
|
||
tags (Optional[List[str]]): Filter agents by tags.
|
||
match_all_tags (bool): If True, only return agents that match ALL given tags.
|
||
before (Optional[str]): Cursor for pagination.
|
||
after (Optional[str]): Cursor for pagination.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
query_text (Optional[str]): Search agents by name.
|
||
project_id (Optional[str]): Filter by project ID.
|
||
template_id (Optional[str]): Filter by template ID.
|
||
base_template_id (Optional[str]): Filter by base template ID.
|
||
identity_id (Optional[str]): Filter by identifier ID.
|
||
identifier_keys (Optional[List[str]]): Search agents by identifier keys.
|
||
include_relationships (Optional[List[str]]): List of fields to load for performance optimization.
|
||
ascending
|
||
|
||
Returns:
|
||
List[PydanticAgentState]: The filtered list of matching agents.
|
||
"""
|
||
with self.session_maker() as session:
|
||
query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id)
|
||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||
|
||
# Apply filters
|
||
query = _apply_filters(query, name, query_text, project_id, template_id, base_template_id)
|
||
query = _apply_identity_filters(query, identity_id, identifier_keys)
|
||
query = _apply_tag_filter(query, tags, match_all_tags)
|
||
query = _apply_pagination(query, before, after, session, ascending=ascending)
|
||
|
||
if limit:
|
||
query = query.limit(limit)
|
||
|
||
agents = session.execute(query).scalars().all()
|
||
return [agent.to_pydantic(include_relationships=include_relationships) for agent in agents]
|
||
|
||
@enforce_types
|
||
def list_agents_matching_tags(
|
||
self,
|
||
actor: PydanticUser,
|
||
match_all: List[str],
|
||
match_some: List[str],
|
||
limit: Optional[int] = 50,
|
||
) -> List[PydanticAgentState]:
|
||
"""
|
||
Retrieves agents in the same organization that match all specified `match_all` tags
|
||
and at least one tag from `match_some`. The query is optimized for efficiency by
|
||
leveraging indexed filtering and aggregation.
|
||
|
||
Args:
|
||
actor (PydanticUser): The user requesting the agent list.
|
||
match_all (List[str]): Agents must have all these tags.
|
||
match_some (List[str]): Agents must have at least one of these tags.
|
||
limit (Optional[int]): Maximum number of agents to return.
|
||
|
||
Returns:
|
||
List[PydanticAgentState: The filtered list of matching agents.
|
||
"""
|
||
with self.session_maker() as session:
|
||
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
|
||
|
||
if match_all:
|
||
# Subquery to find agent IDs that contain all match_all tags
|
||
subquery = (
|
||
select(AgentsTags.agent_id)
|
||
.where(AgentsTags.tag.in_(match_all))
|
||
.group_by(AgentsTags.agent_id)
|
||
.having(func.count(AgentsTags.tag) == literal(len(match_all)))
|
||
)
|
||
query = query.where(AgentModel.id.in_(subquery))
|
||
|
||
if match_some:
|
||
# Ensures agents match at least one tag in match_some
|
||
query = query.join(AgentsTags).where(AgentsTags.tag.in_(match_some))
|
||
|
||
query = query.distinct(AgentModel.id).order_by(AgentModel.id).limit(limit)
|
||
|
||
return list(session.execute(query).scalars())
|
||
|
||
@enforce_types
|
||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Fetch an agent by its ID."""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Fetch an agent by its ID."""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def delete_agent(self, agent_id: str, actor: PydanticUser) -> None:
|
||
"""
|
||
Deletes an agent and its associated relationships.
|
||
Ensures proper permission checks and cascades where applicable.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to be deleted.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If agent doesn't exist
|
||
"""
|
||
with self.session_maker() as session:
|
||
# Retrieve the agent
|
||
logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}")
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
agents_to_delete = [agent]
|
||
sleeptime_group_to_delete = None
|
||
|
||
# Delete sleeptime agent and group
|
||
if agent.multi_agent_group:
|
||
participant_agent_ids = agent.multi_agent_group.agent_ids
|
||
if agent.multi_agent_group.manager_type == ManagerType.sleeptime and len(participant_agent_ids) == 1:
|
||
sleeptime_agent = AgentModel.read(db_session=session, identifier=participant_agent_ids[0], actor=actor)
|
||
if sleeptime_agent.agent_type == AgentType.sleeptime_agent:
|
||
sleeptime_agent_group = GroupModel.read(db_session=session, identifier=agent.multi_agent_group.id, actor=actor)
|
||
sleeptime_group_to_delete = sleeptime_agent_group
|
||
agents_to_delete.append(sleeptime_agent)
|
||
try:
|
||
if sleeptime_group_to_delete is not None:
|
||
session.delete(sleeptime_group_to_delete)
|
||
session.commit()
|
||
for agent in agents_to_delete:
|
||
session.delete(agent)
|
||
session.commit()
|
||
except Exception as e:
|
||
session.rollback()
|
||
logger.exception(f"Failed to hard delete Agent with ID {agent_id}")
|
||
raise ValueError(f"Failed to hard delete Agent with ID {agent_id}: {e}")
|
||
else:
|
||
logger.debug(f"Agent with ID {agent_id} successfully hard deleted")
|
||
|
||
@enforce_types
|
||
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||
data = schema.dump(agent)
|
||
return AgentSchema(**data)
|
||
|
||
@enforce_types
|
||
def deserialize(
|
||
self,
|
||
serialized_agent: AgentSchema,
|
||
actor: PydanticUser,
|
||
append_copy_suffix: bool = True,
|
||
override_existing_tools: bool = True,
|
||
project_id: Optional[str] = None,
|
||
strip_messages: Optional[bool] = False,
|
||
) -> PydanticAgentState:
|
||
serialized_agent_dict = serialized_agent.model_dump()
|
||
tool_data_list = serialized_agent_dict.pop("tools", [])
|
||
messages = serialized_agent_dict.pop(MarshmallowAgentSchema.FIELD_MESSAGES, [])
|
||
|
||
for msg in messages:
|
||
msg[MarshmallowAgentSchema.FIELD_ID] = SerializedMessageSchema.generate_id() # Generate new ID
|
||
|
||
message_ids = []
|
||
in_context_message_indices = serialized_agent_dict.pop(MarshmallowAgentSchema.FIELD_IN_CONTEXT_INDICES)
|
||
for idx in in_context_message_indices:
|
||
message_ids.append(messages[idx][MarshmallowAgentSchema.FIELD_ID])
|
||
|
||
serialized_agent_dict[MarshmallowAgentSchema.FIELD_MESSAGE_IDS] = message_ids
|
||
|
||
with self.session_maker() as session:
|
||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||
agent = schema.load(serialized_agent_dict, session=session)
|
||
|
||
if append_copy_suffix:
|
||
agent.name += "_copy"
|
||
if project_id:
|
||
agent.project_id = project_id
|
||
|
||
if strip_messages:
|
||
# we want to strip all but the first (system) message
|
||
agent.message_ids = [agent.message_ids[0]]
|
||
agent = agent.create(session, actor=actor)
|
||
|
||
pydantic_agent = agent.to_pydantic()
|
||
|
||
pyd_msgs = []
|
||
message_schema = SerializedMessageSchema(session=session, actor=actor)
|
||
|
||
for serialized_message in messages:
|
||
pydantic_message = message_schema.load(serialized_message, session=session).to_pydantic()
|
||
pydantic_message.agent_id = agent.id
|
||
pyd_msgs.append(pydantic_message)
|
||
self.message_manager.create_many_messages(pyd_msgs, actor=actor)
|
||
|
||
# Need to do this separately as there's some fancy upsert logic that SqlAlchemy cannot handle
|
||
for tool_data in tool_data_list:
|
||
pydantic_tool = SerializedToolSchema(actor=actor).load(tool_data, transient=True).to_pydantic()
|
||
|
||
existing_pydantic_tool = self.tool_manager.get_tool_by_name(pydantic_tool.name, actor=actor)
|
||
if existing_pydantic_tool and (
|
||
existing_pydantic_tool.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MULTI_AGENT_CORE, ToolType.LETTA_MEMORY_CORE}
|
||
or not override_existing_tools
|
||
):
|
||
pydantic_tool = existing_pydantic_tool
|
||
else:
|
||
pydantic_tool = self.tool_manager.create_or_update_tool(pydantic_tool, actor=actor)
|
||
|
||
pydantic_agent = self.attach_tool(agent_id=pydantic_agent.id, tool_id=pydantic_tool.id, actor=actor)
|
||
|
||
return pydantic_agent
|
||
|
||
# ======================================================================================================================
|
||
# Per Agent Environment Variable Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
def _set_environment_variables(
|
||
self,
|
||
agent_id: str,
|
||
env_vars: Dict[str, str],
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""
|
||
Adds or replaces the environment variables for the specified agent.
|
||
|
||
Args:
|
||
agent_id: The agent id.
|
||
env_vars: A dictionary of environment variable key-value pairs.
|
||
actor: The user performing the action.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent as a Pydantic model.
|
||
"""
|
||
with self.session_maker() as session:
|
||
# Retrieve the agent
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Fetch existing environment variables as a dictionary
|
||
existing_vars = {var.key: var for var in agent.tool_exec_environment_variables}
|
||
|
||
# Update or create environment variables
|
||
updated_vars = []
|
||
for key, value in env_vars.items():
|
||
if key in existing_vars:
|
||
# Update existing variable
|
||
existing_vars[key].value = value
|
||
updated_vars.append(existing_vars[key])
|
||
else:
|
||
# Create new variable
|
||
updated_vars.append(
|
||
AgentEnvironmentVariableModel(
|
||
key=key,
|
||
value=value,
|
||
agent_id=agent_id,
|
||
organization_id=actor.organization_id,
|
||
created_by_id=actor.id,
|
||
last_updated_by_id=actor.id,
|
||
)
|
||
)
|
||
|
||
# Remove stale variables
|
||
stale_keys = set(existing_vars) - set(env_vars)
|
||
agent.tool_exec_environment_variables = [var for var in updated_vars if var.key not in stale_keys]
|
||
|
||
# Update the agent in the database
|
||
agent.update(session, actor=actor)
|
||
|
||
# Return the updated agent state
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]:
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
if manager_type:
|
||
return [group.to_pydantic() for group in agent.groups if group.manager_type == manager_type]
|
||
return [group.to_pydantic() for group in agent.groups]
|
||
|
||
# ======================================================================================================================
|
||
# In Context Messages Management
|
||
# ======================================================================================================================
|
||
# TODO: There are several assumptions here that are not explicitly checked
|
||
# TODO: 1) These message ids are valid
|
||
# TODO: 2) These messages are ordered from oldest to newest
|
||
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
|
||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||
@enforce_types
|
||
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||
|
||
# TODO: This is duplicated below
|
||
# TODO: This is legacy code and should be cleaned up
|
||
# TODO: A lot of the memory "compilation" should be offset to a separate class
|
||
@enforce_types
|
||
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
|
||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||
|
||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||
|
||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||
"""
|
||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||
|
||
curr_system_message = self.get_system_message(
|
||
agent_id=agent_id, actor=actor
|
||
) # this is the system + memory bank, not just the system prompt
|
||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||
|
||
# note: we only update the system prompt if the core memory is changed
|
||
# this means that the archival/recall memory statistics may be someout out of date
|
||
curr_memory_str = agent_state.memory.compile()
|
||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||
logger.debug(
|
||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||
)
|
||
return agent_state
|
||
|
||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||
# For example, if we're doing a system prompt swap, this should probably be False
|
||
if update_timestamp:
|
||
memory_edit_timestamp = get_utc_time()
|
||
else:
|
||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||
memory_edit_timestamp = curr_system_message.created_at
|
||
|
||
num_messages = self.message_manager.size(actor=actor, agent_id=agent_id)
|
||
num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id)
|
||
|
||
# update memory (TODO: potentially update recall/archival stats separately)
|
||
new_system_message_str = compile_system_message(
|
||
system_prompt=agent_state.system,
|
||
in_context_memory=agent_state.memory,
|
||
in_context_memory_last_edit=memory_edit_timestamp,
|
||
recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10),
|
||
previous_message_count=num_messages,
|
||
archival_memory_size=num_archival_memories,
|
||
)
|
||
|
||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||
if len(diff) > 0: # there was a diff
|
||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||
|
||
# Swap the system message out (only if there is a diff)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent_id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||
)
|
||
message = self.message_manager.update_message_by_id(
|
||
message_id=curr_system_message.id,
|
||
message_update=MessageUpdate(**message.model_dump()),
|
||
actor=actor,
|
||
)
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=agent_state.message_ids, actor=actor)
|
||
else:
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||
|
||
@enforce_types
|
||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||
|
||
@enforce_types
|
||
def trim_all_in_context_messages_except_system(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
# TODO: How do we know this?
|
||
new_messages = [message_ids[0]] # 0 is system message
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||
|
||
@enforce_types
|
||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
|
||
message_ids += [m.id for m in messages]
|
||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||
|
||
@enforce_types
|
||
def reset_messages(self, agent_id: str, actor: PydanticUser, add_default_initial_messages: bool = False) -> PydanticAgentState:
|
||
"""
|
||
Removes all in-context messages for the specified agent by:
|
||
1) Clearing the agent.messages relationship (which cascades delete-orphans).
|
||
2) Resetting the message_ids list to empty.
|
||
3) Committing the transaction.
|
||
|
||
This action is destructive and cannot be undone once committed.
|
||
|
||
Args:
|
||
add_default_initial_messages: If true, adds the default initial messages after resetting.
|
||
agent_id (str): The ID of the agent whose messages will be reset.
|
||
actor (PydanticUser): The user performing this action.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state with no linked messages.
|
||
"""
|
||
with self.session_maker() as session:
|
||
# Retrieve the existing agent (will raise NoResultFound if invalid)
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Also clear out the message_ids field to keep in-context memory consistent
|
||
agent.message_ids = []
|
||
|
||
# Commit the update
|
||
agent.update(db_session=session, actor=actor)
|
||
|
||
agent_state = agent.to_pydantic()
|
||
|
||
self.message_manager.delete_all_messages_for_agent(agent_id=agent_id, actor=actor)
|
||
|
||
if add_default_initial_messages:
|
||
return self.append_initial_message_sequence_to_in_context_messages(actor, agent_state)
|
||
else:
|
||
# We still want to always have a system message
|
||
init_messages = initialize_message_sequence(
|
||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||
)
|
||
system_message = PydanticMessage.dict_to_message(
|
||
agent_id=agent_state.id,
|
||
model=agent_state.llm_config.model,
|
||
openai_message_dict=init_messages[0],
|
||
)
|
||
return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor)
|
||
|
||
# TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version
|
||
@enforce_types
|
||
def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Update internal memory object and system prompt if there have been modifications.
|
||
|
||
Args:
|
||
actor:
|
||
agent_id:
|
||
new_memory (Memory): the new memory object to compare to the current memory object
|
||
|
||
Returns:
|
||
modified (bool): whether the memory was updated
|
||
"""
|
||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||
if agent_state.memory.compile() != new_memory.compile():
|
||
# update the blocks (LRW) in the DB
|
||
for label in agent_state.memory.list_block_labels():
|
||
updated_value = new_memory.get_block(label).value
|
||
if updated_value != agent_state.memory.get_block(label).value:
|
||
# update the block if it's changed
|
||
block_id = agent_state.memory.get_block(label).id
|
||
self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor)
|
||
|
||
# refresh memory from DB (using block ids)
|
||
agent_state.memory = Memory(
|
||
blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()],
|
||
prompt_template=get_prompt_template_for_agent_type(agent_state.agent_type),
|
||
)
|
||
|
||
# NOTE: don't do this since re-buildin the memory is handled at the start of the step
|
||
# rebuild memory - this records the last edited timestamp of the memory
|
||
# TODO: pass in update timestamp from block edit time
|
||
agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor)
|
||
|
||
return agent_state
|
||
|
||
@enforce_types
|
||
def refresh_memory(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||
if not block_ids:
|
||
return agent_state
|
||
|
||
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(
|
||
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
|
||
)
|
||
return agent_state
|
||
|
||
# ======================================================================================================================
|
||
# Source Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
def attach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Attaches a source to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the source to
|
||
source_id: ID of the source to attach
|
||
actor: User performing the action
|
||
|
||
Raises:
|
||
ValueError: If either agent or source doesn't exist
|
||
IntegrityError: If the source is already attached to the agent
|
||
"""
|
||
|
||
with self.session_maker() as session:
|
||
# Verify both agent and source exist and user has permission to access them
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# The _process_relationship helper already handles duplicate checking via unique constraint
|
||
_process_relationship(
|
||
session=session,
|
||
agent=agent,
|
||
relationship_name="sources",
|
||
model_class=SourceModel,
|
||
item_ids=[source_id],
|
||
allow_partial=False,
|
||
replace=False, # Extend existing sources rather than replace
|
||
)
|
||
|
||
# Commit the changes
|
||
agent.update(session, actor=actor)
|
||
|
||
# Force rebuild of system prompt so that the agent is updated with passage count
|
||
# and recent passages and add system message alert to agent
|
||
self.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True)
|
||
self.append_system_message(
|
||
agent_id=agent_id,
|
||
content=DATA_SOURCE_ATTACH_ALERT,
|
||
actor=actor,
|
||
)
|
||
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def append_system_message(self, agent_id: str, content: str, actor: PydanticUser):
|
||
|
||
# get the agent
|
||
agent = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||
message = PydanticMessage.dict_to_message(
|
||
agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||
)
|
||
|
||
# update agent in-context message IDs
|
||
self.append_to_in_context_messages(messages=[message], agent_id=agent_id, actor=actor)
|
||
|
||
@enforce_types
|
||
def list_attached_sources(self, agent_id: str, actor: PydanticUser) -> List[PydanticSource]:
|
||
"""
|
||
Lists all sources attached to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to list sources for
|
||
actor: User performing the action
|
||
|
||
Returns:
|
||
List[str]: List of source IDs attached to the agent
|
||
"""
|
||
with self.session_maker() as session:
|
||
# Verify agent exists and user has permission to access it
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Use the lazy-loaded relationship to get sources
|
||
return [source.to_pydantic() for source in agent.sources]
|
||
|
||
@enforce_types
|
||
def detach_source(self, agent_id: str, source_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Detaches a source from an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the source from
|
||
source_id: ID of the source to detach
|
||
actor: User performing the action
|
||
"""
|
||
with self.session_maker() as session:
|
||
# Verify agent exists and user has permission to access it
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Remove the source from the relationship
|
||
remaining_sources = [s for s in agent.sources if s.id != source_id]
|
||
|
||
if len(remaining_sources) == len(agent.sources): # Source ID was not in the relationship
|
||
logger.warning(f"Attempted to remove unattached source id={source_id} from agent id={agent_id} by actor={actor}")
|
||
|
||
# Update the sources relationship
|
||
agent.sources = remaining_sources
|
||
|
||
# Commit the changes
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
# ======================================================================================================================
|
||
# Block management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
def get_block_with_label(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticBlock:
|
||
"""Gets a block attached to an agent by its label."""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
for block in agent.core_memory:
|
||
if block.label == block_label:
|
||
return block.to_pydantic()
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}'")
|
||
|
||
@enforce_types
|
||
def update_block_with_label(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
new_block_id: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Updates which block is assigned to a specific label for an agent."""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor)
|
||
|
||
if new_block.label != block_label:
|
||
raise ValueError(f"New block label '{new_block.label}' doesn't match required label '{block_label}'")
|
||
|
||
# Remove old block with this label if it exists
|
||
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
|
||
|
||
# Add new block
|
||
agent.core_memory.append(new_block)
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""Attaches a block to an agent."""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||
|
||
agent.core_memory.append(block)
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def detach_block(
|
||
self,
|
||
agent_id: str,
|
||
block_id: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Detaches a block from an agent."""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
original_length = len(agent.core_memory)
|
||
|
||
agent.core_memory = [b for b in agent.core_memory if b.id != block_id]
|
||
|
||
if len(agent.core_memory) == original_length:
|
||
raise NoResultFound(f"No block with id '{block_id}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def detach_block_with_label(
|
||
self,
|
||
agent_id: str,
|
||
block_label: str,
|
||
actor: PydanticUser,
|
||
) -> PydanticAgentState:
|
||
"""Detaches a block with the specified label from an agent."""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
original_length = len(agent.core_memory)
|
||
|
||
agent.core_memory = [b for b in agent.core_memory if b.label != block_label]
|
||
|
||
if len(agent.core_memory) == original_length:
|
||
raise NoResultFound(f"No block with label '{block_label}' found for agent '{agent_id}' with actor id: '{actor.id}'")
|
||
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
# ======================================================================================================================
|
||
# Passage Management
|
||
# ======================================================================================================================
|
||
def _build_passage_query(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> Select:
|
||
"""Helper function to build the base passage query with all filters applied.
|
||
Supports both before and after pagination across merged source and agent passages.
|
||
|
||
Returns the query before any limit or count operations are applied.
|
||
"""
|
||
embedded_text = None
|
||
if embed_query:
|
||
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
||
assert query_text is not None, "query_text must be specified for vector search"
|
||
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
||
embedded_text = np.array(embedded_text)
|
||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||
|
||
with self.session_maker() as session:
|
||
# Start with base query for source passages
|
||
source_passages = None
|
||
if not agent_only: # Include source passages
|
||
if agent_id is not None:
|
||
source_passages = (
|
||
select(SourcePassage, literal(None).label("agent_id"))
|
||
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
||
.where(SourcesAgents.agent_id == agent_id)
|
||
.where(SourcePassage.organization_id == actor.organization_id)
|
||
)
|
||
else:
|
||
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
|
||
SourcePassage.organization_id == actor.organization_id
|
||
)
|
||
|
||
if source_id:
|
||
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
||
if file_id:
|
||
source_passages = source_passages.where(SourcePassage.file_id == file_id)
|
||
|
||
# Add agent passages query
|
||
agent_passages = None
|
||
if agent_id is not None:
|
||
agent_passages = (
|
||
select(
|
||
AgentPassage.id,
|
||
AgentPassage.text,
|
||
AgentPassage.embedding_config,
|
||
AgentPassage.metadata_,
|
||
AgentPassage.embedding,
|
||
AgentPassage.created_at,
|
||
AgentPassage.updated_at,
|
||
AgentPassage.is_deleted,
|
||
AgentPassage._created_by_id,
|
||
AgentPassage._last_updated_by_id,
|
||
AgentPassage.organization_id,
|
||
literal(None).label("file_id"),
|
||
literal(None).label("source_id"),
|
||
AgentPassage.agent_id,
|
||
)
|
||
.where(AgentPassage.agent_id == agent_id)
|
||
.where(AgentPassage.organization_id == actor.organization_id)
|
||
)
|
||
|
||
# Combine queries
|
||
if source_passages is not None and agent_passages is not None:
|
||
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
|
||
elif agent_passages is not None:
|
||
combined_query = agent_passages.cte("combined_passages")
|
||
elif source_passages is not None:
|
||
combined_query = source_passages.cte("combined_passages")
|
||
else:
|
||
raise ValueError("No passages found")
|
||
|
||
# Build main query from combined CTE
|
||
main_query = select(combined_query)
|
||
|
||
# Apply filters
|
||
if start_date:
|
||
main_query = main_query.where(combined_query.c.created_at >= start_date)
|
||
if end_date:
|
||
main_query = main_query.where(combined_query.c.created_at <= end_date)
|
||
if source_id:
|
||
main_query = main_query.where(combined_query.c.source_id == source_id)
|
||
if file_id:
|
||
main_query = main_query.where(combined_query.c.file_id == file_id)
|
||
|
||
# Vector search
|
||
if embedded_text:
|
||
if settings.letta_pg_uri_no_default:
|
||
# PostgreSQL with pgvector
|
||
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
|
||
else:
|
||
# SQLite with custom vector type
|
||
query_embedding_binary = adapt_array(embedded_text)
|
||
main_query = main_query.order_by(
|
||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||
combined_query.c.created_at.asc() if ascending else combined_query.c.created_at.desc(),
|
||
combined_query.c.id.asc(),
|
||
)
|
||
else:
|
||
if query_text:
|
||
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
|
||
|
||
# Handle pagination
|
||
if before or after:
|
||
# Create reference CTEs
|
||
if before:
|
||
before_ref = (
|
||
select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == before).cte("before_ref")
|
||
)
|
||
if after:
|
||
after_ref = (
|
||
select(combined_query.c.created_at, combined_query.c.id).where(combined_query.c.id == after).cte("after_ref")
|
||
)
|
||
|
||
if before and after:
|
||
# Window-based query (get records between before and after)
|
||
main_query = main_query.where(
|
||
or_(
|
||
combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(),
|
||
and_(
|
||
combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(),
|
||
combined_query.c.id < select(before_ref.c.id).scalar_subquery(),
|
||
),
|
||
)
|
||
)
|
||
main_query = main_query.where(
|
||
or_(
|
||
combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(),
|
||
and_(
|
||
combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(),
|
||
combined_query.c.id > select(after_ref.c.id).scalar_subquery(),
|
||
),
|
||
)
|
||
)
|
||
else:
|
||
# Pure pagination (only before or only after)
|
||
if before:
|
||
main_query = main_query.where(
|
||
or_(
|
||
combined_query.c.created_at < select(before_ref.c.created_at).scalar_subquery(),
|
||
and_(
|
||
combined_query.c.created_at == select(before_ref.c.created_at).scalar_subquery(),
|
||
combined_query.c.id < select(before_ref.c.id).scalar_subquery(),
|
||
),
|
||
)
|
||
)
|
||
if after:
|
||
main_query = main_query.where(
|
||
or_(
|
||
combined_query.c.created_at > select(after_ref.c.created_at).scalar_subquery(),
|
||
and_(
|
||
combined_query.c.created_at == select(after_ref.c.created_at).scalar_subquery(),
|
||
combined_query.c.id > select(after_ref.c.id).scalar_subquery(),
|
||
),
|
||
)
|
||
)
|
||
|
||
# Add ordering if not already ordered by similarity
|
||
if not embed_query:
|
||
if ascending:
|
||
main_query = main_query.order_by(
|
||
combined_query.c.created_at.asc(),
|
||
combined_query.c.id.asc(),
|
||
)
|
||
else:
|
||
main_query = main_query.order_by(
|
||
combined_query.c.created_at.desc(),
|
||
combined_query.c.id.asc(),
|
||
)
|
||
|
||
return main_query
|
||
|
||
@enforce_types
|
||
def list_passages(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
limit: Optional[int] = 50,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> List[PydanticPassage]:
|
||
"""Lists all passages attached to an agent."""
|
||
with self.session_maker() as session:
|
||
main_query = self._build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Add limit
|
||
if limit:
|
||
main_query = main_query.limit(limit)
|
||
|
||
# Execute query
|
||
results = list(session.execute(main_query))
|
||
|
||
passages = []
|
||
for row in results:
|
||
data = dict(row._mapping)
|
||
if data["agent_id"] is not None:
|
||
# This is an AgentPassage - remove source fields
|
||
data.pop("source_id", None)
|
||
data.pop("file_id", None)
|
||
passage = AgentPassage(**data)
|
||
else:
|
||
# This is a SourcePassage - remove agent field
|
||
data.pop("agent_id", None)
|
||
passage = SourcePassage(**data)
|
||
passages.append(passage)
|
||
|
||
return [p.to_pydantic() for p in passages]
|
||
|
||
@enforce_types
|
||
def passage_size(
|
||
self,
|
||
actor: PydanticUser,
|
||
agent_id: Optional[str] = None,
|
||
file_id: Optional[str] = None,
|
||
query_text: Optional[str] = None,
|
||
start_date: Optional[datetime] = None,
|
||
end_date: Optional[datetime] = None,
|
||
before: Optional[str] = None,
|
||
after: Optional[str] = None,
|
||
source_id: Optional[str] = None,
|
||
embed_query: bool = False,
|
||
ascending: bool = True,
|
||
embedding_config: Optional[EmbeddingConfig] = None,
|
||
agent_only: bool = False,
|
||
) -> int:
|
||
"""Returns the count of passages matching the given criteria."""
|
||
with self.session_maker() as session:
|
||
main_query = self._build_passage_query(
|
||
actor=actor,
|
||
agent_id=agent_id,
|
||
file_id=file_id,
|
||
query_text=query_text,
|
||
start_date=start_date,
|
||
end_date=end_date,
|
||
before=before,
|
||
after=after,
|
||
source_id=source_id,
|
||
embed_query=embed_query,
|
||
ascending=ascending,
|
||
embedding_config=embedding_config,
|
||
agent_only=agent_only,
|
||
)
|
||
|
||
# Convert to count query
|
||
count_query = select(func.count()).select_from(main_query.subquery())
|
||
return session.scalar(count_query) or 0
|
||
|
||
# ======================================================================================================================
|
||
# Tool Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
def attach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Attaches a tool to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to attach the tool to.
|
||
tool_id: ID of the tool to attach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
with self.session_maker() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Use the _process_relationship helper to attach the tool
|
||
_process_relationship(
|
||
session=session,
|
||
agent=agent,
|
||
relationship_name="tools",
|
||
model_class=ToolModel,
|
||
item_ids=[tool_id],
|
||
allow_partial=False, # Ensure the tool exists
|
||
replace=False, # Extend the existing tools
|
||
)
|
||
|
||
# Commit and refresh the agent
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def detach_tool(self, agent_id: str, tool_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||
"""
|
||
Detaches a tool from an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to detach the tool from.
|
||
tool_id: ID of the tool to detach.
|
||
actor: User performing the action.
|
||
|
||
Raises:
|
||
NoResultFound: If the agent or tool is not found.
|
||
|
||
Returns:
|
||
PydanticAgentState: The updated agent state.
|
||
"""
|
||
with self.session_maker() as session:
|
||
# Verify the agent exists and user has permission to access it
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
|
||
# Filter out the tool to be detached
|
||
remaining_tools = [tool for tool in agent.tools if tool.id != tool_id]
|
||
|
||
if len(remaining_tools) == len(agent.tools): # Tool ID was not in the relationship
|
||
logger.warning(f"Attempted to remove unattached tool id={tool_id} from agent id={agent_id} by actor={actor}")
|
||
|
||
# Update the tools relationship
|
||
agent.tools = remaining_tools
|
||
|
||
# Commit and refresh the agent
|
||
agent.update(session, actor=actor)
|
||
return agent.to_pydantic()
|
||
|
||
@enforce_types
|
||
def list_attached_tools(self, agent_id: str, actor: PydanticUser) -> List[PydanticTool]:
|
||
"""
|
||
List all tools attached to an agent.
|
||
|
||
Args:
|
||
agent_id: ID of the agent to list tools for.
|
||
actor: User performing the action.
|
||
|
||
Returns:
|
||
List[PydanticTool]: List of tools attached to the agent.
|
||
"""
|
||
with self.session_maker() as session:
|
||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||
return [tool.to_pydantic() for tool in agent.tools]
|
||
|
||
# ======================================================================================================================
|
||
# Tag Management
|
||
# ======================================================================================================================
|
||
@enforce_types
|
||
def list_tags(
|
||
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None
|
||
) -> List[str]:
|
||
"""
|
||
Get all tags a user has created, ordered alphabetically.
|
||
|
||
Args:
|
||
actor: User performing the action.
|
||
after: Cursor for forward pagination.
|
||
limit: Maximum number of tags to return.
|
||
query_text: Query text to filter tags by.
|
||
|
||
Returns:
|
||
List[str]: List of all tags.
|
||
"""
|
||
with self.session_maker() as session:
|
||
query = (
|
||
session.query(AgentsTags.tag)
|
||
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
|
||
.filter(AgentModel.organization_id == actor.organization_id)
|
||
.distinct()
|
||
)
|
||
|
||
if query_text:
|
||
query = query.filter(AgentsTags.tag.ilike(f"%{query_text}%"))
|
||
|
||
if after:
|
||
query = query.filter(AgentsTags.tag > after)
|
||
|
||
query = query.order_by(AgentsTags.tag).limit(limit)
|
||
results = [tag[0] for tag in query.all()]
|
||
return results
|