MemGPT/letta/memory.py
Sarah Wooders 85faf5f474
chore: migrate package name to letta (#1775)
Co-authored-by: Charles Packer <packercharles@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
2024-09-23 09:15:18 -07:00

441 lines
17 KiB
Python

import datetime
from abc import ABC, abstractmethod
from typing import Callable, Dict, List, Tuple, Union
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
from letta.embeddings import embedding_model, parse_and_chunk_text, query_embedding
from letta.llm_api.llm_api_tools import create
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from letta.schemas.agent import AgentState
from letta.schemas.memory import Memory
from letta.schemas.message import Message
from letta.schemas.passage import Passage
from letta.utils import (
count_tokens,
extract_date_from_timestamp,
get_local_time,
printd,
validate_date_format,
)
def get_memory_functions(cls: Memory) -> Dict[str, Callable]:
"""Get memory functions for a memory class"""
functions = {}
# collect base memory functions (should not be included)
base_functions = []
for func_name in dir(Memory):
funct = getattr(Memory, func_name)
if callable(funct):
base_functions.append(func_name)
for func_name in dir(cls):
if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
continue
if func_name in base_functions: # dont use BaseMemory functions
continue
func = getattr(cls, func_name)
if not callable(func): # not a function
continue
functions[func_name] = func
return functions
def _format_summary_history(message_history: List[Message]):
# TODO use existing prompt formatters for this (eg ChatML)
return "\n".join([f"{m.role}: {m.text}" for m in message_history])
def summarize_messages(
agent_state: AgentState,
message_sequence_to_summarize: List[Message],
insert_acknowledgement_assistant_message: bool = True,
):
"""Summarize a message sequence using GPT"""
# we need the context_window
context_window = agent_state.llm_config.context_window
summary_prompt = SUMMARY_PROMPT_SYSTEM
summary_input = _format_summary_history(message_sequence_to_summarize)
summary_input_tkns = count_tokens(summary_input)
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_FRAC * context_window:
trunc_ratio = (MESSAGE_SUMMARY_WARNING_FRAC * context_window / summary_input_tkns) * 0.8 # For good measure...
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
summary_input = str(
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff])]
+ message_sequence_to_summarize[cutoff:]
)
dummy_user_id = agent_state.user_id
dummy_agent_id = agent_state.id
message_sequence = []
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt))
if insert_acknowledgement_assistant_message:
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="assistant", text=MESSAGE_SUMMARY_REQUEST_ACK))
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="user", text=summary_input))
response = create(
llm_config=agent_state.llm_config,
user_id=agent_state.user_id,
messages=message_sequence,
stream=False,
)
printd(f"summarize_messages gpt reply: {response.choices[0]}")
reply = response.choices[0].message.content
return reply
class ArchivalMemory(ABC):
@abstractmethod
def insert(self, memory_string: str):
"""Insert new archival memory
:param memory_string: Memory string to insert
:type memory_string: str
"""
@abstractmethod
def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]:
"""Search archival memory
:param query_string: Query string
:type query_string: str
:param count: Number of results to return (None for all)
:type count: Optional[int]
:param start: Offset to start returning results from (None if 0)
:type start: Optional[int]
:return: Tuple of (list of results, total number of results)
"""
@abstractmethod
def compile(self) -> str:
"""Convert archival memory into a string representation for a prompt"""
@abstractmethod
def count(self) -> int:
"""Count the number of memories in the archival memory"""
class RecallMemory(ABC):
@abstractmethod
def text_search(self, query_string, count=None, start=None):
"""Search messages that match query_string in recall memory"""
@abstractmethod
def date_search(self, start_date, end_date, count=None, start=None):
"""Search messages between start_date and end_date in recall memory"""
@abstractmethod
def compile(self) -> str:
"""Convert recall memory into a string representation for a prompt"""
@abstractmethod
def count(self) -> int:
"""Count the number of memories in the recall memory"""
@abstractmethod
def insert(self, message: Message):
"""Insert message into recall memory"""
class DummyRecallMemory(RecallMemory):
"""Dummy in-memory version of a recall memory database (eg run on MongoDB)
Recall memory here is basically just a full conversation history with the user.
Queryable via string matching, or date matching.
Recall Memory: The AI's capability to search through past interactions,
effectively allowing it to 'remember' prior engagements with a user.
"""
def __init__(self, message_database=None, restrict_search_to_summaries=False):
self._message_logs = [] if message_database is None else message_database # consists of full message dicts
# If true, the pool of messages that can be queried are the automated summaries only
# (generated when the conversation window needs to be shortened)
self.restrict_search_to_summaries = restrict_search_to_summaries
def __len__(self):
return len(self._message_logs)
def count(self) -> int:
return len(self)
def compile(self) -> str:
# don't dump all the conversations, just statistics
system_count = user_count = assistant_count = function_count = other_count = 0
for msg in self._message_logs:
role = msg["message"]["role"]
if role == "system":
system_count += 1
elif role == "user":
user_count += 1
elif role == "assistant":
assistant_count += 1
elif role == "function":
function_count += 1
else:
other_count += 1
memory_str = (
f"Statistics:"
+ f"\n{len(self._message_logs)} total messages"
+ f"\n{system_count} system"
+ f"\n{user_count} user"
+ f"\n{assistant_count} assistant"
+ f"\n{function_count} function"
+ f"\n{other_count} other"
)
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
def insert(self, message):
raise NotImplementedError("This should be handled by the PersistenceManager, recall memory is just a search layer on top")
def text_search(self, query_string, count=None, start=None):
# in the dummy version, run an (inefficient) case-insensitive match search
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
printd(
f"recall_memory.text_search: searching for {query_string} (c={count}, s={start}) in {len(self._message_logs)} total messages"
)
matches = [
d for d in message_pool if d["message"]["content"] is not None and query_string.lower() in d["message"]["content"].lower()
]
printd(f"recall_memory - matches:\n{matches[start:start+count]}")
# start/count support paging through results
if start is not None and count is not None:
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
return matches[start:], len(matches)
else:
return matches, len(matches)
def date_search(self, start_date, end_date, count=None, start=None):
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
# First, validate the start_date and end_date format
if not validate_date_format(start_date) or not validate_date_format(end_date):
raise ValueError("Invalid date format. Expected format: YYYY-MM-DD")
# Convert dates to datetime objects for comparison
start_date_dt = datetime.datetime.strptime(start_date, "%Y-%m-%d")
end_date_dt = datetime.datetime.strptime(end_date, "%Y-%m-%d")
# Next, match items inside self._message_logs
matches = [
d
for d in message_pool
if start_date_dt <= datetime.datetime.strptime(extract_date_from_timestamp(d["timestamp"]), "%Y-%m-%d") <= end_date_dt
]
# start/count support paging through results
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
if start is not None and count is not None:
return matches[start : start + count], len(matches)
elif start is None and count is not None:
return matches[:count], len(matches)
elif start is not None and count is None:
return matches[start:], len(matches)
else:
return matches, len(matches)
class BaseRecallMemory(RecallMemory):
"""Recall memory based on base functions implemented by storage connectors"""
def __init__(self, agent_state, restrict_search_to_summaries=False):
# If true, the pool of messages that can be queried are the automated summaries only
# (generated when the conversation window needs to be shortened)
self.restrict_search_to_summaries = restrict_search_to_summaries
from letta.agent_store.storage import StorageConnector
self.agent_state = agent_state
# create embedding model
self.embed_model = embedding_model(agent_state.embedding_config)
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
# create storage backend
self.storage = StorageConnector.get_recall_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
def get_all(self, start=0, count=None):
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
results = self.storage.get_all(start, count)
results_json = [message.to_openai_dict() for message in results]
return results_json, len(results)
def text_search(self, query_string, count=None, start=None):
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
results = self.storage.query_text(query_string, count, start)
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)
def date_search(self, start_date, end_date, count=None, start=None):
start = 0 if start is None else int(start)
count = 0 if count is None else int(count)
results = self.storage.query_date(start_date, end_date, count, start)
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)
def compile(self) -> str:
total = self.storage.size()
system_count = self.storage.size(filters={"role": "system"})
user_count = self.storage.size(filters={"role": "user"})
assistant_count = self.storage.size(filters={"role": "assistant"})
function_count = self.storage.size(filters={"role": "function"})
other_count = total - (system_count + user_count + assistant_count + function_count)
memory_str = (
f"Statistics:"
+ f"\n{total} total messages"
+ f"\n{system_count} system"
+ f"\n{user_count} user"
+ f"\n{assistant_count} assistant"
+ f"\n{function_count} function"
+ f"\n{other_count} other"
)
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
def insert(self, message: Message):
self.storage.insert(message)
def insert_many(self, messages: List[Message]):
self.storage.insert_many(messages)
def save(self):
self.storage.save()
def __len__(self):
return self.storage.size()
def count(self) -> int:
return len(self)
class EmbeddingArchivalMemory(ArchivalMemory):
"""Archival memory with embedding based search"""
def __init__(self, agent_state: AgentState, top_k: int = 100):
"""Init function for archival memory
:param archival_memory_database: name of dataset to pre-fill archival with
:type archival_memory_database: str
"""
from letta.agent_store.storage import StorageConnector
self.top_k = top_k
self.agent_state = agent_state
# create embedding model
self.embed_model = embedding_model(agent_state.embedding_config)
if agent_state.embedding_config.embedding_chunk_size is None:
raise ValueError(f"Must set {agent_state.embedding_config.embedding_chunk_size}")
else:
self.embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size
# create storage backend
self.storage = StorageConnector.get_archival_storage_connector(user_id=agent_state.user_id, agent_id=agent_state.id)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
def create_passage(self, text, embedding):
return Passage(
user_id=self.agent_state.user_id,
agent_id=self.agent_state.id,
text=text,
embedding=embedding,
embedding_config=self.agent_state.embedding_config,
)
def save(self):
"""Save the index to disk"""
self.storage.save()
def insert(self, memory_string, return_ids=False) -> Union[bool, List[str]]:
"""Embed and save memory string"""
if not isinstance(memory_string, str):
raise TypeError("memory must be a string")
try:
passages = []
# breakup string into passages
for text in parse_and_chunk_text(memory_string, self.embedding_chunk_size):
embedding = self.embed_model.get_text_embedding(text)
# fixing weird bug where type returned isn't a list, but instead is an object
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
if isinstance(embedding, dict):
try:
embedding = embedding["data"][0]["embedding"]
except (KeyError, IndexError):
# TODO as a fallback, see if we can find any lists in the payload
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
passages.append(self.create_passage(text, embedding))
# grab the return IDs before the list gets modified
ids = [str(p.id) for p in passages]
# insert passages
self.storage.insert_many(passages)
if return_ids:
return ids
else:
return True
except Exception as e:
print("Archival insert error", e)
raise e
def search(self, query_string, count=None, start=None):
"""Search query string"""
start = 0 if start is None else int(start)
count = self.top_k if count is None else int(count)
if not isinstance(query_string, str):
return TypeError("query must be a string")
try:
if query_string not in self.cache:
# self.cache[query_string] = self.retriever.retrieve(query_string)
query_vec = query_embedding(self.embed_model, query_string)
self.cache[query_string] = self.storage.query(query_string, query_vec, top_k=self.top_k)
end = min(count + start, len(self.cache[query_string]))
results = self.cache[query_string][start:end]
results = [{"timestamp": get_local_time(), "content": node.text} for node in results]
return results, len(results)
except Exception as e:
print("Archival search error", e)
raise e
def compile(self) -> str:
limit = 10
passages = []
for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10
passages.append(str(passage.text))
memory_str = "\n".join(passages)
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" + f"\nSize: {self.storage.size()}"
def __len__(self):
return self.storage.size()
def count(self) -> int:
return len(self)