mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: async db client (#2076)
This commit is contained in:
parent
845005451f
commit
e85c558ddc
@ -3,14 +3,21 @@ from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
|
||||
import openai
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.utils import united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
@ -64,3 +71,107 @@ class BaseAgent(ABC):
|
||||
return ""
|
||||
|
||||
return [{"role": input_message.role.value, "content": get_content(input_message)} for input_message in input_messages]
|
||||
|
||||
def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None, # storing these calculations is specific to the voice agent
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
try:
|
||||
# Refresh Memory
|
||||
# TODO: This only happens for the summary block (voice?)
|
||||
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
|
||||
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
# [DB Call] size of messages and archival memories
|
||||
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# [DB Call] Update Messages
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
except:
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
||||
async def _rebuild_memory_async(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
"""
|
||||
Async version of function above. For now before breaking up components, changes should be made in both places.
|
||||
"""
|
||||
try:
|
||||
# [DB Call] loading blocks (modifies: agent_state.memory.blocks)
|
||||
await self.agent_manager.refresh_memory_async(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
# [DB Call] size of messages and archival memories
|
||||
# todo: blocking for now
|
||||
num_messages = num_messages or self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = num_archival_memories or self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# [DB Call] Update Messages
|
||||
new_system_message = self.message_manager.update_message_by_id_async(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
except:
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
@ -32,6 +32,7 @@ from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
from letta.settings import settings
|
||||
from letta.system import package_function_response
|
||||
from letta.tracing import log_event, trace_method
|
||||
from letta.utils import united_diff
|
||||
@ -171,6 +172,7 @@ class LettaAgent(BaseAgent):
|
||||
yield f"data: {MessageStreamStatus.done.model_dump_json()}\n\n"
|
||||
|
||||
@trace_method
|
||||
# When raising an error this doesn't show up
|
||||
async def _get_ai_reply(
|
||||
self,
|
||||
llm_client: LLMClientBase,
|
||||
@ -179,7 +181,10 @@ class LettaAgent(BaseAgent):
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
stream: bool,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
if settings.experimental_enable_async_db_engine:
|
||||
in_context_messages = await self._rebuild_memory_async(in_context_messages, agent_state)
|
||||
else:
|
||||
in_context_messages = self._rebuild_memory(in_context_messages, agent_state)
|
||||
|
||||
tools = [
|
||||
t
|
||||
@ -296,51 +301,6 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
return persisted_messages, continue_stepping
|
||||
|
||||
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
try:
|
||||
self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
except:
|
||||
logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
"""
|
||||
|
@ -1,11 +1,12 @@
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from aiomultiprocess import Pool
|
||||
from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult
|
||||
|
||||
from letta.agents.base_agent import BaseAgent
|
||||
from letta.agents.helpers import _prepare_in_context_messages
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@ -16,11 +17,12 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepState
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType
|
||||
from letta.schemas.enums import AgentStepStatus, JobStatus, MessageStreamStatus, ProviderType
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_request import LettaBatchRequest
|
||||
from letta.schemas.letta_response import LettaBatchResponse
|
||||
from letta.schemas.letta_response import LettaBatchResponse, LettaResponse
|
||||
from letta.schemas.llm_batch_job import LLMBatchItem
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall as OpenAIToolCall
|
||||
@ -95,7 +97,7 @@ async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[
|
||||
|
||||
# TODO: Limitations ->
|
||||
# TODO: Only works with anthropic for now
|
||||
class LettaAgentBatch:
|
||||
class LettaAgentBatch(BaseAgent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -539,43 +541,20 @@ class LettaAgentBatch:
|
||||
return in_context_messages
|
||||
|
||||
# TODO: Make this a bullk function
|
||||
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
agent_state = self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor)
|
||||
def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None,
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
return super()._rebuild_memory(in_context_messages, agent_state)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
# Not used in batch.
|
||||
async def step(self, input_messages: List[MessageCreate], max_steps: int = 10) -> LettaResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10
|
||||
) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
raise NotImplementedError
|
||||
|
@ -293,48 +293,17 @@ class VoiceAgent(BaseAgent):
|
||||
agent_id=self.agent_id, message_ids=[m.id for m in new_in_context_messages], actor=self.actor
|
||||
)
|
||||
|
||||
def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]:
|
||||
# Refresh memory
|
||||
# TODO: This only happens for the summary block
|
||||
# TODO: We want to extend this refresh to be general, and stick it in agent_manager
|
||||
block_ids = [block.id for block in agent_state.memory.blocks]
|
||||
agent_state.memory.blocks = self.block_manager.get_all_blocks_by_ids(block_ids=block_ids, actor=self.actor)
|
||||
|
||||
# TODO: This is a pretty brittle pattern established all over our code, need to get rid of this
|
||||
curr_system_message = in_context_messages[0]
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
curr_system_message_text = curr_system_message.content[0].text
|
||||
if curr_memory_str in curr_system_message_text:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.debug(
|
||||
f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return in_context_messages
|
||||
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=self.num_messages,
|
||||
archival_memory_size=self.num_archival_memories,
|
||||
def _rebuild_memory(
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None,
|
||||
num_archival_memories: int | None = None,
|
||||
) -> List[Message]:
|
||||
return super()._rebuild_memory(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
)
|
||||
|
||||
# Skip pulling down the agent's memory again to save on a db call
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
else:
|
||||
return in_context_messages
|
||||
|
||||
def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest:
|
||||
tool_schemas = self._build_tool_schemas(agent_state)
|
||||
tool_choice = "auto" if tool_schemas else None
|
||||
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
|
||||
|
||||
from sqlalchemy import String, and_, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
@ -300,6 +301,44 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
|
||||
return found[0]
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
async def read_async(
|
||||
cls,
|
||||
db_session: "Session",
|
||||
identifier: Optional[str] = None,
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
) -> "SqlalchemyBase":
|
||||
"""The primary accessor for an ORM record. Async version of read method.
|
||||
Args:
|
||||
db_session: the database session to use when retrieving the record
|
||||
identifier: the identifier of the record to read, can be the id string or the UUID object for backwards compatibility
|
||||
actor: if specified, results will be scoped only to records the user is able to access
|
||||
access: if actor is specified, records will be filtered to the minimum permission level for the actor
|
||||
kwargs: additional arguments to pass to the read, used for more complex objects
|
||||
Returns:
|
||||
The matching object
|
||||
Raises:
|
||||
NoResultFound: if the object is not found
|
||||
"""
|
||||
# this is ok because read_multiple will check if the
|
||||
identifiers = [] if identifier is None else [identifier]
|
||||
found = await cls.read_multiple_async(db_session, identifiers, actor, access, access_type, **kwargs)
|
||||
if len(found) == 0:
|
||||
# for backwards compatibility.
|
||||
conditions = []
|
||||
if identifier:
|
||||
conditions.append(f"id={identifier}")
|
||||
if actor:
|
||||
conditions.append(f"access level in {access} for {actor}")
|
||||
if hasattr(cls, "is_deleted"):
|
||||
conditions.append("is_deleted=False")
|
||||
raise NoResultFound(f"{cls.__name__} not found with {', '.join(conditions if conditions else ['no conditions'])}")
|
||||
return found[0]
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def read_multiple(
|
||||
@ -323,6 +362,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
Raises:
|
||||
NoResultFound: if the object is not found
|
||||
"""
|
||||
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs)
|
||||
results = db_session.execute(query).scalars().all()
|
||||
return cls._read_multiple_postprocess(results, identifiers, query_conditions)
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
async def read_multiple_async(
|
||||
cls,
|
||||
db_session: "AsyncSession",
|
||||
identifiers: List[str] = [],
|
||||
actor: Optional["User"] = None,
|
||||
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
||||
access_type: AccessType = AccessType.ORGANIZATION,
|
||||
**kwargs,
|
||||
) -> List["SqlalchemyBase"]:
|
||||
"""
|
||||
Async version of read_multiple(...)
|
||||
The primary accessor for ORM record(s)
|
||||
"""
|
||||
query, query_conditions = cls._read_multiple_preprocess(identifiers, actor, access, access_type, **kwargs)
|
||||
results = await db_session.execute(query)
|
||||
return cls._read_multiple_postprocess(results.scalars().all(), identifiers, query_conditions)
|
||||
|
||||
@classmethod
|
||||
def _read_multiple_preprocess(
|
||||
cls,
|
||||
identifiers: List[str],
|
||||
actor: Optional["User"],
|
||||
access: Optional[List[Literal["read", "write", "admin"]]],
|
||||
access_type: AccessType,
|
||||
**kwargs,
|
||||
):
|
||||
logger.debug(f"Reading {cls.__name__} with ID(s): {identifiers} with actor={actor}")
|
||||
|
||||
# Start the query
|
||||
@ -350,7 +421,10 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
query_conditions.append("is_deleted=False")
|
||||
|
||||
results = db_session.execute(query).scalars().all()
|
||||
return query, query_conditions
|
||||
|
||||
@classmethod
|
||||
def _read_multiple_postprocess(cls, results, identifiers: List[str], query_conditions) -> List["SqlalchemyBase"]:
|
||||
if results: # if empty list a.k.a. no results
|
||||
if len(identifiers) > 0:
|
||||
# find which identifiers were not found
|
||||
@ -471,6 +545,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
db_session.refresh(self)
|
||||
return self
|
||||
|
||||
@handle_db_timeout
|
||||
async def update_async(self, db_session: AsyncSession, actor: "User | None" = None, no_commit: bool = False) -> "SqlalchemyBase":
|
||||
"""Async version of update function"""
|
||||
logger.debug(...)
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
self.set_updated_at()
|
||||
|
||||
db_session.add(self)
|
||||
if no_commit:
|
||||
await db_session.flush()
|
||||
else:
|
||||
await db_session.commit()
|
||||
await db_session.refresh(self)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
@handle_db_timeout
|
||||
def size(
|
||||
|
@ -1,6 +1,7 @@
|
||||
from typing import Dict
|
||||
|
||||
from marshmallow import fields, post_dump, pre_load
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import letta
|
||||
from letta.orm import Agent
|
||||
@ -14,7 +15,6 @@ from letta.serialize_schemas.marshmallow_custom_fields import EmbeddingConfigFie
|
||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.marshmallow_tag import SerializedAgentTagSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.server.db import SessionLocal
|
||||
|
||||
|
||||
class MarshmallowAgentSchema(BaseSchema):
|
||||
@ -41,7 +41,7 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
tool_exec_environment_variables = fields.List(fields.Nested(SerializedAgentEnvironmentVariableSchema))
|
||||
tags = fields.List(fields.Nested(SerializedAgentTagSchema))
|
||||
|
||||
def __init__(self, *args, session: SessionLocal, actor: User, **kwargs):
|
||||
def __init__(self, *args, session: sessionmaker, actor: User, **kwargs):
|
||||
super().__init__(*args, actor=actor, **kwargs)
|
||||
self.session = session
|
||||
|
||||
@ -60,9 +60,9 @@ class MarshmallowAgentSchema(BaseSchema):
|
||||
After dumping the agent, load all its Message rows and serialize them here.
|
||||
"""
|
||||
# TODO: This is hacky, but want to move fast, please refactor moving forward
|
||||
from letta.server.db import db_context as session_maker
|
||||
from letta.server.db import db_registry
|
||||
|
||||
with session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent_id = data.get("id")
|
||||
msgs = (
|
||||
session.query(MessageModel)
|
||||
|
@ -1,28 +1,19 @@
|
||||
import os
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from typing import Any, AsyncGenerator, Generator
|
||||
|
||||
from rich.console import Console
|
||||
from rich.panel import Panel
|
||||
from rich.text import Text
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import Engine, create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Base
|
||||
from letta.settings import settings
|
||||
|
||||
# Use globals for the lock and initialization flag
|
||||
_engine_lock = threading.Lock()
|
||||
_engine_initialized = False
|
||||
|
||||
# Create variables in global scope but don't initialize them yet
|
||||
config = LettaConfig.load()
|
||||
logger = get_logger(__name__)
|
||||
engine = None
|
||||
SessionLocal = None
|
||||
|
||||
|
||||
def print_sqlite_schema_error():
|
||||
"""Print a formatted error message for SQLite schema issues"""
|
||||
@ -54,86 +45,187 @@ def db_error_handler():
|
||||
exit(1)
|
||||
|
||||
|
||||
def initialize_engine():
|
||||
"""Initialize the database engine only when needed."""
|
||||
global engine, SessionLocal, _engine_initialized
|
||||
class DatabaseRegistry:
|
||||
"""Registry for database connections and sessions.
|
||||
|
||||
with _engine_lock:
|
||||
# Check again inside the lock to prevent race conditions
|
||||
if _engine_initialized:
|
||||
return
|
||||
This class manages both synchronous and asynchronous database connections
|
||||
and provides context managers for session handling.
|
||||
"""
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
logger.info("Creating postgres engine")
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
config.archival_storage_type = "postgres"
|
||||
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
def __init__(self):
|
||||
self._engines: dict[str, Engine] = {}
|
||||
self._async_engines: dict[str, AsyncEngine] = {}
|
||||
self._session_factories: dict[str, sessionmaker] = {}
|
||||
self._async_session_factories: dict[str, async_sessionmaker] = {}
|
||||
self._initialized: dict[str, bool] = {"sync": False, "async": False}
|
||||
self._lock = threading.Lock()
|
||||
self.config = LettaConfig.load()
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
# create engine
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
# connect_args={"client_encoding": "utf8"},
|
||||
)
|
||||
else:
|
||||
# TODO: don't rely on config storage
|
||||
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
||||
logger.info("Creating sqlite engine " + engine_path)
|
||||
def initialize_sync(self, force: bool = False) -> None:
|
||||
"""Initialize the synchronous database engine if not already initialized."""
|
||||
with self._lock:
|
||||
if self._initialized.get("sync") and not force:
|
||||
return
|
||||
|
||||
engine = create_engine(engine_path)
|
||||
# Postgres engine
|
||||
if settings.letta_pg_uri_no_default:
|
||||
self.logger.info("Creating postgres engine")
|
||||
self.config.recall_storage_type = "postgres"
|
||||
self.config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
self.config.archival_storage_type = "postgres"
|
||||
self.config.archival_storage_uri = settings.letta_pg_uri_no_default
|
||||
|
||||
# Store the original connect method
|
||||
original_connect = engine.connect
|
||||
engine = create_engine(
|
||||
settings.letta_pg_uri,
|
||||
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
# connect_args={"client_encoding": "utf8"},
|
||||
)
|
||||
|
||||
def wrapped_connect(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
# Get the connection
|
||||
connection = original_connect(*args, **kwargs)
|
||||
self._engines["default"] = engine
|
||||
# SQLite engine
|
||||
else:
|
||||
from letta.orm import Base
|
||||
|
||||
# Store the original execution method
|
||||
original_execute = connection.execute
|
||||
# TODO: don't rely on config storage
|
||||
engine_path = "sqlite:///" + os.path.join(self.config.recall_storage_path, "sqlite.db")
|
||||
self.logger.info("Creating sqlite engine " + engine_path)
|
||||
|
||||
# Wrap the execute method of the connection
|
||||
def wrapped_execute(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
return original_execute(*args, **kwargs)
|
||||
engine = create_engine(engine_path)
|
||||
|
||||
# Replace the connection's execute method
|
||||
connection.execute = wrapped_execute
|
||||
# Wrap the engine with error handling
|
||||
self._wrap_sqlite_engine(engine)
|
||||
|
||||
return connection
|
||||
Base.metadata.create_all(bind=engine)
|
||||
self._engines["default"] = engine
|
||||
|
||||
# Replace the engine's connect method
|
||||
engine.connect = wrapped_connect
|
||||
# Create session factory
|
||||
self._session_factories["default"] = sessionmaker(autocommit=False, autoflush=False, bind=self._engines["default"])
|
||||
self._initialized["sync"] = True
|
||||
|
||||
Base.metadata.create_all(bind=engine)
|
||||
def initialize_async(self, force: bool = False) -> None:
|
||||
"""Initialize the asynchronous database engine if not already initialized."""
|
||||
with self._lock:
|
||||
if self._initialized.get("async") and not force:
|
||||
return
|
||||
|
||||
# Create the session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
_engine_initialized = True
|
||||
if settings.letta_pg_uri_no_default:
|
||||
self.logger.info("Creating async postgres engine")
|
||||
|
||||
# Create async engine - convert URI to async format
|
||||
pg_uri = settings.letta_pg_uri
|
||||
if pg_uri.startswith("postgresql://"):
|
||||
async_pg_uri = pg_uri.replace("postgresql://", "postgresql+asyncpg://")
|
||||
else:
|
||||
async_pg_uri = f"postgresql+asyncpg://{pg_uri.split('://', 1)[1]}" if "://" in pg_uri else pg_uri
|
||||
|
||||
async_engine = create_async_engine(
|
||||
async_pg_uri,
|
||||
pool_size=settings.pg_pool_size,
|
||||
max_overflow=settings.pg_max_overflow,
|
||||
pool_timeout=settings.pg_pool_timeout,
|
||||
pool_recycle=settings.pg_pool_recycle,
|
||||
echo=settings.pg_echo,
|
||||
)
|
||||
|
||||
self._async_engines["default"] = async_engine
|
||||
|
||||
# Create async session factory
|
||||
self._async_session_factories["default"] = async_sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=self._async_engines["default"], class_=AsyncSession
|
||||
)
|
||||
self._initialized["async"] = True
|
||||
else:
|
||||
self.logger.warning("Async SQLite is currently not supported. Please use PostgreSQL for async database operations.")
|
||||
# TODO (cliandy): unclear around async sqlite support in sqlalchemy, we will not currently support this
|
||||
self._initialized["async"] = False
|
||||
|
||||
def _wrap_sqlite_engine(self, engine: Engine) -> None:
|
||||
"""Wrap SQLite engine with error handling."""
|
||||
original_connect = engine.connect
|
||||
|
||||
def wrapped_connect(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
connection = original_connect(*args, **kwargs)
|
||||
original_execute = connection.execute
|
||||
|
||||
def wrapped_execute(*args, **kwargs):
|
||||
with db_error_handler():
|
||||
return original_execute(*args, **kwargs)
|
||||
|
||||
connection.execute = wrapped_execute
|
||||
return connection
|
||||
|
||||
engine.connect = wrapped_connect
|
||||
|
||||
def get_engine(self, name: str = "default") -> Engine:
|
||||
"""Get a database engine by name."""
|
||||
self.initialize_sync()
|
||||
return self._engines.get(name)
|
||||
|
||||
def get_async_engine(self, name: str = "default") -> AsyncEngine:
|
||||
"""Get an async database engine by name."""
|
||||
self.initialize_async()
|
||||
return self._async_engines.get(name)
|
||||
|
||||
def get_session_factory(self, name: str = "default") -> sessionmaker:
|
||||
"""Get a session factory by name."""
|
||||
self.initialize_sync()
|
||||
return self._session_factories.get(name)
|
||||
|
||||
def get_async_session_factory(self, name: str = "default") -> async_sessionmaker:
|
||||
"""Get an async session factory by name."""
|
||||
self.initialize_async()
|
||||
return self._async_session_factories.get(name)
|
||||
|
||||
@contextmanager
|
||||
def session(self, name: str = "default") -> Generator[Any, None, None]:
|
||||
"""Context manager for database sessions."""
|
||||
session_factory = self.get_session_factory(name)
|
||||
if not session_factory:
|
||||
raise ValueError(f"No session factory found for '{name}'")
|
||||
|
||||
session = session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Async context manager for database sessions."""
|
||||
session_factory = self.get_async_session_factory(name)
|
||||
if not session_factory:
|
||||
raise ValueError(f"No async session factory found for '{name}' or async database is not configured")
|
||||
|
||||
session = session_factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
|
||||
# Create a singleton instance
|
||||
db_registry = DatabaseRegistry()
|
||||
|
||||
|
||||
def get_db():
|
||||
"""Get a database session, initializing the engine if needed."""
|
||||
global engine, SessionLocal
|
||||
|
||||
# Make sure engine is initialized
|
||||
if not _engine_initialized:
|
||||
initialize_engine()
|
||||
|
||||
# Now SessionLocal should be defined and callable
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
"""Get a database session."""
|
||||
with db_registry.session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Define db_context as a context manager that uses get_db
|
||||
async def get_db_async():
|
||||
"""Get an async database session."""
|
||||
async with db_registry.async_session() as session:
|
||||
yield session
|
||||
|
||||
|
||||
# Prefer calling db_registry.session() or db_registry.async_session() directly
|
||||
# This is for backwards compatibility
|
||||
db_context = contextmanager(get_db)
|
||||
|
@ -56,6 +56,7 @@ from letta.serialize_schemas import MarshmallowAgentSchema
|
||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
from letta.serialize_schemas.marshmallow_tool import SerializedToolSchema
|
||||
from letta.serialize_schemas.pydantic_agent_schema import AgentSchema
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_apply_filters,
|
||||
@ -85,9 +86,6 @@ class AgentManager:
|
||||
"""Manager class to handle business logic related to Agents."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
self.block_manager = BlockManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.source_manager = SourceManager()
|
||||
@ -200,7 +198,7 @@ class AgentManager:
|
||||
identity_ids = agent_create.identity_ids or []
|
||||
tag_values = agent_create.tags or []
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
with session.begin():
|
||||
name_to_id, id_to_name = self._resolve_tools(
|
||||
session,
|
||||
@ -356,7 +354,7 @@ class AgentManager:
|
||||
new_idents = set(agent_update.identity_ids or [])
|
||||
new_tags = set(agent_update.tags or [])
|
||||
|
||||
with self.session_maker() as session, session.begin():
|
||||
with db_registry.session() as session, session.begin():
|
||||
|
||||
agent: AgentModel = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
agent.updated_at = datetime.now(timezone.utc)
|
||||
@ -503,7 +501,7 @@ class AgentManager:
|
||||
Returns:
|
||||
List[PydanticAgentState]: The filtered list of matching agents.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
query = select(AgentModel).distinct(AgentModel.created_at, AgentModel.id)
|
||||
query = AgentModel.apply_access_predicate(query, actor, ["read"], AccessType.ORGANIZATION)
|
||||
|
||||
@ -541,7 +539,7 @@ class AgentManager:
|
||||
Returns:
|
||||
List[PydanticAgentState: The filtered list of matching agents.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
query = select(AgentModel).where(AgentModel.organization_id == actor.organization_id)
|
||||
|
||||
if match_all:
|
||||
@ -569,20 +567,20 @@ class AgentManager:
|
||||
"""
|
||||
Get the total count of agents for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
return AgentModel.size(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_id(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_agent_by_name(self, agent_name: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, name=agent_name, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@ -599,7 +597,7 @@ class AgentManager:
|
||||
Raises:
|
||||
NoResultFound: If agent doesn't exist
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve the agent
|
||||
logger.debug(f"Hard deleting Agent with ID: {agent_id} with actor={actor}")
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
@ -635,7 +633,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
def serialize(self, agent_id: str, actor: PydanticUser) -> AgentSchema:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||||
data = schema.dump(agent)
|
||||
@ -665,7 +663,7 @@ class AgentManager:
|
||||
|
||||
serialized_agent_dict[MarshmallowAgentSchema.FIELD_MESSAGE_IDS] = message_ids
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
schema = MarshmallowAgentSchema(session=session, actor=actor)
|
||||
agent = schema.load(serialized_agent_dict, session=session)
|
||||
|
||||
@ -728,7 +726,7 @@ class AgentManager:
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent as a Pydantic model.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -767,7 +765,7 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
def list_groups(self, agent_id: str, actor: PydanticUser, manager_type: Optional[str] = None) -> List[PydanticGroup]:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
if manager_type:
|
||||
return [group.to_pydantic() for group in agent.groups if group.manager_type == manager_type]
|
||||
@ -908,7 +906,7 @@ class AgentManager:
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state with no linked messages.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve the existing agent (will raise NoResultFound if invalid)
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -985,6 +983,17 @@ class AgentManager:
|
||||
)
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
async def refresh_memory_async(self, agent_state: PydanticAgentState, actor: PydanticUser) -> PydanticAgentState:
|
||||
block_ids = [b.id for b in agent_state.memory.blocks]
|
||||
if not block_ids:
|
||||
return agent_state
|
||||
|
||||
agent_state.memory.blocks = await self.block_manager.get_all_blocks_by_ids_async(
|
||||
block_ids=[b.id for b in agent_state.memory.blocks], actor=actor
|
||||
)
|
||||
return agent_state
|
||||
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
@ -1003,7 +1012,7 @@ class AgentManager:
|
||||
IntegrityError: If the source is already attached to the agent
|
||||
"""
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Verify both agent and source exist and user has permission to access them
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -1056,7 +1065,7 @@ class AgentManager:
|
||||
Returns:
|
||||
List[str]: List of source IDs attached to the agent
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -1073,7 +1082,7 @@ class AgentManager:
|
||||
source_id: ID of the source to detach
|
||||
actor: User performing the action
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Verify agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -1101,7 +1110,7 @@ class AgentManager:
|
||||
actor: PydanticUser,
|
||||
) -> PydanticBlock:
|
||||
"""Gets a block attached to an agent by its label."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
for block in agent.core_memory:
|
||||
if block.label == block_label:
|
||||
@ -1117,7 +1126,7 @@ class AgentManager:
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""Updates which block is assigned to a specific label for an agent."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
new_block = BlockModel.read(db_session=session, identifier=new_block_id, actor=actor)
|
||||
|
||||
@ -1135,7 +1144,7 @@ class AgentManager:
|
||||
@enforce_types
|
||||
def attach_block(self, agent_id: str, block_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Attaches a block to an agent."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
|
||||
@ -1151,7 +1160,7 @@ class AgentManager:
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""Detaches a block from an agent."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
original_length = len(agent.core_memory)
|
||||
|
||||
@ -1171,7 +1180,7 @@ class AgentManager:
|
||||
actor: PydanticUser,
|
||||
) -> PydanticAgentState:
|
||||
"""Detaches a block with the specified label from an agent."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
original_length = len(agent.core_memory)
|
||||
|
||||
@ -1215,7 +1224,7 @@ class AgentManager:
|
||||
embedded_text = np.array(embedded_text)
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Start with base query for source passages
|
||||
source_passages = None
|
||||
if not agent_only: # Include source passages
|
||||
@ -1389,7 +1398,7 @@ class AgentManager:
|
||||
agent_only: bool = False,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
main_query = self._build_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
@ -1447,7 +1456,7 @@ class AgentManager:
|
||||
agent_only: bool = False,
|
||||
) -> int:
|
||||
"""Returns the count of passages matching the given criteria."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
main_query = self._build_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
@ -1487,7 +1496,7 @@ class AgentManager:
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -1522,7 +1531,7 @@ class AgentManager:
|
||||
Returns:
|
||||
PydanticAgentState: The updated agent state.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Verify the agent exists and user has permission to access it
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -1551,7 +1560,7 @@ class AgentManager:
|
||||
Returns:
|
||||
List[PydanticTool]: List of tools attached to the agent.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
return [tool.to_pydantic() for tool in agent.tools]
|
||||
|
||||
@ -1574,7 +1583,7 @@ class AgentManager:
|
||||
Returns:
|
||||
List[str]: List of all tags.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
query = (
|
||||
session.query(AgentsTags.tag)
|
||||
.join(AgentModel, AgentModel.id == AgentsTags.agent_id)
|
||||
|
@ -12,6 +12,7 @@ from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, Human, Persona
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, list_human_files, list_persona_files
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -20,12 +21,6 @@ logger = get_logger(__name__)
|
||||
class BlockManager:
|
||||
"""Manager class to handle business logic related to Blocks."""
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in ToolManager
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_or_update_block(self, block: PydanticBlock, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Create a new block based on the Block schema."""
|
||||
@ -34,7 +29,7 @@ class BlockManager:
|
||||
update_data = BlockUpdate(**block.model_dump(to_orm=True, exclude_none=True))
|
||||
self.update_block(block.id, update_data, actor)
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
data = block.model_dump(to_orm=True, exclude_none=True)
|
||||
block = BlockModel(**data, organization_id=actor.organization_id)
|
||||
block.create(session, actor=actor)
|
||||
@ -53,7 +48,7 @@ class BlockManager:
|
||||
if not blocks:
|
||||
return []
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
block_models = [
|
||||
BlockModel(**block.model_dump(to_orm=True, exclude_none=True), organization_id=actor.organization_id) for block in blocks
|
||||
]
|
||||
@ -68,7 +63,7 @@ class BlockManager:
|
||||
"""Update a block by its ID with the given BlockUpdate object."""
|
||||
# Safety check for block
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
update_data = block_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
@ -81,7 +76,7 @@ class BlockManager:
|
||||
@enforce_types
|
||||
def delete_block(self, block_id: str, actor: PydanticUser) -> PydanticBlock:
|
||||
"""Delete a block by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id)
|
||||
block.hard_delete(db_session=session, actor=actor)
|
||||
return block.to_pydantic()
|
||||
@ -100,7 +95,7 @@ class BlockManager:
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks based on various optional filters."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Prepare filters
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if label:
|
||||
@ -126,7 +121,7 @@ class BlockManager:
|
||||
@enforce_types
|
||||
def get_block_by_id(self, block_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticBlock]:
|
||||
"""Retrieve a block by its name."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
return block.to_pydantic()
|
||||
@ -136,12 +131,24 @@ class BlockManager:
|
||||
@enforce_types
|
||||
def get_all_blocks_by_ids(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
blocks = [block.to_pydantic() for block in BlockModel.read_multiple(db_session=session, identifiers=block_ids, actor=actor)]
|
||||
# backwards compatibility. previous implementation added None for every block not found.
|
||||
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
|
||||
return blocks
|
||||
|
||||
@enforce_types
|
||||
async def get_all_blocks_by_ids_async(self, block_ids: List[str], actor: Optional[PydanticUser] = None) -> List[PydanticBlock]:
|
||||
"""Retrieve blocks by their ids. Async implementation."""
|
||||
async with db_registry.async_session() as session:
|
||||
blocks = [
|
||||
block.to_pydantic()
|
||||
for block in await BlockModel.read_multiple_async(db_session=session, identifiers=block_ids, actor=actor)
|
||||
]
|
||||
# backwards compatibility. previous implementation added None for every block not found.
|
||||
blocks.extend([None for _ in range(len(block_ids) - len(blocks))])
|
||||
return blocks
|
||||
|
||||
@enforce_types
|
||||
def add_default_blocks(self, actor: PydanticUser):
|
||||
for persona_file in list_persona_files():
|
||||
@ -161,7 +168,7 @@ class BlockManager:
|
||||
"""
|
||||
Retrieve all agents associated with a given block.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
agents_orm = block.agents
|
||||
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
|
||||
@ -176,7 +183,7 @@ class BlockManager:
|
||||
"""
|
||||
Get the total count of blocks for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
return BlockModel.size(db_session=session, actor=actor)
|
||||
|
||||
# Block History Functions
|
||||
@ -199,7 +206,7 @@ class BlockManager:
|
||||
strictly linear history.
|
||||
- A single commit at the end ensures atomicity.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# 1) Load the Block
|
||||
if use_preloaded_block is not None:
|
||||
block = session.merge(use_preloaded_block)
|
||||
@ -291,7 +298,7 @@ class BlockManager:
|
||||
If older sequences have been pruned, we jump to the largest sequence
|
||||
number that is still < current_seq.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# 1) Load the current block
|
||||
block = (
|
||||
session.merge(use_preloaded_block)
|
||||
@ -333,7 +340,7 @@ class BlockManager:
|
||||
If some middle checkpoints have been pruned, we jump to the smallest
|
||||
sequence > current_seq that remains.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
block = (
|
||||
session.merge(use_preloaded_block)
|
||||
if use_preloaded_block
|
||||
@ -383,7 +390,7 @@ class BlockManager:
|
||||
NoResultFound if any block_id doesn’t exist or isn’t visible to this actor
|
||||
ValueError if any new value exceeds its block’s limit
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
q = session.query(BlockModel).filter(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
|
||||
blocks = q.all()
|
||||
|
||||
|
@ -11,16 +11,12 @@ from letta.schemas.group import GroupCreate, GroupUpdate, ManagerType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class GroupManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def list_groups(
|
||||
self,
|
||||
@ -31,7 +27,7 @@ class GroupManager:
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
) -> list[PydanticGroup]:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
@ -48,13 +44,13 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
def retrieve_group(self, group_id: str, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
return group.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_group(self, group: GroupCreate, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
new_group = GroupModel()
|
||||
new_group.organization_id = actor.organization_id
|
||||
new_group.description = group.description
|
||||
@ -99,7 +95,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
def modify_group(self, group_id: str, group_update: GroupUpdate, actor: PydanticUser) -> PydanticGroup:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
sleeptime_agent_frequency = None
|
||||
@ -161,7 +157,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
def delete_group(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve the agent
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
group.hard_delete(session)
|
||||
@ -178,7 +174,7 @@ class GroupManager:
|
||||
assistant_message_tool_name: str = "send_message",
|
||||
assistant_message_tool_kwarg: str = "message",
|
||||
) -> list[LettaMessage]:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
filters = {
|
||||
"organization_id": actor.organization_id,
|
||||
"group_id": group_id,
|
||||
@ -204,7 +200,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
def reset_messages(self, group_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
@ -217,7 +213,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
@ -228,7 +224,7 @@ class GroupManager:
|
||||
|
||||
@enforce_types
|
||||
def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
@ -247,7 +243,7 @@ class GroupManager:
|
||||
"""
|
||||
Get the total count of groups for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
return GroupModel.size(db_session=session, actor=actor)
|
||||
|
||||
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
|
||||
|
@ -10,16 +10,12 @@ from letta.orm.identity import Identity as IdentityModel
|
||||
from letta.schemas.identity import Identity as PydanticIdentity
|
||||
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityType, IdentityUpdate, IdentityUpsert
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class IdentityManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def list_identities(
|
||||
self,
|
||||
@ -32,7 +28,7 @@ class IdentityManager:
|
||||
limit: Optional[int] = 50,
|
||||
actor: PydanticUser = None,
|
||||
) -> list[PydanticIdentity]:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if project_id:
|
||||
filters["project_id"] = project_id
|
||||
@ -52,13 +48,13 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
def get_identity(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
return identity.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_identity(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
|
||||
new_identity.organization_id = actor.organization_id
|
||||
self._process_relationship(
|
||||
@ -82,7 +78,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
def upsert_identity(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
existing_identity = IdentityModel.read(
|
||||
db_session=session,
|
||||
identifier_key=identity.identifier_key,
|
||||
@ -107,7 +103,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
def update_identity(self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
except NoResultFound:
|
||||
@ -167,7 +163,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
def upsert_identity_properties(self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser) -> PydanticIdentity:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
existing_identity = IdentityModel.read(db_session=session, identifier=identity_id, actor=actor)
|
||||
if existing_identity is None:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
@ -181,7 +177,7 @@ class IdentityManager:
|
||||
|
||||
@enforce_types
|
||||
def delete_identity(self, identity_id: str, actor: PydanticUser) -> None:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
identity = IdentityModel.read(db_session=session, identifier=identity_id)
|
||||
if identity is None:
|
||||
raise HTTPException(status_code=404, detail="Identity not found")
|
||||
@ -198,7 +194,7 @@ class IdentityManager:
|
||||
"""
|
||||
Get the total count of identities for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
return IdentityModel.size(db_session=session, actor=actor)
|
||||
|
||||
def _process_relationship(
|
||||
|
@ -24,24 +24,19 @@ from letta.schemas.run import Run as PydanticRun
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class JobManager:
|
||||
"""Manager class to handle business logic related to Jobs."""
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_job(
|
||||
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
|
||||
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
|
||||
"""Create a new job based on the JobCreate schema."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Associate the job with the user
|
||||
pydantic_job.user_id = actor.id
|
||||
job_data = pydantic_job.model_dump(to_orm=True)
|
||||
@ -52,7 +47,7 @@ class JobManager:
|
||||
@enforce_types
|
||||
def update_job_by_id(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob:
|
||||
"""Update a job by its ID with the given JobUpdate object."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Fetch the job by ID
|
||||
job = self._verify_job_access(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
|
||||
@ -76,7 +71,7 @@ class JobManager:
|
||||
@enforce_types
|
||||
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Fetch a job by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve job by ID using the Job model's read method
|
||||
job = JobModel.read(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER)
|
||||
return job.to_pydantic()
|
||||
@ -93,7 +88,7 @@ class JobManager:
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticJob]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
filter_kwargs = {"user_id": actor.id, "job_type": job_type}
|
||||
|
||||
# Add status filter if provided
|
||||
@ -113,7 +108,7 @@ class JobManager:
|
||||
@enforce_types
|
||||
def delete_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Delete a job by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
job = self._verify_job_access(session=session, job_id=job_id, actor=actor)
|
||||
job.hard_delete(db_session=session, actor=actor)
|
||||
return job.to_pydantic()
|
||||
@ -147,7 +142,7 @@ class JobManager:
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Build filters
|
||||
filters = {}
|
||||
if role is not None:
|
||||
@ -195,7 +190,7 @@ class JobManager:
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Build filters
|
||||
filters = {}
|
||||
filters["job_id"] = job_id
|
||||
@ -227,7 +222,7 @@ class JobManager:
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# First verify job exists and user has access
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
|
||||
@ -251,7 +246,7 @@ class JobManager:
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# First verify job exists and user has access
|
||||
self._verify_job_access(session, job_id, actor)
|
||||
|
||||
@ -293,7 +288,7 @@ class JobManager:
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# First verify job exists and user has access
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
|
||||
@ -453,7 +448,7 @@ class JobManager:
|
||||
Returns:
|
||||
The request config for the job
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
job = session.query(JobModel).filter(JobModel.id == run_id).first()
|
||||
request_config = job.request_config or LettaRequestConfig()
|
||||
return request_config
|
||||
|
@ -16,6 +16,7 @@ from letta.schemas.llm_batch_job import LLMBatchJob as PydanticLLMBatchJob
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -24,11 +25,6 @@ logger = get_logger(__name__)
|
||||
class LLMBatchManager:
|
||||
"""Manager for handling both LLMBatchJob and LLMBatchItem operations."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_llm_batch_job(
|
||||
self,
|
||||
@ -39,7 +35,7 @@ class LLMBatchManager:
|
||||
status: JobStatus = JobStatus.created,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Create a new LLM batch job."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
batch = LLMBatchJob(
|
||||
status=status,
|
||||
llm_provider=llm_provider,
|
||||
@ -53,7 +49,7 @@ class LLMBatchManager:
|
||||
@enforce_types
|
||||
def get_llm_batch_job_by_id(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob:
|
||||
"""Retrieve a single batch job by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
return batch.to_pydantic()
|
||||
|
||||
@ -66,7 +62,7 @@ class LLMBatchManager:
|
||||
latest_polling_response: Optional[BetaMessageBatch] = None,
|
||||
) -> PydanticLLMBatchJob:
|
||||
"""Update a batch job’s status and optionally its polling response."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
batch.status = status
|
||||
batch.latest_polling_response = latest_polling_response
|
||||
@ -85,7 +81,7 @@ class LLMBatchManager:
|
||||
"""
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
mappings = []
|
||||
for llm_batch_id, status, response in updates:
|
||||
mappings.append(
|
||||
@ -119,7 +115,7 @@ class LLMBatchManager:
|
||||
|
||||
The results are ordered by their id in ascending order.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id)
|
||||
|
||||
if actor is not None:
|
||||
@ -140,7 +136,7 @@ class LLMBatchManager:
|
||||
@enforce_types
|
||||
def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch job by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor)
|
||||
batch.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@ -158,7 +154,7 @@ class LLMBatchManager:
|
||||
Retrieve messages across all LLM batch jobs associated with a Letta batch job.
|
||||
Optimized for PostgreSQL performance using ID-based keyset pagination.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# If cursor is provided, get sequence_id for that message
|
||||
cursor_sequence_id = None
|
||||
if cursor:
|
||||
@ -203,7 +199,7 @@ class LLMBatchManager:
|
||||
@enforce_types
|
||||
def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]:
|
||||
"""Return all running LLM batch jobs, optionally filtered by actor's organization."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running)
|
||||
|
||||
if actor is not None:
|
||||
@ -224,7 +220,7 @@ class LLMBatchManager:
|
||||
step_state: Optional[AgentStepState] = None,
|
||||
) -> PydanticLLMBatchItem:
|
||||
"""Create a new batch item."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
item = LLMBatchItem(
|
||||
llm_batch_id=llm_batch_id,
|
||||
agent_id=agent_id,
|
||||
@ -249,7 +245,7 @@ class LLMBatchManager:
|
||||
Returns:
|
||||
List of created batch items as Pydantic models
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Convert Pydantic models to ORM objects
|
||||
orm_items = []
|
||||
for item in llm_batch_items:
|
||||
@ -274,7 +270,7 @@ class LLMBatchManager:
|
||||
@enforce_types
|
||||
def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem:
|
||||
"""Retrieve a single batch item by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
return item.to_pydantic()
|
||||
|
||||
@ -289,7 +285,7 @@ class LLMBatchManager:
|
||||
step_state: Optional[AgentStepState] = None,
|
||||
) -> PydanticLLMBatchItem:
|
||||
"""Update fields on a batch item."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
|
||||
if request_status:
|
||||
@ -325,7 +321,7 @@ class LLMBatchManager:
|
||||
|
||||
The results are ordered by their id in ascending order.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
query = session.query(LLMBatchItem).filter(LLMBatchItem.llm_batch_id == llm_batch_id)
|
||||
|
||||
if actor is not None:
|
||||
@ -367,7 +363,7 @@ class LLMBatchManager:
|
||||
if len(llm_batch_id_agent_id_pairs) != len(field_updates):
|
||||
raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length")
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Lookup primary keys for all requested (batch_id, agent_id) pairs
|
||||
items = (
|
||||
session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id)
|
||||
@ -434,7 +430,7 @@ class LLMBatchManager:
|
||||
@enforce_types
|
||||
def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None:
|
||||
"""Hard delete a batch item by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor)
|
||||
item.hard_delete(db_session=session, actor=actor)
|
||||
|
||||
@ -449,6 +445,6 @@ class LLMBatchManager:
|
||||
Returns:
|
||||
int: The total number of batch items associated with the given llm_batch_id.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar()
|
||||
return count or 0
|
||||
|
@ -12,6 +12,7 @@ from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -20,15 +21,10 @@ logger = get_logger(__name__)
|
||||
class MessageManager:
|
||||
"""Manager class to handle business logic related to Messages."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_message_by_id(self, message_id: str, actor: PydanticUser) -> Optional[PydanticMessage]:
|
||||
"""Fetch a message by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
message = MessageModel.read(db_session=session, identifier=message_id, actor=actor)
|
||||
return message.to_pydantic()
|
||||
@ -38,7 +34,7 @@ class MessageManager:
|
||||
@enforce_types
|
||||
def get_messages_by_ids(self, message_ids: List[str], actor: PydanticUser) -> List[PydanticMessage]:
|
||||
"""Fetch messages by ID and return them in the requested order."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids))
|
||||
|
||||
if len(results) != len(message_ids):
|
||||
@ -53,7 +49,7 @@ class MessageManager:
|
||||
@enforce_types
|
||||
def create_message(self, pydantic_msg: PydanticMessage, actor: PydanticUser) -> PydanticMessage:
|
||||
"""Create a new message."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Set the organization id of the Pydantic message
|
||||
pydantic_msg.organization_id = actor.organization_id
|
||||
msg_data = pydantic_msg.model_dump(to_orm=True)
|
||||
@ -86,7 +82,7 @@ class MessageManager:
|
||||
orm_messages.append(MessageModel(**msg_data))
|
||||
|
||||
# Use the batch_create method for efficient creation
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
created_messages = MessageModel.batch_create(orm_messages, session, actor=actor)
|
||||
|
||||
# Convert back to Pydantic models
|
||||
@ -173,7 +169,7 @@ class MessageManager:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Fetch existing message from database
|
||||
message = MessageModel.read(
|
||||
db_session=session,
|
||||
@ -181,31 +177,57 @@ class MessageManager:
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Some safety checks specific to messages
|
||||
if message_update.tool_calls and message.role != MessageRole.assistant:
|
||||
raise ValueError(
|
||||
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
|
||||
)
|
||||
if message_update.tool_call_id and message.role != MessageRole.tool:
|
||||
raise ValueError(
|
||||
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
|
||||
)
|
||||
|
||||
# get update dictionary
|
||||
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(message, key, value)
|
||||
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
|
||||
message.update(db_session=session, actor=actor)
|
||||
|
||||
return message.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def update_message_by_id_async(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
Async version of the function above.
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch existing message from database
|
||||
message = await MessageModel.read_async(
|
||||
db_session=session,
|
||||
identifier=message_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
message = self._update_message_by_id_impl(message_id, message_update, actor, message)
|
||||
await message.update_async(db_session=session, actor=actor)
|
||||
return message.to_pydantic()
|
||||
|
||||
def _update_message_by_id_impl(
|
||||
self, message_id: str, message_update: MessageUpdate, actor: PydanticUser, message: MessageModel
|
||||
) -> MessageModel:
|
||||
"""
|
||||
Modifies the existing message object to update the database in the sync/async functions.
|
||||
"""
|
||||
# Some safety checks specific to messages
|
||||
if message_update.tool_calls and message.role != MessageRole.assistant:
|
||||
raise ValueError(
|
||||
f"Tool calls {message_update.tool_calls} can only be added to assistant messages. Message {message_id} has role {message.role}."
|
||||
)
|
||||
if message_update.tool_call_id and message.role != MessageRole.tool:
|
||||
raise ValueError(
|
||||
f"Tool call IDs {message_update.tool_call_id} can only be added to tool messages. Message {message_id} has role {message.role}."
|
||||
)
|
||||
|
||||
# get update dictionary
|
||||
update_data = message_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
# Remove redundant update fields
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(message, key) != value}
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(message, key, value)
|
||||
return message
|
||||
|
||||
@enforce_types
|
||||
def delete_message_by_id(self, message_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a message."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
msg = MessageModel.read(
|
||||
db_session=session,
|
||||
@ -229,7 +251,7 @@ class MessageManager:
|
||||
actor: The user requesting the count
|
||||
role: The role of the message
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id)
|
||||
|
||||
@enforce_types
|
||||
@ -293,7 +315,7 @@ class MessageManager:
|
||||
NoResultFound: If the provided after/before message IDs do not exist.
|
||||
"""
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Permission check: raise if the agent doesn't exist or actor is not allowed.
|
||||
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
@ -356,7 +378,7 @@ class MessageManager:
|
||||
Efficiently deletes all messages associated with a given agent_id,
|
||||
while enforcing permission checks and avoiding any ORM‑level loads.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# 1) verify the agent exists and the actor has access
|
||||
AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.organization import OrganizationUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
@ -13,14 +14,6 @@ class OrganizationManager:
|
||||
DEFAULT_ORG_ID = "org-00000000-0000-4000-8000-000000000000"
|
||||
DEFAULT_ORG_NAME = "default_org"
|
||||
|
||||
def __init__(self):
|
||||
# TODO: Please refactor this out
|
||||
# I am currently working on a ORM refactor and would like to make a more minimal set of changes
|
||||
# - Matt
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_default_organization(self) -> PydanticOrganization:
|
||||
"""Fetch the default organization."""
|
||||
@ -29,7 +22,7 @@ class OrganizationManager:
|
||||
@enforce_types
|
||||
def get_organization_by_id(self, org_id: str) -> Optional[PydanticOrganization]:
|
||||
"""Fetch an organization by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
return organization.to_pydantic()
|
||||
|
||||
@ -44,7 +37,7 @@ class OrganizationManager:
|
||||
|
||||
@enforce_types
|
||||
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
|
||||
org.create(session)
|
||||
return org.to_pydantic()
|
||||
@ -57,7 +50,7 @@ class OrganizationManager:
|
||||
@enforce_types
|
||||
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
|
||||
"""Update an organization."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
org = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
if name:
|
||||
org.name = name
|
||||
@ -67,7 +60,7 @@ class OrganizationManager:
|
||||
@enforce_types
|
||||
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
|
||||
"""Update an organization."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
org = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
if org_update.name:
|
||||
org.name = org_update.name
|
||||
@ -79,14 +72,14 @@ class OrganizationManager:
|
||||
@enforce_types
|
||||
def delete_organization_by_id(self, org_id: str):
|
||||
"""Delete an organization by marking it as deleted."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
organization = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
organization.hard_delete(session)
|
||||
|
||||
@enforce_types
|
||||
def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
|
||||
"""List all organizations with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
organizations = OrganizationModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
|
@ -10,21 +10,17 @@ from letta.orm.passage import AgentPassage, SourcePassage
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class PassageManager:
|
||||
"""Manager class to handle business logic related to Passages."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch a passage by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Try source passages first
|
||||
try:
|
||||
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
@ -69,7 +65,7 @@ class PassageManager:
|
||||
else:
|
||||
raise ValueError("Passage must have either agent_id or source_id")
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
passage.create(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@ -145,7 +141,7 @@ class PassageManager:
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Try source passages first
|
||||
try:
|
||||
curr_passage = SourcePassage.read(
|
||||
@ -179,7 +175,7 @@ class PassageManager:
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Try source passages first
|
||||
try:
|
||||
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
@ -217,7 +213,7 @@ class PassageManager:
|
||||
actor: The user requesting the count
|
||||
agent_id: The agent ID of the messages
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
|
||||
|
||||
def estimate_embeddings_size(
|
||||
|
@ -5,20 +5,16 @@ from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_provider(self, request: ProviderCreate, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Create a new provider if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
provider_create_args = {**request.model_dump(), "provider_category": ProviderCategory.byok}
|
||||
provider = PydanticProvider(**provider_create_args)
|
||||
|
||||
@ -38,7 +34,7 @@ class ProviderManager:
|
||||
@enforce_types
|
||||
def update_provider(self, provider_id: str, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider:
|
||||
"""Update provider details."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve the existing provider by ID
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
|
||||
|
||||
@ -54,7 +50,7 @@ class ProviderManager:
|
||||
@enforce_types
|
||||
def delete_provider_by_id(self, provider_id: str, actor: PydanticUser):
|
||||
"""Delete a provider."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Clear api key field
|
||||
existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor)
|
||||
existing_provider.api_key = None
|
||||
@ -80,7 +76,7 @@ class ProviderManager:
|
||||
filter_kwargs["name"] = name
|
||||
if provider_type:
|
||||
filter_kwargs["provider_type"] = provider_type
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
providers = ProviderModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
|
@ -11,6 +11,7 @@ from letta.schemas.sandbox_config import LocalSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfig as PydanticSandboxConfig
|
||||
from letta.schemas.sandbox_config import SandboxConfigCreate, SandboxConfigUpdate, SandboxType
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -19,11 +20,6 @@ logger = get_logger(__name__)
|
||||
class SandboxConfigManager:
|
||||
"""Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_or_create_default_sandbox_config(self, sandbox_type: SandboxType, actor: PydanticUser) -> PydanticSandboxConfig:
|
||||
sandbox_config = self.get_sandbox_config_by_type(sandbox_type, actor=actor)
|
||||
@ -69,7 +65,7 @@ class SandboxConfigManager:
|
||||
return db_sandbox
|
||||
else:
|
||||
# If the sandbox configuration doesn't exist, create a new one
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
db_sandbox = SandboxConfigModel(**sandbox_config.model_dump(exclude_none=True))
|
||||
db_sandbox.create(session, actor=actor)
|
||||
return db_sandbox.to_pydantic()
|
||||
@ -79,7 +75,7 @@ class SandboxConfigManager:
|
||||
self, sandbox_config_id: str, sandbox_update: SandboxConfigUpdate, actor: PydanticUser
|
||||
) -> PydanticSandboxConfig:
|
||||
"""Update an existing sandbox configuration."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor)
|
||||
# We need to check that the sandbox_update provided is the same type as the original sandbox
|
||||
if sandbox.type != sandbox_update.config.type:
|
||||
@ -104,7 +100,7 @@ class SandboxConfigManager:
|
||||
@enforce_types
|
||||
def delete_sandbox_config(self, sandbox_config_id: str, actor: PydanticUser) -> PydanticSandboxConfig:
|
||||
"""Delete a sandbox configuration by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor)
|
||||
sandbox.hard_delete(db_session=session, actor=actor)
|
||||
return sandbox.to_pydantic()
|
||||
@ -122,14 +118,14 @@ class SandboxConfigManager:
|
||||
if sandbox_type:
|
||||
kwargs.update({"type": sandbox_type})
|
||||
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
sandboxes = SandboxConfigModel.list(db_session=session, after=after, limit=limit, **kwargs)
|
||||
return [sandbox.to_pydantic() for sandbox in sandboxes]
|
||||
|
||||
@enforce_types
|
||||
def get_sandbox_config_by_id(self, sandbox_config_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
|
||||
"""Retrieve a sandbox configuration by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
sandbox = SandboxConfigModel.read(db_session=session, identifier=sandbox_config_id, actor=actor)
|
||||
return sandbox.to_pydantic()
|
||||
@ -139,7 +135,7 @@ class SandboxConfigManager:
|
||||
@enforce_types
|
||||
def get_sandbox_config_by_type(self, type: SandboxType, actor: Optional[PydanticUser] = None) -> Optional[PydanticSandboxConfig]:
|
||||
"""Retrieve a sandbox config by its type."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
sandboxes = SandboxConfigModel.list(
|
||||
db_session=session,
|
||||
@ -175,7 +171,7 @@ class SandboxConfigManager:
|
||||
|
||||
return db_env_var
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
env_var = SandboxEnvVarModel(**env_var.model_dump(to_orm=True, exclude_none=True))
|
||||
env_var.create(session, actor=actor)
|
||||
return env_var.to_pydantic()
|
||||
@ -185,7 +181,7 @@ class SandboxConfigManager:
|
||||
self, env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, actor: PydanticUser
|
||||
) -> PydanticEnvVar:
|
||||
"""Update an existing sandbox environment variable."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor)
|
||||
update_data = env_var_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
update_data = {key: value for key, value in update_data.items() if getattr(env_var, key) != value}
|
||||
@ -204,7 +200,7 @@ class SandboxConfigManager:
|
||||
@enforce_types
|
||||
def delete_sandbox_env_var(self, env_var_id: str, actor: PydanticUser) -> PydanticEnvVar:
|
||||
"""Delete a sandbox environment variable by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
env_var = SandboxEnvVarModel.read(db_session=session, identifier=env_var_id, actor=actor)
|
||||
env_var.hard_delete(db_session=session, actor=actor)
|
||||
return env_var.to_pydantic()
|
||||
@ -218,7 +214,7 @@ class SandboxConfigManager:
|
||||
limit: Optional[int] = 50,
|
||||
) -> List[PydanticEnvVar]:
|
||||
"""List all sandbox environment variables with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
env_vars = SandboxEnvVarModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
@ -233,7 +229,7 @@ class SandboxConfigManager:
|
||||
self, key: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
||||
) -> List[PydanticEnvVar]:
|
||||
"""List all sandbox environment variables with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
env_vars = SandboxEnvVarModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
@ -258,7 +254,7 @@ class SandboxConfigManager:
|
||||
self, key: str, sandbox_config_id: str, actor: Optional[PydanticUser] = None
|
||||
) -> Optional[PydanticEnvVar]:
|
||||
"""Retrieve a sandbox environment variable by its key and sandbox_config_id."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
env_var = SandboxEnvVarModel.list(
|
||||
db_session=session,
|
||||
|
@ -8,17 +8,13 @@ from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
|
||||
class SourceManager:
|
||||
"""Manager class to handle business logic related to Sources."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
|
||||
"""Create a new source based on the PydanticSource schema."""
|
||||
@ -27,7 +23,7 @@ class SourceManager:
|
||||
if db_source:
|
||||
return db_source
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Provide default embedding config if not given
|
||||
source.organization_id = actor.organization_id
|
||||
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
|
||||
@ -37,7 +33,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
|
||||
"""Update a source by its ID with the given SourceUpdate object."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
|
||||
|
||||
# get update dictionary
|
||||
@ -59,7 +55,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
|
||||
"""Delete a source by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
source = SourceModel.read(db_session=session, identifier=source_id)
|
||||
source.hard_delete(db_session=session, actor=actor)
|
||||
return source.to_pydantic()
|
||||
@ -67,7 +63,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
def list_sources(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]:
|
||||
"""List all sources with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
sources = SourceModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
@ -85,7 +81,7 @@ class SourceManager:
|
||||
"""
|
||||
Get the total count of sources for the given user.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
return SourceModel.size(db_session=session, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@ -100,7 +96,7 @@ class SourceManager:
|
||||
Returns:
|
||||
List[PydanticAgentState]: List of agents that have this source attached
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Verify source exists and user has permission to access it
|
||||
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
|
||||
|
||||
@ -112,7 +108,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
|
||||
"""Retrieve a source by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
|
||||
return source.to_pydantic()
|
||||
@ -122,7 +118,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
|
||||
"""Retrieve a source by its name."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
sources = SourceModel.list(
|
||||
db_session=session,
|
||||
name=source_name,
|
||||
@ -141,7 +137,7 @@ class SourceManager:
|
||||
if db_file:
|
||||
return db_file
|
||||
else:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
file_metadata.organization_id = actor.organization_id
|
||||
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
|
||||
file_metadata.create(session, actor=actor)
|
||||
@ -151,7 +147,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]:
|
||||
"""Retrieve a file by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
file = FileMetadataModel.read(db_session=session, identifier=file_id, actor=actor)
|
||||
return file.to_pydantic()
|
||||
@ -163,7 +159,7 @@ class SourceManager:
|
||||
self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
||||
) -> List[PydanticFileMetadata]:
|
||||
"""List all files with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
files = FileMetadataModel.list(
|
||||
db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id
|
||||
)
|
||||
@ -172,7 +168,7 @@ class SourceManager:
|
||||
@enforce_types
|
||||
def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
|
||||
"""Delete a file by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
file = FileMetadataModel.read(db_session=session, identifier=file_id)
|
||||
file.hard_delete(db_session=session, actor=actor)
|
||||
return file.to_pydantic()
|
||||
|
@ -11,17 +11,13 @@ from letta.orm.step import Step as StepModel
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.tracing import get_trace_id
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class StepManager:
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def list_steps(
|
||||
self,
|
||||
@ -36,7 +32,7 @@ class StepManager:
|
||||
agent_id: Optional[str] = None,
|
||||
) -> List[PydanticStep]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
filter_kwargs = {"organization_id": actor.organization_id}
|
||||
if model:
|
||||
filter_kwargs["model"] = model
|
||||
@ -85,7 +81,7 @@ class StepManager:
|
||||
"tid": None,
|
||||
"trace_id": get_trace_id(), # Get the current trace ID
|
||||
}
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
if job_id:
|
||||
self._verify_job_access(session, job_id, actor, access=["write"])
|
||||
new_step = StepModel(**step_data)
|
||||
@ -94,7 +90,7 @@ class StepManager:
|
||||
|
||||
@enforce_types
|
||||
def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
step = StepModel.read(db_session=session, identifier=step_id, actor=actor)
|
||||
return step.to_pydantic()
|
||||
|
||||
@ -113,7 +109,7 @@ class StepManager:
|
||||
Raises:
|
||||
NoResultFound: If the step does not exist
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
step = session.get(StepModel, step_id)
|
||||
if not step:
|
||||
raise NoResultFound(f"Step with id {step_id} does not exist")
|
||||
|
@ -23,6 +23,7 @@ from letta.orm.tool import Tool as ToolModel
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@ -31,12 +32,6 @@ logger = get_logger(__name__)
|
||||
class ToolManager:
|
||||
"""Manager class to handle business logic related to Tools."""
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
|
||||
@enforce_types
|
||||
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
@ -89,7 +84,7 @@ class ToolManager:
|
||||
@enforce_types
|
||||
def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_tool.organization_id = actor.organization_id
|
||||
# Auto-generate description if not provided
|
||||
@ -104,7 +99,7 @@ class ToolManager:
|
||||
@enforce_types
|
||||
def get_tool_by_id(self, tool_id: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Fetch a tool by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve tool by id using the Tool model's read method
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||
# Convert the SQLAlchemy Tool object to PydanticTool
|
||||
@ -114,7 +109,7 @@ class ToolManager:
|
||||
def get_tool_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]:
|
||||
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
|
||||
try:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
tool = ToolModel.read(db_session=session, name=tool_name, actor=actor)
|
||||
return tool.to_pydantic()
|
||||
except NoResultFound:
|
||||
@ -124,7 +119,7 @@ class ToolManager:
|
||||
def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]:
|
||||
"""Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool."""
|
||||
try:
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
tool = ToolModel.read(db_session=session, name=tool_name, actor=actor)
|
||||
return tool.id
|
||||
except NoResultFound:
|
||||
@ -133,7 +128,7 @@ class ToolManager:
|
||||
@enforce_types
|
||||
def list_tools(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]:
|
||||
"""List all tools with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
tools = ToolModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
@ -166,7 +161,7 @@ class ToolManager:
|
||||
|
||||
If include_builtin is True, it will also count the built-in tools.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
if include_base_tools:
|
||||
return ToolModel.size(db_session=session, actor=actor)
|
||||
return ToolModel.size(db_session=session, actor=actor, name=LETTA_TOOL_SET)
|
||||
@ -176,7 +171,7 @@ class ToolManager:
|
||||
self, tool_id: str, tool_update: ToolUpdate, actor: PydanticUser, updated_tool_type: Optional[ToolType] = None
|
||||
) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Fetch the tool by ID
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||
|
||||
@ -202,7 +197,7 @@ class ToolManager:
|
||||
@enforce_types
|
||||
def delete_tool_by_id(self, tool_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a tool by its ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||
tool.hard_delete(db_session=session, actor=actor)
|
||||
|
@ -5,6 +5,7 @@ from letta.orm.organization import Organization as OrganizationModel
|
||||
from letta.orm.user import User as UserModel
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.schemas.user import UserUpdate
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
@ -15,16 +16,10 @@ class UserManager:
|
||||
DEFAULT_USER_NAME = "default_user"
|
||||
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
|
||||
|
||||
def __init__(self):
|
||||
# Fetching the db_context similarly as in OrganizationManager
|
||||
from letta.server.db import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def create_default_user(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
|
||||
"""Create the default user."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Make sure the org id exists
|
||||
try:
|
||||
OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
@ -44,7 +39,7 @@ class UserManager:
|
||||
@enforce_types
|
||||
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
|
||||
"""Create a new user if it doesn't already exist."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
|
||||
new_user.create(session)
|
||||
return new_user.to_pydantic()
|
||||
@ -52,7 +47,7 @@ class UserManager:
|
||||
@enforce_types
|
||||
def update_user(self, user_update: UserUpdate) -> PydanticUser:
|
||||
"""Update user details."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Retrieve the existing user by ID
|
||||
existing_user = UserModel.read(db_session=session, identifier=user_update.id)
|
||||
|
||||
@ -68,7 +63,7 @@ class UserManager:
|
||||
@enforce_types
|
||||
def delete_user_by_id(self, user_id: str):
|
||||
"""Delete a user and their associated records (agents, sources, mappings)."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
# Delete from user table
|
||||
user = UserModel.read(db_session=session, identifier=user_id)
|
||||
user.hard_delete(session)
|
||||
@ -78,7 +73,7 @@ class UserManager:
|
||||
@enforce_types
|
||||
def get_user_by_id(self, user_id: str) -> PydanticUser:
|
||||
"""Fetch a user by ID."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
user = UserModel.read(db_session=session, identifier=user_id)
|
||||
return user.to_pydantic()
|
||||
|
||||
@ -104,7 +99,7 @@ class UserManager:
|
||||
@enforce_types
|
||||
def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
|
||||
"""List all users with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
users = UserModel.list(
|
||||
db_session=session,
|
||||
after=after,
|
||||
|
@ -203,6 +203,7 @@ class Settings(BaseSettings):
|
||||
use_experimental: bool = False
|
||||
use_vertex_structured_outputs_experimental: bool = False
|
||||
use_vertex_async_loop_experimental: bool = False
|
||||
experimental_enable_async_db_engine: bool = False
|
||||
|
||||
# LLM provider client settings
|
||||
httpx_max_retries: int = 5
|
||||
|
@ -812,7 +812,7 @@ def printd(*args, **kwargs):
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def united_diff(str1, str2):
|
||||
def united_diff(str1: str, str2: str) -> str:
|
||||
lines1 = str1.splitlines(True)
|
||||
lines2 = str2.splitlines(True)
|
||||
diff = difflib.unified_diff(lines1, lines2)
|
||||
|
69
poetry.lock
generated
69
poetry.lock
generated
@ -326,6 +326,73 @@ files = [
|
||||
{file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "asyncpg"
|
||||
version = "0.30.0"
|
||||
description = "An asyncio PostgreSQL driver"
|
||||
optional = false
|
||||
python-versions = ">=3.8.0"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf"},
|
||||
{file = "asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d"},
|
||||
{file = "asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e"},
|
||||
{file = "asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590"},
|
||||
{file = "asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:29ff1fc8b5bf724273782ff8b4f57b0f8220a1b2324184846b39d1ab4122031d"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64e899bce0600871b55368b8483e5e3e7f1860c9482e7f12e0a771e747988168"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b290f4726a887f75dcd1b3006f484252db37602313f806e9ffc4e5996cfe5cb"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f86b0e2cd3f1249d6fe6fd6cfe0cd4538ba994e2d8249c0491925629b9104d0f"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:393af4e3214c8fa4c7b86da6364384c0d1b3298d45803375572f415b6f673f38"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fd4406d09208d5b4a14db9a9dbb311b6d7aeeab57bded7ed2f8ea41aeef39b34"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-win32.whl", hash = "sha256:0b448f0150e1c3b96cb0438a0d0aa4871f1472e58de14a3ec320dbb2798fb0d4"},
|
||||
{file = "asyncpg-0.30.0-cp38-cp38-win_amd64.whl", hash = "sha256:f23b836dd90bea21104f69547923a02b167d999ce053f3d502081acea2fba15b"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f4e83f067b35ab5e6371f8a4c93296e0439857b4569850b178a01385e82e9ad"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5df69d55add4efcd25ea2a3b02025b669a285b767bfbf06e356d68dbce4234ff"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3479a0d9a852c7c84e822c073622baca862d1217b10a02dd57ee4a7a081f708"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26683d3b9a62836fad771a18ecf4659a30f348a561279d6227dab96182f46144"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b982daf2441a0ed314bd10817f1606f1c28b1136abd9e4f11335358c2c631cb"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c06a3a50d014b303e5f6fc1e5f95eb28d2cee89cf58384b700da621e5d5e547"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-win32.whl", hash = "sha256:1b11a555a198b08f5c4baa8f8231c74a366d190755aa4f99aacec5970afe929a"},
|
||||
{file = "asyncpg-0.30.0-cp39-cp39-win_amd64.whl", hash = "sha256:8b684a3c858a83cd876f05958823b68e8d14ec01bb0c0d14a6704c5bf9711773"},
|
||||
{file = "asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
async-timeout = {version = ">=4.0.3", markers = "python_version < \"3.11.0\""}
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=8.1.3,<8.2.0)", "sphinx-rtd-theme (>=1.2.2)"]
|
||||
gssauth = ["gssapi ; platform_system != \"Windows\"", "sspilib ; platform_system == \"Windows\""]
|
||||
test = ["distro (>=1.9.0,<1.10.0)", "flake8 (>=6.1,<7.0)", "flake8-pyi (>=24.1.0,<24.2.0)", "gssapi ; platform_system == \"Linux\"", "k5test ; platform_system == \"Linux\"", "mypy (>=1.8.0,<1.9.0)", "sspilib ; platform_system == \"Windows\"", "uvloop (>=0.15.3) ; platform_system != \"Windows\" and python_version < \"3.14.0\""]
|
||||
|
||||
[[package]]
|
||||
name = "attrs"
|
||||
version = "25.3.0"
|
||||
@ -7503,4 +7570,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "f82fec7b3f35d4222c43b692db8cd005eaf8bcf6761bb202d0dbf64121c6b2ab"
|
||||
content-hash = "862dc5a31d4385e89dc9a751cd171a611da3102c6832447a5f61926b25f03e06"
|
||||
|
@ -90,6 +90,7 @@ firecrawl-py = "^1.15.0"
|
||||
apscheduler = "^3.11.0"
|
||||
aiomultiprocess = "^0.9.1"
|
||||
matplotlib = "^3.10.1"
|
||||
asyncpg = "^0.30.0"
|
||||
|
||||
|
||||
[tool.poetry.extras]
|
||||
|
@ -15,6 +15,7 @@ from letta.schemas.enums import JobStatus, ToolRuleType
|
||||
from letta.schemas.group import GroupUpdate, ManagerType, SleeptimeManagerUpdate
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.run import Run
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
from letta.utils import get_human_text, get_persona_text
|
||||
|
||||
@ -37,7 +38,7 @@ def org_id(server):
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
with server.organization_manager.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
|
@ -15,6 +15,7 @@ from letta.schemas.group import (
|
||||
SupervisorManager,
|
||||
)
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.db import db_registry
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@ -36,7 +37,7 @@ def org_id(server):
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
with server.organization_manager.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
|
@ -19,6 +19,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers import ProviderCreate
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.schemas.user import User
|
||||
from letta.server.db import db_registry
|
||||
|
||||
utils.DEBUG = True
|
||||
from letta.config import LettaConfig
|
||||
@ -284,7 +285,7 @@ def org_id(server):
|
||||
yield org.id
|
||||
|
||||
# cleanup
|
||||
with server.organization_manager.session_maker() as session:
|
||||
with db_registry.session() as session:
|
||||
session.execute(delete(Step))
|
||||
session.execute(delete(Provider))
|
||||
session.commit()
|
||||
|
Loading…
Reference in New Issue
Block a user