MemGPT/letta/local_llm/function_parser.py
Sarah Wooders 85faf5f474
chore: migrate package name to letta (#1775)
Co-authored-by: Charles Packer <packercharles@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
2024-09-23 09:15:18 -07:00

69 lines
2.6 KiB
Python

import copy
import json
from letta.utils import json_dumps, json_loads
NO_HEARTBEAT_FUNCS = ["send_message", "pause_heartbeats"]
def insert_heartbeat(message):
# message_copy = message.copy()
message_copy = copy.deepcopy(message)
if message_copy.get("function_call"):
# function_name = message.get("function_call").get("name")
params = message_copy.get("function_call").get("arguments")
params = json_loads(params)
params["request_heartbeat"] = True
message_copy["function_call"]["arguments"] = json_dumps(params)
elif message_copy.get("tool_call"):
# function_name = message.get("tool_calls")[0].get("function").get("name")
params = message_copy.get("tool_calls")[0].get("function").get("arguments")
params = json_loads(params)
params["request_heartbeat"] = True
message_copy["tools_calls"][0]["function"]["arguments"] = json_dumps(params)
return message_copy
def heartbeat_correction(message_history, new_message):
"""Add heartbeats where we think the agent forgot to add them themselves
If the last message in the stack is a user message and the new message is an assistant func call, fix the heartbeat
See: https://github.com/cpacker/Letta/issues/601
"""
if len(message_history) < 1:
return None
last_message_was_user = False
if message_history[-1]["role"] == "user":
try:
content = json_loads(message_history[-1]["content"])
except json.JSONDecodeError:
return None
# Check if it's a user message or system message
if content["type"] == "user_message":
last_message_was_user = True
new_message_is_heartbeat_function = False
if new_message["role"] == "assistant":
if new_message.get("function_call") or new_message.get("tool_calls"):
if new_message.get("function_call"):
function_name = new_message.get("function_call").get("name")
elif new_message.get("tool_calls"):
function_name = new_message.get("tool_calls")[0].get("function").get("name")
if function_name not in NO_HEARTBEAT_FUNCS:
new_message_is_heartbeat_function = True
if last_message_was_user and new_message_is_heartbeat_function:
return insert_heartbeat(new_message)
else:
return None
def patch_function(message_history, new_message):
corrected_output = heartbeat_correction(message_history=message_history, new_message=new_message)
return corrected_output if corrected_output is not None else new_message