refactor: make Agent.step() multi-step (#1884)

This commit is contained in:
Charles Packer 2024-10-15 13:32:13 -07:00 committed by GitHub
parent 94d2a18c27
commit 4fd82ee81a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 148 additions and 125 deletions

View File

@ -11,13 +11,17 @@ from letta.agent_store.storage import StorageConnector
from letta.constants import (
CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
IN_CONTEXT_MEMORY_KEYWORD,
LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import LLMError
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore
@ -32,11 +36,15 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
)
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.passage import Passage
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.system import (
get_heartbeat,
get_initial_boot_messages,
get_login_event,
get_token_limit_warning,
package_function_response,
package_summarize_message,
package_user_message,
@ -56,9 +64,6 @@ from letta.utils import (
verify_first_message_correctness,
)
from .errors import LLMError
from .llm_api.helpers import is_context_overflow_error
def compile_memory_metadata_block(
memory_edit_timestamp: datetime.datetime,
@ -202,7 +207,7 @@ class BaseAgent(ABC):
def step(
self,
messages: Union[Message, List[Message]],
) -> AgentStepResponse:
) -> LettaUsageStatistics:
"""
Top-level event message handler for the agent.
"""
@ -721,18 +726,105 @@ class Agent(BaseAgent):
return messages, heartbeat_request, function_failed
def step(
self,
messages: Union[Message, List[Message]],
# additional args
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
ms: Optional[MetadataStore] = None,
**kwargs,
) -> LettaUsageStatistics:
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
# assert ms is not None, "MetadataStore is required"
next_input_message = messages if isinstance(messages, list) else [messages]
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
kwargs["ms"] = ms
kwargs["first_message"] = False
step_response = self.inner_step(
messages=next_input_message,
**kwargs,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage
step_count += 1
total_usage += usage
counter += 1
self.interface.step_complete()
# logger.debug("Saving agent state")
# save updated state
if ms:
save_agent(self, ms)
# Chain stops
if not chaining:
printd("No chaining, stopping after one step")
break
elif max_chaining_steps is not None and counter > max_chaining_steps:
printd(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
break
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
def inner_step(
self,
messages: Union[Message, List[Message]],
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True,
# recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
"""Top-level event message handler for the Letta agent"""
"""Runs a single step in the agent loop (generates at most one LLM call)"""
try:
@ -834,13 +926,12 @@ class Agent(BaseAgent):
)
self._append_to_messages(all_new_messages)
messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages
# update state after each step
self.update_state()
return AgentStepResponse(
messages=messages_to_return,
messages=all_new_messages,
heartbeat_request=heartbeat_request,
function_failed=function_failed,
in_context_memory_warning=active_memory_warning,
@ -856,15 +947,12 @@ class Agent(BaseAgent):
self.summarize_messages_inplace()
# Try step again
return self.step(
return self.inner_step(
messages=messages,
first_message=first_message,
first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify,
return_dicts=return_dicts,
# recreate_message_timestamp=recreate_message_timestamp,
stream=stream,
# timestamp=timestamp,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
ms=ms,
)
@ -905,7 +993,7 @@ class Agent(BaseAgent):
# created_at=timestamp,
)
return self.step(messages=[user_message], **kwargs)
return self.inner_step(messages=[user_message], **kwargs)
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
@ -1326,7 +1414,7 @@ class Agent(BaseAgent):
self.pop_until_user()
user_message = self.pop_message(count=1)[0]
assert user_message.text is not None, "User message text is None"
step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False)
step_response = self.step_user_message(user_message_str=user_message.text)
messages = step_response.messages
assert messages is not None

View File

@ -747,8 +747,9 @@ class RESTClient(AbstractClient):
# simplify messages
if not include_full_message:
messages = []
for message in response.messages:
messages += message.to_letta_message()
for m in response.messages:
assert isinstance(m, Message)
messages += m.to_letta_message()
response.messages = messages
return response
@ -1677,7 +1678,7 @@ class LocalClient(AbstractClient):
self.interface.clear()
return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id)
def get_agent_id(self, agent_name: str) -> AgentState:
def get_agent_id(self, agent_name: str) -> Optional[str]:
"""
Get the ID of an agent by name (names are unique per user)
@ -1767,6 +1768,7 @@ class LocalClient(AbstractClient):
self,
message: str,
role: str,
name: Optional[str] = None,
agent_id: Optional[str] = None,
agent_name: Optional[str] = None,
stream_steps: bool = False,
@ -1790,19 +1792,18 @@ class LocalClient(AbstractClient):
# lookup agent by name
assert agent_name, f"Either agent_id or agent_name must be provided"
agent_id = self.get_agent_id(agent_name=agent_name)
agent_state = self.get_agent(agent_id=agent_id)
assert agent_id, f"Agent with name {agent_name} not found"
if stream_steps or stream_tokens:
# TODO: implement streaming with stream=True/False
raise NotImplementedError
self.interface.clear()
if role == "system":
usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message)
elif role == "user":
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message)
else:
raise ValueError(f"Role {role} not supported")
usage = self.server.send_messages(
user_id=self.user_id,
agent_id=agent_id,
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
)
# auto-save
if self.auto_save:

View File

@ -361,8 +361,10 @@ def run_agent_loop(
skip_next_user_input = False
def process_agent_step(user_message, no_verify):
# TODO(charles): update to use agent.step() instead of inner_step()
if user_message is None:
step_response = letta_agent.step(
step_response = letta_agent.inner_step(
messages=[],
first_message=False,
skip_verify=no_verify,
@ -402,15 +404,15 @@ def run_agent_loop(
while True:
try:
if strip_ui:
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
else:
if stream:
# Don't display the "Thinking..." if streaming
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
else:
with console.status("[bold cyan]Thinking...") as status:
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
except KeyboardInterrupt:
print("User interrupt occurred.")

View File

@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, field_validator
@ -121,8 +121,7 @@ class UpdateAgentState(BaseAgent):
class AgentStepResponse(BaseModel):
# TODO remove support for list of dicts
messages: Union[List[Message], List[dict]] = Field(..., description="The messages generated during the agent's step.")
messages: List[Message] = Field(..., description="The messages generated during the agent's step.")
heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).")
function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.")
in_context_memory_warning: bool = Field(

View File

@ -248,7 +248,7 @@ def create_run(
agent_id = thread_id
# TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(agent_id=agent_id)
agent.step(user_message=None) # already has messages added
agent.inner_step(messages=[]) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(get_utc_time().timestamp())
return OpenAIRun(

View File

@ -74,7 +74,6 @@ from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate
@ -411,6 +410,7 @@ class SyncServer(Server):
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")
logger.debug(f"Got input messages: {input_messages}")
letta_agent = None
try:
# Get the agent object (loaded in memory)
@ -422,83 +422,14 @@ class SyncServer(Server):
token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False
logger.debug(f"Starting agent step")
no_verify = True
next_input_message = input_messages
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
step_response = letta_agent.step(
messages=next_input_message,
first_message=False,
skip_verify=no_verify,
return_dicts=False,
stream=token_streaming,
# timestamp=timestamp,
ms=self.ms,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage
step_count += 1
total_usage += usage
counter += 1
letta_agent.interface.step_complete()
logger.debug("Saving agent state")
# save updated state
save_agent(letta_agent, self.ms)
# Chain stops
if not self.chaining:
logger.debug("No chaining, stopping after one step")
break
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
logger.debug(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
break
usage_stats = letta_agent.step(
messages=input_messages,
chaining=self.chaining,
max_chaining_steps=self.max_chaining_steps,
stream=token_streaming,
ms=self.ms,
skip_verify=True,
)
except Exception as e:
logger.error(f"Error in server._step: {e}")
@ -506,9 +437,10 @@ class SyncServer(Server):
raise
finally:
logger.debug("Calling step_yield()")
letta_agent.interface.step_yield()
if letta_agent:
letta_agent.interface.step_yield()
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
return usage_stats
def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
"""Process a CLI command"""

View File

@ -4,6 +4,7 @@ import pytest
from letta import create_client
from letta.client.client import LocalClient, RESTClient
from letta.schemas.agent import AgentState
from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
@ -113,7 +114,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent_state_test.id)
def test_agent_with_shared_blocks(client):
def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
existing_non_template_blocks = [persona_block, human_block]
@ -164,7 +165,7 @@ def test_agent_with_shared_blocks(client):
client.delete_agent(second_agent_state_test.id)
def test_memory(client, agent):
def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
# get agent memory
original_memory = client.get_in_context_memory(agent.id)
assert original_memory is not None
@ -177,7 +178,7 @@ def test_memory(client, agent):
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
def test_archival_memory(client, agent):
def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test functions for interacting with archival memory store"""
# add archival memory
@ -192,12 +193,12 @@ def test_archival_memory(client, agent):
client.delete_archival_memory(agent.id, passage.id)
def test_recall_memory(client, agent):
def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test functions for interacting with recall memory store"""
# send message to the agent
message_str = "Hello"
client.send_message(message_str, "user", agent.id)
client.send_message(message=message_str, role="user", agent_id=agent.id)
# list messages
messages = client.get_messages(agent.id)
@ -216,7 +217,7 @@ def test_recall_memory(client, agent):
assert exists
def test_tools(client):
def test_tools(client: Union[LocalClient, RESTClient]):
def print_tool(message: str):
"""
A tool to print a message
@ -265,7 +266,7 @@ def test_tools(client):
# assert len(client.list_tools()) == orig_tool_length
def test_tools_from_composio_basic(client):
def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
from composio_langchain import Action
from letta.schemas.tool import Tool
@ -286,7 +287,7 @@ def test_tools_from_composio_basic(client):
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
def test_tools_from_crewai(client):
def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
# create crewAI tool
from crewai_tools import ScrapeWebsiteTool
@ -323,7 +324,7 @@ def test_tools_from_crewai(client):
assert expected_content in func(website_url=simple_webpage_url)
def test_tools_from_crewai_with_params(client):
def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
# create crewAI tool
from crewai_tools import ScrapeWebsiteTool
@ -357,7 +358,7 @@ def test_tools_from_crewai_with_params(client):
assert expected_content in func()
def test_tools_from_langchain(client):
def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
@ -391,7 +392,7 @@ def test_tools_from_langchain(client):
assert expected_content in func(query="Albert Einstein")
def test_tool_creation_langchain_missing_imports(client):
def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, RESTClient]):
# create langchain tool
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper