feat: Add data loading and attaching to server (#1051)

This commit is contained in:
Sarah Wooders 2024-02-24 19:34:32 -08:00 committed by GitHub
parent 19725cac35
commit ce1ce9d06f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 637 additions and 423 deletions

View File

@ -111,7 +111,6 @@ def load_directory(
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
chunk_size=1000,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")

View File

@ -23,7 +23,6 @@ def load_data(
embedding_config: EmbeddingConfig,
passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None,
chunk_size: int = 1000,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
@ -49,7 +48,6 @@ def load_data(
# generate passages
for passage_text, passage_metadata in connector.generate_passages([document]):
print("passage", passage_text, passage_metadata)
embedding = embed_model.get_text_embedding(passage_text)
passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
@ -64,7 +62,7 @@ def load_data(
)
passages.append(passage)
if len(passages) >= chunk_size:
if len(passages) >= embedding_config.embedding_chunk_size:
# insert passages into passage store
passage_store.insert_many(passages)

View File

@ -17,14 +17,15 @@ import memgpt.constants as constants
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options
# from memgpt.agent_store.storage import StorageConnector
from memgpt.data_sources.connectors import DataConnector, load_data
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore
import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.server.utils as server_utils
from memgpt.data_types import (
User,
Source,
Passage,
AgentState,
LLMConfig,
@ -969,6 +970,10 @@ class SyncServer(LockingServer):
return memgpt_agent.agent_state
def delete_user(self, user_id: uuid.UUID):
# TODO: delete user
pass
def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
"""Delete an agent in the database"""
if self.ms.get_user(user_id=user_id) is None:
@ -1015,15 +1020,45 @@ class SyncServer(LockingServer):
token = self.ms.create_api_key(user_id=user_id)
return token
def create_source(self, name: str): # TODO: add other fields
# craete a data source
pass
def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
"""Create a new data source"""
source = Source(name=name, user_id=user_id)
self.ms.create_source(source)
return source
def load_passages(self, source_id: uuid.UUID, passages: List[Passage]):
# load a list of passages into a data source
pass
def load_data(
self,
user_id: uuid.UUID,
connector: DataConnector,
source_name: Source,
):
"""Load data from a DataConnector into a source for a specified user_id"""
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID):
# load data from a data source into the document store
source = self.ms.get_source(source_name=source_name, user_id=user_id)
if source is None:
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
# get the data connectors
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# TODO: add document store support
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
# load data into the document store
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
# attach a data source to an agent
# TODO: insert passages into agent archival memory
pass
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
if data_source is None:
raise ValueError(f"Data source {source_name} does not exist")
# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# load agent
agent = self._get_or_load_agent(user_id, agent_id)
# attach source to agent
agent.attach_source(data_source.name, source_connector, self.ms)

808
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -55,11 +55,12 @@ pyright = {version = "^1.1.347", optional = true}
llama-index = "^0.10.6"
llama-index-embeddings-openai = "^0.1.1"
python-box = "^7.1.1"
pytest-order = {version = "^1.2.0", optional = true}
[tool.poetry.extras]
local = ["torch", "huggingface-hub", "transformers"]
postgres = ["pgvector", "pg8000"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order"]
server = ["websockets", "fastapi", "uvicorn"]
autogen = ["pyautogen"]

View File

@ -1,4 +1,5 @@
import uuid
import pytest
import os
import memgpt.utils as utils
from dotenv import load_dotenv
@ -10,10 +11,11 @@ from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
from memgpt.embeddings import embedding_model
from memgpt.presets.presets import add_default_presets
from .utils import wipe_config, wipe_memgpt_home
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
def test_server():
@pytest.fixture(scope="module")
def server():
load_dotenv()
wipe_config()
wipe_memgpt_home()
@ -73,14 +75,44 @@ def test_server():
credentials.save()
server = SyncServer()
return server
@pytest.fixture(scope="module")
def user_id(server):
# create user
user = server.create_user()
print(f"Created user\n{user.id}")
# initialize with default presets
server.initialize_default_presets(user.id)
yield user.id
# cleanup
server.delete_user(user.id)
@pytest.fixture(scope="module")
def agent_id(server, user_id):
# create agent
agent_state = server.create_agent(
user_id=user_id,
name="test_agent",
preset="memgpt_chat",
human="cs_phd",
persona="sam_pov",
)
print(f"Created agent\n{agent_state}")
yield agent_state.id
# cleanup
server.delete_agent(user_id, agent_state.id)
def test_error_on_nonexistent_agent(server, user_id, agent_id):
try:
fake_agent_id = uuid.uuid4()
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
# Error is expected
@ -88,21 +120,11 @@ def test_server():
except:
raise
# create presets
add_default_presets(user.id, server.ms)
# create agent
agent_state = server.create_agent(
user_id=user.id,
name="test_agent",
preset="memgpt_chat",
human="cs_phd",
persona="sam_pov",
)
print(f"Created agent\n{agent_state}")
@pytest.mark.order(1)
def test_user_message(server, user_id, agent_id):
try:
server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory")
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
raise Exception("user_message call should have failed")
except ValueError as e:
# Error is expected
@ -110,62 +132,93 @@ def test_server():
except:
raise
print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory"))
server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
# add data into archival memory
agent = server._load_agent(user_id=user.id, agent_id=agent_state.id)
@pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
# create source
source = server.create_source("test_source", user_id)
# load data
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
embed_model = embedding_model(agent.agent_state.embedding_config)
for text in archival_memories:
embedding = embed_model.get_text_embedding(text)
agent.persistence_manager.archival_memory.storage.insert(
Passage(
user_id=user.id,
agent_id=agent_state.id,
text=text,
embedding=embedding,
embedding_dim=agent.agent_state.embedding_config.embedding_dim,
embedding_model=agent.agent_state.embedding_config.embedding_model,
)
)
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
@pytest.mark.order(3)
def test_attach_source_to_agent(server, user_id, agent_id):
# check archival memory size
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
assert len(passages_before) == 0
# attach source
server.attach_source_to_agent(user_id, agent_id, "test_source")
# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
assert len(passages_after) == 5
def test_save_archival_memory(server, user_id, agent_id):
# TODO: insert into archival memory
pass
@pytest.mark.order(4)
def test_user_message(server, user_id, agent_id):
# add data into recall memory
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
@pytest.mark.order5
def test_get_recall_memory(server, user_id, agent_id):
# test recall memory cursor pagination
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000)
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=1000)
ids3 = [m["id"] for m in messages_3]
ids2 = [m["id"] for m in messages_2]
timestamps = [m["created_at"] for m in messages_3]
print("timestamps", timestamps)
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
assert len(messages_3) == len(messages_1) + len(messages_2)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test in-context message ids
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id)
assert len(in_context_ids) == len(messages_3)
assert isinstance(in_context_ids[0], uuid.UUID)
message_ids = [m["id"] for m in messages_3]
for message_id in message_ids:
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
# test recall memory
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
assert len(messages_none) == 0
@pytest.mark.order6
def test_get_archival_memory(server, user_id, agent_id):
# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
)
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
cursor2, passages_2 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
)
cursor3, passages_3 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
)
print("p1", [p["text"] for p in passages_1])
print("p2", [p["text"] for p in passages_2])
@ -174,26 +227,11 @@ def test_server():
assert len(passages_2) == 3
assert len(passages_3) == 4
# test recall memory
messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
assert len(messages_none) == 0
# test archival memory
passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000)
assert len(passage_2) == 4
# test safe empty return
passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
assert len(passage_none) == 0
if __name__ == "__main__":
test_server()

View File

@ -1,13 +1,32 @@
import datetime
from typing import Dict, List, Tuple, Iterator
import os
from memgpt.config import MemGPTConfig
from memgpt.cli.cli import quickstart, QuickstartChoice
from memgpt.data_sources.connectors import DataConnector
from memgpt import Admin
from memgpt.data_types import Document
from .constants import TIMEOUT
class DummyDataConnector(DataConnector):
"""Fake data connector for texting which yields document/passage texts from a provided list"""
def __init__(self, texts: List[str]):
self.texts = texts
def generate_documents(self) -> Iterator[Tuple[str, Dict]]:
for text in self.texts:
yield text, {"metadata": "dummy"}
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
for doc in documents:
yield doc.text, doc.metadata
def create_config(endpoint="openai"):
"""Create config file matching quickstart option"""
if endpoint == "openai":