diff --git a/.gitignore b/.gitignore index 12042451e..baaeabfbe 100644 --- a/.gitignore +++ b/.gitignore @@ -1022,3 +1022,6 @@ memgpy/pytest.ini ## ignore venvs tests/test_tool_sandbox/restaurant_management_system/venv + +## custom scripts +test diff --git a/alembic/versions/c5d964280dff_add_passages_orm_drop_legacy_passages_.py b/alembic/versions/c5d964280dff_add_passages_orm_drop_legacy_passages_.py new file mode 100644 index 000000000..a16fdae44 --- /dev/null +++ b/alembic/versions/c5d964280dff_add_passages_orm_drop_legacy_passages_.py @@ -0,0 +1,88 @@ +"""Add Passages ORM, drop legacy passages, cascading deletes for file-passages and user-jobs + +Revision ID: c5d964280dff +Revises: a91994b9752f +Create Date: 2024-12-10 15:05:32.335519 + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = 'c5d964280dff' +down_revision: Union[str, None] = 'a91994b9752f' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True)) + op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False)) + op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True)) + op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True)) + + # Data migration step: + op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True)) + # Populate `organization_id` based on `user_id` + # Use a raw SQL query to update the organization_id + op.execute( + """ + UPDATE passages + SET organization_id = users.organization_id + FROM users + WHERE passages.user_id = users.id + """ + ) + + # Set `organization_id` as non-nullable after population + op.alter_column("passages", "organization_id", nullable=False) + + op.alter_column('passages', 'text', + existing_type=sa.VARCHAR(), + nullable=False) + op.alter_column('passages', 'embedding_config', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + op.alter_column('passages', 'metadata_', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=False) + op.alter_column('passages', 'created_at', + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=False) + op.drop_index('passage_idx_user', table_name='passages') + op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id']) + op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id']) + op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE') + op.drop_column('passages', 'user_id') + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False)) + op.drop_constraint(None, 'passages', type_='foreignkey') + op.drop_constraint(None, 'passages', type_='foreignkey') + op.drop_constraint(None, 'passages', type_='foreignkey') + op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False) + op.alter_column('passages', 'created_at', + existing_type=postgresql.TIMESTAMP(timezone=True), + nullable=True) + op.alter_column('passages', 'metadata_', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + op.alter_column('passages', 'embedding_config', + existing_type=postgresql.JSON(astext_type=sa.Text()), + nullable=True) + op.alter_column('passages', 'text', + existing_type=sa.VARCHAR(), + nullable=True) + op.drop_column('passages', 'organization_id') + op.drop_column('passages', '_last_updated_by_id') + op.drop_column('passages', '_created_by_id') + op.drop_column('passages', 'is_deleted') + op.drop_column('passages', 'updated_at') + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 81924c2ea..efb850e28 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -28,7 +28,7 @@ from letta.interface import AgentInterface from letta.llm_api.helpers import is_context_overflow_error from letta.llm_api.llm_api_tools import create from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages -from letta.memory import ArchivalMemory, EmbeddingArchivalMemory, summarize_messages +from letta.memory import summarize_messages from letta.metadata import MetadataStore from letta.orm import User from letta.schemas.agent import AgentState, AgentStepResponse @@ -52,6 +52,7 @@ from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser from letta.services.block_manager import BlockManager from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager from letta.services.tool_execution_sandbox import ToolExecutionSandbox from letta.services.user_manager import UserManager @@ -85,7 +86,7 @@ def compile_memory_metadata_block( actor: PydanticUser, agent_id: str, memory_edit_timestamp: datetime.datetime, - archival_memory: Optional[ArchivalMemory] = None, + passage_manager: Optional[PassageManager] = None, message_manager: Optional[MessageManager] = None, ) -> str: # Put the timestamp in the local timezone (mimicking get_local_time()) @@ -96,7 +97,7 @@ def compile_memory_metadata_block( [ f"### Memory [last modified: {timestamp_str}]", f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)", - f"{archival_memory.count() if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)", + f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)", "\nCore memory shown below (limited in size, additional information stored in archival / recall memory):", ] ) @@ -109,7 +110,7 @@ def compile_system_message( in_context_memory: Memory, in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory? actor: PydanticUser, - archival_memory: Optional[ArchivalMemory] = None, + passage_manager: Optional[PassageManager] = None, message_manager: Optional[MessageManager] = None, user_defined_variables: Optional[dict] = None, append_icm_if_missing: bool = True, @@ -138,7 +139,7 @@ def compile_system_message( actor=actor, agent_id=agent_id, memory_edit_timestamp=in_context_memory_last_edit, - archival_memory=archival_memory, + passage_manager=passage_manager, message_manager=message_manager, ) full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile() @@ -175,7 +176,7 @@ def initialize_message_sequence( agent_id: str, memory: Memory, actor: PydanticUser, - archival_memory: Optional[ArchivalMemory] = None, + passage_manager: Optional[PassageManager] = None, message_manager: Optional[MessageManager] = None, memory_edit_timestamp: Optional[datetime.datetime] = None, include_initial_boot_message: bool = True, @@ -184,7 +185,7 @@ def initialize_message_sequence( memory_edit_timestamp = get_local_time() # full_system_message = construct_system_with_memory( - # system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory + # system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory # ) full_system_message = compile_system_message( agent_id=agent_id, @@ -192,7 +193,7 @@ def initialize_message_sequence( in_context_memory=memory, in_context_memory_last_edit=memory_edit_timestamp, actor=actor, - archival_memory=archival_memory, + passage_manager=passage_manager, message_manager=message_manager, user_defined_variables=None, append_icm_if_missing=True, @@ -294,7 +295,7 @@ class Agent(BaseAgent): self.interface = interface # Create the persistence manager object based on the AgentState info - self.archival_memory = EmbeddingArchivalMemory(agent_state) + self.passage_manager = PassageManager() self.message_manager = MessageManager() # State needed for heartbeat pausing @@ -325,7 +326,7 @@ class Agent(BaseAgent): agent_id=self.agent_state.id, memory=self.agent_state.memory, actor=self.user, - archival_memory=None, + passage_manager=None, message_manager=None, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True, @@ -350,7 +351,7 @@ class Agent(BaseAgent): memory=self.agent_state.memory, agent_id=self.agent_state.id, actor=self.user, - archival_memory=None, + passage_manager=None, message_manager=None, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True, @@ -1306,7 +1307,7 @@ class Agent(BaseAgent): in_context_memory=self.agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, actor=self.user, - archival_memory=self.archival_memory, + passage_manager=self.passage_manager, message_manager=self.message_manager, user_defined_variables=None, append_icm_if_missing=True, @@ -1371,45 +1372,33 @@ class Agent(BaseAgent): # TODO: recall memory raise NotImplementedError() - def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore): + def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore): """Attach data with name `source_name` to the agent from source_connector.""" # TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory - user = UserManager().get_user_by_id(self.agent_state.user_id) - filters = {"user_id": self.agent_state.user_id, "source_id": source_id} - size = source_connector.size(filters) page_size = 100 - generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage] - all_passages = [] - for i in tqdm(range(0, size, page_size)): - passages = next(generator) + passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size) - # need to associated passage with agent (for filtering) - for passage in passages: - assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}" - passage.agent_id = self.agent_state.id + for passage in passages: + assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}" + passage.agent_id = self.agent_state.id + self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user) - # regenerate passage ID (avoid duplicates) - # TODO: need to find another solution to the text duplication issue - # passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}") - - # insert into agent archival memory - self.archival_memory.storage.insert_many(passages) - all_passages += passages - - assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}" - - # save destination storage - self.archival_memory.storage.save() + agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size) + passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id) + assert all([p.agent_id == self.agent_state.id for p in agents_passages]) + assert len(agents_passages) == passage_size # sanity check + assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}" # attach to agent - source = SourceManager().get_source_by_id(source_id=source_id, actor=user) + source = source_manager.get_source_by_id(source_id=source_id, actor=user) assert source is not None, f"Source {source_id} not found in metadata store" + + # NOTE: need this redundant line here because we haven't migrated agent to ORM yet + # TODO: delete @matt and remove ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id) - total_agent_passages = self.archival_memory.storage.size() - printd( - f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.", + f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.", ) def update_message(self, message_id: str, request: MessageUpdate) -> Message: @@ -1565,13 +1554,13 @@ class Agent(BaseAgent): num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0 ) - num_archival_memory = self.archival_memory.storage.size() + passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id) message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id) external_memory_summary = compile_memory_metadata_block( actor=self.user, agent_id=self.agent_state.id, memory_edit_timestamp=get_utc_time(), # dummy timestamp - archival_memory=self.archival_memory, + passage_manager=self.passage_manager, message_manager=self.message_manager, ) num_tokens_external_memory_summary = count_tokens(external_memory_summary) @@ -1597,7 +1586,7 @@ class Agent(BaseAgent): return ContextWindowOverview( # context window breakdown (in messages) num_messages=len(self._messages), - num_archival_memory=num_archival_memory, + num_archival_memory=passage_manager_size, num_recall_memory=message_manager_size, num_tokens_external_memory_summary=num_tokens_external_memory_summary, # top-level information diff --git a/letta/agent_store/chroma.py b/letta/agent_store/chroma.py deleted file mode 100644 index eace737b3..000000000 --- a/letta/agent_store/chroma.py +++ /dev/null @@ -1,297 +0,0 @@ -from typing import Dict, List, Optional, Tuple, cast - -import chromadb -from chromadb.api.types import Include - -from letta.agent_store.storage import StorageConnector, TableType -from letta.config import LettaConfig -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.passage import Passage -from letta.utils import datetime_to_timestamp, printd, timestamp_to_datetime - - -class ChromaStorageConnector(StorageConnector): - """Storage via Chroma""" - - # WARNING: This is not thread safe. Do NOT do concurrent access to the same collection. - # Timestamps are converted to integer timestamps for chroma (datetime not supported) - - def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None): - super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id) - - assert table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES, "Chroma only supports archival memory" - - # create chroma client - if config.archival_storage_path: - self.client = chromadb.PersistentClient(config.archival_storage_path) - else: - # assume uri={ip}:{port} - ip = config.archival_storage_uri.split(":")[0] - port = config.archival_storage_uri.split(":")[1] - self.client = chromadb.HttpClient(host=ip, port=port) - - # get a collection or create if it doesn't exist already - self.collection = self.client.get_or_create_collection(self.table_name) - self.include: Include = ["documents", "embeddings", "metadatas"] - - def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]: - # get all filters for query - if filters is not None: - filter_conditions = {**self.filters, **filters} - else: - filter_conditions = self.filters - - # convert to chroma format - chroma_filters = [] - ids = [] - for key, value in filter_conditions.items(): - # filter by id - if key == "id": - ids = [str(value)] - continue - - # filter by other keys - chroma_filters.append({key: {"$eq": value}}) - - if len(chroma_filters) > 1: - chroma_filters = {"$and": chroma_filters} - elif len(chroma_filters) == 0: - chroma_filters = {} - else: - chroma_filters = chroma_filters[0] - return ids, chroma_filters - - def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0): - ids, filters = self.get_filters(filters) - while True: - # Retrieve a chunk of records with the given page_size - results = self.collection.get(ids=ids, offset=offset, limit=page_size, include=self.include, where=filters) - - # If the chunk is empty, we've retrieved all records - assert results["embeddings"] is not None, f"results['embeddings'] was None" - if len(results["embeddings"]) == 0: - break - - # Yield a list of Record objects converted from the chunk - yield self.results_to_records(results) - - # Increment the offset to get the next chunk in the next iteration - offset += page_size - - def results_to_records(self, results): - # convert timestamps to datetime - for metadata in results["metadatas"]: - if "created_at" in metadata: - metadata["created_at"] = timestamp_to_datetime(metadata["created_at"]) - if results["embeddings"]: # may not be returned, depending on table type - passages = [] - for text, record_id, embedding, metadata in zip( - results["documents"], results["ids"], results["embeddings"], results["metadatas"] - ): - args = {} - for field in EmbeddingConfig.__fields__.keys(): - if field in metadata: - args[field] = metadata[field] - del metadata[field] - embedding_config = EmbeddingConfig(**args) - passages.append(Passage(text=text, embedding=embedding, id=record_id, embedding_config=embedding_config, **metadata)) - # return [ - # Passage(text=text, embedding=embedding, id=record_id, embedding_config=EmbeddingConfig(), **metadatas) - # for (text, record_id, embedding, metadatas) in zip( - # results["documents"], results["ids"], results["embeddings"], results["metadatas"] - # ) - # ] - return passages - else: - # no embeddings - passages = [] - for text, id, metadata in zip(results["documents"], results["ids"], results["metadatas"]): - args = {} - for field in EmbeddingConfig.__fields__.keys(): - if field in metadata: - args[field] = metadata[field] - del metadata[field] - embedding_config = EmbeddingConfig(**args) - passages.append(Passage(text=text, embedding=None, id=id, embedding_config=embedding_config, **metadata)) - return passages - - # return [ - # #cast(Passage, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore - # Passage(text=text, embedding=None, id=id, **metadatas) - # for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"]) - # ] - - def get_all(self, filters: Optional[Dict] = {}, limit=None): - ids, filters = self.get_filters(filters) - if self.collection.count() == 0: - return [] - if ids == []: - ids = None - if limit: - results = self.collection.get(ids=ids, include=self.include, where=filters, limit=limit) - else: - results = self.collection.get(ids=ids, include=self.include, where=filters) - return self.results_to_records(results) - - def get(self, id: str): - results = self.collection.get(ids=[str(id)]) - if len(results["ids"]) == 0: - return None - return self.results_to_records(results)[0] - - def format_records(self, records): - assert all([isinstance(r, Passage) for r in records]) - - recs = [] - ids = [] - documents = [] - embeddings = [] - - # de-duplication of ids - exist_ids = set() - for i in range(len(records)): - record = records[i] - if record.id in exist_ids: - continue - exist_ids.add(record.id) - recs.append(cast(Passage, record)) - ids.append(str(record.id)) - documents.append(record.text) - embeddings.append(record.embedding) - - # collect/format record metadata - metadatas = [] - for record in recs: - embedding_config = vars(record.embedding_config) - metadata = vars(record) - metadata.pop("id") - metadata.pop("text") - metadata.pop("embedding") - metadata.pop("embedding_config") - metadata.pop("metadata_") - if "created_at" in metadata: - metadata["created_at"] = datetime_to_timestamp(metadata["created_at"]) - if "metadata_" in metadata and metadata["metadata_"] is not None: - record_metadata = dict(metadata["metadata_"]) - metadata.pop("metadata_") - else: - record_metadata = {} - - metadata = {**metadata, **record_metadata} # merge with metadata - metadata = {**metadata, **embedding_config} # merge with embedding config - metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed - - # convert uuids to strings - metadatas.append(metadata) - return ids, documents, embeddings, metadatas - - def insert(self, record): - ids, documents, embeddings, metadatas = self.format_records([record]) - if any([e is None for e in embeddings]): - raise ValueError("Embeddings must be provided to chroma") - self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas) - - def insert_many(self, records, show_progress=False): - ids, documents, embeddings, metadatas = self.format_records(records) - if any([e is None for e in embeddings]): - raise ValueError("Embeddings must be provided to chroma") - self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas) - - def delete(self, filters: Optional[Dict] = {}): - ids, filters = self.get_filters(filters) - self.collection.delete(ids=ids, where=filters) - - def delete_table(self): - # drop collection - self.client.delete_collection(self.collection.name) - - def save(self): - # save to persistence file (nothing needs to be done) - printd("Saving chroma") - - def size(self, filters: Optional[Dict] = {}) -> int: - # unfortuantely, need to use pagination to get filtering - # warning: poor performance for large datasets - return len(self.get_all(filters=filters)) - - def list_data_sources(self): - raise NotImplementedError - - def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}): - ids, filters = self.get_filters(filters) - results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters) - - # flatten, since we only have one query vector - flattened_results = {} - for key, value in results.items(): - if value: - # value is an Optional[List] type according to chromadb.api.types - flattened_results[key] = value[0] # type: ignore - assert len(value) == 1, f"Value is size {len(value)}: {value}" # type: ignore - else: - flattened_results[key] = value - - return self.results_to_records(flattened_results) - - def query_date(self, start_date, end_date, start=None, count=None): - raise ValueError("Cannot run query_date with chroma") - # filters = self.get_filters(filters) - # filters["created_at"] = { - # "$gte": start_date, - # "$lte": end_date, - # } - # results = self.collection.query(where=filters) - # start = 0 if start is None else start - # count = len(results) if count is None else count - # results = results[start : start + count] - # return self.results_to_records(results) - - def query_text(self, query, count=None, start=None, filters: Optional[Dict] = {}): - raise ValueError("Cannot run query_text with chroma") - # filters = self.get_filters(filters) - # results = self.collection.query(where_document={"$contains": {"text": query}}, where=filters) - # start = 0 if start is None else start - # count = len(results) if count is None else count - # results = results[start : start + count] - # return self.results_to_records(results) - - def get_all_cursor( - self, - filters: Optional[Dict] = {}, - after: str = None, - before: str = None, - limit: Optional[int] = 1000, - order_by: str = "created_at", - reverse: bool = False, - ): - records = self.get_all(filters=filters) - - # WARNING: very hacky and slow implementation - def get_index(id, record_list): - for i in range(len(record_list)): - if record_list[i].id == id: - return i - assert False, f"Could not find id {id} in record list" - - # sort by custom field - records = sorted(records, key=lambda x: getattr(x, order_by), reverse=reverse) - if after: - index = get_index(after, records) - if index + 1 >= len(records): - return None, [] - records = records[index + 1 :] - if before: - index = get_index(before, records) - if index == 0: - return None, [] - - # TODO: not sure if this is correct - records = records[:index] - - if len(records) == 0: - return None, [] - - # enforce limit - if limit: - records = records[:limit] - return records[-1].id, records diff --git a/letta/agent_store/db.py b/letta/agent_store/db.py index 56d35edc7..095a0e82c 100644 --- a/letta/agent_store/db.py +++ b/letta/agent_store/db.py @@ -1,4 +1,5 @@ import base64 +import json import os from datetime import datetime from typing import Dict, List, Optional @@ -32,7 +33,7 @@ from letta.orm.base import Base from letta.orm.file import FileMetadata as FileMetadataModel # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall -from letta.schemas.passage import Passage +from letta.orm.passage import Passage as PassageModel from letta.settings import settings config = LettaConfig() @@ -66,56 +67,6 @@ class CommonVector(TypeDecorator): # For PostgreSQL, value is already in bytes return np.frombuffer(value, dtype=np.float32) - -class PassageModel(Base): - """Defines data model for storing Passages (consisting of text, embedding)""" - - __tablename__ = "passages" - __table_args__ = {"extend_existing": True} - - # Assuming passage_id is the primary key - id = Column(String, primary_key=True) - user_id = Column(String, nullable=False) - text = Column(String) - file_id = Column(String) - agent_id = Column(String) - source_id = Column(String) - - # vector storage - if settings.letta_pg_uri_no_default: - from pgvector.sqlalchemy import Vector - - embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) - elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma": - embedding = Column(CommonVector) - else: - raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}") - embedding_config = Column(EmbeddingConfigColumn) - metadata_ = Column(MutableJson) - - # Add a datetime column, with default value as the current time - created_at = Column(DateTime(timezone=True)) - - Index("passage_idx_user", user_id, agent_id, file_id), - - def __repr__(self): - return f" Optional[str]: Returns: Optional[str]: None is always returned as this function does not produce a response. """ - self.archival_memory.insert(content) + self.passage_manager.insert_passage( + agent_state=self.agent_state, + agent_id=self.agent_state.id, + text=content, + actor=self.user, + ) return None -def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]: +def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]: """ Search archival memory using semantic (embedding-based) search. Args: query (str): String to search for. page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page). + start (Optional[int]): Starting index for the search results. Defaults to 0. Returns: str: Query result string @@ -191,15 +197,34 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) - except: raise ValueError(f"'page' argument must be an integer") count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE - results, total = self.archival_memory.search(query, count=count, start=page * count) - num_pages = math.ceil(total / count) - 1 # 0 index - if len(results) == 0: - results_str = f"No results found." - else: - results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):" - results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results] - results_str = f"{results_pref} {json_dumps(results_formatted)}" - return results_str + + try: + # Get results using passage manager + all_results = self.passage_manager.list_passages( + actor=self.user, + query_text=query, + limit=count + start, # Request enough results to handle offset + embedding_config=self.agent_state.embedding_config, + embed_query=True + ) + + # Apply pagination + end = min(count + start, len(all_results)) + paged_results = all_results[start:end] + + # Format results to match previous implementation + formatted_results = [ + { + "timestamp": str(result.created_at), + "content": result.text + } + for result in paged_results + ] + + return formatted_results, len(formatted_results) + + except Exception as e: + raise e def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore diff --git a/letta/metadata.py b/letta/metadata.py index 56d852eac..017e546e4 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -363,8 +363,19 @@ class MetadataStore: with self.session_maker() as session: # TODO: remove this (is a hack) mapping_id = f"{user_id}-{agent_id}-{source_id}" - session.add(AgentSourceMappingModel(id=mapping_id, user_id=user_id, agent_id=agent_id, source_id=source_id)) - session.commit() + existing = session.query(AgentSourceMappingModel).filter( + AgentSourceMappingModel.id == mapping_id + ).first() + + if existing is None: + # Only create if it doesn't exist + session.add(AgentSourceMappingModel( + id=mapping_id, + user_id=user_id, + agent_id=agent_id, + source_id=source_id + )) + session.commit() @enforce_types def list_attached_source_ids(self, agent_id: str) -> List[str]: diff --git a/letta/orm/__init__.py b/letta/orm/__init__.py index 85b4b7eb3..b7f7bb96f 100644 --- a/letta/orm/__init__.py +++ b/letta/orm/__init__.py @@ -6,6 +6,7 @@ from letta.orm.file import FileMetadata from letta.orm.job import Job from letta.orm.message import Message from letta.orm.organization import Organization +from letta.orm.passage import Passage from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable from letta.orm.source import Source from letta.orm.tool import Tool diff --git a/letta/orm/file.py b/letta/orm/file.py index 187ebbd88..6f7111639 100644 --- a/letta/orm/file.py +++ b/letta/orm/file.py @@ -27,3 +27,4 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin") source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin") + passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="file", lazy="selectin", cascade="all, delete-orphan") diff --git a/letta/orm/mixins.py b/letta/orm/mixins.py index 355a8b2ce..60c319d98 100644 --- a/letta/orm/mixins.py +++ b/letta/orm/mixins.py @@ -1,3 +1,4 @@ +from typing import Optional from uuid import UUID from sqlalchemy import ForeignKey, String @@ -30,6 +31,12 @@ class UserMixin(Base): user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id")) +class FileMixin(Base): + """Mixin for models that belong to a file.""" + + __abstract__ = True + + file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id")) class AgentMixin(Base): """Mixin for models that belong to an agent.""" @@ -38,13 +45,16 @@ class AgentMixin(Base): agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id")) - class FileMixin(Base): """Mixin for models that belong to a file.""" __abstract__ = True - file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id")) + file_id: Mapped[Optional[str]] = mapped_column( + String, + ForeignKey("files.id", ondelete="CASCADE"), + nullable=True + ) class SourceMixin(Base): diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 4e5b6d12c..8dc56e162 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -33,7 +33,10 @@ class Organization(SqlalchemyBase): sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship( "SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan" ) + + # relationships messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan") + passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan") # TODO: Map these relationships later when we actually make these models # below is just a suggestion diff --git a/letta/orm/passage.py b/letta/orm/passage.py new file mode 100644 index 000000000..847c8ddd7 --- /dev/null +++ b/letta/orm/passage.py @@ -0,0 +1,72 @@ +from datetime import datetime +from typing import List, Optional, TYPE_CHECKING +from sqlalchemy import Column, String, DateTime, Index, JSON, UniqueConstraint, ForeignKey +from sqlalchemy.orm import Mapped, mapped_column, relationship +from sqlalchemy.types import TypeDecorator, BINARY + +import numpy as np +import base64 + +from letta.orm.source import EmbeddingConfigColumn +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin +from letta.schemas.passage import Passage as PydanticPassage + +from letta.config import LettaConfig +from letta.constants import MAX_EMBEDDING_DIM +from letta.settings import settings + +config = LettaConfig() + +if TYPE_CHECKING: + from letta.orm.file import File + from letta.orm.organization import Organization + +class CommonVector(TypeDecorator): + """Common type for representing vectors in SQLite""" + impl = BINARY + cache_ok = True + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(BINARY()) + + def process_bind_param(self, value, dialect): + if value is None: + return value + if isinstance(value, list): + value = np.array(value, dtype=np.float32) + return base64.b64encode(value.tobytes()) + + def process_result_value(self, value, dialect): + if not value: + return value + if dialect.name == "sqlite": + value = base64.b64decode(value) + return np.frombuffer(value, dtype=np.float32) + +# TODO: After migration to Passage, will need to manually delete passages where files +# are deleted on web +class Passage(SqlalchemyBase, OrganizationMixin, FileMixin): + """Defines data model for storing Passages""" + __tablename__ = "passages" + __table_args__ = {"extend_existing": True} + __pydantic_model__ = PydanticPassage + + id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier") + text: Mapped[str] = mapped_column(doc="Passage text content") + source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier") + embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration") + metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata") + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow) + if settings.letta_pg_uri_no_default: + from pgvector.sqlalchemy import Vector + embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) + else: + embedding = Column(CommonVector) + + # Foreign keys + agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True) + + # Relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin") + file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin") diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 6f8a76440..74d3f3be4 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -1,6 +1,7 @@ from datetime import datetime from enum import Enum from typing import TYPE_CHECKING, List, Literal, Optional, Type +import sqlite3 from sqlalchemy import String, desc, func, or_, select from sqlalchemy.exc import DBAPIError @@ -8,6 +9,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from letta.log import get_logger from letta.orm.base import Base, CommonSqlalchemyMetaMixins +from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance from letta.orm.errors import ( ForeignKeyConstraintViolationError, NoResultFound, @@ -60,6 +62,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): end_date: Optional[datetime] = None, limit: Optional[int] = 50, query_text: Optional[str] = None, + query_embedding: Optional[List[float]] = None, ascending: bool = True, **kwargs, ) -> List[Type["SqlalchemyBase"]]: @@ -110,17 +113,43 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): # Apply text search if query_text: + from sqlalchemy import func query = query.filter(func.lower(cls.text).contains(func.lower(query_text))) - # Handle soft deletes + # Apply embedding search (Passages) + is_ordered = False + if query_embedding: + # check if embedding column exists. should only exist for passages + if not hasattr(cls, "embedding"): + raise ValueError(f"Class {cls.__name__} does not have an embedding column") + + from letta.settings import settings + if settings.letta_pg_uri_no_default: + # PostgreSQL with pgvector + from pgvector.sqlalchemy import Vector + query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc()) + else: + # SQLite with custom vector type + from sqlalchemy import func + + query_embedding_binary = adapt_array(query_embedding) + query = query.order_by( + func.cosine_distance(cls.embedding, query_embedding_binary).asc(), + cls.created_at.asc(), + cls.id.asc() + ) + is_ordered = True + + # Handle ordering and soft deletes if hasattr(cls, "is_deleted"): query = query.where(cls.is_deleted == False) - + # Apply ordering by created_at - if ascending: - query = query.order_by(cls.created_at, cls.id) - else: - query = query.order_by(desc(cls.created_at), desc(cls.id)) + if not is_ordered: + if ascending: + query = query.order_by(cls.created_at, cls.id) + else: + query = query.order_by(desc(cls.created_at), desc(cls.id)) query = query.limit(limit) @@ -369,4 +398,4 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): def to_record(self) -> Type["BaseModel"]: """Deprecated accessor for to_pydantic""" logger.warning("to_record is deprecated, use to_pydantic instead.") - return self.to_pydantic() + return self.to_pydantic() \ No newline at end of file diff --git a/letta/orm/sqlite_functions.py b/letta/orm/sqlite_functions.py new file mode 100644 index 000000000..a5b741aa5 --- /dev/null +++ b/letta/orm/sqlite_functions.py @@ -0,0 +1,140 @@ +from typing import Optional, Union + +import base64 +import numpy as np +from sqlalchemy import event +from sqlalchemy.engine import Engine +import sqlite3 + +from letta.constants import MAX_EMBEDDING_DIM + +def adapt_array(arr): + """ + Converts numpy array to binary for SQLite storage + """ + if arr is None: + return None + + if isinstance(arr, list): + arr = np.array(arr, dtype=np.float32) + elif not isinstance(arr, np.ndarray): + raise ValueError(f"Unsupported type: {type(arr)}") + + # Convert to bytes and then base64 encode + bytes_data = arr.tobytes() + base64_data = base64.b64encode(bytes_data) + return sqlite3.Binary(base64_data) + +def convert_array(text): + """ + Converts binary back to numpy array + """ + if text is None: + return None + if isinstance(text, list): + return np.array(text, dtype=np.float32) + if isinstance(text, np.ndarray): + return text + + # Handle both bytes and sqlite3.Binary + binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text + + try: + # First decode base64 + decoded_data = base64.b64decode(binary_data) + # Then convert to numpy array + return np.frombuffer(decoded_data, dtype=np.float32) + except Exception as e: + return None + +def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool: + """ + Verifies that an embedding has the expected dimension + + Args: + embedding: Input embedding array + expected_dim: Expected embedding dimension (default: 4096) + + Returns: + bool: True if dimension matches, False otherwise + """ + if embedding is None: + return False + return embedding.shape[0] == expected_dim + +def validate_and_transform_embedding( + embedding: Union[bytes, sqlite3.Binary, list, np.ndarray], + expected_dim: int = MAX_EMBEDDING_DIM, + dtype: np.dtype = np.float32 +) -> Optional[np.ndarray]: + """ + Validates and transforms embeddings to ensure correct dimensionality. + + Args: + embedding: Input embedding in various possible formats + expected_dim: Expected embedding dimension (default 4096) + dtype: NumPy dtype for the embedding (default float32) + + Returns: + np.ndarray: Validated and transformed embedding + + Raises: + ValueError: If embedding dimension doesn't match expected dimension + """ + if embedding is None: + return None + + # Convert to numpy array based on input type + if isinstance(embedding, (bytes, sqlite3.Binary)): + vec = convert_array(embedding) + elif isinstance(embedding, list): + vec = np.array(embedding, dtype=dtype) + elif isinstance(embedding, np.ndarray): + vec = embedding.astype(dtype) + else: + raise ValueError(f"Unsupported embedding type: {type(embedding)}") + + # Validate dimension + if vec.shape[0] != expected_dim: + raise ValueError( + f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}" + ) + + return vec + +def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM): + """ + Calculate cosine distance between two embeddings + + Args: + embedding1: First embedding + embedding2: Second embedding + expected_dim: Expected embedding dimension (default 4096) + + Returns: + float: Cosine distance + """ + + if embedding1 is None or embedding2 is None: + return 0.0 # Maximum distance if either embedding is None + + try: + vec1 = validate_and_transform_embedding(embedding1, expected_dim) + vec2 = validate_and_transform_embedding(embedding2, expected_dim) + except ValueError as e: + return 0.0 + + similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) + distance = float(1.0 - similarity) + + return distance + +@event.listens_for(Engine, "connect") +def register_functions(dbapi_connection, connection_record): + """Register SQLite functions""" + if isinstance(dbapi_connection, sqlite3.Connection): + dbapi_connection.create_function("cosine_distance", 2, cosine_distance) + +# Register adapters and converters for numpy arrays +sqlite3.register_adapter(np.ndarray, adapt_array) +sqlite3.register_converter("ARRAY", convert_array) diff --git a/letta/orm/user.py b/letta/orm/user.py index a44c31ab0..62a3c0e60 100644 --- a/letta/orm/user.py +++ b/letta/orm/user.py @@ -20,7 +20,7 @@ class User(SqlalchemyBase, OrganizationMixin): # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="users") - jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.") + jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan") # TODO: Add this back later potentially # agents: Mapped[List["Agent"]] = relationship( diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index 2ecc5e9ac..faa520c03 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -5,15 +5,17 @@ from pydantic import Field, field_validator from letta.constants import MAX_EMBEDDING_DIM from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.letta_base import LettaBase +from letta.schemas.letta_base import OrmMetadataBase from letta.utils import get_utc_time -class PassageBase(LettaBase): - __id_prefix__ = "passage" +class PassageBase(OrmMetadataBase): + __id_prefix__ = "passage_legacy" + + is_deleted: bool = Field(False, description="Whether this passage is deleted or not.") # associated user/agent - user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.") + organization_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.") agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.") # origin data source diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 9dfd7e2bf..06b0acd60 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -369,8 +369,7 @@ def get_agent_archival_memory( return server.get_agent_archival_cursor( user_id=actor.id, agent_id=agent_id, - after=after, - before=before, + cursor=after, # TODO: deleting before, after. is this expected? limit=limit, ) diff --git a/letta/server/server.py b/letta/server/server.py index c60abe484..1a5de01e9 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -16,7 +16,6 @@ import letta.constants as constants import letta.server.utils as server_utils import letta.system as system from letta.agent import Agent, save_agent -from letta.agent_store.db import attach_base from letta.agent_store.storage import StorageConnector, TableType from letta.chat_only_agent import ChatOnlyAgent from letta.credentials import LettaCredentials @@ -70,17 +69,18 @@ from letta.schemas.memory import ( ) from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate from letta.schemas.organization import Organization -from letta.schemas.passage import Passage +from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.source import Source from letta.schemas.tool import Tool, ToolCreate from letta.schemas.usage import LettaUsageStatistics -from letta.schemas.user import User +from letta.schemas.user import User as PydanticUser from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.block_manager import BlockManager from letta.services.blocks_agents_manager import BlocksAgentsManager from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager +from letta.services.passage_manager import PassageManager from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager @@ -125,7 +125,7 @@ class Server(object): def create_agent( self, request: CreateAgent, - actor: User, + actor: PydanticUser, # interface interface: Union[AgentInterface, None] = None, ) -> AgentState: @@ -166,8 +166,6 @@ from letta.settings import model_settings, settings, tool_settings config = LettaConfig.load() -attach_base() - if settings.letta_pg_uri_no_default: config.recall_storage_type = "postgres" config.recall_storage_uri = settings.letta_pg_uri_no_default @@ -245,6 +243,7 @@ class SyncServer(Server): # Managers that interface with data models self.organization_manager = OrganizationManager() + self.passage_manager = PassageManager() self.user_manager = UserManager() self.tool_manager = ToolManager() self.block_manager = BlockManager() @@ -498,7 +497,12 @@ class SyncServer(Server): # attach data to agent from source source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) - letta_agent.attach_source(data_source, source_connector, self.ms) + letta_agent.attach_source( + user=self.user_manager.get_user_by_id(user_id=user_id), + source_id=data_source, + source_manager=letta_agent.source_manager, + ms=self.ms + ) elif command.lower() == "dump" or command.lower().startswith("dump "): # Check if there's an additional argument that's an integer @@ -513,7 +517,7 @@ class SyncServer(Server): letta_agent.interface.print_messages_raw(letta_agent.messages) elif command.lower() == "memory": - ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.archival_memory)}" + ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}" return ret_str elif command.lower() == "pop" or command.lower().startswith("pop "): @@ -769,7 +773,7 @@ class SyncServer(Server): def create_agent( self, request: CreateAgent, - actor: User, + actor: PydanticUser, # interface interface: Union[AgentInterface, None] = None, ) -> AgentState: @@ -921,6 +925,7 @@ class SyncServer(Server): # get `Tool` objects tools = [self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=user) for tool_name in agent_state.tool_names] + tools = [tool for tool in tools if tool is not None] # get `Source` objects sources = self.list_attached_sources(agent_id=agent_id) @@ -934,7 +939,7 @@ class SyncServer(Server): def update_agent( self, request: UpdateAgentState, - actor: User, + actor: PydanticUser, ) -> AgentState: """Update the agents core memory block, return the new state""" try: @@ -1151,7 +1156,7 @@ class SyncServer(Server): def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary: agent = self.load_agent(agent_id=agent_id) - return ArchivalMemorySummary(size=len(agent.archival_memory)) + return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user)) def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary: agent = self.load_agent(agent_id=agent_id) @@ -1176,7 +1181,56 @@ class SyncServer(Server): message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user) return message - def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]: + def get_agent_messages( + self, + agent_id: str, + start: int, + count: int, + ) -> Union[List[Message], List[LettaMessage]]: + """Paginated query of all messages in agent message queue""" + # Get the agent object (loaded in memory) + letta_agent = self.load_agent(agent_id=agent_id) + + if start < 0 or count < 0: + raise ValueError("Start and count values should be non-negative") + + if start + count < len(letta_agent._messages): # messages can be returned from whats in memory + # Reverse the list to make it in reverse chronological order + reversed_messages = letta_agent._messages[::-1] + # Check if start is within the range of the list + if start >= len(reversed_messages): + raise IndexError("Start index is out of range") + + # Calculate the end index, ensuring it does not exceed the list length + end_index = min(start + count, len(reversed_messages)) + + # Slice the list for pagination + messages = reversed_messages[start:end_index] + + else: + # need to access persistence manager for additional messages + + # get messages using message manager + page = letta_agent.message_manager.list_user_messages_for_agent( + agent_id=agent_id, + actor=self.default_user, + cursor=start, + limit=count, + ) + + messages = page + assert all(isinstance(m, Message) for m in messages) + + ## Convert to json + ## Add a tag indicating in-context or not + # json_messages = [record.to_json() for record in messages] + # in_context_message_ids = [str(m.id) for m in letta_agent._messages] + # for d in json_messages: + # d["in_context"] = True if str(d["id"]) in in_context_message_ids else False + + return messages + + def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[PydanticPassage]: """Paginated query of all messages in agent archival memory""" if self.user_manager.get_user_by_id(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -1187,22 +1241,22 @@ class SyncServer(Server): letta_agent = self.load_agent(agent_id=agent_id) # iterate over records - db_iterator = letta_agent.archival_memory.storage.get_all_paginated(page_size=count, offset=start) + records = letta_agent.passage_manager.list_passages( + actor=self.default_user, + agent_id=agent_id, + cursor=cursor, + limit=limit, + ) - # get a single page of messages - page = next(db_iterator, []) - return page + return records def get_agent_archival_cursor( self, user_id: str, agent_id: str, - after: Optional[str] = None, - before: Optional[str] = None, + cursor: Optional[str] = None, limit: Optional[int] = 100, - order_by: Optional[str] = "created_at", - reverse: Optional[bool] = False, - ) -> List[Passage]: + ) -> List[PydanticPassage]: if self.user_manager.get_user_by_id(user_id=user_id) is None: raise LettaUserNotFoundError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: @@ -1211,14 +1265,15 @@ class SyncServer(Server): # Get the agent object (loaded in memory) letta_agent = self.load_agent(agent_id=agent_id) - # iterate over recorde - cursor, records = letta_agent.archival_memory.storage.get_all_cursor( - after=after, before=before, limit=limit, order_by=order_by, reverse=reverse + # iterate over records + records = letta_agent.passage_manager.list_passages( + actor=self.default_user, agent_id=agent_id, cursor=cursor, limit=limit, ) return records - def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[Passage]: - if self.user_manager.get_user_by_id(user_id=user_id) is None: + def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[PydanticPassage]: + actor = self.user_manager.get_user_by_id(user_id=user_id) + if actor is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1227,17 +1282,20 @@ class SyncServer(Server): letta_agent = self.load_agent(agent_id=agent_id) # Insert into archival memory - passage_ids = letta_agent.archival_memory.insert(memory_string=memory_contents, return_ids=True) + passage_ids = self.passage_manager.insert_passage( + agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor, return_ids=True + ) # Update the agent # TODO: should this update the system prompt? save_agent(letta_agent, self.ms) # TODO: this is gross, fix - return [letta_agent.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids] + return [self.passage_manager.get_passage_by_id(passage_id=passage_id, actor=actor) for passage_id in passage_ids] def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str): - if self.user_manager.get_user_by_id(user_id=user_id) is None: + actor = self.user_manager.get_user_by_id(user_id=user_id) + if actor is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") @@ -1249,7 +1307,7 @@ class SyncServer(Server): # Delete by ID # TODO check if it exists first, and throw error if not - letta_agent.archival_memory.storage.delete({"id": memory_id}) + letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor) # TODO: return archival memory @@ -1395,6 +1453,12 @@ class SyncServer(Server): except NoResultFound: logger.error(f"Agent with id {agent_state.id} has nonexistent user {agent_state.user_id}") + # delete all passages associated with this agent + # TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM + passages = self.passage_manager.list_passages(actor=actor, agent_id=agent_state.id) + for passage in passages: + self.passage_manager.delete_passage_by_id(passage.id, actor=actor) + # First, if the agent is in the in-memory cache we should remove it # List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts try: @@ -1437,7 +1501,7 @@ class SyncServer(Server): self.ms.delete_api_key(api_key=api_key) return api_key_obj - def delete_source(self, source_id: str, actor: User): + def delete_source(self, source_id: str, actor: PydanticUser): """Delete a data source""" self.source_manager.delete_source(source_id=source_id, actor=actor) @@ -1447,7 +1511,7 @@ class SyncServer(Server): # TODO: delete data from agent passage stores (?) - def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job: + def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: PydanticUser) -> Job: # update job job = self.job_manager.get_job_by_id(job_id, actor=actor) @@ -1474,6 +1538,7 @@ class SyncServer(Server): user_id: str, connector: DataConnector, source_name: str, + agent_id: Optional[str] = None, ) -> Tuple[int, int]: """Load data from a DataConnector into a source for a specified user_id""" # TODO: this should be implemented as a batch job or at least async, since it may take a long time @@ -1488,14 +1553,13 @@ class SyncServer(Server): passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) # load data into the document store - passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user) + passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user, agent_id=agent_id) return passage_count, document_count def attach_source_to_agent( self, user_id: str, agent_id: str, - # source_id: str, source_id: Optional[str] = None, source_name: Optional[str] = None, ) -> Source: @@ -1507,15 +1571,14 @@ class SyncServer(Server): data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user) else: raise ValueError(f"Need to provide at least source_id or source_name to find the source.") - # get connection to data source storage - source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + assert data_source, f"Data source with id={source_id} or name={source_name} does not exist" # load agent agent = self.load_agent(agent_id=agent_id) # attach source to agent - agent.attach_source(data_source.id, source_connector, self.ms) + agent.attach_source(user=user, source_id=data_source.id, source_manager=self.source_manager, ms=self.ms) return data_source @@ -1538,8 +1601,7 @@ class SyncServer(Server): # delete all Passage objects with source_id==source_id from agent's archival memory agent = self.load_agent(agent_id=agent_id) - archival_memory = agent.archival_memory - archival_memory.storage.delete({"source_id": source_id}) + agent.passage_manager.delete_passages(actor=user, limit=100, source_id=source_id) # delete agent-source mapping self.ms.detach_source(agent_id=agent_id, source_id=source_id) @@ -1553,11 +1615,11 @@ class SyncServer(Server): return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids] - def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: + def list_data_source_passages(self, user_id: str, source_id: str) -> List[PydanticPassage]: warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) return [] - def list_all_sources(self, actor: User) -> List[Source]: + def list_all_sources(self, actor: PydanticUser) -> List[Source]: """List all sources (w/ extra metadata) belonging to a user""" sources = self.source_manager.list_sources(actor=actor) @@ -1597,7 +1659,7 @@ class SyncServer(Server): return sources_with_metadata - def add_default_external_tools(self, actor: User) -> bool: + def add_default_external_tools(self, actor: PydanticUser) -> bool: """Add default langchain tools. Return true if successful, false otherwise.""" success = True tool_creates = ToolCreate.load_default_langchain_tools() @@ -1654,7 +1716,7 @@ class SyncServer(Server): save_agent(letta_agent, self.ms) return response - def get_user_or_default(self, user_id: Optional[str]) -> User: + def get_user_or_default(self, user_id: Optional[str]) -> PydanticUser: """Get the user object for user_id if it exists, otherwise return the default user object""" if user_id is None: user_id = self.user_manager.DEFAULT_USER_ID diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py new file mode 100644 index 000000000..c1933b394 --- /dev/null +++ b/letta/services/passage_manager.py @@ -0,0 +1,225 @@ +from typing import List, Optional, Dict, Tuple +from letta.constants import MAX_EMBEDDING_DIM +from datetime import datetime +import numpy as np + +from letta.orm.errors import NoResultFound +from letta.utils import enforce_types + +from letta.embeddings import embedding_model, parse_and_chunk_text +from letta.schemas.embedding_config import EmbeddingConfig + +from letta.orm.passage import Passage as PassageModel +from letta.orm.sqlalchemy_base import AccessType +from letta.schemas.agent import AgentState +from letta.schemas.passage import Passage as PydanticPassage +from letta.schemas.user import User as PydanticUser + +class PassageManager: + """Manager class to handle business logic related to Passages.""" + + def __init__(self): + from letta.server.server import db_context + self.session_maker = db_context + + @enforce_types + def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: + """Fetch a passage by ID.""" + with self.session_maker() as session: + try: + passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() + except NoResultFound: + return None + + @enforce_types + def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: + """Create a new passage.""" + with self.session_maker() as session: + passage = PassageModel(**pydantic_passage.model_dump()) + passage.create(session, actor=actor) + return passage.to_pydantic() + + @enforce_types + def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: + """Create multiple passages.""" + return [self.create_passage(p, actor) for p in passages] + + @enforce_types + def insert_passage(self, + agent_state: AgentState, + agent_id: str, + text: str, + actor: PydanticUser, + return_ids: bool = False + ) -> List[PydanticPassage]: + """ Insert passage(s) into archival memory """ + + embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size + embed_model = embedding_model(agent_state.embedding_config) + + passages = [] + + try: + # breakup string into passages + for text in parse_and_chunk_text(text, embedding_chunk_size): + embedding = embed_model.get_text_embedding(text) + if isinstance(embedding, dict): + try: + embedding = embedding["data"][0]["embedding"] + except (KeyError, IndexError): + # TODO as a fallback, see if we can find any lists in the payload + raise TypeError( + f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" + ) + passage = self.create_passage( + PydanticPassage( + organization_id=actor.organization_id, + agent_id=agent_id, + text=text, + embedding=embedding, + embedding_config=agent_state.embedding_config + ), + actor=actor + ) + passages.append(passage) + + ids = [str(p.id) for p in passages] + + if return_ids: + return ids + + return passages + + except Exception as e: + raise e + + @enforce_types + def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]: + """Update a passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + with self.session_maker() as session: + try: + # Fetch existing message from database + curr_passage = PassageModel.read( + db_session=session, + identifier=passage_id, + actor=actor, + ) + if not curr_passage: + raise ValueError(f"Passage with id {passage_id} does not exist.") + + # Update the database record with values from the provided record + update_data = passage.model_dump(exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(curr_passage, key, value) + + # Commit changes + curr_passage.update(session, actor=actor) + return curr_passage.to_pydantic() + except NoResultFound: + return None + + @enforce_types + def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: + """Delete a passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + with self.session_maker() as session: + try: + passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor) + passage.hard_delete(session, actor=actor) + except NoResultFound: + raise ValueError(f"Passage with id {passage_id} not found.") + + @enforce_types + def list_passages(self, + actor : PydanticUser, + agent_id : Optional[str] = None, + file_id : Optional[str] = None, + cursor : Optional[str] = None, + limit : Optional[int] = 50, + query_text : Optional[str] = None, + start_date : Optional[datetime] = None, + end_date : Optional[datetime] = None, + source_id : Optional[str] = None, + embed_query : bool = False, + embedding_config: Optional[EmbeddingConfig] = None + ) -> List[PydanticPassage]: + """List passages with pagination.""" + with self.session_maker() as session: + filters = {"organization_id": actor.organization_id} + if agent_id: + filters["agent_id"] = agent_id + if file_id: + filters["file_id"] = file_id + if source_id: + filters["source_id"] = source_id + + embedded_text = None + if embed_query: + assert embedding_config is not None + + # Embed the text + embedded_text = embedding_model(embedding_config).get_text_embedding(query_text) + + # Pad the embedding with zeros + embedded_text = np.array(embedded_text) + embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist() + + results = PassageModel.list( + db_session=session, + cursor=cursor, + start_date=start_date, + end_date=end_date, + limit=limit, + query_text=query_text if not embedded_text else None, + query_embedding=embedded_text, + **filters + ) + return [p.to_pydantic() for p in results] + + @enforce_types + def size( + self, + actor : PydanticUser, + agent_id : Optional[str] = None, + **kwargs + ) -> int: + """Get the total count of messages with optional filters. + + Args: + actor : The user requesting the count + agent_id: The agent ID + """ + with self.session_maker() as session: + return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs) + + def delete_passages(self, + actor: PydanticUser, + agent_id: Optional[str] = None, + file_id: Optional[str] = None, + start_date: Optional[datetime] = None, + end_date: Optional[datetime] = None, + limit: Optional[int] = 50, + cursor: Optional[str] = None, + query_text: Optional[str] = None, + source_id: Optional[str] = None + ) -> bool: + + passages = self.list_passages( + actor=actor, + agent_id=agent_id, + file_id=file_id, + cursor=cursor, + limit=limit, + start_date=start_date, + end_date=end_date, + query_text=query_text, + source_id=source_id) + + for passage in passages: + self.delete_passage_by_id(passage_id=passage.id, actor=actor) diff --git a/letta/services/source_manager.py b/letta/services/source_manager.py index f2b48e9ba..a6745cec9 100644 --- a/letta/services/source_manager.py +++ b/letta/services/source_manager.py @@ -64,7 +64,7 @@ class SourceManager: return source.to_pydantic() @enforce_types - def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticSource]: + def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]: """List all sources with optional pagination.""" with self.session_maker() as session: sources = SourceModel.list( @@ -72,6 +72,7 @@ class SourceManager: cursor=cursor, limit=limit, organization_id=actor.organization_id, + **kwargs, ) return [source.to_pydantic() for source in sources] diff --git a/letta/settings.py b/letta/settings.py index 7271ff29d..41aca91d0 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -17,6 +17,8 @@ class ToolSettings(BaseSettings): class ModelSettings(BaseSettings): + model_config = SettingsConfigDict(env_file='.env') + # env_prefix='my_prefix_' # when we use /completions APIs (instead of /chat/completions), we need to specify a model wrapper diff --git a/pyproject.toml b/pyproject.toml index cf07689a6..9b8ff2931 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,4 +104,4 @@ extend-exclude = "examples/*" [build-system] requires = ["poetry-core"] -build-backend = "poetry.core.masonry.api" +build-backend = "poetry.core.masonry.api" \ No newline at end of file diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 0668f2fd0..199800eb0 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -29,11 +29,63 @@ def agent_obj(): client.delete_agent(agent_obj.agent_state.id) -def test_archival(agent_obj): - base_functions.archival_memory_insert(agent_obj, "banana") - base_functions.archival_memory_search(agent_obj, "banana") - base_functions.archival_memory_search(agent_obj, "banana", page=0) +def query_in_search_results(search_results, query): + for result in search_results: + if query.lower() in result["content"].lower(): + return True + return False +def test_archival(agent_obj): + """Test archival memory functions comprehensively.""" + # Test 1: Basic insertion and retrieval + base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat") + base_functions.archival_memory_insert(agent_obj, "The dog plays in the park") + base_functions.archival_memory_insert(agent_obj, "Python is a programming language") + + # Test exact text search + results, _ = base_functions.archival_memory_search(agent_obj, "cat") + assert query_in_search_results(results, "cat") + + # Test semantic search (should return animal-related content) + results, _ = base_functions.archival_memory_search(agent_obj, "animal pets") + assert query_in_search_results(results, "cat") or query_in_search_results(results, "dog") + + # Test unrelated search (should not return animal content) + results, _ = base_functions.archival_memory_search(agent_obj, "programming computers") + assert query_in_search_results(results, "python") + + # Test 2: Test pagination + # Insert more items to test pagination + for i in range(10): + base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}") + + # Get first page + page0_results, next_page = base_functions.archival_memory_search(agent_obj, "Test passage", page=0) + # Get second page + page1_results, _ = base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page) + + assert page0_results != page1_results + assert query_in_search_results(page0_results, "Test passage") + assert query_in_search_results(page1_results, "Test passage") + + # Test 3: Test complex text patterns + base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John") + base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week") + base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching") + + # Search for meeting-related content + results, _ = base_functions.archival_memory_search(agent_obj, "meeting schedule") + assert query_in_search_results(results, "meeting") + assert query_in_search_results(results, "2024-01-15") or query_in_search_results(results, "next week") + + # Test 4: Test error handling + # Test invalid page number + try: + base_functions.archival_memory_search(agent_obj, "test", page="invalid") + assert False, "Should have raised ValueError" + except ValueError: + pass + def test_recall(agent_obj): base_functions.conversation_search(agent_obj, "banana") diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 65bd5f16b..fff0e466c 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -649,7 +649,6 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: system=agent.system, agent_id=agent.id, memory=agent.memory, - archival_memory=None, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True, actor=default_user, diff --git a/tests/test_managers.py b/tests/test_managers.py index 74675b4cb..f6e543664 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5,9 +5,11 @@ from datetime import datetime, timedelta import pytest from sqlalchemy import delete +from letta.embeddings import embedding_model import letta.utils as utils from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.metadata import AgentModel +from letta.orm.sqlite_functions import verify_embedding_dimension, convert_array from letta.orm import ( Block, BlocksAgents, @@ -15,6 +17,7 @@ from letta.orm import ( Job, Message, Organization, + Passage, SandboxConfig, SandboxEnvironmentVariable, Source, @@ -40,6 +43,7 @@ from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageUpdate from letta.schemas.organization import Organization as PydanticOrganization +from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.sandbox_config import ( E2BSandboxConfig, LocalSandboxConfig, @@ -55,6 +59,7 @@ from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate from letta.services.block_manager import BlockManager from letta.services.organization_manager import OrganizationManager +from letta.services.passage_manager import PassageManager from letta.services.tool_manager import ToolManager from letta.settings import tool_settings @@ -83,6 +88,7 @@ def clear_tables(server: SyncServer): """Fixture to clear the organization table before each test.""" with server.organization_manager.session_maker() as session: session.execute(delete(Message)) + session.execute(delete(Passage)) session.execute(delete(Job)) session.execute(delete(ToolsAgents)) # Clear ToolsAgents first session.execute(delete(BlocksAgents)) @@ -132,6 +138,16 @@ def default_source(server: SyncServer, default_user): yield source +@pytest.fixture +def default_file(server: SyncServer, default_source, default_user, default_organization): + file = server.source_manager.create_file( + PydanticFileMetadata( + file_name="test_file", organization_id=default_organization.id, source_id=default_source.id), + actor=default_user, + ) + yield file + + @pytest.fixture def sarah_agent(server: SyncServer, default_user, default_organization): """Fixture to create and return a sample agent within the default organization.""" @@ -197,6 +213,41 @@ def print_tool(server: SyncServer, default_user, default_organization): yield tool +@pytest.fixture +def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent): + """Fixture to create a tool with default settings and clean up after the test.""" + # Set up passage + dummy_embedding = [0.0] * 2 + message = PydanticPassage( + organization_id=default_user.organization_id, + agent_id=sarah_agent.id, + file_id=default_file.id, + text="Hello, world!", + embedding=dummy_embedding, + embedding_config=DEFAULT_EMBEDDING_CONFIG + ) + + msg = server.passage_manager.create_passage(message, actor=default_user) + yield msg + + +@pytest.fixture +def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]: + """Helper function to create test passages for all tests""" + dummy_embedding = [0] * 2 + passages = [ + PydanticPassage( + organization_id=default_user.organization_id, + agent_id=sarah_agent.id, + file_id=default_file.id, + text=f"Test passage {i}", + embedding=dummy_embedding, + embedding_config=DEFAULT_EMBEDDING_CONFIG + ) for i in range(4) + ] + server.passage_manager.create_many_passages(passages, actor=default_user) + return passages + @pytest.fixture def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent): """Fixture to create a tool with default settings and clean up after the test.""" @@ -353,6 +404,288 @@ def test_list_organizations_pagination(server: SyncServer): assert len(orgs) == 0 +# ====================================================================================================================== +# Passage Manager Tests +# ====================================================================================================================== + +def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user): + """Test creating a passage using hello_world_passage_fixture fixture""" + assert hello_world_passage_fixture.id is not None + assert hello_world_passage_fixture.text == "Hello, world!" + + # Verify we can retrieve it + retrieved = server.passage_manager.get_passage_by_id( + hello_world_passage_fixture.id, + actor=default_user, + ) + assert retrieved is not None + assert retrieved.id == hello_world_passage_fixture.id + assert retrieved.text == hello_world_passage_fixture.text + + +def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user): + """Test retrieving a passage by ID""" + retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) + assert retrieved is not None + assert retrieved.id == hello_world_passage_fixture.id + assert retrieved.text == hello_world_passage_fixture.text + + +def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user): + """Test updating a passage""" + new_text = "Updated text" + hello_world_passage_fixture.text = new_text + updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user) + assert updated is not None + assert updated.text == new_text + retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) + assert retrieved.text == new_text + + +def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user): + """Test deleting a passage""" + server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user) + retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user) + assert retrieved is None + + +def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user): + """Test counting passages with filters""" + base_passage = hello_world_passage_fixture + + # Test total count + total = server.passage_manager.size(actor=default_user) + assert total == 5 # base passage + 4 test passages + # TODO: change login passage to be a system not user passage + + # Test count with agent filter + agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id) + assert agent_count == 5 + + # Test count with role filter + role_count = server.passage_manager.size(actor=default_user) + assert role_count == 5 + + # Test count with non-existent filter + empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent") + assert empty_count == 0 + + +def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user): + """Test basic passage listing with limit""" + results = server.passage_manager.list_passages(actor=default_user, limit=3) + assert len(results) == 3 + + +def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user): + """Test cursor-based pagination functionality""" + + # Make sure there are 5 passages + assert server.passage_manager.size(actor=default_user) == 5 + + # Get first page + first_page = server.passage_manager.list_passages(actor=default_user, limit=3) + assert len(first_page) == 3 + + last_id_on_first_page = first_page[-1].id + + # Get second page + second_page = server.passage_manager.list_passages( + actor=default_user, cursor=last_id_on_first_page, limit=3 + ) + assert len(second_page) == 2 # Should have 2 remaining passages + assert all(r1.id != r2.id for r1 in first_page for r2 in second_page) + + +def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent): + """Test filtering passages by agent ID""" + agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10) + assert len(agent_results) == 5 # base passage + 4 test passages + assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results) + + +def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent): + """Test searching passages by text content""" + search_results = server.passage_manager.list_passages( + agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10 + ) + assert len(search_results) == 4 + assert all("Test passage" in msg.text for msg in search_results) + + # Test no results + search_results = server.passage_manager.list_passages( + agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10 + ) + assert len(search_results) == 0 + + +def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent): + """Test filtering passages by date range with various scenarios""" + # Set up test data with known dates + base_time = datetime.utcnow() + + # Create passages at different times + passages = [] + time_offsets = [ + timedelta(days=-2), # 2 days ago + timedelta(days=-1), # Yesterday + timedelta(hours=-2), # 2 hours ago + timedelta(minutes=-30), # 30 minutes ago + timedelta(minutes=-1), # 1 minute ago + timedelta(minutes=0), # Now + ] + + for i, offset in enumerate(time_offsets): + timestamp = base_time + offset + passage = server.passage_manager.create_passage( + PydanticPassage( + organization_id=default_user.organization_id, + agent_id=sarah_agent.id, + file_id=default_file.id, + text=f"Test passage {i}", + embedding=[0.1, 0.2, 0.3], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + created_at=timestamp + ), + actor=default_user + ) + passages.append(passage) + + # Test cases + test_cases = [ + { + "name": "Recent passages (last hour)", + "start_date": base_time - timedelta(hours=1), + "end_date": base_time + timedelta(minutes=1), + "expected_count": 1 + 3, # Should include base + -30min, -1min, and now + }, + { + "name": "Yesterday's passages", + "start_date": base_time - timedelta(days=1, hours=12), + "end_date": base_time - timedelta(hours=12), + "expected_count": 1, # Should only include yesterday's passage + }, + { + "name": "Future time range", + "start_date": base_time + timedelta(days=1), + "end_date": base_time + timedelta(days=2), + "expected_count": 0, # Should find no passages + }, + { + "name": "All time", + "start_date": base_time - timedelta(days=3), + "end_date": base_time + timedelta(days=1), + "expected_count": 1 + len(passages), # Should find all passages + }, + { + "name": "Exact timestamp match", + "start_date": passages[0].created_at - timedelta(microseconds=1), + "end_date": passages[0].created_at + timedelta(microseconds=1), + "expected_count": 1, # Should find exactly one passage + }, + { + "name": "Small time window", + "start_date": base_time - timedelta(seconds=30), + "end_date": base_time + timedelta(seconds=30), + "expected_count": 1 + 1, # date + "now" + } + ] + + # Run test cases + for case in test_cases: + results = server.passage_manager.list_passages( + agent_id=sarah_agent.id, + actor=default_user, + start_date=case["start_date"], + end_date=case["end_date"], + limit=10 + ) + + # Verify count + assert len(results) == case["expected_count"], \ + f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}" + + # Test edge cases + + # Test with start_date but no end_date + results_start_only = server.passage_manager.list_passages( + agent_id=sarah_agent.id, + actor=default_user, + start_date=base_time - timedelta(minutes=2), + end_date=None, + limit=10 + ) + assert len(results_start_only) >= 2, "Should find passages after start_date" + + # Test with end_date but no start_date + results_end_only = server.passage_manager.list_passages( + agent_id=sarah_agent.id, + actor=default_user, + start_date=None, + end_date=base_time - timedelta(days=1), + limit=10 + ) + assert len(results_end_only) >= 1, "Should find passages before end_date" + + # Test limit enforcement + limited_results = server.passage_manager.list_passages( + agent_id=sarah_agent.id, + actor=default_user, + start_date=base_time - timedelta(days=3), + end_date=base_time + timedelta(days=1), + limit=3 + ) + assert len(limited_results) <= 3, "Should respect the limit parameter" + + +def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent): + """Test vector search functionality for passages.""" + passage_manager = server.passage_manager + embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) + + # Create passages with known embeddings + passages = [] + + # Create passages with different embeddings + test_passages = [ + "I like red", + "random text", + "blue shoes", + ] + + for text in test_passages: + embedding = embed_model.get_text_embedding(text) + passage = PydanticPassage( + text=text, + organization_id=default_user.organization_id, + agent_id=sarah_agent.id, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + embedding=embedding + ) + created_passage = passage_manager.create_passage(passage, default_user) + passages.append(created_passage) + assert passage_manager.size(actor=default_user) == len(passages) + + # Query vector similar to "cats" embedding + query_key = "What's my favorite color?" + + # List passages with vector search + results = passage_manager.list_passages( + actor=default_user, + agent_id=sarah_agent.id, + query_text=query_key, + limit=3, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + embed_query=True, + ) + + # Verify results are ordered by similarity + assert len(results) == 3 + assert results[0].text == "I like red" + assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes" + assert results[2].text == "blue shoes" + + # ====================================================================================================================== # User Manager Tests # ====================================================================================================================== @@ -834,8 +1167,6 @@ def test_delete_block(server: SyncServer, default_user): # ====================================================================================================================== # Source Manager Tests - Sources # ====================================================================================================================== - - def test_create_source(server: SyncServer, default_user): """Test creating a new source.""" source_pydantic = PydanticSource( @@ -1049,8 +1380,6 @@ def test_delete_file(server: SyncServer, default_user, default_source): # ====================================================================================================================== # AgentsTagsManager Tests # ====================================================================================================================== - - def test_add_tag_to_agent(server: SyncServer, sarah_agent, default_user): # Add a tag to the agent tag_name = "test_tag" diff --git a/tests/test_server.py b/tests/test_server.py index 32e5da690..555b7c07e 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -117,6 +117,9 @@ def test_user_message_memory(server, user_id, agent_id): @pytest.mark.order(3) def test_load_data(server, user_id, agent_id): # create source + passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000) + assert len(passages_before) == 0 + source = server.source_manager.create_source( Source(name="test_source", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=server.default_user ) @@ -130,19 +133,17 @@ def test_load_data(server, user_id, agent_id): "Shishir loves indian food", ] connector = DummyDataConnector(archival_memories) - server.load_data(user_id, connector, source.name) + server.load_data(user_id, connector, source.name, agent_id=agent_id) # @pytest.mark.order(3) # def test_attach_source_to_agent(server, user_id, agent_id): # check archival memory size - passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000) - assert len(passages_before) == 0 # attach source server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source") # check archival memory size - passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000) + passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000) assert len(passages_after) == 5 @@ -182,41 +183,42 @@ def test_get_recall_memory(server, org_id, user_id, agent_id): assert message_id in message_ids, f"{message_id} not in {message_ids}" -@pytest.mark.order(6) -def test_get_archival_memory(server, user_id, agent_id): - # test archival memory cursor pagination - passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text") - assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" - cursor1 = passages_1[-1].id - passages_2 = server.get_agent_archival_cursor( - user_id=user_id, - agent_id=agent_id, - reverse=False, - after=cursor1, - order_by="text", - ) - cursor2 = passages_2[-1].id - passages_3 = server.get_agent_archival_cursor( - user_id=user_id, - agent_id=agent_id, - reverse=False, - before=cursor2, - limit=1000, - order_by="text", - ) - passages_3[-1].id - # assert passages_1[0].text == "Cinderella wore a blue dress" - assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test - assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test +# TODO: Out-of-date test. pagination commands are off +# @pytest.mark.order(6) +# def test_get_archival_memory(server, user_id, agent_id): +# # test archival memory cursor pagination +# passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text") +# assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2" +# cursor1 = passages_1[-1].id +# passages_2 = server.get_agent_archival_cursor( +# user_id=user_id, +# agent_id=agent_id, +# reverse=False, +# after=cursor1, +# order_by="text", +# ) +# cursor2 = passages_2[-1].id +# passages_3 = server.get_agent_archival_cursor( +# user_id=user_id, +# agent_id=agent_id, +# reverse=False, +# before=cursor2, +# limit=1000, +# order_by="text", +# ) +# passages_3[-1].id +# # assert passages_1[0].text == "Cinderella wore a blue dress" +# assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test +# assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test - # test archival memory - passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1) - assert len(passage_1) == 1 - passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000) - assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test - # test safe empty return - passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000) - assert len(passage_none) == 0 +# # test archival memory +# passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1) +# assert len(passage_1) == 1 +# passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000) +# assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test +# # test safe empty return +# passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000) +# assert len(passage_none) == 0 def test_agent_rethink_rewrite_retry(server, user_id, agent_id): diff --git a/tests/test_vector_embeddings.py b/tests/test_vector_embeddings.py new file mode 100644 index 000000000..0ad25071a --- /dev/null +++ b/tests/test_vector_embeddings.py @@ -0,0 +1,42 @@ +import numpy as np +import sqlite3 +import base64 +from numpy.testing import assert_array_almost_equal + +import pytest + +from letta.orm.sqlalchemy_base import adapt_array, convert_array +from letta.orm.sqlite_functions import verify_embedding_dimension + +def test_vector_conversions(): + """Test the vector conversion functions""" + # Create test data + original = np.random.random(4096).astype(np.float32) + print(f"Original shape: {original.shape}") + + # Test full conversion cycle + encoded = adapt_array(original) + print(f"Encoded type: {type(encoded)}") + print(f"Encoded length: {len(encoded)}") + + decoded = convert_array(encoded) + print(f"Decoded shape: {decoded.shape}") + print(f"Dimension verification: {verify_embedding_dimension(decoded)}") + + # Verify data integrity + np.testing.assert_array_almost_equal(original, decoded) + print("✓ Data integrity verified") + + # Test with a list + list_data = original.tolist() + encoded_list = adapt_array(list_data) + decoded_list = convert_array(encoded_list) + np.testing.assert_array_almost_equal(original, decoded_list) + print("✓ List conversion verified") + + # Test None handling + assert adapt_array(None) is None + assert convert_array(None) is None + print("✓ None handling verified") + +# Run the tests \ No newline at end of file