From 9b3d59e0162cfa749c6dfb627c0251c837f79dbe Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 8 Dec 2023 18:27:26 -0800 Subject: [PATCH] Support recall and archival memory for postgres working test --- memgpt/agent.py | 93 ++-------------- memgpt/cli/cli_config.py | 55 ++++++++-- memgpt/config.py | 8 ++ memgpt/connectors/chroma.py | 70 ++++-------- memgpt/connectors/db.py | 114 ++++++++++++++----- memgpt/connectors/local.py | 9 +- memgpt/connectors/storage.py | 26 +++-- memgpt/data_types.py | 32 +++--- memgpt/memory.py | 65 +++++++---- memgpt/persistence_manager.py | 200 +++++++++------------------------- memgpt/utils.py | 5 +- poetry.lock | 19 +++- pyproject.toml | 10 +- tests/test_storage.py | 96 ++++++++++++---- 14 files changed, 393 insertions(+), 409 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index c150eaf6e..c09c64e11 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -7,7 +7,7 @@ import traceback from memgpt.persistence_manager import LocalStateManager from memgpt.config import AgentConfig, MemGPTConfig from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages -from memgpt.memory import CoreMemory as Memory, summarize_messages +from memgpt.memory import CoreMemory as InContextMemory, summarize_messages from memgpt.openai_tools import create, is_context_overflow_error from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response from memgpt.constants import ( @@ -29,7 +29,7 @@ def initialize_memory(ai_notes, human_notes): raise ValueError(ai_notes) if human_notes is None: raise ValueError(human_notes) - memory = Memory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT) + memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT) memory.edit_persona(ai_notes) memory.edit_human(human_notes) return memory @@ -240,6 +240,7 @@ class Agent(object): ### Local state management def to_dict(self): + # TODO: select specific variables for the saves state (to eventually move to a DB) rather than checkpointing everything in the class return { "model": self.model, "system": self.system, @@ -249,32 +250,29 @@ class Agent(object): "memory": self.memory.to_dict(), } - def save_to_json_file(self, filename): + def save_agent_state_json(self, filename): + """Save agent state to JSON""" with open(filename, "w") as file: json.dump(self.to_dict(), file) def save(self): """Save agent state locally""" - timestamp = get_local_time().replace(" ", "_").replace(":", "_") - agent_name = self.config.name # TODO: fix - # save config self.config.save() - # save agent state + # save agent state to timestamped file + timestamp = get_local_time().replace(" ", "_").replace(":", "_") filename = f"{timestamp}.json" os.makedirs(self.config.save_state_dir(), exist_ok=True) - self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename)) + self.save_agent_state_json(os.path.join(self.config.save_state_dir(), filename)) - # save the persistence manager too - filename = f"{timestamp}.persistence.pickle" - os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True) - self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename)) + # save the persistence manager too (recall/archival memory) + self.persistence_manager.save() @classmethod def load_agent(cls, interface, agent_config: AgentConfig): - """Load saved agent state""" + """Load saved agent state based on agent_config""" # TODO: support loading from specific file agent_name = agent_config.name @@ -290,10 +288,7 @@ class Agent(object): state = json.load(open(filename, "r")) # load persistence manager - filename = os.path.basename(filename).replace(".json", ".persistence.pickle") - directory = agent_config.save_persistence_manager_dir() - printd(f"Loading persistence manager from {os.path.join(directory, filename)}") - persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config) + persistence_manager = LocalStateManager.load(agent_config) # need to dynamically link the functions # the saved agent.functions will just have the schemas, but we need to @@ -354,70 +349,6 @@ class Agent(object): agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"]) return agent - @classmethod - def load(cls, state, interface, persistence_manager): - model = state["model"] - system = state["system"] - functions = state["functions"] - messages = state["messages"] - try: - messages_total = state["messages_total"] - except KeyError: - messages_total = len(messages) - 1 - # memory requires a nested load - memory_dict = state["memory"] - persona_notes = memory_dict["persona"] - human_notes = memory_dict["human"] - - # Two-part load - new_agent = cls( - model=model, - system=system, - functions=functions, - interface=interface, - persistence_manager=persistence_manager, - persistence_manager_init=False, - persona_notes=persona_notes, - human_notes=human_notes, - messages_total=messages_total, - ) - new_agent._messages = messages - return new_agent - - def load_inplace(self, state): - self.model = state["model"] - self.system = state["system"] - self.functions = state["functions"] - # memory requires a nested load - memory_dict = state["memory"] - persona_notes = memory_dict["persona"] - human_notes = memory_dict["human"] - self.memory = initialize_memory(persona_notes, human_notes) - # messages also - self._messages = state["messages"] - try: - self.messages_total = state["messages_total"] - except KeyError: - self.messages_total = len(self.messages) - 1 # -system - - @classmethod - def load_from_json(cls, json_state, interface, persistence_manager): - state = json.loads(json_state) - return cls.load(state, interface, persistence_manager) - - @classmethod - def load_from_json_file(cls, json_file, interface, persistence_manager): - with open(json_file, "r") as file: - state = json.load(file) - return cls.load(state, interface, persistence_manager) - - def load_from_json_file_inplace(self, json_file): - # Load in-place - # No interface arg needed, we can use the current one - with open(json_file, "r") as file: - state = json.load(file) - self.load_inplace(state) - def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False): """Can be used to enforce that the first message always uses send_message""" response_message = response.choices[0].message diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index d17e70372..e7ede919c 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -418,6 +418,22 @@ def configure_archival_storage(config: MemGPTConfig): # TODO: allow configuring embedding model +def configure_recall_storage(config: MemGPTConfig): + # Configure recall storage backend + recall_storage_options = ["local", "postgres"] + recall_storage_type = questionary.select( + "Select storage backend for recall data:", recall_storage_options, default=config.recall_storage_type + ).ask() + recall_storage_uri, recall_storage_path = None, None + # configure postgres + if recall_storage_type == "postgres": + recall_storage_uri = questionary.text( + "Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):", + default=config.recall_storage_uri if config.recall_storage_uri else "", + ).ask() + return recall_storage_type, recall_storage_uri, recall_storage_path + + @app.command() def configure(): """Updates default MemGPT configurations""" @@ -430,17 +446,30 @@ def configure(): # Will pre-populate with defaults, or what the user previously set config = MemGPTConfig.load() - try: - model_endpoint_type, model_endpoint = configure_llm_endpoint(config) - model, model_wrapper, context_window = configure_model( - config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint - ) - embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config) - default_preset, default_persona, default_human, default_agent = configure_cli(config) - archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config) - except ValueError as e: - typer.secho(str(e), fg=typer.colors.RED) - return + model_endpoint_type, model_endpoint = configure_llm_endpoint(config) + model, model_wrapper, context_window = configure_model(config, model_endpoint_type) + embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config) + default_preset, default_persona, default_human, default_agent = configure_cli(config) + archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config) + recall_storage_type, recall_storage_uri, recall_storage_path = configure_recall_storage(config) + + # check credentials + azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials() + openai_key = get_openai_credentials() + if model_endpoint_type == "azure" or embedding_endpoint_type == "azure": + if all([azure_key, azure_endpoint, azure_version]): + print(f"Using Microsoft endpoint {azure_endpoint}.") + if all([azure_deployment, azure_embedding_deployment]): + print(f"Using deployment id {azure_deployment}") + else: + raise ValueError( + "Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again." + ) + if model_endpoint_type == "openai" or embedding_endpoint_type == "openai": + if not openai_key: + raise ValueError( + "Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again." + ) config = MemGPTConfig( # model configs @@ -470,6 +499,10 @@ def configure(): archival_storage_type=archival_storage_type, archival_storage_uri=archival_storage_uri, archival_storage_path=archival_storage_path, + # recall storage + recall_storage_type=recall_storage_type, + recall_storage_uri=recall_storage_uri, + recall_storage_path=recall_storage_path, ) typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN) config.save() diff --git a/memgpt/config.py b/memgpt/config.py index 10d01095a..327ee4f93 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -134,6 +134,9 @@ class MemGPTConfig: "archival_storage_type": get_field(config, "archival_storage", "type"), "archival_storage_path": get_field(config, "archival_storage", "path"), "archival_storage_uri": get_field(config, "archival_storage", "uri"), + "recall_storage_type": get_field(config, "recall_storage", "type"), + "recall_storage_path": get_field(config, "recall_storage", "path"), + "recall_storage_uri": get_field(config, "recall_storage", "uri"), "anon_clientid": get_field(config, "client", "anon_clientid"), "config_path": config_path, "memgpt_version": get_field(config, "version", "memgpt_version"), @@ -187,6 +190,11 @@ class MemGPTConfig: set_field(config, "archival_storage", "path", self.archival_storage_path) set_field(config, "archival_storage", "uri", self.archival_storage_uri) + # recall storage + set_field(config, "recall_storage", "type", self.recall_storage_type) + set_field(config, "recall_storage", "path", self.recall_storage_path) + set_field(config, "recall_storage", "uri", self.recall_storage_uri) + # set version set_field(config, "version", "memgpt_version", memgpt.__version__) diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py index 8db7aa2ee..38ed3a1bf 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/connectors/chroma.py @@ -2,22 +2,10 @@ import chromadb import json import re from typing import Optional, List, Iterator -from memgpt.connectors.storage import StorageConnector, Passage +from memgpt.connectors.storage import StorageConnector, TableType from memgpt.utils import printd from memgpt.config import AgentConfig, MemGPTConfig - - -def create_chroma_client(): - config = MemGPTConfig.load() - # create chroma client - if config.archival_storage_path: - client = chromadb.PersistentClient(config.archival_storage_path) - else: - # assume uri={ip}:{port} - ip = config.archival_storage_uri.split(":")[0] - port = config.archival_storage_uri.split(":")[1] - client = chromadb.HttpClient(host=ip, port=port) - return client +from memgpt.data_types import Record, Message, Passage class ChromaStorageConnector(StorageConnector): @@ -25,21 +13,24 @@ class ChromaStorageConnector(StorageConnector): # WARNING: This is not thread safe. Do NOT do concurrent access to the same collection. - def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): - # determine table name - if agent_config: - assert name is None, f"Cannot specify both agent config and name {name}" - self.table_name = self.generate_table_name_agent(agent_config) - elif name: - assert agent_config is None, f"Cannot specify both agent config and name {name}" - self.table_name = self.generate_table_name(name) + def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): + super().__init__(table_type=table_type, agent_config=agent_config) + config = MemGPTConfig.load() + + # supported table types + self.supported_types = [TableType.ARCHIVAL_MEMORY] + + if table_type not in self.supported_types: + raise ValueError(f"Table type {table_type} not supported by Chroma") + + # create chroma client + if config.archival_storage_path: + self.client = chromadb.PersistentClient(config.archival_storage_path) else: - raise ValueError("Must specify either agent config or name") - - printd(f"Using table name {self.table_name}") - - # create client - self.client = create_chroma_client() + # assume uri={ip}:{port} + ip = config.archival_storage_uri.split(":")[0] + port = config.archival_storage_uri.split(":")[1] + self.client = chromadb.HttpClient(host=ip, port=port) # get a collection or create if it doesn't exist already self.collection = self.client.get_or_create_collection(self.table_name) @@ -98,28 +89,5 @@ class ChromaStorageConnector(StorageConnector): collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")] return collections - def sanitize_table_name(self, name: str) -> str: - # Remove leading and trailing whitespace - name = name.strip() - - # Replace spaces and invalid characters with underscores - name = re.sub(r"\s+|\W+", "_", name) - - # Truncate to the maximum identifier length (e.g., 63 for PostgreSQL) - max_length = 63 - if len(name) > max_length: - name = name[:max_length].rstrip("_") - - # Convert to lowercase - name = name.lower() - - return name - - def generate_table_name_agent(self, agent_config: AgentConfig): - return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}" - - def generate_table_name(self, name: str): - return f"memgpt_{self.sanitize_table_name(name)}" - def size(self) -> int: return self.collection.count() diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 9c833968a..e2da2f04e 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -1,5 +1,5 @@ from pgvector.psycopg import register_vector -from pgvector.sqlalchemy import Vector, JSON, Text +from pgvector.sqlalchemy import Vector import psycopg @@ -8,6 +8,8 @@ from sqlalchemy.orm import sessionmaker, mapped_column from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import func from sqlalchemy import Column, BIGINT, String, DateTime +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy_json import mutable_json_type import re from tqdm import tqdm @@ -28,15 +30,10 @@ from datetime import datetime Base = declarative_base() -def parse_formatted_time(formatted_time): - # parse times returned by memgpt.utils.get_formatted_time() - return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z") - - def get_db_model(table_name: str, table_type: TableType): config = MemGPTConfig.load() - if table_name == TableType.ARCHIVAL_MEMORY: + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: # create schema for archival memory class PassageModel(Base): """Defines data model for storing Passages (consisting of text, embedding)""" @@ -45,16 +42,29 @@ def get_db_model(table_name: str, table_type: TableType): # Assuming passage_id is the primary key id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True) + user_id = Column(String, nullable=False) + text = Column(String, nullable=False) doc_id = Column(String) agent_id = Column(String) data_source = Column(String) # agent_name if agent, data_source name if from data source - text = Column(String, nullable=False) embedding = mapped_column(Vector(config.embedding_dim)) - metadata_ = Column(JSON(astext_type=Text())) + metadata_ = Column(mutable_json_type(dbtype=JSONB, nested=True)) def __repr__(self): return f" List[Record]: session = self.Session() filters = self.get_filters(filters) - db_passages = session.query(self.db_model).filter(*filters).limit(limit).all() - return [self.type(**p.to_dict()) for p in db_passages] + db_records = session.query(self.db_model).filter(*filters).limit(limit).all() + return [record.to_record() for record in db_records] - def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Passage]: + def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Record]: session = self.Session() filters = self.get_filters(filters) - db_passage = session.query(self.db_model).filter(*filters).get(id) - if db_passage is None: + db_record = session.query(self.db_model).filter(*filters).get(id) + if db_record is None: return None - return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id) + return db_record.to_record() def size(self, filters: Optional[Dict] = {}) -> int: # return size of table + print("size") session = self.Session() filters = self.get_filters(filters) return session.query(self.db_model).filter(*filters).count() - def insert(self, passage: Passage): + def insert(self, record: Record): session = self.Session() - db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding) - session.add(db_passage) + db_record = self.db_model(**vars(record)) + session.add(db_record) session.commit() def insert_many(self, records: List[Record], show_progress=True): session = self.Session() iterable = tqdm(records) if show_progress else records - for passage in iterable: - db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding) - session.add(db_passage) + for record in iterable: + db_record = self.db_model(**vars(record)) + session.add(db_record) session.commit() def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: session = self.Session() - # Assuming PassageModel.embedding has the capability of computing l2_distance filters = self.get_filters(filters) results = session.scalars( select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k) ).all() # Convert the results into Passage objects - records = [self.type(**vars(result)) for result in results] + records = [result.to_record() for result in results] return records def save(self): @@ -195,6 +229,26 @@ class PostgresStorageConnector(StorageConnector): tables = [table[start_chars:] for table in tables] return tables + def query_date(self, start_date, end_date): + session = self.Session() + filters = self.get_filters({}) + results = ( + session.query(self.db_model) + .filter(*filters) + .filter(self.db_model.created_at >= start_date) + .filter(self.db_model.created_at <= end_date) + .all() + ) + return [result.to_record() for result in results] + + def query_text(self, query): + # todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204 + session = self.Session() + filters = self.get_filters({}) + results = session.query(self.db_model).filter(*filters).filter(self.db_model.text.contains(query)).all() + print(results) + # return [self.type(**vars(result)) for result in results] + return [result.to_record() for result in results] class LanceDBConnector(StorageConnector): """Storage via LanceDB""" @@ -251,7 +305,7 @@ class LanceDBConnector(StorageConnector): def get(self, id: str) -> Optional[Passage]: db_passage = self.table.where(f"passage_id={id}").to_list() - if len(db_passage) == 0: + if len(db_passage) == 0: return None return Passage( text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"] diff --git a/memgpt/connectors/local.py b/memgpt/connectors/local.py index 1566df554..a47c2f53f 100644 --- a/memgpt/connectors/local.py +++ b/memgpt/connectors/local.py @@ -6,7 +6,8 @@ import pickle import os -from typing import List, Optional +from typing import List, Optional, Dict +from abc import abstractmethod from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context from llama_index.indices.empty.base import EmptyIndex @@ -145,11 +146,9 @@ class InMemoryStorageConnector(StorageConnector): # TODO: maybae replace this with sqllite? - def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): - from memgpt.embeddings import embedding_model - + def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): + super().__init__(table_type=table_type, agent_config=agent_config) config = MemGPTConfig.load() - # TODO: figure out save location self.rows = [] diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index ed2d37ebd..61f666a51 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -6,7 +6,7 @@ from typing import Any, Optional, List, Iterator import re import pickle import os - +from abc import abstractmethod from typing import List, Optional, Dict from tqdm import tqdm @@ -14,6 +14,7 @@ from tqdm import tqdm from memgpt.config import AgentConfig, MemGPTConfig from memgpt.data_types import Record, Passage, Document, Message +from memgpt.utils import printd # ENUM representing table types in MemGPT @@ -28,9 +29,9 @@ class TableType: # table names used by MemGPT -RECALL_TABLE_NAME = "memgpt_recall_memory" -ARCHIVAL_TABLE_NAME = "memgpt_archival_memory" -PASSAGE_TABLE_NAME = "memgpt_passages" +RECALL_TABLE_NAME = "memgpt_recall_memory_agent" # agent memory +ARCHIVAL_TABLE_NAME = "memgpt_archival_memory_agent" # agent memory +PASSAGE_TABLE_NAME = "memgpt_passages" # loads data sources DOCUMENT_TABLE_NAME = "memgpt_documents" @@ -65,9 +66,10 @@ class StorageConnector: # get all filters for query if filters is not None: filter_conditions = {**self.filters, **filters} - return self.filters + [self.db_model[key] == value for key, value in filter_conditions.items()] else: - return self.filters + filter_conditions = self.filters + print("FILTERS", filter_conditions) + return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] def generate_table_name(self, agent_config: AgentConfig, table_type: TableType): @@ -102,18 +104,20 @@ class StorageConnector: if storage_type == "local": from memgpt.connectors.local import VectorIndexStorageConnector - return VectorIndexStorageConnector(agent_config=agent_config) + return VectorIndexStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) elif storage_type == "postgres": from memgpt.connectors.db import PostgresStorageConnector - return PostgresStorageConnector(agent_config=agent_config) + return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) + elif storage_type == "chroma": + from memgpt.connectors.chroma import ChromaStorageConnector - return ChromaStorageConnector(name=name, agent_config=agent_config) + return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) elif storage_type == "lancedb": from memgpt.connectors.db import LanceDBConnector - return LanceDBConnector(agent_config=agent_config) + return LanceDBConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @@ -122,6 +126,8 @@ class StorageConnector: def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None): storage_type = MemGPTConfig.load().recall_storage_type + print("Recall storage type", storage_type) + if storage_type == "local": from memgpt.connectors.local import InMemoryStorageConnector diff --git a/memgpt/data_types.py b/memgpt/data_types.py index e408f7232..08c1e11fc 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -20,11 +20,10 @@ class Record: self.text = text self.id = id # todo: generate unique uuid - # todo: timestamp # todo: self.role = role (?) - def __repr__(self): - pass + # def __repr__(self): + # pass class Message(Record): @@ -35,17 +34,19 @@ class Message(Record): user_id: str, agent_id: str, role: str, - content: str, + text: str, model: str, # model used to make function call + created_at: Optional[str] = None, function_name: Optional[str] = None, # name of function called function_args: Optional[str] = None, # args of function called function_response: Optional[str] = None, # response of function called embedding: Optional[np.ndarray] = None, id: Optional[str] = None, ): - super().__init__(user_id, agent_id, content, id) + super().__init__(user_id, agent_id, text, id) self.role = role # role (agent/user/function) self.model = model # model name (e.g. gpt-4) + self.created_at = created_at # function call info (optional) self.function_name = function_name @@ -55,15 +56,15 @@ class Message(Record): # embedding (optional) self.embedding = embedding - def __repr__(self): - pass + # def __repr__(self): + # pass class Document(Record): """A document represent a document loaded into MemGPT, which is broken down into passages.""" def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None): - super().__init__(user_id) + super().__init__(user_id, agent_id, text, id) self.text = text self.document_id = document_id self.data_source = data_source @@ -76,25 +77,26 @@ class Document(Record): class Passage(Record): """A passage is a single unit of memory, and a standard format accross all storage backends. - It is a string of text with an associated embedding. + It is a string of text with an assoidciated embedding. """ def __init__( self, user_id: str, text: str, - data_source: str, - embedding: np.ndarray, + agent_id: Optional[str] = None, # set if contained in agent memory + embedding: Optional[np.ndarray] = None, + data_source: Optional[str] = None, # None if created by agent doc_id: Optional[str] = None, - passage_id: Optional[str] = None, + id: Optional[str] = None, + metadata: Optional[dict] = {}, ): - super().__init__(user_id) + super().__init__(user_id, agent_id, text, id) self.text = text self.data_source = data_source self.embedding = embedding self.doc_id = doc_id - self.passage_id = passage_id - self.metadata = {} + self.metadata = metadata def __repr__(self): return f"Passage(text={self.text}, embedding={self.embedding})" diff --git a/memgpt/memory.py b/memgpt/memory.py index 0ac27aab6..482860ce7 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -284,7 +284,10 @@ class DummyRecallMemory(RecallMemory): return matches, len(matches) -class RecallMemorySQL(RecallMemory): +class BaseRecallMemory(RecallMemory): + + """Recall memory based on base functions implemented by storage connectors""" + def __init__(self, agent_config, restrict_search_to_summaries=False): # If true, the pool of messages that can be queried are the automated summaries only @@ -304,20 +307,39 @@ class RecallMemorySQL(RecallMemory): # TODO: have some mechanism for cleanup otherwise will lead to OOM self.cache = {} - @abstractmethod def text_search(self, query_string, count=None, start=None): - pass + self.storage.query_text(query_string, count, start) - @abstractmethod - def date_search(self, query_string, count=None, start=None): - pass + def date_search(self, start_date, end_date, count=None, start=None): + self.storage.query_date(start_date, end_date, count, start) - @abstractmethod def __repr__(self) -> str: - pass + total = self.storage.size() + system_count = self.storage.size(filters={"role": "system"}) + user_count = self.storage.size(filters={"role": "user"}) + assistant_count = self.storage.size(filters={"role": "assistant"}) + function_count = self.storage.size(filters={"role": "function"}) + other_count = total - (system_count + user_count + assistant_count + function_count) + + memory_str = ( + f"Statistics:" + + f"\n{total} total messages" + + f"\n{system_count} system" + + f"\n{user_count} user" + + f"\n{assistant_count} assistant" + + f"\n{function_count} function" + + f"\n{other_count} other" + ) + return f"\n### RECALL MEMORY ###" + f"\n{memory_str}" def insert(self, message: Message): - pass + self.storage.insert(message) + + def insert_many(self, messages: List[Message]): + self.storage.insert_many(messages) + + def save(self): + self.storage.save() class EmbeddingArchivalMemory(ArchivalMemory): @@ -333,24 +355,31 @@ class EmbeddingArchivalMemory(ArchivalMemory): self.top_k = top_k self.agent_config = agent_config - config = MemGPTConfig.load() + self.config = MemGPTConfig.load() # create embedding model self.embed_model = embedding_model() - self.embedding_chunk_size = config.embedding_chunk_size + self.embedding_chunk_size = self.config.embedding_chunk_size # create storage backend self.storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config) # TODO: have some mechanism for cleanup otherwise will lead to OOM self.cache = {} + def create_passage(self, text, embedding): + return Passage( + user_id=self.config.anon_clientid, + agent_id=self.agent_config.name, + text=text, + embedding=embedding, + ) + def save(self): """Save the index to disk""" self.storage.save() def insert(self, memory_string): """Embed and save memory string""" - from memgpt.connectors.storage import Passage if not isinstance(memory_string, str): return TypeError("memory must be a string") @@ -364,17 +393,7 @@ class EmbeddingArchivalMemory(ArchivalMemory): # breakup string into passages for node in parser.get_nodes_from_documents([Document(text=memory_string)]): embedding = self.embed_model.get_text_embedding(node.text) - # fixing weird bug where type returned isn't a list, but instead is an object - # eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023, - if isinstance(embedding, dict): - try: - embedding = embedding["data"][0]["embedding"] - except (KeyError, IndexError): - # TODO as a fallback, see if we can find any lists in the payload - raise TypeError( - f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" - ) - passages.append(Passage(text=node.text, embedding=embedding, doc_id=f"agent_{self.agent_config.name}_memory")) + passages.append(self.create_passage(node.text, embedding)) # insert passages self.storage.insert_many(passages) diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index 80f3c39b5..6e0d3dcf9 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -3,9 +3,19 @@ import pickle from memgpt.config import AgentConfig from memgpt.memory import ( DummyRecallMemory, + BaseRecallMemory, EmbeddingArchivalMemory, ) -from memgpt.utils import get_local_time, printd, OpenAIBackcompatUnpickler +from memgpt.utils import get_local_time, printd +from memgpt.data_types import Message +from memgpt.config import MemGPTConfig + +from datetime import datetime + + +def parse_formatted_time(formatted_time): + # parse times returned by memgpt.utils.get_formatted_time() + return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z") class PersistenceManager(ABC): @@ -33,83 +43,60 @@ class PersistenceManager(ABC): class LocalStateManager(PersistenceManager): """In-memory state manager has nothing to manage, all agents are held in-memory""" - recall_memory_cls = DummyRecallMemory + recall_memory_cls = BaseRecallMemory archival_memory_cls = EmbeddingArchivalMemory def __init__(self, agent_config: AgentConfig): # Memory held in-state useful for debugging stateful versions self.memory = None - self.messages = [] - self.all_messages = [] + self.messages = [] # current in-context messages + # self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB) self.archival_memory = EmbeddingArchivalMemory(agent_config) - self.recall_memory = None + self.recall_memory = BaseRecallMemory(agent_config) self.agent_config = agent_config + self.config = MemGPTConfig.load() @classmethod - def load(cls, filename, agent_config: AgentConfig): + def load(cls, agent_config: AgentConfig): """ Load a LocalStateManager from a file. """ "" - try: - with open(filename, "rb") as f: - data = pickle.load(f) - except ModuleNotFoundError as e: - # Patch for stripped openai package - # ModuleNotFoundError: No module named 'openai.openai_object' - with open(filename, "rb") as f: - unpickler = OpenAIBackcompatUnpickler(f) - data = unpickler.load() - # print(f"Unpickled data:\n{data.keys()}") - - from memgpt.openai_backcompat.openai_object import OpenAIObject - - def convert_openai_objects_to_dict(obj): - if isinstance(obj, OpenAIObject): - # Convert to dict or handle as needed - # print(f"detected OpenAIObject on {obj}") - return obj.to_dict_recursive() - elif isinstance(obj, dict): - return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [convert_openai_objects_to_dict(v) for v in obj] - else: - return obj - - data = convert_openai_objects_to_dict(data) - # print(f"Converted data:\n{data.keys()}") - + # TODO: remove this class and just init the class manager = cls(agent_config) - manager.all_messages = data["all_messages"] - manager.messages = data["messages"] - manager.recall_memory = data["recall_memory"] - manager.archival_memory = EmbeddingArchivalMemory(agent_config) return manager - def save(self, filename): - with open(filename, "wb") as fh: - ## TODO: fix this hacky solution to pickle the retriever - self.archival_memory.save() - pickle.dump( - { - "recall_memory": self.recall_memory, - "messages": self.messages, - "all_messages": self.all_messages, - }, - fh, - protocol=pickle.HIGHEST_PROTOCOL, - ) - printd(f"Saved state to {fh}") + def save(self): + """Ensure storage connectors save data""" + self.archival_memory.save() + self.recall_memory.save() def init(self, agent): + """Connect persistent state manager to agent""" printd(f"Initializing {self.__class__.__name__} with agent object") - self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] + # self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] self.memory = agent.memory - printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}") + # printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}") printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}") # Persistence manager also handles DB-related state - self.recall_memory = self.recall_memory_cls(message_database=self.all_messages) + # self.recall_memory = self.recall_memory_cls(message_database=self.all_messages) - # TODO: init archival memory here? + def json_to_message(self, message_json) -> Message: + """Convert agent message JSON into Message object""" + timestamp = message_json["timestamp"] + message = message_json["message"] + + return Message( + user_id=self.config.anon_clientid, + agent_id=self.agent_config.name, + role=message["role"], + text=message["content"], + model=self.agent_config.model, + created_at=parse_formatted_time(timestamp), + function_name=message["function_name"] if "function_name" in message else None, + function_args=message["function_args"] if "function_args" in message else None, + function_response=message["function_response"] if "function_response" in message else None, + id=message["id"] if "id" in message else None, + ) def trim_messages(self, num): # printd(f"InMemoryStateManager.trim_messages") @@ -121,7 +108,9 @@ class LocalStateManager(PersistenceManager): printd(f"{self.__class__.__name__}.prepend_to_message") self.messages = [self.messages[0]] + added_messages + self.messages[1:] - self.all_messages.extend(added_messages) + + # add to recall memory + self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages]) def append_to_messages(self, added_messages): # first tag with timestamps @@ -129,7 +118,9 @@ class LocalStateManager(PersistenceManager): printd(f"{self.__class__.__name__}.append_to_messages") self.messages = self.messages + added_messages - self.all_messages.extend(added_messages) + + # add to recall memory + self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages]) def swap_system_message(self, new_system_message): # first tag with timestamps @@ -137,96 +128,9 @@ class LocalStateManager(PersistenceManager): printd(f"{self.__class__.__name__}.swap_system_message") self.messages[0] = new_system_message - self.all_messages.append(new_system_message) - - def update_memory(self, new_memory): - printd(f"{self.__class__.__name__}.update_memory") - self.memory = new_memory - - -class StateManager(PersistenceManager): - """In-memory state manager has nothing to manage, all agents are held in-memory""" - - recall_memory_cls = RecallMemory - archival_memory_cls = EmbeddingArchivalMemory - - def __init__(self, agent_config: AgentConfig): - # Memory held in-state useful for debugging stateful versions - self.memory = None - self.messages = [] - self.all_messages = [] - self.archival_memory = EmbeddingArchivalMemory(agent_config) - self.recall_memory = None - self.agent_config = agent_config - - @classmethod - def load(cls, filename, agent_config: AgentConfig): - """ Load a LocalStateManager from a file. """ "" - with open(filename, "rb") as f: - data = pickle.load(f) - - manager = cls(agent_config) - manager.all_messages = data["all_messages"] - manager.messages = data["messages"] - manager.recall_memory = data["recall_memory"] - manager.archival_memory = EmbeddingArchivalMemory(agent_config) - return manager - - def save(self, filename): - with open(filename, "wb") as fh: - ## TODO: fix this hacky solution to pickle the retriever - self.archival_memory.save() - pickle.dump( - { - "recall_memory": self.recall_memory, - "messages": self.messages, - "all_messages": self.all_messages, - }, - fh, - protocol=pickle.HIGHEST_PROTOCOL, - ) - printd(f"Saved state to {fh}") - - def init(self, agent): - printd(f"Initializing {self.__class__.__name__} with agent object") - self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] - self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()] - self.memory = agent.memory - printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}") - printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}") - - # Persistence manager also handles DB-related state - self.recall_memory = self.recall_memory_cls(message_database=self.all_messages) - - # TODO: init archival memory here? - - def trim_messages(self, num): - # printd(f"InMemoryStateManager.trim_messages") - self.messages = [self.messages[0]] + self.messages[num:] - - def prepend_to_messages(self, added_messages): - # first tag with timestamps - added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages] - - printd(f"{self.__class__.__name__}.prepend_to_message") - self.messages = [self.messages[0]] + added_messages + self.messages[1:] - self.all_messages.extend(added_messages) - - def append_to_messages(self, added_messages): - # first tag with timestamps - added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages] - - printd(f"{self.__class__.__name__}.append_to_messages") - self.messages = self.messages + added_messages - self.all_messages.extend(added_messages) - - def swap_system_message(self, new_system_message): - # first tag with timestamps - new_system_message = {"timestamp": get_local_time(), "message": new_system_message} - - printd(f"{self.__class__.__name__}.swap_system_message") - self.messages[0] = new_system_message - self.all_messages.append(new_system_message) + + # add to recall memory + self.recall_memory.insert(self.json_to_message(new_system_message)) def update_memory(self, new_memory): printd(f"{self.__class__.__name__}.update_memory") diff --git a/memgpt/utils.py b/memgpt/utils.py index b0e858a47..f3b8c39fb 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -117,10 +117,11 @@ def get_local_time(timezone=None): time_str = get_local_time_timezone(timezone) else: # Get the current time, which will be in the local timezone of the computer - local_time = datetime.now() + local_time = datetime.now().astimezone() # You may format it as you desire, including AM/PM - time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z") + formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z") + print("formatted_time", formatted_time) return time_str.strip() diff --git a/poetry.lock b/poetry.lock index ef90397cd..8e311ee99 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3822,6 +3822,23 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"] pymysql = ["pymysql"] sqlcipher = ["sqlcipher3-binary"] +[[package]] +name = "sqlalchemy-json" +version = "0.7.0" +description = "JSON type with nested change tracking for SQLAlchemy" +optional = false +python-versions = ">= 3.6" +files = [ + {file = "sqlalchemy-json-0.7.0.tar.gz", hash = "sha256:620d0b26f648f21a8fa9127df66f55f83a5ab4ae010e5397a5c6989a08238561"}, + {file = "sqlalchemy_json-0.7.0-py3-none-any.whl", hash = "sha256:27881d662ca18363a4ac28175cc47ea2a6f2bef997ae1159c151026b741818e6"}, +] + +[package.dependencies] +sqlalchemy = ">=0.7" + +[package.extras] +dev = ["pytest"] + [[package]] name = "starlette" version = "0.27.0" @@ -4913,4 +4930,4 @@ server = ["fastapi", "uvicorn", "websockets"] [metadata] lock-version = "2.0" python-versions = "<3.12,>=3.9" -content-hash = "12010863b2b9c1e26dceace00ea4e1ea7cc95932ab77b1ef37a5473c2e375575" +content-hash = "a6b8cdeda433007b6d33441b82417e435df336cea458af5dedce9a94f07f1225" diff --git a/pyproject.toml b/pyproject.toml index 7ebfe3b2d..e214e6307 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,14 +49,8 @@ tiktoken = "^0.5.1" python-box = "^7.1.1" pypdf = "^3.17.1" pyyaml = "^6.0.1" -fastapi = {version = "^0.104.1", optional = true} -uvicorn = {version = "^0.24.0.post1", optional = true} -chromadb = "^0.4.18" -pytest-asyncio = {version = "^0.23.2", optional = true} -pydantic = "^2.5.2" -pyautogen = {version = "0.2.0", optional = true} -html2text = "^2020.1.16" -docx2txt = "^0.8" +chromadb = {version = "^0.4.18", optional = true} +sqlalchemy-json = "^0.7.0" [tool.poetry.extras] local = ["torch", "huggingface-hub", "transformers"] diff --git a/tests/test_storage.py b/tests/test_storage.py index 6688a550b..44d3b8fb0 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -3,63 +3,90 @@ import subprocess import sys import pytest -subprocess.check_call( - [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"] -) # , "psycopg_binary"]) # "psycopg", "libpq-dev"]) - -subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"]) +# subprocess.check_call( +# [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"] +# ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"]) +# +# subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"]) import pgvector # Try to import again after installing -from memgpt.connectors.storage import StorageConnector, Passage +from memgpt.connectors.storage import StorageConnector, TableType from memgpt.connectors.chroma import ChromaStorageConnector from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector from memgpt.embeddings import embedding_model from memgpt.data_types import Message, Passage from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.utils import get_local_time import argparse +from datetime import datetime, timedelta -def test_recall_db() -> None: +def test_recall_db(): # os.environ["MEMGPT_CONFIG_PATH"] = "./config" storage_type = "postgres" storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") - config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri) + config = MemGPTConfig( + recall_storage_type=storage_type, + recall_storage_uri=storage_uri, + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt4", + ) print(config.config_path) assert config.recall_storage_uri is not None config.save() print(config) - conn = StorageConnector.get_recall_storage_connector() + agent_config = AgentConfig( + persona=config.persona, + human=config.human, + model=config.model, + ) + + conn = StorageConnector.get_recall_storage_connector(agent_config) # construct recall memory messages message1 = Message( - agent_id="test_agent1", + agent_id=agent_config.name, role="agent", - content="This is a test message", - id="test_id1", + text="This is a test message", + user_id=config.anon_clientid, + model=agent_config.model, + created_at=datetime.now(), ) message2 = Message( - agent_id="test_agent2", + agent_id=agent_config.name, role="user", - content="This is a test message", - id="test_id2", + text="This is a test message", + user_id=config.anon_clientid, + model=agent_config.model, + created_at=datetime.now(), ) + print(vars(message1)) # test insert conn.insert(message1) conn.insert_many([message2]) # test size - assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}" - assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}" + assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}" + assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}' - # test get - assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}" - assert ( - len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1 - ), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}" + # test text query + res = conn.query_text("test") + print(res) + assert len(res) >= 2, f"Expected 2 messages, got {len(res)}" + + # test date query + current_time = datetime.now() + ten_weeks_ago = current_time - timedelta(weeks=1) + res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time) + print(res) + assert len(res) >= 2, f"Expected 2 messages, got {len(res)}" + + print(conn.get_all()) @pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key") @@ -83,10 +110,28 @@ def test_postgres_openai(): passage = ["This is a test passage", "This is another test passage", "Cinderella wept"] - db = PostgresStorageConnector(name="test-openai") + agent_config = AgentConfig( + name="test_agent", + persona=config.persona, + human=config.human, + model=config.model, + ) + db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY) + + # db.delete() + # return for passage in passage: - db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage))) + db.insert( + Passage( + text=passage, + embedding=embed_model.get_text_embedding(passage), + user_id=config.anon_clientid, + agent_id="test_agent", + data_source="test", + metadata={"test_metadata_key": "test_metadata_value"}, + ) + ) print(db.get_all()) @@ -246,3 +291,6 @@ def test_lancedb_local(): assert len(res) == 2, f"Expected 2 results, got {len(res)}" assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" + + +test_recall_db()