refactor: store presets in database via metadata store (#1013)

This commit is contained in:
Sarah Wooders 2024-02-15 18:49:16 -08:00 committed by GitHub
parent d8f3d04268
commit c7fbc03e68
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 325 additions and 52 deletions

View File

@ -344,8 +344,10 @@ def create_autogen_memgpt_agent(
preset=agent_config["preset"], preset=agent_config["preset"],
) )
try: try:
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user_id)
memgpt_agent = presets.create_agent_from_preset( memgpt_agent = presets.create_agent_from_preset(
agent_state=agent_state, agent_state=agent_state,
preset=preset,
interface=interface, interface=interface,
persona_is_file=False, persona_is_file=False,
human_is_file=False, human_is_file=False,

View File

@ -606,8 +606,10 @@ def run(
# create agent # create agent
try: try:
preset = ms.get_preset(preset_name=agent_state.preset, user_id=user.id)
memgpt_agent = presets.create_agent_from_preset( memgpt_agent = presets.create_agent_from_preset(
agent_state=agent_state, agent_state=agent_state,
preset=preset,
interface=interface(), interface=interface(),
) )
save_agent(agent=memgpt_agent, ms=ms) save_agent(agent=memgpt_agent, ms=ms)

View File

@ -1,4 +1,5 @@
import builtins import builtins
import json
import os import os
import shutil import shutil
import uuid import uuid
@ -664,12 +665,18 @@ def configure():
else: else:
ms.create_user(user) ms.create_user(user)
# create preset records in metadata store
from memgpt.presets.presets import add_default_presets
add_default_presets(user_id, ms)
class ListChoice(str, Enum): class ListChoice(str, Enum):
agents = "agents" agents = "agents"
humans = "humans" humans = "humans"
personas = "personas" personas = "personas"
sources = "sources" sources = "sources"
presets = "presets"
@app.command() @app.command()
@ -741,6 +748,22 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
) )
print(table) print(table)
elif arg == ListChoice.presets:
"""List all available presets"""
table = PrettyTable()
table.field_names = ["Name", "Description", "Sources", "Functions"]
for preset in ms.list_presets(user_id=user_id):
sources = ms.get_preset_sources(preset_id=preset.id)
table.add_row(
[
preset.name,
preset.description,
",".join([source.name for source in sources]),
# json.dumps(preset.functions_schema, indent=4)
",\n".join([f["name"] for f in preset.functions_schema]),
]
)
print(table)
else: else:
raise ValueError(f"Unknown argument {arg}") raise ValueError(f"Unknown argument {arg}")

View File

@ -2,7 +2,7 @@ import os
import uuid import uuid
from typing import Dict, List, Union, Optional, Tuple from typing import Dict, List, Union, Optional, Tuple
from memgpt.data_types import AgentState, User from memgpt.data_types import AgentState, User, Preset
from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import QuickstartChoice
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
@ -28,7 +28,6 @@ class Client(object):
:param debug: indicates whether to display debug messages. :param debug: indicates whether to display debug messages.
""" """
self.auto_save = auto_save self.auto_save = auto_save
# make sure everything is set up properly # make sure everything is set up properly
# TODO: remove this eventually? for multi-user, we can't have a shared config directory # TODO: remove this eventually? for multi-user, we can't have a shared config directory
MemGPTConfig.create_config_dir() MemGPTConfig.create_config_dir()
@ -80,6 +79,11 @@ class Client(object):
else: else:
ms.create_user(self.user) ms.create_user(self.user)
# create preset records in metadata store
from memgpt.presets.presets import add_default_presets
add_default_presets(self.user_id, ms)
self.interface = QueuingInterface(debug=debug) self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface=self.interface) self.server = SyncServer(default_interface=self.interface)
@ -114,6 +118,10 @@ class Client(object):
agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config) agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config)
return agent_state return agent_state
def create_preset(self, preset: Preset):
preset = self.server.create_preset(preset=preset)
return preset
def get_agent_config(self, agent_id: str) -> Dict: def get_agent_config(self, agent_id: str) -> Dict:
self.interface.clear() self.interface.clear()
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id) return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)

View File

@ -9,6 +9,9 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.models import chat_completion_response from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd
from pydantic import BaseModel, Field, Json
class Record: class Record:
@ -494,3 +497,24 @@ class Source:
# embedding info (optional) # embedding info (optional)
self.embedding_dim = embedding_dim self.embedding_dim = embedding_dim
self.embedding_model = embedding_model self.embedding_model = embedding_model
class Preset(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID
# sources: List[str] = Field(..., description="The sources of the preset.") # TODO: convert to ID
class Function(BaseModel):
name: str = Field(..., description="The name of the function.")
id: uuid.UUID = Field(..., description="The unique identifier of the function.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the function.")
# TODO: figure out how represent functions

View File

@ -3,7 +3,7 @@ import os
from typing import Optional, List, Dict from typing import Optional, List, Dict
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS
from memgpt.utils import get_local_time, enforce_types from memgpt.utils import get_local_time, enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Preset
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.agent import Agent from memgpt.agent import Agent
@ -197,6 +197,67 @@ class AgentSourceMappingModel(Base):
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>" return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
class PresetSourceMapping(Base):
__tablename__ = "preset_source_mapping"
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
preset_id = Column(CommonUUID, nullable=False)
source_id = Column(CommonUUID, nullable=False)
def __repr__(self) -> str:
return f"<PresetSourceMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', source_id='{self.source_id}')>"
# class PresetFunctionMapping(Base):
# __tablename__ = "preset_function_mapping"
#
# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# user_id = Column(CommonUUID, nullable=False)
# preset_id = Column(CommonUUID, nullable=False)
# #function_id = Column(CommonUUID, nullable=False)
# function = Column(String, nullable=False) # TODO: convert to ID eventually
#
# def __repr__(self) -> str:
# return f"<PresetFunctionMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', function_id='{self.function_id}')>"
class PresetModel(Base):
"""Defines data model for storing Preset objects"""
__tablename__ = "presets"
__table_args__ = {"extend_existing": True}
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False)
description = Column(String)
system = Column(String)
human = Column(String)
persona = Column(String)
preset = Column(String)
created_at = Column(DateTime(timezone=True), server_default=func.now())
functions_schema = Column(JSON)
def __repr__(self) -> str:
return f"<Preset(id='{self.id}', name='{self.name}')>"
def to_record(self) -> Preset:
return Preset(
id=self.id,
user_id=self.user_id,
name=self.name,
description=self.description,
system=self.system,
human=self.human,
persona=self.persona,
preset=self.preset,
created_at=self.created_at,
functions_schema=self.functions_schema,
)
class MetadataStore: class MetadataStore:
def __init__(self, config: MemGPTConfig): def __init__(self, config: MemGPTConfig):
# TODO: get DB URI or path # TODO: get DB URI or path
@ -215,7 +276,15 @@ class MetadataStore:
# Check if tables need to be created # Check if tables need to be created
self.engine = create_engine(self.uri) self.engine = create_engine(self.uri)
Base.metadata.create_all( Base.metadata.create_all(
self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__] self.engine,
tables=[
UserModel.__table__,
AgentModel.__table__,
SourceModel.__table__,
AgentSourceMappingModel.__table__,
PresetModel.__table__,
PresetSourceMapping.__table__,
],
) )
self.session_maker = sessionmaker(bind=self.engine) self.session_maker = sessionmaker(bind=self.engine)
@ -246,6 +315,64 @@ class MetadataStore:
session.add(UserModel(**vars(user))) session.add(UserModel(**vars(user)))
session.commit() session.commit()
@enforce_types
def create_preset(self, preset: Preset):
with self.session_maker() as session:
if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0:
raise ValueError(f"User with id {preset.id} already exists")
session.add(PresetModel(**vars(preset)))
session.commit()
@enforce_types
def get_preset(
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
) -> Optional[Preset]:
with self.session_maker() as session:
if preset_id:
results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
elif preset_name and user_id:
results = session.query(PresetModel).filter(PresetModel.name == preset_name).filter(PresetModel.user_id == user_id).all()
else:
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
# @enforce_types
# def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]):
# preset = self.get_preset(preset_id)
# if preset is None:
# raise ValueError(f"Preset with id {preset_id} does not exist")
# user_id = preset.user_id
# with self.session_maker() as session:
# for function in functions:
# session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function))
# session.commit()
@enforce_types
def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]):
preset = self.get_preset(preset_id)
if preset is None:
raise ValueError(f"Preset with id {preset_id} does not exist")
user_id = preset.user_id
with self.session_maker() as session:
for source_id in sources:
session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id))
session.commit()
# @enforce_types
# def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]:
# with self.session_maker() as session:
# results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all()
# return [r.function for r in results]
@enforce_types
def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]:
with self.session_maker() as session:
results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all()
return [r.source_id for r in results]
@enforce_types @enforce_types
def update_agent(self, agent: AgentState): def update_agent(self, agent: AgentState):
with self.session_maker() as session: with self.session_maker() as session:
@ -298,6 +425,12 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types
def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
with self.session_maker() as session:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return [r.to_record() for r in results]
@enforce_types @enforce_types
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
with self.session_maker() as session: with self.session_maker() as session:

View File

@ -1,17 +1,71 @@
from memgpt.data_types import AgentState from typing import List
from memgpt.data_types import AgentState, Preset
from memgpt.interface import AgentInterface from memgpt.interface import AgentInterface
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
from memgpt.utils import get_human_text, get_persona_text from memgpt.utils import get_human_text, get_persona_text, printd
from memgpt.prompts import gpt_system from memgpt.prompts import gpt_system
from memgpt.functions.functions import load_all_function_sets from memgpt.functions.functions import load_all_function_sets
from memgpt.metadata import MetadataStore
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
import uuid
available_presets = load_all_presets() available_presets = load_all_presets()
preset_options = list(available_presets.keys()) preset_options = list(available_presets.keys())
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
"""Add the default presets to the metadata store"""
for preset_name in preset_options:
preset_config = available_presets[preset_name]
preset_system_prompt = preset_config["system_prompt"]
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
continue
preset = Preset(
user_id=user_id,
name=preset_name,
system=gpt_system.get_system_text(preset_system_prompt),
persona=get_persona_text(DEFAULT_PERSONA),
human=get_human_text(DEFAULT_HUMAN),
functions_schema=functions_schema,
)
ms.create_preset(preset)
def generate_functions_json(preset_functions: List[str]):
"""
Generate JSON schema for the functions based on what is locally available.
TODO: store function definitions in the DB, instead of locally
"""
# Available functions is a mapping from:
# function_name -> {
# json_schema: schema
# python_function: function
# }
available_functions = load_all_function_sets()
# Filter down the function set based on what the preset requested
preset_function_set = {}
for f_name in preset_functions:
if f_name not in available_functions:
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
preset_function_set[f_name] = available_functions[f_name]
assert len(preset_functions) == len(preset_function_set)
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
printd(f"Available functions:\n", list(preset_function_set.keys()))
return preset_function_set_schemas
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): # def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True): def create_agent_from_preset(
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
):
"""Initialize a new agent from a preset (combination of system + function)""" """Initialize a new agent from a preset (combination of system + function)"""
# Input validation # Input validation
@ -25,40 +79,21 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface,
raise ValueError(f"'state' must be uninitialized (empty)") raise ValueError(f"'state' must be uninitialized (empty)")
preset_name = agent_state.preset preset_name = agent_state.preset
assert preset_name == preset.name, f"AgentState preset '{preset_name}' does not match preset name '{preset.name}'"
persona = agent_state.persona persona = agent_state.persona
human = agent_state.human human = agent_state.human
model = agent_state.llm_config.model model = agent_state.llm_config.model
from memgpt.agent import Agent from memgpt.agent import Agent
from memgpt.utils import printd
# Available functions is a mapping from: # available_presets = load_all_presets()
# function_name -> { # if preset_name not in available_presets:
# json_schema: schema # raise ValueError(f"Preset '{preset_name}.yaml' not found")
# python_function: function
# }
available_functions = load_all_function_sets()
available_presets = load_all_presets() # preset = available_presets[preset_name]
if preset_name not in available_presets: # preset_system_prompt = preset["system_prompt"]
raise ValueError(f"Preset '{preset_name}.yaml' not found") # preset_function_set_names = preset["functions"]
# preset_function_set_schemas = generate_functions_json(preset_function_set_names)
preset = available_presets[preset_name]
if not is_valid_yaml_format(preset, list(available_functions.keys())):
raise ValueError(f"Preset '{preset_name}.yaml' is not valid")
preset_system_prompt = preset["system_prompt"]
preset_function_set_names = preset["functions"]
# Filter down the function set based on what the preset requested
preset_function_set = {}
for f_name in preset_function_set_names:
if f_name not in available_functions:
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
preset_function_set[f_name] = available_functions[f_name]
assert len(preset_function_set_names) == len(preset_function_set)
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
printd(f"Available functions:\n", list(preset_function_set.keys()))
# Override the following in the AgentState: # Override the following in the AgentState:
# persona: str # the current persona text # persona: str # the current persona text
@ -69,8 +104,8 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface,
agent_state.state = { agent_state.state = {
"persona": get_persona_text(persona) if persona_is_file else persona, "persona": get_persona_text(persona) if persona_is_file else persona,
"human": get_human_text(human) if human_is_file else human, "human": get_human_text(human) if human_is_file else human,
"system": gpt_system.get_system_text(preset_system_prompt), "system": preset.system,
"functions": preset_function_set_schemas, "functions": preset.functions_schema,
"messages": None, "messages": None,
} }

View File

@ -31,6 +31,7 @@ from memgpt.data_types import (
EmbeddingConfig, EmbeddingConfig,
Message, Message,
ToolCall, ToolCall,
Preset,
) )
@ -582,7 +583,10 @@ class SyncServer(LockingServer):
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}") logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
try: try:
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface) preset = self.ms.get_preset(preset_name=agent_state.preset, user_id=user_id)
assert preset is not None, f"preset {agent_state.preset} does not exist"
agent = presets.create_agent_from_preset(agent_state=agent_state, preset=preset, interface=interface)
save_agent(agent=agent, ms=self.ms) save_agent(agent=agent, ms=self.ms)
# FIXME: this is a hacky way to get the system prompts injected into agent into the DB # FIXME: this is a hacky way to get the system prompts injected into agent into the DB
@ -613,6 +617,20 @@ class SyncServer(LockingServer):
if agent is not None: if agent is not None:
self.ms.delete_agent(agent_id=agent_id) self.ms.delete_agent(agent_id=agent_id)
def create_preset(self, preset: Preset):
"""Create a new preset using a config"""
if self.ms.get_user(user_id=preset.user_id) is None:
raise ValueError(f"User user_id={preset.user_id} does not exist")
self.ms.create_preset(preset)
return preset
def get_preset(
self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None
) -> Preset:
"""Get the preset"""
return self.ms.get_preset(preset_id=preset_id, preset_name=preset_name, user_id=user_id)
def _agent_state_to_config(self, agent_state: AgentState) -> dict: def _agent_state_to_config(self, agent_state: AgentState) -> dict:
"""Convert AgentState to a dict for a JSON response""" """Convert AgentState to a dict for a JSON response"""
assert agent_state is not None assert agent_state is not None

View File

@ -937,8 +937,11 @@ def list_human_files():
memgpt_defaults = os.listdir(defaults_dir) memgpt_defaults = os.listdir(defaults_dir)
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")] memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
if os.path.exists(user_dir):
user_added = os.listdir(user_dir) user_added = os.listdir(user_dir)
user_added = [os.path.join(user_dir, f) for f in user_added] user_added = [os.path.join(user_dir, f) for f in user_added]
else:
user_added = []
return memgpt_defaults + user_added return memgpt_defaults + user_added
@ -950,8 +953,11 @@ def list_persona_files():
memgpt_defaults = os.listdir(defaults_dir) memgpt_defaults = os.listdir(defaults_dir)
memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")] memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")]
if os.path.exists(user_dir):
user_added = os.listdir(user_dir) user_added = os.listdir(user_dir)
user_added = [os.path.join(user_dir, f) for f in user_added] user_added = [os.path.join(user_dir, f) for f in user_added]
else:
user_added = []
return memgpt_defaults + user_added return memgpt_defaults + user_added

View File

@ -4,12 +4,18 @@ import os
from memgpt import MemGPT from memgpt import MemGPT
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt import constants from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
from memgpt.functions.functions import load_all_function_sets
from memgpt.prompts import gpt_system
from memgpt.constants import DEFAULT_PRESET
from .utils import wipe_config from .utils import wipe_config
import uuid import uuid
test_agent_name = f"test_client_{str(uuid.uuid4())}" test_agent_name = f"test_client_{str(uuid.uuid4())}"
test_preset_name = "test_preset"
test_agent_state = None test_agent_state = None
client = None client = None
@ -17,7 +23,7 @@ test_agent_state_post_message = None
test_user_id = uuid.uuid4() test_user_id = uuid.uuid4()
def test_create_agent(): def test_create_preset():
wipe_config() wipe_config()
global client global client
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
@ -25,6 +31,20 @@ def test_create_agent():
else: else:
client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
available_functions = load_all_function_sets(merge=True)
functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()]
preset = Preset(
name=test_preset_name,
user_id=test_user_id,
description="A preset for testing the MemGPT client",
system=gpt_system.get_system_text(DEFAULT_PRESET),
functions_schema=functions_schema,
)
client.create_preset(preset)
def test_create_agent():
wipe_config()
config = MemGPTConfig.load() config = MemGPTConfig.load()
# ensure user exists # ensure user exists
@ -36,8 +56,7 @@ def test_create_agent():
agent_config={ agent_config={
"user_id": test_user_id, "user_id": test_user_id,
"name": test_agent_name, "name": test_agent_name,
"persona": constants.DEFAULT_PERSONA, "preset": test_preset_name,
"human": constants.DEFAULT_HUMAN,
} }
) )
print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
@ -109,5 +128,6 @@ def test_save_load():
if __name__ == "__main__": if __name__ == "__main__":
test_create_preset()
test_create_agent() test_create_agent()
test_user_message() test_user_message()

View File

@ -9,6 +9,7 @@ from memgpt.credentials import MemGPTCredentials
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
from memgpt.embeddings import embedding_model from memgpt.embeddings import embedding_model
from memgpt.presets.presets import add_default_presets
from .utils import wipe_config, wipe_memgpt_home from .utils import wipe_config, wipe_memgpt_home
@ -86,6 +87,10 @@ def test_server():
except: except:
raise raise
# create presets
add_default_presets(user.id, server.ms)
# create agent
agent_state = server.create_agent( agent_state = server.create_agent(
user_id=user.id, user_id=user.id,
agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"), agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"),

View File

@ -20,17 +20,13 @@ def create_test_agent():
wipe_config() wipe_config()
global client global client
if os.getenv("OPENAI_API_KEY"): if os.getenv("OPENAI_API_KEY"):
client = MemGPT(quickstart="openai") client = MemGPT(quickstart="openai", user_id=test_user_id)
else: else:
client = MemGPT(quickstart="memgpt_hosted") client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
user = client.server.create_user({"id": test_user_id})
client.user_id = user.id
assert user is not None
agent_state = client.create_agent( agent_state = client.create_agent(
agent_config={ agent_config={
"user_id": user.id, "user_id": test_user_id,
"name": test_agent_name, "name": test_agent_name,
"persona": constants.DEFAULT_PERSONA, "persona": constants.DEFAULT_PERSONA,
"human": constants.DEFAULT_HUMAN, "human": constants.DEFAULT_HUMAN,

View File

@ -79,6 +79,7 @@ async def test_websockets():
model_endpoint_type="openai", model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1", model_endpoint="https://api.openai.com/v1",
) )
# TODO: get preset to pass in here
memgpt_agent = presets.create_agent_from_preset(agent_state, ws_interface) memgpt_agent = presets.create_agent_from_preset(agent_state, ws_interface)
# Mock the user message packaging # Mock the user message packaging