mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Support recall and archival memory for postgres
working test
This commit is contained in:
parent
9f3806dfcb
commit
9b3d59e016
@ -7,7 +7,7 @@ import traceback
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from memgpt.memory import CoreMemory as Memory, summarize_messages
|
||||
from memgpt.memory import CoreMemory as InContextMemory, summarize_messages
|
||||
from memgpt.openai_tools import create, is_context_overflow_error
|
||||
from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response
|
||||
from memgpt.constants import (
|
||||
@ -29,7 +29,7 @@ def initialize_memory(ai_notes, human_notes):
|
||||
raise ValueError(ai_notes)
|
||||
if human_notes is None:
|
||||
raise ValueError(human_notes)
|
||||
memory = Memory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
|
||||
memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
|
||||
memory.edit_persona(ai_notes)
|
||||
memory.edit_human(human_notes)
|
||||
return memory
|
||||
@ -240,6 +240,7 @@ class Agent(object):
|
||||
|
||||
### Local state management
|
||||
def to_dict(self):
|
||||
# TODO: select specific variables for the saves state (to eventually move to a DB) rather than checkpointing everything in the class
|
||||
return {
|
||||
"model": self.model,
|
||||
"system": self.system,
|
||||
@ -249,32 +250,29 @@ class Agent(object):
|
||||
"memory": self.memory.to_dict(),
|
||||
}
|
||||
|
||||
def save_to_json_file(self, filename):
|
||||
def save_agent_state_json(self, filename):
|
||||
"""Save agent state to JSON"""
|
||||
with open(filename, "w") as file:
|
||||
json.dump(self.to_dict(), file)
|
||||
|
||||
def save(self):
|
||||
"""Save agent state locally"""
|
||||
|
||||
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
|
||||
agent_name = self.config.name # TODO: fix
|
||||
|
||||
# save config
|
||||
self.config.save()
|
||||
|
||||
# save agent state
|
||||
# save agent state to timestamped file
|
||||
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
|
||||
filename = f"{timestamp}.json"
|
||||
os.makedirs(self.config.save_state_dir(), exist_ok=True)
|
||||
self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename))
|
||||
self.save_agent_state_json(os.path.join(self.config.save_state_dir(), filename))
|
||||
|
||||
# save the persistence manager too
|
||||
filename = f"{timestamp}.persistence.pickle"
|
||||
os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True)
|
||||
self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename))
|
||||
# save the persistence manager too (recall/archival memory)
|
||||
self.persistence_manager.save()
|
||||
|
||||
@classmethod
|
||||
def load_agent(cls, interface, agent_config: AgentConfig):
|
||||
"""Load saved agent state"""
|
||||
"""Load saved agent state based on agent_config"""
|
||||
# TODO: support loading from specific file
|
||||
agent_name = agent_config.name
|
||||
|
||||
@ -290,10 +288,7 @@ class Agent(object):
|
||||
state = json.load(open(filename, "r"))
|
||||
|
||||
# load persistence manager
|
||||
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
|
||||
directory = agent_config.save_persistence_manager_dir()
|
||||
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
|
||||
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)
|
||||
persistence_manager = LocalStateManager.load(agent_config)
|
||||
|
||||
# need to dynamically link the functions
|
||||
# the saved agent.functions will just have the schemas, but we need to
|
||||
@ -354,70 +349,6 @@ class Agent(object):
|
||||
agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"])
|
||||
return agent
|
||||
|
||||
@classmethod
|
||||
def load(cls, state, interface, persistence_manager):
|
||||
model = state["model"]
|
||||
system = state["system"]
|
||||
functions = state["functions"]
|
||||
messages = state["messages"]
|
||||
try:
|
||||
messages_total = state["messages_total"]
|
||||
except KeyError:
|
||||
messages_total = len(messages) - 1
|
||||
# memory requires a nested load
|
||||
memory_dict = state["memory"]
|
||||
persona_notes = memory_dict["persona"]
|
||||
human_notes = memory_dict["human"]
|
||||
|
||||
# Two-part load
|
||||
new_agent = cls(
|
||||
model=model,
|
||||
system=system,
|
||||
functions=functions,
|
||||
interface=interface,
|
||||
persistence_manager=persistence_manager,
|
||||
persistence_manager_init=False,
|
||||
persona_notes=persona_notes,
|
||||
human_notes=human_notes,
|
||||
messages_total=messages_total,
|
||||
)
|
||||
new_agent._messages = messages
|
||||
return new_agent
|
||||
|
||||
def load_inplace(self, state):
|
||||
self.model = state["model"]
|
||||
self.system = state["system"]
|
||||
self.functions = state["functions"]
|
||||
# memory requires a nested load
|
||||
memory_dict = state["memory"]
|
||||
persona_notes = memory_dict["persona"]
|
||||
human_notes = memory_dict["human"]
|
||||
self.memory = initialize_memory(persona_notes, human_notes)
|
||||
# messages also
|
||||
self._messages = state["messages"]
|
||||
try:
|
||||
self.messages_total = state["messages_total"]
|
||||
except KeyError:
|
||||
self.messages_total = len(self.messages) - 1 # -system
|
||||
|
||||
@classmethod
|
||||
def load_from_json(cls, json_state, interface, persistence_manager):
|
||||
state = json.loads(json_state)
|
||||
return cls.load(state, interface, persistence_manager)
|
||||
|
||||
@classmethod
|
||||
def load_from_json_file(cls, json_file, interface, persistence_manager):
|
||||
with open(json_file, "r") as file:
|
||||
state = json.load(file)
|
||||
return cls.load(state, interface, persistence_manager)
|
||||
|
||||
def load_from_json_file_inplace(self, json_file):
|
||||
# Load in-place
|
||||
# No interface arg needed, we can use the current one
|
||||
with open(json_file, "r") as file:
|
||||
state = json.load(file)
|
||||
self.load_inplace(state)
|
||||
|
||||
def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False):
|
||||
"""Can be used to enforce that the first message always uses send_message"""
|
||||
response_message = response.choices[0].message
|
||||
|
@ -418,6 +418,22 @@ def configure_archival_storage(config: MemGPTConfig):
|
||||
# TODO: allow configuring embedding model
|
||||
|
||||
|
||||
def configure_recall_storage(config: MemGPTConfig):
|
||||
# Configure recall storage backend
|
||||
recall_storage_options = ["local", "postgres"]
|
||||
recall_storage_type = questionary.select(
|
||||
"Select storage backend for recall data:", recall_storage_options, default=config.recall_storage_type
|
||||
).ask()
|
||||
recall_storage_uri, recall_storage_path = None, None
|
||||
# configure postgres
|
||||
if recall_storage_type == "postgres":
|
||||
recall_storage_uri = questionary.text(
|
||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||
default=config.recall_storage_uri if config.recall_storage_uri else "",
|
||||
).ask()
|
||||
return recall_storage_type, recall_storage_uri, recall_storage_path
|
||||
|
||||
|
||||
@app.command()
|
||||
def configure():
|
||||
"""Updates default MemGPT configurations"""
|
||||
@ -430,17 +446,30 @@ def configure():
|
||||
|
||||
# Will pre-populate with defaults, or what the user previously set
|
||||
config = MemGPTConfig.load()
|
||||
try:
|
||||
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
||||
model, model_wrapper, context_window = configure_model(
|
||||
config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
|
||||
)
|
||||
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
|
||||
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
||||
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
|
||||
except ValueError as e:
|
||||
typer.secho(str(e), fg=typer.colors.RED)
|
||||
return
|
||||
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
||||
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
|
||||
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
|
||||
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
||||
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
|
||||
recall_storage_type, recall_storage_uri, recall_storage_path = configure_recall_storage(config)
|
||||
|
||||
# check credentials
|
||||
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
|
||||
openai_key = get_openai_credentials()
|
||||
if model_endpoint_type == "azure" or embedding_endpoint_type == "azure":
|
||||
if all([azure_key, azure_endpoint, azure_version]):
|
||||
print(f"Using Microsoft endpoint {azure_endpoint}.")
|
||||
if all([azure_deployment, azure_embedding_deployment]):
|
||||
print(f"Using deployment id {azure_deployment}")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again."
|
||||
)
|
||||
if model_endpoint_type == "openai" or embedding_endpoint_type == "openai":
|
||||
if not openai_key:
|
||||
raise ValueError(
|
||||
"Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again."
|
||||
)
|
||||
|
||||
config = MemGPTConfig(
|
||||
# model configs
|
||||
@ -470,6 +499,10 @@ def configure():
|
||||
archival_storage_type=archival_storage_type,
|
||||
archival_storage_uri=archival_storage_uri,
|
||||
archival_storage_path=archival_storage_path,
|
||||
# recall storage
|
||||
recall_storage_type=recall_storage_type,
|
||||
recall_storage_uri=recall_storage_uri,
|
||||
recall_storage_path=recall_storage_path,
|
||||
)
|
||||
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
|
||||
config.save()
|
||||
|
@ -134,6 +134,9 @@ class MemGPTConfig:
|
||||
"archival_storage_type": get_field(config, "archival_storage", "type"),
|
||||
"archival_storage_path": get_field(config, "archival_storage", "path"),
|
||||
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
|
||||
"recall_storage_type": get_field(config, "recall_storage", "type"),
|
||||
"recall_storage_path": get_field(config, "recall_storage", "path"),
|
||||
"recall_storage_uri": get_field(config, "recall_storage", "uri"),
|
||||
"anon_clientid": get_field(config, "client", "anon_clientid"),
|
||||
"config_path": config_path,
|
||||
"memgpt_version": get_field(config, "version", "memgpt_version"),
|
||||
@ -187,6 +190,11 @@ class MemGPTConfig:
|
||||
set_field(config, "archival_storage", "path", self.archival_storage_path)
|
||||
set_field(config, "archival_storage", "uri", self.archival_storage_uri)
|
||||
|
||||
# recall storage
|
||||
set_field(config, "recall_storage", "type", self.recall_storage_type)
|
||||
set_field(config, "recall_storage", "path", self.recall_storage_path)
|
||||
set_field(config, "recall_storage", "uri", self.recall_storage_uri)
|
||||
|
||||
# set version
|
||||
set_field(config, "version", "memgpt_version", memgpt.__version__)
|
||||
|
||||
|
@ -2,22 +2,10 @@ import chromadb
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, List, Iterator
|
||||
from memgpt.connectors.storage import StorageConnector, Passage
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.utils import printd
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
|
||||
|
||||
def create_chroma_client():
|
||||
config = MemGPTConfig.load()
|
||||
# create chroma client
|
||||
if config.archival_storage_path:
|
||||
client = chromadb.PersistentClient(config.archival_storage_path)
|
||||
else:
|
||||
# assume uri={ip}:{port}
|
||||
ip = config.archival_storage_uri.split(":")[0]
|
||||
port = config.archival_storage_uri.split(":")[1]
|
||||
client = chromadb.HttpClient(host=ip, port=port)
|
||||
return client
|
||||
from memgpt.data_types import Record, Message, Passage
|
||||
|
||||
|
||||
class ChromaStorageConnector(StorageConnector):
|
||||
@ -25,21 +13,24 @@ class ChromaStorageConnector(StorageConnector):
|
||||
|
||||
# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.
|
||||
|
||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
# determine table name
|
||||
if agent_config:
|
||||
assert name is None, f"Cannot specify both agent config and name {name}"
|
||||
self.table_name = self.generate_table_name_agent(agent_config)
|
||||
elif name:
|
||||
assert agent_config is None, f"Cannot specify both agent config and name {name}"
|
||||
self.table_name = self.generate_table_name(name)
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# supported table types
|
||||
self.supported_types = [TableType.ARCHIVAL_MEMORY]
|
||||
|
||||
if table_type not in self.supported_types:
|
||||
raise ValueError(f"Table type {table_type} not supported by Chroma")
|
||||
|
||||
# create chroma client
|
||||
if config.archival_storage_path:
|
||||
self.client = chromadb.PersistentClient(config.archival_storage_path)
|
||||
else:
|
||||
raise ValueError("Must specify either agent config or name")
|
||||
|
||||
printd(f"Using table name {self.table_name}")
|
||||
|
||||
# create client
|
||||
self.client = create_chroma_client()
|
||||
# assume uri={ip}:{port}
|
||||
ip = config.archival_storage_uri.split(":")[0]
|
||||
port = config.archival_storage_uri.split(":")[1]
|
||||
self.client = chromadb.HttpClient(host=ip, port=port)
|
||||
|
||||
# get a collection or create if it doesn't exist already
|
||||
self.collection = self.client.get_or_create_collection(self.table_name)
|
||||
@ -98,28 +89,5 @@ class ChromaStorageConnector(StorageConnector):
|
||||
collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")]
|
||||
return collections
|
||||
|
||||
def sanitize_table_name(self, name: str) -> str:
|
||||
# Remove leading and trailing whitespace
|
||||
name = name.strip()
|
||||
|
||||
# Replace spaces and invalid characters with underscores
|
||||
name = re.sub(r"\s+|\W+", "_", name)
|
||||
|
||||
# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
|
||||
max_length = 63
|
||||
if len(name) > max_length:
|
||||
name = name[:max_length].rstrip("_")
|
||||
|
||||
# Convert to lowercase
|
||||
name = name.lower()
|
||||
|
||||
return name
|
||||
|
||||
def generate_table_name_agent(self, agent_config: AgentConfig):
|
||||
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
|
||||
|
||||
def generate_table_name(self, name: str):
|
||||
return f"memgpt_{self.sanitize_table_name(name)}"
|
||||
|
||||
def size(self) -> int:
|
||||
return self.collection.count()
|
||||
|
@ -1,5 +1,5 @@
|
||||
from pgvector.psycopg import register_vector
|
||||
from pgvector.sqlalchemy import Vector, JSON, Text
|
||||
from pgvector.sqlalchemy import Vector
|
||||
import psycopg
|
||||
|
||||
|
||||
@ -8,6 +8,8 @@ from sqlalchemy.orm import sessionmaker, mapped_column
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy import Column, BIGINT, String, DateTime
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy_json import mutable_json_type
|
||||
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
@ -28,15 +30,10 @@ from datetime import datetime
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def parse_formatted_time(formatted_time):
|
||||
# parse times returned by memgpt.utils.get_formatted_time()
|
||||
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
||||
def get_db_model(table_name: str, table_type: TableType):
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
if table_name == TableType.ARCHIVAL_MEMORY:
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
# create schema for archival memory
|
||||
class PassageModel(Base):
|
||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||
@ -45,16 +42,29 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
text = Column(String, nullable=False)
|
||||
doc_id = Column(String)
|
||||
agent_id = Column(String)
|
||||
data_source = Column(String) # agent_name if agent, data_source name if from data source
|
||||
text = Column(String, nullable=False)
|
||||
embedding = mapped_column(Vector(config.embedding_dim))
|
||||
metadata_ = Column(JSON(astext_type=Text()))
|
||||
metadata_ = Column(mutable_json_type(dbtype=JSONB, nested=True))
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
doc_id=self.doc_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
data_source=self.data_source,
|
||||
agent_id=self.agent_id,
|
||||
metadata=self.metadata_,
|
||||
)
|
||||
|
||||
"""Create database model for table_name"""
|
||||
class_name = f"{table_name.capitalize()}Model"
|
||||
Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
|
||||
@ -71,7 +81,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
role = Column(String, nullable=False)
|
||||
content = Column(String, nullable=False)
|
||||
text = Column(String, nullable=False)
|
||||
model = Column(String, nullable=False)
|
||||
function_name = Column(String)
|
||||
function_args = Column(String)
|
||||
@ -82,7 +92,22 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', content='{self.content}', embedding='{self.embedding})>"
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
return Message(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
role=self.role,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
function_name=self.function_name,
|
||||
function_args=self.function_args,
|
||||
function_response=self.function_response,
|
||||
embedding=self.embedding,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
|
||||
"""Create database model for table_name"""
|
||||
class_name = f"{table_name.capitalize()}Model"
|
||||
@ -101,11 +126,20 @@ class PostgresStorageConnector(StorageConnector):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# get storage URI
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = config.recall_storage_uri
|
||||
if config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}")
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
# create table
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||
self.db_model = get_db_model(self.table_name)
|
||||
self.db_model = get_db_model(self.table_name, table_type)
|
||||
self.engine = create_engine(self.uri)
|
||||
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
@ -132,47 +166,47 @@ class PostgresStorageConnector(StorageConnector):
|
||||
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
db_passages = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
return [self.type(**p.to_dict()) for p in db_passages]
|
||||
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
return [record.to_record() for record in db_records]
|
||||
|
||||
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Passage]:
|
||||
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
db_passage = session.query(self.db_model).filter(*filters).get(id)
|
||||
if db_passage is None:
|
||||
db_record = session.query(self.db_model).filter(*filters).get(id)
|
||||
if db_record is None:
|
||||
return None
|
||||
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)
|
||||
return db_record.to_record()
|
||||
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# return size of table
|
||||
print("size")
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
return session.query(self.db_model).filter(*filters).count()
|
||||
|
||||
def insert(self, passage: Passage):
|
||||
def insert(self, record: Record):
|
||||
session = self.Session()
|
||||
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
|
||||
session.add(db_passage)
|
||||
db_record = self.db_model(**vars(record))
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
session = self.Session()
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for passage in iterable:
|
||||
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
|
||||
session.add(db_passage)
|
||||
for record in iterable:
|
||||
db_record = self.db_model(**vars(record))
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
# Assuming PassageModel.embedding has the capability of computing l2_distance
|
||||
filters = self.get_filters(filters)
|
||||
results = session.scalars(
|
||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||
).all()
|
||||
|
||||
# Convert the results into Passage objects
|
||||
records = [self.type(**vars(result)) for result in results]
|
||||
records = [result.to_record() for result in results]
|
||||
return records
|
||||
|
||||
def save(self):
|
||||
@ -195,6 +229,26 @@ class PostgresStorageConnector(StorageConnector):
|
||||
tables = [table[start_chars:] for table in tables]
|
||||
return tables
|
||||
|
||||
def query_date(self, start_date, end_date):
|
||||
session = self.Session()
|
||||
filters = self.get_filters({})
|
||||
results = (
|
||||
session.query(self.db_model)
|
||||
.filter(*filters)
|
||||
.filter(self.db_model.created_at >= start_date)
|
||||
.filter(self.db_model.created_at <= end_date)
|
||||
.all()
|
||||
)
|
||||
return [result.to_record() for result in results]
|
||||
|
||||
def query_text(self, query):
|
||||
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
|
||||
session = self.Session()
|
||||
filters = self.get_filters({})
|
||||
results = session.query(self.db_model).filter(*filters).filter(self.db_model.text.contains(query)).all()
|
||||
print(results)
|
||||
# return [self.type(**vars(result)) for result in results]
|
||||
return [result.to_record() for result in results]
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
@ -251,7 +305,7 @@ class LanceDBConnector(StorageConnector):
|
||||
|
||||
def get(self, id: str) -> Optional[Passage]:
|
||||
db_passage = self.table.where(f"passage_id={id}").to_list()
|
||||
if len(db_passage) == 0:
|
||||
if len(db_passage) == 0:
|
||||
return None
|
||||
return Passage(
|
||||
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
|
||||
|
@ -6,7 +6,8 @@ import pickle
|
||||
import os
|
||||
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict
|
||||
from abc import abstractmethod
|
||||
|
||||
from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context
|
||||
from llama_index.indices.empty.base import EmptyIndex
|
||||
@ -145,11 +146,9 @@ class InMemoryStorageConnector(StorageConnector):
|
||||
|
||||
# TODO: maybae replace this with sqllite?
|
||||
|
||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
from memgpt.embeddings import embedding_model
|
||||
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
config = MemGPTConfig.load()
|
||||
# TODO: figure out save location
|
||||
|
||||
self.rows = []
|
||||
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, List, Iterator
|
||||
import re
|
||||
import pickle
|
||||
import os
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from typing import List, Optional, Dict
|
||||
from tqdm import tqdm
|
||||
@ -14,6 +14,7 @@ from tqdm import tqdm
|
||||
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.data_types import Record, Passage, Document, Message
|
||||
from memgpt.utils import printd
|
||||
|
||||
|
||||
# ENUM representing table types in MemGPT
|
||||
@ -28,9 +29,9 @@ class TableType:
|
||||
|
||||
|
||||
# table names used by MemGPT
|
||||
RECALL_TABLE_NAME = "memgpt_recall_memory"
|
||||
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory"
|
||||
PASSAGE_TABLE_NAME = "memgpt_passages"
|
||||
RECALL_TABLE_NAME = "memgpt_recall_memory_agent" # agent memory
|
||||
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory_agent" # agent memory
|
||||
PASSAGE_TABLE_NAME = "memgpt_passages" # loads data sources
|
||||
DOCUMENT_TABLE_NAME = "memgpt_documents"
|
||||
|
||||
|
||||
@ -65,9 +66,10 @@ class StorageConnector:
|
||||
# get all filters for query
|
||||
if filters is not None:
|
||||
filter_conditions = {**self.filters, **filters}
|
||||
return self.filters + [self.db_model[key] == value for key, value in filter_conditions.items()]
|
||||
else:
|
||||
return self.filters
|
||||
filter_conditions = self.filters
|
||||
print("FILTERS", filter_conditions)
|
||||
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||
|
||||
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
|
||||
|
||||
@ -102,18 +104,20 @@ class StorageConnector:
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
|
||||
return VectorIndexStorageConnector(agent_config=agent_config)
|
||||
return VectorIndexStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
|
||||
elif storage_type == "postgres":
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
|
||||
return PostgresStorageConnector(agent_config=agent_config)
|
||||
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
elif storage_type == "chroma":
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
|
||||
return ChromaStorageConnector(name=name, agent_config=agent_config)
|
||||
return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
elif storage_type == "lancedb":
|
||||
from memgpt.connectors.db import LanceDBConnector
|
||||
|
||||
return LanceDBConnector(agent_config=agent_config)
|
||||
return LanceDBConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
@ -122,6 +126,8 @@ class StorageConnector:
|
||||
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
|
||||
storage_type = MemGPTConfig.load().recall_storage_type
|
||||
|
||||
print("Recall storage type", storage_type)
|
||||
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import InMemoryStorageConnector
|
||||
|
||||
|
@ -20,11 +20,10 @@ class Record:
|
||||
self.text = text
|
||||
self.id = id
|
||||
# todo: generate unique uuid
|
||||
# todo: timestamp
|
||||
# todo: self.role = role (?)
|
||||
|
||||
def __repr__(self):
|
||||
pass
|
||||
# def __repr__(self):
|
||||
# pass
|
||||
|
||||
|
||||
class Message(Record):
|
||||
@ -35,17 +34,19 @@ class Message(Record):
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
text: str,
|
||||
model: str, # model used to make function call
|
||||
created_at: Optional[str] = None,
|
||||
function_name: Optional[str] = None, # name of function called
|
||||
function_args: Optional[str] = None, # args of function called
|
||||
function_response: Optional[str] = None, # response of function called
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
id: Optional[str] = None,
|
||||
):
|
||||
super().__init__(user_id, agent_id, content, id)
|
||||
super().__init__(user_id, agent_id, text, id)
|
||||
self.role = role # role (agent/user/function)
|
||||
self.model = model # model name (e.g. gpt-4)
|
||||
self.created_at = created_at
|
||||
|
||||
# function call info (optional)
|
||||
self.function_name = function_name
|
||||
@ -55,15 +56,15 @@ class Message(Record):
|
||||
# embedding (optional)
|
||||
self.embedding = embedding
|
||||
|
||||
def __repr__(self):
|
||||
pass
|
||||
# def __repr__(self):
|
||||
# pass
|
||||
|
||||
|
||||
class Document(Record):
|
||||
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
|
||||
|
||||
def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
|
||||
super().__init__(user_id)
|
||||
super().__init__(user_id, agent_id, text, id)
|
||||
self.text = text
|
||||
self.document_id = document_id
|
||||
self.data_source = data_source
|
||||
@ -76,25 +77,26 @@ class Document(Record):
|
||||
class Passage(Record):
|
||||
"""A passage is a single unit of memory, and a standard format accross all storage backends.
|
||||
|
||||
It is a string of text with an associated embedding.
|
||||
It is a string of text with an assoidciated embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str,
|
||||
text: str,
|
||||
data_source: str,
|
||||
embedding: np.ndarray,
|
||||
agent_id: Optional[str] = None, # set if contained in agent memory
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
data_source: Optional[str] = None, # None if created by agent
|
||||
doc_id: Optional[str] = None,
|
||||
passage_id: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
metadata: Optional[dict] = {},
|
||||
):
|
||||
super().__init__(user_id)
|
||||
super().__init__(user_id, agent_id, text, id)
|
||||
self.text = text
|
||||
self.data_source = data_source
|
||||
self.embedding = embedding
|
||||
self.doc_id = doc_id
|
||||
self.passage_id = passage_id
|
||||
self.metadata = {}
|
||||
self.metadata = metadata
|
||||
|
||||
def __repr__(self):
|
||||
return f"Passage(text={self.text}, embedding={self.embedding})"
|
||||
|
@ -284,7 +284,10 @@ class DummyRecallMemory(RecallMemory):
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class RecallMemorySQL(RecallMemory):
|
||||
class BaseRecallMemory(RecallMemory):
|
||||
|
||||
"""Recall memory based on base functions implemented by storage connectors"""
|
||||
|
||||
def __init__(self, agent_config, restrict_search_to_summaries=False):
|
||||
|
||||
# If true, the pool of messages that can be queried are the automated summaries only
|
||||
@ -304,20 +307,39 @@ class RecallMemorySQL(RecallMemory):
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
@abstractmethod
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
self.storage.query_text(query_string, count, start)
|
||||
|
||||
@abstractmethod
|
||||
def date_search(self, query_string, count=None, start=None):
|
||||
pass
|
||||
def date_search(self, start_date, end_date, count=None, start=None):
|
||||
self.storage.query_date(start_date, end_date, count, start)
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
total = self.storage.size()
|
||||
system_count = self.storage.size(filters={"role": "system"})
|
||||
user_count = self.storage.size(filters={"role": "user"})
|
||||
assistant_count = self.storage.size(filters={"role": "assistant"})
|
||||
function_count = self.storage.size(filters={"role": "function"})
|
||||
other_count = total - (system_count + user_count + assistant_count + function_count)
|
||||
|
||||
memory_str = (
|
||||
f"Statistics:"
|
||||
+ f"\n{total} total messages"
|
||||
+ f"\n{system_count} system"
|
||||
+ f"\n{user_count} user"
|
||||
+ f"\n{assistant_count} assistant"
|
||||
+ f"\n{function_count} function"
|
||||
+ f"\n{other_count} other"
|
||||
)
|
||||
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
def insert(self, message: Message):
|
||||
pass
|
||||
self.storage.insert(message)
|
||||
|
||||
def insert_many(self, messages: List[Message]):
|
||||
self.storage.insert_many(messages)
|
||||
|
||||
def save(self):
|
||||
self.storage.save()
|
||||
|
||||
|
||||
class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
@ -333,24 +355,31 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
|
||||
self.top_k = top_k
|
||||
self.agent_config = agent_config
|
||||
config = MemGPTConfig.load()
|
||||
self.config = MemGPTConfig.load()
|
||||
|
||||
# create embedding model
|
||||
self.embed_model = embedding_model()
|
||||
self.embedding_chunk_size = config.embedding_chunk_size
|
||||
self.embedding_chunk_size = self.config.embedding_chunk_size
|
||||
|
||||
# create storage backend
|
||||
self.storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config)
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
def create_passage(self, text, embedding):
|
||||
return Passage(
|
||||
user_id=self.config.anon_clientid,
|
||||
agent_id=self.agent_config.name,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
def save(self):
|
||||
"""Save the index to disk"""
|
||||
self.storage.save()
|
||||
|
||||
def insert(self, memory_string):
|
||||
"""Embed and save memory string"""
|
||||
from memgpt.connectors.storage import Passage
|
||||
|
||||
if not isinstance(memory_string, str):
|
||||
return TypeError("memory must be a string")
|
||||
@ -364,17 +393,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
# breakup string into passages
|
||||
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
|
||||
embedding = self.embed_model.get_text_embedding(node.text)
|
||||
# fixing weird bug where type returned isn't a list, but instead is an object
|
||||
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
|
||||
if isinstance(embedding, dict):
|
||||
try:
|
||||
embedding = embedding["data"][0]["embedding"]
|
||||
except (KeyError, IndexError):
|
||||
# TODO as a fallback, see if we can find any lists in the payload
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
passages.append(Passage(text=node.text, embedding=embedding, doc_id=f"agent_{self.agent_config.name}_memory"))
|
||||
passages.append(self.create_passage(node.text, embedding))
|
||||
|
||||
# insert passages
|
||||
self.storage.insert_many(passages)
|
||||
|
@ -3,9 +3,19 @@ import pickle
|
||||
from memgpt.config import AgentConfig
|
||||
from memgpt.memory import (
|
||||
DummyRecallMemory,
|
||||
BaseRecallMemory,
|
||||
EmbeddingArchivalMemory,
|
||||
)
|
||||
from memgpt.utils import get_local_time, printd, OpenAIBackcompatUnpickler
|
||||
from memgpt.utils import get_local_time, printd
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def parse_formatted_time(formatted_time):
|
||||
# parse times returned by memgpt.utils.get_formatted_time()
|
||||
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
||||
class PersistenceManager(ABC):
|
||||
@ -33,83 +43,60 @@ class PersistenceManager(ABC):
|
||||
class LocalStateManager(PersistenceManager):
|
||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
||||
|
||||
recall_memory_cls = DummyRecallMemory
|
||||
recall_memory_cls = BaseRecallMemory
|
||||
archival_memory_cls = EmbeddingArchivalMemory
|
||||
|
||||
def __init__(self, agent_config: AgentConfig):
|
||||
# Memory held in-state useful for debugging stateful versions
|
||||
self.memory = None
|
||||
self.messages = []
|
||||
self.all_messages = []
|
||||
self.messages = [] # current in-context messages
|
||||
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
|
||||
self.archival_memory = EmbeddingArchivalMemory(agent_config)
|
||||
self.recall_memory = None
|
||||
self.recall_memory = BaseRecallMemory(agent_config)
|
||||
self.agent_config = agent_config
|
||||
self.config = MemGPTConfig.load()
|
||||
|
||||
@classmethod
|
||||
def load(cls, filename, agent_config: AgentConfig):
|
||||
def load(cls, agent_config: AgentConfig):
|
||||
""" Load a LocalStateManager from a file. """ ""
|
||||
try:
|
||||
with open(filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
except ModuleNotFoundError as e:
|
||||
# Patch for stripped openai package
|
||||
# ModuleNotFoundError: No module named 'openai.openai_object'
|
||||
with open(filename, "rb") as f:
|
||||
unpickler = OpenAIBackcompatUnpickler(f)
|
||||
data = unpickler.load()
|
||||
# print(f"Unpickled data:\n{data.keys()}")
|
||||
|
||||
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
||||
|
||||
def convert_openai_objects_to_dict(obj):
|
||||
if isinstance(obj, OpenAIObject):
|
||||
# Convert to dict or handle as needed
|
||||
# print(f"detected OpenAIObject on {obj}")
|
||||
return obj.to_dict_recursive()
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_openai_objects_to_dict(v) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
data = convert_openai_objects_to_dict(data)
|
||||
# print(f"Converted data:\n{data.keys()}")
|
||||
|
||||
# TODO: remove this class and just init the class
|
||||
manager = cls(agent_config)
|
||||
manager.all_messages = data["all_messages"]
|
||||
manager.messages = data["messages"]
|
||||
manager.recall_memory = data["recall_memory"]
|
||||
manager.archival_memory = EmbeddingArchivalMemory(agent_config)
|
||||
return manager
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, "wb") as fh:
|
||||
## TODO: fix this hacky solution to pickle the retriever
|
||||
self.archival_memory.save()
|
||||
pickle.dump(
|
||||
{
|
||||
"recall_memory": self.recall_memory,
|
||||
"messages": self.messages,
|
||||
"all_messages": self.all_messages,
|
||||
},
|
||||
fh,
|
||||
protocol=pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
printd(f"Saved state to {fh}")
|
||||
def save(self):
|
||||
"""Ensure storage connectors save data"""
|
||||
self.archival_memory.save()
|
||||
self.recall_memory.save()
|
||||
|
||||
def init(self, agent):
|
||||
"""Connect persistent state manager to agent"""
|
||||
printd(f"Initializing {self.__class__.__name__} with agent object")
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
# self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
||||
# printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
||||
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
||||
|
||||
# Persistence manager also handles DB-related state
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
# self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
|
||||
# TODO: init archival memory here?
|
||||
def json_to_message(self, message_json) -> Message:
|
||||
"""Convert agent message JSON into Message object"""
|
||||
timestamp = message_json["timestamp"]
|
||||
message = message_json["message"]
|
||||
|
||||
return Message(
|
||||
user_id=self.config.anon_clientid,
|
||||
agent_id=self.agent_config.name,
|
||||
role=message["role"],
|
||||
text=message["content"],
|
||||
model=self.agent_config.model,
|
||||
created_at=parse_formatted_time(timestamp),
|
||||
function_name=message["function_name"] if "function_name" in message else None,
|
||||
function_args=message["function_args"] if "function_args" in message else None,
|
||||
function_response=message["function_response"] if "function_response" in message else None,
|
||||
id=message["id"] if "id" in message else None,
|
||||
)
|
||||
|
||||
def trim_messages(self, num):
|
||||
# printd(f"InMemoryStateManager.trim_messages")
|
||||
@ -121,7 +108,9 @@ class LocalStateManager(PersistenceManager):
|
||||
|
||||
printd(f"{self.__class__.__name__}.prepend_to_message")
|
||||
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
# add to recall memory
|
||||
self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages])
|
||||
|
||||
def append_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
@ -129,7 +118,9 @@ class LocalStateManager(PersistenceManager):
|
||||
|
||||
printd(f"{self.__class__.__name__}.append_to_messages")
|
||||
self.messages = self.messages + added_messages
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
# add to recall memory
|
||||
self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages])
|
||||
|
||||
def swap_system_message(self, new_system_message):
|
||||
# first tag with timestamps
|
||||
@ -137,96 +128,9 @@ class LocalStateManager(PersistenceManager):
|
||||
|
||||
printd(f"{self.__class__.__name__}.swap_system_message")
|
||||
self.messages[0] = new_system_message
|
||||
self.all_messages.append(new_system_message)
|
||||
|
||||
def update_memory(self, new_memory):
|
||||
printd(f"{self.__class__.__name__}.update_memory")
|
||||
self.memory = new_memory
|
||||
|
||||
|
||||
class StateManager(PersistenceManager):
|
||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
||||
|
||||
recall_memory_cls = RecallMemory
|
||||
archival_memory_cls = EmbeddingArchivalMemory
|
||||
|
||||
def __init__(self, agent_config: AgentConfig):
|
||||
# Memory held in-state useful for debugging stateful versions
|
||||
self.memory = None
|
||||
self.messages = []
|
||||
self.all_messages = []
|
||||
self.archival_memory = EmbeddingArchivalMemory(agent_config)
|
||||
self.recall_memory = None
|
||||
self.agent_config = agent_config
|
||||
|
||||
@classmethod
|
||||
def load(cls, filename, agent_config: AgentConfig):
|
||||
""" Load a LocalStateManager from a file. """ ""
|
||||
with open(filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
|
||||
manager = cls(agent_config)
|
||||
manager.all_messages = data["all_messages"]
|
||||
manager.messages = data["messages"]
|
||||
manager.recall_memory = data["recall_memory"]
|
||||
manager.archival_memory = EmbeddingArchivalMemory(agent_config)
|
||||
return manager
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, "wb") as fh:
|
||||
## TODO: fix this hacky solution to pickle the retriever
|
||||
self.archival_memory.save()
|
||||
pickle.dump(
|
||||
{
|
||||
"recall_memory": self.recall_memory,
|
||||
"messages": self.messages,
|
||||
"all_messages": self.all_messages,
|
||||
},
|
||||
fh,
|
||||
protocol=pickle.HIGHEST_PROTOCOL,
|
||||
)
|
||||
printd(f"Saved state to {fh}")
|
||||
|
||||
def init(self, agent):
|
||||
printd(f"Initializing {self.__class__.__name__} with agent object")
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
||||
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
||||
|
||||
# Persistence manager also handles DB-related state
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
|
||||
# TODO: init archival memory here?
|
||||
|
||||
def trim_messages(self, num):
|
||||
# printd(f"InMemoryStateManager.trim_messages")
|
||||
self.messages = [self.messages[0]] + self.messages[num:]
|
||||
|
||||
def prepend_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"{self.__class__.__name__}.prepend_to_message")
|
||||
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
def append_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"{self.__class__.__name__}.append_to_messages")
|
||||
self.messages = self.messages + added_messages
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
def swap_system_message(self, new_system_message):
|
||||
# first tag with timestamps
|
||||
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
||||
|
||||
printd(f"{self.__class__.__name__}.swap_system_message")
|
||||
self.messages[0] = new_system_message
|
||||
self.all_messages.append(new_system_message)
|
||||
|
||||
# add to recall memory
|
||||
self.recall_memory.insert(self.json_to_message(new_system_message))
|
||||
|
||||
def update_memory(self, new_memory):
|
||||
printd(f"{self.__class__.__name__}.update_memory")
|
||||
|
@ -117,10 +117,11 @@ def get_local_time(timezone=None):
|
||||
time_str = get_local_time_timezone(timezone)
|
||||
else:
|
||||
# Get the current time, which will be in the local timezone of the computer
|
||||
local_time = datetime.now()
|
||||
local_time = datetime.now().astimezone()
|
||||
|
||||
# You may format it as you desire, including AM/PM
|
||||
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
print("formatted_time", formatted_time)
|
||||
|
||||
return time_str.strip()
|
||||
|
||||
|
19
poetry.lock
generated
19
poetry.lock
generated
@ -3822,6 +3822,23 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
||||
pymysql = ["pymysql"]
|
||||
sqlcipher = ["sqlcipher3-binary"]
|
||||
|
||||
[[package]]
|
||||
name = "sqlalchemy-json"
|
||||
version = "0.7.0"
|
||||
description = "JSON type with nested change tracking for SQLAlchemy"
|
||||
optional = false
|
||||
python-versions = ">= 3.6"
|
||||
files = [
|
||||
{file = "sqlalchemy-json-0.7.0.tar.gz", hash = "sha256:620d0b26f648f21a8fa9127df66f55f83a5ab4ae010e5397a5c6989a08238561"},
|
||||
{file = "sqlalchemy_json-0.7.0-py3-none-any.whl", hash = "sha256:27881d662ca18363a4ac28175cc47ea2a6f2bef997ae1159c151026b741818e6"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
sqlalchemy = ">=0.7"
|
||||
|
||||
[package.extras]
|
||||
dev = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "starlette"
|
||||
version = "0.27.0"
|
||||
@ -4913,4 +4930,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.12,>=3.9"
|
||||
content-hash = "12010863b2b9c1e26dceace00ea4e1ea7cc95932ab77b1ef37a5473c2e375575"
|
||||
content-hash = "a6b8cdeda433007b6d33441b82417e435df336cea458af5dedce9a94f07f1225"
|
||||
|
@ -49,14 +49,8 @@ tiktoken = "^0.5.1"
|
||||
python-box = "^7.1.1"
|
||||
pypdf = "^3.17.1"
|
||||
pyyaml = "^6.0.1"
|
||||
fastapi = {version = "^0.104.1", optional = true}
|
||||
uvicorn = {version = "^0.24.0.post1", optional = true}
|
||||
chromadb = "^0.4.18"
|
||||
pytest-asyncio = {version = "^0.23.2", optional = true}
|
||||
pydantic = "^2.5.2"
|
||||
pyautogen = {version = "0.2.0", optional = true}
|
||||
html2text = "^2020.1.16"
|
||||
docx2txt = "^0.8"
|
||||
chromadb = {version = "^0.4.18", optional = true}
|
||||
sqlalchemy-json = "^0.7.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
|
@ -3,63 +3,90 @@ import subprocess
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
||||
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
||||
# subprocess.check_call(
|
||||
# [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
||||
# ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
||||
#
|
||||
# subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
||||
import pgvector # Try to import again after installing
|
||||
|
||||
from memgpt.connectors.storage import StorageConnector, Passage
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Message, Passage
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.utils import get_local_time
|
||||
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
def test_recall_db() -> None:
|
||||
def test_recall_db():
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
|
||||
storage_type = "postgres"
|
||||
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri)
|
||||
config = MemGPTConfig(
|
||||
recall_storage_type=storage_type,
|
||||
recall_storage_uri=storage_uri,
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt4",
|
||||
)
|
||||
print(config.config_path)
|
||||
assert config.recall_storage_uri is not None
|
||||
config.save()
|
||||
print(config)
|
||||
|
||||
conn = StorageConnector.get_recall_storage_connector()
|
||||
agent_config = AgentConfig(
|
||||
persona=config.persona,
|
||||
human=config.human,
|
||||
model=config.model,
|
||||
)
|
||||
|
||||
conn = StorageConnector.get_recall_storage_connector(agent_config)
|
||||
|
||||
# construct recall memory messages
|
||||
message1 = Message(
|
||||
agent_id="test_agent1",
|
||||
agent_id=agent_config.name,
|
||||
role="agent",
|
||||
content="This is a test message",
|
||||
id="test_id1",
|
||||
text="This is a test message",
|
||||
user_id=config.anon_clientid,
|
||||
model=agent_config.model,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
message2 = Message(
|
||||
agent_id="test_agent2",
|
||||
agent_id=agent_config.name,
|
||||
role="user",
|
||||
content="This is a test message",
|
||||
id="test_id2",
|
||||
text="This is a test message",
|
||||
user_id=config.anon_clientid,
|
||||
model=agent_config.model,
|
||||
created_at=datetime.now(),
|
||||
)
|
||||
print(vars(message1))
|
||||
|
||||
# test insert
|
||||
conn.insert(message1)
|
||||
conn.insert_many([message2])
|
||||
|
||||
# test size
|
||||
assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}"
|
||||
assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}"
|
||||
assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}"
|
||||
assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}'
|
||||
|
||||
# test get
|
||||
assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}"
|
||||
assert (
|
||||
len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1
|
||||
), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}"
|
||||
# test text query
|
||||
res = conn.query_text("test")
|
||||
print(res)
|
||||
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||
|
||||
# test date query
|
||||
current_time = datetime.now()
|
||||
ten_weeks_ago = current_time - timedelta(weeks=1)
|
||||
res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time)
|
||||
print(res)
|
||||
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||
|
||||
print(conn.get_all())
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
||||
@ -83,10 +110,28 @@ def test_postgres_openai():
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
|
||||
db = PostgresStorageConnector(name="test-openai")
|
||||
agent_config = AgentConfig(
|
||||
name="test_agent",
|
||||
persona=config.persona,
|
||||
human=config.human,
|
||||
model=config.model,
|
||||
)
|
||||
|
||||
db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||
|
||||
# db.delete()
|
||||
# return
|
||||
for passage in passage:
|
||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
db.insert(
|
||||
Passage(
|
||||
text=passage,
|
||||
embedding=embed_model.get_text_embedding(passage),
|
||||
user_id=config.anon_clientid,
|
||||
agent_id="test_agent",
|
||||
data_source="test",
|
||||
metadata={"test_metadata_key": "test_metadata_value"},
|
||||
)
|
||||
)
|
||||
|
||||
print(db.get_all())
|
||||
|
||||
@ -246,3 +291,6 @@ def test_lancedb_local():
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
|
||||
test_recall_db()
|
||||
|
Loading…
Reference in New Issue
Block a user