feat: pass message UUIDs during message streaming (POST SSE send_message) (#1120)

This commit is contained in:
Charles Packer 2024-03-10 15:34:37 -07:00 committed by GitHub
parent f193aa58fa
commit 2ca92d6955
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 128 additions and 60 deletions

View File

@ -430,9 +430,6 @@ class Agent(object):
response_message.tool_calls = [response_message.tool_calls[0]]
assert response_message.tool_calls is not None and len(response_message.tool_calls) > 0
# The content if then internal monologue, not chat
self.interface.internal_monologue(response_message.content)
# generate UUID for tool call
if override_tool_call_id or response_message.function_call:
tool_call_id = get_tool_call_id() # needs to be a string for JSON
@ -456,6 +453,9 @@ class Agent(object):
) # extend conversation with assistant's reply
printd(f"Function call message: {messages[-1]}")
# The content if then internal monologue, not chat
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
# Step 3: call the function
# Note: the JSON response may not always be valid; be sure to handle errors
@ -483,7 +483,7 @@ class Agent(object):
},
)
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}")
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
return messages, False, True # force a heartbeat to allow agent to handle error
# Failure case 2: function name is OK, but function args are bad JSON
@ -506,7 +506,7 @@ class Agent(object):
},
)
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}")
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
return messages, False, True # force a heartbeat to allow agent to handle error
# (Still parsing function args)
@ -519,7 +519,9 @@ class Agent(object):
heartbeat_request = False
# Failure case 3: function failed during execution
self.interface.function_message(f"Running {function_name}({function_args})")
# NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message
# this is because the function/tool role message is only created once the function/tool has executed/returned
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1])
try:
spec = inspect.getfullargspec(function_to_call).annotations
@ -562,12 +564,12 @@ class Agent(object):
},
)
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}")
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
return messages, False, True # force a heartbeat to allow agent to handle error
# If no failures happened along the way: ...
# Step 4: send the info on the function call and function response to GPT
self.interface.function_message(f"Success: {function_response_string}")
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
@ -581,10 +583,11 @@ class Agent(object):
},
)
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1])
else:
# Standard non-function reply
self.interface.internal_monologue(response_message.content)
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
@ -593,6 +596,7 @@ class Agent(object):
openai_message_dict=response_message.model_dump(),
)
) # extend conversation with assistant's reply
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])
heartbeat_request = False
function_failed = False
@ -604,7 +608,8 @@ class Agent(object):
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
) -> Tuple[List[dict], bool, bool, bool]:
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent"""
try:
@ -617,7 +622,6 @@ class Agent(object):
else:
raise ValueError(f"Bad type for user_message: {type(user_message)}")
self.interface.user_message(user_message_text)
packed_user_message = {"role": "user", "content": user_message_text}
# Special handling for AutoGen messages with 'name' field
try:
@ -639,6 +643,7 @@ class Agent(object):
model=self.model,
openai_message_dict=packed_user_message,
)
self.interface.user_message(user_message_text, msg_obj=packed_user_message_obj)
input_message_sequence = self.messages + [packed_user_message]
# Alternatively, the requestor can send an empty user message
@ -729,8 +734,8 @@ class Agent(object):
)
self._append_to_messages(all_new_messages)
all_new_messages_dicts = [msg.to_openai_dict() for msg in all_new_messages]
return all_new_messages_dicts, heartbeat_request, function_failed, active_memory_warning, response.usage.completion_tokens
messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages
return messages_to_return, heartbeat_request, function_failed, active_memory_warning, response.usage.completion_tokens
except Exception as e:
printd(f"step() failed\nuser_message = {user_message}\nerror = {e}")
@ -741,7 +746,7 @@ class Agent(object):
self.summarize_messages_inplace()
# Try step again
return self.step(user_message, first_message=first_message)
return self.step(user_message, first_message=first_message, return_dicts=return_dicts)
else:
printd(f"step() failed with an unrecognized exception: '{str(e)}'")
raise e

View File

@ -1,8 +1,10 @@
import json
import re
from typing import Optional
from colorama import Fore, Style, init
from memgpt.data_types import Message
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
init(autoreset=True)
@ -64,7 +66,7 @@ class AutoGenInterface(object):
"""Clears the buffer. Call before every agent.step() when using MemGPT+AutoGen"""
self.message_list = []
def internal_monologue(self, msg):
def internal_monologue(self, msg: str, msg_obj: Optional[Message]):
# NOTE: never gets appended
if self.debug:
print(f"inner thoughts :: {msg}")
@ -74,14 +76,14 @@ class AutoGenInterface(object):
message = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}" if self.fancy else f"[MemGPT agent's inner thoughts] {msg}"
print(message)
def assistant_message(self, msg):
def assistant_message(self, msg: str, msg_obj: Optional[Message]):
# NOTE: gets appended
if self.debug:
print(f"assistant :: {msg}")
# message = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}" if self.fancy else msg
self.message_list.append(msg)
def memory_message(self, msg):
def memory_message(self, msg: str):
# NOTE: never gets appended
if self.debug:
print(f"memory :: {msg}")
@ -90,7 +92,7 @@ class AutoGenInterface(object):
)
print(message)
def system_message(self, msg):
def system_message(self, msg: str):
# NOTE: gets appended
if self.debug:
print(f"system :: {msg}")
@ -98,7 +100,7 @@ class AutoGenInterface(object):
print(message)
self.message_list.append(msg)
def user_message(self, msg, raw=False):
def user_message(self, msg: str, msg_obj: Optional[Message], raw=False):
if self.debug:
print(f"user :: {msg}")
if not self.show_user_message:
@ -136,7 +138,7 @@ class AutoGenInterface(object):
# TODO should we ever be appending this?
self.message_list.append(message)
def function_message(self, msg):
def function_message(self, msg: str, msg_obj: Optional[Message]):
if self.debug:
print(f"function :: {msg}")
if not self.show_function_outputs:

View File

@ -4,13 +4,14 @@ import json
import math
from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, JSON_ENSURE_ASCII
from memgpt.agent import Agent
### Functions / tools the agent can use
# All functions should return a response string (or None)
# If the function fails, throw an exception
def send_message(self, message: str) -> Optional[str]:
def send_message(self: Agent, message: str) -> Optional[str]:
"""
Sends a message to the human user.
@ -20,7 +21,8 @@ def send_message(self, message: str) -> Optional[str]:
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.interface.assistant_message(message)
# FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference
self.interface.assistant_message(message, msg_obj=self._messages[-1])
return None
@ -36,7 +38,7 @@ Returns:
"""
def pause_heartbeats(self, minutes: int) -> Optional[str]:
def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
minutes = min(MAX_PAUSE_HEARTBEATS, minutes)
# Record the current time
@ -50,7 +52,7 @@ def pause_heartbeats(self, minutes: int) -> Optional[str]:
pause_heartbeats.__doc__ = pause_heartbeats_docstring
def core_memory_append(self, name: str, content: str) -> Optional[str]:
def core_memory_append(self: Agent, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
@ -66,7 +68,7 @@ def core_memory_append(self, name: str, content: str) -> Optional[str]:
return None
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
def core_memory_replace(self: Agent, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
@ -83,7 +85,7 @@ def core_memory_replace(self, name: str, old_content: str, new_content: str) ->
return None
def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[str]:
def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.
@ -112,7 +114,7 @@ def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[s
return results_str
def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
def conversation_search_date(self: Agent, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using a date range.
@ -142,7 +144,7 @@ def conversation_search_date(self, start_date: str, end_date: str, page: Optiona
return results_str
def archival_memory_insert(self, content: str) -> Optional[str]:
def archival_memory_insert(self: Agent, content: str) -> Optional[str]:
"""
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
@ -156,7 +158,7 @@ def archival_memory_insert(self, content: str) -> Optional[str]:
return None
def archival_memory_search(self, query: str, page: Optional[int] = 0) -> Optional[str]:
def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search archival memory using semantic (embedding-based) search.

View File

@ -1,11 +1,13 @@
from abc import ABC, abstractmethod
import json
import re
from typing import List, Optional
from colorama import Fore, Style, init
from memgpt.utils import printd
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.data_types import Message
init(autoreset=True)
@ -16,25 +18,28 @@ STRIP_UI = False
class AgentInterface(ABC):
"""Interfaces handle MemGPT-related events (observer pattern)"""
"""Interfaces handle MemGPT-related events (observer pattern)
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
"""
@abstractmethod
def user_message(self, msg):
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT receives a user message"""
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg):
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT generates some internal monologue"""
raise NotImplementedError
@abstractmethod
def assistant_message(self, msg):
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT uses send_message"""
raise NotImplementedError
@abstractmethod
def function_message(self, msg):
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT calls a function"""
raise NotImplementedError
@ -58,14 +63,14 @@ class CLIInterface(AgentInterface):
"""Basic interface for dumping agent events to the command-line"""
@staticmethod
def important_message(msg):
def important_message(msg: str):
fstr = f"{Fore.MAGENTA}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}"
if STRIP_UI:
fstr = "{msg}"
print(fstr.format(msg=msg))
@staticmethod
def warning_message(msg):
def warning_message(msg: str):
fstr = f"{Fore.RED}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}"
if STRIP_UI:
fstr = "{msg}"
@ -73,7 +78,7 @@ class CLIInterface(AgentInterface):
print(fstr.format(msg=msg))
@staticmethod
def internal_monologue(msg):
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
# ANSI escape code for italic is '\x1B[3m'
fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {{msg}}{Style.RESET_ALL}"
if STRIP_UI:
@ -81,28 +86,28 @@ class CLIInterface(AgentInterface):
print(fstr.format(msg=msg))
@staticmethod
def assistant_message(msg):
def assistant_message(msg: str, msg_obj: Optional[Message] = None):
fstr = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{{msg}}{Style.RESET_ALL}"
if STRIP_UI:
fstr = "{msg}"
print(fstr.format(msg=msg))
@staticmethod
def memory_message(msg):
def memory_message(msg: str, msg_obj: Optional[Message] = None):
fstr = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{{msg}}{Style.RESET_ALL}"
if STRIP_UI:
fstr = "{msg}"
print(fstr.format(msg=msg))
@staticmethod
def system_message(msg):
def system_message(msg: str, msg_obj: Optional[Message] = None):
fstr = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
if STRIP_UI:
fstr = "{msg}"
print(fstr.format(msg=msg))
@staticmethod
def user_message(msg, raw=False, dump=False, debug=DEBUG):
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
def print_user_message(icon, msg, printf=print):
if STRIP_UI:
printf(f"{icon} {msg}")
@ -148,7 +153,8 @@ class CLIInterface(AgentInterface):
printd_user_message("🧑", msg_json)
@staticmethod
def function_message(msg, debug=DEBUG):
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
def print_function_message(icon, msg, color=Fore.RED, printf=print):
if STRIP_UI:
printf(f"{icon} [function] {msg}")
@ -166,6 +172,9 @@ class CLIInterface(AgentInterface):
printd_function_message("🟢", msg)
elif msg.startswith("Error: "):
printd_function_message("🔴", msg)
elif msg.startswith("Ran "):
# NOTE: ignore 'ran' messages that come post-execution
return
elif msg.startswith("Running "):
if debug:
printd_function_message("", msg)
@ -230,7 +239,10 @@ class CLIInterface(AgentInterface):
printd_function_message("", msg)
@staticmethod
def print_messages(message_sequence, dump=False):
def print_messages(message_sequence: List[Message], dump=False):
# rewrite to dict format
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
idx = len(message_sequence)
for msg in message_sequence:
if dump:
@ -270,7 +282,10 @@ class CLIInterface(AgentInterface):
print(f"Unknown role: {content}")
@staticmethod
def print_messages_simple(message_sequence):
def print_messages_simple(message_sequence: List[Message]):
# rewrite to dict format
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
for msg in message_sequence:
role = msg["role"]
content = msg["content"]
@ -285,7 +300,10 @@ class CLIInterface(AgentInterface):
print(f"Unknown role: {content}")
@staticmethod
def print_messages_raw(message_sequence):
def print_messages_raw(message_sequence: List[Message]):
# rewrite to dict format
message_sequence = [msg.to_openai_dict() for msg in message_sequence]
for msg in message_sequence:
print(msg)

View File

@ -155,13 +155,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
command = user_input.strip().split()
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
if amount == 0:
interface.print_messages(memgpt_agent.messages, dump=True)
interface.print_messages(memgpt_agent._messages, dump=True)
else:
interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
continue
elif user_input.lower() == "/dumpraw":
interface.print_messages_raw(memgpt_agent.messages)
interface.print_messages_raw(memgpt_agent._messages)
continue
elif user_input.lower() == "/memory":

View File

@ -1,10 +1,12 @@
import asyncio
import queue
from datetime import datetime
from typing import Optional
import pytz
from memgpt.interface import AgentInterface
from memgpt.data_types import Message
class QueuingInterface(AgentInterface):
@ -38,7 +40,8 @@ class QueuingInterface(AgentInterface):
message = self.buffer.get()
if message == "STOP":
break
yield message | {"date": datetime.now(tz=pytz.utc).isoformat()}
# yield message | {"date": datetime.now(tz=pytz.utc).isoformat()}
yield message
else:
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
@ -51,38 +54,73 @@ class QueuingInterface(AgentInterface):
self.buffer.put({"internal_error": error})
self.buffer.put("STOP")
def user_message(self, msg: str):
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Handle reception of a user message"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
def internal_monologue(self, msg: str) -> None:
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent's internal monologue"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
self.buffer.put({"internal_monologue": msg})
def assistant_message(self, msg: str) -> None:
new_message = {"internal_monologue": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
new_message["date"] = msg_obj.created_at.isoformat()
self.buffer.put(new_message)
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent sending a message"""
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
self.buffer.put({"assistant_message": msg})
def function_message(self, msg: str) -> None:
new_message = {"assistant_message": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
new_message["date"] = msg_obj.created_at.isoformat()
self.buffer.put(new_message)
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
"""Handle the agent calling a function"""
# TODO handle 'function' messages that indicate the start of a function call
assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata"
if self.debug:
print(msg)
if msg.startswith("Running "):
msg = msg.replace("Running ", "")
self.buffer.put({"function_call": msg})
new_message = {"function_call": msg}
elif msg.startswith("Ran "):
if not include_ran_messages:
return
msg = msg.replace("Ran ", "Function call returned: ")
new_message = {"function_call": msg}
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
self.buffer.put({"function_return": msg, "status": "success"})
new_message = {"function_return": msg, "status": "success"}
elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
self.buffer.put({"function_return": msg, "status": "error"})
new_message = {"function_return": msg, "status": "error"}
else:
# NOTE: generic, should not happen
self.buffer.put({"function_message": msg})
new_message = {"function_message": msg}
# add extra metadata
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
new_message["date"] = msg_obj.created_at.isoformat()
self.buffer.put(new_message)

View File

@ -336,7 +336,10 @@ class SyncServer(LockingServer):
counter = 0
while True:
new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step(
next_input_message, first_message=False, skip_verify=no_verify
next_input_message,
first_message=False,
skip_verify=no_verify,
return_dicts=False,
)
counter += 1