feat: async db client (#2076)

This commit is contained in:
Andy Li 2025-05-12 17:15:14 -07:00 committed by GitHub
parent 845005451f
commit e85c558ddc
29 changed files with 723 additions and 466 deletions

View File

@ -3,14 +3,21 @@ from typing import Any, AsyncGenerator, List, Optional, Union
import openai
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import MessageCreate
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.message_manager import MessageManager
from letta.utils import united_diff
logger = get_logger(__name__)
class BaseAgent(ABC):
@ -64,3 +71,107 @@ class BaseAgent(ABC):
return ""
return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages]
def _rebuild_memory(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None, # storing these calculations is specific to the voice agent
num_archival_memories: int | None = None,
) -> List[Message]:
try:
# Refresh Memory
# TODO: This only happens for the summary block (voice?)
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
# [DB Call] size of messages and archival memories
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# [DB Call] Update Messages
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
async def _rebuild_memory_async(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
"""
Async version of function above. For now before breaking up components, changes should be made in both places.
"""
try:
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
# [DB Call] size of messages and archival memories
# todo: blocking for now
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# [DB Call] Update Messages
new_system_message = self.message_manager.update_message_by_id_async(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise

View File

@ -32,6 +32,7 @@ from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
from letta.settings import settings
from letta.system import package_function_response
from letta.tracing import log_event, trace_method
from letta.utils import united_diff
@ -171,6 +172,7 @@ class LettaAgent(BaseAgent):
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
@trace_method
# When raising an error this doesn't show up
async def _get_ai_reply(
self,
llm_client: LLMClientBase,
@ -179,7 +181,10 @@ class LettaAgent(BaseAgent):
tool_rules_solver: ToolRulesSolver,
stream: bool,
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
if settings.experimental_enable_async_db_engine:
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
else:
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
tools = [
t
@ -296,51 +301,6 @@ class LettaAgent(BaseAgent):
return persisted_messages, continue_stepping
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
try:
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
except:
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
raise
@trace_method
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
"""

View File

@ -1,11 +1,12 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union
from aiomultiprocess import Pool
from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult
from letta.agents.base_agent import BaseAgent
from letta.agents.helpers import _prepare_in_context_messages
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
@ -16,11 +17,12 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.orm.enums import ToolType
from letta.schemas.agent import AgentState, AgentStepState
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageStreamStatus, ProviderType
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_request import LettaBatchRequest
from letta.schemas.letta_response import LettaBatchResponse
from letta.schemas.letta_response import LettaBatchResponse, LettaResponse
from letta.schemas.llm_batch_job import LLMBatchItem
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.openai.chat_completion_response import ToolCall as OpenAIToolCall
@ -95,7 +97,7 @@ async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[
# TODO: Limitations ->
# TODO: Only works with anthropic for now
class LettaAgentBatch:
class LettaAgentBatch(BaseAgent):
def __init__(
self,
@ -539,43 +541,20 @@ class LettaAgentBatch:
return in_context_messages
# TODO: Make this a bullk function
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
agent_state = self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
def _rebuild_memory(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None,
num_archival_memories: int | None = None,
) -> List[Message]:
return super()._rebuild_memory(in_context_messages, agent_state)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
# Not used in batch.
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
raise NotImplementedError
memory_edit_timestamp = get_utc_time()
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
async def step_stream(
self, input_messages: List[MessageCreate], max_steps: int = 10
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
raise NotImplementedError

View File

@ -293,48 +293,17 @@ class VoiceAgent(BaseAgent):
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
)
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
# Refresh memory
# TODO: This only happens for the summary block
# TODO: We want to extend this refresh to be general, and stick it in agent_manager
block_ids = [block.id for block in agent_state.memory.blocks]
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(block_ids=block_ids, actor=self.actor)
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
curr_system_message = in_context_messages[0]
curr_memory_str = agent_state.memory.compile()
curr_system_message_text = curr_system_message.content[0].text
if curr_memory_str in curr_system_message_text:
# NOTE: could this cause issues if a block is removed? (substring match would still work)
logger.debug(
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
)
return in_context_messages
memory_edit_timestamp = get_utc_time()
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=self.num_messages,
archival_memory_size=self.num_archival_memories,
def _rebuild_memory(
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None,
num_archival_memories: int | None = None,
) -> List[Message]:
return super()._rebuild_memory(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
)
diff = united_diff(curr_system_message_text, new_system_message_str)
if len(diff) > 0:
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
new_system_message = self.message_manager.update_message_by_id(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
)
# Skip pulling down the agent's memory again to save on a db call
return [new_system_message] + in_context_messages[1:]
else:
return in_context_messages
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
tool_schemas = self._build_tool_schemas(agent_state)
tool_choice = "auto" if tool_schemas else None

View File

@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
from sqlalchemy import String, and_, func, or_, select
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, Session, mapped_column
from letta.log import get_logger
@ -300,6 +301,44 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
return found[0]
@classmethod
@handle_db_timeout
async def read_async(
cls,
db_session: "Session",
identifier: Optional[str] = None,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
) -> "SqlalchemyBase":
"""The primary accessor for an ORM record. Async version of read method.
Args:
db_session: the database session to use when retrieving the record
identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility
actor: if specified, results will be scoped only to records the user is able to access
access: if actor is specified, records will be filtered to the minimum permission level for the actor
kwargs: additional arguments to pass to the read, used for more complex objects
Returns:
The matching object
Raises:
NoResultFound: if the object is not found
"""
# this is ok because read_multiple will check if the
identifiers = [] if identifier is None else [identifier]
found = await cls.read_multiple_async(db_session, identifiers, actor, access, access_type, **kwargs)
if len(found) == 0:
# for backwards compatibility.
conditions = []
if identifier:
conditions.append(f"id={identifier}")
if actor:
conditions.append(f"access level in {access} for {actor}")
if hasattr(cls, "is_deleted"):
conditions.append("is_deleted=False")
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
return found[0]
@classmethod
@handle_db_timeout
def read_multiple(
@ -323,6 +362,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
Raises:
NoResultFound: if the object is not found
"""
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs)
results = db_session.execute(query).scalars().all()
return cls._read_multiple_postprocess(results, identifiers, query_conditions)
@classmethod
@handle_db_timeout
async def read_multiple_async(
cls,
db_session: "AsyncSession",
identifiers: List[str] = [],
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
**kwargs,
) -> List["SqlalchemyBase"]:
"""
Async version of read_multiple(...)
The primary accessor for ORM record(s)
"""
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs)
results = await db_session.execute(query)
return cls._read_multiple_postprocess(results.scalars().all(), identifiers, query_conditions)
@classmethod
def _read_multiple_preprocess(
cls,
identifiers: List[str],
actor: Optional["User"],
access: Optional[List[Literal["read", "write", "admin"]]],
access_type: AccessType,
**kwargs,
):
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
# Start the query
@ -350,7 +421,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query = query.where(cls.is_deleted == False)
query_conditions.append("is_deleted=False")
results = db_session.execute(query).scalars().all()
return query, query_conditions
@classmethod
def _read_multiple_postprocess(cls, results, identifiers: List[str], query_conditions) -> List["SqlalchemyBase"]:
if results: # if empty list a.k.a. no results
if len(identifiers) > 0:
# find which identifiers were not found
@ -471,6 +545,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
db_session.refresh(self)
return self
@handle_db_timeout
async def update_async(self, db_session: AsyncSession, actor: "User | None" = None, no_commit: bool = False) -> "SqlalchemyBase":
"""Async version of update function"""
logger.debug(...)
if actor:
self._set_created_and_updated_by_fields(actor.id)
self.set_updated_at()
db_session.add(self)
if no_commit:
await db_session.flush()
else:
await db_session.commit()
await db_session.refresh(self)
return self
@classmethod
@handle_db_timeout
def size(

View File

@ -1,6 +1,7 @@
from typing import Dict
from marshmallow import fields, post_dump, pre_load
from sqlalchemy.orm import sessionmaker
import letta
from letta.orm import Agent
@ -14,7 +15,6 @@ from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigFie
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
from letta.server.db import SessionLocal
class MarshmallowAgentSchema(BaseSchema):
@ -41,7 +41,7 @@ class MarshmallowAgentSchema(BaseSchema):
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
tags = fields.List(fields.Nested(SerializedAgentTagSchema))
def __init__(self, *args, session: SessionLocal, actor: User, **kwargs):
def __init__(self, *args, session: sessionmaker, actor: User, **kwargs):
super().__init__(*args, actor=actor, **kwargs)
self.session = session
@ -60,9 +60,9 @@ class MarshmallowAgentSchema(BaseSchema):
After dumping the agent, load all its Message rows and serialize them here.
"""
# TODO: This is hacky, but want to move fast, please refactor moving forward
from letta.server.db import db_context as session_maker
from letta.server.db import db_registry
with session_maker() as session:
with db_registry.session() as session:
agent_id = data.get("id")
msgs = (
session.query(MessageModel)

View File

@ -1,28 +1,19 @@
import os
import threading
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager
from typing import Any, AsyncGenerator, Generator
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import create_engine
from sqlalchemy import Engine, create_engine
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
from letta.log import get_logger
from letta.orm import Base
from letta.settings import settings
# Use globals for the lock and initialization flag
_engine_lock = threading.Lock()
_engine_initialized = False
# Create variables in global scope but don't initialize them yet
config = LettaConfig.load()
logger = get_logger(__name__)
engine = None
SessionLocal = None
def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
@ -54,86 +45,187 @@ def db_error_handler():
exit(1)
def initialize_engine():
"""Initialize the database engine only when needed."""
global engine, SessionLocal, _engine_initialized
class DatabaseRegistry:
"""Registry for database connections and sessions.
with _engine_lock:
# Check again inside the lock to prevent race conditions
if _engine_initialized:
return
This class manages both synchronous and asynchronous database connections
and provides context managers for session handling.
"""
if settings.letta_pg_uri_no_default:
logger.info("Creating postgres engine")
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
config.archival_storage_type = "postgres"
config.archival_storage_uri = settings.letta_pg_uri_no_default
def __init__(self):
self._engines: dict[str, Engine] = {}
self._async_engines: dict[str, AsyncEngine] = {}
self._session_factories: dict[str, sessionmaker] = {}
self._async_session_factories: dict[str, async_sessionmaker] = {}
self._initialized: dict[str, bool] = {"sync": False, "async": False}
self._lock = threading.Lock()
self.config = LettaConfig.load()
self.logger = get_logger(__name__)
# create engine
engine = create_engine(
settings.letta_pg_uri,
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
# connect_args={"client_encoding": "utf8"},
)
else:
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
logger.info("Creating sqlite engine " + engine_path)
def initialize_sync(self, force: bool = False) -> None:
"""Initialize the synchronous database engine if not already initialized."""
with self._lock:
if self._initialized.get("sync") and not force:
return
engine = create_engine(engine_path)
# Postgres engine
if settings.letta_pg_uri_no_default:
self.logger.info("Creating postgres engine")
self.config.recall_storage_type = "postgres"
self.config.recall_storage_uri = settings.letta_pg_uri_no_default
self.config.archival_storage_type = "postgres"
self.config.archival_storage_uri = settings.letta_pg_uri_no_default
# Store the original connect method
original_connect = engine.connect
engine = create_engine(
settings.letta_pg_uri,
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
# connect_args={"client_encoding": "utf8"},
)
def wrapped_connect(*args, **kwargs):
with db_error_handler():
# Get the connection
connection = original_connect(*args, **kwargs)
self._engines["default"] = engine
# SQLite engine
else:
from letta.orm import Base
# Store the original execution method
original_execute = connection.execute
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(self.config.recall_storage_path, "sqlite.db")
self.logger.info("Creating sqlite engine " + engine_path)
# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
engine = create_engine(engine_path)
# Replace the connection's execute method
connection.execute = wrapped_execute
# Wrap the engine with error handling
self._wrap_sqlite_engine(engine)
return connection
Base.metadata.create_all(bind=engine)
self._engines["default"] = engine
# Replace the engine's connect method
engine.connect = wrapped_connect
# Create session factory
self._session_factories["default"] = sessionmaker(autocommit=False, autoflush=False, bind=self._engines["default"])
self._initialized["sync"] = True
Base.metadata.create_all(bind=engine)
def initialize_async(self, force: bool = False) -> None:
"""Initialize the asynchronous database engine if not already initialized."""
with self._lock:
if self._initialized.get("async") and not force:
return
# Create the session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
_engine_initialized = True
if settings.letta_pg_uri_no_default:
self.logger.info("Creating async postgres engine")
# Create async engine - convert URI to async format
pg_uri = settings.letta_pg_uri
if pg_uri.startswith("postgresql://"):
async_pg_uri = pg_uri.replace("postgresql://", "postgresql+asyncpg://")
else:
async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri
async_engine = create_async_engine(
async_pg_uri,
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
)
self._async_engines["default"] = async_engine
# Create async session factory
self._async_session_factories["default"] = async_sessionmaker(
autocommit=False, autoflush=False, bind=self._async_engines["default"], class_=AsyncSession
)
self._initialized["async"] = True
else:
self.logger.warning("Async SQLite is currently not supported. Please use PostgreSQL for async database operations.")
# TODO (cliandy): unclear around async sqlite support in sqlalchemy, we will not currently support this
self._initialized["async"] = False
def _wrap_sqlite_engine(self, engine: Engine) -> None:
"""Wrap SQLite engine with error handling."""
original_connect = engine.connect
def wrapped_connect(*args, **kwargs):
with db_error_handler():
connection = original_connect(*args, **kwargs)
original_execute = connection.execute
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
connection.execute = wrapped_execute
return connection
engine.connect = wrapped_connect
def get_engine(self, name: str = "default") -> Engine:
"""Get a database engine by name."""
self.initialize_sync()
return self._engines.get(name)
def get_async_engine(self, name: str = "default") -> AsyncEngine:
"""Get an async database engine by name."""
self.initialize_async()
return self._async_engines.get(name)
def get_session_factory(self, name: str = "default") -> sessionmaker:
"""Get a session factory by name."""
self.initialize_sync()
return self._session_factories.get(name)
def get_async_session_factory(self, name: str = "default") -> async_sessionmaker:
"""Get an async session factory by name."""
self.initialize_async()
return self._async_session_factories.get(name)
@contextmanager
def session(self, name: str = "default") -> Generator[Any, None, None]:
"""Context manager for database sessions."""
session_factory = self.get_session_factory(name)
if not session_factory:
raise ValueError(f"No session factory found for '{name}'")
session = session_factory()
try:
yield session
finally:
session.close()
@asynccontextmanager
async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]:
"""Async context manager for database sessions."""
session_factory = self.get_async_session_factory(name)
if not session_factory:
raise ValueError(f"No async session factory found for '{name}' or async database is not configured")
session = session_factory()
try:
yield session
finally:
await session.close()
# Create a singleton instance
db_registry = DatabaseRegistry()
def get_db():
"""Get a database session, initializing the engine if needed."""
global engine, SessionLocal
# Make sure engine is initialized
if not _engine_initialized:
initialize_engine()
# Now SessionLocal should be defined and callable
db = SessionLocal()
try:
yield db
finally:
db.close()
"""Get a database session."""
with db_registry.session() as session:
yield session
# Define db_context as a context manager that uses get_db
async def get_db_async():
"""Get an async database session."""
async with db_registry.async_session() as session:
yield session
# Prefer calling db_registry.session() or db_registry.async_session() directly
# This is for backwards compatibility
db_context = contextmanager(get_db)

View File

@ -56,6 +56,7 @@ from letta.serialize_schemas import MarshmallowAgentSchema
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
from letta.server.db import db_registry
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import (
_apply_filters,
@ -85,9 +86,6 @@ class AgentManager:
"""Manager class to handle business logic related to Agents."""
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
self.block_manager = BlockManager()
self.tool_manager = ToolManager()
self.source_manager = SourceManager()
@ -200,7 +198,7 @@ class AgentManager:
identity_ids = agent_create.identity_ids or []
tag_values = agent_create.tags or []
with self.session_maker() as session:
with db_registry.session() as session:
with session.begin():
name_to_id, id_to_name = self._resolve_tools(
session,
@ -356,7 +354,7 @@ class AgentManager:
new_idents = set(agent_update.identity_ids or [])
new_tags = set(agent_update.tags or [])
with self.session_maker() as session, session.begin():
with db_registry.session() as session, session.begin():
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
agent.updated_at = datetime.now(timezone.utc)
@ -503,7 +501,7 @@ class AgentManager:
Returns:
List[PydanticAgentState]: The filtered list of matching agents.
"""
with self.session_maker() as session:
with db_registry.session() as session:
query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id)
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
@ -541,7 +539,7 @@ class AgentManager:
Returns:
List[PydanticAgentState: The filtered list of matching agents.
"""
with self.session_maker() as session:
with db_registry.session() as session:
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
if match_all:
@ -569,20 +567,20 @@ class AgentManager:
"""
Get the total count of agents for the given user.
"""
with self.session_maker() as session:
with db_registry.session() as session:
return AgentModel.size(db_session=session, actor=actor)
@enforce_types
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return agent.to_pydantic()
@enforce_types
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
"""Fetch an agent by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
return agent.to_pydantic()
@ -599,7 +597,7 @@ class AgentManager:
Raises:
NoResultFound: If agent doesn't exist
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve the agent
logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}")
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -635,7 +633,7 @@ class AgentManager:
@enforce_types
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
schema = MarshmallowAgentSchema(session=session, actor=actor)
data = schema.dump(agent)
@ -665,7 +663,7 @@ class AgentManager:
serialized_agent_dict[MarshmallowAgentSchema.FIELD_MESSAGE_IDS] = message_ids
with self.session_maker() as session:
with db_registry.session() as session:
schema = MarshmallowAgentSchema(session=session, actor=actor)
agent = schema.load(serialized_agent_dict, session=session)
@ -728,7 +726,7 @@ class AgentManager:
Returns:
PydanticAgentState: The updated agent as a Pydantic model.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -767,7 +765,7 @@ class AgentManager:
@enforce_types
def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]:
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
if manager_type:
return [group.to_pydantic() for group in agent.groups if group.manager_type == manager_type]
@ -908,7 +906,7 @@ class AgentManager:
Returns:
PydanticAgentState: The updated agent state with no linked messages.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve the existing agent (will raise NoResultFound if invalid)
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -985,6 +983,17 @@ class AgentManager:
)
return agent_state
@enforce_types
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
block_ids = [b.id for b in agent_state.memory.blocks]
if not block_ids:
return agent_state
agent_state.memory.blocks = await self.block_manager.get_all_blocks_by_ids_async(
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
)
return agent_state
# ======================================================================================================================
# Source Management
# ======================================================================================================================
@ -1003,7 +1012,7 @@ class AgentManager:
IntegrityError: If the source is already attached to the agent
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Verify both agent and source exist and user has permission to access them
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -1056,7 +1065,7 @@ class AgentManager:
Returns:
List[str]: List of source IDs attached to the agent
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -1073,7 +1082,7 @@ class AgentManager:
source_id: ID of the source to detach
actor: User performing the action
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Verify agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -1101,7 +1110,7 @@ class AgentManager:
actor: PydanticUser,
) -> PydanticBlock:
"""Gets a block attached to an agent by its label."""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
for block in agent.core_memory:
if block.label == block_label:
@ -1117,7 +1126,7 @@ class AgentManager:
actor: PydanticUser,
) -> PydanticAgentState:
"""Updates which block is assigned to a specific label for an agent."""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor)
@ -1135,7 +1144,7 @@ class AgentManager:
@enforce_types
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
"""Attaches a block to an agent."""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
@ -1151,7 +1160,7 @@ class AgentManager:
actor: PydanticUser,
) -> PydanticAgentState:
"""Detaches a block from an agent."""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
original_length = len(agent.core_memory)
@ -1171,7 +1180,7 @@ class AgentManager:
actor: PydanticUser,
) -> PydanticAgentState:
"""Detaches a block with the specified label from an agent."""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
original_length = len(agent.core_memory)
@ -1215,7 +1224,7 @@ class AgentManager:
embedded_text = np.array(embedded_text)
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
with self.session_maker() as session:
with db_registry.session() as session:
# Start with base query for source passages
source_passages = None
if not agent_only: # Include source passages
@ -1389,7 +1398,7 @@ class AgentManager:
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
with self.session_maker() as session:
with db_registry.session() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
@ -1447,7 +1456,7 @@ class AgentManager:
agent_only: bool = False,
) -> int:
"""Returns the count of passages matching the given criteria."""
with self.session_maker() as session:
with db_registry.session() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
@ -1487,7 +1496,7 @@ class AgentManager:
Returns:
PydanticAgentState: The updated agent state.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Verify the agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -1522,7 +1531,7 @@ class AgentManager:
Returns:
PydanticAgentState: The updated agent state.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Verify the agent exists and user has permission to access it
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -1551,7 +1560,7 @@ class AgentManager:
Returns:
List[PydanticTool]: List of tools attached to the agent.
"""
with self.session_maker() as session:
with db_registry.session() as session:
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
return [tool.to_pydantic() for tool in agent.tools]
@ -1574,7 +1583,7 @@ class AgentManager:
Returns:
List[str]: List of all tags.
"""
with self.session_maker() as session:
with db_registry.session() as session:
query = (
session.query(AgentsTags.tag)
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)

View File

@ -12,6 +12,7 @@ from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, Human, Persona
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, list_human_files, list_persona_files
logger = get_logger(__name__)
@ -20,12 +21,6 @@ logger = get_logger(__name__)
class BlockManager:
"""Manager class to handle business logic related to Blocks."""
def __init__(self):
# Fetching the db_context similarly as in ToolManager
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
"""Create a new block based on the Block schema."""
@ -34,7 +29,7 @@ class BlockManager:
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
self.update_block(block.id, update_data, actor)
else:
with self.session_maker() as session:
with db_registry.session() as session:
data = block.model_dump(to_orm=True, exclude_none=True)
block = BlockModel(**data, organization_id=actor.organization_id)
block.create(session, actor=actor)
@ -53,7 +48,7 @@ class BlockManager:
if not blocks:
return []
with self.session_maker() as session:
with db_registry.session() as session:
block_models = [
BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks
]
@ -68,7 +63,7 @@ class BlockManager:
"""Update a block by its ID with the given BlockUpdate object."""
# Safety check for block
with self.session_maker() as session:
with db_registry.session() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
@ -81,7 +76,7 @@ class BlockManager:
@enforce_types
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
"""Delete a block by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
block = BlockModel.read(db_session=session, identifier=block_id)
block.hard_delete(db_session=session, actor=actor)
return block.to_pydantic()
@ -100,7 +95,7 @@ class BlockManager:
limit: Optional[int] = 50,
) -> List[PydanticBlock]:
"""Retrieve blocks based on various optional filters."""
with self.session_maker() as session:
with db_registry.session() as session:
# Prepare filters
filters = {"organization_id": actor.organization_id}
if label:
@ -126,7 +121,7 @@ class BlockManager:
@enforce_types
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
"""Retrieve a block by its name."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
return block.to_pydantic()
@ -136,12 +131,24 @@ class BlockManager:
@enforce_types
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids."""
with self.session_maker() as session:
with db_registry.session() as session:
blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)]
# backwards compatibility. previous implementation added None for every block not found.
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
return blocks
@enforce_types
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
"""Retrieve blocks by their ids. Async implementation."""
async with db_registry.async_session() as session:
blocks = [
block.to_pydantic()
for block in await BlockModel.read_multiple_async(db_session=session, identifiers=block_ids, actor=actor)
]
# backwards compatibility. previous implementation added None for every block not found.
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
return blocks
@enforce_types
def add_default_blocks(self, actor: PydanticUser):
for persona_file in list_persona_files():
@ -161,7 +168,7 @@ class BlockManager:
"""
Retrieve all agents associated with a given block.
"""
with self.session_maker() as session:
with db_registry.session() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
agents_orm = block.agents
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
@ -176,7 +183,7 @@ class BlockManager:
"""
Get the total count of blocks for the given user.
"""
with self.session_maker() as session:
with db_registry.session() as session:
return BlockModel.size(db_session=session, actor=actor)
# Block History Functions
@ -199,7 +206,7 @@ class BlockManager:
strictly linear history.
- A single commit at the end ensures atomicity.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# 1) Load the Block
if use_preloaded_block is not None:
block = session.merge(use_preloaded_block)
@ -291,7 +298,7 @@ class BlockManager:
If older sequences have been pruned, we jump to the largest sequence
number that is still < current_seq.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# 1) Load the current block
block = (
session.merge(use_preloaded_block)
@ -333,7 +340,7 @@ class BlockManager:
If some middle checkpoints have been pruned, we jump to the smallest
sequence > current_seq that remains.
"""
with self.session_maker() as session:
with db_registry.session() as session:
block = (
session.merge(use_preloaded_block)
if use_preloaded_block
@ -383,7 +390,7 @@ class BlockManager:
NoResultFound if any block_id doesnt exist or isnt visible to this actor
ValueError if any new value exceeds its blocks limit
"""
with self.session_maker() as session:
with db_registry.session() as session:
q = session.query(BlockModel).filter(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
blocks = q.all()

View File

@ -11,16 +11,12 @@ from letta.schemas.group import GroupCreate, GroupUpdate, ManagerType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
class GroupManager:
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def list_groups(
self,
@ -31,7 +27,7 @@ class GroupManager:
after: Optional[str] = None,
limit: Optional[int] = 50,
) -> list[PydanticGroup]:
with self.session_maker() as session:
with db_registry.session() as session:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
@ -48,13 +44,13 @@ class GroupManager:
@enforce_types
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
with self.session_maker() as session:
with db_registry.session() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
return group.to_pydantic()
@enforce_types
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
with self.session_maker() as session:
with db_registry.session() as session:
new_group = GroupModel()
new_group.organization_id = actor.organization_id
new_group.description = group.description
@ -99,7 +95,7 @@ class GroupManager:
@enforce_types
def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
with self.session_maker() as session:
with db_registry.session() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
sleeptime_agent_frequency = None
@ -161,7 +157,7 @@ class GroupManager:
@enforce_types
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve the agent
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
group.hard_delete(session)
@ -178,7 +174,7 @@ class GroupManager:
assistant_message_tool_name: str = "send_message",
assistant_message_tool_kwarg: str = "message",
) -> list[LettaMessage]:
with self.session_maker() as session:
with db_registry.session() as session:
filters = {
"organization_id": actor.organization_id,
"group_id": group_id,
@ -204,7 +200,7 @@ class GroupManager:
@enforce_types
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
with db_registry.session() as session:
# Ensure group is loadable by user
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
@ -217,7 +213,7 @@ class GroupManager:
@enforce_types
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
with self.session_maker() as session:
with db_registry.session() as session:
# Ensure group is loadable by user
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
@ -228,7 +224,7 @@ class GroupManager:
@enforce_types
def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str:
with self.session_maker() as session:
with db_registry.session() as session:
# Ensure group is loadable by user
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
@ -247,7 +243,7 @@ class GroupManager:
"""
Get the total count of groups for the given user.
"""
with self.session_maker() as session:
with db_registry.session() as session:
return GroupModel.size(db_session=session, actor=actor)
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):

View File

@ -10,16 +10,12 @@ from letta.orm.identity import Identity as IdentityModel
from letta.schemas.identity import Identity as PydanticIdentity
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
class IdentityManager:
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def list_identities(
self,
@ -32,7 +28,7 @@ class IdentityManager:
limit: Optional[int] = 50,
actor: PydanticUser = None,
) -> list[PydanticIdentity]:
with self.session_maker() as session:
with db_registry.session() as session:
filters = {"organization_id": actor.organization_id}
if project_id:
filters["project_id"] = project_id
@ -52,13 +48,13 @@ class IdentityManager:
@enforce_types
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
with db_registry.session() as session:
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
return identity.to_pydantic()
@enforce_types
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
with db_registry.session() as session:
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
new_identity.organization_id = actor.organization_id
self._process_relationship(
@ -82,7 +78,7 @@ class IdentityManager:
@enforce_types
def upsert_identity(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
with db_registry.session() as session:
existing_identity = IdentityModel.read(
db_session=session,
identifier_key=identity.identifier_key,
@ -107,7 +103,7 @@ class IdentityManager:
@enforce_types
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
with self.session_maker() as session:
with db_registry.session() as session:
try:
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
except NoResultFound:
@ -167,7 +163,7 @@ class IdentityManager:
@enforce_types
def upsert_identity_properties(self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser) -> PydanticIdentity:
with self.session_maker() as session:
with db_registry.session() as session:
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
if existing_identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
@ -181,7 +177,7 @@ class IdentityManager:
@enforce_types
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
with self.session_maker() as session:
with db_registry.session() as session:
identity = IdentityModel.read(db_session=session, identifier=identity_id)
if identity is None:
raise HTTPException(status_code=404, detail="Identity not found")
@ -198,7 +194,7 @@ class IdentityManager:
"""
Get the total count of identities for the given user.
"""
with self.session_maker() as session:
with db_registry.session() as session:
return IdentityModel.size(db_session=session, actor=actor)
def _process_relationship(

View File

@ -24,24 +24,19 @@ from letta.schemas.run import Run as PydanticRun
from letta.schemas.step import Step as PydanticStep
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
class JobManager:
"""Manager class to handle business logic related to Jobs."""
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_job(
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
"""Create a new job based on the JobCreate schema."""
with self.session_maker() as session:
with db_registry.session() as session:
# Associate the job with the user
pydantic_job.user_id = actor.id
job_data = pydantic_job.model_dump(to_orm=True)
@ -52,7 +47,7 @@ class JobManager:
@enforce_types
def update_job_by_id(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob:
"""Update a job by its ID with the given JobUpdate object."""
with self.session_maker() as session:
with db_registry.session() as session:
# Fetch the job by ID
job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"])
@ -76,7 +71,7 @@ class JobManager:
@enforce_types
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Fetch a job by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve job by ID using the Job model's read method
job = JobModel.read(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER)
return job.to_pydantic()
@ -93,7 +88,7 @@ class JobManager:
ascending: bool = True,
) -> List[PydanticJob]:
"""List all jobs with optional pagination and status filter."""
with self.session_maker() as session:
with db_registry.session() as session:
filter_kwargs = {"user_id": actor.id, "job_type": job_type}
# Add status filter if provided
@ -113,7 +108,7 @@ class JobManager:
@enforce_types
def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Delete a job by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
job = self._verify_job_access(session=session, job_id=job_id, actor=actor)
job.hard_delete(db_session=session, actor=actor)
return job.to_pydantic()
@ -147,7 +142,7 @@ class JobManager:
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Build filters
filters = {}
if role is not None:
@ -195,7 +190,7 @@ class JobManager:
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Build filters
filters = {}
filters["job_id"] = job_id
@ -227,7 +222,7 @@ class JobManager:
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
with self.session_maker() as session:
with db_registry.session() as session:
# First verify job exists and user has access
self._verify_job_access(session, job_id, actor, access=["write"])
@ -251,7 +246,7 @@ class JobManager:
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
with self.session_maker() as session:
with db_registry.session() as session:
# First verify job exists and user has access
self._verify_job_access(session, job_id, actor)
@ -293,7 +288,7 @@ class JobManager:
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
with self.session_maker() as session:
with db_registry.session() as session:
# First verify job exists and user has access
self._verify_job_access(session, job_id, actor, access=["write"])
@ -453,7 +448,7 @@ class JobManager:
Returns:
The request config for the job
"""
with self.session_maker() as session:
with db_registry.session() as session:
job = session.query(JobModel).filter(JobModel.id == run_id).first()
request_config = job.request_config or LettaRequestConfig()
return request_config

View File

@ -16,6 +16,7 @@ from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
logger = get_logger(__name__)
@ -24,11 +25,6 @@ logger = get_logger(__name__)
class LLMBatchManager:
"""Manager for handling both LLMBatchJob and LLMBatchItem operations."""
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_llm_batch_job(
self,
@ -39,7 +35,7 @@ class LLMBatchManager:
status: JobStatus = JobStatus.created,
) -> PydanticLLMBatchJob:
"""Create a new LLM batch job."""
with self.session_maker() as session:
with db_registry.session() as session:
batch = LLMBatchJob(
status=status,
llm_provider=llm_provider,
@ -53,7 +49,7 @@ class LLMBatchManager:
@enforce_types
def get_llm_batch_job_by_id(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
"""Retrieve a single batch job by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
return batch.to_pydantic()
@ -66,7 +62,7 @@ class LLMBatchManager:
latest_polling_response: Optional[BetaMessageBatch] = None,
) -> PydanticLLMBatchJob:
"""Update a batch jobs status and optionally its polling response."""
with self.session_maker() as session:
with db_registry.session() as session:
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
batch.status = status
batch.latest_polling_response = latest_polling_response
@ -85,7 +81,7 @@ class LLMBatchManager:
"""
now = datetime.datetime.now(datetime.timezone.utc)
with self.session_maker() as session:
with db_registry.session() as session:
mappings = []
for llm_batch_id, status, response in updates:
mappings.append(
@ -119,7 +115,7 @@ class LLMBatchManager:
The results are ordered by their id in ascending order.
"""
with self.session_maker() as session:
with db_registry.session() as session:
query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id)
if actor is not None:
@ -140,7 +136,7 @@ class LLMBatchManager:
@enforce_types
def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch job by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
batch.hard_delete(db_session=session, actor=actor)
@ -158,7 +154,7 @@ class LLMBatchManager:
Retrieve messages across all LLM batch jobs associated with a Letta batch job.
Optimized for PostgreSQL performance using ID-based keyset pagination.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# If cursor is provided, get sequence_id for that message
cursor_sequence_id = None
if cursor:
@ -203,7 +199,7 @@ class LLMBatchManager:
@enforce_types
def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
with self.session_maker() as session:
with db_registry.session() as session:
query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running)
if actor is not None:
@ -224,7 +220,7 @@ class LLMBatchManager:
step_state: Optional[AgentStepState] = None,
) -> PydanticLLMBatchItem:
"""Create a new batch item."""
with self.session_maker() as session:
with db_registry.session() as session:
item = LLMBatchItem(
llm_batch_id=llm_batch_id,
agent_id=agent_id,
@ -249,7 +245,7 @@ class LLMBatchManager:
Returns:
List of created batch items as Pydantic models
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Convert Pydantic models to ORM objects
orm_items = []
for item in llm_batch_items:
@ -274,7 +270,7 @@ class LLMBatchManager:
@enforce_types
def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
"""Retrieve a single batch item by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
return item.to_pydantic()
@ -289,7 +285,7 @@ class LLMBatchManager:
step_state: Optional[AgentStepState] = None,
) -> PydanticLLMBatchItem:
"""Update fields on a batch item."""
with self.session_maker() as session:
with db_registry.session() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
if request_status:
@ -325,7 +321,7 @@ class LLMBatchManager:
The results are ordered by their id in ascending order.
"""
with self.session_maker() as session:
with db_registry.session() as session:
query = session.query(LLMBatchItem).filter(LLMBatchItem.llm_batch_id == llm_batch_id)
if actor is not None:
@ -367,7 +363,7 @@ class LLMBatchManager:
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
with self.session_maker() as session:
with db_registry.session() as session:
# Lookup primary keys for all requested (batch_id, agent_id) pairs
items = (
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
@ -434,7 +430,7 @@ class LLMBatchManager:
@enforce_types
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
"""Hard delete a batch item by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
item.hard_delete(db_session=session, actor=actor)
@ -449,6 +445,6 @@ class LLMBatchManager:
Returns:
int: The total number of batch items associated with the given llm_batch_id.
"""
with self.session_maker() as session:
with db_registry.session() as session:
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar()
return count or 0

View File

@ -12,6 +12,7 @@ from letta.schemas.letta_message import LettaMessageUpdateUnion
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.message import MessageUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
logger = get_logger(__name__)
@ -20,15 +21,10 @@ logger = get_logger(__name__)
class MessageManager:
"""Manager class to handle business logic related to Messages."""
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
"""Fetch a message by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
message = MessageModel.read(db_session=session, identifier=message_id, actor=actor)
return message.to_pydantic()
@ -38,7 +34,7 @@ class MessageManager:
@enforce_types
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
"""Fetch messages by ID and return them in the requested order."""
with self.session_maker() as session:
with db_registry.session() as session:
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
if len(results) != len(message_ids):
@ -53,7 +49,7 @@ class MessageManager:
@enforce_types
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
"""Create a new message."""
with self.session_maker() as session:
with db_registry.session() as session:
# Set the organization id of the Pydantic message
pydantic_msg.organization_id = actor.organization_id
msg_data = pydantic_msg.model_dump(to_orm=True)
@ -86,7 +82,7 @@ class MessageManager:
orm_messages.append(MessageModel(**msg_data))
# Use the batch_create method for efficient creation
with self.session_maker() as session:
with db_registry.session() as session:
created_messages = MessageModel.batch_create(orm_messages, session, actor=actor)
# Convert back to Pydantic models
@ -173,7 +169,7 @@ class MessageManager:
"""
Updates an existing record in the database with values from the provided record object.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Fetch existing message from database
message = MessageModel.read(
db_session=session,
@ -181,31 +177,57 @@ class MessageManager:
actor=actor,
)
# Some safety checks specific to messages
if message_update.tool_calls and message.role != MessageRole.assistant:
raise ValueError(
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
)
if message_update.tool_call_id and message.role != MessageRole.tool:
raise ValueError(
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
)
# get update dictionary
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
for key, value in update_data.items():
setattr(message, key, value)
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
message.update(db_session=session, actor=actor)
return message.to_pydantic()
@enforce_types
async def update_message_by_id_async(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
Async version of the function above.
"""
async with db_registry.async_session() as session:
# Fetch existing message from database
message = await MessageModel.read_async(
db_session=session,
identifier=message_id,
actor=actor,
)
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
await message.update_async(db_session=session, actor=actor)
return message.to_pydantic()
def _update_message_by_id_impl(
self, message_id: str, message_update: MessageUpdate, actor: PydanticUser, message: MessageModel
) -> MessageModel:
"""
Modifies the existing message object to update the database in the sync/async functions.
"""
# Some safety checks specific to messages
if message_update.tool_calls and message.role != MessageRole.assistant:
raise ValueError(
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
)
if message_update.tool_call_id and message.role != MessageRole.tool:
raise ValueError(
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
)
# get update dictionary
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
for key, value in update_data.items():
setattr(message, key, value)
return message
@enforce_types
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
"""Delete a message."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
msg = MessageModel.read(
db_session=session,
@ -229,7 +251,7 @@ class MessageManager:
actor: The user requesting the count
role: The role of the message
"""
with self.session_maker() as session:
with db_registry.session() as session:
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
@enforce_types
@ -293,7 +315,7 @@ class MessageManager:
NoResultFound: If the provided after/before message IDs do not exist.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Permission check: raise if the agent doesn't exist or actor is not allowed.
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
@ -356,7 +378,7 @@ class MessageManager:
Efficiently deletes all messages associated with a given agent_id,
while enforcing permission checks and avoiding any ORMlevel loads.
"""
with self.session_maker() as session:
with db_registry.session() as session:
# 1) verify the agent exists and the actor has access
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)

View File

@ -4,6 +4,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.organization import OrganizationUpdate
from letta.server.db import db_registry
from letta.utils import enforce_types
@ -13,14 +14,6 @@ class OrganizationManager:
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
DEFAULT_ORG_NAME = "default_org"
def __init__(self):
# TODO: Please refactor this out
# I am currently working on a ORM refactor and would like to make a more minimal set of changes
# - Matt
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def get_default_organization(self) -> PydanticOrganization:
"""Fetch the default organization."""
@ -29,7 +22,7 @@ class OrganizationManager:
@enforce_types
def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]:
"""Fetch an organization by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
organization = OrganizationModel.read(db_session=session, identifier=org_id)
return organization.to_pydantic()
@ -44,7 +37,7 @@ class OrganizationManager:
@enforce_types
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
with self.session_maker() as session:
with db_registry.session() as session:
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
org.create(session)
return org.to_pydantic()
@ -57,7 +50,7 @@ class OrganizationManager:
@enforce_types
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
"""Update an organization."""
with self.session_maker() as session:
with db_registry.session() as session:
org = OrganizationModel.read(db_session=session, identifier=org_id)
if name:
org.name = name
@ -67,7 +60,7 @@ class OrganizationManager:
@enforce_types
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
"""Update an organization."""
with self.session_maker() as session:
with db_registry.session() as session:
org = OrganizationModel.read(db_session=session, identifier=org_id)
if org_update.name:
org.name = org_update.name
@ -79,14 +72,14 @@ class OrganizationManager:
@enforce_types
def delete_organization_by_id(self, org_id: str):
"""Delete an organization by marking it as deleted."""
with self.session_maker() as session:
with db_registry.session() as session:
organization = OrganizationModel.read(db_session=session, identifier=org_id)
organization.hard_delete(session)
@enforce_types
def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
"""List all organizations with optional pagination."""
with self.session_maker() as session:
with db_registry.session() as session:
organizations = OrganizationModel.list(
db_session=session,
after=after,

View File

@ -10,21 +10,17 @@ from letta.orm.passage import AgentPassage, SourcePassage
from letta.schemas.agent import AgentState
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
class PassageManager:
"""Manager class to handle business logic related to Passages."""
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch a passage by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
# Try source passages first
try:
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
@ -69,7 +65,7 @@ class PassageManager:
else:
raise ValueError("Passage must have either agent_id or source_id")
with self.session_maker() as session:
with db_registry.session() as session:
passage.create(session, actor=actor)
return passage.to_pydantic()
@ -145,7 +141,7 @@ class PassageManager:
if not passage_id:
raise ValueError("Passage ID must be provided.")
with self.session_maker() as session:
with db_registry.session() as session:
# Try source passages first
try:
curr_passage = SourcePassage.read(
@ -179,7 +175,7 @@ class PassageManager:
if not passage_id:
raise ValueError("Passage ID must be provided.")
with self.session_maker() as session:
with db_registry.session() as session:
# Try source passages first
try:
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
@ -217,7 +213,7 @@ class PassageManager:
actor: The user requesting the count
agent_id: The agent ID of the messages
"""
with self.session_maker() as session:
with db_registry.session() as session:
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
def estimate_embeddings_size(

View File

@ -5,20 +5,16 @@ from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
class ProviderManager:
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
"""Create a new provider if it doesn't already exist."""
with self.session_maker() as session:
with db_registry.session() as session:
provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok}
provider = PydanticProvider(**provider_create_args)
@ -38,7 +34,7 @@ class ProviderManager:
@enforce_types
def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
"""Update provider details."""
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve the existing provider by ID
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
@ -54,7 +50,7 @@ class ProviderManager:
@enforce_types
def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
"""Delete a provider."""
with self.session_maker() as session:
with db_registry.session() as session:
# Clear api key field
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
existing_provider.api_key = None
@ -80,7 +76,7 @@ class ProviderManager:
filter_kwargs["name"] = name
if provider_type:
filter_kwargs["provider_type"] = provider_type
with self.session_maker() as session:
with db_registry.session() as session:
providers = ProviderModel.list(
db_session=session,
after=after,

View File

@ -11,6 +11,7 @@ from letta.schemas.sandbox_config import LocalSandboxConfig
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, printd
logger = get_logger(__name__)
@ -19,11 +20,6 @@ logger = get_logger(__name__)
class SandboxConfigManager:
"""Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable."""
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig:
sandbox_config = self.get_sandbox_config_by_type(sandbox_type, actor=actor)
@ -69,7 +65,7 @@ class SandboxConfigManager:
return db_sandbox
else:
# If the sandbox configuration doesn't exist, create a new one
with self.session_maker() as session:
with db_registry.session() as session:
db_sandbox = SandboxConfigModel(**sandbox_config.model_dump(exclude_none=True))
db_sandbox.create(session, actor=actor)
return db_sandbox.to_pydantic()
@ -79,7 +75,7 @@ class SandboxConfigManager:
self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser
) -> PydanticSandboxConfig:
"""Update an existing sandbox configuration."""
with self.session_maker() as session:
with db_registry.session() as session:
sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor)
# We need to check that the sandbox_update provided is the same type as the original sandbox
if sandbox.type != sandbox_update.config.type:
@ -104,7 +100,7 @@ class SandboxConfigManager:
@enforce_types
def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig:
"""Delete a sandbox configuration by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor)
sandbox.hard_delete(db_session=session, actor=actor)
return sandbox.to_pydantic()
@ -122,14 +118,14 @@ class SandboxConfigManager:
if sandbox_type:
kwargs.update({"type": sandbox_type})
with self.session_maker() as session:
with db_registry.session() as session:
sandboxes = SandboxConfigModel.list(db_session=session, after=after, limit=limit, **kwargs)
return [sandbox.to_pydantic() for sandbox in sandboxes]
@enforce_types
def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox configuration by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor)
return sandbox.to_pydantic()
@ -139,7 +135,7 @@ class SandboxConfigManager:
@enforce_types
def get_sandbox_config_by_type(self, type: SandboxType, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
"""Retrieve a sandbox config by its type."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
sandboxes = SandboxConfigModel.list(
db_session=session,
@ -175,7 +171,7 @@ class SandboxConfigManager:
return db_env_var
else:
with self.session_maker() as session:
with db_registry.session() as session:
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True))
env_var.create(session, actor=actor)
return env_var.to_pydantic()
@ -185,7 +181,7 @@ class SandboxConfigManager:
self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser
) -> PydanticEnvVar:
"""Update an existing sandbox environment variable."""
with self.session_maker() as session:
with db_registry.session() as session:
env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor)
update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value}
@ -204,7 +200,7 @@ class SandboxConfigManager:
@enforce_types
def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar:
"""Delete a sandbox environment variable by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor)
env_var.hard_delete(db_session=session, actor=actor)
return env_var.to_pydantic()
@ -218,7 +214,7 @@ class SandboxConfigManager:
limit: Optional[int] = 50,
) -> List[PydanticEnvVar]:
"""List all sandbox environment variables with optional pagination."""
with self.session_maker() as session:
with db_registry.session() as session:
env_vars = SandboxEnvVarModel.list(
db_session=session,
after=after,
@ -233,7 +229,7 @@ class SandboxConfigManager:
self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> List[PydanticEnvVar]:
"""List all sandbox environment variables with optional pagination."""
with self.session_maker() as session:
with db_registry.session() as session:
env_vars = SandboxEnvVarModel.list(
db_session=session,
after=after,
@ -258,7 +254,7 @@ class SandboxConfigManager:
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
) -> Optional[PydanticEnvVar]:
"""Retrieve a sandbox environment variable by its key and sandbox_config_id."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
env_var = SandboxEnvVarModel.list(
db_session=session,

View File

@ -8,17 +8,13 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, printd
class SourceManager:
"""Manager class to handle business logic related to Sources."""
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
"""Create a new source based on the PydanticSource schema."""
@ -27,7 +23,7 @@ class SourceManager:
if db_source:
return db_source
else:
with self.session_maker() as session:
with db_registry.session() as session:
# Provide default embedding config if not given
source.organization_id = actor.organization_id
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
@ -37,7 +33,7 @@ class SourceManager:
@enforce_types
def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
"""Update a source by its ID with the given SourceUpdate object."""
with self.session_maker() as session:
with db_registry.session() as session:
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
# get update dictionary
@ -59,7 +55,7 @@ class SourceManager:
@enforce_types
def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
"""Delete a source by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
source = SourceModel.read(db_session=session, identifier=source_id)
source.hard_delete(db_session=session, actor=actor)
return source.to_pydantic()
@ -67,7 +63,7 @@ class SourceManager:
@enforce_types
def list_sources(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]:
"""List all sources with optional pagination."""
with self.session_maker() as session:
with db_registry.session() as session:
sources = SourceModel.list(
db_session=session,
after=after,
@ -85,7 +81,7 @@ class SourceManager:
"""
Get the total count of sources for the given user.
"""
with self.session_maker() as session:
with db_registry.session() as session:
return SourceModel.size(db_session=session, actor=actor)
@enforce_types
@ -100,7 +96,7 @@ class SourceManager:
Returns:
List[PydanticAgentState]: List of agents that have this source attached
"""
with self.session_maker() as session:
with db_registry.session() as session:
# Verify source exists and user has permission to access it
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
@ -112,7 +108,7 @@ class SourceManager:
@enforce_types
def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
"""Retrieve a source by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
return source.to_pydantic()
@ -122,7 +118,7 @@ class SourceManager:
@enforce_types
def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
"""Retrieve a source by its name."""
with self.session_maker() as session:
with db_registry.session() as session:
sources = SourceModel.list(
db_session=session,
name=source_name,
@ -141,7 +137,7 @@ class SourceManager:
if db_file:
return db_file
else:
with self.session_maker() as session:
with db_registry.session() as session:
file_metadata.organization_id = actor.organization_id
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
file_metadata.create(session, actor=actor)
@ -151,7 +147,7 @@ class SourceManager:
@enforce_types
def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]:
"""Retrieve a file by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
file = FileMetadataModel.read(db_session=session, identifier=file_id, actor=actor)
return file.to_pydantic()
@ -163,7 +159,7 @@ class SourceManager:
self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> List[PydanticFileMetadata]:
"""List all files with optional pagination."""
with self.session_maker() as session:
with db_registry.session() as session:
files = FileMetadataModel.list(
db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id
)
@ -172,7 +168,7 @@ class SourceManager:
@enforce_types
def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
"""Delete a file by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
file = FileMetadataModel.read(db_session=session, identifier=file_id)
file.hard_delete(db_session=session, actor=actor)
return file.to_pydantic()

View File

@ -11,17 +11,13 @@ from letta.orm.step import Step as StepModel
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.tracing import get_trace_id
from letta.utils import enforce_types
class StepManager:
def __init__(self):
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def list_steps(
self,
@ -36,7 +32,7 @@ class StepManager:
agent_id: Optional[str] = None,
) -> List[PydanticStep]:
"""List all jobs with optional pagination and status filter."""
with self.session_maker() as session:
with db_registry.session() as session:
filter_kwargs = {"organization_id": actor.organization_id}
if model:
filter_kwargs["model"] = model
@ -85,7 +81,7 @@ class StepManager:
"tid": None,
"trace_id": get_trace_id(), # Get the current trace ID
}
with self.session_maker() as session:
with db_registry.session() as session:
if job_id:
self._verify_job_access(session, job_id, actor, access=["write"])
new_step = StepModel(**step_data)
@ -94,7 +90,7 @@ class StepManager:
@enforce_types
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
with self.session_maker() as session:
with db_registry.session() as session:
step = StepModel.read(db_session=session, identifier=step_id, actor=actor)
return step.to_pydantic()
@ -113,7 +109,7 @@ class StepManager:
Raises:
NoResultFound: If the step does not exist
"""
with self.session_maker() as session:
with db_registry.session() as session:
step = session.get(StepModel, step_id)
if not step:
raise NoResultFound(f"Step with id {step_id} does not exist")

View File

@ -23,6 +23,7 @@ from letta.orm.tool import Tool as ToolModel
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, printd
logger = get_logger(__name__)
@ -31,12 +32,6 @@ logger = get_logger(__name__)
class ToolManager:
"""Manager class to handle business logic related to Tools."""
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.db import db_context
self.session_maker = db_context
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
@enforce_types
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
@ -89,7 +84,7 @@ class ToolManager:
@enforce_types
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
"""Create a new tool based on the ToolCreate schema."""
with self.session_maker() as session:
with db_registry.session() as session:
# Set the organization id at the ORM layer
pydantic_tool.organization_id = actor.organization_id
# Auto-generate description if not provided
@ -104,7 +99,7 @@ class ToolManager:
@enforce_types
def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
"""Fetch a tool by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve tool by id using the Tool model's read method
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
# Convert the SQLAlchemy Tool object to PydanticTool
@ -114,7 +109,7 @@ class ToolManager:
def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
with self.session_maker() as session:
with db_registry.session() as session:
tool = ToolModel.read(db_session=session, name=tool_name, actor=actor)
return tool.to_pydantic()
except NoResultFound:
@ -124,7 +119,7 @@ class ToolManager:
def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
try:
with self.session_maker() as session:
with db_registry.session() as session:
tool = ToolModel.read(db_session=session, name=tool_name, actor=actor)
return tool.id
except NoResultFound:
@ -133,7 +128,7 @@ class ToolManager:
@enforce_types
def list_tools(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
"""List all tools with optional pagination."""
with self.session_maker() as session:
with db_registry.session() as session:
tools = ToolModel.list(
db_session=session,
after=after,
@ -166,7 +161,7 @@ class ToolManager:
If include_builtin is True, it will also count the built-in tools.
"""
with self.session_maker() as session:
with db_registry.session() as session:
if include_base_tools:
return ToolModel.size(db_session=session, actor=actor)
return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET)
@ -176,7 +171,7 @@ class ToolManager:
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
) -> PydanticTool:
"""Update a tool by its ID with the given ToolUpdate object."""
with self.session_maker() as session:
with db_registry.session() as session:
# Fetch the tool by ID
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
@ -202,7 +197,7 @@ class ToolManager:
@enforce_types
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
"""Delete a tool by its ID."""
with self.session_maker() as session:
with db_registry.session() as session:
try:
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
tool.hard_delete(db_session=session, actor=actor)

View File

@ -5,6 +5,7 @@ from letta.orm.organization import Organization as OrganizationModel
from letta.orm.user import User as UserModel
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.services.organization_manager import OrganizationManager
from letta.utils import enforce_types
@ -15,16 +16,10 @@ class UserManager:
DEFAULT_USER_NAME = "default_user"
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
def __init__(self):
# Fetching the db_context similarly as in OrganizationManager
from letta.server.db import db_context
self.session_maker = db_context
@enforce_types
def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
with self.session_maker() as session:
with db_registry.session() as session:
# Make sure the org id exists
try:
OrganizationModel.read(db_session=session, identifier=org_id)
@ -44,7 +39,7 @@ class UserManager:
@enforce_types
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist."""
with self.session_maker() as session:
with db_registry.session() as session:
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
new_user.create(session)
return new_user.to_pydantic()
@ -52,7 +47,7 @@ class UserManager:
@enforce_types
def update_user(self, user_update: UserUpdate) -> PydanticUser:
"""Update user details."""
with self.session_maker() as session:
with db_registry.session() as session:
# Retrieve the existing user by ID
existing_user = UserModel.read(db_session=session, identifier=user_update.id)
@ -68,7 +63,7 @@ class UserManager:
@enforce_types
def delete_user_by_id(self, user_id: str):
"""Delete a user and their associated records (agents, sources, mappings)."""
with self.session_maker() as session:
with db_registry.session() as session:
# Delete from user table
user = UserModel.read(db_session=session, identifier=user_id)
user.hard_delete(session)
@ -78,7 +73,7 @@ class UserManager:
@enforce_types
def get_user_by_id(self, user_id: str) -> PydanticUser:
"""Fetch a user by ID."""
with self.session_maker() as session:
with db_registry.session() as session:
user = UserModel.read(db_session=session, identifier=user_id)
return user.to_pydantic()
@ -104,7 +99,7 @@ class UserManager:
@enforce_types
def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
"""List all users with optional pagination."""
with self.session_maker() as session:
with db_registry.session() as session:
users = UserModel.list(
db_session=session,
after=after,

View File

@ -203,6 +203,7 @@ class Settings(BaseSettings):
use_experimental: bool = False
use_vertex_structured_outputs_experimental: bool = False
use_vertex_async_loop_experimental: bool = False
experimental_enable_async_db_engine: bool = False
# LLM provider client settings
httpx_max_retries: int = 5

View File

@ -812,7 +812,7 @@ def printd(*args, **kwargs):
print(*args, **kwargs)
def united_diff(str1, str2):
def united_diff(str1: str, str2: str) -> str:
lines1 = str1.splitlines(True)
lines2 = str2.splitlines(True)
diff = difflib.unified_diff(lines1, lines2)

69
poetry.lock generated
View File

@ -326,6 +326,73 @@ files = [
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
]
[[package]]
name = "asyncpg"
version = "0.30.0"
description = "An asyncio PostgreSQL driver"
optional = false
python-versions = ">=3.8.0"
groups = ["main"]
files = [
{file = "asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e"},
{file = "asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0"},
{file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f"},
{file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af"},
{file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75"},
{file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f"},
{file = "asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf"},
{file = "asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50"},
{file = "asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a"},
{file = "asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed"},
{file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a"},
{file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956"},
{file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056"},
{file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454"},
{file = "asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d"},
{file = "asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f"},
{file = "asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e"},
{file = "asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a"},
{file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3"},
{file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737"},
{file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a"},
{file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af"},
{file = "asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e"},
{file = "asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305"},
{file = "asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70"},
{file = "asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3"},
{file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33"},
{file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4"},
{file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4"},
{file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba"},
{file = "asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590"},
{file = "asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e"},
{file = "asyncpg-0.30.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:29ff1fc8b5bf724273782ff8b4f57b0f8220a1b2324184846b39d1ab4122031d"},
{file = "asyncpg-0.30.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64e899bce0600871b55368b8483e5e3e7f1860c9482e7f12e0a771e747988168"},
{file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b290f4726a887f75dcd1b3006f484252db37602313f806e9ffc4e5996cfe5cb"},
{file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f86b0e2cd3f1249d6fe6fd6cfe0cd4538ba994e2d8249c0491925629b9104d0f"},
{file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:393af4e3214c8fa4c7b86da6364384c0d1b3298d45803375572f415b6f673f38"},
{file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fd4406d09208d5b4a14db9a9dbb311b6d7aeeab57bded7ed2f8ea41aeef39b34"},
{file = "asyncpg-0.30.0-cp38-cp38-win32.whl", hash = "sha256:0b448f0150e1c3b96cb0438a0d0aa4871f1472e58de14a3ec320dbb2798fb0d4"},
{file = "asyncpg-0.30.0-cp38-cp38-win_amd64.whl", hash = "sha256:f23b836dd90bea21104f69547923a02b167d999ce053f3d502081acea2fba15b"},
{file = "asyncpg-0.30.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f4e83f067b35ab5e6371f8a4c93296e0439857b4569850b178a01385e82e9ad"},
{file = "asyncpg-0.30.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5df69d55add4efcd25ea2a3b02025b669a285b767bfbf06e356d68dbce4234ff"},
{file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3479a0d9a852c7c84e822c073622baca862d1217b10a02dd57ee4a7a081f708"},
{file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26683d3b9a62836fad771a18ecf4659a30f348a561279d6227dab96182f46144"},
{file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b982daf2441a0ed314bd10817f1606f1c28b1136abd9e4f11335358c2c631cb"},
{file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c06a3a50d014b303e5f6fc1e5f95eb28d2cee89cf58384b700da621e5d5e547"},
{file = "asyncpg-0.30.0-cp39-cp39-win32.whl", hash = "sha256:1b11a555a198b08f5c4baa8f8231c74a366d190755aa4f99aacec5970afe929a"},
{file = "asyncpg-0.30.0-cp39-cp39-win_amd64.whl", hash = "sha256:8b684a3c858a83cd876f05958823b68e8d14ec01bb0c0d14a6704c5bf9711773"},
{file = "asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851"},
]
[package.dependencies]
async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.11.0\""}
[package.extras]
docs = ["Sphinx (>=8.1.3,<8.2.0)", "sphinx-rtd-theme (>=1.2.2)"]
gssauth = ["gssapi ; platform_system != \"Windows\"", "sspilib ; platform_system == \"Windows\""]
test = ["distro (>=1.9.0,<1.10.0)", "flake8 (>=6.1,<7.0)", "flake8-pyi (>=24.1.0,<24.2.0)", "gssapi ; platform_system == \"Linux\"", "k5test ; platform_system == \"Linux\"", "mypy (>=1.8.0,<1.9.0)", "sspilib ; platform_system == \"Windows\"", "uvloop (>=0.15.3) ; platform_system != \"Windows\" and python_version < \"3.14.0\""]
[[package]]
name = "attrs"
version = "25.3.0"
@ -7503,4 +7570,4 @@ tests = ["wikipedia"]
[metadata]
lock-version = "2.1"
python-versions = "<3.14,>=3.10"
content-hash = "f82fec7b3f35d4222c43b692db8cd005eaf8bcf6761bb202d0dbf64121c6b2ab"
content-hash = "862dc5a31d4385e89dc9a751cd171a611da3102c6832447a5f61926b25f03e06"

View File

@ -90,6 +90,7 @@ firecrawl-py = "^1.15.0"
apscheduler = "^3.11.0"
aiomultiprocess = "^0.9.1"
matplotlib = "^3.10.1"
asyncpg = "^0.30.0"
[tool.poetry.extras]

View File

@ -15,6 +15,7 @@ from letta.schemas.enums import JobStatus, ToolRuleType
from letta.schemas.group import GroupUpdate, ManagerType, SleeptimeManagerUpdate
from letta.schemas.message import MessageCreate
from letta.schemas.run import Run
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.utils import get_human_text, get_persona_text
@ -37,7 +38,7 @@ def org_id(server):
yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
with db_registry.session() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()

View File

@ -15,6 +15,7 @@ from letta.schemas.group import (
SupervisorManager,
)
from letta.schemas.message import MessageCreate
from letta.server.db import db_registry
from letta.server.server import SyncServer
@ -36,7 +37,7 @@ def org_id(server):
yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
with db_registry.session() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()

View File

@ -19,6 +19,7 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import ProviderCreate
from letta.schemas.sandbox_config import SandboxType
from letta.schemas.user import User
from letta.server.db import db_registry
utils.DEBUG = True
from letta.config import LettaConfig
@ -284,7 +285,7 @@ def org_id(server):
yield org.id
# cleanup
with server.organization_manager.session_maker() as session:
with db_registry.session() as session:
session.execute(delete(Step))
session.execute(delete(Provider))
session.commit()