mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Add tool_type
column (#576)
This commit is contained in:
parent
696940e481
commit
14a13b80b0
70
alembic/versions/e20573fe9b86_add_tool_types.py
Normal file
70
alembic/versions/e20573fe9b86_add_tool_types.py
Normal file
@ -0,0 +1,70 @@
|
||||
"""Add tool types
|
||||
|
||||
Revision ID: e20573fe9b86
|
||||
Revises: 915b68780108
|
||||
Create Date: 2025-01-09 15:11:47.779646
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from alembic import op
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.orm.enums import ToolType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e20573fe9b86"
|
||||
down_revision: Union[str, None] = "915b68780108"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Add the column as nullable with no default
|
||||
op.add_column("tools", sa.Column("tool_type", sa.String(), nullable=True))
|
||||
|
||||
# Step 2: Backpopulate the tool_type column based on tool name
|
||||
# Define the list of Letta core tools
|
||||
letta_core_value = ToolType.LETTA_CORE.value
|
||||
letta_memory_core_value = ToolType.LETTA_MEMORY_CORE.value
|
||||
custom_value = ToolType.CUSTOM.value
|
||||
|
||||
# Update tool_type for Letta core tools
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE tools
|
||||
SET tool_type = '{letta_core_value}'
|
||||
WHERE name IN ({','.join(f"'{name}'" for name in BASE_TOOLS)});
|
||||
"""
|
||||
)
|
||||
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE tools
|
||||
SET tool_type = '{letta_memory_core_value}'
|
||||
WHERE name IN ({','.join(f"'{name}'" for name in BASE_MEMORY_TOOLS)});
|
||||
"""
|
||||
)
|
||||
|
||||
# Update tool_type for all other tools
|
||||
op.execute(
|
||||
f"""
|
||||
UPDATE tools
|
||||
SET tool_type = '{custom_value}'
|
||||
WHERE tool_type IS NULL;
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 3: Alter the column to be non-nullable
|
||||
op.alter_column("tools", "tool_type", nullable=False)
|
||||
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the changes made during the upgrade
|
||||
op.alter_column("tools", "json_schema", existing_type=postgresql.JSON(astext_type=sa.Text()), nullable=False)
|
||||
op.drop_column("tools", "tool_type")
|
||||
# ### end Alembic commands ###
|
@ -7,11 +7,11 @@ from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
CLI_WARNING_PREFIX,
|
||||
ERROR_MESSAGE_PREFIX,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
LETTA_CORE_TOOL_MODULE_NAME,
|
||||
LLM_MAX_TOKENS,
|
||||
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
@ -19,6 +19,7 @@ from letta.constants import (
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.functions import get_function_from_module
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import is_context_overflow_error
|
||||
@ -26,6 +27,7 @@ from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.memory import summarize_messages
|
||||
from letta.orm import User
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@ -153,7 +155,7 @@ class Agent(BaseAgent):
|
||||
raise ValueError(f"Invalid JSON format in message: {msg.text}")
|
||||
return None
|
||||
|
||||
def update_memory_if_change(self, new_memory: Memory) -> bool:
|
||||
def update_memory_if_changed(self, new_memory: Memory) -> bool:
|
||||
"""
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
|
||||
@ -192,39 +194,45 @@ class Agent(BaseAgent):
|
||||
Execute tool modifications and persist the state of the agent.
|
||||
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
||||
"""
|
||||
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
|
||||
env = {}
|
||||
env.update(globals())
|
||||
exec(target_letta_tool.source_code, env)
|
||||
callable_func = env[target_letta_tool.json_schema["name"]]
|
||||
spec = inspect.getfullargspec(callable_func).annotations
|
||||
for name, arg in function_args.items():
|
||||
if isinstance(function_args[name], dict):
|
||||
function_args[name] = spec[name](**function_args[name])
|
||||
|
||||
# TODO: add agent manager here
|
||||
orig_memory_str = self.agent_state.memory.compile()
|
||||
|
||||
# TODO: need to have an AgentState object that actually has full access to the block data
|
||||
# this is because the sandbox tools need to be able to access block.value to edit this data
|
||||
try:
|
||||
# TODO: This is NO BUENO
|
||||
# TODO: Matching purely by names is extremely problematic, users can create tools with these names and run them in the agent loop
|
||||
# TODO: We will have probably have to match the function strings exactly for safety
|
||||
if function_name in BASE_TOOLS:
|
||||
if target_letta_tool.tool_type == ToolType.LETTA_CORE:
|
||||
# base tools are allowed to access the `Agent` object and run on the database
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE:
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
agent_state_copy = self.agent_state.__deepcopy__()
|
||||
function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
self.update_memory_if_changed(agent_state_copy.memory)
|
||||
else:
|
||||
# TODO: Get rid of this. This whole piece is pretty shady, that we exec the function to just get the type hints for args.
|
||||
env = {}
|
||||
env.update(globals())
|
||||
exec(target_letta_tool.source_code, env)
|
||||
callable_func = env[target_letta_tool.json_schema["name"]]
|
||||
spec = inspect.getfullargspec(callable_func).annotations
|
||||
for name, arg in function_args.items():
|
||||
if isinstance(function_args[name], dict):
|
||||
function_args[name] = spec[name](**function_args[name])
|
||||
|
||||
# execute tool in a sandbox
|
||||
# TODO: allow agent_state to specify which sandbox to execute tools in
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(
|
||||
agent_state=self.agent_state.__deepcopy__()
|
||||
)
|
||||
# TODO: This is only temporary, can remove after we publish a pip package with this object
|
||||
agent_state_copy = self.agent_state.__deepcopy__()
|
||||
agent_state_copy.tools = []
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(agent_state=agent_state_copy)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
if updated_agent_state is not None:
|
||||
self.update_memory_if_change(updated_agent_state.memory)
|
||||
self.update_memory_if_changed(updated_agent_state.memory)
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
@ -677,7 +685,7 @@ class Agent(BaseAgent):
|
||||
current_persisted_memory = Memory(
|
||||
blocks=[self.block_manager.get_block_by_id(block.id, actor=self.user) for block in self.agent_state.memory.get_blocks()]
|
||||
) # read blocks from DB
|
||||
self.update_memory_if_change(current_persisted_memory)
|
||||
self.update_memory_if_changed(current_persisted_memory)
|
||||
|
||||
# Step 1: add user message
|
||||
if isinstance(messages, Message):
|
||||
|
@ -8,6 +8,9 @@ API_PREFIX = "/v1"
|
||||
OPENAI_API_PREFIX = "/openai"
|
||||
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY = "COMPOSIO_ENTITY"
|
||||
COMPOSIO_TOOL_TAG_NAME = "composio"
|
||||
|
||||
LETTA_CORE_TOOL_MODULE_NAME = "letta.functions.function_sets.base"
|
||||
|
||||
# String in the error message for when the context window is too large
|
||||
# Example full message:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
import inspect
|
||||
from textwrap import dedent # remove indentation
|
||||
from types import ModuleType
|
||||
@ -64,6 +65,70 @@ def parse_source_code(func) -> str:
|
||||
return source_code
|
||||
|
||||
|
||||
def get_function_from_module(module_name: str, function_name: str):
|
||||
"""
|
||||
Dynamically imports a function from a specified module.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import (e.g., 'base').
|
||||
function_name (str): The name of the function to retrieve.
|
||||
|
||||
Returns:
|
||||
Callable: The imported function.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If the specified module cannot be found.
|
||||
AttributeError: If the function is not found in the module.
|
||||
"""
|
||||
try:
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(module_name)
|
||||
# Retrieve the function
|
||||
return getattr(module, function_name)
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")
|
||||
|
||||
|
||||
def get_json_schema_from_module(module_name: str, function_name: str) -> dict:
|
||||
"""
|
||||
Dynamically loads a specific function from a module and generates its JSON schema.
|
||||
|
||||
Args:
|
||||
module_name (str): The name of the module to import (e.g., 'base').
|
||||
function_name (str): The name of the function to retrieve.
|
||||
|
||||
Returns:
|
||||
dict: The JSON schema for the specified function.
|
||||
|
||||
Raises:
|
||||
ModuleNotFoundError: If the specified module cannot be found.
|
||||
AttributeError: If the function is not found in the module.
|
||||
ValueError: If the attribute is not a user-defined function.
|
||||
"""
|
||||
try:
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Retrieve the function
|
||||
attr = getattr(module, function_name, None)
|
||||
|
||||
# Check if it's a user-defined function
|
||||
if not (inspect.isfunction(attr) and attr.__module__ == module.__name__):
|
||||
raise ValueError(f"'{function_name}' is not a user-defined function in module '{module_name}'")
|
||||
|
||||
# Generate schema (assuming a `generate_schema` function exists)
|
||||
generated_schema = generate_schema(attr)
|
||||
|
||||
return generated_schema
|
||||
|
||||
except ModuleNotFoundError:
|
||||
raise ModuleNotFoundError(f"Module '{module_name}' not found.")
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Function '{function_name}' not found in module '{module_name}'.")
|
||||
|
||||
|
||||
def load_function_set(module: ModuleType) -> dict:
|
||||
"""Load the functions and generate schema for them, given a module object"""
|
||||
function_dict = {}
|
||||
|
@ -109,6 +109,7 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
state = {
|
||||
"id": self.id,
|
||||
"organization_id": self.organization_id,
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"message_ids": self.message_ids,
|
||||
|
@ -4,7 +4,7 @@ from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
# TODO everything in functions should live in this model
|
||||
from letta.orm.enums import ToolSourceType
|
||||
from letta.orm.enums import ToolSourceType, ToolType
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
@ -29,12 +29,17 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
__table_args__ = (UniqueConstraint("name", "organization_id", name="uix_name_organization"),)
|
||||
|
||||
name: Mapped[str] = mapped_column(doc="The display name of the tool.")
|
||||
tool_type: Mapped[ToolType] = mapped_column(
|
||||
String,
|
||||
default=ToolType.CUSTOM,
|
||||
doc="The type of tool. This affects whether or not we generate json_schema and source_code on the fly.",
|
||||
)
|
||||
return_char_limit: Mapped[int] = mapped_column(nullable=True, doc="The maximum number of characters the tool can return.")
|
||||
description: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The description of the tool.")
|
||||
tags: Mapped[List] = mapped_column(JSON, doc="Metadata tags used to filter tools.")
|
||||
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
|
||||
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
|
||||
json_schema: Mapped[dict] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
|
||||
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
|
||||
module: Mapped[Optional[str]] = mapped_column(
|
||||
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
|
||||
)
|
||||
|
@ -2,10 +2,11 @@ from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from letta.constants import FUNCTION_RETURN_CHAR_LIMIT
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.constants import COMPOSIO_TOOL_TAG_NAME, FUNCTION_RETURN_CHAR_LIMIT, LETTA_CORE_TOOL_MODULE_NAME
|
||||
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
|
||||
from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema_v2
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
|
||||
@ -28,6 +29,7 @@ class Tool(BaseTool):
|
||||
"""
|
||||
|
||||
id: str = BaseTool.generate_id_field()
|
||||
tool_type: ToolType = Field(ToolType.CUSTOM, description="The type of the tool.")
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
module: Optional[str] = Field(None, description="The module of the function.")
|
||||
@ -36,7 +38,7 @@ class Tool(BaseTool):
|
||||
tags: List[str] = Field([], description="Metadata tags.")
|
||||
|
||||
# code
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_code: Optional[str] = Field(None, description="The source code of the function.")
|
||||
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
|
||||
|
||||
# tool configuration
|
||||
@ -51,9 +53,19 @@ class Tool(BaseTool):
|
||||
"""
|
||||
Populate missing fields: name, description, and json_schema.
|
||||
"""
|
||||
# Derive JSON schema if not provided
|
||||
if not self.json_schema:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
if self.tool_type == ToolType.CUSTOM:
|
||||
# If it's a custom tool, we need to ensure source_code is present
|
||||
if not self.source_code:
|
||||
raise ValueError(f"Custom tool with id={self.id} is missing source_code field.")
|
||||
|
||||
# Always derive json_schema for freshest possible json_schema
|
||||
# TODO: Instead of checking the tag, we should having `COMPOSIO` as a specific ToolType
|
||||
# TODO: We skip this for Composio bc composio json schemas are derived differently
|
||||
if not (COMPOSIO_TOOL_TAG_NAME in self.tags):
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}:
|
||||
# If it's letta core tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
|
||||
|
||||
# Derive name from the JSON schema if not provided
|
||||
if not self.name:
|
||||
@ -125,7 +137,7 @@ class ToolCreate(LettaBase):
|
||||
|
||||
description = composio_tool.description
|
||||
source_type = "python"
|
||||
tags = ["composio"]
|
||||
tags = [COMPOSIO_TOOL_TAG_NAME]
|
||||
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action_name)
|
||||
json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
from letta.orm.enums import ToolType
|
||||
|
||||
# TODO: Remove this once we translate all of these to the ORM
|
||||
from letta.orm.errors import NoResultFound
|
||||
@ -32,10 +32,10 @@ class ToolManager:
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
# TODO: Refactor this across the codebase to use CreateTool instead of passing in a Tool object
|
||||
@enforce_types
|
||||
def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
# Derive json_schema
|
||||
tool = self.get_tool_by_name(tool_name=pydantic_tool.name, actor=actor)
|
||||
if tool:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
@ -63,6 +63,7 @@ class ToolManager:
|
||||
if pydantic_tool.description is None:
|
||||
pydantic_tool.description = pydantic_tool.json_schema.get("description", None)
|
||||
tool_data = pydantic_tool.model_dump()
|
||||
|
||||
tool = ToolModel(**tool_data)
|
||||
tool.create(session, actor=actor) # Re-raise other database-related errors
|
||||
return tool.to_pydantic()
|
||||
@ -113,8 +114,6 @@ class ToolManager:
|
||||
# If source code is changed and a new json_schema is not provided, we want to auto-refresh the schema
|
||||
if "source_code" in update_data.keys() and "json_schema" not in update_data.keys():
|
||||
pydantic_tool = tool.to_pydantic()
|
||||
|
||||
update_data["name"] if "name" in update_data.keys() else None
|
||||
new_schema = derive_openai_json_schema(source_code=pydantic_tool.source_code)
|
||||
|
||||
tool.json_schema = new_schema
|
||||
@ -155,12 +154,19 @@ class ToolManager:
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
if name in BASE_TOOLS + BASE_MEMORY_TOOLS:
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tags = [module_name]
|
||||
if module_name == "base":
|
||||
tags.append("letta-base")
|
||||
|
||||
# BASE_MEMORY_TOOLS should be executed in an e2b sandbox
|
||||
# so they should NOT be letta_core tools, instead, treated as custom tools
|
||||
if name in BASE_TOOLS:
|
||||
tool_type = ToolType.LETTA_CORE
|
||||
elif name in BASE_MEMORY_TOOLS:
|
||||
tool_type = ToolType.LETTA_MEMORY_CORE
|
||||
else:
|
||||
raise ValueError(f"Tool name {name} is not in the list of base tool names: {BASE_TOOLS + BASE_MEMORY_TOOLS}")
|
||||
|
||||
# create to tool
|
||||
tools.append(
|
||||
self.create_or_update_tool(
|
||||
@ -168,9 +174,7 @@ class ToolManager:
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
module=schema["module"],
|
||||
source_code=source_code,
|
||||
json_schema=schema["json_schema"],
|
||||
tool_type=tool_type,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
@ -197,7 +197,7 @@ def composio_gmail_get_profile_tool(test_user):
|
||||
|
||||
@pytest.fixture
|
||||
def clear_core_memory_tool(test_user):
|
||||
def clear_memory(agent_state: AgentState):
|
||||
def clear_memory(agent_state: "AgentState"):
|
||||
"""Clear the core memory"""
|
||||
agent_state.memory.get_block("human").value = ""
|
||||
agent_state.memory.get_block("persona").value = ""
|
||||
|
@ -42,7 +42,9 @@ def run_server():
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[{"server": False}, {"server": True}], # whether to use REST API server
|
||||
params=[
|
||||
{"server": False},
|
||||
], # {"server": True}], # whether to use REST API server
|
||||
# params=[{"server": False}], # whether to use REST API server
|
||||
scope="module",
|
||||
)
|
||||
@ -121,7 +123,6 @@ def test_shared_blocks(mock_e2b_api_key_none, client: Union[LocalClient, RESTCli
|
||||
assert (
|
||||
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
|
||||
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
|
||||
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
|
||||
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
|
@ -30,6 +30,7 @@ from letta.orm import (
|
||||
User,
|
||||
)
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
@ -1368,6 +1369,7 @@ def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
assert fetched_tool.tags == print_tool.tags
|
||||
assert fetched_tool.source_code == print_tool.source_code
|
||||
assert fetched_tool.source_type == print_tool.source_type
|
||||
assert fetched_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
|
||||
def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
|
||||
@ -1382,6 +1384,7 @@ def test_get_tool_with_actor(server: SyncServer, print_tool, default_user):
|
||||
assert fetched_tool.tags == print_tool.tags
|
||||
assert fetched_tool.source_code == print_tool.source_code
|
||||
assert fetched_tool.source_type == print_tool.source_type
|
||||
assert fetched_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
|
||||
def test_list_tools(server: SyncServer, print_tool, default_user):
|
||||
@ -1445,6 +1448,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, p
|
||||
|
||||
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code)
|
||||
assert updated_tool.json_schema == new_schema
|
||||
assert updated_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print_tool, default_user):
|
||||
@ -1483,6 +1487,7 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, print
|
||||
new_schema = derive_openai_json_schema(source_code=updated_tool.source_code, name=updated_tool.name)
|
||||
assert updated_tool.json_schema == new_schema
|
||||
assert updated_tool.name == name
|
||||
assert updated_tool.tool_type == ToolType.CUSTOM
|
||||
|
||||
|
||||
def test_update_tool_multi_user(server: SyncServer, print_tool, default_user, other_user):
|
||||
@ -1519,6 +1524,15 @@ def test_upsert_base_tools(server: SyncServer, default_user):
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
# Confirm that the return tools have no source_code, but a json_schema
|
||||
for t in tools:
|
||||
if t.name in BASE_TOOLS:
|
||||
assert t.tool_type == ToolType.LETTA_CORE
|
||||
else:
|
||||
assert t.tool_type == ToolType.LETTA_MEMORY_CORE
|
||||
assert t.source_code is None
|
||||
assert t.json_schema
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Message Manager Tests
|
||||
|
Loading…
Reference in New Issue
Block a user