mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fixed bug where persistence manager was not saving in demo CLI
This commit is contained in:
parent
e3f784a925
commit
5714cda986
21
main.py
21
main.py
@ -132,6 +132,18 @@ async def main():
|
||||
print(f"Saved checkpoint to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving state to {filename} failed with: {e}")
|
||||
|
||||
# save the persistence manager too
|
||||
filename = filename.replace('.json', '.persistence.pickle')
|
||||
try:
|
||||
memgpt_agent.persistence_manager.save(filename)
|
||||
# with open(filename, 'wb') as fh:
|
||||
# p_dump = memgpt_agent.persistence_manager.save()
|
||||
# pickle.dump(p_dump, fh, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
print(f"Saved persistence manager to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving persistence manager to {filename} failed with: {e}")
|
||||
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
|
||||
@ -145,6 +157,15 @@ async def main():
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
else:
|
||||
print(f"/load error: no checkpoint specified")
|
||||
|
||||
# need to load persistence manager too
|
||||
filename = filename.replace('.json', '.persistence.pickle')
|
||||
try:
|
||||
memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods
|
||||
print(f"Loaded persistence manager from: {filename}")
|
||||
except Exception as e:
|
||||
print(f"/load error: loading persistence manager from {filename} failed with: {e}")
|
||||
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/dump":
|
||||
|
@ -250,7 +250,7 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
|
||||
class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
"""Dummy in-memory version of an archival memory database, using a FAISS
|
||||
index for fast nearest-neighbors embedding search.
|
||||
|
||||
|
||||
Archival memory is effectively "infinite" overflow for core memory,
|
||||
and is read-only via string queries.
|
||||
|
||||
@ -291,7 +291,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
|
||||
|
||||
# query_embedding = get_embedding(query_string, model=self.embedding_model)
|
||||
# query_embedding = get_embedding(query_string, model=self.embedding_model)
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
if query_string in self.embeddings_dict:
|
||||
query_embedding = self.embeddings_dict[query_string]
|
||||
|
@ -1,4 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import pickle
|
||||
|
||||
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss
|
||||
from .utils import get_local_time, printd
|
||||
@ -39,6 +40,15 @@ class InMemoryStateManager(PersistenceManager):
|
||||
self.messages = []
|
||||
self.all_messages = []
|
||||
|
||||
@staticmethod
|
||||
def load(filename):
|
||||
with open(filename, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, 'wb') as fh:
|
||||
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def init(self, agent):
|
||||
printd(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
@ -91,7 +101,7 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
|
||||
|
||||
def __init__(self, archival_memory_db):
|
||||
self.archival_memory_db = archival_memory_db
|
||||
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
@ -117,7 +127,10 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager):
|
||||
self.archival_index = archival_index
|
||||
self.archival_memory_db = archival_memory_db
|
||||
self.a_k = a_k
|
||||
|
||||
|
||||
def save(self, _filename):
|
||||
raise NotImplementedError
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
|
Loading…
Reference in New Issue
Block a user