mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
refactor: make Agent.step()
multi-step (#1884)
This commit is contained in:
parent
94d2a18c27
commit
4fd82ee81a
118
letta/agent.py
118
letta/agent.py
@ -11,13 +11,17 @@ from letta.agent_store.storage import StorageConnector
|
||||
from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
IN_CONTEXT_MEMORY_KEYWORD,
|
||||
LLM_MAX_TOKENS,
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
)
|
||||
from letta.errors import LLMError
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
|
||||
from letta.metadata import MetadataStore
|
||||
@ -32,11 +36,15 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import (
|
||||
Message as ChatCompletionMessage,
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.system import (
|
||||
get_heartbeat,
|
||||
get_initial_boot_messages,
|
||||
get_login_event,
|
||||
get_token_limit_warning,
|
||||
package_function_response,
|
||||
package_summarize_message,
|
||||
package_user_message,
|
||||
@ -56,9 +64,6 @@ from letta.utils import (
|
||||
verify_first_message_correctness,
|
||||
)
|
||||
|
||||
from .errors import LLMError
|
||||
from .llm_api.helpers import is_context_overflow_error
|
||||
|
||||
|
||||
def compile_memory_metadata_block(
|
||||
memory_edit_timestamp: datetime.datetime,
|
||||
@ -202,7 +207,7 @@ class BaseAgent(ABC):
|
||||
def step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
) -> AgentStepResponse:
|
||||
) -> LettaUsageStatistics:
|
||||
"""
|
||||
Top-level event message handler for the agent.
|
||||
"""
|
||||
@ -721,18 +726,105 @@ class Agent(BaseAgent):
|
||||
return messages, heartbeat_request, function_failed
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
# additional args
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
ms: Optional[MetadataStore] = None,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
|
||||
# assert ms is not None, "MetadataStore is required"
|
||||
|
||||
next_input_message = messages if isinstance(messages, list) else [messages]
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
while True:
|
||||
kwargs["ms"] = ms
|
||||
kwargs["first_message"] = False
|
||||
step_response = self.inner_step(
|
||||
messages=next_input_message,
|
||||
**kwargs,
|
||||
)
|
||||
step_response.messages
|
||||
heartbeat_request = step_response.heartbeat_request
|
||||
function_failed = step_response.function_failed
|
||||
token_warning = step_response.in_context_memory_warning
|
||||
usage = step_response.usage
|
||||
|
||||
step_count += 1
|
||||
total_usage += usage
|
||||
counter += 1
|
||||
self.interface.step_complete()
|
||||
|
||||
# logger.debug("Saving agent state")
|
||||
# save updated state
|
||||
if ms:
|
||||
save_agent(self, ms)
|
||||
|
||||
# Chain stops
|
||||
if not chaining:
|
||||
printd("No chaining, stopping after one step")
|
||||
break
|
||||
elif max_chaining_steps is not None and counter > max_chaining_steps:
|
||||
printd(f"Hit max chaining steps, stopping after {counter} steps")
|
||||
break
|
||||
# Chain handlers
|
||||
elif token_warning:
|
||||
assert self.agent_state.user_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_token_limit_warning(),
|
||||
},
|
||||
)
|
||||
continue # always chain
|
||||
elif function_failed:
|
||||
assert self.agent_state.user_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
continue # always chain
|
||||
elif heartbeat_request:
|
||||
assert self.agent_state.user_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=self.agent_state.id,
|
||||
user_id=self.agent_state.user_id,
|
||||
model=self.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
continue # always chain
|
||||
# Letta no-op / yield
|
||||
else:
|
||||
break
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
|
||||
def inner_step(
|
||||
self,
|
||||
messages: Union[Message, List[Message]],
|
||||
first_message: bool = False,
|
||||
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
|
||||
skip_verify: bool = False,
|
||||
return_dicts: bool = True,
|
||||
# recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
|
||||
stream: bool = False, # TODO move to config?
|
||||
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
|
||||
ms: Optional[MetadataStore] = None,
|
||||
) -> AgentStepResponse:
|
||||
"""Top-level event message handler for the Letta agent"""
|
||||
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
||||
|
||||
try:
|
||||
|
||||
@ -834,13 +926,12 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
self._append_to_messages(all_new_messages)
|
||||
messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages
|
||||
|
||||
# update state after each step
|
||||
self.update_state()
|
||||
|
||||
return AgentStepResponse(
|
||||
messages=messages_to_return,
|
||||
messages=all_new_messages,
|
||||
heartbeat_request=heartbeat_request,
|
||||
function_failed=function_failed,
|
||||
in_context_memory_warning=active_memory_warning,
|
||||
@ -856,15 +947,12 @@ class Agent(BaseAgent):
|
||||
self.summarize_messages_inplace()
|
||||
|
||||
# Try step again
|
||||
return self.step(
|
||||
return self.inner_step(
|
||||
messages=messages,
|
||||
first_message=first_message,
|
||||
first_message_retry_limit=first_message_retry_limit,
|
||||
skip_verify=skip_verify,
|
||||
return_dicts=return_dicts,
|
||||
# recreate_message_timestamp=recreate_message_timestamp,
|
||||
stream=stream,
|
||||
# timestamp=timestamp,
|
||||
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
|
||||
ms=ms,
|
||||
)
|
||||
@ -905,7 +993,7 @@ class Agent(BaseAgent):
|
||||
# created_at=timestamp,
|
||||
)
|
||||
|
||||
return self.step(messages=[user_message], **kwargs)
|
||||
return self.inner_step(messages=[user_message], **kwargs)
|
||||
|
||||
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
|
||||
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
|
||||
@ -1326,7 +1414,7 @@ class Agent(BaseAgent):
|
||||
self.pop_until_user()
|
||||
user_message = self.pop_message(count=1)[0]
|
||||
assert user_message.text is not None, "User message text is None"
|
||||
step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False)
|
||||
step_response = self.step_user_message(user_message_str=user_message.text)
|
||||
messages = step_response.messages
|
||||
|
||||
assert messages is not None
|
||||
|
@ -747,8 +747,9 @@ class RESTClient(AbstractClient):
|
||||
# simplify messages
|
||||
if not include_full_message:
|
||||
messages = []
|
||||
for message in response.messages:
|
||||
messages += message.to_letta_message()
|
||||
for m in response.messages:
|
||||
assert isinstance(m, Message)
|
||||
messages += m.to_letta_message()
|
||||
response.messages = messages
|
||||
|
||||
return response
|
||||
@ -1677,7 +1678,7 @@ class LocalClient(AbstractClient):
|
||||
self.interface.clear()
|
||||
return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id)
|
||||
|
||||
def get_agent_id(self, agent_name: str) -> AgentState:
|
||||
def get_agent_id(self, agent_name: str) -> Optional[str]:
|
||||
"""
|
||||
Get the ID of an agent by name (names are unique per user)
|
||||
|
||||
@ -1767,6 +1768,7 @@ class LocalClient(AbstractClient):
|
||||
self,
|
||||
message: str,
|
||||
role: str,
|
||||
name: Optional[str] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
agent_name: Optional[str] = None,
|
||||
stream_steps: bool = False,
|
||||
@ -1790,19 +1792,18 @@ class LocalClient(AbstractClient):
|
||||
# lookup agent by name
|
||||
assert agent_name, f"Either agent_id or agent_name must be provided"
|
||||
agent_id = self.get_agent_id(agent_name=agent_name)
|
||||
|
||||
agent_state = self.get_agent(agent_id=agent_id)
|
||||
assert agent_id, f"Agent with name {agent_name} not found"
|
||||
|
||||
if stream_steps or stream_tokens:
|
||||
# TODO: implement streaming with stream=True/False
|
||||
raise NotImplementedError
|
||||
self.interface.clear()
|
||||
if role == "system":
|
||||
usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message)
|
||||
elif role == "user":
|
||||
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message)
|
||||
else:
|
||||
raise ValueError(f"Role {role} not supported")
|
||||
|
||||
usage = self.server.send_messages(
|
||||
user_id=self.user_id,
|
||||
agent_id=agent_id,
|
||||
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
|
||||
)
|
||||
|
||||
# auto-save
|
||||
if self.auto_save:
|
||||
|
@ -361,8 +361,10 @@ def run_agent_loop(
|
||||
skip_next_user_input = False
|
||||
|
||||
def process_agent_step(user_message, no_verify):
|
||||
# TODO(charles): update to use agent.step() instead of inner_step()
|
||||
|
||||
if user_message is None:
|
||||
step_response = letta_agent.step(
|
||||
step_response = letta_agent.inner_step(
|
||||
messages=[],
|
||||
first_message=False,
|
||||
skip_verify=no_verify,
|
||||
@ -402,15 +404,15 @@ def run_agent_loop(
|
||||
while True:
|
||||
try:
|
||||
if strip_ui:
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
break
|
||||
else:
|
||||
if stream:
|
||||
# Don't display the "Thinking..." if streaming
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
else:
|
||||
with console.status("[bold cyan]Thinking...") as status:
|
||||
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
print("User interrupt occurred.")
|
||||
|
@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
@ -121,8 +121,7 @@ class UpdateAgentState(BaseAgent):
|
||||
|
||||
|
||||
class AgentStepResponse(BaseModel):
|
||||
# TODO remove support for list of dicts
|
||||
messages: Union[List[Message], List[dict]] = Field(..., description="The messages generated during the agent's step.")
|
||||
messages: List[Message] = Field(..., description="The messages generated during the agent's step.")
|
||||
heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).")
|
||||
function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.")
|
||||
in_context_memory_warning: bool = Field(
|
||||
|
@ -248,7 +248,7 @@ def create_run(
|
||||
agent_id = thread_id
|
||||
# TODO: override preset of agent with request.assistant_id
|
||||
agent = server._get_or_load_agent(agent_id=agent_id)
|
||||
agent.step(user_message=None) # already has messages added
|
||||
agent.inner_step(messages=[]) # already has messages added
|
||||
run_id = str(uuid.uuid4())
|
||||
create_time = int(get_utc_time().timestamp())
|
||||
return OpenAIRun(
|
||||
|
@ -74,7 +74,6 @@ from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.organization import Organization, OrganizationCreate
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
@ -411,6 +410,7 @@ class SyncServer(Server):
|
||||
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")
|
||||
|
||||
logger.debug(f"Got input messages: {input_messages}")
|
||||
letta_agent = None
|
||||
try:
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
@ -422,83 +422,14 @@ class SyncServer(Server):
|
||||
token_streaming = letta_agent.interface.streaming_mode if hasattr(letta_agent.interface, "streaming_mode") else False
|
||||
|
||||
logger.debug(f"Starting agent step")
|
||||
no_verify = True
|
||||
next_input_message = input_messages
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
while True:
|
||||
step_response = letta_agent.step(
|
||||
messages=next_input_message,
|
||||
first_message=False,
|
||||
skip_verify=no_verify,
|
||||
return_dicts=False,
|
||||
stream=token_streaming,
|
||||
# timestamp=timestamp,
|
||||
ms=self.ms,
|
||||
)
|
||||
step_response.messages
|
||||
heartbeat_request = step_response.heartbeat_request
|
||||
function_failed = step_response.function_failed
|
||||
token_warning = step_response.in_context_memory_warning
|
||||
usage = step_response.usage
|
||||
|
||||
step_count += 1
|
||||
total_usage += usage
|
||||
counter += 1
|
||||
letta_agent.interface.step_complete()
|
||||
|
||||
logger.debug("Saving agent state")
|
||||
# save updated state
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
# Chain stops
|
||||
if not self.chaining:
|
||||
logger.debug("No chaining, stopping after one step")
|
||||
break
|
||||
elif self.max_chaining_steps is not None and counter > self.max_chaining_steps:
|
||||
logger.debug(f"Hit max chaining steps, stopping after {counter} steps")
|
||||
break
|
||||
# Chain handlers
|
||||
elif token_warning:
|
||||
assert letta_agent.agent_state.user_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=letta_agent.agent_state.id,
|
||||
user_id=letta_agent.agent_state.user_id,
|
||||
model=letta_agent.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": system.get_token_limit_warning(),
|
||||
},
|
||||
)
|
||||
continue # always chain
|
||||
elif function_failed:
|
||||
assert letta_agent.agent_state.user_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=letta_agent.agent_state.id,
|
||||
user_id=letta_agent.agent_state.user_id,
|
||||
model=letta_agent.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
continue # always chain
|
||||
elif heartbeat_request:
|
||||
assert letta_agent.agent_state.user_id is not None
|
||||
next_input_message = Message.dict_to_message(
|
||||
agent_id=letta_agent.agent_state.id,
|
||||
user_id=letta_agent.agent_state.user_id,
|
||||
model=letta_agent.model,
|
||||
openai_message_dict={
|
||||
"role": "user", # TODO: change to system?
|
||||
"content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE),
|
||||
},
|
||||
)
|
||||
continue # always chain
|
||||
# Letta no-op / yield
|
||||
else:
|
||||
break
|
||||
usage_stats = letta_agent.step(
|
||||
messages=input_messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
ms=self.ms,
|
||||
skip_verify=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in server._step: {e}")
|
||||
@ -506,9 +437,10 @@ class SyncServer(Server):
|
||||
raise
|
||||
finally:
|
||||
logger.debug("Calling step_yield()")
|
||||
letta_agent.interface.step_yield()
|
||||
if letta_agent:
|
||||
letta_agent.interface.step_yield()
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
return usage_stats
|
||||
|
||||
def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
|
||||
"""Process a CLI command"""
|
||||
|
@ -4,6 +4,7 @@ import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient, RESTClient
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
@ -113,7 +114,7 @@ def test_agent(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent_state_test.id)
|
||||
|
||||
|
||||
def test_agent_with_shared_blocks(client):
|
||||
def test_agent_with_shared_blocks(client: Union[LocalClient, RESTClient]):
|
||||
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
|
||||
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
|
||||
existing_non_template_blocks = [persona_block, human_block]
|
||||
@ -164,7 +165,7 @@ def test_agent_with_shared_blocks(client):
|
||||
client.delete_agent(second_agent_state_test.id)
|
||||
|
||||
|
||||
def test_memory(client, agent):
|
||||
def test_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# get agent memory
|
||||
original_memory = client.get_in_context_memory(agent.id)
|
||||
assert original_memory is not None
|
||||
@ -177,7 +178,7 @@ def test_memory(client, agent):
|
||||
assert updated_memory.get_block("human").value != original_memory_value # check if the memory has been updated
|
||||
|
||||
|
||||
def test_archival_memory(client, agent):
|
||||
def test_archival_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test functions for interacting with archival memory store"""
|
||||
|
||||
# add archival memory
|
||||
@ -192,12 +193,12 @@ def test_archival_memory(client, agent):
|
||||
client.delete_archival_memory(agent.id, passage.id)
|
||||
|
||||
|
||||
def test_recall_memory(client, agent):
|
||||
def test_recall_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""Test functions for interacting with recall memory store"""
|
||||
|
||||
# send message to the agent
|
||||
message_str = "Hello"
|
||||
client.send_message(message_str, "user", agent.id)
|
||||
client.send_message(message=message_str, role="user", agent_id=agent.id)
|
||||
|
||||
# list messages
|
||||
messages = client.get_messages(agent.id)
|
||||
@ -216,7 +217,7 @@ def test_recall_memory(client, agent):
|
||||
assert exists
|
||||
|
||||
|
||||
def test_tools(client):
|
||||
def test_tools(client: Union[LocalClient, RESTClient]):
|
||||
def print_tool(message: str):
|
||||
"""
|
||||
A tool to print a message
|
||||
@ -265,7 +266,7 @@ def test_tools(client):
|
||||
# assert len(client.list_tools()) == orig_tool_length
|
||||
|
||||
|
||||
def test_tools_from_composio_basic(client):
|
||||
def test_tools_from_composio_basic(client: Union[LocalClient, RESTClient]):
|
||||
from composio_langchain import Action
|
||||
|
||||
from letta.schemas.tool import Tool
|
||||
@ -286,7 +287,7 @@ def test_tools_from_composio_basic(client):
|
||||
# The tool creation includes a compile safety check, so if this test doesn't error out, at least the code is compilable
|
||||
|
||||
|
||||
def test_tools_from_crewai(client):
|
||||
def test_tools_from_crewai(client: Union[LocalClient, RESTClient]):
|
||||
# create crewAI tool
|
||||
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
@ -323,7 +324,7 @@ def test_tools_from_crewai(client):
|
||||
assert expected_content in func(website_url=simple_webpage_url)
|
||||
|
||||
|
||||
def test_tools_from_crewai_with_params(client):
|
||||
def test_tools_from_crewai_with_params(client: Union[LocalClient, RESTClient]):
|
||||
# create crewAI tool
|
||||
|
||||
from crewai_tools import ScrapeWebsiteTool
|
||||
@ -357,7 +358,7 @@ def test_tools_from_crewai_with_params(client):
|
||||
assert expected_content in func()
|
||||
|
||||
|
||||
def test_tools_from_langchain(client):
|
||||
def test_tools_from_langchain(client: Union[LocalClient, RESTClient]):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
@ -391,7 +392,7 @@ def test_tools_from_langchain(client):
|
||||
assert expected_content in func(query="Albert Einstein")
|
||||
|
||||
|
||||
def test_tool_creation_langchain_missing_imports(client):
|
||||
def test_tool_creation_langchain_missing_imports(client: Union[LocalClient, RESTClient]):
|
||||
# create langchain tool
|
||||
from langchain_community.tools import WikipediaQueryRun
|
||||
from langchain_community.utilities import WikipediaAPIWrapper
|
||||
|
Loading…
Reference in New Issue
Block a user