mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Reverse inner thoughts for chat completions endpoint (#1081)
This commit is contained in:
parent
81956d7d3a
commit
9ac089a143
@ -322,6 +322,7 @@ class Agent(BaseAgent):
|
||||
max_delay: float = 10.0, # max delay between retries
|
||||
step_count: Optional[int] = None,
|
||||
last_function_failed: bool = False,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
log_telemetry(self.logger, "_get_ai_reply start")
|
||||
@ -367,6 +368,7 @@ class Agent(BaseAgent):
|
||||
force_tool_call=force_tool_call,
|
||||
stream=stream,
|
||||
stream_interface=self.interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
log_telemetry(self.logger, "_get_ai_reply create finish")
|
||||
|
||||
@ -648,6 +650,7 @@ class Agent(BaseAgent):
|
||||
# additional args
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
|
||||
@ -662,6 +665,7 @@ class Agent(BaseAgent):
|
||||
kwargs["last_function_failed"] = function_failed
|
||||
step_response = self.inner_step(
|
||||
messages=next_input_message,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -743,6 +747,7 @@ class Agent(BaseAgent):
|
||||
metadata: Optional[dict] = None,
|
||||
summarize_attempt_count: int = 0,
|
||||
last_function_failed: bool = False,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> AgentStepResponse:
|
||||
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
||||
|
||||
@ -778,6 +783,7 @@ class Agent(BaseAgent):
|
||||
stream=stream,
|
||||
step_count=step_count,
|
||||
last_function_failed=last_function_failed,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
if not response:
|
||||
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
|
||||
|
@ -202,21 +202,29 @@ def add_inner_thoughts_to_functions(
|
||||
inner_thoughts_key: str,
|
||||
inner_thoughts_description: str,
|
||||
inner_thoughts_required: bool = True,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> List[dict]:
|
||||
"""Add an inner_thoughts kwarg to every function in the provided list, ensuring it's the first parameter"""
|
||||
new_functions = []
|
||||
for function_object in functions:
|
||||
new_function_object = copy.deepcopy(function_object)
|
||||
|
||||
# Create a new OrderedDict with inner_thoughts as the first item
|
||||
new_properties = OrderedDict()
|
||||
new_properties[inner_thoughts_key] = {
|
||||
"type": "string",
|
||||
"description": inner_thoughts_description,
|
||||
}
|
||||
|
||||
# Add the rest of the properties
|
||||
new_properties.update(function_object["parameters"]["properties"])
|
||||
# For chat completions, we want inner thoughts to come later
|
||||
if put_inner_thoughts_first:
|
||||
# Create with inner_thoughts as the first item
|
||||
new_properties[inner_thoughts_key] = {
|
||||
"type": "string",
|
||||
"description": inner_thoughts_description,
|
||||
}
|
||||
# Add the rest of the properties
|
||||
new_properties.update(function_object["parameters"]["properties"])
|
||||
else:
|
||||
new_properties.update(function_object["parameters"]["properties"])
|
||||
new_properties[inner_thoughts_key] = {
|
||||
"type": "string",
|
||||
"description": inner_thoughts_description,
|
||||
}
|
||||
|
||||
# Cast OrderedDict back to a regular dict
|
||||
new_function_object["parameters"]["properties"] = dict(new_properties)
|
||||
@ -225,9 +233,11 @@ def add_inner_thoughts_to_functions(
|
||||
if inner_thoughts_required:
|
||||
required_params = new_function_object["parameters"].get("required", [])
|
||||
if inner_thoughts_key not in required_params:
|
||||
required_params.insert(0, inner_thoughts_key)
|
||||
if put_inner_thoughts_first:
|
||||
required_params.insert(0, inner_thoughts_key)
|
||||
else:
|
||||
required_params.append(inner_thoughts_key)
|
||||
new_function_object["parameters"]["required"] = required_params
|
||||
|
||||
new_functions.append(new_function_object)
|
||||
|
||||
return new_functions
|
||||
|
@ -140,6 +140,7 @@ def create(
|
||||
stream: bool = False,
|
||||
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
|
||||
model_settings: Optional[dict] = None, # TODO: eventually pass from server
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Return response to chat completion with backoff"""
|
||||
from letta.utils import printd
|
||||
@ -185,7 +186,9 @@ def create(
|
||||
else:
|
||||
function_call = "required"
|
||||
|
||||
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming)
|
||||
data = build_openai_chat_completions_request(
|
||||
llm_config, messages, user_id, functions, function_call, use_tool_naming, put_inner_thoughts_first=put_inner_thoughts_first
|
||||
)
|
||||
if stream: # Client requested token streaming
|
||||
data.stream = True
|
||||
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
|
||||
|
@ -94,6 +94,7 @@ def build_openai_chat_completions_request(
|
||||
functions: Optional[list],
|
||||
function_call: Optional[str],
|
||||
use_tool_naming: bool,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> ChatCompletionRequest:
|
||||
if functions and llm_config.put_inner_thoughts_in_kwargs:
|
||||
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
|
||||
@ -105,6 +106,7 @@ def build_openai_chat_completions_request(
|
||||
functions=functions,
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
inner_thoughts_description=inner_thoughts_desc,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
openai_message_list = [
|
||||
|
@ -56,6 +56,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
self.current_function_name = ""
|
||||
self.current_function_arguments = []
|
||||
self.current_json_parse_result = {}
|
||||
self._found_message_tool_kwarg = False
|
||||
|
||||
# Internal chunk buffer and event for async notification
|
||||
self._chunks = deque()
|
||||
@ -160,8 +161,10 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
Called externally with a ChatCompletionChunkResponse. Transforms
|
||||
it if necessary, then enqueues partial messages for streaming back.
|
||||
"""
|
||||
# print("RECEIVED CHUNK...")
|
||||
# print(chunk)
|
||||
processed_chunk = self._process_chunk_to_openai_style(chunk)
|
||||
# print(processed_chunk)
|
||||
# print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
||||
if processed_chunk is not None:
|
||||
self._push_to_buffer(processed_chunk)
|
||||
|
||||
@ -199,6 +202,10 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
content (especially from a 'send_message' tool) is exposed as text
|
||||
deltas in 'content'. Otherwise, pass through or yield finish reasons.
|
||||
"""
|
||||
# If we've already sent the final chunk, ignore everything.
|
||||
if self._found_message_tool_kwarg:
|
||||
return None
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
@ -221,25 +228,43 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
combined_args = "".join(self.current_function_arguments)
|
||||
parsed_args = OptimisticJSONParser().parse(combined_args)
|
||||
|
||||
# If the parsed result is different
|
||||
# This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
|
||||
if parsed_args != self.current_json_parse_result:
|
||||
self.current_json_parse_result = parsed_args
|
||||
# If we can see a "message" field, return it as partial content
|
||||
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
# TODO: Make this less brittle! This depends on `message` coming first!
|
||||
# This is a heuristic we use to know if we're done with the `message` part of `send_message`
|
||||
if len(parsed_args.keys()) > 1:
|
||||
self._found_message_tool_kwarg = True
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
# If the parsed result is different
|
||||
# This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
|
||||
if parsed_args != self.current_json_parse_result:
|
||||
self.current_json_parse_result = parsed_args
|
||||
# If we can see a "message" field, return it as partial content
|
||||
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# If there's a finish reason, pass that along
|
||||
if choice.finish_reason is not None:
|
||||
|
@ -138,6 +138,7 @@ async def send_message_to_agent_chat_completions(
|
||||
agent_id=letta_agent.agent_state.id,
|
||||
messages=messages,
|
||||
interface=streaming_interface,
|
||||
put_inner_thoughts_first=False,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -336,6 +336,7 @@ class SyncServer(Server):
|
||||
agent_id: str,
|
||||
input_messages: Union[Message, List[Message]],
|
||||
interface: Union[AgentInterface, None] = None, # needed to getting responses
|
||||
put_inner_thoughts_first: bool = True,
|
||||
# timestamp: Optional[datetime],
|
||||
) -> LettaUsageStatistics:
|
||||
"""Send the input message through the agent"""
|
||||
@ -368,6 +369,7 @@ class SyncServer(Server):
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@ -625,6 +627,7 @@ class SyncServer(Server):
|
||||
wrap_system_message: bool = True,
|
||||
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed to getting responses
|
||||
metadata: Optional[dict] = None, # Pass through metadata to interface
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Send a list of messages to the agent
|
||||
|
||||
@ -675,7 +678,13 @@ class SyncServer(Server):
|
||||
interface.metadata = metadata
|
||||
|
||||
# Run the agent state forward
|
||||
return self._step(actor=actor, agent_id=agent_id, input_messages=message_objects, interface=interface)
|
||||
return self._step(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
input_messages=message_objects,
|
||||
interface=interface,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
# @LockingServer.agent_lock_decorator
|
||||
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
|
||||
|
@ -122,7 +122,7 @@ def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, messag
|
||||
|
||||
try:
|
||||
chunks = list(response)
|
||||
assert len(chunks) > 1, "Streaming response did not return enough chunks (may have failed silently)."
|
||||
assert len(chunks) > 5, "Streaming response did not return enough chunks (may have failed silently)."
|
||||
|
||||
for idx, chunk in enumerate(chunks):
|
||||
assert chunk, f"Empty chunk received at index {idx}."
|
||||
@ -142,14 +142,34 @@ async def test_chat_completions_streaming_async(client, agent, message):
|
||||
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
|
||||
|
||||
received_chunks = 0
|
||||
stop_chunk_count = 0
|
||||
last_chunk = None
|
||||
|
||||
try:
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
print(chunk)
|
||||
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
|
||||
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
|
||||
assert chunk.choices[0].delta.content, f"Chunk at index 0 has no content: {chunk.model_dump_json(indent=4)}"
|
||||
|
||||
# Track last chunk for final verification
|
||||
last_chunk = chunk
|
||||
|
||||
# If this chunk has a finish reason of "stop", track it
|
||||
if chunk.choices[0].finish_reason == "stop":
|
||||
stop_chunk_count += 1
|
||||
# Fail early if more than one stop chunk is sent
|
||||
assert stop_chunk_count == 1, f"Multiple stop chunks detected: {chunk.model_dump_json(indent=4)}"
|
||||
continue
|
||||
|
||||
# Validate regular content chunks
|
||||
assert chunk.choices[0].delta.content, f"Chunk at index {received_chunks} has no content: {chunk.model_dump_json(indent=4)}"
|
||||
received_chunks += 1
|
||||
except Exception as e:
|
||||
pytest.fail(f"Streaming failed with exception: {e}")
|
||||
|
||||
assert received_chunks > 1, "No valid streaming chunks were received."
|
||||
assert received_chunks > 0, "No valid streaming chunks were received."
|
||||
|
||||
# Ensure the last chunk is the expected stop chunk
|
||||
assert last_chunk is not None, "No last chunk received."
|
||||
assert last_chunk.choices[0].finish_reason == "stop", f"Last chunk did not indicate stop: {last_chunk.model_dump_json(indent=4)}"
|
||||
|
Loading…
Reference in New Issue
Block a user