mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Composio tools execute on-the-fly (#999)
This commit is contained in:
parent
8097b1e98a
commit
a734f99d8d
@ -9,7 +9,6 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
"""
|
||||
Setup here.
|
||||
@ -31,7 +30,7 @@ for agent_state in client.list_agents():
|
||||
|
||||
|
||||
# Add sandbox env
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
# Ensure you have e2b key set
|
||||
sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=client.user)
|
||||
manager.create_sandbox_env_var(
|
||||
|
@ -3,12 +3,13 @@ import time
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
||||
ERROR_MESSAGE_PREFIX,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
@ -20,7 +21,11 @@ from letta.constants import (
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.functions import get_function_from_module
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
@ -51,6 +56,7 @@ from letta.services.passage_manager import PassageManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
from letta.system import get_heartbeat, get_token_limit_warning, package_function_response, package_summarize_message, package_user_message
|
||||
@ -58,9 +64,6 @@ from letta.utils import (
|
||||
count_tokens,
|
||||
get_friendly_error_msg,
|
||||
get_tool_call_id,
|
||||
get_utc_time,
|
||||
json_dumps,
|
||||
json_loads,
|
||||
log_telemetry,
|
||||
parse_json,
|
||||
printd,
|
||||
@ -202,7 +205,7 @@ class Agent(BaseAgent):
|
||||
|
||||
def execute_tool_and_persist_state(
|
||||
self, function_name: str, function_args: dict, target_letta_tool: Tool
|
||||
) -> tuple[str, Optional[SandboxRunResult]]:
|
||||
) -> tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""
|
||||
Execute tool modifications and persist the state of the agent.
|
||||
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
||||
@ -228,6 +231,18 @@ class Agent(BaseAgent):
|
||||
function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
self.update_memory_if_changed(agent_state_copy.memory)
|
||||
elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
|
||||
action_name = generate_composio_action_from_func_name(target_letta_tool.name)
|
||||
# Get entity ID from the agent_state
|
||||
entity_id = None
|
||||
for env_var in self.agent_state.tool_exec_environment_variables:
|
||||
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
|
||||
entity_id = env_var.value
|
||||
# Get composio_api_key
|
||||
composio_api_key = get_composio_api_key(actor=self.user, logger=self.logger)
|
||||
function_response = execute_composio_action(
|
||||
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
||||
)
|
||||
else:
|
||||
# Parse the source code to extract function annotations
|
||||
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
|
||||
@ -460,7 +475,10 @@ class Agent(BaseAgent):
|
||||
target_letta_tool = None
|
||||
for t in self.agent_state.tools:
|
||||
if t.name == function_name:
|
||||
target_letta_tool = t
|
||||
# This force refreshes the target_letta_tool from the database
|
||||
# We only do this on name match to confirm that the agent state contains a specific tool with the right name
|
||||
target_letta_tool = ToolManager().get_tool_by_name(tool_name=function_name, actor=self.user)
|
||||
break
|
||||
|
||||
if not target_letta_tool:
|
||||
error_msg = f"No function named {function_name}"
|
||||
|
@ -8,7 +8,7 @@ import typer
|
||||
from prettytable.colortable import ColorTable, Themes
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta import utils
|
||||
import letta.helpers.datetime_helpers
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@ -51,7 +51,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
agent.memory.get_block("persona").value[:100] + "...",
|
||||
agent.memory.get_block("human").value[:100] + "...",
|
||||
",".join(source_names),
|
||||
utils.format_datetime(agent.created_at),
|
||||
letta.helpers.datetime_helpers.format_datetime(agent.created_at),
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
@ -84,7 +84,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
source.description,
|
||||
source.embedding_config.embedding_model,
|
||||
source.embedding_config.embedding_dim,
|
||||
utils.format_datetime(source.created_at),
|
||||
letta.helpers.datetime_helpers.format_datetime(source.created_at),
|
||||
]
|
||||
)
|
||||
|
||||
|
@ -33,7 +33,7 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
|
||||
import math
|
||||
|
||||
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.utils import json_dumps
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
|
||||
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
|
||||
page = 0
|
||||
|
@ -5,9 +5,9 @@ from typing import Optional
|
||||
import requests
|
||||
|
||||
from letta.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.schemas.message import Message, TextContent
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
|
||||
def message_chatgpt(self, message: str):
|
||||
|
@ -72,6 +72,22 @@ def {func_name}(**kwargs):
|
||||
return func_name, wrapper_function_str
|
||||
|
||||
|
||||
def execute_composio_action(
|
||||
action_name: str, args: dict, api_key: Optional[str] = None, entity_id: Optional[str] = None
|
||||
) -> tuple[str, str]:
|
||||
import os
|
||||
|
||||
from composio_langchain import ComposioToolSet
|
||||
|
||||
entity_id = entity_id or os.getenv(COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_ENTITY_ID)
|
||||
composio_toolset = ComposioToolSet(api_key=api_key, entity_id=entity_id)
|
||||
response = composio_toolset.execute_action(action=action_name, params=args)
|
||||
|
||||
if response["error"]:
|
||||
raise RuntimeError(response["error"])
|
||||
return response["data"]
|
||||
|
||||
|
||||
def generate_langchain_tool_wrapper(
|
||||
tool: "LangChainBaseTool", additional_imports_module_attr_map: dict[str, str] = None
|
||||
) -> tuple[str, str]:
|
||||
|
21
letta/helpers/composio_helpers.py
Normal file
21
letta/helpers/composio_helpers.py
Normal file
@ -0,0 +1,21 @@
|
||||
from logging import Logger
|
||||
from typing import Optional
|
||||
|
||||
from letta.schemas.user import User
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def get_composio_api_key(actor: User, logger: Logger) -> Optional[str]:
|
||||
api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
|
||||
if not api_keys:
|
||||
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
if tool_settings.composio_api_key:
|
||||
return tool_settings.composio_api_key
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# TODO: Add more protections around this
|
||||
# Ideally, not tied to a specific sandbox, but for now we just get the first one
|
||||
# Theoretically possible for someone to have different composio api keys per sandbox
|
||||
return api_keys[0].value
|
90
letta/helpers/datetime_helpers.py
Normal file
90
letta/helpers/datetime_helpers.py
Normal file
@ -0,0 +1,90 @@
|
||||
import re
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytz
|
||||
|
||||
|
||||
def parse_formatted_time(formatted_time):
|
||||
# parse times returned by letta.utils.get_formatted_time()
|
||||
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
||||
def datetime_to_timestamp(dt):
|
||||
# convert datetime object to integer timestamp
|
||||
return int(dt.timestamp())
|
||||
|
||||
|
||||
def timestamp_to_datetime(ts):
|
||||
# convert integer timestamp to datetime object
|
||||
return datetime.fromtimestamp(ts)
|
||||
|
||||
|
||||
def get_local_time_military():
|
||||
# Get the current time in UTC
|
||||
current_time_utc = datetime.now(pytz.utc)
|
||||
|
||||
# Convert to San Francisco's time zone (PST/PDT)
|
||||
sf_time_zone = pytz.timezone("America/Los_Angeles")
|
||||
local_time = current_time_utc.astimezone(sf_time_zone)
|
||||
|
||||
# You may format it as you desire
|
||||
formatted_time = local_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")
|
||||
|
||||
return formatted_time
|
||||
|
||||
|
||||
def get_local_time_timezone(timezone="America/Los_Angeles"):
|
||||
# Get the current time in UTC
|
||||
current_time_utc = datetime.now(pytz.utc)
|
||||
|
||||
# Convert to San Francisco's time zone (PST/PDT)
|
||||
sf_time_zone = pytz.timezone(timezone)
|
||||
local_time = current_time_utc.astimezone(sf_time_zone)
|
||||
|
||||
# You may format it as you desire, including AM/PM
|
||||
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
return formatted_time
|
||||
|
||||
|
||||
def get_local_time(timezone=None):
|
||||
if timezone is not None:
|
||||
time_str = get_local_time_timezone(timezone)
|
||||
else:
|
||||
# Get the current time, which will be in the local timezone of the computer
|
||||
local_time = datetime.now().astimezone()
|
||||
|
||||
# You may format it as you desire, including AM/PM
|
||||
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
return time_str.strip()
|
||||
|
||||
|
||||
def get_utc_time() -> datetime:
|
||||
"""Get the current UTC time"""
|
||||
# return datetime.now(pytz.utc)
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def format_datetime(dt):
|
||||
return dt.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
||||
def validate_date_format(date_str):
|
||||
"""Validate the given date string in the format 'YYYY-MM-DD'."""
|
||||
try:
|
||||
datetime.strptime(date_str, "%Y-%m-%d")
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def extract_date_from_timestamp(timestamp):
|
||||
"""Extracts and returns the date from the given timestamp."""
|
||||
# Extracts the date (ignoring the time and timezone)
|
||||
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
def is_utc_datetime(dt: datetime) -> bool:
|
||||
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0)
|
15
letta/helpers/json_helpers.py
Normal file
15
letta/helpers/json_helpers.py
Normal file
@ -0,0 +1,15 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
def json_loads(data):
|
||||
return json.loads(data, strict=False)
|
||||
|
||||
|
||||
def json_dumps(data, indent=2):
|
||||
def safe_serializer(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Type {type(obj)} not serializable")
|
||||
|
||||
return json.dumps(data, indent=indent, default=safe_serializer, ensure_ascii=False)
|
@ -5,9 +5,10 @@ from typing import List, Optional
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.helpers.json_helpers import json_loads
|
||||
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL, INNER_THOUGHTS_CLI_SYMBOL
|
||||
from letta.schemas.message import Message
|
||||
from letta.utils import json_loads, printd
|
||||
from letta.utils import printd
|
||||
|
||||
init(autoreset=True)
|
||||
|
||||
|
@ -18,6 +18,7 @@ from anthropic.types.beta import (
|
||||
)
|
||||
|
||||
from letta.errors import BedrockError, BedrockPermissionError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.llm_api.aws_bedrock import get_bedrock_client
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
||||
@ -39,7 +40,6 @@ from letta.schemas.openai.chat_completion_response import MessageDelta, ToolCall
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.settings import model_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.utils import get_utc_time
|
||||
|
||||
BASE_URL = "https://api.anthropic.com/v1"
|
||||
|
||||
|
@ -4,6 +4,8 @@ from typing import List, Optional, Union
|
||||
|
||||
import requests
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
@ -12,7 +14,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own Letta Message datatype
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from letta.utils import get_tool_call_id, get_utc_time, json_dumps, smart_urljoin
|
||||
from letta.utils import get_tool_call_id, smart_urljoin
|
||||
|
||||
BASE_URL = "https://api.cohere.ai/v1"
|
||||
|
||||
|
@ -4,12 +4,14 @@ from typing import List, Optional, Tuple
|
||||
import requests
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.helpers import make_post_request
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
|
||||
def get_gemini_endpoint_and_headers(
|
||||
|
@ -2,11 +2,13 @@ import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import NON_USER_MSG_PREFIX
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.schemas.openai.chat_completion_request import Tool
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
||||
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
|
||||
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
|
||||
|
@ -7,10 +7,11 @@ from typing import Any, List, Union
|
||||
import requests
|
||||
|
||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.utils import count_tokens, json_dumps, printd
|
||||
from letta.utils import count_tokens, printd
|
||||
|
||||
|
||||
def _convert_to_structured_output_helper(property: dict) -> dict:
|
||||
|
@ -6,6 +6,8 @@ import requests
|
||||
|
||||
from letta.constants import CLI_WARNING_PREFIX
|
||||
from letta.errors import LocalLLMConnectionError, LocalLLMError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.constants import DEFAULT_WRAPPER
|
||||
from letta.local_llm.function_parser import patch_function
|
||||
from letta.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model_from_function, generate_gbnf_grammar_and_documentation
|
||||
@ -20,7 +22,7 @@ from letta.local_llm.webui.api import get_webui_completion
|
||||
from letta.local_llm.webui.legacy_api import get_webui_completion as get_webui_completion_legacy
|
||||
from letta.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, Message, ToolCall, UsageStatistics
|
||||
from letta.utils import get_tool_call_id, get_utc_time, json_dumps
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
has_shown_warning = False
|
||||
grammar_supported_backends = ["koboldcpp", "llamacpp", "webui", "webui-legacy"]
|
||||
|
@ -1,7 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
|
||||
from letta.utils import json_dumps, json_loads
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
|
||||
NO_HEARTBEAT_FUNCS = ["send_message"]
|
||||
|
||||
|
@ -10,7 +10,7 @@ from typing import Any, Callable, List, Optional, Tuple, Type, Union, _GenericAl
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel, create_model
|
||||
|
||||
from letta.utils import json_dumps
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
|
||||
|
||||
class PydanticDataType(Enum):
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
import re
|
||||
|
||||
from letta.errors import LLMJSONParsingError
|
||||
from letta.utils import json_loads
|
||||
from letta.helpers.json_helpers import json_loads
|
||||
|
||||
|
||||
def clean_json_string_extra_backslash(s):
|
||||
|
@ -1,6 +1,5 @@
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ...helpers.json_helpers import json_dumps, json_loads
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
from letta.errors import LLMJSONParsingError
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.local_llm.json_parser import clean_json
|
||||
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import LLMChatCompletionWrapper
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
PREFIX_HINT = """# Reminders:
|
||||
# Important information about yourself and the user is stored in (limited) core memory
|
||||
|
@ -1,8 +1,7 @@
|
||||
import yaml
|
||||
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ...helpers.json_helpers import json_dumps, json_loads
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ...helpers.json_helpers import json_dumps, json_loads
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from letta.errors import LLMJSONParsingError
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.local_llm.json_parser import clean_json
|
||||
from letta.local_llm.llm_chat_completion_wrappers.wrapper_base import LLMChatCompletionWrapper
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
PREFIX_HINT = """# Reminders:
|
||||
# Important information about yourself and the user is stored in (limited) core memory
|
||||
|
@ -1,5 +1,4 @@
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
from ...helpers.json_helpers import json_dumps, json_loads
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
|
||||
|
||||
|
@ -1,6 +1,5 @@
|
||||
from letta.utils import json_dumps, json_loads
|
||||
|
||||
from ...errors import LLMJSONParsingError
|
||||
from ...helpers.json_helpers import json_dumps, json_loads
|
||||
from ..json_parser import clean_json
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
|
||||
|
@ -4,7 +4,7 @@ from copy import deepcopy
|
||||
from enum import Enum
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from letta.utils import json_dumps
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
|
||||
api_requestor = None
|
||||
api_resources = None
|
||||
|
@ -5,10 +5,10 @@ from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.utils import json_dumps
|
||||
|
||||
# TODO: consider moving into own file
|
||||
|
||||
|
@ -12,6 +12,8 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, TOOL_CALL_ID_MAX_LEN
|
||||
from letta.helpers.datetime_helpers import get_utc_time, is_utc_datetime
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageContentType, MessageRole
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
@ -28,7 +30,6 @@ from letta.schemas.letta_message import (
|
||||
UserMessage,
|
||||
)
|
||||
from letta.system import unpack_message
|
||||
from letta.utils import get_utc_time, is_utc_datetime, json_dumps
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
|
@ -3,8 +3,9 @@ from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.utils import create_random_username, get_utc_time
|
||||
from letta.utils import create_random_username
|
||||
|
||||
|
||||
class OrganizationBase(LettaBase):
|
||||
|
@ -4,9 +4,9 @@ from typing import Dict, List, Optional
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.utils import get_utc_time
|
||||
|
||||
|
||||
class PassageBase(OrmMetadataBase):
|
||||
|
@ -7,6 +7,7 @@ from datetime import datetime
|
||||
from typing import AsyncGenerator, Literal, Optional, Union
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.helpers.datetime_helpers import is_utc_datetime
|
||||
from letta.interface import AgentInterface
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
@ -25,7 +26,6 @@ from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
|
||||
from letta.utils import is_utc_datetime
|
||||
|
||||
|
||||
# TODO strip from code / deprecate
|
||||
|
@ -8,14 +8,13 @@ from composio.tools.base.abs import InvalidClassDefinition
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import tool_settings
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
@ -205,15 +204,18 @@ def run_tool_from_source(
|
||||
|
||||
|
||||
# Specific routes for Composio
|
||||
|
||||
|
||||
@router.get("/composio/apps", response_model=List[AppModel], operation_id="list_composio_apps")
|
||||
def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")):
|
||||
"""
|
||||
Get a list of all Composio apps
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
composio_api_key = get_composio_key(server, actor=actor)
|
||||
composio_api_key = get_composio_api_key(actor=actor, logger=logger)
|
||||
if not composio_api_key:
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail=f"No API keys found for Composio. Please add your Composio API Key as an environment variable for your sandbox configuration, or set it as environment variable COMPOSIO_API_KEY.",
|
||||
)
|
||||
return server.get_composio_apps(api_key=composio_api_key)
|
||||
|
||||
|
||||
@ -227,7 +229,12 @@ def list_composio_actions_by_app(
|
||||
Get a list of all Composio actions for a specific app
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
composio_api_key = get_composio_key(server, actor=actor)
|
||||
composio_api_key = get_composio_api_key(actor=actor, logger=logger)
|
||||
if not composio_api_key:
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail=f"No API keys found for Composio. Please add your Composio API Key as an environment variable for your sandbox configuration, or set it as environment variable COMPOSIO_API_KEY.",
|
||||
)
|
||||
return server.get_composio_actions_from_app_name(composio_app_name=composio_app_name, api_key=composio_api_key)
|
||||
|
||||
|
||||
@ -308,24 +315,3 @@ def add_composio_tool(
|
||||
"composio_action_name": composio_action_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# TODO: Factor this out to somewhere else
|
||||
def get_composio_key(server: SyncServer, actor: User):
|
||||
api_keys = server.sandbox_config_manager.list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor)
|
||||
if not api_keys:
|
||||
logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...")
|
||||
|
||||
if tool_settings.composio_api_key:
|
||||
return tool_settings.composio_api_key
|
||||
else:
|
||||
# Nothing, raise fatal warning
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail=f"No API keys found for Composio. Please add your Composio API Key as an environment variable for your sandbox configuration, or set it as environment variable COMPOSIO_API_KEY.",
|
||||
)
|
||||
else:
|
||||
# TODO: Add more protections around this
|
||||
# Ideally, not tied to a specific sandbox, but for now we just get the first one
|
||||
# Theoretically possible for someone to have different composio api keys per sandbox
|
||||
return api_keys[0].value
|
||||
|
@ -19,6 +19,8 @@ import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.chat_only_agent import ChatOnlyAgent
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
|
||||
# TODO use custom interface
|
||||
from letta.interface import AgentInterface # abstract
|
||||
@ -80,7 +82,7 @@ from letta.services.step_manager import StepManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.utils import get_friendly_error_msg, get_utc_time, json_dumps, json_loads
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -296,7 +298,7 @@ class SyncServer(Server):
|
||||
self.tool_manager = ToolManager()
|
||||
self.block_manager = BlockManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||
self.sandbox_config_manager = SandboxConfigManager()
|
||||
self.message_manager = MessageManager()
|
||||
self.job_manager = JobManager()
|
||||
self.agent_manager = AgentManager()
|
||||
@ -315,7 +317,7 @@ class SyncServer(Server):
|
||||
|
||||
# Add composio keys to the tool sandbox env vars of the org
|
||||
if tool_settings.composio_api_key:
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
sandbox_config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=self.default_user)
|
||||
|
||||
manager.create_sandbox_env_var(
|
||||
|
@ -1,4 +1,4 @@
|
||||
from letta.utils import json_dumps
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
|
||||
# Server -> client
|
||||
|
||||
|
@ -6,6 +6,7 @@ from sqlalchemy import Select, and_, func, literal, or_, select, union_all
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM, MULTI_AGENT_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import AgentPassage, AgentsTags
|
||||
@ -42,7 +43,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types, get_utc_time, united_diff
|
||||
from letta.utils import enforce_types, united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
@ -4,6 +4,7 @@ from typing import List, Literal, Optional
|
||||
from letta import system
|
||||
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_local_time
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
from letta.orm.errors import NoResultFound
|
||||
@ -15,7 +16,6 @@ from letta.schemas.message import Message, MessageCreate, TextContent
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.schemas.user import User
|
||||
from letta.system import get_initial_boot_messages, get_login_event
|
||||
from letta.utils import get_local_time
|
||||
|
||||
|
||||
# Static methods
|
||||
|
@ -5,6 +5,7 @@ from typing import List, Literal, Optional, Union
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.orm.enums import JobType
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.job import Job as JobModel
|
||||
@ -20,7 +21,7 @@ from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types, get_utc_time
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class JobManager:
|
||||
|
@ -19,7 +19,7 @@ logger = get_logger(__name__)
|
||||
class SandboxConfigManager:
|
||||
"""Manager class to handle business logic related to SandboxConfig and SandboxEnvironmentVariable."""
|
||||
|
||||
def __init__(self, settings):
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
@ -62,7 +62,7 @@ class ToolExecutionSandbox:
|
||||
f"Agent attempted to invoke tool {self.tool_name} that does not exist for organization {self.user.organization_id}"
|
||||
)
|
||||
|
||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||
self.sandbox_config_manager = SandboxConfigManager()
|
||||
self.force_recreate = force_recreate
|
||||
self.force_recreate_venv = force_recreate_venv
|
||||
|
||||
|
@ -9,7 +9,8 @@ from .constants import (
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_THOUGHT,
|
||||
MESSAGE_SUMMARY_WARNING_STR,
|
||||
)
|
||||
from .utils import get_local_time, json_dumps
|
||||
from .helpers.datetime_helpers import get_local_time
|
||||
from .helpers.json_helpers import json_dumps
|
||||
|
||||
|
||||
def get_initial_boot_messages(version="startup"):
|
||||
|
105
letta/utils.py
105
letta/utils.py
@ -4,7 +4,6 @@ import difflib
|
||||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
@ -14,14 +13,13 @@ import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from logging import Logger
|
||||
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import demjson3 as demjson
|
||||
import pytz
|
||||
import tiktoken
|
||||
from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename
|
||||
|
||||
@ -35,6 +33,7 @@ from letta.constants import (
|
||||
MAX_FILENAME_LENGTH,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
)
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
|
||||
DEBUG = False
|
||||
@ -487,10 +486,6 @@ def smart_urljoin(base_url: str, relative_url: str) -> str:
|
||||
return urljoin(base_url, relative_url)
|
||||
|
||||
|
||||
def is_utc_datetime(dt: datetime) -> bool:
|
||||
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) == timedelta(0)
|
||||
|
||||
|
||||
def get_tool_call_id() -> str:
|
||||
# TODO(sarah) make this a slug-style string?
|
||||
# e.g. OpenAI: "call_xlIfzR1HqAW7xJPa3ExJSg3C"
|
||||
@ -824,72 +819,6 @@ def united_diff(str1, str2):
|
||||
return "".join(diff)
|
||||
|
||||
|
||||
def parse_formatted_time(formatted_time):
|
||||
# parse times returned by letta.utils.get_formatted_time()
|
||||
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
||||
def datetime_to_timestamp(dt):
|
||||
# convert datetime object to integer timestamp
|
||||
return int(dt.timestamp())
|
||||
|
||||
|
||||
def timestamp_to_datetime(ts):
|
||||
# convert integer timestamp to datetime object
|
||||
return datetime.fromtimestamp(ts)
|
||||
|
||||
|
||||
def get_local_time_military():
|
||||
# Get the current time in UTC
|
||||
current_time_utc = datetime.now(pytz.utc)
|
||||
|
||||
# Convert to San Francisco's time zone (PST/PDT)
|
||||
sf_time_zone = pytz.timezone("America/Los_Angeles")
|
||||
local_time = current_time_utc.astimezone(sf_time_zone)
|
||||
|
||||
# You may format it as you desire
|
||||
formatted_time = local_time.strftime("%Y-%m-%d %H:%M:%S %Z%z")
|
||||
|
||||
return formatted_time
|
||||
|
||||
|
||||
def get_local_time_timezone(timezone="America/Los_Angeles"):
|
||||
# Get the current time in UTC
|
||||
current_time_utc = datetime.now(pytz.utc)
|
||||
|
||||
# Convert to San Francisco's time zone (PST/PDT)
|
||||
sf_time_zone = pytz.timezone(timezone)
|
||||
local_time = current_time_utc.astimezone(sf_time_zone)
|
||||
|
||||
# You may format it as you desire, including AM/PM
|
||||
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
return formatted_time
|
||||
|
||||
|
||||
def get_local_time(timezone=None):
|
||||
if timezone is not None:
|
||||
time_str = get_local_time_timezone(timezone)
|
||||
else:
|
||||
# Get the current time, which will be in the local timezone of the computer
|
||||
local_time = datetime.now().astimezone()
|
||||
|
||||
# You may format it as you desire, including AM/PM
|
||||
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
return time_str.strip()
|
||||
|
||||
|
||||
def get_utc_time() -> datetime:
|
||||
"""Get the current UTC time"""
|
||||
# return datetime.now(pytz.utc)
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def format_datetime(dt):
|
||||
return dt.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
|
||||
def parse_json(string) -> dict:
|
||||
"""Parse JSON string into JSON with both json and demjson"""
|
||||
result = None
|
||||
@ -1046,23 +975,6 @@ def get_schema_diff(schema_a, schema_b):
|
||||
return "".join(difference)
|
||||
|
||||
|
||||
# datetime related
|
||||
def validate_date_format(date_str):
|
||||
"""Validate the given date string in the format 'YYYY-MM-DD'."""
|
||||
try:
|
||||
datetime.strptime(date_str, "%Y-%m-%d")
|
||||
return True
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def extract_date_from_timestamp(timestamp):
|
||||
"""Extracts and returns the date from the given timestamp."""
|
||||
# Extracts the date (ignoring the time and timezone)
|
||||
match = re.match(r"(\d{4}-\d{2}-\d{2})", timestamp)
|
||||
return match.group(1) if match else None
|
||||
|
||||
|
||||
def create_uuid_from_string(val: str):
|
||||
"""
|
||||
Generate consistent UUID from a string
|
||||
@ -1072,19 +984,6 @@ def create_uuid_from_string(val: str):
|
||||
return uuid.UUID(hex=hex_string)
|
||||
|
||||
|
||||
def json_dumps(data, indent=2):
|
||||
def safe_serializer(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
raise TypeError(f"Type {type(obj)} not serializable")
|
||||
|
||||
return json.dumps(data, indent=indent, default=safe_serializer, ensure_ascii=False)
|
||||
|
||||
|
||||
def json_loads(data):
|
||||
return json.loads(data, strict=False)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize the given filename to prevent directory traversal, invalid characters,
|
||||
|
@ -31,6 +31,7 @@ import openai
|
||||
from icml_experiments.utils import get_experiment_config, load_gzipped_file
|
||||
from tqdm import tqdm
|
||||
|
||||
import letta.helpers.json_helpers
|
||||
from letta import utils
|
||||
from letta.cli.cli_config import delete
|
||||
from letta.config import LettaConfig
|
||||
@ -70,7 +71,7 @@ def archival_memory_text_search(self, query: str, page: Optional[int] = 0) -> Op
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"memory: {d.text}" for d in results]
|
||||
results_str = f"{results_pref} {utils.json_dumps(results_formatted)}"
|
||||
results_str = f"{results_pref} {letta.helpers.json_helpers.json_dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
|
||||
|
@ -2,6 +2,10 @@ import logging
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
@ -31,3 +35,26 @@ def check_e2b_key_is_set():
|
||||
original_api_key = tool_settings.e2b_api_key
|
||||
assert original_api_key is not None, "Missing e2b key! Cannot execute these tests."
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization():
|
||||
"""Fixture to create and return the default organization."""
|
||||
manager = OrganizationManager()
|
||||
org = manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
manager = UserManager()
|
||||
user = manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_composio_key_set():
|
||||
original_api_key = tool_settings.composio_api_key
|
||||
assert original_api_key is not None, "Missing composio key! Cannot execute this test."
|
||||
yield
|
||||
|
@ -15,6 +15,7 @@ from letta.config import LettaConfig
|
||||
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.schemas.agent import AgentState
|
||||
@ -24,7 +25,7 @@ from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message
|
||||
from letta.utils import get_human_text, get_persona_text, json_dumps
|
||||
from letta.utils import get_human_text, get_persona_text
|
||||
from tests.helpers.utils import cleanup
|
||||
|
||||
# Generate uuid for agent name for this example
|
||||
|
@ -1,9 +1,15 @@
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.server.rest_api.app import app
|
||||
from letta.settings import tool_settings
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -13,6 +19,24 @@ def fastapi_client():
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
config.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composio_gmail_get_profile_tool(server, default_user):
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=default_user)
|
||||
yield tool
|
||||
|
||||
|
||||
def test_list_composio_apps(fastapi_client):
|
||||
response = fastapi_client.get("/v1/tools/composio/apps")
|
||||
assert response.status_code == 200
|
||||
@ -32,28 +56,26 @@ def test_add_composio_tool(fastapi_client):
|
||||
assert "name" in response.json()
|
||||
|
||||
|
||||
def test_composio_version_on_e2b_matches_server(check_e2b_key_is_set):
|
||||
import composio
|
||||
from e2b_code_interpreter import Sandbox
|
||||
from packaging.version import Version
|
||||
|
||||
sbx = Sandbox(tool_settings.e2b_sandbox_template_id)
|
||||
result = sbx.run_code(
|
||||
"""
|
||||
import composio
|
||||
print(str(composio.__version__))
|
||||
"""
|
||||
def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_profile_tool, server: SyncServer, default_user):
|
||||
agent_state = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="sarah_agent",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
e2b_composio_version = result.logs.stdout[0].strip()
|
||||
composio_version = str(composio.__version__)
|
||||
agent = server.load_agent(agent_state.id, actor=default_user)
|
||||
response = agent.execute_tool_and_persist_state(composio_gmail_get_profile_tool.name, {}, composio_gmail_get_profile_tool)
|
||||
assert response[0]["response_data"]["emailAddress"] == "sarah@letta.com"
|
||||
|
||||
# Compare versions
|
||||
if Version(composio_version) > Version(e2b_composio_version):
|
||||
raise AssertionError(f"Local composio version {composio_version} is greater than server version {e2b_composio_version}")
|
||||
elif Version(composio_version) < Version(e2b_composio_version):
|
||||
logger.warning(
|
||||
f"Local version of composio {composio_version} is less than the E2B version: {e2b_composio_version}. Please upgrade your local composio version."
|
||||
)
|
||||
|
||||
# Print concise summary
|
||||
logger.info(f"Server version: {composio_version}, E2B version: {e2b_composio_version}")
|
||||
# Add agent variable changing the entity ID
|
||||
agent_state = server.agent_manager.update_agent(
|
||||
agent_id=agent_state.id,
|
||||
agent_update=UpdateAgent(tool_exec_environment_variables={COMPOSIO_ENTITY_ENV_VAR_KEY: "matt"}),
|
||||
actor=default_user,
|
||||
)
|
||||
agent = server.load_agent(agent_state.id, actor=default_user)
|
||||
response = agent.execute_tool_and_persist_state(composio_gmail_get_profile_tool.name, {}, composio_gmail_get_profile_tool)
|
||||
assert response[0]["response_data"]["emailAddress"] == "matt@letta.com"
|
||||
|
@ -25,7 +25,7 @@ from letta.schemas.sandbox_config import (
|
||||
SandboxConfigUpdate,
|
||||
SandboxType,
|
||||
)
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
@ -53,13 +53,6 @@ def clear_tables():
|
||||
session.commit() # Commit the deletion
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def check_composio_key_set():
|
||||
original_api_key = tool_settings.composio_api_key
|
||||
assert original_api_key is not None, "Missing composio key! Cannot execute this test."
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_organization():
|
||||
"""Fixture to create and return the default organization."""
|
||||
@ -74,6 +67,14 @@ def test_user(test_organization):
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composio_gmail_get_profile_tool(test_user):
|
||||
tool_manager = ToolManager()
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def add_integers_tool(test_user):
|
||||
def add(x: int, y: int) -> int:
|
||||
@ -194,22 +195,6 @@ def composio_github_star_tool(test_user):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composio_gmail_get_profile_tool(test_user):
|
||||
tool_manager = ToolManager()
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composio_gmail_get_profile_tool(test_user):
|
||||
tool_manager = ToolManager()
|
||||
tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE")
|
||||
tool = tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clear_core_memory_tool(test_user):
|
||||
def clear_memory(agent_state: "AgentState"):
|
||||
@ -237,7 +222,7 @@ def agent_state():
|
||||
agent_state = client.create_agent(
|
||||
memory=ChatMemory(persona="This is the persona", human="My name is Chad"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4"),
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
)
|
||||
agent_state.tool_rules = []
|
||||
yield agent_state
|
||||
@ -255,7 +240,7 @@ def custom_test_sandbox_config(test_user):
|
||||
A tuple containing the SandboxConfigManager and the created sandbox configuration.
|
||||
"""
|
||||
# Create the SandboxConfigManager
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
|
||||
# Set the sandbox to be within the external codebase path and use a venv
|
||||
external_codebase_path = str(Path(__file__).parent / "test_tool_sandbox" / "restaurant_management_system")
|
||||
@ -327,7 +312,7 @@ def test_local_sandbox_with_list_rv(mock_e2b_api_key_none, list_tool, test_user)
|
||||
|
||||
@pytest.mark.local_sandbox
|
||||
def test_local_sandbox_env(mock_e2b_api_key_none, get_env_tool, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
|
||||
# Make a custom local sandbox config
|
||||
sandbox_dir = str(Path(__file__).parent / "test_tool_sandbox")
|
||||
@ -353,7 +338,7 @@ def test_local_sandbox_env(mock_e2b_api_key_none, get_env_tool, test_user):
|
||||
|
||||
@pytest.mark.local_sandbox
|
||||
def test_local_sandbox_per_agent_env(mock_e2b_api_key_none, get_env_tool, agent_state, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
key = "secret_word"
|
||||
|
||||
# Make a custom local sandbox config
|
||||
@ -389,7 +374,7 @@ def test_local_sandbox_per_agent_env(mock_e2b_api_key_none, get_env_tool, agent_
|
||||
@pytest.mark.local_sandbox
|
||||
def test_local_sandbox_e2e_composio_star_github(mock_e2b_api_key_none, check_composio_key_set, composio_github_star_tool, test_user):
|
||||
# Add the composio key
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=test_user)
|
||||
|
||||
manager.create_sandbox_env_var(
|
||||
@ -482,7 +467,7 @@ def test_local_sandbox_with_venv_errors(mock_e2b_api_key_none, custom_test_sandb
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_local_sandbox_with_venv_pip_installs_basic(mock_e2b_api_key_none, cowsay_tool, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(
|
||||
config=LocalSandboxConfig(use_venv=True, pip_requirements=[PipRequirement(name="cowsay")]).model_dump()
|
||||
)
|
||||
@ -502,7 +487,7 @@ def test_local_sandbox_with_venv_pip_installs_basic(mock_e2b_api_key_none, cowsa
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_local_sandbox_with_venv_pip_installs_with_update(mock_e2b_api_key_none, cowsay_tool, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=LocalSandboxConfig(use_venv=True).model_dump())
|
||||
config = manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
@ -554,7 +539,7 @@ def test_e2b_sandbox_default(check_e2b_key_is_set, add_integers_tool, test_user)
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_e2b_sandbox_pip_installs(check_e2b_key_is_set, cowsay_tool, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=E2BSandboxConfig(pip_requirements=["cowsay"]).model_dump())
|
||||
config = manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
@ -598,7 +583,7 @@ def test_e2b_sandbox_stateful_tool(check_e2b_key_is_set, clear_core_memory_tool,
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_env_tool, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
config_create = SandboxConfigCreate(config=E2BSandboxConfig().model_dump())
|
||||
config = manager.create_or_update_sandbox_config(config_create, test_user)
|
||||
|
||||
@ -624,7 +609,7 @@ def test_e2b_sandbox_inject_env_var_existing_sandbox(check_e2b_key_is_set, get_e
|
||||
# TODO: There is a near dupe of this test above for local sandbox - we should try to make it parameterized tests to minimize code bloat
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_state, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
key = "secret_word"
|
||||
|
||||
# Make a custom local sandbox config
|
||||
@ -659,7 +644,7 @@ def test_e2b_sandbox_per_agent_env(check_e2b_key_is_set, get_env_tool, agent_sta
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_e2b_sandbox_config_change_force_recreates_sandbox(check_e2b_key_is_set, list_tool, test_user):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
manager = SandboxConfigManager()
|
||||
old_timeout = 5 * 60
|
||||
new_timeout = 10 * 60
|
||||
|
||||
@ -693,58 +678,6 @@ def test_e2b_sandbox_with_list_rv(check_e2b_key_is_set, list_tool, test_user):
|
||||
assert len(result.func_return) == 5
|
||||
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_e2b_e2e_composio_star_github(check_e2b_key_is_set, check_composio_key_set, composio_github_star_tool, test_user):
|
||||
# Add the composio key
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=test_user)
|
||||
|
||||
manager.create_sandbox_env_var(
|
||||
SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=tool_settings.composio_api_key),
|
||||
sandbox_config_id=config.id,
|
||||
actor=test_user,
|
||||
)
|
||||
|
||||
result = ToolExecutionSandbox(composio_github_star_tool.name, {"owner": "letta-ai", "repo": "letta"}, user=test_user).run()
|
||||
assert result.func_return["details"] == "Action executed successfully"
|
||||
|
||||
# Missing args causes error
|
||||
result = ToolExecutionSandbox(composio_github_star_tool.name, {}, user=test_user).run()
|
||||
assert "Invalid request data provided" in result.func_return
|
||||
|
||||
|
||||
@pytest.mark.e2b_sandbox
|
||||
def test_e2b_multiple_composio_entities(
|
||||
check_e2b_key_is_set, check_composio_key_set, composio_gmail_get_profile_tool, agent_state, test_user
|
||||
):
|
||||
manager = SandboxConfigManager(tool_settings)
|
||||
config = manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=test_user)
|
||||
|
||||
manager.create_sandbox_env_var(
|
||||
SandboxEnvironmentVariableCreate(key="COMPOSIO_API_KEY", value=tool_settings.composio_api_key),
|
||||
sandbox_config_id=config.id,
|
||||
actor=test_user,
|
||||
)
|
||||
|
||||
# Agent state with no composio entity ID
|
||||
result = ToolExecutionSandbox(composio_gmail_get_profile_tool.name, {}, user=test_user).run(agent_state=agent_state)
|
||||
assert result.func_return["response_data"]["emailAddress"] == "sarah@letta.com"
|
||||
|
||||
# Agent state with the composio entity set to 'matt'
|
||||
agent_state.tool_exec_environment_variables = [
|
||||
AgentEnvironmentVariable(key=COMPOSIO_ENTITY_ENV_VAR_KEY, value="matt", agent_id=agent_state.id)
|
||||
]
|
||||
result = ToolExecutionSandbox(composio_gmail_get_profile_tool.name, {}, user=test_user).run(agent_state=agent_state)
|
||||
assert result.func_return["response_data"]["emailAddress"] == "matt@letta.com"
|
||||
|
||||
# Agent state with composio entity ID set to default
|
||||
agent_state.tool_exec_environment_variables = [
|
||||
AgentEnvironmentVariable(key=COMPOSIO_ENTITY_ENV_VAR_KEY, value="default", agent_id=agent_state.id)
|
||||
]
|
||||
result = ToolExecutionSandbox(composio_gmail_get_profile_tool.name, {}, user=test_user).run(agent_state=agent_state)
|
||||
assert result.func_return["response_data"]["emailAddress"] == "sarah@letta.com"
|
||||
|
||||
|
||||
# Core memory integration tests
|
||||
class TestCoreMemoryTools:
|
||||
"""
|
||||
|
@ -12,6 +12,7 @@ from sqlalchemy import delete
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET, MULTI_AGENT_TOOLS
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.orm import FileMetadata, Source
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@ -33,7 +34,6 @@ from letta.services.helpers.agent_manager_helper import initialize_message_seque
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings
|
||||
from letta.utils import get_utc_time
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
|
Loading…
Reference in New Issue
Block a user