feat: Reverse inner thoughts for chat completions endpoint (#1081)

This commit is contained in:
Matthew Zhou 2025-02-20 12:30:34 -08:00 committed by GitHub
parent 81956d7d3a
commit 9ac089a143
8 changed files with 111 additions and 35 deletions

View File

@ -322,6 +322,7 @@ class Agent(BaseAgent):
max_delay: float = 10.0, # max delay between retries
step_count: Optional[int] = None,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""
log_telemetry(self.logger, "_get_ai_reply start")
@ -367,6 +368,7 @@ class Agent(BaseAgent):
force_tool_call=force_tool_call,
stream=stream,
stream_interface=self.interface,
put_inner_thoughts_first=put_inner_thoughts_first,
)
log_telemetry(self.logger, "_get_ai_reply create finish")
@ -648,6 +650,7 @@ class Agent(BaseAgent):
# additional args
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
**kwargs,
) -> LettaUsageStatistics:
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
@ -662,6 +665,7 @@ class Agent(BaseAgent):
kwargs["last_function_failed"] = function_failed
step_response = self.inner_step(
messages=next_input_message,
put_inner_thoughts_first=put_inner_thoughts_first,
**kwargs,
)
@ -743,6 +747,7 @@ class Agent(BaseAgent):
metadata: Optional[dict] = None,
summarize_attempt_count: int = 0,
last_function_failed: bool = False,
put_inner_thoughts_first: bool = True,
) -> AgentStepResponse:
"""Runs a single step in the agent loop (generates at most one LLM call)"""
@ -778,6 +783,7 @@ class Agent(BaseAgent):
stream=stream,
step_count=step_count,
last_function_failed=last_function_failed,
put_inner_thoughts_first=put_inner_thoughts_first,
)
if not response:
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early

View File

@ -202,21 +202,29 @@ def add_inner_thoughts_to_functions(
inner_thoughts_key: str,
inner_thoughts_description: str,
inner_thoughts_required: bool = True,
put_inner_thoughts_first: bool = True,
) -> List[dict]:
"""Add an inner_thoughts kwarg to every function in the provided list, ensuring it's the first parameter"""
new_functions = []
for function_object in functions:
new_function_object = copy.deepcopy(function_object)
# Create a new OrderedDict with inner_thoughts as the first item
new_properties = OrderedDict()
new_properties[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}
# Add the rest of the properties
new_properties.update(function_object["parameters"]["properties"])
# For chat completions, we want inner thoughts to come later
if put_inner_thoughts_first:
# Create with inner_thoughts as the first item
new_properties[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}
# Add the rest of the properties
new_properties.update(function_object["parameters"]["properties"])
else:
new_properties.update(function_object["parameters"]["properties"])
new_properties[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}
# Cast OrderedDict back to a regular dict
new_function_object["parameters"]["properties"] = dict(new_properties)
@ -225,9 +233,11 @@ def add_inner_thoughts_to_functions(
if inner_thoughts_required:
required_params = new_function_object["parameters"].get("required", [])
if inner_thoughts_key not in required_params:
required_params.insert(0, inner_thoughts_key)
if put_inner_thoughts_first:
required_params.insert(0, inner_thoughts_key)
else:
required_params.append(inner_thoughts_key)
new_function_object["parameters"]["required"] = required_params
new_functions.append(new_function_object)
return new_functions

View File

@ -140,6 +140,7 @@ def create(
stream: bool = False,
stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None,
model_settings: Optional[dict] = None, # TODO: eventually pass from server
put_inner_thoughts_first: bool = True,
) -> ChatCompletionResponse:
"""Return response to chat completion with backoff"""
from letta.utils import printd
@ -185,7 +186,9 @@ def create(
else:
function_call = "required"
data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming)
data = build_openai_chat_completions_request(
llm_config, messages, user_id, functions, function_call, use_tool_naming, put_inner_thoughts_first=put_inner_thoughts_first
)
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(

View File

@ -94,6 +94,7 @@ def build_openai_chat_completions_request(
functions: Optional[list],
function_call: Optional[str],
use_tool_naming: bool,
put_inner_thoughts_first: bool = True,
) -> ChatCompletionRequest:
if functions and llm_config.put_inner_thoughts_in_kwargs:
# Special case for LM Studio backend since it needs extra guidance to force out the thoughts first
@ -105,6 +106,7 @@ def build_openai_chat_completions_request(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
inner_thoughts_description=inner_thoughts_desc,
put_inner_thoughts_first=put_inner_thoughts_first,
)
openai_message_list = [

View File

@ -56,6 +56,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
self.current_function_name = ""
self.current_function_arguments = []
self.current_json_parse_result = {}
self._found_message_tool_kwarg = False
# Internal chunk buffer and event for async notification
self._chunks = deque()
@ -160,8 +161,10 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
Called externally with a ChatCompletionChunkResponse. Transforms
it if necessary, then enqueues partial messages for streaming back.
"""
# print("RECEIVED CHUNK...")
# print(chunk)
processed_chunk = self._process_chunk_to_openai_style(chunk)
# print(processed_chunk)
# print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
if processed_chunk is not None:
self._push_to_buffer(processed_chunk)
@ -199,6 +202,10 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
content (especially from a 'send_message' tool) is exposed as text
deltas in 'content'. Otherwise, pass through or yield finish reasons.
"""
# If we've already sent the final chunk, ignore everything.
if self._found_message_tool_kwarg:
return None
choice = chunk.choices[0]
delta = choice.delta
@ -221,25 +228,43 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
combined_args = "".join(self.current_function_arguments)
parsed_args = OptimisticJSONParser().parse(combined_args)
# If the parsed result is different
# This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
if parsed_args != self.current_json_parse_result:
self.current_json_parse_result = parsed_args
# If we can see a "message" field, return it as partial content
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
# TODO: Make this less brittle! This depends on `message` coming first!
# This is a heuristic we use to know if we're done with the `message` part of `send_message`
if len(parsed_args.keys()) > 1:
self._found_message_tool_kwarg = True
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(),
finish_reason="stop",
)
],
)
else:
# If the parsed result is different
# This is an edge case we need to consider. E.g. if the last streamed token is '}', we shouldn't stream that out
if parsed_args != self.current_json_parse_result:
self.current_json_parse_result = parsed_args
# If we can see a "message" field, return it as partial content
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
# If there's a finish reason, pass that along
if choice.finish_reason is not None:

View File

@ -138,6 +138,7 @@ async def send_message_to_agent_chat_completions(
agent_id=letta_agent.agent_state.id,
messages=messages,
interface=streaming_interface,
put_inner_thoughts_first=False,
)
)

View File

@ -336,6 +336,7 @@ class SyncServer(Server):
agent_id: str,
input_messages: Union[Message, List[Message]],
interface: Union[AgentInterface, None] = None, # needed to getting responses
put_inner_thoughts_first: bool = True,
# timestamp: Optional[datetime],
) -> LettaUsageStatistics:
"""Send the input message through the agent"""
@ -368,6 +369,7 @@ class SyncServer(Server):
stream=token_streaming,
skip_verify=True,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
except Exception as e:
@ -625,6 +627,7 @@ class SyncServer(Server):
wrap_system_message: bool = True,
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed to getting responses
metadata: Optional[dict] = None, # Pass through metadata to interface
put_inner_thoughts_first: bool = True,
) -> LettaUsageStatistics:
"""Send a list of messages to the agent
@ -675,7 +678,13 @@ class SyncServer(Server):
interface.metadata = metadata
# Run the agent state forward
return self._step(actor=actor, agent_id=agent_id, input_messages=message_objects, interface=interface)
return self._step(
actor=actor,
agent_id=agent_id,
input_messages=message_objects,
interface=interface,
put_inner_thoughts_first=put_inner_thoughts_first,
)
# @LockingServer.agent_lock_decorator
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:

View File

@ -122,7 +122,7 @@ def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, messag
try:
chunks = list(response)
assert len(chunks) > 1, "Streaming response did not return enough chunks (may have failed silently)."
assert len(chunks) > 5, "Streaming response did not return enough chunks (may have failed silently)."
for idx, chunk in enumerate(chunks):
assert chunk, f"Empty chunk received at index {idx}."
@ -142,14 +142,34 @@ async def test_chat_completions_streaming_async(client, agent, message):
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
received_chunks = 0
stop_chunk_count = 0
last_chunk = None
try:
async with stream:
async for chunk in stream:
print(chunk)
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
assert chunk.choices[0].delta.content, f"Chunk at index 0 has no content: {chunk.model_dump_json(indent=4)}"
# Track last chunk for final verification
last_chunk = chunk
# If this chunk has a finish reason of "stop", track it
if chunk.choices[0].finish_reason == "stop":
stop_chunk_count += 1
# Fail early if more than one stop chunk is sent
assert stop_chunk_count == 1, f"Multiple stop chunks detected: {chunk.model_dump_json(indent=4)}"
continue
# Validate regular content chunks
assert chunk.choices[0].delta.content, f"Chunk at index {received_chunks} has no content: {chunk.model_dump_json(indent=4)}"
received_chunks += 1
except Exception as e:
pytest.fail(f"Streaming failed with exception: {e}")
assert received_chunks > 1, "No valid streaming chunks were received."
assert received_chunks > 0, "No valid streaming chunks were received."
# Ensure the last chunk is the expected stop chunk
assert last_chunk is not None, "No last chunk received."
assert last_chunk.choices[0].finish_reason == "stop", f"Last chunk did not indicate stop: {last_chunk.model_dump_json(indent=4)}"