mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add data loading and attaching to server (#1051)
This commit is contained in:
parent
19725cac35
commit
ce1ce9d06f
@ -111,7 +111,6 @@ def load_directory(
|
||||
embedding_config=config.default_embedding_config,
|
||||
document_store=None,
|
||||
passage_store=passage_storage,
|
||||
chunk_size=1000,
|
||||
)
|
||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
|
||||
|
@ -23,7 +23,6 @@ def load_data(
|
||||
embedding_config: EmbeddingConfig,
|
||||
passage_store: StorageConnector,
|
||||
document_store: Optional[StorageConnector] = None,
|
||||
chunk_size: int = 1000,
|
||||
):
|
||||
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
|
||||
|
||||
@ -49,7 +48,6 @@ def load_data(
|
||||
|
||||
# generate passages
|
||||
for passage_text, passage_metadata in connector.generate_passages([document]):
|
||||
print("passage", passage_text, passage_metadata)
|
||||
embedding = embed_model.get_text_embedding(passage_text)
|
||||
passage = Passage(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||
@ -64,7 +62,7 @@ def load_data(
|
||||
)
|
||||
|
||||
passages.append(passage)
|
||||
if len(passages) >= chunk_size:
|
||||
if len(passages) >= embedding_config.embedding_chunk_size:
|
||||
# insert passages into passage store
|
||||
passage_store.insert_many(passages)
|
||||
|
||||
|
@ -17,14 +17,15 @@ import memgpt.constants as constants
|
||||
|
||||
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
from memgpt.cli.cli_config import get_model_options
|
||||
|
||||
# from memgpt.agent_store.storage import StorageConnector
|
||||
from memgpt.data_sources.connectors import DataConnector, load_data
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.metadata import MetadataStore
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
import memgpt.server.utils as server_utils
|
||||
from memgpt.data_types import (
|
||||
User,
|
||||
Source,
|
||||
Passage,
|
||||
AgentState,
|
||||
LLMConfig,
|
||||
@ -969,6 +970,10 @@ class SyncServer(LockingServer):
|
||||
|
||||
return memgpt_agent.agent_state
|
||||
|
||||
def delete_user(self, user_id: uuid.UUID):
|
||||
# TODO: delete user
|
||||
pass
|
||||
|
||||
def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
"""Delete an agent in the database"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
@ -1015,15 +1020,45 @@ class SyncServer(LockingServer):
|
||||
token = self.ms.create_api_key(user_id=user_id)
|
||||
return token
|
||||
|
||||
def create_source(self, name: str): # TODO: add other fields
|
||||
# craete a data source
|
||||
pass
|
||||
def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
|
||||
"""Create a new data source"""
|
||||
source = Source(name=name, user_id=user_id)
|
||||
self.ms.create_source(source)
|
||||
return source
|
||||
|
||||
def load_passages(self, source_id: uuid.UUID, passages: List[Passage]):
|
||||
# load a list of passages into a data source
|
||||
pass
|
||||
def load_data(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
connector: DataConnector,
|
||||
source_name: Source,
|
||||
):
|
||||
"""Load data from a DataConnector into a source for a specified user_id"""
|
||||
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
||||
|
||||
def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID):
|
||||
# load data from a data source into the document store
|
||||
source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
||||
if source is None:
|
||||
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
|
||||
|
||||
# get the data connectors
|
||||
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
# TODO: add document store support
|
||||
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
||||
|
||||
# load data into the document store
|
||||
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
|
||||
|
||||
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
||||
# attach a data source to an agent
|
||||
# TODO: insert passages into agent archival memory
|
||||
pass
|
||||
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
||||
if data_source is None:
|
||||
raise ValueError(f"Data source {source_name} does not exist")
|
||||
|
||||
# get connection to data source storage
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
|
||||
# load agent
|
||||
agent = self._get_or_load_agent(user_id, agent_id)
|
||||
|
||||
# attach source to agent
|
||||
agent.attach_source(data_source.name, source_connector, self.ms)
|
||||
|
808
poetry.lock
generated
808
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -55,11 +55,12 @@ pyright = {version = "^1.1.347", optional = true}
|
||||
llama-index = "^0.10.6"
|
||||
llama-index-embeddings-openai = "^0.1.1"
|
||||
python-box = "^7.1.1"
|
||||
pytest-order = {version = "^1.2.0", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
postgres = ["pgvector", "pg8000"]
|
||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright"]
|
||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order"]
|
||||
server = ["websockets", "fastapi", "uvicorn"]
|
||||
autogen = ["pyautogen"]
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import uuid
|
||||
import pytest
|
||||
import os
|
||||
import memgpt.utils as utils
|
||||
from dotenv import load_dotenv
|
||||
@ -10,10 +11,11 @@ from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
from .utils import wipe_config, wipe_memgpt_home
|
||||
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
|
||||
|
||||
|
||||
def test_server():
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
load_dotenv()
|
||||
wipe_config()
|
||||
wipe_memgpt_home()
|
||||
@ -73,14 +75,44 @@ def test_server():
|
||||
credentials.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def user_id(server):
|
||||
# create user
|
||||
user = server.create_user()
|
||||
print(f"Created user\n{user.id}")
|
||||
|
||||
# initialize with default presets
|
||||
server.initialize_default_presets(user.id)
|
||||
yield user.id
|
||||
|
||||
# cleanup
|
||||
server.delete_user(user.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent_id(server, user_id):
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
name="test_agent",
|
||||
preset="memgpt_chat",
|
||||
human="cs_phd",
|
||||
persona="sam_pov",
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
|
||||
# cleanup
|
||||
server.delete_agent(user_id, agent_state.id)
|
||||
|
||||
|
||||
def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
try:
|
||||
fake_agent_id = uuid.uuid4()
|
||||
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
|
||||
raise Exception("user_message call should have failed")
|
||||
except (KeyError, ValueError) as e:
|
||||
# Error is expected
|
||||
@ -88,21 +120,11 @@ def test_server():
|
||||
except:
|
||||
raise
|
||||
|
||||
# create presets
|
||||
add_default_presets(user.id, server.ms)
|
||||
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user.id,
|
||||
name="test_agent",
|
||||
preset="memgpt_chat",
|
||||
human="cs_phd",
|
||||
persona="sam_pov",
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
|
||||
@pytest.mark.order(1)
|
||||
def test_user_message(server, user_id, agent_id):
|
||||
try:
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
|
||||
raise Exception("user_message call should have failed")
|
||||
except ValueError as e:
|
||||
# Error is expected
|
||||
@ -110,62 +132,93 @@ def test_server():
|
||||
except:
|
||||
raise
|
||||
|
||||
print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory"))
|
||||
server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
|
||||
|
||||
# add data into archival memory
|
||||
agent = server._load_agent(user_id=user.id, agent_id=agent_state.id)
|
||||
|
||||
@pytest.mark.order(3)
|
||||
def test_load_data(server, user_id, agent_id):
|
||||
# create source
|
||||
source = server.create_source("test_source", user_id)
|
||||
|
||||
# load data
|
||||
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
||||
embed_model = embedding_model(agent.agent_state.embedding_config)
|
||||
for text in archival_memories:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
agent.persistence_manager.archival_memory.storage.insert(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
agent_id=agent_state.id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_dim=agent.agent_state.embedding_config.embedding_dim,
|
||||
embedding_model=agent.agent_state.embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
|
||||
|
||||
@pytest.mark.order(3)
|
||||
def test_attach_source_to_agent(server, user_id, agent_id):
|
||||
# check archival memory size
|
||||
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||
assert len(passages_before) == 0
|
||||
|
||||
# attach source
|
||||
server.attach_source_to_agent(user_id, agent_id, "test_source")
|
||||
|
||||
# check archival memory size
|
||||
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||
assert len(passages_after) == 5
|
||||
|
||||
|
||||
def test_save_archival_memory(server, user_id, agent_id):
|
||||
# TODO: insert into archival memory
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.order(4)
|
||||
def test_user_message(server, user_id, agent_id):
|
||||
# add data into recall memory
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
|
||||
|
||||
@pytest.mark.order5
|
||||
def test_get_recall_memory(server, user_id, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000)
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=1000)
|
||||
ids3 = [m["id"] for m in messages_3]
|
||||
ids2 = [m["id"] for m in messages_2]
|
||||
timestamps = [m["created_at"] for m in messages_3]
|
||||
print("timestamps", timestamps)
|
||||
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
|
||||
in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id)
|
||||
assert len(in_context_ids) == len(messages_3)
|
||||
assert isinstance(in_context_ids[0], uuid.UUID)
|
||||
message_ids = [m["id"] for m in messages_3]
|
||||
for message_id in message_ids:
|
||||
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
|
||||
|
||||
# test recall memory
|
||||
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1)
|
||||
assert len(messages_1) == 1
|
||||
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=5)
|
||||
# not sure exactly how many messages there should be
|
||||
assert len(messages_2) > len(messages_3)
|
||||
# test safe empty return
|
||||
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||
assert len(messages_none) == 0
|
||||
|
||||
|
||||
@pytest.mark.order6
|
||||
def test_get_archival_memory(server, user_id, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(
|
||||
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
|
||||
)
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||
cursor2, passages_2 = server.get_agent_archival_cursor(
|
||||
user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
|
||||
user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
|
||||
)
|
||||
cursor3, passages_3 = server.get_agent_archival_cursor(
|
||||
user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||
user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||
)
|
||||
print("p1", [p["text"] for p in passages_1])
|
||||
print("p2", [p["text"] for p in passages_2])
|
||||
@ -174,26 +227,11 @@ def test_server():
|
||||
assert len(passages_2) == 3
|
||||
assert len(passages_3) == 4
|
||||
|
||||
# test recall memory
|
||||
messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
||||
assert len(messages_1) == 1
|
||||
messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5)
|
||||
# not sure exactly how many messages there should be
|
||||
assert len(messages_2) > len(messages_3)
|
||||
# test safe empty return
|
||||
messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
assert len(messages_none) == 0
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
||||
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1)
|
||||
assert len(passage_1) == 1
|
||||
passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
||||
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000)
|
||||
assert len(passage_2) == 4
|
||||
# test safe empty return
|
||||
passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_server()
|
||||
|
@ -1,13 +1,32 @@
|
||||
import datetime
|
||||
from typing import Dict, List, Tuple, Iterator
|
||||
import os
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.cli.cli import quickstart, QuickstartChoice
|
||||
from memgpt.data_sources.connectors import DataConnector
|
||||
from memgpt import Admin
|
||||
from memgpt.data_types import Document
|
||||
|
||||
from .constants import TIMEOUT
|
||||
|
||||
|
||||
class DummyDataConnector(DataConnector):
|
||||
|
||||
"""Fake data connector for texting which yields document/passage texts from a provided list"""
|
||||
|
||||
def __init__(self, texts: List[str]):
|
||||
self.texts = texts
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]:
|
||||
for text in self.texts:
|
||||
yield text, {"metadata": "dummy"}
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
|
||||
for doc in documents:
|
||||
yield doc.text, doc.metadata
|
||||
|
||||
|
||||
def create_config(endpoint="openai"):
|
||||
"""Create config file matching quickstart option"""
|
||||
if endpoint == "openai":
|
||||
|
Loading…
Reference in New Issue
Block a user