mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
157 lines
5.2 KiB
Python
157 lines
5.2 KiB
Python
import base64
|
|
from typing import List, Union
|
|
|
|
import numpy as np
|
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
|
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
|
from sqlalchemy import JSON
|
|
from sqlalchemy.types import BINARY, TypeDecorator
|
|
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.enums import ToolRuleType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
|
|
|
|
|
class EmbeddingConfigColumn(TypeDecorator):
|
|
"""Custom type for storing EmbeddingConfig as JSON."""
|
|
|
|
impl = JSON
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
return dialect.type_descriptor(JSON())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value and isinstance(value, EmbeddingConfig):
|
|
return value.model_dump()
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value:
|
|
return EmbeddingConfig(**value)
|
|
return value
|
|
|
|
|
|
class LLMConfigColumn(TypeDecorator):
|
|
"""Custom type for storing LLMConfig as JSON."""
|
|
|
|
impl = JSON
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
return dialect.type_descriptor(JSON())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value and isinstance(value, LLMConfig):
|
|
return value.model_dump()
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value:
|
|
return LLMConfig(**value)
|
|
return value
|
|
|
|
|
|
class ToolRulesColumn(TypeDecorator):
|
|
"""Custom type for storing a list of ToolRules as JSON"""
|
|
|
|
impl = JSON
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
return dialect.type_descriptor(JSON())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
"""Convert a list of ToolRules to JSON-serializable format."""
|
|
if value:
|
|
data = [rule.model_dump() for rule in value]
|
|
for d in data:
|
|
d["type"] = d["type"].value
|
|
|
|
for d in data:
|
|
assert not (d["type"] == "ToolRule" and "children" not in d), "ToolRule does not have children field"
|
|
return data
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect) -> List[Union[ChildToolRule, InitToolRule, TerminalToolRule]]:
|
|
"""Convert JSON back to a list of ToolRules."""
|
|
if value:
|
|
return [self.deserialize_tool_rule(rule_data) for rule_data in value]
|
|
return value
|
|
|
|
@staticmethod
|
|
def deserialize_tool_rule(data: dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
|
|
"""Deserialize a dictionary to the appropriate ToolRule subclass based on the 'type'."""
|
|
rule_type = ToolRuleType(data.get("type")) # Remove 'type' field if it exists since it is a class var
|
|
if rule_type == ToolRuleType.run_first or rule_type == "InitToolRule":
|
|
return InitToolRule(**data)
|
|
elif rule_type == ToolRuleType.exit_loop or rule_type == "TerminalToolRule":
|
|
return TerminalToolRule(**data)
|
|
elif rule_type == ToolRuleType.constrain_child_tools or rule_type == "ToolRule":
|
|
rule = ChildToolRule(**data)
|
|
return rule
|
|
elif rule_type == ToolRuleType.conditional:
|
|
rule = ConditionalToolRule(**data)
|
|
return rule
|
|
else:
|
|
raise ValueError(f"Unknown tool rule type: {rule_type}")
|
|
|
|
|
|
class ToolCallColumn(TypeDecorator):
|
|
|
|
impl = JSON
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
return dialect.type_descriptor(JSON())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value:
|
|
values = []
|
|
for v in value:
|
|
if isinstance(v, OpenAIToolCall):
|
|
values.append(v.model_dump())
|
|
else:
|
|
values.append(v)
|
|
return values
|
|
|
|
return value
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if value:
|
|
tools = []
|
|
for tool_value in value:
|
|
if "function" in tool_value:
|
|
tool_call_function = OpenAIFunction(**tool_value["function"])
|
|
del tool_value["function"]
|
|
else:
|
|
tool_call_function = None
|
|
tools.append(OpenAIToolCall(function=tool_call_function, **tool_value))
|
|
return tools
|
|
return value
|
|
|
|
|
|
class CommonVector(TypeDecorator):
|
|
"""Common type for representing vectors in SQLite"""
|
|
|
|
impl = BINARY
|
|
cache_ok = True
|
|
|
|
def load_dialect_impl(self, dialect):
|
|
return dialect.type_descriptor(BINARY())
|
|
|
|
def process_bind_param(self, value, dialect):
|
|
if value is None:
|
|
return value
|
|
if isinstance(value, list):
|
|
value = np.array(value, dtype=np.float32)
|
|
return base64.b64encode(value.tobytes())
|
|
|
|
def process_result_value(self, value, dialect):
|
|
if not value:
|
|
return value
|
|
if dialect.name == "sqlite":
|
|
value = base64.b64decode(value)
|
|
return np.frombuffer(value, dtype=np.float32)
|