mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
refactor: store presets in database via metadata store (#1013)
This commit is contained in:
parent
d8f3d04268
commit
c7fbc03e68
@ -344,8 +344,10 @@ def create_autogen_memgpt_agent(
|
||||
preset=agent_config["preset"],
|
||||
)
|
||||
try:
|
||||
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user_id)
|
||||
memgpt_agent = presets.create_agent_from_preset(
|
||||
agent_state=agent_state,
|
||||
preset=preset,
|
||||
interface=interface,
|
||||
persona_is_file=False,
|
||||
human_is_file=False,
|
||||
|
@ -606,8 +606,10 @@ def run(
|
||||
|
||||
# create agent
|
||||
try:
|
||||
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user.id)
|
||||
memgpt_agent = presets.create_agent_from_preset(
|
||||
agent_state=agent_state,
|
||||
preset=preset,
|
||||
interface=interface(),
|
||||
)
|
||||
save_agent(agent=memgpt_agent, ms=ms)
|
||||
|
@ -1,4 +1,5 @@
|
||||
import builtins
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import uuid
|
||||
@ -664,12 +665,18 @@ def configure():
|
||||
else:
|
||||
ms.create_user(user)
|
||||
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(user_id, ms)
|
||||
|
||||
|
||||
class ListChoice(str, Enum):
|
||||
agents = "agents"
|
||||
humans = "humans"
|
||||
personas = "personas"
|
||||
sources = "sources"
|
||||
presets = "presets"
|
||||
|
||||
|
||||
@app.command()
|
||||
@ -741,6 +748,22 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
)
|
||||
|
||||
print(table)
|
||||
elif arg == ListChoice.presets:
|
||||
"""List all available presets"""
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Name", "Description", "Sources", "Functions"]
|
||||
for preset in ms.list_presets(user_id=user_id):
|
||||
sources = ms.get_preset_sources(preset_id=preset.id)
|
||||
table.add_row(
|
||||
[
|
||||
preset.name,
|
||||
preset.description,
|
||||
",".join([source.name for source in sources]),
|
||||
# json.dumps(preset.functions_schema, indent=4)
|
||||
",\n".join([f["name"] for f in preset.functions_schema]),
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
else:
|
||||
raise ValueError(f"Unknown argument {arg}")
|
||||
|
||||
|
@ -2,7 +2,7 @@ import os
|
||||
import uuid
|
||||
from typing import Dict, List, Union, Optional, Tuple
|
||||
|
||||
from memgpt.data_types import AgentState, User
|
||||
from memgpt.data_types import AgentState, User, Preset
|
||||
from memgpt.cli.cli import QuickstartChoice
|
||||
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
|
||||
from memgpt.config import MemGPTConfig
|
||||
@ -28,7 +28,6 @@ class Client(object):
|
||||
:param debug: indicates whether to display debug messages.
|
||||
"""
|
||||
self.auto_save = auto_save
|
||||
|
||||
# make sure everything is set up properly
|
||||
# TODO: remove this eventually? for multi-user, we can't have a shared config directory
|
||||
MemGPTConfig.create_config_dir()
|
||||
@ -80,6 +79,11 @@ class Client(object):
|
||||
else:
|
||||
ms.create_user(self.user)
|
||||
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(self.user_id, ms)
|
||||
|
||||
self.interface = QueuingInterface(debug=debug)
|
||||
self.server = SyncServer(default_interface=self.interface)
|
||||
|
||||
@ -114,6 +118,10 @@ class Client(object):
|
||||
agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config)
|
||||
return agent_state
|
||||
|
||||
def create_preset(self, preset: Preset):
|
||||
preset = self.server.create_preset(preset=preset)
|
||||
return preset
|
||||
|
||||
def get_agent_config(self, agent_id: str) -> Dict:
|
||||
self.interface.clear()
|
||||
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)
|
||||
|
@ -9,6 +9,9 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON
|
||||
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
|
||||
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
|
||||
from pydantic import BaseModel, Field, Json
|
||||
|
||||
|
||||
class Record:
|
||||
@ -494,3 +497,24 @@ class Source:
|
||||
# embedding info (optional)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
|
||||
class Preset(BaseModel):
|
||||
name: str = Field(..., description="The name of the preset.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
|
||||
system: str = Field(..., description="The system prompt of the preset.")
|
||||
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||
# functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID
|
||||
# sources: List[str] = Field(..., description="The sources of the preset.") # TODO: convert to ID
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
id: uuid.UUID = Field(..., description="The unique identifier of the function.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the function.")
|
||||
# TODO: figure out how represent functions
|
||||
|
@ -3,7 +3,7 @@ import os
|
||||
from typing import Optional, List, Dict
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
|
||||
from memgpt.utils import get_local_time, enforce_types
|
||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig
|
||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Preset
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.agent import Agent
|
||||
|
||||
@ -197,6 +197,67 @@ class AgentSourceMappingModel(Base):
|
||||
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
|
||||
|
||||
|
||||
class PresetSourceMapping(Base):
|
||||
__tablename__ = "preset_source_mapping"
|
||||
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
preset_id = Column(CommonUUID, nullable=False)
|
||||
source_id = Column(CommonUUID, nullable=False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<PresetSourceMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', source_id='{self.source_id}')>"
|
||||
|
||||
|
||||
# class PresetFunctionMapping(Base):
|
||||
# __tablename__ = "preset_function_mapping"
|
||||
#
|
||||
# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# user_id = Column(CommonUUID, nullable=False)
|
||||
# preset_id = Column(CommonUUID, nullable=False)
|
||||
# #function_id = Column(CommonUUID, nullable=False)
|
||||
# function = Column(String, nullable=False) # TODO: convert to ID eventually
|
||||
#
|
||||
# def __repr__(self) -> str:
|
||||
# return f"<PresetFunctionMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', function_id='{self.function_id}')>"
|
||||
|
||||
|
||||
class PresetModel(Base):
|
||||
"""Defines data model for storing Preset objects"""
|
||||
|
||||
__tablename__ = "presets"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(String)
|
||||
system = Column(String)
|
||||
human = Column(String)
|
||||
persona = Column(String)
|
||||
preset = Column(String)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
functions_schema = Column(JSON)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Preset(id='{self.id}', name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> Preset:
|
||||
return Preset(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
description=self.description,
|
||||
system=self.system,
|
||||
human=self.human,
|
||||
persona=self.persona,
|
||||
preset=self.preset,
|
||||
created_at=self.created_at,
|
||||
functions_schema=self.functions_schema,
|
||||
)
|
||||
|
||||
|
||||
class MetadataStore:
|
||||
def __init__(self, config: MemGPTConfig):
|
||||
# TODO: get DB URI or path
|
||||
@ -215,7 +276,15 @@ class MetadataStore:
|
||||
# Check if tables need to be created
|
||||
self.engine = create_engine(self.uri)
|
||||
Base.metadata.create_all(
|
||||
self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__]
|
||||
self.engine,
|
||||
tables=[
|
||||
UserModel.__table__,
|
||||
AgentModel.__table__,
|
||||
SourceModel.__table__,
|
||||
AgentSourceMappingModel.__table__,
|
||||
PresetModel.__table__,
|
||||
PresetSourceMapping.__table__,
|
||||
],
|
||||
)
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
|
||||
@ -246,6 +315,64 @@ class MetadataStore:
|
||||
session.add(UserModel(**vars(user)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_preset(self, preset: Preset):
|
||||
with self.session_maker() as session:
|
||||
if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0:
|
||||
raise ValueError(f"User with id {preset.id} already exists")
|
||||
session.add(PresetModel(**vars(preset)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def get_preset(
|
||||
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
|
||||
) -> Optional[Preset]:
|
||||
with self.session_maker() as session:
|
||||
if preset_id:
|
||||
results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
|
||||
elif preset_name and user_id:
|
||||
results = session.query(PresetModel).filter(PresetModel.name == preset_name).filter(PresetModel.user_id == user_id).all()
|
||||
else:
|
||||
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
# @enforce_types
|
||||
# def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]):
|
||||
# preset = self.get_preset(preset_id)
|
||||
# if preset is None:
|
||||
# raise ValueError(f"Preset with id {preset_id} does not exist")
|
||||
# user_id = preset.user_id
|
||||
# with self.session_maker() as session:
|
||||
# for function in functions:
|
||||
# session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function))
|
||||
# session.commit()
|
||||
|
||||
@enforce_types
|
||||
def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]):
|
||||
preset = self.get_preset(preset_id)
|
||||
if preset is None:
|
||||
raise ValueError(f"Preset with id {preset_id} does not exist")
|
||||
user_id = preset.user_id
|
||||
with self.session_maker() as session:
|
||||
for source_id in sources:
|
||||
session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id))
|
||||
session.commit()
|
||||
|
||||
# @enforce_types
|
||||
# def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]:
|
||||
# with self.session_maker() as session:
|
||||
# results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all()
|
||||
# return [r.function for r in results]
|
||||
|
||||
@enforce_types
|
||||
def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all()
|
||||
return [r.source_id for r in results]
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent: AgentState):
|
||||
with self.session_maker() as session:
|
||||
@ -298,6 +425,12 @@ class MetadataStore:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
||||
with self.session_maker() as session:
|
||||
|
@ -1,17 +1,71 @@
|
||||
from memgpt.data_types import AgentState
|
||||
from typing import List
|
||||
from memgpt.data_types import AgentState, Preset
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||
|
||||
import uuid
|
||||
|
||||
|
||||
available_presets = load_all_presets()
|
||||
preset_options = list(available_presets.keys())
|
||||
|
||||
|
||||
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
"""Add the default presets to the metadata store"""
|
||||
for preset_name in preset_options:
|
||||
preset_config = available_presets[preset_name]
|
||||
preset_system_prompt = preset_config["system_prompt"]
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
|
||||
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
|
||||
preset = Preset(
|
||||
user_id=user_id,
|
||||
name=preset_name,
|
||||
system=gpt_system.get_system_text(preset_system_prompt),
|
||||
persona=get_persona_text(DEFAULT_PERSONA),
|
||||
human=get_human_text(DEFAULT_HUMAN),
|
||||
functions_schema=functions_schema,
|
||||
)
|
||||
ms.create_preset(preset)
|
||||
|
||||
|
||||
def generate_functions_json(preset_functions: List[str]):
|
||||
"""
|
||||
Generate JSON schema for the functions based on what is locally available.
|
||||
|
||||
TODO: store function definitions in the DB, instead of locally
|
||||
"""
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
available_functions = load_all_function_sets()
|
||||
# Filter down the function set based on what the preset requested
|
||||
preset_function_set = {}
|
||||
for f_name in preset_functions:
|
||||
if f_name not in available_functions:
|
||||
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
|
||||
preset_function_set[f_name] = available_functions[f_name]
|
||||
assert len(preset_functions) == len(preset_function_set)
|
||||
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
|
||||
printd(f"Available functions:\n", list(preset_function_set.keys()))
|
||||
return preset_function_set_schemas
|
||||
|
||||
|
||||
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
||||
def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True):
|
||||
def create_agent_from_preset(
|
||||
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
|
||||
):
|
||||
"""Initialize a new agent from a preset (combination of system + function)"""
|
||||
|
||||
# Input validation
|
||||
@ -25,40 +79,21 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface,
|
||||
raise ValueError(f"'state' must be uninitialized (empty)")
|
||||
|
||||
preset_name = agent_state.preset
|
||||
assert preset_name == preset.name, f"AgentState preset '{preset_name}' does not match preset name '{preset.name}'"
|
||||
persona = agent_state.persona
|
||||
human = agent_state.human
|
||||
model = agent_state.llm_config.model
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.utils import printd
|
||||
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
available_functions = load_all_function_sets()
|
||||
# available_presets = load_all_presets()
|
||||
# if preset_name not in available_presets:
|
||||
# raise ValueError(f"Preset '{preset_name}.yaml' not found")
|
||||
|
||||
available_presets = load_all_presets()
|
||||
if preset_name not in available_presets:
|
||||
raise ValueError(f"Preset '{preset_name}.yaml' not found")
|
||||
|
||||
preset = available_presets[preset_name]
|
||||
if not is_valid_yaml_format(preset, list(available_functions.keys())):
|
||||
raise ValueError(f"Preset '{preset_name}.yaml' is not valid")
|
||||
|
||||
preset_system_prompt = preset["system_prompt"]
|
||||
preset_function_set_names = preset["functions"]
|
||||
|
||||
# Filter down the function set based on what the preset requested
|
||||
preset_function_set = {}
|
||||
for f_name in preset_function_set_names:
|
||||
if f_name not in available_functions:
|
||||
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
|
||||
preset_function_set[f_name] = available_functions[f_name]
|
||||
assert len(preset_function_set_names) == len(preset_function_set)
|
||||
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
|
||||
printd(f"Available functions:\n", list(preset_function_set.keys()))
|
||||
# preset = available_presets[preset_name]
|
||||
# preset_system_prompt = preset["system_prompt"]
|
||||
# preset_function_set_names = preset["functions"]
|
||||
# preset_function_set_schemas = generate_functions_json(preset_function_set_names)
|
||||
|
||||
# Override the following in the AgentState:
|
||||
# persona: str # the current persona text
|
||||
@ -69,8 +104,8 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface,
|
||||
agent_state.state = {
|
||||
"persona": get_persona_text(persona) if persona_is_file else persona,
|
||||
"human": get_human_text(human) if human_is_file else human,
|
||||
"system": gpt_system.get_system_text(preset_system_prompt),
|
||||
"functions": preset_function_set_schemas,
|
||||
"system": preset.system,
|
||||
"functions": preset.functions_schema,
|
||||
"messages": None,
|
||||
}
|
||||
|
||||
|
@ -31,6 +31,7 @@ from memgpt.data_types import (
|
||||
EmbeddingConfig,
|
||||
Message,
|
||||
ToolCall,
|
||||
Preset,
|
||||
)
|
||||
|
||||
|
||||
@ -582,7 +583,10 @@ class SyncServer(LockingServer):
|
||||
|
||||
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
|
||||
try:
|
||||
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface)
|
||||
preset = self.ms.get_preset(preset_name=agent_state.preset, user_id=user_id)
|
||||
assert preset is not None, f"preset {agent_state.preset} does not exist"
|
||||
|
||||
agent = presets.create_agent_from_preset(agent_state=agent_state, preset=preset, interface=interface)
|
||||
save_agent(agent=agent, ms=self.ms)
|
||||
|
||||
# FIXME: this is a hacky way to get the system prompts injected into agent into the DB
|
||||
@ -613,6 +617,20 @@ class SyncServer(LockingServer):
|
||||
if agent is not None:
|
||||
self.ms.delete_agent(agent_id=agent_id)
|
||||
|
||||
def create_preset(self, preset: Preset):
|
||||
"""Create a new preset using a config"""
|
||||
if self.ms.get_user(user_id=preset.user_id) is None:
|
||||
raise ValueError(f"User user_id={preset.user_id} does not exist")
|
||||
|
||||
self.ms.create_preset(preset)
|
||||
return preset
|
||||
|
||||
def get_preset(
|
||||
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None
|
||||
) -> Preset:
|
||||
"""Get the preset"""
|
||||
return self.ms.get_preset(preset_id=preset_id, preset_name=preset_name, user_id=user_id)
|
||||
|
||||
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
|
||||
"""Convert AgentState to a dict for a JSON response"""
|
||||
assert agent_state is not None
|
||||
|
@ -937,8 +937,11 @@ def list_human_files():
|
||||
memgpt_defaults = os.listdir(defaults_dir)
|
||||
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
||||
|
||||
user_added = os.listdir(user_dir)
|
||||
user_added = [os.path.join(user_dir, f) for f in user_added]
|
||||
if os.path.exists(user_dir):
|
||||
user_added = os.listdir(user_dir)
|
||||
user_added = [os.path.join(user_dir, f) for f in user_added]
|
||||
else:
|
||||
user_added = []
|
||||
return memgpt_defaults + user_added
|
||||
|
||||
|
||||
@ -950,8 +953,11 @@ def list_persona_files():
|
||||
memgpt_defaults = os.listdir(defaults_dir)
|
||||
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
|
||||
|
||||
user_added = os.listdir(user_dir)
|
||||
user_added = [os.path.join(user_dir, f) for f in user_added]
|
||||
if os.path.exists(user_dir):
|
||||
user_added = os.listdir(user_dir)
|
||||
user_added = [os.path.join(user_dir, f) for f in user_added]
|
||||
else:
|
||||
user_added = []
|
||||
return memgpt_defaults + user_added
|
||||
|
||||
|
||||
|
@ -4,12 +4,18 @@ import os
|
||||
from memgpt import MemGPT
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import constants
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
|
||||
|
||||
from .utils import wipe_config
|
||||
import uuid
|
||||
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
test_preset_name = "test_preset"
|
||||
test_agent_state = None
|
||||
client = None
|
||||
|
||||
@ -17,7 +23,7 @@ test_agent_state_post_message = None
|
||||
test_user_id = uuid.uuid4()
|
||||
|
||||
|
||||
def test_create_agent():
|
||||
def test_create_preset():
|
||||
wipe_config()
|
||||
global client
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
@ -25,6 +31,20 @@ def test_create_agent():
|
||||
else:
|
||||
client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
|
||||
|
||||
available_functions = load_all_function_sets(merge=True)
|
||||
functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()]
|
||||
preset = Preset(
|
||||
name=test_preset_name,
|
||||
user_id=test_user_id,
|
||||
description="A preset for testing the MemGPT client",
|
||||
system=gpt_system.get_system_text(DEFAULT_PRESET),
|
||||
functions_schema=functions_schema,
|
||||
)
|
||||
client.create_preset(preset)
|
||||
|
||||
|
||||
def test_create_agent():
|
||||
wipe_config()
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# ensure user exists
|
||||
@ -36,8 +56,7 @@ def test_create_agent():
|
||||
agent_config={
|
||||
"user_id": test_user_id,
|
||||
"name": test_agent_name,
|
||||
"persona": constants.DEFAULT_PERSONA,
|
||||
"human": constants.DEFAULT_HUMAN,
|
||||
"preset": test_preset_name,
|
||||
}
|
||||
)
|
||||
print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
|
||||
@ -109,5 +128,6 @@ def test_save_load():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_preset()
|
||||
test_create_agent()
|
||||
test_user_message()
|
||||
|
@ -9,6 +9,7 @@ from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
from .utils import wipe_config, wipe_memgpt_home
|
||||
|
||||
|
||||
@ -86,6 +87,10 @@ def test_server():
|
||||
except:
|
||||
raise
|
||||
|
||||
# create presets
|
||||
add_default_presets(user.id, server.ms)
|
||||
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user.id,
|
||||
agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"),
|
||||
|
@ -20,17 +20,13 @@ def create_test_agent():
|
||||
wipe_config()
|
||||
global client
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
client = MemGPT(quickstart="openai")
|
||||
client = MemGPT(quickstart="openai", user_id=test_user_id)
|
||||
else:
|
||||
client = MemGPT(quickstart="memgpt_hosted")
|
||||
|
||||
user = client.server.create_user({"id": test_user_id})
|
||||
client.user_id = user.id
|
||||
assert user is not None
|
||||
client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
|
||||
|
||||
agent_state = client.create_agent(
|
||||
agent_config={
|
||||
"user_id": user.id,
|
||||
"user_id": test_user_id,
|
||||
"name": test_agent_name,
|
||||
"persona": constants.DEFAULT_PERSONA,
|
||||
"human": constants.DEFAULT_HUMAN,
|
||||
|
@ -79,6 +79,7 @@ async def test_websockets():
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
)
|
||||
# TODO: get preset to pass in here
|
||||
memgpt_agent = presets.create_agent_from_preset(agent_state, ws_interface)
|
||||
|
||||
# Mock the user message packaging
|
||||
|
Loading…
Reference in New Issue
Block a user