mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
196 lines
8.0 KiB
Python
196 lines
8.0 KiB
Python
import uuid
|
|
import os
|
|
import memgpt.utils as utils
|
|
from dotenv import load_dotenv
|
|
|
|
utils.DEBUG = True
|
|
from memgpt.config import MemGPTConfig
|
|
from memgpt.credentials import MemGPTCredentials
|
|
from memgpt.server.server import SyncServer
|
|
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
|
from memgpt.embeddings import embedding_model
|
|
from memgpt.presets.presets import add_default_presets
|
|
from .utils import wipe_config, wipe_memgpt_home
|
|
|
|
|
|
def test_server():
|
|
load_dotenv()
|
|
wipe_memgpt_home()
|
|
|
|
# Use os.getenv with a fallback to os.environ.get
|
|
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
|
|
|
|
if os.getenv("OPENAI_API_KEY"):
|
|
config = MemGPTConfig(
|
|
archival_storage_uri=db_url,
|
|
recall_storage_uri=db_url,
|
|
metadata_storage_uri=db_url,
|
|
archival_storage_type="postgres",
|
|
recall_storage_type="postgres",
|
|
metadata_storage_type="postgres",
|
|
# embeddings
|
|
default_embedding_config=EmbeddingConfig(
|
|
embedding_endpoint_type="openai",
|
|
embedding_endpoint="https://api.openai.com/v1",
|
|
embedding_dim=1536,
|
|
),
|
|
# llms
|
|
default_llm_config=LLMConfig(
|
|
model_endpoint_type="openai",
|
|
model_endpoint="https://api.openai.com/v1",
|
|
model="gpt-4",
|
|
),
|
|
)
|
|
credentials = MemGPTCredentials(
|
|
openai_key=os.getenv("OPENAI_API_KEY"),
|
|
)
|
|
else: # hosted
|
|
config = MemGPTConfig(
|
|
archival_storage_uri=db_url,
|
|
recall_storage_uri=db_url,
|
|
metadata_storage_uri=db_url,
|
|
archival_storage_type="postgres",
|
|
recall_storage_type="postgres",
|
|
metadata_storage_type="postgres",
|
|
# embeddings
|
|
default_embedding_config=EmbeddingConfig(
|
|
embedding_endpoint_type="hugging-face",
|
|
embedding_endpoint="https://embeddings.memgpt.ai",
|
|
embedding_model="BAAI/bge-large-en-v1.5",
|
|
embedding_dim=1024,
|
|
),
|
|
# llms
|
|
default_llm_config=LLMConfig(
|
|
model_endpoint_type="vllm",
|
|
model_endpoint="https://api.memgpt.ai",
|
|
model="ehartford/dolphin-2.5-mixtral-8x7b",
|
|
),
|
|
)
|
|
credentials = MemGPTCredentials()
|
|
|
|
config.save()
|
|
credentials.save()
|
|
|
|
server = SyncServer()
|
|
|
|
# create user
|
|
user = server.create_user()
|
|
print(f"Created user\n{user.id}")
|
|
|
|
try:
|
|
fake_agent_id = uuid.uuid4()
|
|
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
|
|
raise Exception("user_message call should have failed")
|
|
except (KeyError, ValueError) as e:
|
|
# Error is expected
|
|
print(e)
|
|
except:
|
|
raise
|
|
|
|
# create presets
|
|
add_default_presets(user.id, server.ms)
|
|
|
|
# create agent
|
|
agent_state = server.create_agent(
|
|
user_id=user.id,
|
|
agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"),
|
|
)
|
|
print(f"Created agent\n{agent_state}")
|
|
|
|
try:
|
|
server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory")
|
|
raise Exception("user_message call should have failed")
|
|
except ValueError as e:
|
|
# Error is expected
|
|
print(e)
|
|
except:
|
|
raise
|
|
|
|
print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory"))
|
|
|
|
# add data into archival memory
|
|
agent = server._load_agent(user_id=user.id, agent_id=agent_state.id)
|
|
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
|
embed_model = embedding_model(agent.agent_state.embedding_config)
|
|
for text in archival_memories:
|
|
embedding = embed_model.get_text_embedding(text)
|
|
agent.persistence_manager.archival_memory.storage.insert(
|
|
Passage(
|
|
user_id=user.id,
|
|
agent_id=agent_state.id,
|
|
text=text,
|
|
embedding=embedding,
|
|
embedding_dim=agent.agent_state.embedding_config.embedding_dim,
|
|
embedding_model=agent.agent_state.embedding_config.embedding_model,
|
|
)
|
|
)
|
|
|
|
# add data into recall memory
|
|
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
|
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
|
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
|
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
|
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
|
|
|
# test recall memory cursor pagination
|
|
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2)
|
|
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
|
|
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000)
|
|
ids3 = [m["id"] for m in messages_3]
|
|
ids2 = [m["id"] for m in messages_2]
|
|
timestamps = [m["created_at"] for m in messages_3]
|
|
print("timestamps", timestamps)
|
|
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
|
|
assert len(messages_3) == len(messages_1) + len(messages_2)
|
|
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
|
|
assert len(messages_4) == 1
|
|
|
|
# test in-context message ids
|
|
in_context_ids = server.get_in_context_message_ids(user_id=user.id, agent_id=agent_state.id)
|
|
assert len(in_context_ids) == len(messages_3)
|
|
assert isinstance(in_context_ids[0], uuid.UUID)
|
|
message_ids = [m["id"] for m in messages_3]
|
|
for message_id in message_ids:
|
|
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
|
|
|
|
# test archival memory cursor pagination
|
|
cursor1, passages_1 = server.get_agent_archival_cursor(
|
|
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
|
|
)
|
|
cursor2, passages_2 = server.get_agent_archival_cursor(
|
|
user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
|
|
)
|
|
cursor3, passages_3 = server.get_agent_archival_cursor(
|
|
user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
|
)
|
|
print("p1", [p["text"] for p in passages_1])
|
|
print("p2", [p["text"] for p in passages_2])
|
|
print("p3", [p["text"] for p in passages_3])
|
|
assert passages_1[0]["text"] == "alpha"
|
|
assert len(passages_2) == 3
|
|
assert len(passages_3) == 4
|
|
|
|
# test recall memory
|
|
messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
|
assert len(messages_1) == 1
|
|
messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
|
messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5)
|
|
# not sure exactly how many messages there should be
|
|
assert len(messages_2) > len(messages_3)
|
|
# test safe empty return
|
|
messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
|
assert len(messages_none) == 0
|
|
|
|
# test archival memory
|
|
passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
|
assert len(passage_1) == 1
|
|
passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
|
assert len(passage_2) == 4
|
|
# test safe empty return
|
|
passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
|
assert len(passage_none) == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
test_server()
|