fixed bug where persistence manager was not saving in demo CLI

This commit is contained in:
Charles Packer 2023-10-17 23:40:31 -07:00
parent e3f784a925
commit 5714cda986
3 changed files with 38 additions and 4 deletions

21
main.py
View File

@ -132,6 +132,18 @@ async def main():
print(f"Saved checkpoint to: {filename}")
except Exception as e:
print(f"Saving state to {filename} failed with: {e}")
# save the persistence manager too
filename = filename.replace('.json', '.persistence.pickle')
try:
memgpt_agent.persistence_manager.save(filename)
# with open(filename, 'wb') as fh:
# p_dump = memgpt_agent.persistence_manager.save()
# pickle.dump(p_dump, fh, protocol=pickle.HIGHEST_PROTOCOL)
print(f"Saved persistence manager to: {filename}")
except Exception as e:
print(f"Saving persistence manager to {filename} failed with: {e}")
continue
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
@ -145,6 +157,15 @@ async def main():
print(f"Loading {filename} failed with: {e}")
else:
print(f"/load error: no checkpoint specified")
# need to load persistence manager too
filename = filename.replace('.json', '.persistence.pickle')
try:
memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods
print(f"Loaded persistence manager from: {filename}")
except Exception as e:
print(f"/load error: loading persistence manager from {filename} failed with: {e}")
continue
elif user_input.lower() == "/dump":

View File

@ -250,7 +250,7 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
"""Dummy in-memory version of an archival memory database, using a FAISS
index for fast nearest-neighbors embedding search.
Archival memory is effectively "infinite" overflow for core memory,
and is read-only via string queries.
@ -291,7 +291,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
"""Simple embedding-based search (inefficient, no caching)"""
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
# query_embedding = get_embedding(query_string, model=self.embedding_model)
# query_embedding = get_embedding(query_string, model=self.embedding_model)
# our wrapped version supports backoff/rate-limits
if query_string in self.embeddings_dict:
query_embedding = self.embeddings_dict[query_string]

View File

@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import pickle
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss
from .utils import get_local_time, printd
@ -39,6 +40,15 @@ class InMemoryStateManager(PersistenceManager):
self.messages = []
self.all_messages = []
@staticmethod
def load(filename):
with open(filename, 'rb') as f:
return pickle.load(f)
def save(self, filename):
with open(filename, 'wb') as fh:
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
def init(self, agent):
printd(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
@ -91,7 +101,7 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
def __init__(self, archival_memory_db):
self.archival_memory_db = archival_memory_db
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
@ -117,7 +127,10 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager):
self.archival_index = archival_index
self.archival_memory_db = archival_memory_db
self.a_k = a_k
def save(self, _filename):
raise NotImplementedError
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]