Support recall and archival memory for postgres

working test
This commit is contained in:
Sarah Wooders 2023-12-08 18:27:26 -08:00
parent 9f3806dfcb
commit 9b3d59e016
14 changed files with 393 additions and 409 deletions

View File

@ -7,7 +7,7 @@ import traceback
from memgpt.persistence_manager import LocalStateManager
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
from memgpt.memory import CoreMemory as Memory, summarize_messages
from memgpt.memory import CoreMemory as InContextMemory, summarize_messages
from memgpt.openai_tools import create, is_context_overflow_error
from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response
from memgpt.constants import (
@ -29,7 +29,7 @@ def initialize_memory(ai_notes, human_notes):
raise ValueError(ai_notes)
if human_notes is None:
raise ValueError(human_notes)
memory = Memory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
memory.edit_persona(ai_notes)
memory.edit_human(human_notes)
return memory
@ -240,6 +240,7 @@ class Agent(object):
### Local state management
def to_dict(self):
# TODO: select specific variables for the saves state (to eventually move to a DB) rather than checkpointing everything in the class
return {
"model": self.model,
"system": self.system,
@ -249,32 +250,29 @@ class Agent(object):
"memory": self.memory.to_dict(),
}
def save_to_json_file(self, filename):
def save_agent_state_json(self, filename):
"""Save agent state to JSON"""
with open(filename, "w") as file:
json.dump(self.to_dict(), file)
def save(self):
"""Save agent state locally"""
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
agent_name = self.config.name # TODO: fix
# save config
self.config.save()
# save agent state
# save agent state to timestamped file
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
filename = f"{timestamp}.json"
os.makedirs(self.config.save_state_dir(), exist_ok=True)
self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename))
self.save_agent_state_json(os.path.join(self.config.save_state_dir(), filename))
# save the persistence manager too
filename = f"{timestamp}.persistence.pickle"
os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True)
self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename))
# save the persistence manager too (recall/archival memory)
self.persistence_manager.save()
@classmethod
def load_agent(cls, interface, agent_config: AgentConfig):
"""Load saved agent state"""
"""Load saved agent state based on agent_config"""
# TODO: support loading from specific file
agent_name = agent_config.name
@ -290,10 +288,7 @@ class Agent(object):
state = json.load(open(filename, "r"))
# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)
persistence_manager = LocalStateManager.load(agent_config)
# need to dynamically link the functions
# the saved agent.functions will just have the schemas, but we need to
@ -354,70 +349,6 @@ class Agent(object):
agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"])
return agent
@classmethod
def load(cls, state, interface, persistence_manager):
model = state["model"]
system = state["system"]
functions = state["functions"]
messages = state["messages"]
try:
messages_total = state["messages_total"]
except KeyError:
messages_total = len(messages) - 1
# memory requires a nested load
memory_dict = state["memory"]
persona_notes = memory_dict["persona"]
human_notes = memory_dict["human"]
# Two-part load
new_agent = cls(
model=model,
system=system,
functions=functions,
interface=interface,
persistence_manager=persistence_manager,
persistence_manager_init=False,
persona_notes=persona_notes,
human_notes=human_notes,
messages_total=messages_total,
)
new_agent._messages = messages
return new_agent
def load_inplace(self, state):
self.model = state["model"]
self.system = state["system"]
self.functions = state["functions"]
# memory requires a nested load
memory_dict = state["memory"]
persona_notes = memory_dict["persona"]
human_notes = memory_dict["human"]
self.memory = initialize_memory(persona_notes, human_notes)
# messages also
self._messages = state["messages"]
try:
self.messages_total = state["messages_total"]
except KeyError:
self.messages_total = len(self.messages) - 1 # -system
@classmethod
def load_from_json(cls, json_state, interface, persistence_manager):
state = json.loads(json_state)
return cls.load(state, interface, persistence_manager)
@classmethod
def load_from_json_file(cls, json_file, interface, persistence_manager):
with open(json_file, "r") as file:
state = json.load(file)
return cls.load(state, interface, persistence_manager)
def load_from_json_file_inplace(self, json_file):
# Load in-place
# No interface arg needed, we can use the current one
with open(json_file, "r") as file:
state = json.load(file)
self.load_inplace(state)
def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False):
"""Can be used to enforce that the first message always uses send_message"""
response_message = response.choices[0].message

View File

@ -418,6 +418,22 @@ def configure_archival_storage(config: MemGPTConfig):
# TODO: allow configuring embedding model
def configure_recall_storage(config: MemGPTConfig):
# Configure recall storage backend
recall_storage_options = ["local", "postgres"]
recall_storage_type = questionary.select(
"Select storage backend for recall data:", recall_storage_options, default=config.recall_storage_type
).ask()
recall_storage_uri, recall_storage_path = None, None
# configure postgres
if recall_storage_type == "postgres":
recall_storage_uri = questionary.text(
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.recall_storage_uri if config.recall_storage_uri else "",
).ask()
return recall_storage_type, recall_storage_uri, recall_storage_path
@app.command()
def configure():
"""Updates default MemGPT configurations"""
@ -430,17 +446,30 @@ def configure():
# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()
try:
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(
config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
except ValueError as e:
typer.secho(str(e), fg=typer.colors.RED)
return
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
default_preset, default_persona, default_human, default_agent = configure_cli(config)
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
recall_storage_type, recall_storage_uri, recall_storage_path = configure_recall_storage(config)
# check credentials
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
openai_key = get_openai_credentials()
if model_endpoint_type == "azure" or embedding_endpoint_type == "azure":
if all([azure_key, azure_endpoint, azure_version]):
print(f"Using Microsoft endpoint {azure_endpoint}.")
if all([azure_deployment, azure_embedding_deployment]):
print(f"Using deployment id {azure_deployment}")
else:
raise ValueError(
"Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again."
)
if model_endpoint_type == "openai" or embedding_endpoint_type == "openai":
if not openai_key:
raise ValueError(
"Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again."
)
config = MemGPTConfig(
# model configs
@ -470,6 +499,10 @@ def configure():
archival_storage_type=archival_storage_type,
archival_storage_uri=archival_storage_uri,
archival_storage_path=archival_storage_path,
# recall storage
recall_storage_type=recall_storage_type,
recall_storage_uri=recall_storage_uri,
recall_storage_path=recall_storage_path,
)
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
config.save()

View File

@ -134,6 +134,9 @@ class MemGPTConfig:
"archival_storage_type": get_field(config, "archival_storage", "type"),
"archival_storage_path": get_field(config, "archival_storage", "path"),
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
"recall_storage_type": get_field(config, "recall_storage", "type"),
"recall_storage_path": get_field(config, "recall_storage", "path"),
"recall_storage_uri": get_field(config, "recall_storage", "uri"),
"anon_clientid": get_field(config, "client", "anon_clientid"),
"config_path": config_path,
"memgpt_version": get_field(config, "version", "memgpt_version"),
@ -187,6 +190,11 @@ class MemGPTConfig:
set_field(config, "archival_storage", "path", self.archival_storage_path)
set_field(config, "archival_storage", "uri", self.archival_storage_uri)
# recall storage
set_field(config, "recall_storage", "type", self.recall_storage_type)
set_field(config, "recall_storage", "path", self.recall_storage_path)
set_field(config, "recall_storage", "uri", self.recall_storage_uri)
# set version
set_field(config, "version", "memgpt_version", memgpt.__version__)

View File

@ -2,22 +2,10 @@ import chromadb
import json
import re
from typing import Optional, List, Iterator
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.utils import printd
from memgpt.config import AgentConfig, MemGPTConfig
def create_chroma_client():
config = MemGPTConfig.load()
# create chroma client
if config.archival_storage_path:
client = chromadb.PersistentClient(config.archival_storage_path)
else:
# assume uri={ip}:{port}
ip = config.archival_storage_uri.split(":")[0]
port = config.archival_storage_uri.split(":")[1]
client = chromadb.HttpClient(host=ip, port=port)
return client
from memgpt.data_types import Record, Message, Passage
class ChromaStorageConnector(StorageConnector):
@ -25,21 +13,24 @@ class ChromaStorageConnector(StorageConnector):
# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
# determine table name
if agent_config:
assert name is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name_agent(agent_config)
elif name:
assert agent_config is None, f"Cannot specify both agent config and name {name}"
self.table_name = self.generate_table_name(name)
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
# supported table types
self.supported_types = [TableType.ARCHIVAL_MEMORY]
if table_type not in self.supported_types:
raise ValueError(f"Table type {table_type} not supported by Chroma")
# create chroma client
if config.archival_storage_path:
self.client = chromadb.PersistentClient(config.archival_storage_path)
else:
raise ValueError("Must specify either agent config or name")
printd(f"Using table name {self.table_name}")
# create client
self.client = create_chroma_client()
# assume uri={ip}:{port}
ip = config.archival_storage_uri.split(":")[0]
port = config.archival_storage_uri.split(":")[1]
self.client = chromadb.HttpClient(host=ip, port=port)
# get a collection or create if it doesn't exist already
self.collection = self.client.get_or_create_collection(self.table_name)
@ -98,28 +89,5 @@ class ChromaStorageConnector(StorageConnector):
collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")]
return collections
def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()
# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)
# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")
# Convert to lowercase
name = name.lower()
return name
def generate_table_name_agent(self, agent_config: AgentConfig):
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"
def size(self) -> int:
return self.collection.count()

View File

@ -1,5 +1,5 @@
from pgvector.psycopg import register_vector
from pgvector.sqlalchemy import Vector, JSON, Text
from pgvector.sqlalchemy import Vector
import psycopg
@ -8,6 +8,8 @@ from sqlalchemy.orm import sessionmaker, mapped_column
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
from sqlalchemy import Column, BIGINT, String, DateTime
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy_json import mutable_json_type
import re
from tqdm import tqdm
@ -28,15 +30,10 @@ from datetime import datetime
Base = declarative_base()
def parse_formatted_time(formatted_time):
# parse times returned by memgpt.utils.get_formatted_time()
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
def get_db_model(table_name: str, table_type: TableType):
config = MemGPTConfig.load()
if table_name == TableType.ARCHIVAL_MEMORY:
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
# create schema for archival memory
class PassageModel(Base):
"""Defines data model for storing Passages (consisting of text, embedding)"""
@ -45,16 +42,29 @@ def get_db_model(table_name: str, table_type: TableType):
# Assuming passage_id is the primary key
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
user_id = Column(String, nullable=False)
text = Column(String, nullable=False)
doc_id = Column(String)
agent_id = Column(String)
data_source = Column(String) # agent_name if agent, data_source name if from data source
text = Column(String, nullable=False)
embedding = mapped_column(Vector(config.embedding_dim))
metadata_ = Column(JSON(astext_type=Text()))
metadata_ = Column(mutable_json_type(dbtype=JSONB, nested=True))
def __repr__(self):
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Passage(
text=self.text,
embedding=self.embedding,
doc_id=self.doc_id,
user_id=self.user_id,
id=self.id,
data_source=self.data_source,
agent_id=self.agent_id,
metadata=self.metadata_,
)
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model"
Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
@ -71,7 +81,7 @@ def get_db_model(table_name: str, table_type: TableType):
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
role = Column(String, nullable=False)
content = Column(String, nullable=False)
text = Column(String, nullable=False)
model = Column(String, nullable=False)
function_name = Column(String)
function_args = Column(String)
@ -82,7 +92,22 @@ def get_db_model(table_name: str, table_type: TableType):
created_at = Column(DateTime(timezone=True), server_default=func.now())
def __repr__(self):
return f"<Message(message_id='{self.id}', content='{self.content}', embedding='{self.embedding})>"
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
text=self.text,
model=self.model,
function_name=self.function_name,
function_args=self.function_args,
function_response=self.function_response,
embedding=self.embedding,
created_at=self.created_at,
id=self.id,
)
"""Create database model for table_name"""
class_name = f"{table_name.capitalize()}Model"
@ -101,11 +126,20 @@ class PostgresStorageConnector(StorageConnector):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY:
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = config.recall_storage_uri
if config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
self.db_model = get_db_model(self.table_name)
self.db_model = get_db_model(self.table_name, table_type)
self.engine = create_engine(self.uri)
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
self.Session = sessionmaker(bind=self.engine)
@ -132,47 +166,47 @@ class PostgresStorageConnector(StorageConnector):
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
filters = self.get_filters(filters)
db_passages = session.query(self.db_model).filter(*filters).limit(limit).all()
return [self.type(**p.to_dict()) for p in db_passages]
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
return [record.to_record() for record in db_records]
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Passage]:
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Record]:
session = self.Session()
filters = self.get_filters(filters)
db_passage = session.query(self.db_model).filter(*filters).get(id)
if db_passage is None:
db_record = session.query(self.db_model).filter(*filters).get(id)
if db_record is None:
return None
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)
return db_record.to_record()
def size(self, filters: Optional[Dict] = {}) -> int:
# return size of table
print("size")
session = self.Session()
filters = self.get_filters(filters)
return session.query(self.db_model).filter(*filters).count()
def insert(self, passage: Passage):
def insert(self, record: Record):
session = self.Session()
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
session.add(db_passage)
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
def insert_many(self, records: List[Record], show_progress=True):
session = self.Session()
iterable = tqdm(records) if show_progress else records
for passage in iterable:
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
session.add(db_passage)
for record in iterable:
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
# Assuming PassageModel.embedding has the capability of computing l2_distance
filters = self.get_filters(filters)
results = session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
# Convert the results into Passage objects
records = [self.type(**vars(result)) for result in results]
records = [result.to_record() for result in results]
return records
def save(self):
@ -195,6 +229,26 @@ class PostgresStorageConnector(StorageConnector):
tables = [table[start_chars:] for table in tables]
return tables
def query_date(self, start_date, end_date):
session = self.Session()
filters = self.get_filters({})
results = (
session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= start_date)
.filter(self.db_model.created_at <= end_date)
.all()
)
return [result.to_record() for result in results]
def query_text(self, query):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
session = self.Session()
filters = self.get_filters({})
results = session.query(self.db_model).filter(*filters).filter(self.db_model.text.contains(query)).all()
print(results)
# return [self.type(**vars(result)) for result in results]
return [result.to_record() for result in results]
class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""
@ -251,7 +305,7 @@ class LanceDBConnector(StorageConnector):
def get(self, id: str) -> Optional[Passage]:
db_passage = self.table.where(f"passage_id={id}").to_list()
if len(db_passage) == 0:
if len(db_passage) == 0:
return None
return Passage(
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]

View File

@ -6,7 +6,8 @@ import pickle
import os
from typing import List, Optional
from typing import List, Optional, Dict
from abc import abstractmethod
from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context
from llama_index.indices.empty.base import EmptyIndex
@ -145,11 +146,9 @@ class InMemoryStorageConnector(StorageConnector):
# TODO: maybae replace this with sqllite?
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
from memgpt.embeddings import embedding_model
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
# TODO: figure out save location
self.rows = []

View File

@ -6,7 +6,7 @@ from typing import Any, Optional, List, Iterator
import re
import pickle
import os
from abc import abstractmethod
from typing import List, Optional, Dict
from tqdm import tqdm
@ -14,6 +14,7 @@ from tqdm import tqdm
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.data_types import Record, Passage, Document, Message
from memgpt.utils import printd
# ENUM representing table types in MemGPT
@ -28,9 +29,9 @@ class TableType:
# table names used by MemGPT
RECALL_TABLE_NAME = "memgpt_recall_memory"
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory"
PASSAGE_TABLE_NAME = "memgpt_passages"
RECALL_TABLE_NAME = "memgpt_recall_memory_agent" # agent memory
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory_agent" # agent memory
PASSAGE_TABLE_NAME = "memgpt_passages" # loads data sources
DOCUMENT_TABLE_NAME = "memgpt_documents"
@ -65,9 +66,10 @@ class StorageConnector:
# get all filters for query
if filters is not None:
filter_conditions = {**self.filters, **filters}
return self.filters + [self.db_model[key] == value for key, value in filter_conditions.items()]
else:
return self.filters
filter_conditions = self.filters
print("FILTERS", filter_conditions)
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
@ -102,18 +104,20 @@ class StorageConnector:
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector
return VectorIndexStorageConnector(agent_config=agent_config)
return VectorIndexStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
elif storage_type == "postgres":
from memgpt.connectors.db import PostgresStorageConnector
return PostgresStorageConnector(agent_config=agent_config)
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
elif storage_type == "chroma":
from memgpt.connectors.chroma import ChromaStorageConnector
return ChromaStorageConnector(name=name, agent_config=agent_config)
return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector
return LanceDBConnector(agent_config=agent_config)
return LanceDBConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@ -122,6 +126,8 @@ class StorageConnector:
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
storage_type = MemGPTConfig.load().recall_storage_type
print("Recall storage type", storage_type)
if storage_type == "local":
from memgpt.connectors.local import InMemoryStorageConnector

View File

@ -20,11 +20,10 @@ class Record:
self.text = text
self.id = id
# todo: generate unique uuid
# todo: timestamp
# todo: self.role = role (?)
def __repr__(self):
pass
# def __repr__(self):
# pass
class Message(Record):
@ -35,17 +34,19 @@ class Message(Record):
user_id: str,
agent_id: str,
role: str,
content: str,
text: str,
model: str, # model used to make function call
created_at: Optional[str] = None,
function_name: Optional[str] = None, # name of function called
function_args: Optional[str] = None, # args of function called
function_response: Optional[str] = None, # response of function called
embedding: Optional[np.ndarray] = None,
id: Optional[str] = None,
):
super().__init__(user_id, agent_id, content, id)
super().__init__(user_id, agent_id, text, id)
self.role = role # role (agent/user/function)
self.model = model # model name (e.g. gpt-4)
self.created_at = created_at
# function call info (optional)
self.function_name = function_name
@ -55,15 +56,15 @@ class Message(Record):
# embedding (optional)
self.embedding = embedding
def __repr__(self):
pass
# def __repr__(self):
# pass
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
super().__init__(user_id)
super().__init__(user_id, agent_id, text, id)
self.text = text
self.document_id = document_id
self.data_source = data_source
@ -76,25 +77,26 @@ class Document(Record):
class Passage(Record):
"""A passage is a single unit of memory, and a standard format accross all storage backends.
It is a string of text with an associated embedding.
It is a string of text with an assoidciated embedding.
"""
def __init__(
self,
user_id: str,
text: str,
data_source: str,
embedding: np.ndarray,
agent_id: Optional[str] = None, # set if contained in agent memory
embedding: Optional[np.ndarray] = None,
data_source: Optional[str] = None, # None if created by agent
doc_id: Optional[str] = None,
passage_id: Optional[str] = None,
id: Optional[str] = None,
metadata: Optional[dict] = {},
):
super().__init__(user_id)
super().__init__(user_id, agent_id, text, id)
self.text = text
self.data_source = data_source
self.embedding = embedding
self.doc_id = doc_id
self.passage_id = passage_id
self.metadata = {}
self.metadata = metadata
def __repr__(self):
return f"Passage(text={self.text}, embedding={self.embedding})"

View File

@ -284,7 +284,10 @@ class DummyRecallMemory(RecallMemory):
return matches, len(matches)
class RecallMemorySQL(RecallMemory):
class BaseRecallMemory(RecallMemory):
"""Recall memory based on base functions implemented by storage connectors"""
def __init__(self, agent_config, restrict_search_to_summaries=False):
# If true, the pool of messages that can be queried are the automated summaries only
@ -304,20 +307,39 @@ class RecallMemorySQL(RecallMemory):
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
@abstractmethod
def text_search(self, query_string, count=None, start=None):
pass
self.storage.query_text(query_string, count, start)
@abstractmethod
def date_search(self, query_string, count=None, start=None):
pass
def date_search(self, start_date, end_date, count=None, start=None):
self.storage.query_date(start_date, end_date, count, start)
@abstractmethod
def __repr__(self) -> str:
pass
total = self.storage.size()
system_count = self.storage.size(filters={"role": "system"})
user_count = self.storage.size(filters={"role": "user"})
assistant_count = self.storage.size(filters={"role": "assistant"})
function_count = self.storage.size(filters={"role": "function"})
other_count = total - (system_count + user_count + assistant_count + function_count)
memory_str = (
f"Statistics:"
+ f"\n{total} total messages"
+ f"\n{system_count} system"
+ f"\n{user_count} user"
+ f"\n{assistant_count} assistant"
+ f"\n{function_count} function"
+ f"\n{other_count} other"
)
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
def insert(self, message: Message):
pass
self.storage.insert(message)
def insert_many(self, messages: List[Message]):
self.storage.insert_many(messages)
def save(self):
self.storage.save()
class EmbeddingArchivalMemory(ArchivalMemory):
@ -333,24 +355,31 @@ class EmbeddingArchivalMemory(ArchivalMemory):
self.top_k = top_k
self.agent_config = agent_config
config = MemGPTConfig.load()
self.config = MemGPTConfig.load()
# create embedding model
self.embed_model = embedding_model()
self.embedding_chunk_size = config.embedding_chunk_size
self.embedding_chunk_size = self.config.embedding_chunk_size
# create storage backend
self.storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
def create_passage(self, text, embedding):
return Passage(
user_id=self.config.anon_clientid,
agent_id=self.agent_config.name,
text=text,
embedding=embedding,
)
def save(self):
"""Save the index to disk"""
self.storage.save()
def insert(self, memory_string):
"""Embed and save memory string"""
from memgpt.connectors.storage import Passage
if not isinstance(memory_string, str):
return TypeError("memory must be a string")
@ -364,17 +393,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
# breakup string into passages
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
embedding = self.embed_model.get_text_embedding(node.text)
# fixing weird bug where type returned isn't a list, but instead is an object
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
if isinstance(embedding, dict):
try:
embedding = embedding["data"][0]["embedding"]
except (KeyError, IndexError):
# TODO as a fallback, see if we can find any lists in the payload
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
passages.append(Passage(text=node.text, embedding=embedding, doc_id=f"agent_{self.agent_config.name}_memory"))
passages.append(self.create_passage(node.text, embedding))
# insert passages
self.storage.insert_many(passages)

View File

@ -3,9 +3,19 @@ import pickle
from memgpt.config import AgentConfig
from memgpt.memory import (
DummyRecallMemory,
BaseRecallMemory,
EmbeddingArchivalMemory,
)
from memgpt.utils import get_local_time, printd, OpenAIBackcompatUnpickler
from memgpt.utils import get_local_time, printd
from memgpt.data_types import Message
from memgpt.config import MemGPTConfig
from datetime import datetime
def parse_formatted_time(formatted_time):
# parse times returned by memgpt.utils.get_formatted_time()
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
class PersistenceManager(ABC):
@ -33,83 +43,60 @@ class PersistenceManager(ABC):
class LocalStateManager(PersistenceManager):
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
recall_memory_cls = DummyRecallMemory
recall_memory_cls = BaseRecallMemory
archival_memory_cls = EmbeddingArchivalMemory
def __init__(self, agent_config: AgentConfig):
# Memory held in-state useful for debugging stateful versions
self.memory = None
self.messages = []
self.all_messages = []
self.messages = [] # current in-context messages
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
self.archival_memory = EmbeddingArchivalMemory(agent_config)
self.recall_memory = None
self.recall_memory = BaseRecallMemory(agent_config)
self.agent_config = agent_config
self.config = MemGPTConfig.load()
@classmethod
def load(cls, filename, agent_config: AgentConfig):
def load(cls, agent_config: AgentConfig):
""" Load a LocalStateManager from a file. """ ""
try:
with open(filename, "rb") as f:
data = pickle.load(f)
except ModuleNotFoundError as e:
# Patch for stripped openai package
# ModuleNotFoundError: No module named 'openai.openai_object'
with open(filename, "rb") as f:
unpickler = OpenAIBackcompatUnpickler(f)
data = unpickler.load()
# print(f"Unpickled data:\n{data.keys()}")
from memgpt.openai_backcompat.openai_object import OpenAIObject
def convert_openai_objects_to_dict(obj):
if isinstance(obj, OpenAIObject):
# Convert to dict or handle as needed
# print(f"detected OpenAIObject on {obj}")
return obj.to_dict_recursive()
elif isinstance(obj, dict):
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_openai_objects_to_dict(v) for v in obj]
else:
return obj
data = convert_openai_objects_to_dict(data)
# print(f"Converted data:\n{data.keys()}")
# TODO: remove this class and just init the class
manager = cls(agent_config)
manager.all_messages = data["all_messages"]
manager.messages = data["messages"]
manager.recall_memory = data["recall_memory"]
manager.archival_memory = EmbeddingArchivalMemory(agent_config)
return manager
def save(self, filename):
with open(filename, "wb") as fh:
## TODO: fix this hacky solution to pickle the retriever
self.archival_memory.save()
pickle.dump(
{
"recall_memory": self.recall_memory,
"messages": self.messages,
"all_messages": self.all_messages,
},
fh,
protocol=pickle.HIGHEST_PROTOCOL,
)
printd(f"Saved state to {fh}")
def save(self):
"""Ensure storage connectors save data"""
self.archival_memory.save()
self.recall_memory.save()
def init(self, agent):
"""Connect persistent state manager to agent"""
printd(f"Initializing {self.__class__.__name__} with agent object")
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
# self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
# printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
# self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
# TODO: init archival memory here?
def json_to_message(self, message_json) -> Message:
"""Convert agent message JSON into Message object"""
timestamp = message_json["timestamp"]
message = message_json["message"]
return Message(
user_id=self.config.anon_clientid,
agent_id=self.agent_config.name,
role=message["role"],
text=message["content"],
model=self.agent_config.model,
created_at=parse_formatted_time(timestamp),
function_name=message["function_name"] if "function_name" in message else None,
function_args=message["function_args"] if "function_args" in message else None,
function_response=message["function_response"] if "function_response" in message else None,
id=message["id"] if "id" in message else None,
)
def trim_messages(self, num):
# printd(f"InMemoryStateManager.trim_messages")
@ -121,7 +108,9 @@ class LocalStateManager(PersistenceManager):
printd(f"{self.__class__.__name__}.prepend_to_message")
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
self.all_messages.extend(added_messages)
# add to recall memory
self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages])
def append_to_messages(self, added_messages):
# first tag with timestamps
@ -129,7 +118,9 @@ class LocalStateManager(PersistenceManager):
printd(f"{self.__class__.__name__}.append_to_messages")
self.messages = self.messages + added_messages
self.all_messages.extend(added_messages)
# add to recall memory
self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages])
def swap_system_message(self, new_system_message):
# first tag with timestamps
@ -137,96 +128,9 @@ class LocalStateManager(PersistenceManager):
printd(f"{self.__class__.__name__}.swap_system_message")
self.messages[0] = new_system_message
self.all_messages.append(new_system_message)
def update_memory(self, new_memory):
printd(f"{self.__class__.__name__}.update_memory")
self.memory = new_memory
class StateManager(PersistenceManager):
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
recall_memory_cls = RecallMemory
archival_memory_cls = EmbeddingArchivalMemory
def __init__(self, agent_config: AgentConfig):
# Memory held in-state useful for debugging stateful versions
self.memory = None
self.messages = []
self.all_messages = []
self.archival_memory = EmbeddingArchivalMemory(agent_config)
self.recall_memory = None
self.agent_config = agent_config
@classmethod
def load(cls, filename, agent_config: AgentConfig):
""" Load a LocalStateManager from a file. """ ""
with open(filename, "rb") as f:
data = pickle.load(f)
manager = cls(agent_config)
manager.all_messages = data["all_messages"]
manager.messages = data["messages"]
manager.recall_memory = data["recall_memory"]
manager.archival_memory = EmbeddingArchivalMemory(agent_config)
return manager
def save(self, filename):
with open(filename, "wb") as fh:
## TODO: fix this hacky solution to pickle the retriever
self.archival_memory.save()
pickle.dump(
{
"recall_memory": self.recall_memory,
"messages": self.messages,
"all_messages": self.all_messages,
},
fh,
protocol=pickle.HIGHEST_PROTOCOL,
)
printd(f"Saved state to {fh}")
def init(self, agent):
printd(f"Initializing {self.__class__.__name__} with agent object")
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
# TODO: init archival memory here?
def trim_messages(self, num):
# printd(f"InMemoryStateManager.trim_messages")
self.messages = [self.messages[0]] + self.messages[num:]
def prepend_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"{self.__class__.__name__}.prepend_to_message")
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
self.all_messages.extend(added_messages)
def append_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"{self.__class__.__name__}.append_to_messages")
self.messages = self.messages + added_messages
self.all_messages.extend(added_messages)
def swap_system_message(self, new_system_message):
# first tag with timestamps
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
printd(f"{self.__class__.__name__}.swap_system_message")
self.messages[0] = new_system_message
self.all_messages.append(new_system_message)
# add to recall memory
self.recall_memory.insert(self.json_to_message(new_system_message))
def update_memory(self, new_memory):
printd(f"{self.__class__.__name__}.update_memory")

View File

@ -117,10 +117,11 @@ def get_local_time(timezone=None):
time_str = get_local_time_timezone(timezone)
else:
# Get the current time, which will be in the local timezone of the computer
local_time = datetime.now()
local_time = datetime.now().astimezone()
# You may format it as you desire, including AM/PM
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
print("formatted_time", formatted_time)
return time_str.strip()

19
poetry.lock generated
View File

@ -3822,6 +3822,23 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
pymysql = ["pymysql"]
sqlcipher = ["sqlcipher3-binary"]
[[package]]
name = "sqlalchemy-json"
version = "0.7.0"
description = "JSON type with nested change tracking for SQLAlchemy"
optional = false
python-versions = ">= 3.6"
files = [
{file = "sqlalchemy-json-0.7.0.tar.gz", hash = "sha256:620d0b26f648f21a8fa9127df66f55f83a5ab4ae010e5397a5c6989a08238561"},
{file = "sqlalchemy_json-0.7.0-py3-none-any.whl", hash = "sha256:27881d662ca18363a4ac28175cc47ea2a6f2bef997ae1159c151026b741818e6"},
]
[package.dependencies]
sqlalchemy = ">=0.7"
[package.extras]
dev = ["pytest"]
[[package]]
name = "starlette"
version = "0.27.0"
@ -4913,4 +4930,4 @@ server = ["fastapi", "uvicorn", "websockets"]
[metadata]
lock-version = "2.0"
python-versions = "<3.12,>=3.9"
content-hash = "12010863b2b9c1e26dceace00ea4e1ea7cc95932ab77b1ef37a5473c2e375575"
content-hash = "a6b8cdeda433007b6d33441b82417e435df336cea458af5dedce9a94f07f1225"

View File

@ -49,14 +49,8 @@ tiktoken = "^0.5.1"
python-box = "^7.1.1"
pypdf = "^3.17.1"
pyyaml = "^6.0.1"
fastapi = {version = "^0.104.1", optional = true}
uvicorn = {version = "^0.24.0.post1", optional = true}
chromadb = "^0.4.18"
pytest-asyncio = {version = "^0.23.2", optional = true}
pydantic = "^2.5.2"
pyautogen = {version = "0.2.0", optional = true}
html2text = "^2020.1.16"
docx2txt = "^0.8"
chromadb = {version = "^0.4.18", optional = true}
sqlalchemy-json = "^0.7.0"
[tool.poetry.extras]
local = ["torch", "huggingface-hub", "transformers"]

View File

@ -3,63 +3,90 @@ import subprocess
import sys
import pytest
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
# subprocess.check_call(
# [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
# ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
#
# subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
import pgvector # Try to import again after installing
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.connectors.chroma import ChromaStorageConnector
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
from memgpt.embeddings import embedding_model
from memgpt.data_types import Message, Passage
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.utils import get_local_time
import argparse
from datetime import datetime, timedelta
def test_recall_db() -> None:
def test_recall_db():
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
storage_type = "postgres"
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri)
config = MemGPTConfig(
recall_storage_type=storage_type,
recall_storage_uri=storage_uri,
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt4",
)
print(config.config_path)
assert config.recall_storage_uri is not None
config.save()
print(config)
conn = StorageConnector.get_recall_storage_connector()
agent_config = AgentConfig(
persona=config.persona,
human=config.human,
model=config.model,
)
conn = StorageConnector.get_recall_storage_connector(agent_config)
# construct recall memory messages
message1 = Message(
agent_id="test_agent1",
agent_id=agent_config.name,
role="agent",
content="This is a test message",
id="test_id1",
text="This is a test message",
user_id=config.anon_clientid,
model=agent_config.model,
created_at=datetime.now(),
)
message2 = Message(
agent_id="test_agent2",
agent_id=agent_config.name,
role="user",
content="This is a test message",
id="test_id2",
text="This is a test message",
user_id=config.anon_clientid,
model=agent_config.model,
created_at=datetime.now(),
)
print(vars(message1))
# test insert
conn.insert(message1)
conn.insert_many([message2])
# test size
assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}"
assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}"
assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}"
assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}'
# test get
assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}"
assert (
len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1
), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}"
# test text query
res = conn.query_text("test")
print(res)
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
# test date query
current_time = datetime.now()
ten_weeks_ago = current_time - timedelta(weeks=1)
res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time)
print(res)
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
print(conn.get_all())
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
@ -83,10 +110,28 @@ def test_postgres_openai():
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = PostgresStorageConnector(name="test-openai")
agent_config = AgentConfig(
name="test_agent",
persona=config.persona,
human=config.human,
model=config.model,
)
db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
# db.delete()
# return
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
db.insert(
Passage(
text=passage,
embedding=embed_model.get_text_embedding(passage),
user_id=config.anon_clientid,
agent_id="test_agent",
data_source="test",
metadata={"test_metadata_key": "test_metadata_value"},
)
)
print(db.get_all())
@ -246,3 +291,6 @@ def test_lancedb_local():
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
test_recall_db()