fix: Modify the list ORM function (#2208)

This commit is contained in:
Matthew Zhou 2024-12-09 19:35:58 -08:00 committed by GitHub
parent af5ef6d174
commit 666e4259cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 83 additions and 106 deletions

View File

@ -3128,7 +3128,7 @@ class LocalClient(AbstractClient):
return self.server.get_agent_recall_cursor(
user_id=self.user_id,
agent_id=agent_id,
cursor=cursor,
before=cursor,
limit=limit,
reverse=True,
)

View File

@ -2,7 +2,7 @@ from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, List, Literal, Optional, Type
from sqlalchemy import String, func, select
from sqlalchemy import String, desc, func, or_, select
from sqlalchemy.exc import DBAPIError
from sqlalchemy.orm import Mapped, Session, mapped_column
@ -60,14 +60,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
end_date: Optional[datetime] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
ascending: bool = True,
**kwargs,
) -> List[Type["SqlalchemyBase"]]:
"""List records with advanced filtering and pagination options."""
"""
List records with cursor-based pagination, ordering by created_at.
Cursor is an ID, but pagination is based on the cursor object's created_at value.
"""
if start_date and end_date and start_date > end_date:
raise ValueError("start_date must be earlier than or equal to end_date")
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
with db_session as session:
# If cursor provided, get the reference object
cursor_obj = None
if cursor:
cursor_obj = session.get(cls, cursor)
if not cursor_obj:
raise NoResultFound(f"No {cls.__name__} found with id {cursor}")
query = select(cls)
# Apply filtering logic
@ -80,22 +91,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# Date range filtering
if start_date:
query = query.filter(cls.created_at >= start_date)
query = query.filter(cls.created_at > start_date)
if end_date:
query = query.filter(cls.created_at <= end_date)
query = query.filter(cls.created_at < end_date)
# Cursor-based pagination
if cursor:
query = query.where(cls.id > cursor)
# Cursor-based pagination using created_at
# TODO: There is a really nasty race condition issue here with Sqlite
# TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
if cursor_obj:
if ascending:
query = query.where(cls.created_at >= cursor_obj.created_at).where(
or_(cls.created_at > cursor_obj.created_at, cls.id > cursor_obj.id)
)
else:
query = query.where(cls.created_at <= cursor_obj.created_at).where(
or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
)
# Apply text search
if query_text:
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
# Handle ordering and soft deletes
# Handle soft deletes
if hasattr(cls, "is_deleted"):
query = query.where(cls.is_deleted == False)
query = query.order_by(cls.id).limit(limit)
# Apply ordering by created_at
if ascending:
query = query.order_by(cls.created_at, cls.id)
else:
query = query.order_by(desc(cls.created_at), desc(cls.id))
query = query.limit(limit)
return list(session.execute(query).scalars())

View File

@ -420,7 +420,7 @@ def get_agent_messages(
return server.get_agent_recall_cursor(
user_id=actor.id,
agent_id=agent_id,
cursor=before,
before=before,
limit=limit,
reverse=True,
return_message_object=msg_object,

View File

@ -101,11 +101,6 @@ class Server(object):
"""List all available agents to a user"""
raise NotImplementedError
@abstractmethod
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
"""Paginated query of in-context messages in agent message queue"""
raise NotImplementedError
@abstractmethod
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
@ -1173,55 +1168,6 @@ class SyncServer(Server):
message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user)
return message
def get_agent_messages(
self,
agent_id: str,
start: int,
count: int,
) -> Union[List[Message], List[LettaMessage]]:
"""Paginated query of all messages in agent message queue"""
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id)
if start < 0 or count < 0:
raise ValueError("Start and count values should be non-negative")
if start + count < len(letta_agent._messages): # messages can be returned from whats in memory
# Reverse the list to make it in reverse chronological order
reversed_messages = letta_agent._messages[::-1]
# Check if start is within the range of the list
if start >= len(reversed_messages):
raise IndexError("Start index is out of range")
# Calculate the end index, ensuring it does not exceed the list length
end_index = min(start + count, len(reversed_messages))
# Slice the list for pagination
messages = reversed_messages[start:end_index]
else:
# need to access persistence manager for additional messages
# get messages using message manager
page = letta_agent.message_manager.list_user_messages_for_agent(
agent_id=agent_id,
actor=self.default_user,
cursor=start,
limit=count,
)
messages = page
assert all(isinstance(m, Message) for m in messages)
## Convert to json
## Add a tag indicating in-context or not
# json_messages = [record.to_json() for record in messages]
# in_context_message_ids = [str(m.id) for m in letta_agent._messages]
# for d in json_messages:
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
return messages
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
"""Paginated query of all messages in agent archival memory"""
if self.user_manager.get_user_by_id(user_id=user_id) is None:
@ -1303,7 +1249,8 @@ class SyncServer(Server):
self,
user_id: str,
agent_id: str,
cursor: Optional[str] = None,
after: Optional[str] = None,
before: Optional[str] = None,
limit: Optional[int] = 100,
reverse: Optional[bool] = False,
return_message_object: bool = True,
@ -1320,12 +1267,15 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id)
# iterate over records
# TODO: Check "order_by", "order"
start_date = self.message_manager.get_message_by_id(after, actor=actor).created_at if after else None
end_date = self.message_manager.get_message_by_id(before, actor=actor).created_at if before else None
records = letta_agent.message_manager.list_messages_for_agent(
agent_id=agent_id,
actor=actor,
cursor=cursor,
start_date=start_date,
end_date=end_date,
limit=limit,
ascending=not reverse,
)
assert all(isinstance(m, Message) for m in records)

View File

@ -119,6 +119,7 @@ class MessageManager:
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
query_text: Optional[str] = None,
ascending: bool = True,
) -> List[PydanticMessage]:
"""List user messages with flexible filtering and pagination options.
@ -159,6 +160,7 @@ class MessageManager:
limit: Optional[int] = 50,
filters: Optional[Dict] = None,
query_text: Optional[str] = None,
ascending: bool = True,
) -> List[PydanticMessage]:
"""List messages with flexible filtering and pagination options.
@ -188,6 +190,7 @@ class MessageManager:
end_date=end_date,
limit=limit,
query_text=query_text,
ascending=ascending,
**message_filters,
)

View File

@ -459,7 +459,7 @@ class ToolExecutionSandbox:
Generate the code string to call the function.
Args:
inject_agent_state (bool): Whether to inject the agent's state as an input into the tool
inject_agent_state (bool): Whether to inject the axgent's state as an input into the tool
Returns:
str: Generated code string for calling the tool

View File

@ -1,4 +1,5 @@
import os
import time
from datetime import datetime, timedelta
import pytest
@ -73,8 +74,8 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
azure_version=None,
azure_deployment=None,
)
using_sqlite = not bool(os.getenv("LETTA_PG_URI"))
CREATE_DELAY_SQLITE = 1
USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
@pytest.fixture(autouse=True)
@ -911,6 +912,8 @@ def test_list_sources(server: SyncServer, default_user):
"""Test listing sources with pagination."""
# Create multiple sources
server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
# List sources without pagination
@ -1004,6 +1007,8 @@ def test_list_files(server: SyncServer, default_user, default_source):
PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_file(
PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
@ -1184,6 +1189,8 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
config=LocalSandboxConfig(sandbox_dir=""),
)
server.sandbox_config_manager.create_or_update_sandbox_config(config_a, actor=default_user)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.sandbox_config_manager.create_or_update_sandbox_config(config_b, actor=default_user)
# List configs without pagination
@ -1239,6 +1246,8 @@ def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, defau
env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1")
env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2")
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
# List env vars without pagination
@ -1299,7 +1308,7 @@ def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agen
assert default_block.label not in labels
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.blocks_agents_manager.add_block_to_agent(
@ -1361,7 +1370,7 @@ def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_u
assert len(agent_ids) == 2
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
block_manager = BlockManager()
block_manager.delete_block(block_id=default_block.id, actor=default_user)
@ -1401,7 +1410,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
assert print_tool.name not in names
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name")
@ -1447,7 +1456,7 @@ def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_us
assert len(agent_ids) == 2
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
tool_manager = ToolManager()
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)

View File

@ -161,37 +161,25 @@ def test_user_message(server, user_id, agent_id):
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# TODO: Add this back, this is broken on main
# @pytest.mark.order(5)
# def test_get_recall_memory(server, org_id, user_id, agent_id):
# # test recall memory cursor pagination
# messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
# cursor1 = messages_1[-1].id
# messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
# messages_2[-1].id
# messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
# messages_3[-1].id
# assert messages_3[-1].created_at >= messages_3[0].created_at
# assert len(messages_3) == len(messages_1) + len(messages_2)
# messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
# assert len(messages_4) == 1
#
# # test in-context message ids
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
# message_ids = [m.id for m in messages_3]
# for message_id in in_context_ids:
# assert message_id in message_ids, f"{message_id} not in {message_ids}"
#
# # test recall memory
# messages_1 = server.get_agent_messages(agent_id=agent_id, start=0, count=1)
# assert len(messages_1) == 1
# messages_2 = server.get_agent_messages(agent_id=agent_id, start=1, count=1000)
# messages_3 = server.get_agent_messages(agent_id=agent_id, start=1, count=2)
# # not sure exactly how many messages there should be
# assert len(messages_2) > len(messages_3)
# # test safe empty return
# messages_none = server.get_agent_messages(agent_id=agent_id, start=1000, count=1000)
# assert len(messages_none) == 0
@pytest.mark.order(5)
def test_get_recall_memory(server, org_id, user_id, agent_id):
# test recall memory cursor pagination
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
cursor1 = messages_1[-1].id
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
messages_2[-1].id
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
messages_3[-1].id
assert messages_3[-1].created_at >= messages_3[0].created_at
assert len(messages_3) == len(messages_1) + len(messages_2)
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test in-context message ids
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
message_ids = [m.id for m in messages_3]
for message_id in in_context_ids:
assert message_id in message_ids, f"{message_id} not in {message_ids}"
@pytest.mark.order(6)