MemGPT/letta/functions/ast_parsers.py
cthomas 1b58fae4fb
chore: bump version 0.7.22 (#2655)
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: jnjpng <jin@letta.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
2025-05-23 01:13:05 -07:00

158 lines
5.1 KiB
Python

import ast
import builtins
import json
import typing
from typing import Dict, Optional, Tuple
from letta.errors import LettaToolCreateError
# Registry of known types for annotation resolution
BUILTIN_TYPES = {
"int": int,
"float": float,
"str": str,
"dict": dict,
"list": list,
"set": set,
"tuple": tuple,
"bool": bool,
}
def resolve_type(annotation: str):
"""
Resolve a type annotation string into a Python type.
Args:
annotation (str): The annotation string (e.g., 'int', 'list[int]', 'dict[str, int]').
Returns:
type: The corresponding Python type.
Raises:
ValueError: If the annotation is unsupported or invalid.
"""
if annotation in BUILTIN_TYPES:
return BUILTIN_TYPES[annotation]
try:
# Allow use of typing and builtins in a safe eval context
namespace = {
**vars(typing),
**vars(builtins),
"list": list,
"dict": dict,
"tuple": tuple,
"set": set,
}
return eval(annotation, namespace)
except Exception:
raise ValueError(f"Unsupported annotation: {annotation}")
def get_function_annotations_from_source(source_code: str, function_name: str) -> Dict[str, str]:
"""
Parse the source code to extract annotations for a given function name.
Args:
source_code (str): The Python source code containing the function.
function_name (str): The name of the function to extract annotations for.
Returns:
Dict[str, str]: A dictionary of argument names to their annotation strings.
Raises:
ValueError: If the function is not found in the source code.
"""
tree = ast.parse(source_code)
for node in ast.iter_child_nodes(tree):
if isinstance(node, ast.FunctionDef) and node.name == function_name:
annotations = {}
for arg in node.args.args:
if arg.annotation is not None:
annotation_str = ast.unparse(arg.annotation)
annotations[arg.arg] = annotation_str
return annotations
raise ValueError(f"Function '{function_name}' not found in the provided source code.")
def coerce_dict_args_by_annotations(function_args: dict, annotations: Dict[str, str]) -> dict:
coerced_args = dict(function_args) # Shallow copy
for arg_name, value in coerced_args.items():
if arg_name in annotations:
annotation_str = annotations[arg_name]
try:
arg_type = resolve_type(annotation_str)
# Always parse strings using literal_eval or json if possible
if isinstance(value, str):
try:
value = json.loads(value)
except json.JSONDecodeError:
try:
value = ast.literal_eval(value)
except (SyntaxError, ValueError) as e:
if arg_type is not str:
raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}")
origin = typing.get_origin(arg_type)
if origin in (list, dict, tuple, set):
# Let the origin (e.g., list) handle coercion
coerced_args[arg_name] = origin(value)
else:
# Coerce simple types (e.g., int, float)
coerced_args[arg_name] = arg_type(value)
except Exception as e:
raise ValueError(f"Failed to coerce argument '{arg_name}' to {annotation_str}: {e}")
return coerced_args
def get_function_name_and_description(source_code: str, name: Optional[str] = None) -> Tuple[str, str]:
"""Gets the name and description for a given function source code by parsing the AST.
Args:
source_code: The source code to parse
name: Optional override for the function name
Returns:
Tuple of (function_name, docstring)
"""
try:
# Parse the source code into an AST
tree = ast.parse(source_code)
# Find the last function definition
function_def = None
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
function_def = node
if not function_def:
raise LettaToolCreateError("No function definition found in source code")
# Get the function name
function_name = name if name is not None else function_def.name
# Get the docstring if it exists
docstring = ast.get_docstring(function_def)
if not function_name:
raise LettaToolCreateError("Could not determine function name")
if not docstring:
raise LettaToolCreateError("Docstring is missing")
return function_name, docstring
except Exception as e:
raise LettaToolCreateError(f"Failed to parse function name and docstring: {str(e)}")
except Exception as e:
import traceback
traceback.print_exc()
raise LettaToolCreateError(f"Name and docstring generation failed: {str(e)}")