mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00

Co-authored-by: Charles Packer <packercharles@gmail.com> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: Shubham Naik <shub@memgpt.ai>
149 lines
5.1 KiB
Python
149 lines
5.1 KiB
Python
from abc import ABC, abstractmethod
|
|
from datetime import datetime
|
|
from typing import List
|
|
|
|
from letta.memory import BaseRecallMemory, EmbeddingArchivalMemory
|
|
from letta.schemas.agent import AgentState
|
|
from letta.schemas.memory import Memory
|
|
from letta.schemas.message import Message
|
|
from letta.utils import printd
|
|
|
|
|
|
def parse_formatted_time(formatted_time: str):
|
|
# parse times returned by letta.utils.get_formatted_time()
|
|
try:
|
|
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p %Z%z")
|
|
except:
|
|
return datetime.strptime(formatted_time.strip(), "%Y-%m-%d %I:%M:%S %p")
|
|
|
|
|
|
class PersistenceManager(ABC):
|
|
@abstractmethod
|
|
def trim_messages(self, num):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def prepend_to_messages(self, added_messages):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def append_to_messages(self, added_messages):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def swap_system_message(self, new_system_message):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_memory(self, new_memory):
|
|
pass
|
|
|
|
|
|
class LocalStateManager(PersistenceManager):
|
|
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
|
|
|
recall_memory_cls = BaseRecallMemory
|
|
archival_memory_cls = EmbeddingArchivalMemory
|
|
|
|
def __init__(self, agent_state: AgentState):
|
|
# Memory held in-state useful for debugging stateful versions
|
|
self.memory = agent_state.memory
|
|
# self.messages = [] # current in-context messages
|
|
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
|
|
self.archival_memory = EmbeddingArchivalMemory(agent_state)
|
|
self.recall_memory = BaseRecallMemory(agent_state)
|
|
# self.agent_state = agent_state
|
|
|
|
def save(self):
|
|
"""Ensure storage connectors save data"""
|
|
self.archival_memory.save()
|
|
self.recall_memory.save()
|
|
|
|
'''
|
|
def json_to_message(self, message_json) -> Message:
|
|
"""Convert agent message JSON into Message object"""
|
|
|
|
# get message
|
|
if "message" in message_json:
|
|
message = message_json["message"]
|
|
else:
|
|
message = message_json
|
|
|
|
# get timestamp
|
|
if "timestamp" in message_json:
|
|
timestamp = parse_formatted_time(message_json["timestamp"])
|
|
else:
|
|
timestamp = get_local_time()
|
|
|
|
# TODO: change this when we fully migrate to tool calls API
|
|
if "function_call" in message:
|
|
tool_calls = [
|
|
ToolCall(
|
|
id=message["tool_call_id"],
|
|
tool_call_type="function",
|
|
function={
|
|
"name": message["function_call"]["name"],
|
|
"arguments": message["function_call"]["arguments"],
|
|
},
|
|
)
|
|
]
|
|
printd(f"Saving tool calls {[vars(tc) for tc in tool_calls]}")
|
|
else:
|
|
tool_calls = None
|
|
|
|
# if message["role"] == "function":
|
|
# message["role"] = "tool"
|
|
|
|
return Message(
|
|
user_id=self.agent_state.user_id,
|
|
agent_id=self.agent_state.id,
|
|
role=message["role"],
|
|
text=message["content"],
|
|
name=message["name"] if "name" in message else None,
|
|
model=self.agent_state.llm_config.model,
|
|
created_at=timestamp,
|
|
tool_calls=tool_calls,
|
|
tool_call_id=message["tool_call_id"] if "tool_call_id" in message else None,
|
|
id=message["id"] if "id" in message else None,
|
|
)
|
|
'''
|
|
|
|
def trim_messages(self, num):
|
|
# printd(f"InMemoryStateManager.trim_messages")
|
|
# self.messages = [self.messages[0]] + self.messages[num:]
|
|
pass
|
|
|
|
def prepend_to_messages(self, added_messages: List[Message]):
|
|
# first tag with timestamps
|
|
# added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
|
|
printd(f"{self.__class__.__name__}.prepend_to_message")
|
|
# self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
|
|
|
# add to recall memory
|
|
|
|
def append_to_messages(self, added_messages: List[Message]):
|
|
# first tag with timestamps
|
|
# added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
|
|
printd(f"{self.__class__.__name__}.append_to_messages")
|
|
# self.messages = self.messages + added_messages
|
|
|
|
# add to recall memory
|
|
self.recall_memory.insert_many([m for m in added_messages])
|
|
|
|
def swap_system_message(self, new_system_message: Message):
|
|
# first tag with timestamps
|
|
# new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
|
|
|
printd(f"{self.__class__.__name__}.swap_system_message")
|
|
# self.messages[0] = new_system_message
|
|
|
|
# add to recall memory
|
|
self.recall_memory.insert(new_system_message)
|
|
|
|
def update_memory(self, new_memory: Memory):
|
|
printd(f"{self.__class__.__name__}.update_memory")
|
|
assert isinstance(new_memory, Memory), type(new_memory)
|
|
self.memory = new_memory
|