mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
fix: patch incorrect use of name in function response (#1642)
This commit is contained in:
parent
16e59b07a3
commit
db3744f208
@ -110,19 +110,19 @@ class Agent(BaseAgent):
|
||||
self.user = user
|
||||
|
||||
# initialize a tool rules solver
|
||||
if agent_state.tool_rules:
|
||||
# if there are tool rules, print out a warning
|
||||
for rule in agent_state.tool_rules:
|
||||
if not isinstance(rule, TerminalToolRule):
|
||||
warnings.warn("Tool rules only work reliably for the latest OpenAI models that support structured outputs.")
|
||||
break
|
||||
|
||||
self.tool_rules_solver = ToolRulesSolver(tool_rules=agent_state.tool_rules)
|
||||
|
||||
# gpt-4, gpt-3.5-turbo, ...
|
||||
self.model = self.agent_state.llm_config.model
|
||||
self.supports_structured_output = check_supports_structured_output(model=self.model, tool_rules=agent_state.tool_rules)
|
||||
|
||||
# if there are tool rules, print out a warning
|
||||
if not self.supports_structured_output and agent_state.tool_rules:
|
||||
for rule in agent_state.tool_rules:
|
||||
if not isinstance(rule, TerminalToolRule):
|
||||
warnings.warn("Tool rules only work reliably for model backends that support structured outputs (e.g. OpenAI gpt-4o).")
|
||||
break
|
||||
|
||||
# state managers
|
||||
self.block_manager = BlockManager()
|
||||
|
||||
@ -236,17 +236,15 @@ class Agent(BaseAgent):
|
||||
|
||||
# Extend conversation with function response
|
||||
function_response = package_function_response(False, error_msg)
|
||||
new_message = Message.dict_to_message(
|
||||
new_message = Message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
# Base info OpenAI-style
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
name=self.agent_state.name,
|
||||
role="tool",
|
||||
name=function_name, # NOTE: when role is 'tool', the 'name' is the function name, not agent name
|
||||
content=[TextContent(text=function_response)],
|
||||
tool_call_id=tool_call_id,
|
||||
# Letta extras
|
||||
tool_returns=tool_returns,
|
||||
group_id=group_id,
|
||||
)
|
||||
@ -455,7 +453,6 @@ class Agent(BaseAgent):
|
||||
Message.dict_to_message(
|
||||
id=response_message_id,
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
name=self.agent_state.name,
|
||||
@ -659,17 +656,15 @@ class Agent(BaseAgent):
|
||||
else None
|
||||
)
|
||||
messages.append(
|
||||
Message.dict_to_message(
|
||||
Message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
# Base info OpenAI-style
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": function_response,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
name=self.agent_state.name,
|
||||
role="tool",
|
||||
name=function_name, # NOTE: when role is 'tool', the 'name' is the function name, not agent name
|
||||
content=[TextContent(text=function_response)],
|
||||
tool_call_id=tool_call_id,
|
||||
# Letta extras
|
||||
tool_returns=[tool_return] if sandbox_run_result else None,
|
||||
group_id=group_id,
|
||||
)
|
||||
@ -686,7 +681,6 @@ class Agent(BaseAgent):
|
||||
Message.dict_to_message(
|
||||
id=response_message_id,
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
name=self.agent_state.name,
|
||||
@ -777,7 +771,6 @@ class Agent(BaseAgent):
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
@ -789,7 +782,6 @@ class Agent(BaseAgent):
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
@ -801,7 +793,6 @@ class Agent(BaseAgent):
|
||||
assert self.agent_state.created_by_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
@ -1057,7 +1048,6 @@ class Agent(BaseAgent):
|
||||
assert self.agent_state.created_by_id is not None, "User ID is not set"
|
||||
user_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=openai_message_dict,
|
||||
# created_at=timestamp,
|
||||
@ -1117,7 +1107,6 @@ class Agent(BaseAgent):
|
||||
messages=[
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.created_by_id,
|
||||
model=self.model,
|
||||
openai_message_dict=packed_summary_message,
|
||||
)
|
||||
|
@ -691,7 +691,6 @@ def _prepare_anthropic_request(
|
||||
# Convert to Anthropic format
|
||||
msg_objs = [
|
||||
_Message.dict_to_message(
|
||||
user_id=None,
|
||||
agent_id=None,
|
||||
openai_message_dict=m,
|
||||
)
|
||||
|
@ -315,7 +315,7 @@ def cohere_chat_completions_request(
|
||||
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
||||
|
||||
# Convert messages to Cohere format
|
||||
msg_objs = [Message.dict_to_message(user_id=uuid.uuid4(), agent_id=uuid.uuid4(), openai_message_dict=m) for m in data["messages"]]
|
||||
msg_objs = [Message.dict_to_message(agent_id=uuid.uuid4(), openai_message_dict=m) for m in data["messages"]]
|
||||
|
||||
# System message 0 should instead be a "preamble"
|
||||
# See: https://docs.cohere.com/reference/chat
|
||||
|
@ -137,19 +137,26 @@ class Message(BaseMessage):
|
||||
"""
|
||||
|
||||
id: str = BaseMessage.generate_id_field()
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
content: Optional[List[LettaMessageContentUnion]] = Field(None, description="The content of the message.")
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
tool_calls: Optional[List[OpenAIToolCall]] = Field(None, description="The list of tool calls requested.")
|
||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||
# Basic OpenAI-style fields
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
content: Optional[List[LettaMessageContentUnion]] = Field(None, description="The content of the message.")
|
||||
# NOTE: in OpenAI, this field is only used for roles 'user', 'assistant', and 'function' (now deprecated). 'tool' does not use it.
|
||||
name: Optional[str] = Field(
|
||||
None,
|
||||
description="For role user/assistant: the (optional) name of the participant. For role tool/function: the name of the function called.",
|
||||
)
|
||||
tool_calls: Optional[List[OpenAIToolCall]] = Field(
|
||||
None, description="The list of tool calls requested. Only applicable for role assistant."
|
||||
)
|
||||
tool_call_id: Optional[str] = Field(None, description="The ID of the tool call. Only applicable for role tool.")
|
||||
# Extras
|
||||
step_id: Optional[str] = Field(None, description="The id of the step that this message was created in.")
|
||||
otid: Optional[str] = Field(None, description="The offline threading id associated with this message")
|
||||
tool_returns: Optional[List[ToolReturn]] = Field(None, description="Tool execution return information for prior tool calls")
|
||||
group_id: Optional[str] = Field(None, description="The multi-agent group that the message was sent in")
|
||||
|
||||
# This overrides the optional base orm schema, created_at MUST exist on all messages objects
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@ -406,7 +413,6 @@ class Message(BaseMessage):
|
||||
|
||||
@staticmethod
|
||||
def dict_to_message(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
openai_message_dict: dict,
|
||||
model: Optional[str] = None, # model used to make function call
|
||||
@ -560,7 +566,7 @@ class Message(BaseMessage):
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
content=content,
|
||||
name=name,
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
@ -575,7 +581,7 @@ class Message(BaseMessage):
|
||||
# standard fields expected in an OpenAI ChatCompletion message object
|
||||
role=MessageRole(openai_message_dict["role"]),
|
||||
content=content,
|
||||
name=name,
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else name,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
@ -809,7 +815,7 @@ class Message(BaseMessage):
|
||||
text_content = None
|
||||
|
||||
if self.role != "tool" and self.name is not None:
|
||||
warnings.warn(f"Using Google AI with non-null 'name' field ({self.name}) not yet supported.")
|
||||
warnings.warn(f"Using Google AI with non-null 'name' field (name={self.name} role={self.role}), not yet supported.")
|
||||
|
||||
if self.role == "system":
|
||||
# NOTE: Gemini API doesn't have a 'system' role, use 'user' instead
|
||||
@ -908,7 +914,9 @@ class Message(BaseMessage):
|
||||
if "parts" not in google_ai_message or not google_ai_message["parts"]:
|
||||
# If parts is empty, add a default text part
|
||||
google_ai_message["parts"] = [{"text": "empty message"}]
|
||||
warnings.warn(f"Empty 'parts' detected in message with role '{self.role}'. Added default empty text part.")
|
||||
warnings.warn(
|
||||
f"Empty 'parts' detected in message with role '{self.role}'. Added default empty text part. Full message:\n{vars(self)}"
|
||||
)
|
||||
|
||||
return google_ai_message
|
||||
|
||||
|
@ -212,7 +212,6 @@ class AgentManager:
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
user_id=agent_state.created_by_id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
@ -223,9 +222,7 @@ class AgentManager:
|
||||
)
|
||||
else:
|
||||
init_messages = [
|
||||
PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id, user_id=agent_state.created_by_id, model=agent_state.llm_config.model, openai_message_dict=msg
|
||||
)
|
||||
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
|
||||
for msg in init_messages
|
||||
]
|
||||
|
||||
@ -713,7 +710,6 @@ class AgentManager:
|
||||
# Swap the system message out (only if there is a diff)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||||
)
|
||||
@ -800,7 +796,6 @@ class AgentManager:
|
||||
)
|
||||
system_message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
user_id=agent_state.created_by_id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
@ -902,7 +897,7 @@ class AgentManager:
|
||||
# get the agent
|
||||
agent = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent.id, user_id=actor.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||||
agent_id=agent.id, model=agent.llm_config.model, openai_message_dict={"role": "system", "content": content}
|
||||
)
|
||||
|
||||
# update agent in-context message IDs
|
||||
|
Loading…
Reference in New Issue
Block a user