fix: Patch dummy message and fix test (#2192)

This commit is contained in:
Matthew Zhou 2024-12-07 13:11:46 -08:00 committed by GitHub
parent 3ee3793a4f
commit ee0ab8d7d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 24 additions and 13 deletions

View File

@ -22,7 +22,7 @@ from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.file import FileMetadata
from letta.schemas.job import Job
from letta.schemas.letta_request import LettaRequest
from letta.schemas.letta_request import LettaRequest, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import (
@ -965,8 +965,10 @@ class RESTClient(AbstractClient):
if stream_tokens or stream_steps:
from letta.client.streaming import _sse_post
request = LettaStreamingRequest(messages=messages, stream_tokens=stream_tokens)
return _sse_post(f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages/stream", request.model_dump(), self.headers)
else:
request = LettaRequest(messages=messages)
response = requests.post(
f"{self.base_url}/{self.api_prefix}/agents/{agent_id}/messages", json=request.model_dump(), headers=self.headers
)

View File

@ -217,7 +217,6 @@ def openai_chat_completions_process_stream(
dummy_message = _Message(
role=_MessageRole.assistant,
text="",
user_id="",
agent_id="",
model="",
name=None,

View File

@ -182,14 +182,16 @@ class ToolCreate(LettaBase):
@classmethod
def load_default_composio_tools(cls) -> List["ToolCreate"]:
from composio_langchain import Action
pass
calculator = ToolCreate.from_composio(action_name=Action.MATHEMATICAL_CALCULATOR.name)
serp_news = ToolCreate.from_composio(action_name=Action.SERPAPI_NEWS_SEARCH.name)
serp_google_search = ToolCreate.from_composio(action_name=Action.SERPAPI_SEARCH.name)
serp_google_maps = ToolCreate.from_composio(action_name=Action.SERPAPI_GOOGLE_MAPS_SEARCH.name)
# TODO: Disable composio tools for now
# TODO: Naming is causing issues
# calculator = ToolCreate.from_composio(action_name=Action.MATHEMATICAL_CALCULATOR.name)
# serp_news = ToolCreate.from_composio(action_name=Action.SERPAPI_NEWS_SEARCH.name)
# serp_google_search = ToolCreate.from_composio(action_name=Action.SERPAPI_SEARCH.name)
# serp_google_maps = ToolCreate.from_composio(action_name=Action.SERPAPI_GOOGLE_MAPS_SEARCH.name)
return [calculator, serp_news, serp_google_search, serp_google_maps]
return []
class ToolUpdate(LettaBase):

View File

@ -170,8 +170,7 @@ def check_agent_uses_external_tool(filename: str) -> LettaResponse:
# Set up client
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
# tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
tool = client.load_composio_tool(action=Action.WEBTOOL_SCRAPE_WEBSITE_CONTENT)
tool = client.load_composio_tool(action=Action.GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER)
tool_name = tool.name
# Set up persona for tool usage

View File

@ -81,7 +81,7 @@ def client(request):
# use local client (no server)
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
@ -223,7 +223,8 @@ def test_core_memory(mock_e2b_api_key_none, client: Union[LocalClient, RESTClien
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
def test_streaming_send_message(mock_e2b_api_key_none, client: Union[LocalClient, RESTClient], agent: AgentState):
@pytest.mark.parametrize("stream_tokens", [True, False])
def test_streaming_send_message(mock_e2b_api_key_none, client: RESTClient, agent: AgentState, stream_tokens):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")
assert isinstance(client, RESTClient), client
@ -236,12 +237,13 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: Union[LocalClient
message="This is a test. Repeat after me: 'banana'",
role="user",
stream_steps=True,
stream_tokens=True,
stream_tokens=stream_tokens,
)
# Some manual checks to run
# 1. Check that there were inner thoughts
inner_thoughts_exist = False
inner_thoughts_count = 0
# 2. Check that the agent runs `send_message`
send_message_ran = False
# 3. Check that we get all the start/stop/end tokens we want
@ -256,6 +258,7 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: Union[LocalClient
assert isinstance(chunk, LettaStreamingResponse)
if isinstance(chunk, InternalMonologue) and chunk.internal_monologue and chunk.internal_monologue != "":
inner_thoughts_exist = True
inner_thoughts_count += 1
if isinstance(chunk, FunctionCallMessage) and chunk.function_call and chunk.function_call.name == "send_message":
send_message_ran = True
if isinstance(chunk, MessageStreamStatus):
@ -275,6 +278,12 @@ def test_streaming_send_message(mock_e2b_api_key_none, client: Union[LocalClient
assert chunk.prompt_tokens > 1000
assert chunk.total_tokens > 1000
# If stream tokens, we expect at least one inner thought
if stream_tokens:
assert inner_thoughts_count > 1, "Expected more than one inner thought"
else:
assert inner_thoughts_count == 1, "Expected one inner thought"
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
assert done, "Message stream not done"