mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add data loading and attaching to server (#1051)
This commit is contained in:
parent
19725cac35
commit
ce1ce9d06f
@ -111,7 +111,6 @@ def load_directory(
|
|||||||
embedding_config=config.default_embedding_config,
|
embedding_config=config.default_embedding_config,
|
||||||
document_store=None,
|
document_store=None,
|
||||||
passage_store=passage_storage,
|
passage_store=passage_storage,
|
||||||
chunk_size=1000,
|
|
||||||
)
|
)
|
||||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ def load_data(
|
|||||||
embedding_config: EmbeddingConfig,
|
embedding_config: EmbeddingConfig,
|
||||||
passage_store: StorageConnector,
|
passage_store: StorageConnector,
|
||||||
document_store: Optional[StorageConnector] = None,
|
document_store: Optional[StorageConnector] = None,
|
||||||
chunk_size: int = 1000,
|
|
||||||
):
|
):
|
||||||
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
|
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
|
||||||
|
|
||||||
@ -49,7 +48,6 @@ def load_data(
|
|||||||
|
|
||||||
# generate passages
|
# generate passages
|
||||||
for passage_text, passage_metadata in connector.generate_passages([document]):
|
for passage_text, passage_metadata in connector.generate_passages([document]):
|
||||||
print("passage", passage_text, passage_metadata)
|
|
||||||
embedding = embed_model.get_text_embedding(passage_text)
|
embedding = embed_model.get_text_embedding(passage_text)
|
||||||
passage = Passage(
|
passage = Passage(
|
||||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||||
@ -64,7 +62,7 @@ def load_data(
|
|||||||
)
|
)
|
||||||
|
|
||||||
passages.append(passage)
|
passages.append(passage)
|
||||||
if len(passages) >= chunk_size:
|
if len(passages) >= embedding_config.embedding_chunk_size:
|
||||||
# insert passages into passage store
|
# insert passages into passage store
|
||||||
passage_store.insert_many(passages)
|
passage_store.insert_many(passages)
|
||||||
|
|
||||||
|
@ -17,14 +17,15 @@ import memgpt.constants as constants
|
|||||||
|
|
||||||
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||||
from memgpt.cli.cli_config import get_model_options
|
from memgpt.cli.cli_config import get_model_options
|
||||||
|
from memgpt.data_sources.connectors import DataConnector, load_data
|
||||||
# from memgpt.agent_store.storage import StorageConnector
|
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||||
from memgpt.metadata import MetadataStore
|
from memgpt.metadata import MetadataStore
|
||||||
import memgpt.presets.presets as presets
|
import memgpt.presets.presets as presets
|
||||||
import memgpt.utils as utils
|
import memgpt.utils as utils
|
||||||
import memgpt.server.utils as server_utils
|
import memgpt.server.utils as server_utils
|
||||||
from memgpt.data_types import (
|
from memgpt.data_types import (
|
||||||
User,
|
User,
|
||||||
|
Source,
|
||||||
Passage,
|
Passage,
|
||||||
AgentState,
|
AgentState,
|
||||||
LLMConfig,
|
LLMConfig,
|
||||||
@ -969,6 +970,10 @@ class SyncServer(LockingServer):
|
|||||||
|
|
||||||
return memgpt_agent.agent_state
|
return memgpt_agent.agent_state
|
||||||
|
|
||||||
|
def delete_user(self, user_id: uuid.UUID):
|
||||||
|
# TODO: delete user
|
||||||
|
pass
|
||||||
|
|
||||||
def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
|
def delete_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID):
|
||||||
"""Delete an agent in the database"""
|
"""Delete an agent in the database"""
|
||||||
if self.ms.get_user(user_id=user_id) is None:
|
if self.ms.get_user(user_id=user_id) is None:
|
||||||
@ -1015,15 +1020,45 @@ class SyncServer(LockingServer):
|
|||||||
token = self.ms.create_api_key(user_id=user_id)
|
token = self.ms.create_api_key(user_id=user_id)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def create_source(self, name: str): # TODO: add other fields
|
def create_source(self, name: str, user_id: uuid.UUID) -> Source: # TODO: add other fields
|
||||||
# craete a data source
|
"""Create a new data source"""
|
||||||
pass
|
source = Source(name=name, user_id=user_id)
|
||||||
|
self.ms.create_source(source)
|
||||||
|
return source
|
||||||
|
|
||||||
def load_passages(self, source_id: uuid.UUID, passages: List[Passage]):
|
def load_data(
|
||||||
# load a list of passages into a data source
|
self,
|
||||||
pass
|
user_id: uuid.UUID,
|
||||||
|
connector: DataConnector,
|
||||||
|
source_name: Source,
|
||||||
|
):
|
||||||
|
"""Load data from a DataConnector into a source for a specified user_id"""
|
||||||
|
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
||||||
|
|
||||||
def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID):
|
# load data from a data source into the document store
|
||||||
|
source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
||||||
|
if source is None:
|
||||||
|
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
|
||||||
|
|
||||||
|
# get the data connectors
|
||||||
|
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||||
|
# TODO: add document store support
|
||||||
|
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
||||||
|
|
||||||
|
# load data into the document store
|
||||||
|
load_data(connector, source, self.config.default_embedding_config, passage_store, document_store)
|
||||||
|
|
||||||
|
def attach_source_to_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_name: str):
|
||||||
# attach a data source to an agent
|
# attach a data source to an agent
|
||||||
# TODO: insert passages into agent archival memory
|
data_source = self.ms.get_source(source_name=source_name, user_id=user_id)
|
||||||
pass
|
if data_source is None:
|
||||||
|
raise ValueError(f"Data source {source_name} does not exist")
|
||||||
|
|
||||||
|
# get connection to data source storage
|
||||||
|
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||||
|
|
||||||
|
# load agent
|
||||||
|
agent = self._get_or_load_agent(user_id, agent_id)
|
||||||
|
|
||||||
|
# attach source to agent
|
||||||
|
agent.attach_source(data_source.name, source_connector, self.ms)
|
||||||
|
808
poetry.lock
generated
808
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -55,11 +55,12 @@ pyright = {version = "^1.1.347", optional = true}
|
|||||||
llama-index = "^0.10.6"
|
llama-index = "^0.10.6"
|
||||||
llama-index-embeddings-openai = "^0.1.1"
|
llama-index-embeddings-openai = "^0.1.1"
|
||||||
python-box = "^7.1.1"
|
python-box = "^7.1.1"
|
||||||
|
pytest-order = {version = "^1.2.0", optional = true}
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
local = ["torch", "huggingface-hub", "transformers"]
|
local = ["torch", "huggingface-hub", "transformers"]
|
||||||
postgres = ["pgvector", "pg8000"]
|
postgres = ["pgvector", "pg8000"]
|
||||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright"]
|
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order"]
|
||||||
server = ["websockets", "fastapi", "uvicorn"]
|
server = ["websockets", "fastapi", "uvicorn"]
|
||||||
autogen = ["pyautogen"]
|
autogen = ["pyautogen"]
|
||||||
|
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import uuid
|
import uuid
|
||||||
|
import pytest
|
||||||
import os
|
import os
|
||||||
import memgpt.utils as utils
|
import memgpt.utils as utils
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
@ -10,10 +11,11 @@ from memgpt.server.server import SyncServer
|
|||||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
||||||
from memgpt.embeddings import embedding_model
|
from memgpt.embeddings import embedding_model
|
||||||
from memgpt.presets.presets import add_default_presets
|
from memgpt.presets.presets import add_default_presets
|
||||||
from .utils import wipe_config, wipe_memgpt_home
|
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
|
||||||
|
|
||||||
|
|
||||||
def test_server():
|
@pytest.fixture(scope="module")
|
||||||
|
def server():
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
wipe_config()
|
wipe_config()
|
||||||
wipe_memgpt_home()
|
wipe_memgpt_home()
|
||||||
@ -73,14 +75,44 @@ def test_server():
|
|||||||
credentials.save()
|
credentials.save()
|
||||||
|
|
||||||
server = SyncServer()
|
server = SyncServer()
|
||||||
|
return server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def user_id(server):
|
||||||
# create user
|
# create user
|
||||||
user = server.create_user()
|
user = server.create_user()
|
||||||
print(f"Created user\n{user.id}")
|
print(f"Created user\n{user.id}")
|
||||||
|
|
||||||
|
# initialize with default presets
|
||||||
|
server.initialize_default_presets(user.id)
|
||||||
|
yield user.id
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
server.delete_user(user.id)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def agent_id(server, user_id):
|
||||||
|
# create agent
|
||||||
|
agent_state = server.create_agent(
|
||||||
|
user_id=user_id,
|
||||||
|
name="test_agent",
|
||||||
|
preset="memgpt_chat",
|
||||||
|
human="cs_phd",
|
||||||
|
persona="sam_pov",
|
||||||
|
)
|
||||||
|
print(f"Created agent\n{agent_state}")
|
||||||
|
yield agent_state.id
|
||||||
|
|
||||||
|
# cleanup
|
||||||
|
server.delete_agent(user_id, agent_state.id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||||
try:
|
try:
|
||||||
fake_agent_id = uuid.uuid4()
|
fake_agent_id = uuid.uuid4()
|
||||||
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
|
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
|
||||||
raise Exception("user_message call should have failed")
|
raise Exception("user_message call should have failed")
|
||||||
except (KeyError, ValueError) as e:
|
except (KeyError, ValueError) as e:
|
||||||
# Error is expected
|
# Error is expected
|
||||||
@ -88,21 +120,11 @@ def test_server():
|
|||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# create presets
|
|
||||||
add_default_presets(user.id, server.ms)
|
|
||||||
|
|
||||||
# create agent
|
|
||||||
agent_state = server.create_agent(
|
|
||||||
user_id=user.id,
|
|
||||||
name="test_agent",
|
|
||||||
preset="memgpt_chat",
|
|
||||||
human="cs_phd",
|
|
||||||
persona="sam_pov",
|
|
||||||
)
|
|
||||||
print(f"Created agent\n{agent_state}")
|
|
||||||
|
|
||||||
|
@pytest.mark.order(1)
|
||||||
|
def test_user_message(server, user_id, agent_id):
|
||||||
try:
|
try:
|
||||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory")
|
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
|
||||||
raise Exception("user_message call should have failed")
|
raise Exception("user_message call should have failed")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
# Error is expected
|
# Error is expected
|
||||||
@ -110,62 +132,93 @@ def test_server():
|
|||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory"))
|
server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
|
||||||
|
|
||||||
# add data into archival memory
|
|
||||||
agent = server._load_agent(user_id=user.id, agent_id=agent_state.id)
|
@pytest.mark.order(3)
|
||||||
|
def test_load_data(server, user_id, agent_id):
|
||||||
|
# create source
|
||||||
|
source = server.create_source("test_source", user_id)
|
||||||
|
|
||||||
|
# load data
|
||||||
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
||||||
embed_model = embedding_model(agent.agent_state.embedding_config)
|
connector = DummyDataConnector(archival_memories)
|
||||||
for text in archival_memories:
|
server.load_data(user_id, connector, source.name)
|
||||||
embedding = embed_model.get_text_embedding(text)
|
|
||||||
agent.persistence_manager.archival_memory.storage.insert(
|
|
||||||
Passage(
|
|
||||||
user_id=user.id,
|
|
||||||
agent_id=agent_state.id,
|
|
||||||
text=text,
|
|
||||||
embedding=embedding,
|
|
||||||
embedding_dim=agent.agent_state.embedding_config.embedding_dim,
|
|
||||||
embedding_model=agent.agent_state.embedding_config.embedding_model,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.order(3)
|
||||||
|
def test_attach_source_to_agent(server, user_id, agent_id):
|
||||||
|
# check archival memory size
|
||||||
|
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||||
|
assert len(passages_before) == 0
|
||||||
|
|
||||||
|
# attach source
|
||||||
|
server.attach_source_to_agent(user_id, agent_id, "test_source")
|
||||||
|
|
||||||
|
# check archival memory size
|
||||||
|
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=10000)
|
||||||
|
assert len(passages_after) == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_archival_memory(server, user_id, agent_id):
|
||||||
|
# TODO: insert into archival memory
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.order(4)
|
||||||
|
def test_user_message(server, user_id, agent_id):
|
||||||
# add data into recall memory
|
# add data into recall memory
|
||||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.order5
|
||||||
|
def test_get_recall_memory(server, user_id, agent_id):
|
||||||
# test recall memory cursor pagination
|
# test recall memory cursor pagination
|
||||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2)
|
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=2)
|
||||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
|
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, after=cursor1, limit=1000)
|
||||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000)
|
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, limit=1000)
|
||||||
ids3 = [m["id"] for m in messages_3]
|
ids3 = [m["id"] for m in messages_3]
|
||||||
ids2 = [m["id"] for m in messages_2]
|
ids2 = [m["id"] for m in messages_2]
|
||||||
timestamps = [m["created_at"] for m in messages_3]
|
timestamps = [m["created_at"] for m in messages_3]
|
||||||
print("timestamps", timestamps)
|
print("timestamps", timestamps)
|
||||||
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
|
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
|
||||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
|
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||||
assert len(messages_4) == 1
|
assert len(messages_4) == 1
|
||||||
|
|
||||||
# test in-context message ids
|
# test in-context message ids
|
||||||
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
|
in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id)
|
||||||
assert len(in_context_ids) == len(messages_3)
|
assert len(in_context_ids) == len(messages_3)
|
||||||
assert isinstance(in_context_ids[0], uuid.UUID)
|
assert isinstance(in_context_ids[0], uuid.UUID)
|
||||||
message_ids = [m["id"] for m in messages_3]
|
message_ids = [m["id"] for m in messages_3]
|
||||||
for message_id in message_ids:
|
for message_id in message_ids:
|
||||||
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
|
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
|
||||||
|
|
||||||
|
# test recall memory
|
||||||
|
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1)
|
||||||
|
assert len(messages_1) == 1
|
||||||
|
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=1000)
|
||||||
|
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=5)
|
||||||
|
# not sure exactly how many messages there should be
|
||||||
|
assert len(messages_2) > len(messages_3)
|
||||||
|
# test safe empty return
|
||||||
|
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||||
|
assert len(messages_none) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.order6
|
||||||
|
def test_get_archival_memory(server, user_id, agent_id):
|
||||||
# test archival memory cursor pagination
|
# test archival memory cursor pagination
|
||||||
cursor1, passages_1 = server.get_agent_archival_cursor(
|
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||||
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
|
|
||||||
)
|
|
||||||
cursor2, passages_2 = server.get_agent_archival_cursor(
|
cursor2, passages_2 = server.get_agent_archival_cursor(
|
||||||
user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
|
user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
|
||||||
)
|
)
|
||||||
cursor3, passages_3 = server.get_agent_archival_cursor(
|
cursor3, passages_3 = server.get_agent_archival_cursor(
|
||||||
user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||||
)
|
)
|
||||||
print("p1", [p["text"] for p in passages_1])
|
print("p1", [p["text"] for p in passages_1])
|
||||||
print("p2", [p["text"] for p in passages_2])
|
print("p2", [p["text"] for p in passages_2])
|
||||||
@ -174,26 +227,11 @@ def test_server():
|
|||||||
assert len(passages_2) == 3
|
assert len(passages_2) == 3
|
||||||
assert len(passages_3) == 4
|
assert len(passages_3) == 4
|
||||||
|
|
||||||
# test recall memory
|
|
||||||
messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
|
||||||
assert len(messages_1) == 1
|
|
||||||
messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
|
||||||
messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5)
|
|
||||||
# not sure exactly how many messages there should be
|
|
||||||
assert len(messages_2) > len(messages_3)
|
|
||||||
# test safe empty return
|
|
||||||
messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
|
||||||
assert len(messages_none) == 0
|
|
||||||
|
|
||||||
# test archival memory
|
# test archival memory
|
||||||
passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=0, count=1)
|
||||||
assert len(passage_1) == 1
|
assert len(passage_1) == 1
|
||||||
passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1, count=1000)
|
||||||
assert len(passage_2) == 4
|
assert len(passage_2) == 4
|
||||||
# test safe empty return
|
# test safe empty return
|
||||||
passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||||
assert len(passage_none) == 0
|
assert len(passage_none) == 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_server()
|
|
||||||
|
@ -1,13 +1,32 @@
|
|||||||
import datetime
|
import datetime
|
||||||
|
from typing import Dict, List, Tuple, Iterator
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
from memgpt.cli.cli import quickstart, QuickstartChoice
|
from memgpt.cli.cli import quickstart, QuickstartChoice
|
||||||
|
from memgpt.data_sources.connectors import DataConnector
|
||||||
from memgpt import Admin
|
from memgpt import Admin
|
||||||
|
from memgpt.data_types import Document
|
||||||
|
|
||||||
from .constants import TIMEOUT
|
from .constants import TIMEOUT
|
||||||
|
|
||||||
|
|
||||||
|
class DummyDataConnector(DataConnector):
|
||||||
|
|
||||||
|
"""Fake data connector for texting which yields document/passage texts from a provided list"""
|
||||||
|
|
||||||
|
def __init__(self, texts: List[str]):
|
||||||
|
self.texts = texts
|
||||||
|
|
||||||
|
def generate_documents(self) -> Iterator[Tuple[str, Dict]]:
|
||||||
|
for text in self.texts:
|
||||||
|
yield text, {"metadata": "dummy"}
|
||||||
|
|
||||||
|
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str | Dict]]:
|
||||||
|
for doc in documents:
|
||||||
|
yield doc.text, doc.metadata
|
||||||
|
|
||||||
|
|
||||||
def create_config(endpoint="openai"):
|
def create_config(endpoint="openai"):
|
||||||
"""Create config file matching quickstart option"""
|
"""Create config file matching quickstart option"""
|
||||||
if endpoint == "openai":
|
if endpoint == "openai":
|
||||||
|
Loading…
Reference in New Issue
Block a user