mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add streaming support for OpenAI-compatible endpoints (#1262)
This commit is contained in:
parent
e22f3572fd
commit
aeb4a94e0b
@ -403,6 +403,7 @@ class Agent(object):
|
||||
message_sequence: List[Message],
|
||||
function_call: str = "auto",
|
||||
first_message: bool = False, # hint
|
||||
stream: bool = False, # TODO move to config?
|
||||
) -> chat_completion_response.ChatCompletionResponse:
|
||||
"""Get response from LLM API"""
|
||||
try:
|
||||
@ -414,6 +415,9 @@ class Agent(object):
|
||||
function_call=function_call,
|
||||
# hint
|
||||
first_message=first_message,
|
||||
# streaming
|
||||
stream=stream,
|
||||
stream_inferface=self.interface,
|
||||
)
|
||||
# special case for 'length'
|
||||
if response.choices[0].finish_reason == "length":
|
||||
@ -628,6 +632,7 @@ class Agent(object):
|
||||
skip_verify: bool = False,
|
||||
return_dicts: bool = True, # if True, return dicts, if False, return Message objects
|
||||
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
|
||||
stream: bool = False, # TODO move to config?
|
||||
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
|
||||
"""Top-level event message handler for the MemGPT agent"""
|
||||
|
||||
@ -710,6 +715,7 @@ class Agent(object):
|
||||
response = self._get_ai_reply(
|
||||
message_sequence=input_message_sequence,
|
||||
first_message=True, # passed through to the prompt formatter
|
||||
stream=stream,
|
||||
)
|
||||
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
|
||||
break
|
||||
@ -721,6 +727,7 @@ class Agent(object):
|
||||
else:
|
||||
response = self._get_ai_reply(
|
||||
message_sequence=input_message_sequence,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# Step 2: check if LLM wanted to call a function
|
||||
|
@ -13,7 +13,9 @@ import typer
|
||||
import questionary
|
||||
|
||||
from memgpt.log import logger
|
||||
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
|
||||
# from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
from memgpt.streaming_interface import StreamingRefreshCLIInterface as interface # for printing to terminal
|
||||
from memgpt.cli.cli_config import configure
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
@ -445,6 +447,8 @@ def run(
|
||||
debug: Annotated[bool, typer.Option(help="Use --debug to enable debugging output")] = False,
|
||||
no_verify: Annotated[bool, typer.Option(help="Bypass message verification")] = False,
|
||||
yes: Annotated[bool, typer.Option("-y", help="Skip confirmation prompt and use defaults")] = False,
|
||||
# streaming
|
||||
stream: Annotated[bool, typer.Option(help="Enables message streaming in the CLI (if the backend supports it)")] = False,
|
||||
):
|
||||
"""Start chatting with an MemGPT agent
|
||||
|
||||
@ -710,7 +714,9 @@ def run(
|
||||
from memgpt.main import run_agent_loop
|
||||
|
||||
print() # extra space
|
||||
run_agent_loop(memgpt_agent, config, first, ms, no_verify) # TODO: add back no_verify
|
||||
run_agent_loop(
|
||||
memgpt_agent=memgpt_agent, config=config, first=first, ms=ms, no_verify=no_verify, stream=stream
|
||||
) # TODO: add back no_verify
|
||||
|
||||
|
||||
def delete_agent(
|
||||
|
@ -90,6 +90,7 @@ class MemGPTConfig:
|
||||
def load(cls) -> "MemGPTConfig":
|
||||
# avoid circular import
|
||||
from memgpt.migrate import config_is_compatible, VERSION_CUTOFF
|
||||
from memgpt.utils import printd
|
||||
|
||||
if not config_is_compatible(allow_empty=True):
|
||||
error_message = " ".join(
|
||||
@ -110,7 +111,7 @@ class MemGPTConfig:
|
||||
|
||||
# insure all configuration directories exist
|
||||
cls.create_config_dir()
|
||||
print(f"Loading config from {config_path}")
|
||||
printd(f"Loading config from {config_path}")
|
||||
if os.path.exists(config_path):
|
||||
# read existing config
|
||||
config.read(config_path)
|
||||
|
@ -3,17 +3,18 @@ import time
|
||||
import requests
|
||||
import os
|
||||
import time
|
||||
from typing import List
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.local_llm.chat_completion_proxy import get_chat_completion
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool, cast_message_to_subtype
|
||||
from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
|
||||
from memgpt.data_types import AgentState, Message
|
||||
|
||||
from memgpt.llm_api.openai import openai_chat_completions_request
|
||||
from memgpt.llm_api.openai import openai_chat_completions_request, openai_chat_completions_process_stream
|
||||
from memgpt.llm_api.azure_openai import azure_openai_chat_completions_request, MODEL_TO_AZURE_ENGINE
|
||||
from memgpt.llm_api.google_ai import (
|
||||
google_ai_chat_completions_request,
|
||||
@ -126,14 +127,17 @@ def retry_with_exponential_backoff(
|
||||
def create(
|
||||
agent_state: AgentState,
|
||||
messages: List[Message],
|
||||
functions=None,
|
||||
functions_python=None,
|
||||
function_call="auto",
|
||||
functions: list = None,
|
||||
functions_python: list = None,
|
||||
function_call: str = "auto",
|
||||
# hint
|
||||
first_message=False,
|
||||
first_message: bool = False,
|
||||
# use tool naming?
|
||||
# if false, will use deprecated 'functions' style
|
||||
use_tool_naming=True,
|
||||
use_tool_naming: bool = True,
|
||||
# streaming?
|
||||
stream: bool = False,
|
||||
stream_inferface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Return response to chat completion with backoff"""
|
||||
from memgpt.utils import printd
|
||||
@ -169,11 +173,25 @@ def create(
|
||||
function_call=function_call,
|
||||
user=str(agent_state.user_id),
|
||||
)
|
||||
return openai_chat_completions_request(
|
||||
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
||||
api_key=credentials.openai_key,
|
||||
data=data,
|
||||
)
|
||||
|
||||
if stream:
|
||||
data.stream = True
|
||||
assert isinstance(stream_inferface, AgentChunkStreamingInterface) or isinstance(
|
||||
stream_inferface, AgentRefreshStreamingInterface
|
||||
), type(stream_inferface)
|
||||
return openai_chat_completions_process_stream(
|
||||
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
||||
api_key=credentials.openai_key,
|
||||
chat_completion_request=data,
|
||||
stream_inferface=stream_inferface,
|
||||
)
|
||||
else:
|
||||
data.stream = False
|
||||
return openai_chat_completions_request(
|
||||
url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
||||
api_key=credentials.openai_key,
|
||||
chat_completion_request=data,
|
||||
)
|
||||
|
||||
# azure
|
||||
elif agent_state.llm_config.model_endpoint_type == "azure":
|
||||
|
@ -1,11 +1,28 @@
|
||||
import requests
|
||||
import time
|
||||
from typing import Union, Optional
|
||||
import json
|
||||
import httpx
|
||||
from httpx_sse import connect_sse
|
||||
from httpx_sse._exceptions import SSEError
|
||||
from typing import Union, Optional, Generator
|
||||
|
||||
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
Message,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
UsageStatistics,
|
||||
ChatCompletionChunkResponse,
|
||||
)
|
||||
from memgpt.models.chat_completion_request import ChatCompletionRequest
|
||||
from memgpt.models.embedding_response import EmbeddingResponse
|
||||
from memgpt.utils import smart_urljoin
|
||||
from memgpt.utils import smart_urljoin, get_utc_time
|
||||
from memgpt.local_llm.utils import num_tokens_from_messages, num_tokens_from_functions
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
|
||||
|
||||
OPENAI_SSE_DONE = "[DONE]"
|
||||
|
||||
|
||||
def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional[bool] = False) -> dict:
|
||||
@ -58,13 +75,233 @@ def openai_get_model_list(url: str, api_key: Union[str, None], fix_url: Optional
|
||||
raise e
|
||||
|
||||
|
||||
def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||
"""https://platform.openai.com/docs/guides/text-generation?lang=curl"""
|
||||
def openai_chat_completions_process_stream(
|
||||
url: str,
|
||||
api_key: str,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
stream_inferface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
|
||||
|
||||
To "stream" the response in MemGPT, we want to call a streaming-compatible interface function
|
||||
on the chunks received from the OpenAI-compatible server POST SSE response.
|
||||
"""
|
||||
assert chat_completion_request.stream == True
|
||||
|
||||
# Count the prompt tokens
|
||||
# TODO move to post-request?
|
||||
chat_history = [m.model_dump(exclude_none=True) for m in chat_completion_request.messages]
|
||||
# print(chat_history)
|
||||
|
||||
prompt_tokens = num_tokens_from_messages(
|
||||
messages=chat_history,
|
||||
model=chat_completion_request.model,
|
||||
)
|
||||
# We also need to add the cost of including the functions list to the input prompt
|
||||
if chat_completion_request.tools is not None:
|
||||
assert chat_completion_request.functions is None
|
||||
prompt_tokens += num_tokens_from_functions(
|
||||
functions=[t.function.model_dump() for t in chat_completion_request.tools],
|
||||
model=chat_completion_request.model,
|
||||
)
|
||||
elif chat_completion_request.functions is not None:
|
||||
assert chat_completion_request.tools is None
|
||||
prompt_tokens += num_tokens_from_functions(
|
||||
functions=[f.model_dump() for f in chat_completion_request.functions],
|
||||
model=chat_completion_request.model,
|
||||
)
|
||||
|
||||
TEMP_STREAM_RESPONSE_ID = "temp_id"
|
||||
TEMP_STREAM_FINISH_REASON = "temp_null"
|
||||
TEMP_STREAM_TOOL_CALL_ID = "temp_id"
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
id=TEMP_STREAM_RESPONSE_ID,
|
||||
choices=[],
|
||||
created=get_utc_time(),
|
||||
model=chat_completion_request.model,
|
||||
usage=UsageStatistics(
|
||||
completion_tokens=0,
|
||||
prompt_tokens=prompt_tokens,
|
||||
total_tokens=prompt_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
if stream_inferface:
|
||||
stream_inferface.stream_start()
|
||||
|
||||
n_chunks = 0 # approx == n_tokens
|
||||
try:
|
||||
for chunk_idx, chat_completion_chunk in enumerate(
|
||||
openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request)
|
||||
):
|
||||
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
|
||||
# print(chat_completion_chunk)
|
||||
|
||||
if stream_inferface:
|
||||
if isinstance(stream_inferface, AgentChunkStreamingInterface):
|
||||
stream_inferface.process_chunk(chat_completion_chunk)
|
||||
elif isinstance(stream_inferface, AgentRefreshStreamingInterface):
|
||||
stream_inferface.process_refresh(chat_completion_response)
|
||||
else:
|
||||
raise TypeError(stream_inferface)
|
||||
|
||||
if chunk_idx == 0:
|
||||
# initialize the choice objects which we will increment with the deltas
|
||||
num_choices = len(chat_completion_chunk.choices)
|
||||
assert num_choices > 0
|
||||
chat_completion_response.choices = [
|
||||
Choice(
|
||||
finish_reason=TEMP_STREAM_FINISH_REASON, # NOTE: needs to be ovrerwritten
|
||||
index=i,
|
||||
message=Message(
|
||||
role="assistant",
|
||||
),
|
||||
)
|
||||
for i in range(len(chat_completion_chunk.choices))
|
||||
]
|
||||
|
||||
# add the choice delta
|
||||
assert len(chat_completion_chunk.choices) == len(chat_completion_response.choices), chat_completion_chunk
|
||||
for chunk_choice in chat_completion_chunk.choices:
|
||||
if chunk_choice.finish_reason is not None:
|
||||
chat_completion_response.choices[chunk_choice.index].finish_reason = chunk_choice.finish_reason
|
||||
|
||||
if chunk_choice.logprobs is not None:
|
||||
chat_completion_response.choices[chunk_choice.index].logprobs = chunk_choice.logprobs
|
||||
|
||||
accum_message = chat_completion_response.choices[chunk_choice.index].message
|
||||
message_delta = chunk_choice.delta
|
||||
|
||||
if message_delta.content is not None:
|
||||
content_delta = message_delta.content
|
||||
if accum_message.content is None:
|
||||
accum_message.content = content_delta
|
||||
else:
|
||||
accum_message.content += content_delta
|
||||
|
||||
if message_delta.tool_calls is not None:
|
||||
tool_calls_delta = message_delta.tool_calls
|
||||
|
||||
# If this is the first tool call showing up in a chunk, initialize the list with it
|
||||
if accum_message.tool_calls is None:
|
||||
accum_message.tool_calls = [
|
||||
ToolCall(id=TEMP_STREAM_TOOL_CALL_ID, function=FunctionCall(name="", arguments=""))
|
||||
for _ in range(len(tool_calls_delta))
|
||||
]
|
||||
|
||||
for tool_call_delta in tool_calls_delta:
|
||||
if tool_call_delta.id is not None:
|
||||
# TODO assert that we're not overwriting?
|
||||
# TODO += instead of =?
|
||||
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
|
||||
if tool_call_delta.function is not None:
|
||||
if tool_call_delta.function.name is not None:
|
||||
# TODO assert that we're not overwriting?
|
||||
# TODO += instead of =?
|
||||
accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name
|
||||
if tool_call_delta.function.arguments is not None:
|
||||
accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments
|
||||
|
||||
if message_delta.function_call is not None:
|
||||
raise NotImplementedError(f"Old function_call style not support with stream=True")
|
||||
|
||||
# overwrite response fields based on latest chunk
|
||||
chat_completion_response.id = chat_completion_chunk.id
|
||||
chat_completion_response.system_fingerprint = chat_completion_chunk.system_fingerprint
|
||||
chat_completion_response.created = chat_completion_chunk.created
|
||||
chat_completion_response.model = chat_completion_chunk.model
|
||||
|
||||
# increment chunk counter
|
||||
n_chunks += 1
|
||||
|
||||
except Exception as e:
|
||||
if stream_inferface:
|
||||
stream_inferface.stream_end()
|
||||
print(f"Parsing ChatCompletion stream failed with error:\n{str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
if stream_inferface:
|
||||
stream_inferface.stream_end()
|
||||
|
||||
# make sure we didn't leave temp stuff in
|
||||
assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices])
|
||||
assert all(
|
||||
[
|
||||
all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
|
||||
for c in chat_completion_response.choices
|
||||
]
|
||||
)
|
||||
assert chat_completion_response.id != TEMP_STREAM_RESPONSE_ID
|
||||
|
||||
# compute token usage before returning
|
||||
# TODO try actually computing the #tokens instead of assuming the chunks is the same
|
||||
chat_completion_response.usage.completion_tokens = n_chunks
|
||||
chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks
|
||||
|
||||
# printd(chat_completion_response)
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
def _sse_post(url: str, data: dict, headers: dict) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
|
||||
with httpx.Client() as client:
|
||||
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
|
||||
try:
|
||||
for sse in event_source.iter_sse():
|
||||
# printd(sse.event, sse.data, sse.id, sse.retry)
|
||||
if sse.data == OPENAI_SSE_DONE:
|
||||
# print("finished")
|
||||
break
|
||||
else:
|
||||
chunk_data = json.loads(sse.data)
|
||||
# print("chunk_data::", chunk_data)
|
||||
chunk_object = ChatCompletionChunkResponse(**chunk_data)
|
||||
# print("chunk_object::", chunk_object)
|
||||
# id=chunk_data["id"],
|
||||
# choices=[ChunkChoice],
|
||||
# model=chunk_data["model"],
|
||||
# system_fingerprint=chunk_data["system_fingerprint"]
|
||||
# )
|
||||
yield chunk_object
|
||||
|
||||
except SSEError as e:
|
||||
if "application/json" in str(e): # Check if the error is because of JSON response
|
||||
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
|
||||
if response.headers["Content-Type"].startswith("application/json"):
|
||||
error_details = response.json() # Parse the JSON to get the error message
|
||||
print("Error:", error_details)
|
||||
print("Reqeust:", vars(response.request))
|
||||
else:
|
||||
print("Failed to retrieve JSON error message.")
|
||||
else:
|
||||
print("SSEError not related to 'application/json' content type.")
|
||||
|
||||
# Optionally re-raise the exception if you need to propagate it
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
if event_source.response.request is not None:
|
||||
print("HTTP Request:", vars(event_source.response.request))
|
||||
if event_source.response is not None:
|
||||
print("HTTP Status:", event_source.response.status_code)
|
||||
print("HTTP Headers:", event_source.response.headers)
|
||||
# print("HTTP Body:", event_source.response.text)
|
||||
print("Exception message:", str(e))
|
||||
raise e
|
||||
|
||||
|
||||
def openai_chat_completions_request_stream(
|
||||
url: str,
|
||||
api_key: str,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
from memgpt.utils import printd
|
||||
|
||||
url = smart_urljoin(url, "chat/completions")
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
data = data.model_dump(exclude_none=True)
|
||||
data = chat_completion_request.model_dump(exclude_none=True)
|
||||
|
||||
printd("Request:\n", json.dumps(data, indent=2))
|
||||
|
||||
# If functions == None, strip from the payload
|
||||
if "functions" in data and data["functions"] is None:
|
||||
@ -77,21 +314,59 @@ def openai_chat_completions_request(url: str, api_key: str, data: ChatCompletion
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
try:
|
||||
# Example code to trigger a rate limit response:
|
||||
# mock_response = requests.Response()
|
||||
# mock_response.status_code = 429
|
||||
# http_error = requests.exceptions.HTTPError("429 Client Error: Too Many Requests")
|
||||
# http_error.response = mock_response
|
||||
# raise http_error
|
||||
return _sse_post(url=url, data=data, headers=headers)
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
printd(f"Got HTTPError, exception={http_err}, payload={data}")
|
||||
raise http_err
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
printd(f"Got RequestException, exception={req_err}")
|
||||
raise req_err
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
printd(f"Got unknown Exception, exception={e}")
|
||||
raise e
|
||||
|
||||
# Example code to trigger a context overflow response (for an 8k model)
|
||||
# data["messages"][-1]["content"] = " ".join(["repeat after me this is not a fluke"] * 1000)
|
||||
|
||||
def openai_chat_completions_request(
|
||||
url: str,
|
||||
api_key: str,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Send a ChatCompletion request to an OpenAI-compatible server
|
||||
|
||||
If request.stream == True, will yield ChatCompletionChunkResponses
|
||||
If request.stream == False, will return a ChatCompletionResponse
|
||||
|
||||
https://platform.openai.com/docs/guides/text-generation?lang=curl
|
||||
"""
|
||||
from memgpt.utils import printd
|
||||
|
||||
url = smart_urljoin(url, "chat/completions")
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
data = chat_completion_request.model_dump(exclude_none=True)
|
||||
|
||||
printd("Request:\n", json.dumps(data, indent=2))
|
||||
|
||||
# If functions == None, strip from the payload
|
||||
if "functions" in data and data["functions"] is None:
|
||||
data.pop("functions")
|
||||
data.pop("function_call", None) # extra safe, should exist always (default="auto")
|
||||
|
||||
if "tools" in data and data["tools"] is None:
|
||||
data.pop("tools")
|
||||
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
printd(f"response = {response}")
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
|
||||
response = response.json() # convert to dict from string
|
||||
printd(f"response.json = {response}")
|
||||
|
||||
response = ChatCompletionResponse(**response) # convert to 'dot-dict' style which is the openai python client default
|
||||
return response
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
|
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import requests
|
||||
import tiktoken
|
||||
from typing import List
|
||||
|
||||
import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros
|
||||
import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
|
||||
@ -74,6 +75,148 @@ def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
return len(encoding.encode(s))
|
||||
|
||||
|
||||
def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
|
||||
"""Return the number of tokens used by a list of functions.
|
||||
|
||||
Copied from https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = 0
|
||||
for function in functions:
|
||||
function_tokens = len(encoding.encode(function["name"]))
|
||||
function_tokens += len(encoding.encode(function["description"]))
|
||||
|
||||
if "parameters" in function:
|
||||
parameters = function["parameters"]
|
||||
if "properties" in parameters:
|
||||
for propertiesKey in parameters["properties"]:
|
||||
function_tokens += len(encoding.encode(propertiesKey))
|
||||
v = parameters["properties"][propertiesKey]
|
||||
for field in v:
|
||||
if field == "type":
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v["type"]))
|
||||
elif field == "description":
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v["description"]))
|
||||
elif field == "enum":
|
||||
function_tokens -= 3
|
||||
for o in v["enum"]:
|
||||
function_tokens += 3
|
||||
function_tokens += len(encoding.encode(o))
|
||||
else:
|
||||
print(f"Warning: not supported field {field}")
|
||||
function_tokens += 11
|
||||
|
||||
num_tokens += function_tokens
|
||||
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
|
||||
def num_tokens_from_tool_calls(tool_calls: List[dict], model: str = "gpt-4"):
|
||||
"""Based on above code (num_tokens_from_functions).
|
||||
|
||||
Example to encode:
|
||||
[{
|
||||
'id': '8b6707cf-2352-4804-93db-0423f',
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'send_message',
|
||||
'arguments': '{\n "message": "More human than human is our motto."\n}'
|
||||
}
|
||||
}]
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
# print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = 0
|
||||
for tool_call in tool_calls:
|
||||
function_tokens = len(encoding.encode(tool_call["id"]))
|
||||
function_tokens += 2 + len(encoding.encode(tool_call["type"]))
|
||||
function_tokens += 2 + len(encoding.encode(tool_call["function"]["name"]))
|
||||
function_tokens += 2 + len(encoding.encode(tool_call["function"]["arguments"]))
|
||||
|
||||
num_tokens += function_tokens
|
||||
|
||||
# TODO adjust?
|
||||
num_tokens += 12
|
||||
return num_tokens
|
||||
|
||||
|
||||
def num_tokens_from_messages(messages: List[dict], model: str = "gpt-4") -> int:
|
||||
"""Return the number of tokens used by a list of messages.
|
||||
|
||||
From: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
For counting tokens in function calling RESPONSES, see:
|
||||
https://hmarr.com/blog/counting-openai-tokens/, https://github.com/hmarr/openai-chat-tokens
|
||||
|
||||
For counting tokens in function calling REQUESTS, see:
|
||||
https://community.openai.com/t/how-to-calculate-the-tokens-when-using-function-call/266573/11
|
||||
"""
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
# print("Warning: model not found. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
if model in {
|
||||
"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613",
|
||||
}:
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
elif model == "gpt-3.5-turbo-0301":
|
||||
tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
tokens_per_name = -1 # if there's a name, the role is omitted
|
||||
elif "gpt-3.5-turbo" in model:
|
||||
# print("Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.")
|
||||
return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
|
||||
elif "gpt-4" in model:
|
||||
# print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.")
|
||||
return num_tokens_from_messages(messages, model="gpt-4-0613")
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"""num_tokens_from_messages() is not implemented for model {model}. See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens."""
|
||||
)
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
try:
|
||||
|
||||
if isinstance(value, list) and key == "tool_calls":
|
||||
num_tokens += num_tokens_from_tool_calls(tool_calls=value, model=model)
|
||||
# special case for tool calling (list)
|
||||
# num_tokens += len(encoding.encode(value["name"]))
|
||||
# num_tokens += len(encoding.encode(value["arguments"]))
|
||||
|
||||
else:
|
||||
num_tokens += len(encoding.encode(value))
|
||||
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
|
||||
except TypeError as e:
|
||||
print(f"tiktoken encoding failed on: {value}")
|
||||
raise e
|
||||
|
||||
num_tokens += 3 # every reply is primed with <|start|>assistant<|message|>
|
||||
return num_tokens
|
||||
|
||||
|
||||
def get_available_wrappers() -> dict:
|
||||
return {
|
||||
"experimental-wrapper-neural-chat-grammar-noforce": configurable_wrapper.ConfigurableJSONWrapper(
|
||||
|
@ -10,11 +10,10 @@ import typer
|
||||
from rich.console import Console
|
||||
from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, JSON_ENSURE_ASCII, JSON_LOADS_STRICT, REQ_HEARTBEAT_MESSAGE
|
||||
|
||||
console = Console()
|
||||
|
||||
from memgpt.agent import save_agent
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
|
||||
# from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
from memgpt.streaming_interface import AgentRefreshStreamingInterface
|
||||
from memgpt.config import MemGPTConfig
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
@ -27,6 +26,8 @@ from memgpt.metadata import MetadataStore
|
||||
# import benchmark
|
||||
from memgpt.benchmark.benchmark import bench
|
||||
|
||||
# interface = interface()
|
||||
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
app.command(name="run")(run)
|
||||
app.command(name="version")(version)
|
||||
@ -47,7 +48,7 @@ app.command(name="benchmark")(bench)
|
||||
app.command(name="delete-agent")(delete_agent)
|
||||
|
||||
|
||||
def clear_line(strip_ui=False):
|
||||
def clear_line(console, strip_ui=False):
|
||||
if strip_ui:
|
||||
return
|
||||
if os.name == "nt": # for windows
|
||||
@ -57,7 +58,19 @@ def clear_line(strip_ui=False):
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False):
|
||||
def run_agent_loop(
|
||||
memgpt_agent: agent.Agent, config: MemGPTConfig, first, ms: MetadataStore, no_verify=False, cfg=None, strip_ui=False, stream=False
|
||||
):
|
||||
if isinstance(memgpt_agent.interface, AgentRefreshStreamingInterface):
|
||||
# memgpt_agent.interface.toggle_streaming(on=stream)
|
||||
if not stream:
|
||||
memgpt_agent.interface = memgpt_agent.interface.nonstreaming_interface
|
||||
|
||||
if hasattr(memgpt_agent.interface, "console"):
|
||||
console = memgpt_agent.interface.console
|
||||
else:
|
||||
console = Console()
|
||||
|
||||
counter = 0
|
||||
user_input = None
|
||||
skip_next_user_input = False
|
||||
@ -65,8 +78,8 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
USER_GOES_FIRST = first
|
||||
|
||||
if not USER_GOES_FIRST:
|
||||
console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]")
|
||||
clear_line(strip_ui)
|
||||
console.input("[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]\n")
|
||||
clear_line(console, strip_ui=strip_ui)
|
||||
print()
|
||||
|
||||
multiline_input = False
|
||||
@ -74,12 +87,16 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
while True:
|
||||
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
|
||||
# Ask for user input
|
||||
if not stream:
|
||||
print()
|
||||
user_input = questionary.text(
|
||||
"Enter your message:",
|
||||
multiline=multiline_input,
|
||||
qmark=">",
|
||||
).ask()
|
||||
clear_line(strip_ui)
|
||||
clear_line(console, strip_ui=strip_ui)
|
||||
if not stream:
|
||||
print()
|
||||
|
||||
# Gracefully exit on Ctrl-C/D
|
||||
if user_input is None:
|
||||
@ -157,13 +174,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
command = user_input.strip().split()
|
||||
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0
|
||||
if amount == 0:
|
||||
interface.print_messages(memgpt_agent._messages, dump=True)
|
||||
memgpt_agent.interface.print_messages(memgpt_agent._messages, dump=True)
|
||||
else:
|
||||
interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
|
||||
memgpt_agent.interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/dumpraw":
|
||||
interface.print_messages_raw(memgpt_agent._messages)
|
||||
memgpt_agent.interface.print_messages_raw(memgpt_agent._messages)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/memory":
|
||||
@ -194,9 +211,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
else:
|
||||
print(f"Popping last {pop_amount} messages from stack")
|
||||
for _ in range(min(pop_amount, len(memgpt_agent.messages))):
|
||||
memgpt_agent._messages.pop()
|
||||
# Persist the state
|
||||
save_agent(agent=memgpt_agent, ms=ms)
|
||||
memgpt_agent.messages.pop()
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/retry":
|
||||
@ -218,13 +233,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
for x in range(len(memgpt_agent.messages) - 1, 0, -1):
|
||||
if memgpt_agent.messages[x].get("role") == "assistant":
|
||||
text = user_input[len("/rethink ") :].strip()
|
||||
|
||||
# Do the /rethink-ing
|
||||
message_obj = memgpt_agent._messages[x]
|
||||
message_obj.text = text
|
||||
|
||||
# To persist to the database, all we need to do is "re-insert" into recall memory
|
||||
memgpt_agent.persistence_manager.recall_memory.storage.update(record=message_obj)
|
||||
memgpt_agent.messages[x].update({"content": text})
|
||||
break
|
||||
continue
|
||||
|
||||
@ -321,7 +330,7 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
|
||||
# No skip options
|
||||
elif user_input.lower() == "/wipe":
|
||||
memgpt_agent = agent.Agent(interface)
|
||||
memgpt_agent = agent.Agent(memgpt_agent.interface)
|
||||
user_message = None
|
||||
|
||||
elif user_input.lower() == "/heartbeat":
|
||||
@ -354,7 +363,10 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
|
||||
def process_agent_step(user_message, no_verify):
|
||||
new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step(
|
||||
user_message, first_message=False, skip_verify=no_verify
|
||||
user_message,
|
||||
first_message=False,
|
||||
skip_verify=no_verify,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
skip_next_user_input = False
|
||||
@ -376,9 +388,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
break
|
||||
else:
|
||||
with console.status("[bold cyan]Thinking...") as status:
|
||||
if stream:
|
||||
# Don't display the "Thinking..." if streaming
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
break
|
||||
else:
|
||||
with console.status("[bold cyan]Thinking...") as status:
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
print("User interrupt occurred.")
|
||||
retry = questionary.confirm("Retry agent.step()?").ask()
|
||||
|
@ -55,6 +55,8 @@ class UsageStatistics(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
"""https://platform.openai.com/docs/api-reference/chat/object"""
|
||||
|
||||
id: str
|
||||
choices: List[Choice]
|
||||
created: datetime.datetime
|
||||
@ -64,3 +66,64 @@ class ChatCompletionResponse(BaseModel):
|
||||
# object: str = Field(default="chat.completion")
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
usage: UsageStatistics
|
||||
|
||||
|
||||
class FunctionCallDelta(BaseModel):
|
||||
# arguments: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
arguments: str
|
||||
# name: str
|
||||
|
||||
|
||||
class ToolCallDelta(BaseModel):
|
||||
index: int
|
||||
id: Optional[str] = None
|
||||
# "Currently, only function is supported"
|
||||
type: Literal["function"] = "function"
|
||||
# function: ToolCallFunction
|
||||
function: Optional[FunctionCallDelta] = None
|
||||
|
||||
|
||||
class MessageDelta(BaseModel):
|
||||
"""Partial delta stream of a Message
|
||||
|
||||
Example ChunkResponse:
|
||||
{
|
||||
'id': 'chatcmpl-9EOCkKdicNo1tiL1956kPvCnL2lLS',
|
||||
'object': 'chat.completion.chunk',
|
||||
'created': 1713216662,
|
||||
'model': 'gpt-4-0613',
|
||||
'system_fingerprint': None,
|
||||
'choices': [{
|
||||
'index': 0,
|
||||
'delta': {'content': 'User'},
|
||||
'logprobs': None,
|
||||
'finish_reason': None
|
||||
}]
|
||||
}
|
||||
"""
|
||||
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCallDelta]] = None
|
||||
# role: Optional[str] = None
|
||||
function_call: Optional[FunctionCallDelta] = None # Deprecated
|
||||
|
||||
|
||||
class ChunkChoice(BaseModel):
|
||||
finish_reason: Optional[str] = None # NOTE: when streaming will be null
|
||||
index: int
|
||||
delta: MessageDelta
|
||||
logprobs: Optional[Dict[str, Union[List[MessageContentLogProb], None]]] = None
|
||||
|
||||
|
||||
class ChatCompletionChunkResponse(BaseModel):
|
||||
"""https://platform.openai.com/docs/api-reference/chat/streaming"""
|
||||
|
||||
id: str
|
||||
choices: List[ChunkChoice]
|
||||
created: datetime.datetime
|
||||
model: str
|
||||
# system_fingerprint: str # docs say this is mandatory, but in reality API returns None
|
||||
system_fingerprint: Optional[str] = None
|
||||
# object: str = Field(default="chat.completion")
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
|
398
memgpt/streaming_interface.py
Normal file
398
memgpt/streaming_interface.py
Normal file
@ -0,0 +1,398 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
# from colorama import Fore, Style, init
|
||||
from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markup import escape
|
||||
from rich.style import Style
|
||||
from rich.text import Text
|
||||
|
||||
from memgpt.utils import printd
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.models.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse
|
||||
from memgpt.interface import AgentInterface, CLIInterface
|
||||
|
||||
# init(autoreset=True)
|
||||
|
||||
# DEBUG = True # puts full message outputs in the terminal
|
||||
DEBUG = False # only dumps important messages in the terminal
|
||||
|
||||
STRIP_UI = False
|
||||
|
||||
|
||||
class AgentChunkStreamingInterface(ABC):
|
||||
"""Interfaces handle MemGPT-related events (observer pattern)
|
||||
|
||||
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT receives a user message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT generates some internal monologue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT uses send_message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT calls a function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_start(self):
|
||||
"""Any setup required before streaming begins"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_end(self):
|
||||
"""Any cleanup required after streaming ends"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamingCLIInterface(AgentChunkStreamingInterface):
|
||||
"""Version of the CLI interface that attaches to a stream generator and prints along the way.
|
||||
|
||||
When a chunk is received, we write the delta to the buffer. If the buffer type has changed,
|
||||
we write out a newline + set the formatting for the new line.
|
||||
|
||||
The two buffer types are:
|
||||
(1) content (inner thoughts)
|
||||
(2) tool_calls (function calling)
|
||||
|
||||
NOTE: this assumes that the deltas received in the chunks are in-order, e.g.
|
||||
that once 'content' deltas stop streaming, they won't be received again. See notes
|
||||
on alternative version of the StreamingCLIInterface that does not have this same problem below:
|
||||
|
||||
An alternative implementation could instead maintain the partial message state, and on each
|
||||
process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen.
|
||||
"""
|
||||
|
||||
# CLIInterface is static/stateless
|
||||
nonstreaming_interface = CLIInterface()
|
||||
|
||||
def __init__(self):
|
||||
"""The streaming CLI interface state for determining which buffer is currently being written to"""
|
||||
|
||||
self.streaming_buffer_type = None
|
||||
|
||||
def _flush(self):
|
||||
pass
|
||||
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse):
|
||||
assert len(chunk.choices) == 1, chunk
|
||||
|
||||
message_delta = chunk.choices[0].delta
|
||||
|
||||
# Starting a new buffer line
|
||||
if not self.streaming_buffer_type:
|
||||
assert not (
|
||||
message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls)
|
||||
), f"Error: got both content and tool_calls in message stream\n{message_delta}"
|
||||
|
||||
if message_delta.content is not None:
|
||||
# Write out the prefix for inner thoughts
|
||||
print("Inner thoughts: ", end="", flush=True)
|
||||
elif message_delta.tool_calls is not None:
|
||||
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
|
||||
# Write out the prefix for function calling
|
||||
print("Calling function: ", end="", flush=True)
|
||||
|
||||
# Potentially switch/flush a buffer line
|
||||
else:
|
||||
pass
|
||||
|
||||
# Write out the delta
|
||||
if message_delta.content is not None:
|
||||
if self.streaming_buffer_type and self.streaming_buffer_type != "content":
|
||||
print()
|
||||
self.streaming_buffer_type = "content"
|
||||
|
||||
# Simple, just write out to the buffer
|
||||
print(message_delta.content, end="", flush=True)
|
||||
|
||||
elif message_delta.tool_calls is not None:
|
||||
if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls":
|
||||
print()
|
||||
self.streaming_buffer_type = "tool_calls"
|
||||
|
||||
assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}"
|
||||
function_call = message_delta.tool_calls[0].function
|
||||
|
||||
# Slightly more complex - want to write parameters in a certain way (paren-style)
|
||||
# function_name(function_args)
|
||||
if function_call.name:
|
||||
# NOTE: need to account for closing the brace later
|
||||
print(f"{function_call.name}(", end="", flush=True)
|
||||
if function_call.arguments:
|
||||
print(function_call.arguments, end="", flush=True)
|
||||
|
||||
def stream_start(self):
|
||||
# should be handled by stream_end(), but just in case
|
||||
self.streaming_buffer_type = None
|
||||
|
||||
def stream_end(self):
|
||||
if self.streaming_buffer_type is not None:
|
||||
# TODO: should have a separate self.tool_call_open_paren flag
|
||||
if self.streaming_buffer_type == "tool_calls":
|
||||
print(")", end="", flush=True)
|
||||
|
||||
print() # newline to move the cursor
|
||||
self.streaming_buffer_type = None # reset buffer tracker
|
||||
|
||||
@staticmethod
|
||||
def important_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg)
|
||||
|
||||
@staticmethod
|
||||
def warning_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg)
|
||||
|
||||
@staticmethod
|
||||
def internal_monologue(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def assistant_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def memory_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def system_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def print_messages(message_sequence: List[Message], dump=False):
|
||||
StreamingCLIInterface.nonstreaming_interface(message_sequence, dump)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_simple(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_raw(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def step_yield():
|
||||
pass
|
||||
|
||||
|
||||
class AgentRefreshStreamingInterface(ABC):
|
||||
"""Same as the ChunkStreamingInterface, but
|
||||
|
||||
The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT receives a user message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT generates some internal monologue"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT uses send_message"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT calls a function"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_refresh(self, response: ChatCompletionResponse):
|
||||
"""Process a streaming chunk from an OpenAI-compatible server"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_start(self):
|
||||
"""Any setup required before streaming begins"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def stream_end(self):
|
||||
"""Any cleanup required after streaming ends"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def toggle_streaming(self, on: bool):
|
||||
"""Toggle streaming on/off (off = regular CLI interface)"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface):
|
||||
"""Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step.
|
||||
|
||||
We maintain the partial message state in the interface state, and on each
|
||||
process chunk we:
|
||||
(1) update the partial message state,
|
||||
(2) refresh/rewrite the state to the screen.
|
||||
"""
|
||||
|
||||
nonstreaming_interface = CLIInterface
|
||||
|
||||
def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True):
|
||||
"""Initialize the streaming CLI interface state."""
|
||||
self.console = Console()
|
||||
|
||||
# Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates
|
||||
self.live = Live("", console=self.console, refresh_per_second=10)
|
||||
# self.live.start() # Start the Live display context and keep it running
|
||||
|
||||
# Use italics / emoji?
|
||||
self.fancy = fancy
|
||||
|
||||
self.streaming = True
|
||||
self.separate_send_message = separate_send_message
|
||||
self.disable_inner_mono_call = disable_inner_mono_call
|
||||
|
||||
def toggle_streaming(self, on: bool):
|
||||
self.streaming = on
|
||||
if on:
|
||||
self.separate_send_message = True
|
||||
self.disable_inner_mono_call = True
|
||||
else:
|
||||
self.separate_send_message = False
|
||||
self.disable_inner_mono_call = False
|
||||
|
||||
def update_output(self, content: str):
|
||||
"""Update the displayed output with new content."""
|
||||
# We use the `Live` object's update mechanism to refresh content without clearing the console
|
||||
if not self.fancy:
|
||||
content = escape(content)
|
||||
self.live.update(self.console.render_str(content), refresh=True)
|
||||
|
||||
def process_refresh(self, response: ChatCompletionResponse):
|
||||
"""Process the response to rewrite the current output buffer."""
|
||||
if not response.choices:
|
||||
self.update_output("💭 [italic]...[/italic]")
|
||||
return # Early exit if there are no choices
|
||||
|
||||
choice = response.choices[0]
|
||||
inner_thoughts = choice.message.content if choice.message.content else ""
|
||||
tool_calls = choice.message.tool_calls if choice.message.tool_calls else []
|
||||
|
||||
if self.fancy:
|
||||
message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else ""
|
||||
else:
|
||||
message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else ""
|
||||
|
||||
if tool_calls:
|
||||
function_call = tool_calls[0].function
|
||||
function_name = function_call.name # Function name, can be an empty string
|
||||
function_args = function_call.arguments # Function arguments, can be an empty string
|
||||
if message_string:
|
||||
message_string += "\n"
|
||||
# special case here for send_message
|
||||
if self.separate_send_message and function_name == "send_message":
|
||||
try:
|
||||
message = json.loads(function_args)["message"]
|
||||
except:
|
||||
prefix = '{\n "message": "'
|
||||
if len(function_args) < len(prefix):
|
||||
message = "..."
|
||||
elif function_args.startswith(prefix):
|
||||
message = function_args[len(prefix) :]
|
||||
else:
|
||||
message = function_args
|
||||
message_string += f"🤖 [bold yellow]{message}[/bold yellow]"
|
||||
else:
|
||||
message_string += f"{function_name}({function_args})"
|
||||
|
||||
self.update_output(message_string)
|
||||
|
||||
def stream_start(self):
|
||||
if self.streaming:
|
||||
print()
|
||||
self.live.start() # Start the Live display context and keep it running
|
||||
self.update_output("💭 [italic]...[/italic]")
|
||||
|
||||
def stream_end(self):
|
||||
if self.streaming:
|
||||
if self.live.is_started:
|
||||
self.live.stop()
|
||||
print()
|
||||
self.live = Live("", console=self.console, refresh_per_second=10)
|
||||
|
||||
@staticmethod
|
||||
def important_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface.important_message(msg)
|
||||
|
||||
@staticmethod
|
||||
def warning_message(msg: str):
|
||||
StreamingCLIInterface.nonstreaming_interface.warning_message(msg)
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
if self.disable_inner_mono_call:
|
||||
return
|
||||
StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj)
|
||||
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
if self.separate_send_message:
|
||||
return
|
||||
StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def memory_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def system_message(msg: str, msg_obj: Optional[Message] = None):
|
||||
StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG):
|
||||
StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj)
|
||||
|
||||
@staticmethod
|
||||
def print_messages(message_sequence: List[Message], dump=False):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_simple(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def print_messages_raw(message_sequence: List[Message]):
|
||||
StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence)
|
||||
|
||||
@staticmethod
|
||||
def step_yield():
|
||||
pass
|
15
poetry.lock
generated
15
poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.0 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
@ -1524,6 +1524,17 @@ cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"]
|
||||
http2 = ["h2 (>=3,<5)"]
|
||||
socks = ["socksio (==1.*)"]
|
||||
|
||||
[[package]]
|
||||
name = "httpx-sse"
|
||||
version = "0.4.0"
|
||||
description = "Consume Server-Sent Event (SSE) messages with HTTPX."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "httpx-sse-0.4.0.tar.gz", hash = "sha256:1e81a3a3070ce322add1d3529ed42eb5f70817f45ed6ec915ab753f961139721"},
|
||||
{file = "httpx_sse-0.4.0-py3-none-any.whl", hash = "sha256:f329af6eae57eaa2bdfd962b42524764af68075ea87370a2de920af5341e318f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "huggingface-hub"
|
||||
version = "0.22.2"
|
||||
@ -6094,4 +6105,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.10"
|
||||
content-hash = "5c36931d717323eab3eea32bf383b27578ea8f3467fd230ce543af364caffa92"
|
||||
content-hash = "a9635dccf8bd7d826f776e36a9d6fbc845a1b7de0586d06c6a9ce7230a5a14bc"
|
||||
|
@ -57,6 +57,7 @@ llama-index-embeddings-openai = "^0.1.1"
|
||||
llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true}
|
||||
llama-index-embeddings-azure-openai = "^0.1.6"
|
||||
python-multipart = "^0.0.9"
|
||||
httpx-sse = "^0.4.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["llama-index-embeddings-huggingface"]
|
||||
|
Loading…
Reference in New Issue
Block a user