mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
130 lines
5.6 KiB
Python
130 lines
5.6 KiB
Python
# import tempfile
|
|
# import asyncio
|
|
import os
|
|
import pytest
|
|
from memgpt.connectors.storage import StorageConnector, TableType
|
|
|
|
# import memgpt
|
|
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
|
from memgpt.cli.cli import attach
|
|
from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN
|
|
from memgpt.config import AgentConfig, MemGPTConfig
|
|
|
|
|
|
# @pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
|
|
# @pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"])
|
|
@pytest.mark.parametrize("metadata_storage_connector", ["postgres"])
|
|
@pytest.mark.parametrize("passage_storage_connector", ["postgres"])
|
|
def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
|
# setup config
|
|
config = MemGPTConfig()
|
|
if metadata_storage_connector == "postgres":
|
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
|
print("Skipping test, missing PG URI")
|
|
return
|
|
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
|
config.metadata_storage_type = "postgres"
|
|
elif metadata_storage_connector == "sqlite":
|
|
print("testing sqlite metadata")
|
|
# nothing to do (should be config defaults)
|
|
else:
|
|
raise NotImplementedError(f"Storage type {metadata_storage_connector} not implemented")
|
|
if passage_storage_connector == "postgres":
|
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
|
print("Skipping test, missing PG URI")
|
|
return
|
|
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
|
config.archival_storage_type = "postgres"
|
|
elif passage_storage_connector == "chroma":
|
|
print("testing chroma passage storage")
|
|
# nothing to do (should be config defaults)
|
|
else:
|
|
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
|
|
config.save()
|
|
|
|
# setup storage connectors
|
|
print("Creating storage connectors...")
|
|
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
|
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
|
|
|
# load data
|
|
name = "test_dataset"
|
|
cache_dir = "CONTRIBUTING.md"
|
|
|
|
# TODO: load two different data sources
|
|
|
|
# clear out data
|
|
print("Resetting tables with delete_table...")
|
|
data_source_conn.delete_table()
|
|
passages_conn.delete_table()
|
|
print("Re-creating tables...")
|
|
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
|
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
|
assert (
|
|
data_source_conn.size() == 0
|
|
), f"Expected 0 records, got {data_source_conn.size()}: {[vars(r) for r in data_source_conn.get_all()]}"
|
|
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"
|
|
|
|
# test: load directory
|
|
print("Loading directory")
|
|
load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False) # cache_dir,
|
|
|
|
# test to see if contained in storage
|
|
print("Querying table...")
|
|
sources = data_source_conn.get_all({"name": name})
|
|
assert len(sources) == 1, f"Expected 1 source, but got {len(sources)}"
|
|
assert sources[0].name == name, f"Expected name {name}, but got {sources[0].name}"
|
|
print("Source", sources)
|
|
|
|
# test to see if contained in storage
|
|
assert (
|
|
len(passages_conn.get_all()) == passages_conn.size()
|
|
), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all())}"
|
|
passages = passages_conn.get_all({"data_source": name})
|
|
print("Source", [p.data_source for p in passages])
|
|
print("All sources", [p.data_source for p in passages_conn.get_all()])
|
|
assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}"
|
|
assert len(passages) == passages_conn.size(), f"Expected {passages_conn.size()} passages, but got {len(passages)}"
|
|
assert [p.data_source == name for p in passages]
|
|
print("Passages", passages)
|
|
|
|
# test: listing sources
|
|
print("Querying all...")
|
|
sources = data_source_conn.get_all()
|
|
print("All sources", [s.name for s in sources])
|
|
|
|
# test loading into an agent
|
|
# create agent
|
|
agent_config = AgentConfig(
|
|
name="memgpt_test_agent",
|
|
persona=DEFAULT_PERSONA,
|
|
human=DEFAULT_HUMAN,
|
|
model=DEFAULT_MEMGPT_MODEL,
|
|
)
|
|
agent_config.save()
|
|
# create storage connector
|
|
print("Creating agent archival storage connector...")
|
|
conn = StorageConnector.get_storage_connector(
|
|
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
|
|
)
|
|
print("Deleting agent archival table...")
|
|
conn.delete_table()
|
|
conn = StorageConnector.get_storage_connector(
|
|
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
|
|
)
|
|
assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
|
|
|
|
# attach data
|
|
print("Attaching data...")
|
|
attach(agent=agent_config.name, data_source=name)
|
|
|
|
# test to see if contained in storage
|
|
assert len(passages) == conn.size()
|
|
assert len(passages) == len(conn.get_all({"data_source": name}))
|
|
|
|
# test: delete source
|
|
data_source_conn.delete({"name": name})
|
|
passages_conn.delete({"data_source": name})
|
|
assert len(data_source_conn.get_all({"name": name})) == 0
|
|
assert len(passages_conn.get_all({"data_source": name})) == 0
|