mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00

Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com> Co-authored-by: Kevin Lin <klin5061@gmail.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: jnjpng <jin@letta.com> Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
190 lines
8.6 KiB
Python
190 lines
8.6 KiB
Python
import asyncio
|
|
from typing import List, Optional
|
|
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.orm.file import FileMetadata as FileMetadataModel
|
|
from letta.orm.source import Source as SourceModel
|
|
from letta.schemas.agent import AgentState as PydanticAgentState
|
|
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
|
from letta.schemas.source import Source as PydanticSource
|
|
from letta.schemas.source import SourceUpdate
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.tracing import trace_method
|
|
from letta.utils import enforce_types, printd
|
|
|
|
|
|
class SourceManager:
|
|
"""Manager class to handle business logic related to Sources."""
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
|
|
"""Create a new source based on the PydanticSource schema."""
|
|
# Try getting the source first by id
|
|
db_source = await self.get_source_by_id(source.id, actor=actor)
|
|
if db_source:
|
|
return db_source
|
|
else:
|
|
async with db_registry.async_session() as session:
|
|
# Provide default embedding config if not given
|
|
source.organization_id = actor.organization_id
|
|
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
|
|
await source.create_async(session, actor=actor)
|
|
return source.to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
|
|
"""Update a source by its ID with the given SourceUpdate object."""
|
|
async with db_registry.async_session() as session:
|
|
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
|
|
|
|
# get update dictionary
|
|
update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
|
# Remove redundant update fields
|
|
update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value}
|
|
|
|
if update_data:
|
|
for key, value in update_data.items():
|
|
setattr(source, key, value)
|
|
source.update(db_session=session, actor=actor)
|
|
else:
|
|
printd(
|
|
f"`update_source` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={source.name}, but found existing source with nothing to update."
|
|
)
|
|
|
|
return source.to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
|
|
"""Delete a source by its ID."""
|
|
async with db_registry.async_session() as session:
|
|
source = await SourceModel.read_async(db_session=session, identifier=source_id)
|
|
await source.hard_delete_async(db_session=session, actor=actor)
|
|
return source.to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_sources(
|
|
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs
|
|
) -> List[PydanticSource]:
|
|
"""List all sources with optional pagination."""
|
|
async with db_registry.async_session() as session:
|
|
sources = await SourceModel.list_async(
|
|
db_session=session,
|
|
after=after,
|
|
limit=limit,
|
|
organization_id=actor.organization_id,
|
|
**kwargs,
|
|
)
|
|
return [source.to_pydantic() for source in sources]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def size(self, actor: PydanticUser) -> int:
|
|
"""
|
|
Get the total count of sources for the given user.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
return await SourceModel.size_async(db_session=session, actor=actor)
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
|
|
"""
|
|
Lists all agents that have the specified source attached.
|
|
|
|
Args:
|
|
source_id: ID of the source to find attached agents for
|
|
actor: User performing the action (optional for now, following existing pattern)
|
|
|
|
Returns:
|
|
List[PydanticAgentState]: List of agents that have this source attached
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
# Verify source exists and user has permission to access it
|
|
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
|
|
|
|
# The agents relationship is already loaded due to lazy="selectin" in the Source model
|
|
# and will be properly filtered by organization_id due to the OrganizationMixin
|
|
agents_orm = source.agents
|
|
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
|
|
|
|
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
|
|
"""Retrieve a source by its ID."""
|
|
async with db_registry.async_session() as session:
|
|
try:
|
|
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
|
|
return source.to_pydantic()
|
|
except NoResultFound:
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
|
|
"""Retrieve a source by its name."""
|
|
async with db_registry.async_session() as session:
|
|
sources = await SourceModel.list_async(
|
|
db_session=session,
|
|
name=source_name,
|
|
organization_id=actor.organization_id,
|
|
limit=1,
|
|
)
|
|
if not sources:
|
|
return None
|
|
else:
|
|
return sources[0].to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata:
|
|
"""Create a new file based on the PydanticFileMetadata schema."""
|
|
db_file = await self.get_file_by_id(file_metadata.id, actor=actor)
|
|
if db_file:
|
|
return db_file
|
|
else:
|
|
async with db_registry.async_session() as session:
|
|
file_metadata.organization_id = actor.organization_id
|
|
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
|
|
await file_metadata.create_async(session, actor=actor)
|
|
return file_metadata.to_pydantic()
|
|
|
|
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
|
|
@enforce_types
|
|
@trace_method
|
|
async def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]:
|
|
"""Retrieve a file by its ID."""
|
|
async with db_registry.async_session() as session:
|
|
try:
|
|
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
|
|
return file.to_pydantic()
|
|
except NoResultFound:
|
|
return None
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_files(
|
|
self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
|
|
) -> List[PydanticFileMetadata]:
|
|
"""List all files with optional pagination."""
|
|
async with db_registry.async_session() as session:
|
|
files_all = await FileMetadataModel.list_async(db_session=session, organization_id=actor.organization_id, source_id=source_id)
|
|
files = await FileMetadataModel.list_async(
|
|
db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id
|
|
)
|
|
return [file.to_pydantic() for file in files]
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
|
|
"""Delete a file by its ID."""
|
|
async with db_registry.async_session() as session:
|
|
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
|
|
await file.hard_delete_async(db_session=session, actor=actor)
|
|
return file.to_pydantic()
|