refactor: make Agent.step() multi-step (#1884)

This commit is contained in:
Charles Packer 2024-10-15 13:32:13 -07:00 committed by GitHub
parent 94d2a18c27
commit 4fd82ee81a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 148 additions and 125 deletions

View File

@ -11,13 +11,17 @@ from letta.agent_store.storage import StorageConnector
from letta.constants import ( from letta.constants import (
CLI_WARNING_PREFIX, CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS, FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
IN_CONTEXT_MEMORY_KEYWORD, IN_CONTEXT_MEMORY_KEYWORD,
LLM_MAX_TOKENS, LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC, MESSAGE_SUMMARY_WARNING_FRAC,
REQ_HEARTBEAT_MESSAGE,
) )
from letta.errors import LLMError
from letta.interface import AgentInterface from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create from letta.llm_api.llm_api_tools import create
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore from letta.metadata import MetadataStore
@ -32,11 +36,15 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import ( from letta.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage, Message as ChatCompletionMessage,
) )
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.passage import Passage from letta.schemas.passage import Passage
from letta.schemas.tool import Tool from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.system import ( from letta.system import (
get_heartbeat,
get_initial_boot_messages, get_initial_boot_messages,
get_login_event, get_login_event,
get_token_limit_warning,
package_function_response, package_function_response,
package_summarize_message, package_summarize_message,
package_user_message, package_user_message,
@ -56,9 +64,6 @@ from letta.utils import (
verify_first_message_correctness, verify_first_message_correctness,
) )
from .errors import LLMError
from .llm_api.helpers import is_context_overflow_error
def compile_memory_metadata_block( def compile_memory_metadata_block(
memory_edit_timestamp: datetime.datetime, memory_edit_timestamp: datetime.datetime,
@ -202,7 +207,7 @@ class BaseAgent(ABC):
def step( def step(
self, self,
messages: Union[Message, List[Message]], messages: Union[Message, List[Message]],
) -> AgentStepResponse: ) -> LettaUsageStatistics:
""" """
Top-level event message handler for the agent. Top-level event message handler for the agent.
""" """
@ -721,18 +726,105 @@ class Agent(BaseAgent):
return messages, heartbeat_request, function_failed return messages, heartbeat_request, function_failed
def step( def step(
self,
messages: Union[Message, List[Message]],
# additional args
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
ms: Optional[MetadataStore] = None,
**kwargs,
) -> LettaUsageStatistics:
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
# assert ms is not None, "MetadataStore is required"
next_input_message = messages if isinstance(messages, list) else [messages]
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
kwargs["ms"] = ms
kwargs["first_message"] = False
step_response = self.inner_step(
messages=next_input_message,
**kwargs,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage
step_count += 1
total_usage += usage
counter += 1
self.interface.step_complete()
# logger.debug("Saving agent state")
# save updated state
if ms:
save_agent(self, ms)
# Chain stops
if not chaining:
printd("No chaining, stopping after one step")
break
elif max_chaining_steps is not None and counter > max_chaining_steps:
printd(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
break
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
def inner_step(
self, self,
messages: Union[Message, List[Message]], messages: Union[Message, List[Message]],
first_message: bool = False, first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False, skip_verify: bool = False,
return_dicts: bool = True,
# recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config? stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT, inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None, ms: Optional[MetadataStore] = None,
) -> AgentStepResponse: ) -> AgentStepResponse:
"""Top-level event message handler for the Letta agent""" """Runs a single step in the agent loop (generates at most one LLM call)"""
try: try:
@ -834,13 +926,12 @@ class Agent(BaseAgent):
) )
self._append_to_messages(all_new_messages) self._append_to_messages(all_new_messages)
messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages
# update state after each step # update state after each step
self.update_state() self.update_state()
return AgentStepResponse( return AgentStepResponse(
messages=messages_to_return, messages=all_new_messages,
heartbeat_request=heartbeat_request, heartbeat_request=heartbeat_request,
function_failed=function_failed, function_failed=function_failed,
in_context_memory_warning=active_memory_warning, in_context_memory_warning=active_memory_warning,
@ -856,15 +947,12 @@ class Agent(BaseAgent):
self.summarize_messages_inplace() self.summarize_messages_inplace()
# Try step again # Try step again
return self.step( return self.inner_step(
messages=messages, messages=messages,
first_message=first_message, first_message=first_message,
first_message_retry_limit=first_message_retry_limit, first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify, skip_verify=skip_verify,
return_dicts=return_dicts,
# recreate_message_timestamp=recreate_message_timestamp,
stream=stream, stream=stream,
# timestamp=timestamp,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option, inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
ms=ms, ms=ms,
) )
@ -905,7 +993,7 @@ class Agent(BaseAgent):
# created_at=timestamp, # created_at=timestamp,
) )
return self.step(messages=[user_message], **kwargs) return self.inner_step(messages=[user_message], **kwargs)
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True): def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})" assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
@ -1326,7 +1414,7 @@ class Agent(BaseAgent):
self.pop_until_user() self.pop_until_user()
user_message = self.pop_message(count=1)[0] user_message = self.pop_message(count=1)[0]
assert user_message.text is not None, "User message text is None" assert user_message.text is not None, "User message text is None"
step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False) step_response = self.step_user_message(user_message_str=user_message.text)
messages = step_response.messages messages = step_response.messages
assert messages is not None assert messages is not None

View File

@ -747,8 +747,9 @@ class RESTClient(AbstractClient):
# simplify messages # simplify messages
if not include_full_message: if not include_full_message:
messages = [] messages = []
for message in response.messages: for m in response.messages:
messages += message.to_letta_message() assert isinstance(m, Message)
messages += m.to_letta_message()
response.messages = messages response.messages = messages
return response return response
@ -1677,7 +1678,7 @@ class LocalClient(AbstractClient):
self.interface.clear() self.interface.clear()
return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id) return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id)
def get_agent_id(self, agent_name: str) -> AgentState: def get_agent_id(self, agent_name: str) -> Optional[str]:
""" """
Get the ID of an agent by name (names are unique per user) Get the ID of an agent by name (names are unique per user)
@ -1767,6 +1768,7 @@ class LocalClient(AbstractClient):
self, self,
message: str, message: str,
role: str, role: str,
name: Optional[str] = None,
agent_id: Optional[str] = None, agent_id: Optional[str] = None,
agent_name: Optional[str] = None, agent_name: Optional[str] = None,
stream_steps: bool = False, stream_steps: bool = False,
@ -1790,19 +1792,18 @@ class LocalClient(AbstractClient):
# lookup agent by name # lookup agent by name
assert agent_name, f"Either agent_id or agent_name must be provided" assert agent_name, f"Either agent_id or agent_name must be provided"
agent_id = self.get_agent_id(agent_name=agent_name) agent_id = self.get_agent_id(agent_name=agent_name)
assert agent_id, f"Agent with name {agent_name} not found"
agent_state = self.get_agent(agent_id=agent_id)
if stream_steps or stream_tokens: if stream_steps or stream_tokens:
# TODO: implement streaming with stream=True/False # TODO: implement streaming with stream=True/False
raise NotImplementedError raise NotImplementedError
self.interface.clear() self.interface.clear()
if role == "system":
usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message) usage = self.server.send_messages(
elif role == "user": user_id=self.user_id,
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message) agent_id=agent_id,
else: messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
raise ValueError(f"Role {role} not supported") )
# auto-save # auto-save
if self.auto_save: if self.auto_save:

View File

@ -361,8 +361,10 @@ def run_agent_loop(
skip_next_user_input = False skip_next_user_input = False
def process_agent_step(user_message, no_verify): def process_agent_step(user_message, no_verify):
# TODO(charles): update to use agent.step() instead of inner_step()
if user_message is None: if user_message is None:
step_response = letta_agent.step( step_response = letta_agent.inner_step(
messages=[], messages=[],
first_message=False, first_message=False,
skip_verify=no_verify, skip_verify=no_verify,
@ -402,15 +404,15 @@ def run_agent_loop(
while True: while True:
try: try:
if strip_ui: if strip_ui:
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break break
else: else:
if stream: if stream:
# Don't display the "Thinking..." if streaming # Don't display the "Thinking..." if streaming
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
else: else:
with console.status("[bold cyan]Thinking...") as status: with console.status("[bold cyan]Thinking...") as status:
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify) _, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break break
except KeyboardInterrupt: except KeyboardInterrupt:
print("User interrupt occurred.") print("User interrupt occurred.")

View File

@ -1,7 +1,7 @@
import uuid import uuid
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional
from pydantic import BaseModel, Field, field_validator from pydantic import BaseModel, Field, field_validator
@ -121,8 +121,7 @@ class UpdateAgentState(BaseAgent):
class AgentStepResponse(BaseModel): class AgentStepResponse(BaseModel):
# TODO remove support for list of dicts messages: List[Message] = Field(..., description="The messages generated during the agent's step.")
messages: Union[List[Message], List[dict]] = Field(..., description="The messages generated during the agent's step.")
heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).") heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).")
function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.") function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.")
in_context_memory_warning: bool = Field( in_context_memory_warning: bool = Field(

View File

@ -248,7 +248,7 @@ def create_run(
agent_id = thread_id agent_id = thread_id
# TODO: override preset of agent with request.assistant_id # TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(agent_id=agent_id) agent = server._get_or_load_agent(agent_id=agent_id)
agent.step(user_message=None) # already has messages added agent.inner_step(messages=[]) # already has messages added
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
create_time = int(get_utc_time().timestamp()) create_time = int(get_utc_time().timestamp())
return OpenAIRun( return OpenAIRun(

View File

@ -74,7 +74,6 @@ from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization, OrganizationCreate from letta.schemas.organization import Organization, OrganizationCreate
from letta.schemas.passage import Passage from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.source import Source, SourceCreate, SourceUpdate
@ -411,6 +410,7 @@ class SyncServer(Server):
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}") raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")
logger.debug(f"Got input messages: {input_messages}") logger.debug(f"Got input messages: {input_messages}")
letta_agent = None
try: try:
# Get the agent object (loaded in memory) # Get the agent object (loaded in memory)
@ -422,83 +422,14 @@ class SyncServer(Server):
token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False
logger.debug(f"Starting agent step") logger.debug(f"Starting agent step")
no_verify = True usage_stats = letta_agent.step(
next_input_message = input_messages messages=input_messages,
counter = 0 chaining=self.chaining,
total_usage = UsageStatistics() max_chaining_steps=self.max_chaining_steps,
step_count = 0 stream=token_streaming,
while True: ms=self.ms,
step_response = letta_agent.step( skip_verify=True,
messages=next_input_message, )
first_message=False,
skip_verify=no_verify,
return_dicts=False,
stream=token_streaming,
# timestamp=timestamp,
ms=self.ms,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage
step_count += 1
total_usage += usage
counter += 1
letta_agent.interface.step_complete()
logger.debug("Saving agent state")
# save updated state
save_agent(letta_agent, self.ms)
# Chain stops
if not self.chaining:
logger.debug("No chaining, stopping after one step")
break
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
logger.debug(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
assert letta_agent.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=letta_agent.agent_state.id,
user_id=letta_agent.agent_state.user_id,
model=letta_agent.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
break
except Exception as e: except Exception as e:
logger.error(f"Error in server._step: {e}") logger.error(f"Error in server._step: {e}")
@ -506,9 +437,10 @@ class SyncServer(Server):
raise raise
finally: finally:
logger.debug("Calling step_yield()") logger.debug("Calling step_yield()")
letta_agent.interface.step_yield() if letta_agent:
letta_agent.interface.step_yield()
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count) return usage_stats
def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics: def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
"""Process a CLI command""" """Process a CLI command"""

View File

@ -4,6 +4,7 @@ import pytest
from letta import create_client from letta import create_client
from letta.client.client import LocalClient, RESTClient from letta.client.client import LocalClient, RESTClient
from letta.schemas.agent import AgentState
from letta.schemas.block import Block from letta.schemas.block import Block
from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig from letta.schemas.llm_config import LLMConfig
@ -113,7 +114,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent_state_test.id) client.delete_agent(agent_state_test.id)
def test_agent_with_shared_blocks(client): def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id) persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id) human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
existing_non_template_blocks = [persona_block, human_block] existing_non_template_blocks = [persona_block, human_block]
@ -164,7 +165,7 @@ def test_agent_with_shared_blocks(client):
client.delete_agent(second_agent_state_test.id) client.delete_agent(second_agent_state_test.id)
def test_memory(client, agent): def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
# get agent memory # get agent memory
original_memory = client.get_in_context_memory(agent.id) original_memory = client.get_in_context_memory(agent.id)
assert original_memory is not None assert original_memory is not None
@ -177,7 +178,7 @@ def test_memory(client, agent):
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
def test_archival_memory(client, agent): def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test functions for interacting with archival memory store""" """Test functions for interacting with archival memory store"""
# add archival memory # add archival memory
@ -192,12 +193,12 @@ def test_archival_memory(client, agent):
client.delete_archival_memory(agent.id, passage.id) client.delete_archival_memory(agent.id, passage.id)
def test_recall_memory(client, agent): def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
"""Test functions for interacting with recall memory store""" """Test functions for interacting with recall memory store"""
# send message to the agent # send message to the agent
message_str = "Hello" message_str = "Hello"
client.send_message(message_str, "user", agent.id) client.send_message(message=message_str, role="user", agent_id=agent.id)
# list messages # list messages
messages = client.get_messages(agent.id) messages = client.get_messages(agent.id)
@ -216,7 +217,7 @@ def test_recall_memory(client, agent):
assert exists assert exists
def test_tools(client): def test_tools(client: Union[LocalClient, RESTClient]):
def print_tool(message: str): def print_tool(message: str):
""" """
A tool to print a message A tool to print a message
@ -265,7 +266,7 @@ def test_tools(client):
# assert len(client.list_tools()) == orig_tool_length # assert len(client.list_tools()) == orig_tool_length
def test_tools_from_composio_basic(client): def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
from composio_langchain import Action from composio_langchain import Action
from letta.schemas.tool import Tool from letta.schemas.tool import Tool
@ -286,7 +287,7 @@ def test_tools_from_composio_basic(client):
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable # The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
def test_tools_from_crewai(client): def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
# create crewAI tool # create crewAI tool
from crewai_tools import ScrapeWebsiteTool from crewai_tools import ScrapeWebsiteTool
@ -323,7 +324,7 @@ def test_tools_from_crewai(client):
assert expected_content in func(website_url=simple_webpage_url) assert expected_content in func(website_url=simple_webpage_url)
def test_tools_from_crewai_with_params(client): def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
# create crewAI tool # create crewAI tool
from crewai_tools import ScrapeWebsiteTool from crewai_tools import ScrapeWebsiteTool
@ -357,7 +358,7 @@ def test_tools_from_crewai_with_params(client):
assert expected_content in func() assert expected_content in func()
def test_tools_from_langchain(client): def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
# create langchain tool # create langchain tool
from langchain_community.tools import WikipediaQueryRun from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper from langchain_community.utilities import WikipediaAPIWrapper
@ -391,7 +392,7 @@ def test_tools_from_langchain(client):
assert expected_content in func(query="Albert Einstein") assert expected_content in func(query="Albert Einstein")
def test_tool_creation_langchain_missing_imports(client): def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, RESTClient]):
# create langchain tool # create langchain tool
from langchain_community.tools import WikipediaQueryRun from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper from langchain_community.utilities import WikipediaAPIWrapper