mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00

Co-authored-by: Charles Packer <packercharles@gmail.com> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: Shubham Naik <shub@memgpt.ai>
69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
import copy
|
|
import json
|
|
|
|
from letta.utils import json_dumps, json_loads
|
|
|
|
NO_HEARTBEAT_FUNCS = ["send_message", "pause_heartbeats"]
|
|
|
|
|
|
def insert_heartbeat(message):
|
|
# message_copy = message.copy()
|
|
message_copy = copy.deepcopy(message)
|
|
|
|
if message_copy.get("function_call"):
|
|
# function_name = message.get("function_call").get("name")
|
|
params = message_copy.get("function_call").get("arguments")
|
|
params = json_loads(params)
|
|
params["request_heartbeat"] = True
|
|
message_copy["function_call"]["arguments"] = json_dumps(params)
|
|
|
|
elif message_copy.get("tool_call"):
|
|
# function_name = message.get("tool_calls")[0].get("function").get("name")
|
|
params = message_copy.get("tool_calls")[0].get("function").get("arguments")
|
|
params = json_loads(params)
|
|
params["request_heartbeat"] = True
|
|
message_copy["tools_calls"][0]["function"]["arguments"] = json_dumps(params)
|
|
|
|
return message_copy
|
|
|
|
|
|
def heartbeat_correction(message_history, new_message):
|
|
"""Add heartbeats where we think the agent forgot to add them themselves
|
|
|
|
If the last message in the stack is a user message and the new message is an assistant func call, fix the heartbeat
|
|
|
|
See: https://github.com/cpacker/Letta/issues/601
|
|
"""
|
|
if len(message_history) < 1:
|
|
return None
|
|
|
|
last_message_was_user = False
|
|
if message_history[-1]["role"] == "user":
|
|
try:
|
|
content = json_loads(message_history[-1]["content"])
|
|
except json.JSONDecodeError:
|
|
return None
|
|
# Check if it's a user message or system message
|
|
if content["type"] == "user_message":
|
|
last_message_was_user = True
|
|
|
|
new_message_is_heartbeat_function = False
|
|
if new_message["role"] == "assistant":
|
|
if new_message.get("function_call") or new_message.get("tool_calls"):
|
|
if new_message.get("function_call"):
|
|
function_name = new_message.get("function_call").get("name")
|
|
elif new_message.get("tool_calls"):
|
|
function_name = new_message.get("tool_calls")[0].get("function").get("name")
|
|
if function_name not in NO_HEARTBEAT_FUNCS:
|
|
new_message_is_heartbeat_function = True
|
|
|
|
if last_message_was_user and new_message_is_heartbeat_function:
|
|
return insert_heartbeat(new_message)
|
|
else:
|
|
return None
|
|
|
|
|
|
def patch_function(message_history, new_message):
|
|
corrected_output = heartbeat_correction(message_history=message_history, new_message=new_message)
|
|
return corrected_output if corrected_output is not None else new_message
|