mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
249 lines
8.2 KiB
Python
249 lines
8.2 KiB
Python
import os
|
|
import subprocess
|
|
import sys
|
|
import pytest
|
|
|
|
subprocess.check_call(
|
|
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
|
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
|
|
|
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
|
import pgvector # Try to import again after installing
|
|
|
|
from memgpt.connectors.storage import StorageConnector, Passage
|
|
from memgpt.connectors.chroma import ChromaStorageConnector
|
|
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
|
from memgpt.embeddings import embedding_model
|
|
from memgpt.data_types import Message, Passage
|
|
from memgpt.config import MemGPTConfig, AgentConfig
|
|
|
|
import argparse
|
|
|
|
|
|
def test_recall_db() -> None:
|
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
|
|
|
storage_type = "postgres"
|
|
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
|
config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri)
|
|
print(config.config_path)
|
|
assert config.recall_storage_uri is not None
|
|
config.save()
|
|
print(config)
|
|
|
|
conn = StorageConnector.get_recall_storage_connector()
|
|
|
|
# construct recall memory messages
|
|
message1 = Message(
|
|
agent_id="test_agent1",
|
|
role="agent",
|
|
content="This is a test message",
|
|
id="test_id1",
|
|
)
|
|
message2 = Message(
|
|
agent_id="test_agent2",
|
|
role="user",
|
|
content="This is a test message",
|
|
id="test_id2",
|
|
)
|
|
|
|
# test insert
|
|
conn.insert(message1)
|
|
conn.insert_many([message2])
|
|
|
|
# test size
|
|
assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}"
|
|
assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}"
|
|
|
|
# test get
|
|
assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}"
|
|
assert (
|
|
len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1
|
|
), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}"
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
|
def test_postgres_openai():
|
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
|
return # soft pass
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
return # soft pass
|
|
|
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
|
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
config.archival_storage_uri = config.archival_storage_uri.replace(
|
|
"postgres://", "postgresql://"
|
|
) # https://stackoverflow.com/a/64698899
|
|
config.save()
|
|
print(config)
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = PostgresStorageConnector(name="test-openai")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
# TODO fix (causes a hang for some reason)
|
|
# print("deleting...")
|
|
# db.delete()
|
|
# print("...finished")
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
|
|
def test_chroma_openai():
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
return # soft pass
|
|
|
|
config = MemGPTConfig(
|
|
archival_storage_type="chroma",
|
|
archival_storage_path="./test_chroma",
|
|
embedding_endpoint_type="openai",
|
|
embedding_dim=1536,
|
|
model="gpt4",
|
|
model_endpoint_type="openai",
|
|
model_endpoint="https://api.openai.com/v1",
|
|
)
|
|
config.save()
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = ChromaStorageConnector(name="test-openai")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(query, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
print(res[0].text)
|
|
|
|
print("deleting")
|
|
db.delete()
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key"
|
|
)
|
|
def test_lancedb_openai():
|
|
assert os.getenv("LANCEDB_TEST_URL") is not None
|
|
if os.getenv("OPENAI_API_KEY") is None:
|
|
return # soft pass
|
|
|
|
config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
print(config)
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = LanceDBConnector(name="test-openai")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
|
|
def test_postgres_local():
|
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
|
return
|
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
|
|
|
config = MemGPTConfig(
|
|
archival_storage_type="postgres",
|
|
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
|
embedding_endpoint_type="local",
|
|
embedding_dim=384, # use HF model
|
|
)
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
config.archival_storage_uri = config.archival_storage_uri.replace(
|
|
"postgres://", "postgresql://"
|
|
) # https://stackoverflow.com/a/64698899
|
|
config.save()
|
|
print(config)
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = PostgresStorageConnector(name="test-local")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
# TODO fix (causes a hang for some reason)
|
|
# print("deleting...")
|
|
# db.delete()
|
|
# print("...finished")
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI")
|
|
def test_lancedb_local():
|
|
assert os.getenv("LANCEDB_TEST_URL") is not None
|
|
|
|
config = MemGPTConfig(
|
|
archival_storage_type="lancedb",
|
|
archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
|
|
embedding_model="local",
|
|
embedding_dim=384, # use HF model
|
|
)
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = LanceDBConnector(name="test-local")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|