diff --git a/main.py b/main.py index 10d14b5d2..d628b49c3 100644 --- a/main.py +++ b/main.py @@ -132,6 +132,18 @@ async def main(): print(f"Saved checkpoint to: {filename}") except Exception as e: print(f"Saving state to {filename} failed with: {e}") + + # save the persistence manager too + filename = filename.replace('.json', '.persistence.pickle') + try: + memgpt_agent.persistence_manager.save(filename) + # with open(filename, 'wb') as fh: + # p_dump = memgpt_agent.persistence_manager.save() + # pickle.dump(p_dump, fh, protocol=pickle.HIGHEST_PROTOCOL) + print(f"Saved persistence manager to: {filename}") + except Exception as e: + print(f"Saving persistence manager to {filename} failed with: {e}") + continue elif user_input.lower() == "/load" or user_input.lower().startswith("/load "): @@ -145,6 +157,15 @@ async def main(): print(f"Loading {filename} failed with: {e}") else: print(f"/load error: no checkpoint specified") + + # need to load persistence manager too + filename = filename.replace('.json', '.persistence.pickle') + try: + memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods + print(f"Loaded persistence manager from: {filename}") + except Exception as e: + print(f"/load error: loading persistence manager from {filename} failed with: {e}") + continue elif user_input.lower() == "/dump": diff --git a/memgpt/memory.py b/memgpt/memory.py index c36dabdb6..659076d56 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -250,7 +250,7 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory): class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): """Dummy in-memory version of an archival memory database, using a FAISS index for fast nearest-neighbors embedding search. - + Archival memory is effectively "infinite" overflow for core memory, and is read-only via string queries. @@ -291,7 +291,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): """Simple embedding-based search (inefficient, no caching)""" # see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb - # query_embedding = get_embedding(query_string, model=self.embedding_model) + # query_embedding = get_embedding(query_string, model=self.embedding_model) # our wrapped version supports backoff/rate-limits if query_string in self.embeddings_dict: query_embedding = self.embeddings_dict[query_string] diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index 4950d1030..4874fe020 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import pickle from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss from .utils import get_local_time, printd @@ -39,6 +40,15 @@ class InMemoryStateManager(PersistenceManager): self.messages = [] self.all_messages = [] + @staticmethod + def load(filename): + with open(filename, 'rb') as f: + return pickle.load(f) + + def save(self, filename): + with open(filename, 'wb') as fh: + pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL) + def init(self, agent): printd(f"Initializing InMemoryStateManager with agent object") self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] @@ -91,7 +101,7 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager): def __init__(self, archival_memory_db): self.archival_memory_db = archival_memory_db - + def init(self, agent): print(f"Initializing InMemoryStateManager with agent object") self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] @@ -117,7 +127,10 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager): self.archival_index = archival_index self.archival_memory_db = archival_memory_db self.a_k = a_k - + + def save(self, _filename): + raise NotImplementedError + def init(self, agent): print(f"Initializing InMemoryStateManager with agent object") self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]