MemGPT/letta/orm/custom_columns.py
cthomas a6bf85b01e
fix: message packing and tool rules issues (#2406)
Co-authored-by: cpacker <packercharles@gmail.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
2025-01-31 14:19:36 -08:00

160 lines
5.3 KiB
Python

import base64
from typing import List, Union
import numpy as np
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from sqlalchemy import JSON
from sqlalchemy.types import BINARY, TypeDecorator
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
class EmbeddingConfigColumn(TypeDecorator):
"""Custom type for storing EmbeddingConfig as JSON."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value and isinstance(value, EmbeddingConfig):
return value.model_dump()
return value
def process_result_value(self, value, dialect):
if value:
return EmbeddingConfig(**value)
return value
class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON."""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value and isinstance(value, LLMConfig):
return value.model_dump()
return value
def process_result_value(self, value, dialect):
if value:
return LLMConfig(**value)
return value
class ToolRulesColumn(TypeDecorator):
"""Custom type for storing a list of ToolRules as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
"""Convert a list of ToolRules to JSON-serializable format."""
if value:
data = [rule.model_dump() for rule in value]
for d in data:
d["type"] = d["type"].value
for d in data:
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
return data
return value
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
"""Convert JSON back to a list of ToolRules."""
if value:
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
return value
@staticmethod
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
if rule_type == ToolRuleType.run_first or rule_type == "InitToolRule":
data["type"] = ToolRuleType.run_first
return InitToolRule(**data)
elif rule_type == ToolRuleType.exit_loop or rule_type == "TerminalToolRule":
data["type"] = ToolRuleType.exit_loop
return TerminalToolRule(**data)
elif rule_type == ToolRuleType.constrain_child_tools or rule_type == "ToolRule":
data["type"] = ToolRuleType.constrain_child_tools
rule = ChildToolRule(**data)
return rule
elif rule_type == ToolRuleType.conditional:
rule = ConditionalToolRule(**data)
return rule
else:
raise ValueError(f"Unknown tool rule type: {rule_type}")
class ToolCallColumn(TypeDecorator):
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
values = []
for v in value:
if isinstance(v, OpenAIToolCall):
values.append(v.model_dump())
else:
values.append(v)
return values
return value
def process_result_value(self, value, dialect):
if value:
tools = []
for tool_value in value:
if "function" in tool_value:
tool_call_function = OpenAIFunction(**tool_value["function"])
del tool_value["function"]
else:
tool_call_function = None
tools.append(OpenAIToolCall(function=tool_call_function, **tool_value))
return tools
return value
class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""
impl = BINARY
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(BINARY())
def process_bind_param(self, value, dialect):
if value is None:
return value
if isinstance(value, list):
value = np.array(value, dtype=np.float32)
return base64.b64encode(value.tobytes())
def process_result_value(self, value, dialect):
if not value:
return value
if dialect.name == "sqlite":
value = base64.b64decode(value)
return np.frombuffer(value, dtype=np.float32)