feat: add streaming support for OpenAI-compatible endpoints (#1262)

This commit is contained in:
Charles Packer 2024-04-17 23:40:52 -07:00 committed by GitHub
parent e22f3572fd
commit aeb4a94e0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 997 additions and 58 deletions

View File

@ -403,6 +403,7 @@ class Agent(object):
message_sequence: List[Message],
function_call: str = "auto",
first_message: bool = False, # hint
stream: bool = False, # TODO move to config?
) -> chat_completion_response.ChatCompletionResponse:
"""Get response from LLM API"""
try:
@ -414,6 +415,9 @@ class Agent(object):
function_call=function_call,
# hint
first_message=first_message,
# streaming
stream=stream,
stream_inferface=self.interface,
)
# special case for 'length'
if response.choices[0].finish_reason == "length":
@ -628,6 +632,7 @@ class Agent(object):
skip_verify: bool = False,
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent"""
@ -710,6 +715,7 @@ class Agent(object):
response = self._get_ai_reply(
message_sequence=input_message_sequence,
first_message=True, # passed through to the prompt formatter
stream=stream,
)
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
@ -721,6 +727,7 @@ class Agent(object):
else:
response = self._get_ai_reply(
message_sequence=input_message_sequence,
stream=stream,
)
# Step 2: check if LLM wanted to call a function

View File

@ -13,7 +13,9 @@ import typer
import questionary
from memgpt.log import logger
from memgpt.interface import CLIInterface as interface # for printing to terminal
# from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
from memgpt.cli.cli_config import configure
import memgpt.presets.presets as presets
import memgpt.utils as utils
@ -445,6 +447,8 @@ def run(
debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False,
yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False,
# streaming
stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False,
):
"""Start chatting with an MemGPT agent
@ -710,7 +714,9 @@ def run(
from memgpt.main import run_agent_loop
print() # extra space
run_agent_loop(memgpt_agent, config, first, ms, no_verify) # TODO: add back no_verify
run_agent_loop(
memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream
) # TODO: add back no_verify
def delete_agent(

View File

@ -90,6 +90,7 @@ class MemGPTConfig:
def load(cls) -> "MemGPTConfig":
# avoid circular import
from memgpt.migrate import config_is_compatible, VERSION_CUTOFF
from memgpt.utils import printd
if not config_is_compatible(allow_empty=True):
error_message = " ".join(
@ -110,7 +111,7 @@ class MemGPTConfig:
# insure all configuration directories exist
cls.create_config_dir()
print(f"Loading config from {config_path}")
printd(f"Loading config from {config_path}")
if os.path.exists(config_path):
# read existing config
config.read(config_path)

View File

@ -3,17 +3,18 @@ import time
import requests
import os
import time
from typing import List
from typing import List, Optional, Union
from memgpt.credentials import MemGPTCredentials
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
from memgpt.constants import CLI_WARNING_PREFIX
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
from memgpt.data_types import AgentState, Message
from memgpt.llm_api.openai import openai_chat_completions_request
from memgpt.llm_api.openai import openai_chat_completions_request, openai_chat_completions_process_stream
from memgpt.llm_api.azure_openai import azure_openai_chat_completions_request, MODEL_TO_AZURE_ENGINE
from memgpt.llm_api.google_ai import (
google_ai_chat_completions_request,
@ -126,14 +127,17 @@ def retry_with_exponential_backoff(
def create(
agent_state: AgentState,
messages: List[Message],
functions=None,
functions_python=None,
function_call="auto",
functions: list = None,
functions_python: list = None,
function_call: str = "auto",
# hint
first_message=False,
first_message: bool = False,
# use tool naming?
# if false, will use deprecated 'functions' style
use_tool_naming=True,
use_tool_naming: bool = True,
# streaming?
stream: bool = False,
stream_inferface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from memgpt.utils import printd
@ -169,11 +173,25 @@ def create(
function_call=function_call,
user=str(agent_state.user_id),
)
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
data=data,
)
if stream:
data.stream = True
assert isinstance(stream_inferface, AgentChunkStreamingInterface) or isinstance(
stream_inferface, AgentRefreshStreamingInterface
), type(stream_inferface)
return openai_chat_completions_process_stream(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
stream_inferface=stream_inferface,
)
else:
data.stream = False
return openai_chat_completions_request(
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=credentials.openai_key,
chat_completion_request=data,
)
# azure
elif agent_state.llm_config.model_endpoint_type == "azure":

View File

@ -1,11 +1,28 @@
import requests
import time
from typing import Union, Optional
import json
import httpx
from httpx_sse import connect_sse
from httpx_sse._exceptions import SSEError
from typing import Union, Optional, Generator
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_response import (
ChatCompletionResponse,
Choice,
Message,
ToolCall,
FunctionCall,
UsageStatistics,
ChatCompletionChunkResponse,
)
from memgpt.models.chat_completion_request import ChatCompletionRequest
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.utils import smart_urljoin
from memgpt.utils import smart_urljoin, get_utc_time
from memgpt.local_llm.utils import num_tokens_from_messages, num_tokens_from_functions
from memgpt.interface import AgentInterface
from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
OPENAI_SSE_DONE = "[DONE]"
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
@ -58,13 +75,233 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional
raise e
def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletionRequest) -> ChatCompletionResponse:
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
def openai_chat_completions_process_stream(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
) -> ChatCompletionResponse:
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
To "stream" the response in MemGPT, we want to call a streaming-compatible interface function
on the chunks received from the OpenAI-compatible server POST SSE response.
"""
assert chat_completion_request.stream == True
# Count the prompt tokens
# TODO move to post-request?
chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages]
# print(chat_history)
prompt_tokens = num_tokens_from_messages(
messages=chat_history,
model=chat_completion_request.model,
)
# We also need to add the cost of including the functions list to the input prompt
if chat_completion_request.tools is not None:
assert chat_completion_request.functions is None
prompt_tokens += num_tokens_from_functions(
functions=[t.function.model_dump() for t in chat_completion_request.tools],
model=chat_completion_request.model,
)
elif chat_completion_request.functions is not None:
assert chat_completion_request.tools is None
prompt_tokens += num_tokens_from_functions(
functions=[f.model_dump() for f in chat_completion_request.functions],
model=chat_completion_request.model,
)
TEMP_STREAM_RESPONSE_ID = "temp_id"
TEMP_STREAM_FINISH_REASON = "temp_null"
TEMP_STREAM_TOOL_CALL_ID = "temp_id"
chat_completion_response = ChatCompletionResponse(
id=TEMP_STREAM_RESPONSE_ID,
choices=[],
created=get_utc_time(),
model=chat_completion_request.model,
usage=UsageStatistics(
completion_tokens=0,
prompt_tokens=prompt_tokens,
total_tokens=prompt_tokens,
),
)
if stream_inferface:
stream_inferface.stream_start()
n_chunks = 0 # approx == n_tokens
try:
for chunk_idx, chat_completion_chunk in enumerate(
openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request)
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
# print(chat_completion_chunk)
if stream_inferface:
if isinstance(stream_inferface, AgentChunkStreamingInterface):
stream_inferface.process_chunk(chat_completion_chunk)
elif isinstance(stream_inferface, AgentRefreshStreamingInterface):
stream_inferface.process_refresh(chat_completion_response)
else:
raise TypeError(stream_inferface)
if chunk_idx == 0:
# initialize the choice objects which we will increment with the deltas
num_choices = len(chat_completion_chunk.choices)
assert num_choices > 0
chat_completion_response.choices = [
Choice(
finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten
index=i,
message=Message(
role="assistant",
),
)
for i in range(len(chat_completion_chunk.choices))
]
# add the choice delta
assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk
for chunk_choice in chat_completion_chunk.choices:
if chunk_choice.finish_reason is not None:
chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason
if chunk_choice.logprobs is not None:
chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs
accum_message = chat_completion_response.choices[chunk_choice.index].message
message_delta = chunk_choice.delta
if message_delta.content is not None:
content_delta = message_delta.content
if accum_message.content is None:
accum_message.content = content_delta
else:
accum_message.content += content_delta
if message_delta.tool_calls is not None:
tool_calls_delta = message_delta.tool_calls
# If this is the first tool call showing up in a chunk, initialize the list with it
if accum_message.tool_calls is None:
accum_message.tool_calls = [
ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments=""))
for _ in range(len(tool_calls_delta))
]
for tool_call_delta in tool_calls_delta:
if tool_call_delta.id is not None:
# TODO assert that we're not overwriting?
# TODO += instead of =?
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
if tool_call_delta.function is not None:
if tool_call_delta.function.name is not None:
# TODO assert that we're not overwriting?
# TODO += instead of =?
accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name
if tool_call_delta.function.arguments is not None:
accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments
if message_delta.function_call is not None:
raise NotImplementedError(f"Old function_call style not support with stream=True")
# overwrite response fields based on latest chunk
chat_completion_response.id = chat_completion_chunk.id
chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint
chat_completion_response.created = chat_completion_chunk.created
chat_completion_response.model = chat_completion_chunk.model
# increment chunk counter
n_chunks += 1
except Exception as e:
if stream_inferface:
stream_inferface.stream_end()
print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
raise e
finally:
if stream_inferface:
stream_inferface.stream_end()
# make sure we didn't leave temp stuff in
assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices])
assert all(
[
all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
for c in chat_completion_response.choices
]
)
assert chat_completion_response.id != TEMP_STREAM_RESPONSE_ID
# compute token usage before returning
# TODO try actually computing the #tokens instead of assuming the chunks is the same
chat_completion_response.usage.completion_tokens = n_chunks
chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks
# printd(chat_completion_response)
return chat_completion_response
def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]:
with httpx.Client() as client:
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
try:
for sse in event_source.iter_sse():
# printd(sse.event, sse.data, sse.id, sse.retry)
if sse.data == OPENAI_SSE_DONE:
# print("finished")
break
else:
chunk_data = json.loads(sse.data)
# print("chunk_data::", chunk_data)
chunk_object = ChatCompletionChunkResponse(**chunk_data)
# print("chunk_object::", chunk_object)
# id=chunk_data["id"],
# choices=[ChunkChoice],
# model=chunk_data["model"],
# system_fingerprint=chunk_data["system_fingerprint"]
# )
yield chunk_object
except SSEError as e:
if "application/json" in str(e): # Check if the error is because of JSON response
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
if response.headers["Content-Type"].startswith("application/json"):
error_details = response.json() # Parse the JSON to get the error message
print("Error:", error_details)
print("Reqeust:", vars(response.request))
else:
print("Failed to retrieve JSON error message.")
else:
print("SSEError not related to 'application/json' content type.")
# Optionally re-raise the exception if you need to propagate it
raise e
except Exception as e:
if event_source.response.request is not None:
print("HTTP Request:", vars(event_source.response.request))
if event_source.response is not None:
print("HTTP Status:", event_source.response.status_code)
print("HTTP Headers:", event_source.response.headers)
# print("HTTP Body:", event_source.response.text)
print("Exception message:", str(e))
raise e
def openai_chat_completions_request_stream(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
) -> Generator[ChatCompletionChunkResponse, None, None]:
from memgpt.utils import printd
url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = data.model_dump(exclude_none=True)
data = chat_completion_request.model_dump(exclude_none=True)
printd("Request:\n", json.dumps(data, indent=2))
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
@ -77,21 +314,59 @@ def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletion
printd(f"Sending request to {url}")
try:
# Example code to trigger a rate limit response:
# mock_response = requests.Response()
# mock_response.status_code = 429
# http_error = requests.exceptions.HTTPError("429 Client Error: Too Many Requests")
# http_error.response = mock_response
# raise http_error
return _sse_post(url=url, data=data, headers=headers)
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
printd(f"Got HTTPError, exception={http_err}, payload={data}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
printd(f"Got RequestException, exception={req_err}")
raise req_err
except Exception as e:
# Handle other potential errors
printd(f"Got unknown Exception, exception={e}")
raise e
# Example code to trigger a context overflow response (for an 8k model)
# data["messages"][-1]["content"] = " ".join(["repeat after me this is not a fluke"] * 1000)
def openai_chat_completions_request(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
) -> ChatCompletionResponse:
"""Send a ChatCompletion request to an OpenAI-compatible server
If request.stream == True, will yield ChatCompletionChunkResponses
If request.stream == False, will return a ChatCompletionResponse
https://platform.openai.com/docs/guides/text-generation?lang=curl
"""
from memgpt.utils import printd
url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = chat_completion_request.model_dump(exclude_none=True)
printd("Request:\n", json.dumps(data, indent=2))
# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
data.pop("functions")
data.pop("function_call", None) # extra safe, should exist always (default="auto")
if "tools" in data and data["tools"] is None:
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
return response
except requests.exceptions.HTTPError as http_err:

View File

@ -1,6 +1,7 @@
import os
import requests
import tiktoken
from typing import List
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
@ -74,6 +75,148 @@ def count_tokens(s: str, model: str = "gpt-4") -> int:
return len(encoding.encode(s))
def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
"""Return the number of tokens used by a list of functions.
Copied from https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for function in functions:
function_tokens = len(encoding.encode(function["name"]))
function_tokens += len(encoding.encode(function["description"]))
if "parameters" in function:
parameters = function["parameters"]
if "properties" in parameters:
for propertiesKey in parameters["properties"]:
function_tokens += len(encoding.encode(propertiesKey))
v = parameters["properties"][propertiesKey]
for field in v:
if field == "type":
function_tokens += 2
function_tokens += len(encoding.encode(v["type"]))
elif field == "description":
function_tokens += 2
function_tokens += len(encoding.encode(v["description"]))
elif field == "enum":
function_tokens -= 3
for o in v["enum"]:
function_tokens += 3
function_tokens += len(encoding.encode(o))
else:
print(f"Warning: not supported field {field}")
function_tokens += 11
num_tokens += function_tokens
num_tokens += 12
return num_tokens
def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"):
"""Based on above code (num_tokens_from_functions).
Example to encode:
[{
'id': '8b6707cf-2352-4804-93db-0423f',
'type': 'function',
'function': {
'name': 'send_message',
'arguments': '{\n "message": "More human than human is our motto."\n}'
}
}]
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
# print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for tool_call in tool_calls:
function_tokens = len(encoding.encode(tool_call["id"]))
function_tokens += 2 + len(encoding.encode(tool_call["type"]))
function_tokens += 2 + len(encoding.encode(tool_call["function"]["name"]))
function_tokens += 2 + len(encoding.encode(tool_call["function"]["arguments"]))
num_tokens += function_tokens
# TODO adjust?
num_tokens += 12
return num_tokens
def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int:
"""Return the number of tokens used by a list of messages.
From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
For counting tokens in function calling RESPONSES, see:
https://hmarr.com/blog/counting-openai-tokens/, https://github.com/hmarr/openai-chat-tokens
For counting tokens in function calling REQUESTS, see:
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11
"""
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
# print("Warning: model not found. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
if model in {
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613",
}:
tokens_per_message = 3
tokens_per_name = 1
elif model == "gpt-3.5-turbo-0301":
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
tokens_per_name = -1 # if there's a name, the role is omitted
elif "gpt-3.5-turbo" in model:
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
elif "gpt-4" in model:
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
return num_tokens_from_messages(messages, model="gpt-4-0613")
else:
raise NotImplementedError(
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
)
num_tokens = 0
for message in messages:
num_tokens += tokens_per_message
for key, value in message.items():
try:
if isinstance(value, list) and key == "tool_calls":
num_tokens += num_tokens_from_tool_calls(tool_calls=value, model=model)
# special case for tool calling (list)
# num_tokens += len(encoding.encode(value["name"]))
# num_tokens += len(encoding.encode(value["arguments"]))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
except TypeError as e:
print(f"tiktoken encoding failed on: {value}")
raise e
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
return num_tokens
def get_available_wrappers() -> dict:
return {
"experimental-wrapper-neural-chat-grammar-noforce": configurable_wrapper.ConfigurableJSONWrapper(

View File

@ -10,11 +10,10 @@ import typer
from rich.console import Console
from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, JSON_ENSURE_ASCII, JSON_LOADS_STRICT, REQ_HEARTBEAT_MESSAGE
console = Console()
from memgpt.agent import save_agent
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.interface import CLIInterface as interface # for printing to terminal
# from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.streaming_interface import AgentRefreshStreamingInterface
from memgpt.config import MemGPTConfig
import memgpt.agent as agent
import memgpt.system as system
@ -27,6 +26,8 @@ from memgpt.metadata import MetadataStore
# import benchmark
from memgpt.benchmark.benchmark import bench
# interface = interface()
app = typer.Typer(pretty_exceptions_enable=False)
app.command(name="run")(run)
app.command(name="version")(version)
@ -47,7 +48,7 @@ app.command(name="benchmark")(bench)
app.command(name="delete-agent")(delete_agent)
def clear_line(strip_ui=False):
def clear_line(console, strip_ui=False):
if strip_ui:
return
if os.name == "nt": # for windows
@ -57,7 +58,19 @@ def clear_line(strip_ui=False):
sys.stdout.flush()
def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False):
def run_agent_loop(
memgpt_agent: agent.Agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False, stream=False
):
if isinstance(memgpt_agent.interface, AgentRefreshStreamingInterface):
# memgpt_agent.interface.toggle_streaming(on=stream)
if not stream:
memgpt_agent.interface = memgpt_agent.interface.nonstreaming_interface
if hasattr(memgpt_agent.interface, "console"):
console = memgpt_agent.interface.console
else:
console = Console()
counter = 0
user_input = None
skip_next_user_input = False
@ -65,8 +78,8 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
USER_GOES_FIRST = first
if not USER_GOES_FIRST:
console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]")
clear_line(strip_ui)
console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]\n")
clear_line(console, strip_ui=strip_ui)
print()
multiline_input = False
@ -74,12 +87,16 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
while True:
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
# Ask for user input
if not stream:
print()
user_input = questionary.text(
"Enter your message:",
multiline=multiline_input,
qmark=">",
).ask()
clear_line(strip_ui)
clear_line(console, strip_ui=strip_ui)
if not stream:
print()
# Gracefully exit on Ctrl-C/D
if user_input is None:
@ -157,13 +174,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
command = user_input.strip().split()
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
if amount == 0:
interface.print_messages(memgpt_agent._messages, dump=True)
memgpt_agent.interface.print_messages(memgpt_agent._messages, dump=True)
else:
interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
memgpt_agent.interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
continue
elif user_input.lower() == "/dumpraw":
interface.print_messages_raw(memgpt_agent._messages)
memgpt_agent.interface.print_messages_raw(memgpt_agent._messages)
continue
elif user_input.lower() == "/memory":
@ -194,9 +211,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
else:
print(f"Popping last {pop_amount} messages from stack")
for _ in range(min(pop_amount, len(memgpt_agent.messages))):
memgpt_agent._messages.pop()
# Persist the state
save_agent(agent=memgpt_agent, ms=ms)
memgpt_agent.messages.pop()
continue
elif user_input.lower() == "/retry":
@ -218,13 +233,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
if memgpt_agent.messages[x].get("role") == "assistant":
text = user_input[len("/rethink ") :].strip()
# Do the /rethink-ing
message_obj = memgpt_agent._messages[x]
message_obj.text = text
# To persist to the database, all we need to do is "re-insert" into recall memory
memgpt_agent.persistence_manager.recall_memory.storage.update(record=message_obj)
memgpt_agent.messages[x].update({"content": text})
break
continue
@ -321,7 +330,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
# No skip options
elif user_input.lower() == "/wipe":
memgpt_agent = agent.Agent(interface)
memgpt_agent = agent.Agent(memgpt_agent.interface)
user_message = None
elif user_input.lower() == "/heartbeat":
@ -354,7 +363,10 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
def process_agent_step(user_message, no_verify):
new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step(
user_message, first_message=False, skip_verify=no_verify
user_message,
first_message=False,
skip_verify=no_verify,
stream=stream,
)
skip_next_user_input = False
@ -376,9 +388,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
else:
with console.status("[bold cyan]Thinking...") as status:
if stream:
# Don't display the "Thinking..." if streaming
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
else:
with console.status("[bold cyan]Thinking...") as status:
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
except KeyboardInterrupt:
print("User interrupt occurred.")
retry = questionary.confirm("Retry agent.step()?").ask()

View File

@ -55,6 +55,8 @@ class UsageStatistics(BaseModel):
class ChatCompletionResponse(BaseModel):
"""https://platform.openai.com/docs/api-reference/chat/object"""
id: str
choices: List[Choice]
created: datetime.datetime
@ -64,3 +66,64 @@ class ChatCompletionResponse(BaseModel):
# object: str = Field(default="chat.completion")
object: Literal["chat.completion"] = "chat.completion"
usage: UsageStatistics
class FunctionCallDelta(BaseModel):
# arguments: Optional[str] = None
name: Optional[str] = None
arguments: str
# name: str
class ToolCallDelta(BaseModel):
index: int
id: Optional[str] = None
# "Currently, only function is supported"
type: Literal["function"] = "function"
# function: ToolCallFunction
function: Optional[FunctionCallDelta] = None
class MessageDelta(BaseModel):
"""Partial delta stream of a Message
Example ChunkResponse:
{
'id': 'chatcmpl-9EOCkKdicNo1tiL1956kPvCnL2lLS',
'object': 'chat.completion.chunk',
'created': 1713216662,
'model': 'gpt-4-0613',
'system_fingerprint': None,
'choices': [{
'index': 0,
'delta': {'content': 'User'},
'logprobs': None,
'finish_reason': None
}]
}
"""
content: Optional[str] = None
tool_calls: Optional[List[ToolCallDelta]] = None
# role: Optional[str] = None
function_call: Optional[FunctionCallDelta] = None # Deprecated
class ChunkChoice(BaseModel):
finish_reason: Optional[str] = None # NOTE: when streaming will be null
index: int
delta: MessageDelta
logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None
class ChatCompletionChunkResponse(BaseModel):
"""https://platform.openai.com/docs/api-reference/chat/streaming"""
id: str
choices: List[ChunkChoice]
created: datetime.datetime
model: str
# system_fingerprint: str # docs say this is mandatory, but in reality API returns None
system_fingerprint: Optional[str] = None
# object: str = Field(default="chat.completion")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"

View File

@ -0,0 +1,398 @@
from abc import ABC, abstractmethod
import json
import re
import sys
from typing import List, Optional
# from colorama import Fore, Style, init
from rich.console import Console
from rich.live import Live
from rich.markup import escape
from rich.style import Style
from rich.text import Text
from memgpt.utils import printd
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.data_types import Message
from memgpt.models.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse
from memgpt.interface import AgentInterface, CLIInterface
# init(autoreset=True)
# DEBUG = True # puts full message outputs in the terminal
DEBUG = False # only dumps important messages in the terminal
STRIP_UI = False
class AgentChunkStreamingInterface(ABC):
"""Interfaces handle MemGPT-related events (observer pattern)
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
"""
@abstractmethod
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT receives a user message"""
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT generates some internal monologue"""
raise NotImplementedError
@abstractmethod
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT uses send_message"""
raise NotImplementedError
@abstractmethod
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT calls a function"""
raise NotImplementedError
@abstractmethod
def process_chunk(self, chunk: ChatCompletionChunkResponse):
"""Process a streaming chunk from an OpenAI-compatible server"""
raise NotImplementedError
@abstractmethod
def stream_start(self):
"""Any setup required before streaming begins"""
raise NotImplementedError
@abstractmethod
def stream_end(self):
"""Any cleanup required after streaming ends"""
raise NotImplementedError
class StreamingCLIInterface(AgentChunkStreamingInterface):
"""Version of the CLI interface that attaches to a stream generator and prints along the way.
When a chunk is received, we write the delta to the buffer. If the buffer type has changed,
we write out a newline + set the formatting for the new line.
The two buffer types are:
(1) content (inner thoughts)
(2) tool_calls (function calling)
NOTE: this assumes that the deltas received in the chunks are in-order, e.g.
that once 'content' deltas stop streaming, they won't be received again. See notes
on alternative version of the StreamingCLIInterface that does not have this same problem below:
An alternative implementation could instead maintain the partial message state, and on each
process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen.
"""
# CLIInterface is static/stateless
nonstreaming_interface = CLIInterface()
def __init__(self):
"""The streaming CLI interface state for determining which buffer is currently being written to"""
self.streaming_buffer_type = None
def _flush(self):
pass
def process_chunk(self, chunk: ChatCompletionChunkResponse):
assert len(chunk.choices) == 1, chunk
message_delta = chunk.choices[0].delta
# Starting a new buffer line
if not self.streaming_buffer_type:
assert not (
message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)
), f"Error: got both content and tool_calls in message stream\n{message_delta}"
if message_delta.content is not None:
# Write out the prefix for inner thoughts
print("Inner thoughts: ", end="", flush=True)
elif message_delta.tool_calls is not None:
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
# Write out the prefix for function calling
print("Calling function: ", end="", flush=True)
# Potentially switch/flush a buffer line
else:
pass
# Write out the delta
if message_delta.content is not None:
if self.streaming_buffer_type and self.streaming_buffer_type != "content":
print()
self.streaming_buffer_type = "content"
# Simple, just write out to the buffer
print(message_delta.content, end="", flush=True)
elif message_delta.tool_calls is not None:
if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls":
print()
self.streaming_buffer_type = "tool_calls"
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
function_call = message_delta.tool_calls[0].function
# Slightly more complex - want to write parameters in a certain way (paren-style)
# function_name(function_args)
if function_call.name:
# NOTE: need to account for closing the brace later
print(f"{function_call.name}(", end="", flush=True)
if function_call.arguments:
print(function_call.arguments, end="", flush=True)
def stream_start(self):
# should be handled by stream_end(), but just in case
self.streaming_buffer_type = None
def stream_end(self):
if self.streaming_buffer_type is not None:
# TODO: should have a separate self.tool_call_open_paren flag
if self.streaming_buffer_type == "tool_calls":
print(")", end="", flush=True)
print() # newline to move the cursor
self.streaming_buffer_type = None # reset buffer tracker
@staticmethod
def important_message(msg: str):
StreamingCLIInterface.nonstreaming_interface(msg)
@staticmethod
def warning_message(msg: str):
StreamingCLIInterface.nonstreaming_interface(msg)
@staticmethod
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def assistant_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def memory_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def system_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
@staticmethod
def print_messages(message_sequence: List[Message], dump=False):
StreamingCLIInterface.nonstreaming_interface(message_sequence, dump)
@staticmethod
def print_messages_simple(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
@staticmethod
def print_messages_raw(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
@staticmethod
def step_yield():
pass
class AgentRefreshStreamingInterface(ABC):
"""Same as the ChunkStreamingInterface, but
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
"""
@abstractmethod
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT receives a user message"""
raise NotImplementedError
@abstractmethod
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT generates some internal monologue"""
raise NotImplementedError
@abstractmethod
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT uses send_message"""
raise NotImplementedError
@abstractmethod
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT calls a function"""
raise NotImplementedError
@abstractmethod
def process_refresh(self, response: ChatCompletionResponse):
"""Process a streaming chunk from an OpenAI-compatible server"""
raise NotImplementedError
@abstractmethod
def stream_start(self):
"""Any setup required before streaming begins"""
raise NotImplementedError
@abstractmethod
def stream_end(self):
"""Any cleanup required after streaming ends"""
raise NotImplementedError
@abstractmethod
def toggle_streaming(self, on: bool):
"""Toggle streaming on/off (off = regular CLI interface)"""
raise NotImplementedError
class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
"""Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step.
We maintain the partial message state in the interface state, and on each
process chunk we:
(1) update the partial message state,
(2) refresh/rewrite the state to the screen.
"""
nonstreaming_interface = CLIInterface
def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True):
"""Initialize the streaming CLI interface state."""
self.console = Console()
# Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates
self.live = Live("", console=self.console, refresh_per_second=10)
# self.live.start() # Start the Live display context and keep it running
# Use italics / emoji?
self.fancy = fancy
self.streaming = True
self.separate_send_message = separate_send_message
self.disable_inner_mono_call = disable_inner_mono_call
def toggle_streaming(self, on: bool):
self.streaming = on
if on:
self.separate_send_message = True
self.disable_inner_mono_call = True
else:
self.separate_send_message = False
self.disable_inner_mono_call = False
def update_output(self, content: str):
"""Update the displayed output with new content."""
# We use the `Live` object's update mechanism to refresh content without clearing the console
if not self.fancy:
content = escape(content)
self.live.update(self.console.render_str(content), refresh=True)
def process_refresh(self, response: ChatCompletionResponse):
"""Process the response to rewrite the current output buffer."""
if not response.choices:
self.update_output("💭 [italic]...[/italic]")
return # Early exit if there are no choices
choice = response.choices[0]
inner_thoughts = choice.message.content if choice.message.content else ""
tool_calls = choice.message.tool_calls if choice.message.tool_calls else []
if self.fancy:
message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else ""
else:
message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else ""
if tool_calls:
function_call = tool_calls[0].function
function_name = function_call.name # Function name, can be an empty string
function_args = function_call.arguments # Function arguments, can be an empty string
if message_string:
message_string += "\n"
# special case here for send_message
if self.separate_send_message and function_name == "send_message":
try:
message = json.loads(function_args)["message"]
except:
prefix = '{\n "message": "'
if len(function_args) < len(prefix):
message = "..."
elif function_args.startswith(prefix):
message = function_args[len(prefix) :]
else:
message = function_args
message_string += f"🤖 [bold yellow]{message}[/bold yellow]"
else:
message_string += f"{function_name}({function_args})"
self.update_output(message_string)
def stream_start(self):
if self.streaming:
print()
self.live.start() # Start the Live display context and keep it running
self.update_output("💭 [italic]...[/italic]")
def stream_end(self):
if self.streaming:
if self.live.is_started:
self.live.stop()
print()
self.live = Live("", console=self.console, refresh_per_second=10)
@staticmethod
def important_message(msg: str):
StreamingCLIInterface.nonstreaming_interface.important_message(msg)
@staticmethod
def warning_message(msg: str):
StreamingCLIInterface.nonstreaming_interface.warning_message(msg)
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
if self.disable_inner_mono_call:
return
StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj)
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
if self.separate_send_message:
return
StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj)
@staticmethod
def memory_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj)
@staticmethod
def system_message(msg: str, msg_obj: Optional[Message] = None):
StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj)
@staticmethod
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj)
@staticmethod
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj)
@staticmethod
def print_messages(message_sequence: List[Message], dump=False):
StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump)
@staticmethod
def print_messages_simple(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
@staticmethod
def print_messages_raw(message_sequence: List[Message]):
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
@staticmethod
def step_yield():
pass

15
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
[[package]]
name = "aiohttp"
@ -1524,6 +1524,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
http2 = ["h2 (>=3,<5)"]
socks = ["socksio (==1.*)"]
[[package]]
name = "httpx-sse"
version = "0.4.0"
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
optional = false
python-versions = ">=3.8"
files = [
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
]
[[package]]
name = "huggingface-hub"
version = "0.22.2"
@ -6094,4 +6105,4 @@ server = ["fastapi", "uvicorn", "websockets"]
[metadata]
lock-version = "2.0"
python-versions = "<3.13,>=3.10"
content-hash = "5c36931d717323eab3eea32bf383b27578ea8f3467fd230ce543af364caffa92"
content-hash = "a9635dccf8bd7d826f776e36a9d6fbc845a1b7de0586d06c6a9ce7230a5a14bc"

View File

@ -57,6 +57,7 @@ llama-index-embeddings-openai = "^0.1.1"
llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true}
llama-index-embeddings-azure-openai = "^0.1.6"
python-multipart = "^0.0.9"
httpx-sse = "^0.4.0"
[tool.poetry.extras]
local = ["llama-index-embeddings-huggingface"]