mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
refactor: store presets in database via metadata store (#1013)
This commit is contained in:
parent
d8f3d04268
commit
c7fbc03e68
@ -344,8 +344,10 @@ def create_autogen_memgpt_agent(
|
|||||||
preset=agent_config["preset"],
|
preset=agent_config["preset"],
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user_id)
|
||||||
memgpt_agent = presets.create_agent_from_preset(
|
memgpt_agent = presets.create_agent_from_preset(
|
||||||
agent_state=agent_state,
|
agent_state=agent_state,
|
||||||
|
preset=preset,
|
||||||
interface=interface,
|
interface=interface,
|
||||||
persona_is_file=False,
|
persona_is_file=False,
|
||||||
human_is_file=False,
|
human_is_file=False,
|
||||||
|
@ -606,8 +606,10 @@ def run(
|
|||||||
|
|
||||||
# create agent
|
# create agent
|
||||||
try:
|
try:
|
||||||
|
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user.id)
|
||||||
memgpt_agent = presets.create_agent_from_preset(
|
memgpt_agent = presets.create_agent_from_preset(
|
||||||
agent_state=agent_state,
|
agent_state=agent_state,
|
||||||
|
preset=preset,
|
||||||
interface=interface(),
|
interface=interface(),
|
||||||
)
|
)
|
||||||
save_agent(agent=memgpt_agent, ms=ms)
|
save_agent(agent=memgpt_agent, ms=ms)
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
import builtins
|
import builtins
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
@ -664,12 +665,18 @@ def configure():
|
|||||||
else:
|
else:
|
||||||
ms.create_user(user)
|
ms.create_user(user)
|
||||||
|
|
||||||
|
# create preset records in metadata store
|
||||||
|
from memgpt.presets.presets import add_default_presets
|
||||||
|
|
||||||
|
add_default_presets(user_id, ms)
|
||||||
|
|
||||||
|
|
||||||
class ListChoice(str, Enum):
|
class ListChoice(str, Enum):
|
||||||
agents = "agents"
|
agents = "agents"
|
||||||
humans = "humans"
|
humans = "humans"
|
||||||
personas = "personas"
|
personas = "personas"
|
||||||
sources = "sources"
|
sources = "sources"
|
||||||
|
presets = "presets"
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
@ -741,6 +748,22 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(table)
|
print(table)
|
||||||
|
elif arg == ListChoice.presets:
|
||||||
|
"""List all available presets"""
|
||||||
|
table = PrettyTable()
|
||||||
|
table.field_names = ["Name", "Description", "Sources", "Functions"]
|
||||||
|
for preset in ms.list_presets(user_id=user_id):
|
||||||
|
sources = ms.get_preset_sources(preset_id=preset.id)
|
||||||
|
table.add_row(
|
||||||
|
[
|
||||||
|
preset.name,
|
||||||
|
preset.description,
|
||||||
|
",".join([source.name for source in sources]),
|
||||||
|
# json.dumps(preset.functions_schema, indent=4)
|
||||||
|
",\n".join([f["name"] for f in preset.functions_schema]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
print(table)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown argument {arg}")
|
raise ValueError(f"Unknown argument {arg}")
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import os
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, List, Union, Optional, Tuple
|
from typing import Dict, List, Union, Optional, Tuple
|
||||||
|
|
||||||
from memgpt.data_types import AgentState, User
|
from memgpt.data_types import AgentState, User, Preset
|
||||||
from memgpt.cli.cli import QuickstartChoice
|
from memgpt.cli.cli import QuickstartChoice
|
||||||
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
|
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
|
||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
@ -28,7 +28,6 @@ class Client(object):
|
|||||||
:param debug: indicates whether to display debug messages.
|
:param debug: indicates whether to display debug messages.
|
||||||
"""
|
"""
|
||||||
self.auto_save = auto_save
|
self.auto_save = auto_save
|
||||||
|
|
||||||
# make sure everything is set up properly
|
# make sure everything is set up properly
|
||||||
# TODO: remove this eventually? for multi-user, we can't have a shared config directory
|
# TODO: remove this eventually? for multi-user, we can't have a shared config directory
|
||||||
MemGPTConfig.create_config_dir()
|
MemGPTConfig.create_config_dir()
|
||||||
@ -80,6 +79,11 @@ class Client(object):
|
|||||||
else:
|
else:
|
||||||
ms.create_user(self.user)
|
ms.create_user(self.user)
|
||||||
|
|
||||||
|
# create preset records in metadata store
|
||||||
|
from memgpt.presets.presets import add_default_presets
|
||||||
|
|
||||||
|
add_default_presets(self.user_id, ms)
|
||||||
|
|
||||||
self.interface = QueuingInterface(debug=debug)
|
self.interface = QueuingInterface(debug=debug)
|
||||||
self.server = SyncServer(default_interface=self.interface)
|
self.server = SyncServer(default_interface=self.interface)
|
||||||
|
|
||||||
@ -114,6 +118,10 @@ class Client(object):
|
|||||||
agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config)
|
agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config)
|
||||||
return agent_state
|
return agent_state
|
||||||
|
|
||||||
|
def create_preset(self, preset: Preset):
|
||||||
|
preset = self.server.create_preset(preset=preset)
|
||||||
|
return preset
|
||||||
|
|
||||||
def get_agent_config(self, agent_id: str) -> Dict:
|
def get_agent_config(self, agent_id: str) -> Dict:
|
||||||
self.interface.clear()
|
self.interface.clear()
|
||||||
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)
|
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)
|
||||||
|
@ -9,6 +9,9 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON
|
|||||||
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
|
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
|
||||||
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
|
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
|
||||||
from memgpt.models import chat_completion_response
|
from memgpt.models import chat_completion_response
|
||||||
|
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, Json
|
||||||
|
|
||||||
|
|
||||||
class Record:
|
class Record:
|
||||||
@ -494,3 +497,24 @@ class Source:
|
|||||||
# embedding info (optional)
|
# embedding info (optional)
|
||||||
self.embedding_dim = embedding_dim
|
self.embedding_dim = embedding_dim
|
||||||
self.embedding_model = embedding_model
|
self.embedding_model = embedding_model
|
||||||
|
|
||||||
|
|
||||||
|
class Preset(BaseModel):
|
||||||
|
name: str = Field(..., description="The name of the preset.")
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||||
|
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
|
||||||
|
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
|
||||||
|
system: str = Field(..., description="The system prompt of the preset.")
|
||||||
|
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||||
|
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||||
|
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||||
|
# functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID
|
||||||
|
# sources: List[str] = Field(..., description="The sources of the preset.") # TODO: convert to ID
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: str = Field(..., description="The name of the function.")
|
||||||
|
id: uuid.UUID = Field(..., description="The unique identifier of the function.")
|
||||||
|
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the function.")
|
||||||
|
# TODO: figure out how represent functions
|
||||||
|
@ -3,7 +3,7 @@ import os
|
|||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
|
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
|
||||||
from memgpt.utils import get_local_time, enforce_types
|
from memgpt.utils import get_local_time, enforce_types
|
||||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig
|
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Preset
|
||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
from memgpt.agent import Agent
|
from memgpt.agent import Agent
|
||||||
|
|
||||||
@ -197,6 +197,67 @@ class AgentSourceMappingModel(Base):
|
|||||||
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
|
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
|
||||||
|
|
||||||
|
|
||||||
|
class PresetSourceMapping(Base):
|
||||||
|
__tablename__ = "preset_source_mapping"
|
||||||
|
|
||||||
|
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||||
|
user_id = Column(CommonUUID, nullable=False)
|
||||||
|
preset_id = Column(CommonUUID, nullable=False)
|
||||||
|
source_id = Column(CommonUUID, nullable=False)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<PresetSourceMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', source_id='{self.source_id}')>"
|
||||||
|
|
||||||
|
|
||||||
|
# class PresetFunctionMapping(Base):
|
||||||
|
# __tablename__ = "preset_function_mapping"
|
||||||
|
#
|
||||||
|
# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||||
|
# user_id = Column(CommonUUID, nullable=False)
|
||||||
|
# preset_id = Column(CommonUUID, nullable=False)
|
||||||
|
# #function_id = Column(CommonUUID, nullable=False)
|
||||||
|
# function = Column(String, nullable=False) # TODO: convert to ID eventually
|
||||||
|
#
|
||||||
|
# def __repr__(self) -> str:
|
||||||
|
# return f"<PresetFunctionMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', function_id='{self.function_id}')>"
|
||||||
|
|
||||||
|
|
||||||
|
class PresetModel(Base):
|
||||||
|
"""Defines data model for storing Preset objects"""
|
||||||
|
|
||||||
|
__tablename__ = "presets"
|
||||||
|
__table_args__ = {"extend_existing": True}
|
||||||
|
|
||||||
|
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||||
|
user_id = Column(CommonUUID, nullable=False)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
description = Column(String)
|
||||||
|
system = Column(String)
|
||||||
|
human = Column(String)
|
||||||
|
persona = Column(String)
|
||||||
|
preset = Column(String)
|
||||||
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
|
functions_schema = Column(JSON)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Preset(id='{self.id}', name='{self.name}')>"
|
||||||
|
|
||||||
|
def to_record(self) -> Preset:
|
||||||
|
return Preset(
|
||||||
|
id=self.id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
name=self.name,
|
||||||
|
description=self.description,
|
||||||
|
system=self.system,
|
||||||
|
human=self.human,
|
||||||
|
persona=self.persona,
|
||||||
|
preset=self.preset,
|
||||||
|
created_at=self.created_at,
|
||||||
|
functions_schema=self.functions_schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MetadataStore:
|
class MetadataStore:
|
||||||
def __init__(self, config: MemGPTConfig):
|
def __init__(self, config: MemGPTConfig):
|
||||||
# TODO: get DB URI or path
|
# TODO: get DB URI or path
|
||||||
@ -215,7 +276,15 @@ class MetadataStore:
|
|||||||
# Check if tables need to be created
|
# Check if tables need to be created
|
||||||
self.engine = create_engine(self.uri)
|
self.engine = create_engine(self.uri)
|
||||||
Base.metadata.create_all(
|
Base.metadata.create_all(
|
||||||
self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__]
|
self.engine,
|
||||||
|
tables=[
|
||||||
|
UserModel.__table__,
|
||||||
|
AgentModel.__table__,
|
||||||
|
SourceModel.__table__,
|
||||||
|
AgentSourceMappingModel.__table__,
|
||||||
|
PresetModel.__table__,
|
||||||
|
PresetSourceMapping.__table__,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
self.session_maker = sessionmaker(bind=self.engine)
|
self.session_maker = sessionmaker(bind=self.engine)
|
||||||
|
|
||||||
@ -246,6 +315,64 @@ class MetadataStore:
|
|||||||
session.add(UserModel(**vars(user)))
|
session.add(UserModel(**vars(user)))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def create_preset(self, preset: Preset):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0:
|
||||||
|
raise ValueError(f"User with id {preset.id} already exists")
|
||||||
|
session.add(PresetModel(**vars(preset)))
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_preset(
|
||||||
|
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
|
||||||
|
) -> Optional[Preset]:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
if preset_id:
|
||||||
|
results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
|
||||||
|
elif preset_name and user_id:
|
||||||
|
results = session.query(PresetModel).filter(PresetModel.name == preset_name).filter(PresetModel.user_id == user_id).all()
|
||||||
|
else:
|
||||||
|
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||||
|
return results[0].to_record()
|
||||||
|
|
||||||
|
# @enforce_types
|
||||||
|
# def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]):
|
||||||
|
# preset = self.get_preset(preset_id)
|
||||||
|
# if preset is None:
|
||||||
|
# raise ValueError(f"Preset with id {preset_id} does not exist")
|
||||||
|
# user_id = preset.user_id
|
||||||
|
# with self.session_maker() as session:
|
||||||
|
# for function in functions:
|
||||||
|
# session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function))
|
||||||
|
# session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]):
|
||||||
|
preset = self.get_preset(preset_id)
|
||||||
|
if preset is None:
|
||||||
|
raise ValueError(f"Preset with id {preset_id} does not exist")
|
||||||
|
user_id = preset.user_id
|
||||||
|
with self.session_maker() as session:
|
||||||
|
for source_id in sources:
|
||||||
|
session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id))
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# @enforce_types
|
||||||
|
# def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]:
|
||||||
|
# with self.session_maker() as session:
|
||||||
|
# results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all()
|
||||||
|
# return [r.function for r in results]
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all()
|
||||||
|
return [r.source_id for r in results]
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def update_agent(self, agent: AgentState):
|
def update_agent(self, agent: AgentState):
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
@ -298,6 +425,12 @@ class MetadataStore:
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||||
|
return [r.to_record() for r in results]
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
|
@ -1,17 +1,71 @@
|
|||||||
from memgpt.data_types import AgentState
|
from typing import List
|
||||||
|
from memgpt.data_types import AgentState, Preset
|
||||||
from memgpt.interface import AgentInterface
|
from memgpt.interface import AgentInterface
|
||||||
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
|
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
|
||||||
from memgpt.utils import get_human_text, get_persona_text
|
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||||
from memgpt.prompts import gpt_system
|
from memgpt.prompts import gpt_system
|
||||||
from memgpt.functions.functions import load_all_function_sets
|
from memgpt.functions.functions import load_all_function_sets
|
||||||
|
from memgpt.metadata import MetadataStore
|
||||||
|
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
available_presets = load_all_presets()
|
available_presets = load_all_presets()
|
||||||
preset_options = list(available_presets.keys())
|
preset_options = list(available_presets.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||||
|
"""Add the default presets to the metadata store"""
|
||||||
|
for preset_name in preset_options:
|
||||||
|
preset_config = available_presets[preset_name]
|
||||||
|
preset_system_prompt = preset_config["system_prompt"]
|
||||||
|
preset_function_set_names = preset_config["functions"]
|
||||||
|
functions_schema = generate_functions_json(preset_function_set_names)
|
||||||
|
|
||||||
|
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
|
||||||
|
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||||
|
continue
|
||||||
|
|
||||||
|
preset = Preset(
|
||||||
|
user_id=user_id,
|
||||||
|
name=preset_name,
|
||||||
|
system=gpt_system.get_system_text(preset_system_prompt),
|
||||||
|
persona=get_persona_text(DEFAULT_PERSONA),
|
||||||
|
human=get_human_text(DEFAULT_HUMAN),
|
||||||
|
functions_schema=functions_schema,
|
||||||
|
)
|
||||||
|
ms.create_preset(preset)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_functions_json(preset_functions: List[str]):
|
||||||
|
"""
|
||||||
|
Generate JSON schema for the functions based on what is locally available.
|
||||||
|
|
||||||
|
TODO: store function definitions in the DB, instead of locally
|
||||||
|
"""
|
||||||
|
# Available functions is a mapping from:
|
||||||
|
# function_name -> {
|
||||||
|
# json_schema: schema
|
||||||
|
# python_function: function
|
||||||
|
# }
|
||||||
|
available_functions = load_all_function_sets()
|
||||||
|
# Filter down the function set based on what the preset requested
|
||||||
|
preset_function_set = {}
|
||||||
|
for f_name in preset_functions:
|
||||||
|
if f_name not in available_functions:
|
||||||
|
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
|
||||||
|
preset_function_set[f_name] = available_functions[f_name]
|
||||||
|
assert len(preset_functions) == len(preset_function_set)
|
||||||
|
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
|
||||||
|
printd(f"Available functions:\n", list(preset_function_set.keys()))
|
||||||
|
return preset_function_set_schemas
|
||||||
|
|
||||||
|
|
||||||
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
||||||
def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True):
|
def create_agent_from_preset(
|
||||||
|
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
|
||||||
|
):
|
||||||
"""Initialize a new agent from a preset (combination of system + function)"""
|
"""Initialize a new agent from a preset (combination of system + function)"""
|
||||||
|
|
||||||
# Input validation
|
# Input validation
|
||||||
@ -25,40 +79,21 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface,
|
|||||||
raise ValueError(f"'state' must be uninitialized (empty)")
|
raise ValueError(f"'state' must be uninitialized (empty)")
|
||||||
|
|
||||||
preset_name = agent_state.preset
|
preset_name = agent_state.preset
|
||||||
|
assert preset_name == preset.name, f"AgentState preset '{preset_name}' does not match preset name '{preset.name}'"
|
||||||
persona = agent_state.persona
|
persona = agent_state.persona
|
||||||
human = agent_state.human
|
human = agent_state.human
|
||||||
model = agent_state.llm_config.model
|
model = agent_state.llm_config.model
|
||||||
|
|
||||||
from memgpt.agent import Agent
|
from memgpt.agent import Agent
|
||||||
from memgpt.utils import printd
|
|
||||||
|
|
||||||
# Available functions is a mapping from:
|
# available_presets = load_all_presets()
|
||||||
# function_name -> {
|
# if preset_name not in available_presets:
|
||||||
# json_schema: schema
|
# raise ValueError(f"Preset '{preset_name}.yaml' not found")
|
||||||
# python_function: function
|
|
||||||
# }
|
|
||||||
available_functions = load_all_function_sets()
|
|
||||||
|
|
||||||
available_presets = load_all_presets()
|
# preset = available_presets[preset_name]
|
||||||
if preset_name not in available_presets:
|
# preset_system_prompt = preset["system_prompt"]
|
||||||
raise ValueError(f"Preset '{preset_name}.yaml' not found")
|
# preset_function_set_names = preset["functions"]
|
||||||
|
# preset_function_set_schemas = generate_functions_json(preset_function_set_names)
|
||||||
preset = available_presets[preset_name]
|
|
||||||
if not is_valid_yaml_format(preset, list(available_functions.keys())):
|
|
||||||
raise ValueError(f"Preset '{preset_name}.yaml' is not valid")
|
|
||||||
|
|
||||||
preset_system_prompt = preset["system_prompt"]
|
|
||||||
preset_function_set_names = preset["functions"]
|
|
||||||
|
|
||||||
# Filter down the function set based on what the preset requested
|
|
||||||
preset_function_set = {}
|
|
||||||
for f_name in preset_function_set_names:
|
|
||||||
if f_name not in available_functions:
|
|
||||||
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
|
|
||||||
preset_function_set[f_name] = available_functions[f_name]
|
|
||||||
assert len(preset_function_set_names) == len(preset_function_set)
|
|
||||||
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
|
|
||||||
printd(f"Available functions:\n", list(preset_function_set.keys()))
|
|
||||||
|
|
||||||
# Override the following in the AgentState:
|
# Override the following in the AgentState:
|
||||||
# persona: str # the current persona text
|
# persona: str # the current persona text
|
||||||
@ -69,8 +104,8 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface,
|
|||||||
agent_state.state = {
|
agent_state.state = {
|
||||||
"persona": get_persona_text(persona) if persona_is_file else persona,
|
"persona": get_persona_text(persona) if persona_is_file else persona,
|
||||||
"human": get_human_text(human) if human_is_file else human,
|
"human": get_human_text(human) if human_is_file else human,
|
||||||
"system": gpt_system.get_system_text(preset_system_prompt),
|
"system": preset.system,
|
||||||
"functions": preset_function_set_schemas,
|
"functions": preset.functions_schema,
|
||||||
"messages": None,
|
"messages": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -31,6 +31,7 @@ from memgpt.data_types import (
|
|||||||
EmbeddingConfig,
|
EmbeddingConfig,
|
||||||
Message,
|
Message,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
Preset,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -582,7 +583,10 @@ class SyncServer(LockingServer):
|
|||||||
|
|
||||||
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
|
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
|
||||||
try:
|
try:
|
||||||
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface)
|
preset = self.ms.get_preset(preset_name=agent_state.preset, user_id=user_id)
|
||||||
|
assert preset is not None, f"preset {agent_state.preset} does not exist"
|
||||||
|
|
||||||
|
agent = presets.create_agent_from_preset(agent_state=agent_state, preset=preset, interface=interface)
|
||||||
save_agent(agent=agent, ms=self.ms)
|
save_agent(agent=agent, ms=self.ms)
|
||||||
|
|
||||||
# FIXME: this is a hacky way to get the system prompts injected into agent into the DB
|
# FIXME: this is a hacky way to get the system prompts injected into agent into the DB
|
||||||
@ -613,6 +617,20 @@ class SyncServer(LockingServer):
|
|||||||
if agent is not None:
|
if agent is not None:
|
||||||
self.ms.delete_agent(agent_id=agent_id)
|
self.ms.delete_agent(agent_id=agent_id)
|
||||||
|
|
||||||
|
def create_preset(self, preset: Preset):
|
||||||
|
"""Create a new preset using a config"""
|
||||||
|
if self.ms.get_user(user_id=preset.user_id) is None:
|
||||||
|
raise ValueError(f"User user_id={preset.user_id} does not exist")
|
||||||
|
|
||||||
|
self.ms.create_preset(preset)
|
||||||
|
return preset
|
||||||
|
|
||||||
|
def get_preset(
|
||||||
|
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None
|
||||||
|
) -> Preset:
|
||||||
|
"""Get the preset"""
|
||||||
|
return self.ms.get_preset(preset_id=preset_id, preset_name=preset_name, user_id=user_id)
|
||||||
|
|
||||||
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
|
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
|
||||||
"""Convert AgentState to a dict for a JSON response"""
|
"""Convert AgentState to a dict for a JSON response"""
|
||||||
assert agent_state is not None
|
assert agent_state is not None
|
||||||
|
@ -937,8 +937,11 @@ def list_human_files():
|
|||||||
memgpt_defaults = os.listdir(defaults_dir)
|
memgpt_defaults = os.listdir(defaults_dir)
|
||||||
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
||||||
|
|
||||||
|
if os.path.exists(user_dir):
|
||||||
user_added = os.listdir(user_dir)
|
user_added = os.listdir(user_dir)
|
||||||
user_added = [os.path.join(user_dir, f) for f in user_added]
|
user_added = [os.path.join(user_dir, f) for f in user_added]
|
||||||
|
else:
|
||||||
|
user_added = []
|
||||||
return memgpt_defaults + user_added
|
return memgpt_defaults + user_added
|
||||||
|
|
||||||
|
|
||||||
@ -950,8 +953,11 @@ def list_persona_files():
|
|||||||
memgpt_defaults = os.listdir(defaults_dir)
|
memgpt_defaults = os.listdir(defaults_dir)
|
||||||
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
||||||
|
|
||||||
|
if os.path.exists(user_dir):
|
||||||
user_added = os.listdir(user_dir)
|
user_added = os.listdir(user_dir)
|
||||||
user_added = [os.path.join(user_dir, f) for f in user_added]
|
user_added = [os.path.join(user_dir, f) for f in user_added]
|
||||||
|
else:
|
||||||
|
user_added = []
|
||||||
return memgpt_defaults + user_added
|
return memgpt_defaults + user_added
|
||||||
|
|
||||||
|
|
||||||
|
@ -4,12 +4,18 @@ import os
|
|||||||
from memgpt import MemGPT
|
from memgpt import MemGPT
|
||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
from memgpt import constants
|
from memgpt import constants
|
||||||
from memgpt.data_types import LLMConfig, EmbeddingConfig
|
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
|
||||||
|
from memgpt.functions.functions import load_all_function_sets
|
||||||
|
from memgpt.prompts import gpt_system
|
||||||
|
from memgpt.constants import DEFAULT_PRESET
|
||||||
|
|
||||||
|
|
||||||
from .utils import wipe_config
|
from .utils import wipe_config
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||||
|
test_preset_name = "test_preset"
|
||||||
test_agent_state = None
|
test_agent_state = None
|
||||||
client = None
|
client = None
|
||||||
|
|
||||||
@ -17,7 +23,7 @@ test_agent_state_post_message = None
|
|||||||
test_user_id = uuid.uuid4()
|
test_user_id = uuid.uuid4()
|
||||||
|
|
||||||
|
|
||||||
def test_create_agent():
|
def test_create_preset():
|
||||||
wipe_config()
|
wipe_config()
|
||||||
global client
|
global client
|
||||||
if os.getenv("OPENAI_API_KEY"):
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
@ -25,6 +31,20 @@ def test_create_agent():
|
|||||||
else:
|
else:
|
||||||
client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
|
client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
|
||||||
|
|
||||||
|
available_functions = load_all_function_sets(merge=True)
|
||||||
|
functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()]
|
||||||
|
preset = Preset(
|
||||||
|
name=test_preset_name,
|
||||||
|
user_id=test_user_id,
|
||||||
|
description="A preset for testing the MemGPT client",
|
||||||
|
system=gpt_system.get_system_text(DEFAULT_PRESET),
|
||||||
|
functions_schema=functions_schema,
|
||||||
|
)
|
||||||
|
client.create_preset(preset)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_agent():
|
||||||
|
wipe_config()
|
||||||
config = MemGPTConfig.load()
|
config = MemGPTConfig.load()
|
||||||
|
|
||||||
# ensure user exists
|
# ensure user exists
|
||||||
@ -36,8 +56,7 @@ def test_create_agent():
|
|||||||
agent_config={
|
agent_config={
|
||||||
"user_id": test_user_id,
|
"user_id": test_user_id,
|
||||||
"name": test_agent_name,
|
"name": test_agent_name,
|
||||||
"persona": constants.DEFAULT_PERSONA,
|
"preset": test_preset_name,
|
||||||
"human": constants.DEFAULT_HUMAN,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
|
print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
|
||||||
@ -109,5 +128,6 @@ def test_save_load():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
test_create_preset()
|
||||||
test_create_agent()
|
test_create_agent()
|
||||||
test_user_message()
|
test_user_message()
|
||||||
|
@ -9,6 +9,7 @@ from memgpt.credentials import MemGPTCredentials
|
|||||||
from memgpt.server.server import SyncServer
|
from memgpt.server.server import SyncServer
|
||||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
||||||
from memgpt.embeddings import embedding_model
|
from memgpt.embeddings import embedding_model
|
||||||
|
from memgpt.presets.presets import add_default_presets
|
||||||
from .utils import wipe_config, wipe_memgpt_home
|
from .utils import wipe_config, wipe_memgpt_home
|
||||||
|
|
||||||
|
|
||||||
@ -86,6 +87,10 @@ def test_server():
|
|||||||
except:
|
except:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
# create presets
|
||||||
|
add_default_presets(user.id, server.ms)
|
||||||
|
|
||||||
|
# create agent
|
||||||
agent_state = server.create_agent(
|
agent_state = server.create_agent(
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"),
|
agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"),
|
||||||
|
@ -20,17 +20,13 @@ def create_test_agent():
|
|||||||
wipe_config()
|
wipe_config()
|
||||||
global client
|
global client
|
||||||
if os.getenv("OPENAI_API_KEY"):
|
if os.getenv("OPENAI_API_KEY"):
|
||||||
client = MemGPT(quickstart="openai")
|
client = MemGPT(quickstart="openai", user_id=test_user_id)
|
||||||
else:
|
else:
|
||||||
client = MemGPT(quickstart="memgpt_hosted")
|
client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
|
||||||
|
|
||||||
user = client.server.create_user({"id": test_user_id})
|
|
||||||
client.user_id = user.id
|
|
||||||
assert user is not None
|
|
||||||
|
|
||||||
agent_state = client.create_agent(
|
agent_state = client.create_agent(
|
||||||
agent_config={
|
agent_config={
|
||||||
"user_id": user.id,
|
"user_id": test_user_id,
|
||||||
"name": test_agent_name,
|
"name": test_agent_name,
|
||||||
"persona": constants.DEFAULT_PERSONA,
|
"persona": constants.DEFAULT_PERSONA,
|
||||||
"human": constants.DEFAULT_HUMAN,
|
"human": constants.DEFAULT_HUMAN,
|
||||||
|
@ -79,6 +79,7 @@ async def test_websockets():
|
|||||||
model_endpoint_type="openai",
|
model_endpoint_type="openai",
|
||||||
model_endpoint="https://api.openai.com/v1",
|
model_endpoint="https://api.openai.com/v1",
|
||||||
)
|
)
|
||||||
|
# TODO: get preset to pass in here
|
||||||
memgpt_agent = presets.create_agent_from_preset(agent_state, ws_interface)
|
memgpt_agent = presets.create_agent_from_preset(agent_state, ws_interface)
|
||||||
|
|
||||||
# Mock the user message packaging
|
# Mock the user message packaging
|
||||||
|
Loading…
Reference in New Issue
Block a user