mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: orm passage migration (#2180)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
parent
c9c2cca4f4
commit
31d2774193
3
.gitignore
vendored
3
.gitignore
vendored
@ -1022,3 +1022,6 @@ memgpy/pytest.ini
|
||||
|
||||
## ignore venvs
|
||||
tests/test_tool_sandbox/restaurant_management_system/venv
|
||||
|
||||
## custom scripts
|
||||
test
|
||||
|
@ -0,0 +1,88 @@
|
||||
"""Add Passages ORM, drop legacy passages, cascading deletes for file-passages and user-jobs
|
||||
|
||||
Revision ID: c5d964280dff
|
||||
Revises: a91994b9752f
|
||||
Create Date: 2024-12-10 15:05:32.335519
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'c5d964280dff'
|
||||
down_revision: Union[str, None] = 'a91994b9752f'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('passages', sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True))
|
||||
op.add_column('passages', sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False))
|
||||
op.add_column('passages', sa.Column('_created_by_id', sa.String(), nullable=True))
|
||||
op.add_column('passages', sa.Column('_last_updated_by_id', sa.String(), nullable=True))
|
||||
|
||||
# Data migration step:
|
||||
op.add_column("passages", sa.Column("organization_id", sa.String(), nullable=True))
|
||||
# Populate `organization_id` based on `user_id`
|
||||
# Use a raw SQL query to update the organization_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE passages
|
||||
SET organization_id = users.organization_id
|
||||
FROM users
|
||||
WHERE passages.user_id = users.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Set `organization_id` as non-nullable after population
|
||||
op.alter_column("passages", "organization_id", nullable=False)
|
||||
|
||||
op.alter_column('passages', 'text',
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=False)
|
||||
op.alter_column('passages', 'embedding_config',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
op.alter_column('passages', 'metadata_',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=False)
|
||||
op.alter_column('passages', 'created_at',
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=False)
|
||||
op.drop_index('passage_idx_user', table_name='passages')
|
||||
op.create_foreign_key(None, 'passages', 'organizations', ['organization_id'], ['id'])
|
||||
op.create_foreign_key(None, 'passages', 'agents', ['agent_id'], ['id'])
|
||||
op.create_foreign_key(None, 'passages', 'files', ['file_id'], ['id'], ondelete='CASCADE')
|
||||
op.drop_column('passages', 'user_id')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('passages', sa.Column('user_id', sa.VARCHAR(), autoincrement=False, nullable=False))
|
||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
||||
op.drop_constraint(None, 'passages', type_='foreignkey')
|
||||
op.create_index('passage_idx_user', 'passages', ['user_id', 'agent_id', 'file_id'], unique=False)
|
||||
op.alter_column('passages', 'created_at',
|
||||
existing_type=postgresql.TIMESTAMP(timezone=True),
|
||||
nullable=True)
|
||||
op.alter_column('passages', 'metadata_',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
op.alter_column('passages', 'embedding_config',
|
||||
existing_type=postgresql.JSON(astext_type=sa.Text()),
|
||||
nullable=True)
|
||||
op.alter_column('passages', 'text',
|
||||
existing_type=sa.VARCHAR(),
|
||||
nullable=True)
|
||||
op.drop_column('passages', 'organization_id')
|
||||
op.drop_column('passages', '_last_updated_by_id')
|
||||
op.drop_column('passages', '_created_by_id')
|
||||
op.drop_column('passages', 'is_deleted')
|
||||
op.drop_column('passages', 'updated_at')
|
||||
# ### end Alembic commands ###
|
@ -28,7 +28,7 @@ from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.memory import ArchivalMemory, EmbeddingArchivalMemory, summarize_messages
|
||||
from letta.memory import summarize_messages
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.orm import User
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse
|
||||
@ -52,6 +52,7 @@ from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.user_manager import UserManager
|
||||
@ -85,7 +86,7 @@ def compile_memory_metadata_block(
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
memory_edit_timestamp: datetime.datetime,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
) -> str:
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
@ -96,7 +97,7 @@ def compile_memory_metadata_block(
|
||||
[
|
||||
f"### Memory [last modified: {timestamp_str}]",
|
||||
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{archival_memory.count() if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
]
|
||||
)
|
||||
@ -109,7 +110,7 @@ def compile_system_message(
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||
actor: PydanticUser,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
@ -138,7 +139,7 @@ def compile_system_message(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
archival_memory=archival_memory,
|
||||
passage_manager=passage_manager,
|
||||
message_manager=message_manager,
|
||||
)
|
||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||
@ -175,7 +176,7 @@ def initialize_message_sequence(
|
||||
agent_id: str,
|
||||
memory: Memory,
|
||||
actor: PydanticUser,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
@ -184,7 +185,7 @@ def initialize_message_sequence(
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
# full_system_message = construct_system_with_memory(
|
||||
# system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
|
||||
# system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory
|
||||
# )
|
||||
full_system_message = compile_system_message(
|
||||
agent_id=agent_id,
|
||||
@ -192,7 +193,7 @@ def initialize_message_sequence(
|
||||
in_context_memory=memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=actor,
|
||||
archival_memory=archival_memory,
|
||||
passage_manager=passage_manager,
|
||||
message_manager=message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
@ -294,7 +295,7 @@ class Agent(BaseAgent):
|
||||
self.interface = interface
|
||||
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
self.archival_memory = EmbeddingArchivalMemory(agent_state)
|
||||
self.passage_manager = PassageManager()
|
||||
self.message_manager = MessageManager()
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
@ -325,7 +326,7 @@ class Agent(BaseAgent):
|
||||
agent_id=self.agent_state.id,
|
||||
memory=self.agent_state.memory,
|
||||
actor=self.user,
|
||||
archival_memory=None,
|
||||
passage_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
@ -350,7 +351,7 @@ class Agent(BaseAgent):
|
||||
memory=self.agent_state.memory,
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
archival_memory=None,
|
||||
passage_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
@ -1306,7 +1307,7 @@ class Agent(BaseAgent):
|
||||
in_context_memory=self.agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=self.user,
|
||||
archival_memory=self.archival_memory,
|
||||
passage_manager=self.passage_manager,
|
||||
message_manager=self.message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
@ -1371,45 +1372,33 @@ class Agent(BaseAgent):
|
||||
# TODO: recall memory
|
||||
raise NotImplementedError()
|
||||
|
||||
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
|
||||
def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore):
|
||||
"""Attach data with name `source_name` to the agent from source_connector."""
|
||||
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
|
||||
user = UserManager().get_user_by_id(self.agent_state.user_id)
|
||||
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
|
||||
size = source_connector.size(filters)
|
||||
page_size = 100
|
||||
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
|
||||
all_passages = []
|
||||
for i in tqdm(range(0, size, page_size)):
|
||||
passages = next(generator)
|
||||
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
|
||||
|
||||
# need to associated passage with agent (for filtering)
|
||||
for passage in passages:
|
||||
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
|
||||
passage.agent_id = self.agent_state.id
|
||||
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)
|
||||
|
||||
# regenerate passage ID (avoid duplicates)
|
||||
# TODO: need to find another solution to the text duplication issue
|
||||
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
|
||||
|
||||
# insert into agent archival memory
|
||||
self.archival_memory.storage.insert_many(passages)
|
||||
all_passages += passages
|
||||
|
||||
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
|
||||
|
||||
# save destination storage
|
||||
self.archival_memory.storage.save()
|
||||
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
|
||||
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
|
||||
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
|
||||
assert len(agents_passages) == passage_size # sanity check
|
||||
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
|
||||
|
||||
# attach to agent
|
||||
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
|
||||
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
|
||||
assert source is not None, f"Source {source_id} not found in metadata store"
|
||||
|
||||
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
|
||||
# TODO: delete @matt and remove
|
||||
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
|
||||
|
||||
total_agent_passages = self.archival_memory.storage.size()
|
||||
|
||||
printd(
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
|
||||
)
|
||||
|
||||
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
|
||||
@ -1565,13 +1554,13 @@ class Agent(BaseAgent):
|
||||
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
||||
)
|
||||
|
||||
num_archival_memory = self.archival_memory.storage.size()
|
||||
passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
external_memory_summary = compile_memory_metadata_block(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
||||
archival_memory=self.archival_memory,
|
||||
passage_manager=self.passage_manager,
|
||||
message_manager=self.message_manager,
|
||||
)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
@ -1597,7 +1586,7 @@ class Agent(BaseAgent):
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(self._messages),
|
||||
num_archival_memory=num_archival_memory,
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
# top-level information
|
||||
|
@ -1,297 +0,0 @@
|
||||
from typing import Dict, List, Optional, Tuple, cast
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.types import Include
|
||||
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.utils import datetime_to_timestamp, printd, timestamp_to_datetime
|
||||
|
||||
|
||||
class ChromaStorageConnector(StorageConnector):
|
||||
"""Storage via Chroma"""
|
||||
|
||||
# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.
|
||||
# Timestamps are converted to integer timestamps for chroma (datetime not supported)
|
||||
|
||||
def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
|
||||
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
||||
|
||||
assert table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES, "Chroma only supports archival memory"
|
||||
|
||||
# create chroma client
|
||||
if config.archival_storage_path:
|
||||
self.client = chromadb.PersistentClient(config.archival_storage_path)
|
||||
else:
|
||||
# assume uri={ip}:{port}
|
||||
ip = config.archival_storage_uri.split(":")[0]
|
||||
port = config.archival_storage_uri.split(":")[1]
|
||||
self.client = chromadb.HttpClient(host=ip, port=port)
|
||||
|
||||
# get a collection or create if it doesn't exist already
|
||||
self.collection = self.client.get_or_create_collection(self.table_name)
|
||||
self.include: Include = ["documents", "embeddings", "metadatas"]
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
|
||||
# get all filters for query
|
||||
if filters is not None:
|
||||
filter_conditions = {**self.filters, **filters}
|
||||
else:
|
||||
filter_conditions = self.filters
|
||||
|
||||
# convert to chroma format
|
||||
chroma_filters = []
|
||||
ids = []
|
||||
for key, value in filter_conditions.items():
|
||||
# filter by id
|
||||
if key == "id":
|
||||
ids = [str(value)]
|
||||
continue
|
||||
|
||||
# filter by other keys
|
||||
chroma_filters.append({key: {"$eq": value}})
|
||||
|
||||
if len(chroma_filters) > 1:
|
||||
chroma_filters = {"$and": chroma_filters}
|
||||
elif len(chroma_filters) == 0:
|
||||
chroma_filters = {}
|
||||
else:
|
||||
chroma_filters = chroma_filters[0]
|
||||
return ids, chroma_filters
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0):
|
||||
ids, filters = self.get_filters(filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
results = self.collection.get(ids=ids, offset=offset, limit=page_size, include=self.include, where=filters)
|
||||
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
assert results["embeddings"] is not None, f"results['embeddings'] was None"
|
||||
if len(results["embeddings"]) == 0:
|
||||
break
|
||||
|
||||
# Yield a list of Record objects converted from the chunk
|
||||
yield self.results_to_records(results)
|
||||
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
|
||||
def results_to_records(self, results):
|
||||
# convert timestamps to datetime
|
||||
for metadata in results["metadatas"]:
|
||||
if "created_at" in metadata:
|
||||
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
|
||||
if results["embeddings"]: # may not be returned, depending on table type
|
||||
passages = []
|
||||
for text, record_id, embedding, metadata in zip(
|
||||
results["documents"], results["ids"], results["embeddings"], results["metadatas"]
|
||||
):
|
||||
args = {}
|
||||
for field in EmbeddingConfig.__fields__.keys():
|
||||
if field in metadata:
|
||||
args[field] = metadata[field]
|
||||
del metadata[field]
|
||||
embedding_config = EmbeddingConfig(**args)
|
||||
passages.append(Passage(text=text, embedding=embedding, id=record_id, embedding_config=embedding_config, **metadata))
|
||||
# return [
|
||||
# Passage(text=text, embedding=embedding, id=record_id, embedding_config=EmbeddingConfig(), **metadatas)
|
||||
# for (text, record_id, embedding, metadatas) in zip(
|
||||
# results["documents"], results["ids"], results["embeddings"], results["metadatas"]
|
||||
# )
|
||||
# ]
|
||||
return passages
|
||||
else:
|
||||
# no embeddings
|
||||
passages = []
|
||||
for text, id, metadata in zip(results["documents"], results["ids"], results["metadatas"]):
|
||||
args = {}
|
||||
for field in EmbeddingConfig.__fields__.keys():
|
||||
if field in metadata:
|
||||
args[field] = metadata[field]
|
||||
del metadata[field]
|
||||
embedding_config = EmbeddingConfig(**args)
|
||||
passages.append(Passage(text=text, embedding=None, id=id, embedding_config=embedding_config, **metadata))
|
||||
return passages
|
||||
|
||||
# return [
|
||||
# #cast(Passage, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
|
||||
# Passage(text=text, embedding=None, id=id, **metadatas)
|
||||
# for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
|
||||
# ]
|
||||
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None):
|
||||
ids, filters = self.get_filters(filters)
|
||||
if self.collection.count() == 0:
|
||||
return []
|
||||
if ids == []:
|
||||
ids = None
|
||||
if limit:
|
||||
results = self.collection.get(ids=ids, include=self.include, where=filters, limit=limit)
|
||||
else:
|
||||
results = self.collection.get(ids=ids, include=self.include, where=filters)
|
||||
return self.results_to_records(results)
|
||||
|
||||
def get(self, id: str):
|
||||
results = self.collection.get(ids=[str(id)])
|
||||
if len(results["ids"]) == 0:
|
||||
return None
|
||||
return self.results_to_records(results)[0]
|
||||
|
||||
def format_records(self, records):
|
||||
assert all([isinstance(r, Passage) for r in records])
|
||||
|
||||
recs = []
|
||||
ids = []
|
||||
documents = []
|
||||
embeddings = []
|
||||
|
||||
# de-duplication of ids
|
||||
exist_ids = set()
|
||||
for i in range(len(records)):
|
||||
record = records[i]
|
||||
if record.id in exist_ids:
|
||||
continue
|
||||
exist_ids.add(record.id)
|
||||
recs.append(cast(Passage, record))
|
||||
ids.append(str(record.id))
|
||||
documents.append(record.text)
|
||||
embeddings.append(record.embedding)
|
||||
|
||||
# collect/format record metadata
|
||||
metadatas = []
|
||||
for record in recs:
|
||||
embedding_config = vars(record.embedding_config)
|
||||
metadata = vars(record)
|
||||
metadata.pop("id")
|
||||
metadata.pop("text")
|
||||
metadata.pop("embedding")
|
||||
metadata.pop("embedding_config")
|
||||
metadata.pop("metadata_")
|
||||
if "created_at" in metadata:
|
||||
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
|
||||
if "metadata_" in metadata and metadata["metadata_"] is not None:
|
||||
record_metadata = dict(metadata["metadata_"])
|
||||
metadata.pop("metadata_")
|
||||
else:
|
||||
record_metadata = {}
|
||||
|
||||
metadata = {**metadata, **record_metadata} # merge with metadata
|
||||
metadata = {**metadata, **embedding_config} # merge with embedding config
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
|
||||
|
||||
# convert uuids to strings
|
||||
metadatas.append(metadata)
|
||||
return ids, documents, embeddings, metadatas
|
||||
|
||||
def insert(self, record):
|
||||
ids, documents, embeddings, metadatas = self.format_records([record])
|
||||
if any([e is None for e in embeddings]):
|
||||
raise ValueError("Embeddings must be provided to chroma")
|
||||
self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas)
|
||||
|
||||
def insert_many(self, records, show_progress=False):
|
||||
ids, documents, embeddings, metadatas = self.format_records(records)
|
||||
if any([e is None for e in embeddings]):
|
||||
raise ValueError("Embeddings must be provided to chroma")
|
||||
self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas)
|
||||
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
ids, filters = self.get_filters(filters)
|
||||
self.collection.delete(ids=ids, where=filters)
|
||||
|
||||
def delete_table(self):
|
||||
# drop collection
|
||||
self.client.delete_collection(self.collection.name)
|
||||
|
||||
def save(self):
|
||||
# save to persistence file (nothing needs to be done)
|
||||
printd("Saving chroma")
|
||||
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# unfortuantely, need to use pagination to get filtering
|
||||
# warning: poor performance for large datasets
|
||||
return len(self.get_all(filters=filters))
|
||||
|
||||
def list_data_sources(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
||||
ids, filters = self.get_filters(filters)
|
||||
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters)
|
||||
|
||||
# flatten, since we only have one query vector
|
||||
flattened_results = {}
|
||||
for key, value in results.items():
|
||||
if value:
|
||||
# value is an Optional[List] type according to chromadb.api.types
|
||||
flattened_results[key] = value[0] # type: ignore
|
||||
assert len(value) == 1, f"Value is size {len(value)}: {value}" # type: ignore
|
||||
else:
|
||||
flattened_results[key] = value
|
||||
|
||||
return self.results_to_records(flattened_results)
|
||||
|
||||
def query_date(self, start_date, end_date, start=None, count=None):
|
||||
raise ValueError("Cannot run query_date with chroma")
|
||||
# filters = self.get_filters(filters)
|
||||
# filters["created_at"] = {
|
||||
# "$gte": start_date,
|
||||
# "$lte": end_date,
|
||||
# }
|
||||
# results = self.collection.query(where=filters)
|
||||
# start = 0 if start is None else start
|
||||
# count = len(results) if count is None else count
|
||||
# results = results[start : start + count]
|
||||
# return self.results_to_records(results)
|
||||
|
||||
def query_text(self, query, count=None, start=None, filters: Optional[Dict] = {}):
|
||||
raise ValueError("Cannot run query_text with chroma")
|
||||
# filters = self.get_filters(filters)
|
||||
# results = self.collection.query(where_document={"$contains": {"text": query}}, where=filters)
|
||||
# start = 0 if start is None else start
|
||||
# count = len(results) if count is None else count
|
||||
# results = results[start : start + count]
|
||||
# return self.results_to_records(results)
|
||||
|
||||
def get_all_cursor(
|
||||
self,
|
||||
filters: Optional[Dict] = {},
|
||||
after: str = None,
|
||||
before: str = None,
|
||||
limit: Optional[int] = 1000,
|
||||
order_by: str = "created_at",
|
||||
reverse: bool = False,
|
||||
):
|
||||
records = self.get_all(filters=filters)
|
||||
|
||||
# WARNING: very hacky and slow implementation
|
||||
def get_index(id, record_list):
|
||||
for i in range(len(record_list)):
|
||||
if record_list[i].id == id:
|
||||
return i
|
||||
assert False, f"Could not find id {id} in record list"
|
||||
|
||||
# sort by custom field
|
||||
records = sorted(records, key=lambda x: getattr(x, order_by), reverse=reverse)
|
||||
if after:
|
||||
index = get_index(after, records)
|
||||
if index + 1 >= len(records):
|
||||
return None, []
|
||||
records = records[index + 1 :]
|
||||
if before:
|
||||
index = get_index(before, records)
|
||||
if index == 0:
|
||||
return None, []
|
||||
|
||||
# TODO: not sure if this is correct
|
||||
records = records[:index]
|
||||
|
||||
if len(records) == 0:
|
||||
return None, []
|
||||
|
||||
# enforce limit
|
||||
if limit:
|
||||
records = records[:limit]
|
||||
return records[-1].id, records
|
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
@ -32,7 +33,7 @@ from letta.orm.base import Base
|
||||
from letta.orm.file import FileMetadata as FileMetadataModel
|
||||
|
||||
# from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.orm.passage import Passage as PassageModel
|
||||
from letta.settings import settings
|
||||
|
||||
config = LettaConfig()
|
||||
@ -66,56 +67,6 @@ class CommonVector(TypeDecorator):
|
||||
# For PostgreSQL, value is already in bytes
|
||||
return np.frombuffer(value, dtype=np.float32)
|
||||
|
||||
|
||||
class PassageModel(Base):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
|
||||
__tablename__ = "passages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
text = Column(String)
|
||||
file_id = Column(String)
|
||||
agent_id = Column(String)
|
||||
source_id = Column(String)
|
||||
|
||||
# vector storage
|
||||
if settings.letta_pg_uri_no_default:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
elif config.archival_storage_type == "sqlite" or config.archival_storage_type == "chroma":
|
||||
embedding = Column(CommonVector)
|
||||
else:
|
||||
raise ValueError(f"Unsupported archival_storage_type: {config.archival_storage_type}")
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
metadata_ = Column(MutableJson)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
Index("passage_idx_user", user_id, agent_id, file_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
embedding_config=self.embedding_config,
|
||||
file_id=self.file_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
source_id=self.source_id,
|
||||
agent_id=self.agent_id,
|
||||
metadata_=self.metadata_,
|
||||
created_at=self.created_at,
|
||||
)
|
||||
|
||||
|
||||
class SQLStorageConnector(StorageConnector):
|
||||
def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
|
||||
super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
|
||||
@ -320,6 +271,7 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
self.session_maker = db_context
|
||||
|
||||
# TODO: move to DB init
|
||||
if settings.pg_uri:
|
||||
with self.session_maker() as session:
|
||||
session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
@ -419,7 +371,13 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
|
||||
# get storage URI
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
self.db_model = PassageModel
|
||||
if settings.letta_pg_uri_no_default:
|
||||
self.uri = settings.letta_pg_uri_no_default
|
||||
else:
|
||||
# For SQLite, use the archival storage path
|
||||
self.path = config.archival_storage_path
|
||||
self.uri = f"sqlite:///{os.path.join(config.archival_storage_path, 'letta.db')}"
|
||||
elif table_type == TableType.FILES:
|
||||
self.path = self.config.metadata_storage_path
|
||||
if self.path is None:
|
||||
|
@ -45,11 +45,13 @@ class StorageConnector:
|
||||
self,
|
||||
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
|
||||
config: LettaConfig,
|
||||
user_id,
|
||||
agent_id=None,
|
||||
user_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
):
|
||||
self.user_id = user_id
|
||||
self.agent_id = agent_id
|
||||
self.organization_id = organization_id
|
||||
self.table_type = table_type
|
||||
|
||||
# get object type
|
||||
@ -74,10 +76,12 @@ class StorageConnector:
|
||||
# agent-specific table
|
||||
assert agent_id is not None, "Agent ID must be provided for agent-specific tables"
|
||||
self.filters = {"user_id": self.user_id, "agent_id": self.agent_id}
|
||||
elif self.table_type == TableType.PASSAGES or self.table_type == TableType.FILES:
|
||||
elif self.table_type == TableType.FILES:
|
||||
# setup base filters for user-specific tables
|
||||
assert agent_id is None, "Agent ID must not be provided for user-specific tables"
|
||||
self.filters = {"user_id": self.user_id}
|
||||
elif self.table_type == TableType.PASSAGES:
|
||||
self.filters = {"organization_id": self.organization_id}
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
@ -85,8 +89,9 @@ class StorageConnector:
|
||||
def get_storage_connector(
|
||||
table_type: Union[TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY, TableType.PASSAGES, TableType.FILES],
|
||||
config: LettaConfig,
|
||||
user_id,
|
||||
agent_id=None,
|
||||
user_id: str,
|
||||
organization_id: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
storage_type = config.archival_storage_type
|
||||
@ -101,10 +106,6 @@ class StorageConnector:
|
||||
from letta.agent_store.db import PostgresStorageConnector
|
||||
|
||||
return PostgresStorageConnector(table_type, config, user_id, agent_id)
|
||||
elif storage_type == "chroma":
|
||||
from letta.agent_store.chroma import ChromaStorageConnector
|
||||
|
||||
return ChromaStorageConnector(table_type, config, user_id, agent_id)
|
||||
|
||||
elif storage_type == "qdrant":
|
||||
from letta.agent_store.qdrant import QdrantStorageConnector
|
||||
|
@ -742,7 +742,8 @@ class RESTClient(AbstractClient):
|
||||
agents = [AgentState(**agent) for agent in response.json()]
|
||||
if len(agents) == 0:
|
||||
return None
|
||||
assert len(agents) == 1, f"Multiple agents with the same name: {agents}"
|
||||
agents = [agents[0]] # TODO: @matt monkeypatched
|
||||
assert len(agents) == 1, f"Multiple agents with the same name: {[(agents.name, agents.id) for agents in agents]}"
|
||||
return agents[0].id
|
||||
|
||||
# memory
|
||||
@ -3107,7 +3108,7 @@ class LocalClient(AbstractClient):
|
||||
passages (List[Passage]): List of passages
|
||||
"""
|
||||
|
||||
return self.server.get_agent_archival_cursor(user_id=self.user_id, agent_id=agent_id, before=before, after=after, limit=limit)
|
||||
return self.server.get_agent_archival_cursor(user_id=self.user_id, agent_id=agent_id, limit=limit)
|
||||
|
||||
# recall memory
|
||||
|
||||
|
@ -62,8 +62,8 @@ class LettaConfig:
|
||||
# @norton120 these are the metdadatastore
|
||||
|
||||
# database configs: archival
|
||||
archival_storage_type: str = "chroma" # local, db
|
||||
archival_storage_path: str = os.path.join(LETTA_DIR, "chroma")
|
||||
archival_storage_type: str = "sqlite" # local, db
|
||||
archival_storage_path: str = LETTA_DIR
|
||||
archival_storage_uri: str = None # TODO: eventually allow external vector DB
|
||||
|
||||
# database configs: recall
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
from typing import Dict, Iterator, List, Tuple, Optional
|
||||
|
||||
import typer
|
||||
|
||||
@ -42,7 +42,7 @@ class DataConnector:
|
||||
"""
|
||||
|
||||
|
||||
def load_data(connector: DataConnector, source: Source, passage_store: StorageConnector, source_manager: SourceManager, actor: "User"):
|
||||
def load_data(connector: DataConnector, source: Source, passage_store: StorageConnector, source_manager: SourceManager, actor: "User", agent_id: Optional[str] = None):
|
||||
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
|
||||
embedding_config = source.embedding_config
|
||||
|
||||
@ -82,9 +82,10 @@ def load_data(connector: DataConnector, source: Source, passage_store: StorageCo
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||
text=passage_text,
|
||||
file_id=file_metadata.id,
|
||||
agent_id=agent_id,
|
||||
source_id=source.id,
|
||||
metadata_=passage_metadata,
|
||||
user_id=source.created_by_id,
|
||||
organization_id=source.organization_id,
|
||||
embedding_config=source.embedding_config,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
@ -164,17 +164,23 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.archival_memory.insert(content)
|
||||
self.passage_manager.insert_passage(
|
||||
agent_state=self.agent_state,
|
||||
agent_id=self.agent_state.id,
|
||||
text=content,
|
||||
actor=self.user,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0, start: Optional[int] = 0) -> Optional[str]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
start (Optional[int]): Starting index for the search results. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
@ -191,15 +197,34 @@ def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results, total = self.archival_memory.search(query, count=count, start=page * count)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, memory: {d['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
all_results = self.passage_manager.list_passages(
|
||||
actor=self.user,
|
||||
query_text=query,
|
||||
limit=count + start, # Request enough results to handle offset
|
||||
embedding_config=self.agent_state.embedding_config,
|
||||
embed_query=True
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
end = min(count + start, len(all_results))
|
||||
paged_results = all_results[start:end]
|
||||
|
||||
# Format results to match previous implementation
|
||||
formatted_results = [
|
||||
{
|
||||
"timestamp": str(result.created_at),
|
||||
"content": result.text
|
||||
}
|
||||
for result in paged_results
|
||||
]
|
||||
|
||||
return formatted_results, len(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def core_memory_append(agent_state: "AgentState", label: str, content: str) -> Optional[str]: # type: ignore
|
||||
|
@ -363,7 +363,18 @@ class MetadataStore:
|
||||
with self.session_maker() as session:
|
||||
# TODO: remove this (is a hack)
|
||||
mapping_id = f"{user_id}-{agent_id}-{source_id}"
|
||||
session.add(AgentSourceMappingModel(id=mapping_id, user_id=user_id, agent_id=agent_id, source_id=source_id))
|
||||
existing = session.query(AgentSourceMappingModel).filter(
|
||||
AgentSourceMappingModel.id == mapping_id
|
||||
).first()
|
||||
|
||||
if existing is None:
|
||||
# Only create if it doesn't exist
|
||||
session.add(AgentSourceMappingModel(
|
||||
id=mapping_id,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
source_id=source_id
|
||||
))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
|
@ -6,6 +6,7 @@ from letta.orm.file import FileMetadata
|
||||
from letta.orm.job import Job
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import Passage
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.tool import Tool
|
||||
|
@ -27,3 +27,4 @@ class FileMetadata(SqlalchemyBase, OrganizationMixin, SourceMixin):
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="files", lazy="selectin")
|
||||
source: Mapped["Source"] = relationship("Source", back_populates="files", lazy="selectin")
|
||||
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="file", lazy="selectin", cascade="all, delete-orphan")
|
||||
|
@ -1,3 +1,4 @@
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
@ -30,6 +31,12 @@ class UserMixin(Base):
|
||||
|
||||
user_id: Mapped[str] = mapped_column(String, ForeignKey("users.id"))
|
||||
|
||||
class FileMixin(Base):
|
||||
"""Mixin for models that belong to a file."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
|
||||
|
||||
class AgentMixin(Base):
|
||||
"""Mixin for models that belong to an agent."""
|
||||
@ -38,13 +45,16 @@ class AgentMixin(Base):
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"))
|
||||
|
||||
|
||||
class FileMixin(Base):
|
||||
"""Mixin for models that belong to a file."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
file_id: Mapped[str] = mapped_column(String, ForeignKey("files.id"))
|
||||
file_id: Mapped[Optional[str]] = mapped_column(
|
||||
String,
|
||||
ForeignKey("files.id", ondelete="CASCADE"),
|
||||
nullable=True
|
||||
)
|
||||
|
||||
|
||||
class SourceMixin(Base):
|
||||
|
@ -33,7 +33,10 @@ class Organization(SqlalchemyBase):
|
||||
sandbox_environment_variables: Mapped[List["SandboxEnvironmentVariable"]] = relationship(
|
||||
"SandboxEnvironmentVariable", back_populates="organization", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# relationships
|
||||
messages: Mapped[List["Message"]] = relationship("Message", back_populates="organization", cascade="all, delete-orphan")
|
||||
passages: Mapped[List["Passage"]] = relationship("Passage", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
# TODO: Map these relationships later when we actually make these models
|
||||
# below is just a suggestion
|
||||
|
72
letta/orm/passage.py
Normal file
72
letta/orm/passage.py
Normal file
@ -0,0 +1,72 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
from sqlalchemy import Column, String, DateTime, Index, JSON, UniqueConstraint, ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
from sqlalchemy.types import TypeDecorator, BINARY
|
||||
|
||||
import numpy as np
|
||||
import base64
|
||||
|
||||
from letta.orm.source import EmbeddingConfigColumn
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.orm.mixins import AgentMixin, FileMixin, OrganizationMixin
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.settings import settings
|
||||
|
||||
config = LettaConfig()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.file import File
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
class CommonVector(TypeDecorator):
|
||||
"""Common type for representing vectors in SQLite"""
|
||||
impl = BINARY
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(BINARY())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
value = np.array(value, dtype=np.float32)
|
||||
return base64.b64encode(value.tobytes())
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if not value:
|
||||
return value
|
||||
if dialect.name == "sqlite":
|
||||
value = base64.b64decode(value)
|
||||
return np.frombuffer(value, dtype=np.float32)
|
||||
|
||||
# TODO: After migration to Passage, will need to manually delete passages where files
|
||||
# are deleted on web
|
||||
class Passage(SqlalchemyBase, OrganizationMixin, FileMixin):
|
||||
"""Defines data model for storing Passages"""
|
||||
__tablename__ = "passages"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
__pydantic_model__ = PydanticPassage
|
||||
|
||||
id: Mapped[str] = mapped_column(primary_key=True, doc="Unique passage identifier")
|
||||
text: Mapped[str] = mapped_column(doc="Passage text content")
|
||||
source_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="Source identifier")
|
||||
embedding_config: Mapped[dict] = mapped_column(EmbeddingConfigColumn, doc="Embedding configuration")
|
||||
metadata_: Mapped[dict] = mapped_column(JSON, doc="Additional metadata")
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=datetime.utcnow)
|
||||
if settings.letta_pg_uri_no_default:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
else:
|
||||
embedding = Column(CommonVector)
|
||||
|
||||
# Foreign keys
|
||||
agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id"), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="passages", lazy="selectin")
|
||||
file: Mapped["FileMetadata"] = relationship("FileMetadata", back_populates="passages", lazy="selectin")
|
@ -1,6 +1,7 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
||||
import sqlite3
|
||||
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
@ -8,6 +9,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.base import Base, CommonSqlalchemyMetaMixins
|
||||
from letta.orm.sqlite_functions import adapt_array, convert_array, cosine_distance
|
||||
from letta.orm.errors import (
|
||||
ForeignKeyConstraintViolationError,
|
||||
NoResultFound,
|
||||
@ -60,6 +62,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
query_embedding: Optional[List[float]] = None,
|
||||
ascending: bool = True,
|
||||
**kwargs,
|
||||
) -> List[Type["SqlalchemyBase"]]:
|
||||
@ -110,13 +113,39 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
# Apply text search
|
||||
if query_text:
|
||||
from sqlalchemy import func
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
|
||||
# Handle soft deletes
|
||||
# Apply embedding search (Passages)
|
||||
is_ordered = False
|
||||
if query_embedding:
|
||||
# check if embedding column exists. should only exist for passages
|
||||
if not hasattr(cls, "embedding"):
|
||||
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
||||
|
||||
from letta.settings import settings
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
from pgvector.sqlalchemy import Vector
|
||||
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
from sqlalchemy import func
|
||||
|
||||
query_embedding_binary = adapt_array(query_embedding)
|
||||
query = query.order_by(
|
||||
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
||||
cls.created_at.asc(),
|
||||
cls.id.asc()
|
||||
)
|
||||
is_ordered = True
|
||||
|
||||
# Handle ordering and soft deletes
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
|
||||
# Apply ordering by created_at
|
||||
if not is_ordered:
|
||||
if ascending:
|
||||
query = query.order_by(cls.created_at, cls.id)
|
||||
else:
|
||||
|
140
letta/orm/sqlite_functions.py
Normal file
140
letta/orm/sqlite_functions.py
Normal file
@ -0,0 +1,140 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import base64
|
||||
import numpy as np
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
import sqlite3
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
|
||||
def adapt_array(arr):
|
||||
"""
|
||||
Converts numpy array to binary for SQLite storage
|
||||
"""
|
||||
if arr is None:
|
||||
return None
|
||||
|
||||
if isinstance(arr, list):
|
||||
arr = np.array(arr, dtype=np.float32)
|
||||
elif not isinstance(arr, np.ndarray):
|
||||
raise ValueError(f"Unsupported type: {type(arr)}")
|
||||
|
||||
# Convert to bytes and then base64 encode
|
||||
bytes_data = arr.tobytes()
|
||||
base64_data = base64.b64encode(bytes_data)
|
||||
return sqlite3.Binary(base64_data)
|
||||
|
||||
def convert_array(text):
|
||||
"""
|
||||
Converts binary back to numpy array
|
||||
"""
|
||||
if text is None:
|
||||
return None
|
||||
if isinstance(text, list):
|
||||
return np.array(text, dtype=np.float32)
|
||||
if isinstance(text, np.ndarray):
|
||||
return text
|
||||
|
||||
# Handle both bytes and sqlite3.Binary
|
||||
binary_data = bytes(text) if isinstance(text, sqlite3.Binary) else text
|
||||
|
||||
try:
|
||||
# First decode base64
|
||||
decoded_data = base64.b64decode(binary_data)
|
||||
# Then convert to numpy array
|
||||
return np.frombuffer(decoded_data, dtype=np.float32)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def verify_embedding_dimension(embedding: np.ndarray, expected_dim: int = MAX_EMBEDDING_DIM) -> bool:
|
||||
"""
|
||||
Verifies that an embedding has the expected dimension
|
||||
|
||||
Args:
|
||||
embedding: Input embedding array
|
||||
expected_dim: Expected embedding dimension (default: 4096)
|
||||
|
||||
Returns:
|
||||
bool: True if dimension matches, False otherwise
|
||||
"""
|
||||
if embedding is None:
|
||||
return False
|
||||
return embedding.shape[0] == expected_dim
|
||||
|
||||
def validate_and_transform_embedding(
|
||||
embedding: Union[bytes, sqlite3.Binary, list, np.ndarray],
|
||||
expected_dim: int = MAX_EMBEDDING_DIM,
|
||||
dtype: np.dtype = np.float32
|
||||
) -> Optional[np.ndarray]:
|
||||
"""
|
||||
Validates and transforms embeddings to ensure correct dimensionality.
|
||||
|
||||
Args:
|
||||
embedding: Input embedding in various possible formats
|
||||
expected_dim: Expected embedding dimension (default 4096)
|
||||
dtype: NumPy dtype for the embedding (default float32)
|
||||
|
||||
Returns:
|
||||
np.ndarray: Validated and transformed embedding
|
||||
|
||||
Raises:
|
||||
ValueError: If embedding dimension doesn't match expected dimension
|
||||
"""
|
||||
if embedding is None:
|
||||
return None
|
||||
|
||||
# Convert to numpy array based on input type
|
||||
if isinstance(embedding, (bytes, sqlite3.Binary)):
|
||||
vec = convert_array(embedding)
|
||||
elif isinstance(embedding, list):
|
||||
vec = np.array(embedding, dtype=dtype)
|
||||
elif isinstance(embedding, np.ndarray):
|
||||
vec = embedding.astype(dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported embedding type: {type(embedding)}")
|
||||
|
||||
# Validate dimension
|
||||
if vec.shape[0] != expected_dim:
|
||||
raise ValueError(
|
||||
f"Invalid embedding dimension: got {vec.shape[0]}, expected {expected_dim}"
|
||||
)
|
||||
|
||||
return vec
|
||||
|
||||
def cosine_distance(embedding1, embedding2, expected_dim=MAX_EMBEDDING_DIM):
|
||||
"""
|
||||
Calculate cosine distance between two embeddings
|
||||
|
||||
Args:
|
||||
embedding1: First embedding
|
||||
embedding2: Second embedding
|
||||
expected_dim: Expected embedding dimension (default 4096)
|
||||
|
||||
Returns:
|
||||
float: Cosine distance
|
||||
"""
|
||||
|
||||
if embedding1 is None or embedding2 is None:
|
||||
return 0.0 # Maximum distance if either embedding is None
|
||||
|
||||
try:
|
||||
vec1 = validate_and_transform_embedding(embedding1, expected_dim)
|
||||
vec2 = validate_and_transform_embedding(embedding2, expected_dim)
|
||||
except ValueError as e:
|
||||
return 0.0
|
||||
|
||||
similarity = np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))
|
||||
distance = float(1.0 - similarity)
|
||||
|
||||
return distance
|
||||
|
||||
@event.listens_for(Engine, "connect")
|
||||
def register_functions(dbapi_connection, connection_record):
|
||||
"""Register SQLite functions"""
|
||||
if isinstance(dbapi_connection, sqlite3.Connection):
|
||||
dbapi_connection.create_function("cosine_distance", 2, cosine_distance)
|
||||
|
||||
# Register adapters and converters for numpy arrays
|
||||
sqlite3.register_adapter(np.ndarray, adapt_array)
|
||||
sqlite3.register_converter("ARRAY", convert_array)
|
@ -20,7 +20,7 @@ class User(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="users")
|
||||
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.")
|
||||
jobs: Mapped[List["Job"]] = relationship("Job", back_populates="user", doc="the jobs associated with this user.", cascade="all, delete-orphan")
|
||||
|
||||
# TODO: Add this back later potentially
|
||||
# agents: Mapped[List["Agent"]] = relationship(
|
||||
|
@ -5,15 +5,17 @@ from pydantic import Field, field_validator
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.utils import get_utc_time
|
||||
|
||||
|
||||
class PassageBase(LettaBase):
|
||||
__id_prefix__ = "passage"
|
||||
class PassageBase(OrmMetadataBase):
|
||||
__id_prefix__ = "passage_legacy"
|
||||
|
||||
is_deleted: bool = Field(False, description="Whether this passage is deleted or not.")
|
||||
|
||||
# associated user/agent
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.")
|
||||
|
||||
# origin data source
|
||||
|
@ -369,8 +369,7 @@ def get_agent_archival_memory(
|
||||
return server.get_agent_archival_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
cursor=after, # TODO: deleting before, after. is this expected?
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
@ -16,7 +16,6 @@ import letta.constants as constants
|
||||
import letta.server.utils as server_utils
|
||||
import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.agent_store.db import attach_base
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.chat_only_agent import ChatOnlyAgent
|
||||
from letta.credentials import LettaCredentials
|
||||
@ -70,17 +69,18 @@ from letta.schemas.memory import (
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
@ -125,7 +125,7 @@ class Server(object):
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
actor: User,
|
||||
actor: PydanticUser,
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
) -> AgentState:
|
||||
@ -166,8 +166,6 @@ from letta.settings import model_settings, settings, tool_settings
|
||||
|
||||
config = LettaConfig.load()
|
||||
|
||||
attach_base()
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
@ -245,6 +243,7 @@ class SyncServer(Server):
|
||||
|
||||
# Managers that interface with data models
|
||||
self.organization_manager = OrganizationManager()
|
||||
self.passage_manager = PassageManager()
|
||||
self.user_manager = UserManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.block_manager = BlockManager()
|
||||
@ -498,7 +497,12 @@ class SyncServer(Server):
|
||||
|
||||
# attach data to agent from source
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
letta_agent.attach_source(data_source, source_connector, self.ms)
|
||||
letta_agent.attach_source(
|
||||
user=self.user_manager.get_user_by_id(user_id=user_id),
|
||||
source_id=data_source,
|
||||
source_manager=letta_agent.source_manager,
|
||||
ms=self.ms
|
||||
)
|
||||
|
||||
elif command.lower() == "dump" or command.lower().startswith("dump "):
|
||||
# Check if there's an additional argument that's an integer
|
||||
@ -513,7 +517,7 @@ class SyncServer(Server):
|
||||
letta_agent.interface.print_messages_raw(letta_agent.messages)
|
||||
|
||||
elif command.lower() == "memory":
|
||||
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.archival_memory)}"
|
||||
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}"
|
||||
return ret_str
|
||||
|
||||
elif command.lower() == "pop" or command.lower().startswith("pop "):
|
||||
@ -769,7 +773,7 @@ class SyncServer(Server):
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
actor: User,
|
||||
actor: PydanticUser,
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
) -> AgentState:
|
||||
@ -921,6 +925,7 @@ class SyncServer(Server):
|
||||
|
||||
# get `Tool` objects
|
||||
tools = [self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=user) for tool_name in agent_state.tool_names]
|
||||
tools = [tool for tool in tools if tool is not None]
|
||||
|
||||
# get `Source` objects
|
||||
sources = self.list_attached_sources(agent_id=agent_id)
|
||||
@ -934,7 +939,7 @@ class SyncServer(Server):
|
||||
def update_agent(
|
||||
self,
|
||||
request: UpdateAgentState,
|
||||
actor: User,
|
||||
actor: PydanticUser,
|
||||
) -> AgentState:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
try:
|
||||
@ -1151,7 +1156,7 @@ class SyncServer(Server):
|
||||
|
||||
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
return ArchivalMemorySummary(size=len(agent.archival_memory))
|
||||
return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user))
|
||||
|
||||
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
@ -1176,7 +1181,56 @@ class SyncServer(Server):
|
||||
message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user)
|
||||
return message
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
|
||||
def get_agent_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
start: int,
|
||||
count: int,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
if start < 0 or count < 0:
|
||||
raise ValueError("Start and count values should be non-negative")
|
||||
|
||||
if start + count < len(letta_agent._messages): # messages can be returned from whats in memory
|
||||
# Reverse the list to make it in reverse chronological order
|
||||
reversed_messages = letta_agent._messages[::-1]
|
||||
# Check if start is within the range of the list
|
||||
if start >= len(reversed_messages):
|
||||
raise IndexError("Start index is out of range")
|
||||
|
||||
# Calculate the end index, ensuring it does not exceed the list length
|
||||
end_index = min(start + count, len(reversed_messages))
|
||||
|
||||
# Slice the list for pagination
|
||||
messages = reversed_messages[start:end_index]
|
||||
|
||||
else:
|
||||
# need to access persistence manager for additional messages
|
||||
|
||||
# get messages using message manager
|
||||
page = letta_agent.message_manager.list_user_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=self.default_user,
|
||||
cursor=start,
|
||||
limit=count,
|
||||
)
|
||||
|
||||
messages = page
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
## Convert to json
|
||||
## Add a tag indicating in-context or not
|
||||
# json_messages = [record.to_json() for record in messages]
|
||||
# in_context_message_ids = [str(m.id) for m in letta_agent._messages]
|
||||
# for d in json_messages:
|
||||
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
|
||||
|
||||
return messages
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[PydanticPassage]:
|
||||
"""Paginated query of all messages in agent archival memory"""
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
@ -1187,22 +1241,22 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over records
|
||||
db_iterator = letta_agent.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||
records = letta_agent.passage_manager.list_passages(
|
||||
actor=self.default_user,
|
||||
agent_id=agent_id,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# get a single page of messages
|
||||
page = next(db_iterator, [])
|
||||
return page
|
||||
return records
|
||||
|
||||
def get_agent_archival_cursor(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
order_by: Optional[str] = "created_at",
|
||||
reverse: Optional[bool] = False,
|
||||
) -> List[Passage]:
|
||||
) -> List[PydanticPassage]:
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise LettaUserNotFoundError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
@ -1211,14 +1265,15 @@ class SyncServer(Server):
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over recorde
|
||||
cursor, records = letta_agent.archival_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
# iterate over records
|
||||
records = letta_agent.passage_manager.list_passages(
|
||||
actor=self.default_user, agent_id=agent_id, cursor=cursor, limit=limit,
|
||||
)
|
||||
return records
|
||||
|
||||
def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[Passage]:
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[PydanticPassage]:
|
||||
actor = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
if actor is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
@ -1227,17 +1282,20 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# Insert into archival memory
|
||||
passage_ids = letta_agent.archival_memory.insert(memory_string=memory_contents, return_ids=True)
|
||||
passage_ids = self.passage_manager.insert_passage(
|
||||
agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor, return_ids=True
|
||||
)
|
||||
|
||||
# Update the agent
|
||||
# TODO: should this update the system prompt?
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
# TODO: this is gross, fix
|
||||
return [letta_agent.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
|
||||
return [self.passage_manager.get_passage_by_id(passage_id=passage_id, actor=actor) for passage_id in passage_ids]
|
||||
|
||||
def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str):
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
actor = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
if actor is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
@ -1249,7 +1307,7 @@ class SyncServer(Server):
|
||||
|
||||
# Delete by ID
|
||||
# TODO check if it exists first, and throw error if not
|
||||
letta_agent.archival_memory.storage.delete({"id": memory_id})
|
||||
letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
|
||||
# TODO: return archival memory
|
||||
|
||||
@ -1395,6 +1453,12 @@ class SyncServer(Server):
|
||||
except NoResultFound:
|
||||
logger.error(f"Agent with id {agent_state.id} has nonexistent user {agent_state.user_id}")
|
||||
|
||||
# delete all passages associated with this agent
|
||||
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM
|
||||
passages = self.passage_manager.list_passages(actor=actor, agent_id=agent_state.id)
|
||||
for passage in passages:
|
||||
self.passage_manager.delete_passage_by_id(passage.id, actor=actor)
|
||||
|
||||
# First, if the agent is in the in-memory cache we should remove it
|
||||
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
||||
try:
|
||||
@ -1437,7 +1501,7 @@ class SyncServer(Server):
|
||||
self.ms.delete_api_key(api_key=api_key)
|
||||
return api_key_obj
|
||||
|
||||
def delete_source(self, source_id: str, actor: User):
|
||||
def delete_source(self, source_id: str, actor: PydanticUser):
|
||||
"""Delete a data source"""
|
||||
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
@ -1447,7 +1511,7 @@ class SyncServer(Server):
|
||||
|
||||
# TODO: delete data from agent passage stores (?)
|
||||
|
||||
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
|
||||
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: PydanticUser) -> Job:
|
||||
|
||||
# update job
|
||||
job = self.job_manager.get_job_by_id(job_id, actor=actor)
|
||||
@ -1474,6 +1538,7 @@ class SyncServer(Server):
|
||||
user_id: str,
|
||||
connector: DataConnector,
|
||||
source_name: str,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""Load data from a DataConnector into a source for a specified user_id"""
|
||||
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
||||
@ -1488,14 +1553,13 @@ class SyncServer(Server):
|
||||
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
|
||||
# load data into the document store
|
||||
passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user)
|
||||
passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user, agent_id=agent_id)
|
||||
return passage_count, document_count
|
||||
|
||||
def attach_source_to_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
# source_id: str,
|
||||
source_id: Optional[str] = None,
|
||||
source_name: Optional[str] = None,
|
||||
) -> Source:
|
||||
@ -1507,15 +1571,14 @@ class SyncServer(Server):
|
||||
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
|
||||
else:
|
||||
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
||||
# get connection to data source storage
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
|
||||
assert data_source, f"Data source with id={source_id} or name={source_name} does not exist"
|
||||
|
||||
# load agent
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# attach source to agent
|
||||
agent.attach_source(data_source.id, source_connector, self.ms)
|
||||
agent.attach_source(user=user, source_id=data_source.id, source_manager=self.source_manager, ms=self.ms)
|
||||
|
||||
return data_source
|
||||
|
||||
@ -1538,8 +1601,7 @@ class SyncServer(Server):
|
||||
|
||||
# delete all Passage objects with source_id==source_id from agent's archival memory
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
archival_memory = agent.archival_memory
|
||||
archival_memory.storage.delete({"source_id": source_id})
|
||||
agent.passage_manager.delete_passages(actor=user, limit=100, source_id=source_id)
|
||||
|
||||
# delete agent-source mapping
|
||||
self.ms.detach_source(agent_id=agent_id, source_id=source_id)
|
||||
@ -1553,11 +1615,11 @@ class SyncServer(Server):
|
||||
|
||||
return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids]
|
||||
|
||||
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
|
||||
def list_data_source_passages(self, user_id: str, source_id: str) -> List[PydanticPassage]:
|
||||
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
|
||||
return []
|
||||
|
||||
def list_all_sources(self, actor: User) -> List[Source]:
|
||||
def list_all_sources(self, actor: PydanticUser) -> List[Source]:
|
||||
"""List all sources (w/ extra metadata) belonging to a user"""
|
||||
|
||||
sources = self.source_manager.list_sources(actor=actor)
|
||||
@ -1597,7 +1659,7 @@ class SyncServer(Server):
|
||||
|
||||
return sources_with_metadata
|
||||
|
||||
def add_default_external_tools(self, actor: User) -> bool:
|
||||
def add_default_external_tools(self, actor: PydanticUser) -> bool:
|
||||
"""Add default langchain tools. Return true if successful, false otherwise."""
|
||||
success = True
|
||||
tool_creates = ToolCreate.load_default_langchain_tools()
|
||||
@ -1654,7 +1716,7 @@ class SyncServer(Server):
|
||||
save_agent(letta_agent, self.ms)
|
||||
return response
|
||||
|
||||
def get_user_or_default(self, user_id: Optional[str]) -> User:
|
||||
def get_user_or_default(self, user_id: Optional[str]) -> PydanticUser:
|
||||
"""Get the user object for user_id if it exists, otherwise return the default user object"""
|
||||
if user_id is None:
|
||||
user_id = self.user_manager.DEFAULT_USER_ID
|
||||
|
225
letta/services/passage_manager.py
Normal file
225
letta/services/passage_manager.py
Normal file
@ -0,0 +1,225 @@
|
||||
from typing import List, Optional, Dict, Tuple
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.utils import enforce_types
|
||||
|
||||
from letta.embeddings import embedding_model, parse_and_chunk_text
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
from letta.orm.passage import Passage as PassageModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
|
||||
class PassageManager:
|
||||
"""Manager class to handle business logic related to Passages."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch a passage by ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage."""
|
||||
with self.session_maker() as session:
|
||||
passage = PassageModel(**pydantic_passage.model_dump())
|
||||
passage.create(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple passages."""
|
||||
return [self.create_passage(p, actor) for p in passages]
|
||||
|
||||
@enforce_types
|
||||
def insert_passage(self,
|
||||
agent_state: AgentState,
|
||||
agent_id: str,
|
||||
text: str,
|
||||
actor: PydanticUser,
|
||||
return_ids: bool = False
|
||||
) -> List[PydanticPassage]:
|
||||
""" Insert passage(s) into archival memory """
|
||||
|
||||
embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
|
||||
embed_model = embedding_model(agent_state.embedding_config)
|
||||
|
||||
passages = []
|
||||
|
||||
try:
|
||||
# breakup string into passages
|
||||
for text in parse_and_chunk_text(text, embedding_chunk_size):
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
if isinstance(embedding, dict):
|
||||
try:
|
||||
embedding = embedding["data"][0]["embedding"]
|
||||
except (KeyError, IndexError):
|
||||
# TODO as a fallback, see if we can find any lists in the payload
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
passage = self.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_config=agent_state.embedding_config
|
||||
),
|
||||
actor=actor
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
ids = [str(p.id) for p in passages]
|
||||
|
||||
if return_ids:
|
||||
return ids
|
||||
|
||||
return passages
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@enforce_types
|
||||
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
|
||||
"""Update a passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Fetch existing message from database
|
||||
curr_passage = PassageModel.read(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
if not curr_passage:
|
||||
raise ValueError(f"Passage with id {passage_id} does not exist.")
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(curr_passage, key, value)
|
||||
|
||||
# Commit changes
|
||||
curr_passage.update(session, actor=actor)
|
||||
return curr_passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
passage = PassageModel.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
passage.hard_delete(session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Passage with id {passage_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def list_passages(self,
|
||||
actor : PydanticUser,
|
||||
agent_id : Optional[str] = None,
|
||||
file_id : Optional[str] = None,
|
||||
cursor : Optional[str] = None,
|
||||
limit : Optional[int] = 50,
|
||||
query_text : Optional[str] = None,
|
||||
start_date : Optional[datetime] = None,
|
||||
end_date : Optional[datetime] = None,
|
||||
source_id : Optional[str] = None,
|
||||
embed_query : bool = False,
|
||||
embedding_config: Optional[EmbeddingConfig] = None
|
||||
) -> List[PydanticPassage]:
|
||||
"""List passages with pagination."""
|
||||
with self.session_maker() as session:
|
||||
filters = {"organization_id": actor.organization_id}
|
||||
if agent_id:
|
||||
filters["agent_id"] = agent_id
|
||||
if file_id:
|
||||
filters["file_id"] = file_id
|
||||
if source_id:
|
||||
filters["source_id"] = source_id
|
||||
|
||||
embedded_text = None
|
||||
if embed_query:
|
||||
assert embedding_config is not None
|
||||
|
||||
# Embed the text
|
||||
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
||||
|
||||
# Pad the embedding with zeros
|
||||
embedded_text = np.array(embedded_text)
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
results = PassageModel.list(
|
||||
db_session=session,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
query_text=query_text if not embedded_text else None,
|
||||
query_embedding=embedded_text,
|
||||
**filters
|
||||
)
|
||||
return [p.to_pydantic() for p in results]
|
||||
|
||||
@enforce_types
|
||||
def size(
|
||||
self,
|
||||
actor : PydanticUser,
|
||||
agent_id : Optional[str] = None,
|
||||
**kwargs
|
||||
) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
|
||||
Args:
|
||||
actor : The user requesting the count
|
||||
agent_id: The agent ID
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
return PassageModel.size(db_session=session, actor=actor, agent_id=agent_id, **kwargs)
|
||||
|
||||
def delete_passages(self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
cursor: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
source_id: Optional[str] = None
|
||||
) -> bool:
|
||||
|
||||
passages = self.list_passages(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
query_text=query_text,
|
||||
source_id=source_id)
|
||||
|
||||
for passage in passages:
|
||||
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
@ -64,7 +64,7 @@ class SourceManager:
|
||||
return source.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticSource]:
|
||||
def list_sources(self, actor: PydanticUser, cursor: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]:
|
||||
"""List all sources with optional pagination."""
|
||||
with self.session_maker() as session:
|
||||
sources = SourceModel.list(
|
||||
@ -72,6 +72,7 @@ class SourceManager:
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
organization_id=actor.organization_id,
|
||||
**kwargs,
|
||||
)
|
||||
return [source.to_pydantic() for source in sources]
|
||||
|
||||
|
@ -17,6 +17,8 @@ class ToolSettings(BaseSettings):
|
||||
|
||||
class ModelSettings(BaseSettings):
|
||||
|
||||
model_config = SettingsConfigDict(env_file='.env')
|
||||
|
||||
# env_prefix='my_prefix_'
|
||||
|
||||
# when we use /completions APIs (instead of /chat/completions), we need to specify a model wrapper
|
||||
|
@ -29,10 +29,62 @@ def agent_obj():
|
||||
client.delete_agent(agent_obj.agent_state.id)
|
||||
|
||||
|
||||
def query_in_search_results(search_results, query):
|
||||
for result in search_results:
|
||||
if query.lower() in result["content"].lower():
|
||||
return True
|
||||
return False
|
||||
|
||||
def test_archival(agent_obj):
|
||||
base_functions.archival_memory_insert(agent_obj, "banana")
|
||||
base_functions.archival_memory_search(agent_obj, "banana")
|
||||
base_functions.archival_memory_search(agent_obj, "banana", page=0)
|
||||
"""Test archival memory functions comprehensively."""
|
||||
# Test 1: Basic insertion and retrieval
|
||||
base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat")
|
||||
base_functions.archival_memory_insert(agent_obj, "The dog plays in the park")
|
||||
base_functions.archival_memory_insert(agent_obj, "Python is a programming language")
|
||||
|
||||
# Test exact text search
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "cat")
|
||||
assert query_in_search_results(results, "cat")
|
||||
|
||||
# Test semantic search (should return animal-related content)
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "animal pets")
|
||||
assert query_in_search_results(results, "cat") or query_in_search_results(results, "dog")
|
||||
|
||||
# Test unrelated search (should not return animal content)
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "programming computers")
|
||||
assert query_in_search_results(results, "python")
|
||||
|
||||
# Test 2: Test pagination
|
||||
# Insert more items to test pagination
|
||||
for i in range(10):
|
||||
base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}")
|
||||
|
||||
# Get first page
|
||||
page0_results, next_page = base_functions.archival_memory_search(agent_obj, "Test passage", page=0)
|
||||
# Get second page
|
||||
page1_results, _ = base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page)
|
||||
|
||||
assert page0_results != page1_results
|
||||
assert query_in_search_results(page0_results, "Test passage")
|
||||
assert query_in_search_results(page1_results, "Test passage")
|
||||
|
||||
# Test 3: Test complex text patterns
|
||||
base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John")
|
||||
base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week")
|
||||
base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching")
|
||||
|
||||
# Search for meeting-related content
|
||||
results, _ = base_functions.archival_memory_search(agent_obj, "meeting schedule")
|
||||
assert query_in_search_results(results, "meeting")
|
||||
assert query_in_search_results(results, "2024-01-15") or query_in_search_results(results, "next week")
|
||||
|
||||
# Test 4: Test error handling
|
||||
# Test invalid page number
|
||||
try:
|
||||
base_functions.archival_memory_search(agent_obj, "test", page="invalid")
|
||||
assert False, "Should have raised ValueError"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
def test_recall(agent_obj):
|
||||
|
@ -649,7 +649,6 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
system=agent.system,
|
||||
agent_id=agent.id,
|
||||
memory=agent.memory,
|
||||
archival_memory=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
actor=default_user,
|
||||
|
@ -5,9 +5,11 @@ from datetime import datetime, timedelta
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.embeddings import embedding_model
|
||||
import letta.utils as utils
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.metadata import AgentModel
|
||||
from letta.orm.sqlite_functions import verify_embedding_dimension, convert_array
|
||||
from letta.orm import (
|
||||
Block,
|
||||
BlocksAgents,
|
||||
@ -15,6 +17,7 @@ from letta.orm import (
|
||||
Job,
|
||||
Message,
|
||||
Organization,
|
||||
Passage,
|
||||
SandboxConfig,
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
@ -40,6 +43,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import MessageUpdate
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.sandbox_config import (
|
||||
E2BSandboxConfig,
|
||||
LocalSandboxConfig,
|
||||
@ -55,6 +59,7 @@ from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
@ -83,6 +88,7 @@ def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Message))
|
||||
session.execute(delete(Passage))
|
||||
session.execute(delete(Job))
|
||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||
session.execute(delete(BlocksAgents))
|
||||
@ -132,6 +138,16 @@ def default_source(server: SyncServer, default_user):
|
||||
yield source
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_file(server: SyncServer, default_source, default_user, default_organization):
|
||||
file = server.source_manager.create_file(
|
||||
PydanticFileMetadata(
|
||||
file_name="test_file", organization_id=default_organization.id, source_id=default_source.id),
|
||||
actor=default_user,
|
||||
)
|
||||
yield file
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
@ -197,6 +213,41 @@ def print_tool(server: SyncServer, default_user, default_organization):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_passage_fixture(server: SyncServer, default_user, default_file, sarah_agent):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
# Set up passage
|
||||
dummy_embedding = [0.0] * 2
|
||||
message = PydanticPassage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
text="Hello, world!",
|
||||
embedding=dummy_embedding,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG
|
||||
)
|
||||
|
||||
msg = server.passage_manager.create_passage(message, actor=default_user)
|
||||
yield msg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_test_passages(server: SyncServer, default_file, default_user, sarah_agent) -> list[PydanticPassage]:
|
||||
"""Helper function to create test passages for all tests"""
|
||||
dummy_embedding = [0] * 2
|
||||
passages = [
|
||||
PydanticPassage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
text=f"Test passage {i}",
|
||||
embedding=dummy_embedding,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG
|
||||
) for i in range(4)
|
||||
]
|
||||
server.passage_manager.create_many_passages(passages, actor=default_user)
|
||||
return passages
|
||||
|
||||
@pytest.fixture
|
||||
def hello_world_message_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
@ -353,6 +404,288 @@ def test_list_organizations_pagination(server: SyncServer):
|
||||
assert len(orgs) == 0
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Passage Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
def test_passage_create(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test creating a passage using hello_world_passage_fixture fixture"""
|
||||
assert hello_world_passage_fixture.id is not None
|
||||
assert hello_world_passage_fixture.text == "Hello, world!"
|
||||
|
||||
# Verify we can retrieve it
|
||||
retrieved = server.passage_manager.get_passage_by_id(
|
||||
hello_world_passage_fixture.id,
|
||||
actor=default_user,
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_passage_fixture.id
|
||||
assert retrieved.text == hello_world_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_get_by_id(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test retrieving a passage by ID"""
|
||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == hello_world_passage_fixture.id
|
||||
assert retrieved.text == hello_world_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_update(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test updating a passage"""
|
||||
new_text = "Updated text"
|
||||
hello_world_passage_fixture.text = new_text
|
||||
updated = server.passage_manager.update_passage_by_id(hello_world_passage_fixture.id, hello_world_passage_fixture, actor=default_user)
|
||||
assert updated is not None
|
||||
assert updated.text == new_text
|
||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
assert retrieved.text == new_text
|
||||
|
||||
|
||||
def test_passage_delete(server: SyncServer, hello_world_passage_fixture, default_user):
|
||||
"""Test deleting a passage"""
|
||||
server.passage_manager.delete_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
retrieved = server.passage_manager.get_passage_by_id(hello_world_passage_fixture.id, actor=default_user)
|
||||
assert retrieved is None
|
||||
|
||||
|
||||
def test_passage_size(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test counting passages with filters"""
|
||||
base_passage = hello_world_passage_fixture
|
||||
|
||||
# Test total count
|
||||
total = server.passage_manager.size(actor=default_user)
|
||||
assert total == 5 # base passage + 4 test passages
|
||||
# TODO: change login passage to be a system not user passage
|
||||
|
||||
# Test count with agent filter
|
||||
agent_count = server.passage_manager.size(actor=default_user, agent_id=base_passage.agent_id)
|
||||
assert agent_count == 5
|
||||
|
||||
# Test count with role filter
|
||||
role_count = server.passage_manager.size(actor=default_user)
|
||||
assert role_count == 5
|
||||
|
||||
# Test count with non-existent filter
|
||||
empty_count = server.passage_manager.size(actor=default_user, agent_id="non-existent")
|
||||
assert empty_count == 0
|
||||
|
||||
|
||||
def test_passage_listing_basic(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test basic passage listing with limit"""
|
||||
results = server.passage_manager.list_passages(actor=default_user, limit=3)
|
||||
assert len(results) == 3
|
||||
|
||||
|
||||
def test_passage_listing_cursor(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user):
|
||||
"""Test cursor-based pagination functionality"""
|
||||
|
||||
# Make sure there are 5 passages
|
||||
assert server.passage_manager.size(actor=default_user) == 5
|
||||
|
||||
# Get first page
|
||||
first_page = server.passage_manager.list_passages(actor=default_user, limit=3)
|
||||
assert len(first_page) == 3
|
||||
|
||||
last_id_on_first_page = first_page[-1].id
|
||||
|
||||
# Get second page
|
||||
second_page = server.passage_manager.list_passages(
|
||||
actor=default_user, cursor=last_id_on_first_page, limit=3
|
||||
)
|
||||
assert len(second_page) == 2 # Should have 2 remaining passages
|
||||
assert all(r1.id != r2.id for r1 in first_page for r2 in second_page)
|
||||
|
||||
|
||||
def test_passage_listing_filtering(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
||||
"""Test filtering passages by agent ID"""
|
||||
agent_results = server.passage_manager.list_passages(agent_id=sarah_agent.id, actor=default_user, limit=10)
|
||||
assert len(agent_results) == 5 # base passage + 4 test passages
|
||||
assert all(msg.agent_id == hello_world_passage_fixture.agent_id for msg in agent_results)
|
||||
|
||||
|
||||
def test_passage_listing_text_search(server: SyncServer, hello_world_passage_fixture, create_test_passages, default_user, sarah_agent):
|
||||
"""Test searching passages by text content"""
|
||||
search_results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, query_text="Test passage", limit=10
|
||||
)
|
||||
assert len(search_results) == 4
|
||||
assert all("Test passage" in msg.text for msg in search_results)
|
||||
|
||||
# Test no results
|
||||
search_results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id, actor=default_user, query_text="Letta", limit=10
|
||||
)
|
||||
assert len(search_results) == 0
|
||||
|
||||
|
||||
def test_passage_listing_date_range_filtering(server: SyncServer, hello_world_passage_fixture, default_user, default_file, sarah_agent):
|
||||
"""Test filtering passages by date range with various scenarios"""
|
||||
# Set up test data with known dates
|
||||
base_time = datetime.utcnow()
|
||||
|
||||
# Create passages at different times
|
||||
passages = []
|
||||
time_offsets = [
|
||||
timedelta(days=-2), # 2 days ago
|
||||
timedelta(days=-1), # Yesterday
|
||||
timedelta(hours=-2), # 2 hours ago
|
||||
timedelta(minutes=-30), # 30 minutes ago
|
||||
timedelta(minutes=-1), # 1 minute ago
|
||||
timedelta(minutes=0), # Now
|
||||
]
|
||||
|
||||
for i, offset in enumerate(time_offsets):
|
||||
timestamp = base_time + offset
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
file_id=default_file.id,
|
||||
text=f"Test passage {i}",
|
||||
embedding=[0.1, 0.2, 0.3],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
created_at=timestamp
|
||||
),
|
||||
actor=default_user
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
# Test cases
|
||||
test_cases = [
|
||||
{
|
||||
"name": "Recent passages (last hour)",
|
||||
"start_date": base_time - timedelta(hours=1),
|
||||
"end_date": base_time + timedelta(minutes=1),
|
||||
"expected_count": 1 + 3, # Should include base + -30min, -1min, and now
|
||||
},
|
||||
{
|
||||
"name": "Yesterday's passages",
|
||||
"start_date": base_time - timedelta(days=1, hours=12),
|
||||
"end_date": base_time - timedelta(hours=12),
|
||||
"expected_count": 1, # Should only include yesterday's passage
|
||||
},
|
||||
{
|
||||
"name": "Future time range",
|
||||
"start_date": base_time + timedelta(days=1),
|
||||
"end_date": base_time + timedelta(days=2),
|
||||
"expected_count": 0, # Should find no passages
|
||||
},
|
||||
{
|
||||
"name": "All time",
|
||||
"start_date": base_time - timedelta(days=3),
|
||||
"end_date": base_time + timedelta(days=1),
|
||||
"expected_count": 1 + len(passages), # Should find all passages
|
||||
},
|
||||
{
|
||||
"name": "Exact timestamp match",
|
||||
"start_date": passages[0].created_at - timedelta(microseconds=1),
|
||||
"end_date": passages[0].created_at + timedelta(microseconds=1),
|
||||
"expected_count": 1, # Should find exactly one passage
|
||||
},
|
||||
{
|
||||
"name": "Small time window",
|
||||
"start_date": base_time - timedelta(seconds=30),
|
||||
"end_date": base_time + timedelta(seconds=30),
|
||||
"expected_count": 1 + 1, # date + "now"
|
||||
}
|
||||
]
|
||||
|
||||
# Run test cases
|
||||
for case in test_cases:
|
||||
results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
start_date=case["start_date"],
|
||||
end_date=case["end_date"],
|
||||
limit=10
|
||||
)
|
||||
|
||||
# Verify count
|
||||
assert len(results) == case["expected_count"], \
|
||||
f"Test case '{case['name']}' failed: expected {case['expected_count']} passages, got {len(results)}"
|
||||
|
||||
# Test edge cases
|
||||
|
||||
# Test with start_date but no end_date
|
||||
results_start_only = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
start_date=base_time - timedelta(minutes=2),
|
||||
end_date=None,
|
||||
limit=10
|
||||
)
|
||||
assert len(results_start_only) >= 2, "Should find passages after start_date"
|
||||
|
||||
# Test with end_date but no start_date
|
||||
results_end_only = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
start_date=None,
|
||||
end_date=base_time - timedelta(days=1),
|
||||
limit=10
|
||||
)
|
||||
assert len(results_end_only) >= 1, "Should find passages before end_date"
|
||||
|
||||
# Test limit enforcement
|
||||
limited_results = server.passage_manager.list_passages(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
start_date=base_time - timedelta(days=3),
|
||||
end_date=base_time + timedelta(days=1),
|
||||
limit=3
|
||||
)
|
||||
assert len(limited_results) <= 3, "Should respect the limit parameter"
|
||||
|
||||
|
||||
def test_passage_vector_search(server: SyncServer, default_user, default_file, sarah_agent):
|
||||
"""Test vector search functionality for passages."""
|
||||
passage_manager = server.passage_manager
|
||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
# Create passages with different embeddings
|
||||
test_passages = [
|
||||
"I like red",
|
||||
"random text",
|
||||
"blue shoes",
|
||||
]
|
||||
|
||||
for text in test_passages:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
)
|
||||
created_passage = passage_manager.create_passage(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
assert passage_manager.size(actor=default_user) == len(passages)
|
||||
|
||||
# Query vector similar to "cats" embedding
|
||||
query_key = "What's my favorite color?"
|
||||
|
||||
# List passages with vector search
|
||||
results = passage_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text=query_key,
|
||||
limit=3,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Verify results are ordered by similarity
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "I like red"
|
||||
assert results[1].text == "random text" # For some reason the embedding model doesn't like "blue shoes"
|
||||
assert results[2].text == "blue shoes"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# User Manager Tests
|
||||
# ======================================================================================================================
|
||||
@ -834,8 +1167,6 @@ def test_delete_block(server: SyncServer, default_user):
|
||||
# ======================================================================================================================
|
||||
# Source Manager Tests - Sources
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_source(server: SyncServer, default_user):
|
||||
"""Test creating a new source."""
|
||||
source_pydantic = PydanticSource(
|
||||
@ -1049,8 +1380,6 @@ def test_delete_file(server: SyncServer, default_user, default_source):
|
||||
# ======================================================================================================================
|
||||
# AgentsTagsManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_add_tag_to_agent(server: SyncServer, sarah_agent, default_user):
|
||||
# Add a tag to the agent
|
||||
tag_name = "test_tag"
|
||||
|
@ -117,6 +117,9 @@ def test_user_message_memory(server, user_id, agent_id):
|
||||
@pytest.mark.order(3)
|
||||
def test_load_data(server, user_id, agent_id):
|
||||
# create source
|
||||
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
||||
assert len(passages_before) == 0
|
||||
|
||||
source = server.source_manager.create_source(
|
||||
Source(name="test_source", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=server.default_user
|
||||
)
|
||||
@ -130,19 +133,17 @@ def test_load_data(server, user_id, agent_id):
|
||||
"Shishir loves indian food",
|
||||
]
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
server.load_data(user_id, connector, source.name, agent_id=agent_id)
|
||||
|
||||
# @pytest.mark.order(3)
|
||||
# def test_attach_source_to_agent(server, user_id, agent_id):
|
||||
# check archival memory size
|
||||
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||
assert len(passages_before) == 0
|
||||
|
||||
# attach source
|
||||
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||
|
||||
# check archival memory size
|
||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
|
||||
assert len(passages_after) == 5
|
||||
|
||||
|
||||
@ -182,41 +183,42 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
|
||||
|
||||
@pytest.mark.order(6)
|
||||
def test_get_archival_memory(server, user_id, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||
assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
|
||||
cursor1 = passages_1[-1].id
|
||||
passages_2 = server.get_agent_archival_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
after=cursor1,
|
||||
order_by="text",
|
||||
)
|
||||
cursor2 = passages_2[-1].id
|
||||
passages_3 = server.get_agent_archival_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
before=cursor2,
|
||||
limit=1000,
|
||||
order_by="text",
|
||||
)
|
||||
passages_3[-1].id
|
||||
# assert passages_1[0].text == "Cinderella wore a blue dress"
|
||||
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
# TODO: Out-of-date test. pagination commands are off
|
||||
# @pytest.mark.order(6)
|
||||
# def test_get_archival_memory(server, user_id, agent_id):
|
||||
# # test archival memory cursor pagination
|
||||
# passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||
# assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
|
||||
# cursor1 = passages_1[-1].id
|
||||
# passages_2 = server.get_agent_archival_cursor(
|
||||
# user_id=user_id,
|
||||
# agent_id=agent_id,
|
||||
# reverse=False,
|
||||
# after=cursor1,
|
||||
# order_by="text",
|
||||
# )
|
||||
# cursor2 = passages_2[-1].id
|
||||
# passages_3 = server.get_agent_archival_cursor(
|
||||
# user_id=user_id,
|
||||
# agent_id=agent_id,
|
||||
# reverse=False,
|
||||
# before=cursor2,
|
||||
# limit=1000,
|
||||
# order_by="text",
|
||||
# )
|
||||
# passages_3[-1].id
|
||||
# # assert passages_1[0].text == "Cinderella wore a blue dress"
|
||||
# assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
# assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1)
|
||||
assert len(passage_1) == 1
|
||||
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000)
|
||||
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
# test safe empty return
|
||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||
assert len(passage_none) == 0
|
||||
# # test archival memory
|
||||
# passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1)
|
||||
# assert len(passage_1) == 1
|
||||
# passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000)
|
||||
# assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
# # test safe empty return
|
||||
# passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||
# assert len(passage_none) == 0
|
||||
|
||||
|
||||
def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
|
||||
|
42
tests/test_vector_embeddings.py
Normal file
42
tests/test_vector_embeddings.py
Normal file
@ -0,0 +1,42 @@
|
||||
import numpy as np
|
||||
import sqlite3
|
||||
import base64
|
||||
from numpy.testing import assert_array_almost_equal
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.orm.sqlalchemy_base import adapt_array, convert_array
|
||||
from letta.orm.sqlite_functions import verify_embedding_dimension
|
||||
|
||||
def test_vector_conversions():
|
||||
"""Test the vector conversion functions"""
|
||||
# Create test data
|
||||
original = np.random.random(4096).astype(np.float32)
|
||||
print(f"Original shape: {original.shape}")
|
||||
|
||||
# Test full conversion cycle
|
||||
encoded = adapt_array(original)
|
||||
print(f"Encoded type: {type(encoded)}")
|
||||
print(f"Encoded length: {len(encoded)}")
|
||||
|
||||
decoded = convert_array(encoded)
|
||||
print(f"Decoded shape: {decoded.shape}")
|
||||
print(f"Dimension verification: {verify_embedding_dimension(decoded)}")
|
||||
|
||||
# Verify data integrity
|
||||
np.testing.assert_array_almost_equal(original, decoded)
|
||||
print("✓ Data integrity verified")
|
||||
|
||||
# Test with a list
|
||||
list_data = original.tolist()
|
||||
encoded_list = adapt_array(list_data)
|
||||
decoded_list = convert_array(encoded_list)
|
||||
np.testing.assert_array_almost_equal(original, decoded_list)
|
||||
print("✓ List conversion verified")
|
||||
|
||||
# Test None handling
|
||||
assert adapt_array(None) is None
|
||||
assert convert_array(None) is None
|
||||
print("✓ None handling verified")
|
||||
|
||||
# Run the tests
|
Loading…
Reference in New Issue
Block a user