feat: Add data loading and attaching to server (#1051)

This commit is contained in:
Sarah Wooders 2024-02-24 19:34:32 -08:00 committed by GitHub
parent 19725cac35
commit ce1ce9d06f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 637 additions and 423 deletions

View File

@ -111,7 +111,6 @@ def load_directory(
embedding_config=config.default_embedding_config, embedding_config=config.default_embedding_config,
document_store=None, document_store=None,
passage_store=passage_storage, passage_store=passage_storage,
chunk_size=1000,
) )
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}") print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")

View File

@ -23,7 +23,6 @@ def load_data(
embedding_config: EmbeddingConfig, embedding_config: EmbeddingConfig,
passage_store: StorageConnector, passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None, document_store: Optional[StorageConnector] = None,
chunk_size: int = 1000,
): ):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
@ -49,7 +48,6 @@ def load_data(
# generate passages # generate passages
for passage_text, passage_metadata in connector.generate_passages([document]): for passage_text, passage_metadata in connector.generate_passages([document]):
print("passage", passage_text, passage_metadata)
embedding = embed_model.get_text_embedding(passage_text) embedding = embed_model.get_text_embedding(passage_text)
passage = Passage( passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"), id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
@ -64,7 +62,7 @@ def load_data(
) )
passages.append(passage) passages.append(passage)
if len(passages) >= chunk_size: if len(passages) >= embedding_config.embedding_chunk_size:
# insert passages into passage store # insert passages into passage store
passage_store.insert_many(passages) passage_store.insert_many(passages)

View File

@ -17,14 +17,15 @@ import memgpt.constants as constants
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin # from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options from memgpt.cli.cli_config import get_model_options
from memgpt.data_sources.connectors import DataConnector, load_data
# from memgpt.agent_store.storage import StorageConnector from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore from memgpt.metadata import MetadataStore
import memgpt.presets.presets as presets import memgpt.presets.presets as presets
import memgpt.utils as utils import memgpt.utils as utils
import memgpt.server.utils as server_utils import memgpt.server.utils as server_utils
from memgpt.data_types import ( from memgpt.data_types import (
User, User,
Source,
Passage, Passage,
AgentState, AgentState,
LLMConfig, LLMConfig,
@ -969,6 +970,10 @@ class SyncServer(LockingServer):
return memgpt_agent.agent_state return memgpt_agent.agent_state
def delete_user(self, user_id: uuid.UUID):
# TODO: delete user
pass
def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID): def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
"""Delete an agent in the database""" """Delete an agent in the database"""
if self.ms.get_user(user_id=user_id) is None: if self.ms.get_user(user_id=user_id) is None:
@ -1015,15 +1020,45 @@ class SyncServer(LockingServer):
token = self.ms.create_api_key(user_id=user_id) token = self.ms.create_api_key(user_id=user_id)
return token return token
def create_source(self, name: str): # TODO: add other fields def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
# craete a data source """Create a new data source"""
pass source = Source(name=name, user_id=user_id)
self.ms.create_source(source)
return source
def load_passages(self, source_id: uuid.UUID, passages: List[Passage]): def load_data(
# load a list of passages into a data source self,
pass user_id: uuid.UUID,
connector: DataConnector,
source_name: Source,
):
"""Load data from a DataConnector into a source for a specified user_id"""
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID): # load data from a data source into the document store
source = self.ms.get_source(source_name=source_name, user_id=user_id)
if source is None:
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
# get the data connectors
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# TODO: add document store support
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
# load data into the document store
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
# attach a data source to an agent # attach a data source to an agent
# TODO: insert passages into agent archival memory data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
pass if data_source is None:
raise ValueError(f"Data source {source_name} does not exist")
# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# load agent
agent = self._get_or_load_agent(user_id, agent_id)
# attach source to agent
agent.attach_source(data_source.name, source_connector, self.ms)

808
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -55,11 +55,12 @@ pyright = {version = "^1.1.347", optional = true}
llama-index = "^0.10.6" llama-index = "^0.10.6"
llama-index-embeddings-openai = "^0.1.1" llama-index-embeddings-openai = "^0.1.1"
python-box = "^7.1.1" python-box = "^7.1.1"
pytest-order = {version = "^1.2.0", optional = true}
[tool.poetry.extras] [tool.poetry.extras]
local = ["torch", "huggingface-hub", "transformers"] local = ["torch", "huggingface-hub", "transformers"]
postgres = ["pgvector", "pg8000"] postgres = ["pgvector", "pg8000"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright"] dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order"]
server = ["websockets", "fastapi", "uvicorn"] server = ["websockets", "fastapi", "uvicorn"]
autogen = ["pyautogen"] autogen = ["pyautogen"]

View File

@ -1,4 +1,5 @@
import uuid import uuid
import pytest
import os import os
import memgpt.utils as utils import memgpt.utils as utils
from dotenv import load_dotenv from dotenv import load_dotenv
@ -10,10 +11,11 @@ from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
from memgpt.embeddings import embedding_model from memgpt.embeddings import embedding_model
from memgpt.presets.presets import add_default_presets from memgpt.presets.presets import add_default_presets
from .utils import wipe_config, wipe_memgpt_home from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
def test_server(): @pytest.fixture(scope="module")
def server():
load_dotenv() load_dotenv()
wipe_config() wipe_config()
wipe_memgpt_home() wipe_memgpt_home()
@ -73,14 +75,44 @@ def test_server():
credentials.save() credentials.save()
server = SyncServer() server = SyncServer()
return server
@pytest.fixture(scope="module")
def user_id(server):
# create user # create user
user = server.create_user() user = server.create_user()
print(f"Created user\n{user.id}") print(f"Created user\n{user.id}")
# initialize with default presets
server.initialize_default_presets(user.id)
yield user.id
# cleanup
server.delete_user(user.id)
@pytest.fixture(scope="module")
def agent_id(server, user_id):
# create agent
agent_state = server.create_agent(
user_id=user_id,
name="test_agent",
preset="memgpt_chat",
human="cs_phd",
persona="sam_pov",
)
print(f"Created agent\n{agent_state}")
yield agent_state.id
# cleanup
server.delete_agent(user_id, agent_state.id)
def test_error_on_nonexistent_agent(server, user_id, agent_id):
try: try:
fake_agent_id = uuid.uuid4() fake_agent_id = uuid.uuid4()
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?") server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed") raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e: except (KeyError, ValueError) as e:
# Error is expected # Error is expected
@ -88,21 +120,11 @@ def test_server():
except: except:
raise raise
# create presets
add_default_presets(user.id, server.ms)
# create agent
agent_state = server.create_agent(
user_id=user.id,
name="test_agent",
preset="memgpt_chat",
human="cs_phd",
persona="sam_pov",
)
print(f"Created agent\n{agent_state}")
@pytest.mark.order(1)
def test_user_message(server, user_id, agent_id):
try: try:
server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory") server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
raise Exception("user_message call should have failed") raise Exception("user_message call should have failed")
except ValueError as e: except ValueError as e:
# Error is expected # Error is expected
@ -110,62 +132,93 @@ def test_server():
except: except:
raise raise
print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory")) server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
# add data into archival memory
agent = server._load_agent(user_id=user.id, agent_id=agent_state.id) @pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
# create source
source = server.create_source("test_source", user_id)
# load data
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"] archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
embed_model = embedding_model(agent.agent_state.embedding_config) connector = DummyDataConnector(archival_memories)
for text in archival_memories: server.load_data(user_id, connector, source.name)
embedding = embed_model.get_text_embedding(text)
agent.persistence_manager.archival_memory.storage.insert(
Passage(
user_id=user.id,
agent_id=agent_state.id,
text=text,
embedding=embedding,
embedding_dim=agent.agent_state.embedding_config.embedding_dim,
embedding_model=agent.agent_state.embedding_config.embedding_model,
)
)
@pytest.mark.order(3)
def test_attach_source_to_agent(server, user_id, agent_id):
# check archival memory size
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
assert len(passages_before) == 0
# attach source
server.attach_source_to_agent(user_id, agent_id, "test_source")
# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
assert len(passages_after) == 5
def test_save_archival_memory(server, user_id, agent_id):
# TODO: insert into archival memory
pass
@pytest.mark.order(4)
def test_user_message(server, user_id, agent_id):
# add data into recall memory # add data into recall memory
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?") server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
@pytest.mark.order5
def test_get_recall_memory(server, user_id, agent_id):
# test recall memory cursor pagination # test recall memory cursor pagination
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2) cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000) cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000) cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=1000)
ids3 = [m["id"] for m in messages_3] ids3 = [m["id"] for m in messages_3]
ids2 = [m["id"] for m in messages_2] ids2 = [m["id"] for m in messages_2]
timestamps = [m["created_at"] for m in messages_3] timestamps = [m["created_at"] for m in messages_3]
print("timestamps", timestamps) print("timestamps", timestamps)
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"] assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
assert len(messages_3) == len(messages_1) + len(messages_2) assert len(messages_3) == len(messages_1) + len(messages_2)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1) cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1 assert len(messages_4) == 1
# test in-context message ids # test in-context message ids
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id) in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id)
assert len(in_context_ids) == len(messages_3) assert len(in_context_ids) == len(messages_3)
assert isinstance(in_context_ids[0], uuid.UUID) assert isinstance(in_context_ids[0], uuid.UUID)
message_ids = [m["id"] for m in messages_3] message_ids = [m["id"] for m in messages_3]
for message_id in message_ids: for message_id in message_ids:
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}" assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
# test recall memory
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
assert len(messages_none) == 0
@pytest.mark.order6
def test_get_archival_memory(server, user_id, agent_id):
# test archival memory cursor pagination # test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor( cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
)
cursor2, passages_2 = server.get_agent_archival_cursor( cursor2, passages_2 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text" user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
) )
cursor3, passages_3 = server.get_agent_archival_cursor( cursor3, passages_3 = server.get_agent_archival_cursor(
user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text" user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
) )
print("p1", [p["text"] for p in passages_1]) print("p1", [p["text"] for p in passages_1])
print("p2", [p["text"] for p in passages_2]) print("p2", [p["text"] for p in passages_2])
@ -174,26 +227,11 @@ def test_server():
assert len(passages_2) == 3 assert len(passages_2) == 3
assert len(passages_3) == 4 assert len(passages_3) == 4
# test recall memory
messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
assert len(messages_none) == 0
# test archival memory # test archival memory
passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1) passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1)
assert len(passage_1) == 1 assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000) passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000)
assert len(passage_2) == 4 assert len(passage_2) == 4
# test safe empty return # test safe empty return
passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000) passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
assert len(passage_none) == 0 assert len(passage_none) == 0
if __name__ == "__main__":
test_server()

View File

@ -1,13 +1,32 @@
import datetime import datetime
from typing import Dict, List, Tuple, Iterator
import os import os
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.cli.cli import quickstart, QuickstartChoice from memgpt.cli.cli import quickstart, QuickstartChoice
from memgpt.data_sources.connectors import DataConnector
from memgpt import Admin from memgpt import Admin
from memgpt.data_types import Document
from .constants import TIMEOUT from .constants import TIMEOUT
class DummyDataConnector(DataConnector):
"""Fake data connector for texting which yields document/passage texts from a provided list"""
def __init__(self, texts: List[str]):
self.texts = texts
def generate_documents(self) -> Iterator[Tuple[str, Dict]]:
for text in self.texts:
yield text, {"metadata": "dummy"}
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
for doc in documents:
yield doc.text, doc.metadata
def create_config(endpoint="openai"): def create_config(endpoint="openai"):
"""Create config file matching quickstart option""" """Create config file matching quickstart option"""
if endpoint == "openai": if endpoint == "openai":