mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: Modify the list
ORM function (#2208)
This commit is contained in:
parent
af5ef6d174
commit
666e4259cf
@ -3128,7 +3128,7 @@ class LocalClient(AbstractClient):
|
||||
return self.server.get_agent_recall_cursor(
|
||||
user_id=self.user_id,
|
||||
agent_id=agent_id,
|
||||
cursor=cursor,
|
||||
before=cursor,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List, Literal, Optional, Type
|
||||
|
||||
from sqlalchemy import String, func, select
|
||||
from sqlalchemy import String, desc, func, or_, select
|
||||
from sqlalchemy.exc import DBAPIError
|
||||
from sqlalchemy.orm import Mapped, Session, mapped_column
|
||||
|
||||
@ -60,14 +60,25 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
end_date: Optional[datetime] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
ascending: bool = True,
|
||||
**kwargs,
|
||||
) -> List[Type["SqlalchemyBase"]]:
|
||||
"""List records with advanced filtering and pagination options."""
|
||||
"""
|
||||
List records with cursor-based pagination, ordering by created_at.
|
||||
Cursor is an ID, but pagination is based on the cursor object's created_at value.
|
||||
"""
|
||||
if start_date and end_date and start_date > end_date:
|
||||
raise ValueError("start_date must be earlier than or equal to end_date")
|
||||
|
||||
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
||||
with db_session as session:
|
||||
# If cursor provided, get the reference object
|
||||
cursor_obj = None
|
||||
if cursor:
|
||||
cursor_obj = session.get(cls, cursor)
|
||||
if not cursor_obj:
|
||||
raise NoResultFound(f"No {cls.__name__} found with id {cursor}")
|
||||
|
||||
query = select(cls)
|
||||
|
||||
# Apply filtering logic
|
||||
@ -80,22 +91,38 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
# Date range filtering
|
||||
if start_date:
|
||||
query = query.filter(cls.created_at >= start_date)
|
||||
query = query.filter(cls.created_at > start_date)
|
||||
if end_date:
|
||||
query = query.filter(cls.created_at <= end_date)
|
||||
query = query.filter(cls.created_at < end_date)
|
||||
|
||||
# Cursor-based pagination
|
||||
if cursor:
|
||||
query = query.where(cls.id > cursor)
|
||||
# Cursor-based pagination using created_at
|
||||
# TODO: There is a really nasty race condition issue here with Sqlite
|
||||
# TODO: If they have the same created_at timestamp, this query does NOT match for whatever reason
|
||||
if cursor_obj:
|
||||
if ascending:
|
||||
query = query.where(cls.created_at >= cursor_obj.created_at).where(
|
||||
or_(cls.created_at > cursor_obj.created_at, cls.id > cursor_obj.id)
|
||||
)
|
||||
else:
|
||||
query = query.where(cls.created_at <= cursor_obj.created_at).where(
|
||||
or_(cls.created_at < cursor_obj.created_at, cls.id < cursor_obj.id)
|
||||
)
|
||||
|
||||
# Apply text search
|
||||
if query_text:
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
|
||||
# Handle ordering and soft deletes
|
||||
# Handle soft deletes
|
||||
if hasattr(cls, "is_deleted"):
|
||||
query = query.where(cls.is_deleted == False)
|
||||
query = query.order_by(cls.id).limit(limit)
|
||||
|
||||
# Apply ordering by created_at
|
||||
if ascending:
|
||||
query = query.order_by(cls.created_at, cls.id)
|
||||
else:
|
||||
query = query.order_by(desc(cls.created_at), desc(cls.id))
|
||||
|
||||
query = query.limit(limit)
|
||||
|
||||
return list(session.execute(query).scalars())
|
||||
|
||||
|
@ -420,7 +420,7 @@ def get_agent_messages(
|
||||
return server.get_agent_recall_cursor(
|
||||
user_id=actor.id,
|
||||
agent_id=agent_id,
|
||||
cursor=before,
|
||||
before=before,
|
||||
limit=limit,
|
||||
reverse=True,
|
||||
return_message_object=msg_object,
|
||||
|
@ -101,11 +101,6 @@ class Server(object):
|
||||
"""List all available agents to a user"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_messages(self, user_id: str, agent_id: str, start: int, count: int) -> list:
|
||||
"""Paginated query of in-context messages in agent message queue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_agent_memory(self, user_id: str, agent_id: str) -> dict:
|
||||
"""Return the memory of an agent (core memory + non-core statistics)"""
|
||||
@ -1173,55 +1168,6 @@ class SyncServer(Server):
|
||||
message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user)
|
||||
return message
|
||||
|
||||
def get_agent_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
start: int,
|
||||
count: int,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
if start < 0 or count < 0:
|
||||
raise ValueError("Start and count values should be non-negative")
|
||||
|
||||
if start + count < len(letta_agent._messages): # messages can be returned from whats in memory
|
||||
# Reverse the list to make it in reverse chronological order
|
||||
reversed_messages = letta_agent._messages[::-1]
|
||||
# Check if start is within the range of the list
|
||||
if start >= len(reversed_messages):
|
||||
raise IndexError("Start index is out of range")
|
||||
|
||||
# Calculate the end index, ensuring it does not exceed the list length
|
||||
end_index = min(start + count, len(reversed_messages))
|
||||
|
||||
# Slice the list for pagination
|
||||
messages = reversed_messages[start:end_index]
|
||||
|
||||
else:
|
||||
# need to access persistence manager for additional messages
|
||||
|
||||
# get messages using message manager
|
||||
page = letta_agent.message_manager.list_user_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=self.default_user,
|
||||
cursor=start,
|
||||
limit=count,
|
||||
)
|
||||
|
||||
messages = page
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
## Convert to json
|
||||
## Add a tag indicating in-context or not
|
||||
# json_messages = [record.to_json() for record in messages]
|
||||
# in_context_message_ids = [str(m.id) for m in letta_agent._messages]
|
||||
# for d in json_messages:
|
||||
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
|
||||
|
||||
return messages
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
|
||||
"""Paginated query of all messages in agent archival memory"""
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
@ -1303,7 +1249,8 @@ class SyncServer(Server):
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
cursor: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
reverse: Optional[bool] = False,
|
||||
return_message_object: bool = True,
|
||||
@ -1320,12 +1267,15 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over records
|
||||
# TODO: Check "order_by", "order"
|
||||
start_date = self.message_manager.get_message_by_id(after, actor=actor).created_at if after else None
|
||||
end_date = self.message_manager.get_message_by_id(before, actor=actor).created_at if before else None
|
||||
records = letta_agent.message_manager.list_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
cursor=cursor,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
ascending=not reverse,
|
||||
)
|
||||
|
||||
assert all(isinstance(m, Message) for m in records)
|
||||
|
@ -119,6 +119,7 @@ class MessageManager:
|
||||
limit: Optional[int] = 50,
|
||||
filters: Optional[Dict] = None,
|
||||
query_text: Optional[str] = None,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticMessage]:
|
||||
"""List user messages with flexible filtering and pagination options.
|
||||
|
||||
@ -159,6 +160,7 @@ class MessageManager:
|
||||
limit: Optional[int] = 50,
|
||||
filters: Optional[Dict] = None,
|
||||
query_text: Optional[str] = None,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticMessage]:
|
||||
"""List messages with flexible filtering and pagination options.
|
||||
|
||||
@ -188,6 +190,7 @@ class MessageManager:
|
||||
end_date=end_date,
|
||||
limit=limit,
|
||||
query_text=query_text,
|
||||
ascending=ascending,
|
||||
**message_filters,
|
||||
)
|
||||
|
||||
|
@ -459,7 +459,7 @@ class ToolExecutionSandbox:
|
||||
Generate the code string to call the function.
|
||||
|
||||
Args:
|
||||
inject_agent_state (bool): Whether to inject the agent's state as an input into the tool
|
||||
inject_agent_state (bool): Whether to inject the axgent's state as an input into the tool
|
||||
|
||||
Returns:
|
||||
str: Generated code string for calling the tool
|
||||
|
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
@ -73,8 +74,8 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
|
||||
azure_version=None,
|
||||
azure_deployment=None,
|
||||
)
|
||||
|
||||
using_sqlite = not bool(os.getenv("LETTA_PG_URI"))
|
||||
CREATE_DELAY_SQLITE = 1
|
||||
USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -911,6 +912,8 @@ def test_list_sources(server: SyncServer, default_user):
|
||||
"""Test listing sources with pagination."""
|
||||
# Create multiple sources
|
||||
server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
|
||||
|
||||
# List sources without pagination
|
||||
@ -1004,6 +1007,8 @@ def test_list_files(server: SyncServer, default_user, default_source):
|
||||
PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id),
|
||||
actor=default_user,
|
||||
)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.source_manager.create_file(
|
||||
PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id),
|
||||
actor=default_user,
|
||||
@ -1184,6 +1189,8 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
|
||||
config=LocalSandboxConfig(sandbox_dir=""),
|
||||
)
|
||||
server.sandbox_config_manager.create_or_update_sandbox_config(config_a, actor=default_user)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.sandbox_config_manager.create_or_update_sandbox_config(config_b, actor=default_user)
|
||||
|
||||
# List configs without pagination
|
||||
@ -1239,6 +1246,8 @@ def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, defau
|
||||
env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1")
|
||||
env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2")
|
||||
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
|
||||
|
||||
# List env vars without pagination
|
||||
@ -1299,7 +1308,7 @@ def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agen
|
||||
assert default_block.label not in labels
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.blocks_agents_manager.add_block_to_agent(
|
||||
@ -1361,7 +1370,7 @@ def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_u
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
|
||||
block_manager = BlockManager()
|
||||
block_manager.delete_block(block_id=default_block.id, actor=default_user)
|
||||
@ -1401,7 +1410,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
|
||||
assert print_tool.name not in names
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name")
|
||||
@ -1447,7 +1456,7 @@ def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_us
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)
|
||||
|
@ -161,37 +161,25 @@ def test_user_message(server, user_id, agent_id):
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
|
||||
|
||||
# TODO: Add this back, this is broken on main
|
||||
# @pytest.mark.order(5)
|
||||
# def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
# # test recall memory cursor pagination
|
||||
# messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
# cursor1 = messages_1[-1].id
|
||||
# messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
# messages_2[-1].id
|
||||
# messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
# messages_3[-1].id
|
||||
# assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
# assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
# messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
# assert len(messages_4) == 1
|
||||
#
|
||||
# # test in-context message ids
|
||||
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
# message_ids = [m.id for m in messages_3]
|
||||
# for message_id in in_context_ids:
|
||||
# assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
#
|
||||
# # test recall memory
|
||||
# messages_1 = server.get_agent_messages(agent_id=agent_id, start=0, count=1)
|
||||
# assert len(messages_1) == 1
|
||||
# messages_2 = server.get_agent_messages(agent_id=agent_id, start=1, count=1000)
|
||||
# messages_3 = server.get_agent_messages(agent_id=agent_id, start=1, count=2)
|
||||
# # not sure exactly how many messages there should be
|
||||
# assert len(messages_2) > len(messages_3)
|
||||
# # test safe empty return
|
||||
# messages_none = server.get_agent_messages(agent_id=agent_id, start=1000, count=1000)
|
||||
# assert len(messages_none) == 0
|
||||
@pytest.mark.order(5)
|
||||
def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
cursor1 = messages_1[-1].id
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
messages_2[-1].id
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
messages_3[-1].id
|
||||
assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
message_ids = [m.id for m in messages_3]
|
||||
for message_id in in_context_ids:
|
||||
assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
|
||||
|
||||
@pytest.mark.order(6)
|
||||
|
Loading…
Reference in New Issue
Block a user