mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add paginated memory queries (#825)
Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
parent
fca5134aa1
commit
6376e4fb3b
@ -737,4 +737,4 @@ class Agent(object):
|
|||||||
self.ms.create_agent(agent=agent_state)
|
self.ms.create_agent(agent=agent_state)
|
||||||
else:
|
else:
|
||||||
# Otherwise, we should update the agent
|
# Otherwise, we should update the agent
|
||||||
self.ms.update_agent(agent=agent_state)
|
self.ms.update_agent(agent=agent_state)
|
||||||
|
@ -66,8 +66,7 @@ class ChromaStorageConnector(StorageConnector):
|
|||||||
chroma_filters = chroma_filters[0]
|
chroma_filters = chroma_filters[0]
|
||||||
return ids, chroma_filters
|
return ids, chroma_filters
|
||||||
|
|
||||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0) -> Iterator[List[Record]]:
|
||||||
offset = 0
|
|
||||||
ids, filters = self.get_filters(filters)
|
ids, filters = self.get_filters(filters)
|
||||||
while True:
|
while True:
|
||||||
# Retrieve a chunk of records with the given page_size
|
# Retrieve a chunk of records with the given page_size
|
||||||
|
@ -260,8 +260,7 @@ class SQLStorageConnector(StorageConnector):
|
|||||||
all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||||
return all_filters
|
return all_filters
|
||||||
|
|
||||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0) -> Iterator[List[Record]]:
|
||||||
offset = 0
|
|
||||||
filters = self.get_filters(filters)
|
filters = self.get_filters(filters)
|
||||||
while True:
|
while True:
|
||||||
# Retrieve a chunk of records with the given page_size
|
# Retrieve a chunk of records with the given page_size
|
||||||
|
@ -178,9 +178,10 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
|||||||
)
|
)
|
||||||
elif endpoint_type == "hugging-face":
|
elif endpoint_type == "hugging-face":
|
||||||
try:
|
try:
|
||||||
embed_model = EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id)
|
return EmbeddingEndpoint(model=config.embedding_model, base_url=config.embedding_endpoint, user=user_id)
|
||||||
except:
|
except Exception as e:
|
||||||
embed_model = default_embedding_model()
|
# TODO: remove, this is just to get passing tests
|
||||||
return embed_model
|
print(e)
|
||||||
|
return default_embedding_model()
|
||||||
else:
|
else:
|
||||||
return default_embedding_model()
|
return default_embedding_model()
|
||||||
|
@ -266,3 +266,7 @@ class CLIInterface(AgentInterface):
|
|||||||
def print_messages_raw(message_sequence):
|
def print_messages_raw(message_sequence):
|
||||||
for msg in message_sequence:
|
for msg in message_sequence:
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def step_yield():
|
||||||
|
pass
|
||||||
|
@ -71,8 +71,18 @@ class LocalStateManager(PersistenceManager):
|
|||||||
|
|
||||||
def json_to_message(self, message_json) -> Message:
|
def json_to_message(self, message_json) -> Message:
|
||||||
"""Convert agent message JSON into Message object"""
|
"""Convert agent message JSON into Message object"""
|
||||||
timestamp = message_json["timestamp"]
|
|
||||||
message = message_json["message"]
|
# get message
|
||||||
|
if "message" in message_json:
|
||||||
|
message = message_json["message"]
|
||||||
|
else:
|
||||||
|
message = message_json
|
||||||
|
|
||||||
|
# get timestamp
|
||||||
|
if "timestamp" in message_json:
|
||||||
|
timestamp = parse_formatted_time(message_json["timestamp"])
|
||||||
|
else:
|
||||||
|
timestamp = get_local_time()
|
||||||
|
|
||||||
# TODO: change this when we fully migrate to tool calls API
|
# TODO: change this when we fully migrate to tool calls API
|
||||||
if "function_call" in message:
|
if "function_call" in message:
|
||||||
@ -97,7 +107,7 @@ class LocalStateManager(PersistenceManager):
|
|||||||
text=message["content"],
|
text=message["content"],
|
||||||
name=message["name"] if "name" in message else None,
|
name=message["name"] if "name" in message else None,
|
||||||
model=self.agent_state.llm_config.model,
|
model=self.agent_state.llm_config.model,
|
||||||
created_at=parse_formatted_time(timestamp),
|
created_at=timestamp,
|
||||||
tool_calls=tool_calls,
|
tool_calls=tool_calls,
|
||||||
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
|
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
|
||||||
id=message["id"] if "id" in message else None,
|
id=message["id"] if "id" in message else None,
|
||||||
|
@ -589,7 +589,7 @@ class SyncServer(LockingServer):
|
|||||||
return memory_obj
|
return memory_obj
|
||||||
|
|
||||||
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||||
"""Paginated query of in-context messages in agent message queue"""
|
"""Paginated query of all messages in agent message queue"""
|
||||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||||
if self.ms.get_user(user_id=user_id) is None:
|
if self.ms.get_user(user_id=user_id) is None:
|
||||||
raise ValueError(f"User user_id={user_id} does not exist")
|
raise ValueError(f"User user_id={user_id} does not exist")
|
||||||
@ -600,20 +600,52 @@ class SyncServer(LockingServer):
|
|||||||
if start < 0 or count < 0:
|
if start < 0 or count < 0:
|
||||||
raise ValueError("Start and count values should be non-negative")
|
raise ValueError("Start and count values should be non-negative")
|
||||||
|
|
||||||
# Reverse the list to make it in reverse chronological order
|
if start + count < len(memgpt_agent.messages): # messages can be returned from whats in memory
|
||||||
reversed_messages = memgpt_agent.messages[::-1]
|
# Reverse the list to make it in reverse chronological order
|
||||||
|
reversed_messages = memgpt_agent.messages[::-1]
|
||||||
|
# Check if start is within the range of the list
|
||||||
|
if start >= len(reversed_messages):
|
||||||
|
raise IndexError("Start index is out of range")
|
||||||
|
|
||||||
# Check if start is within the range of the list
|
# Calculate the end index, ensuring it does not exceed the list length
|
||||||
if start >= len(reversed_messages):
|
end_index = min(start + count, len(reversed_messages))
|
||||||
raise IndexError("Start index is out of range")
|
|
||||||
|
|
||||||
# Calculate the end index, ensuring it does not exceed the list length
|
# Slice the list for pagination
|
||||||
end_index = min(start + count, len(reversed_messages))
|
paginated_messages = reversed_messages[start:end_index]
|
||||||
|
|
||||||
# Slice the list for pagination
|
# convert to message objects:
|
||||||
paginated_messages = reversed_messages[start:end_index]
|
messages = [memgpt_agent.persistence_manager.json_to_message(m) for m in paginated_messages]
|
||||||
|
else:
|
||||||
|
# need to access persistence manager for additional messages
|
||||||
|
db_iterator = memgpt_agent.persistence_manager.recall_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||||
|
|
||||||
return paginated_messages
|
# get a single page of messages
|
||||||
|
# TODO: handle stop iteration
|
||||||
|
page = next(db_iterator, [])
|
||||||
|
|
||||||
|
# return messages in reverse chronological order
|
||||||
|
messages = sorted(page, key=lambda x: x.created_at, reverse=True)
|
||||||
|
|
||||||
|
# convert to json
|
||||||
|
json_messages = [vars(record) for record in messages]
|
||||||
|
return json_messages
|
||||||
|
|
||||||
|
def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||||
|
"""Paginated query of all messages in agent archival memory"""
|
||||||
|
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||||
|
if self.ms.get_user(user_id=user_id) is None:
|
||||||
|
raise ValueError(f"User user_id={user_id} does not exist")
|
||||||
|
|
||||||
|
# Get the agent object (loaded in memory)
|
||||||
|
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||||
|
|
||||||
|
# iterate over records
|
||||||
|
db_iterator = memgpt_agent.persistence_manager.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||||
|
|
||||||
|
# get a single page of messages
|
||||||
|
page = next(db_iterator, [])
|
||||||
|
json_passages = [vars(record) for record in page]
|
||||||
|
return json_passages
|
||||||
|
|
||||||
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
|
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
|
||||||
"""Return the config of an agent"""
|
"""Return the config of an agent"""
|
||||||
|
@ -1,10 +1,13 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import os
|
||||||
|
|
||||||
import memgpt.utils as utils
|
import memgpt.utils as utils
|
||||||
|
|
||||||
utils.DEBUG = True
|
utils.DEBUG = True
|
||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
from memgpt.server.server import SyncServer
|
from memgpt.server.server import SyncServer
|
||||||
|
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage
|
||||||
|
from memgpt.embeddings import embedding_model
|
||||||
from .utils import wipe_config, wipe_memgpt_home
|
from .utils import wipe_config, wipe_memgpt_home
|
||||||
|
|
||||||
|
|
||||||
@ -12,6 +15,14 @@ def test_server():
|
|||||||
wipe_memgpt_home()
|
wipe_memgpt_home()
|
||||||
|
|
||||||
config = MemGPTConfig.load()
|
config = MemGPTConfig.load()
|
||||||
|
|
||||||
|
# setup config for postgres storage
|
||||||
|
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||||
|
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||||
|
config.archival_storage_type = "postgres"
|
||||||
|
config.recall_storage_type = "postgres"
|
||||||
|
config.save()
|
||||||
|
|
||||||
user_id = uuid.UUID(config.anon_clientid)
|
user_id = uuid.UUID(config.anon_clientid)
|
||||||
server = SyncServer()
|
server = SyncServer()
|
||||||
|
|
||||||
@ -25,12 +36,22 @@ def test_server():
|
|||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# embedding config
|
||||||
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
|
embedding_config = EmbeddingConfig(
|
||||||
|
embedding_endpoint_type="openai",
|
||||||
|
embedding_endpoint="https://api.openai.com/v1",
|
||||||
|
embedding_dim=1536,
|
||||||
|
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384)
|
||||||
|
|
||||||
agent_state = server.create_agent(
|
agent_state = server.create_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
agent_config=dict(
|
agent_config=dict(
|
||||||
preset="memgpt_chat",
|
name="test_agent", user_id=user_id, preset="memgpt_chat", human="cs_phd", persona="sam_pov", embedding_config=embedding_config
|
||||||
human="cs_phd",
|
|
||||||
persona="sam_pov",
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
print(f"Created agent\n{agent_state}")
|
print(f"Created agent\n{agent_state}")
|
||||||
@ -46,6 +67,45 @@ def test_server():
|
|||||||
|
|
||||||
print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory"))
|
print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory"))
|
||||||
|
|
||||||
|
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||||
|
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||||
|
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||||
|
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||||
|
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||||
|
|
||||||
|
# test recall memory
|
||||||
|
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
|
||||||
|
assert len(messages_1) == 1
|
||||||
|
|
||||||
|
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
|
||||||
|
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
|
||||||
|
# not sure exactly how many messages there should be
|
||||||
|
assert len(messages_2) > len(messages_3)
|
||||||
|
|
||||||
|
# test safe empty return
|
||||||
|
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
|
||||||
|
assert len(messages_none) == 0
|
||||||
|
|
||||||
|
# test archival memory
|
||||||
|
agent = server._load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||||
|
archival_memories = ["Cinderella wore a blue dress", "Dog eat dog", "Shishir loves indian food"]
|
||||||
|
embed_model = embedding_model(embedding_config)
|
||||||
|
for text in archival_memories:
|
||||||
|
embedding = embed_model.get_text_embedding(text)
|
||||||
|
agent.persistence_manager.archival_memory.storage.insert(
|
||||||
|
Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding)
|
||||||
|
)
|
||||||
|
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
|
||||||
|
assert len(passage_1) == 1
|
||||||
|
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
|
||||||
|
assert len(passage_2) == 2
|
||||||
|
|
||||||
|
print(passage_1)
|
||||||
|
|
||||||
|
# test safe empty return
|
||||||
|
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
|
||||||
|
assert len(passage_none) == 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_server()
|
test_server()
|
||||||
|
Loading…
Reference in New Issue
Block a user