From c7fbc03e682effdda4f1861a30c2179bb008cd1d Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 15 Feb 2024 18:49:16 -0800 Subject: [PATCH] refactor: store presets in database via metadata store (#1013) --- memgpt/autogen/memgpt_agent.py | 2 + memgpt/cli/cli.py | 2 + memgpt/cli/cli_config.py | 23 +++++ memgpt/client/client.py | 12 ++- memgpt/data_types.py | 24 ++++++ memgpt/metadata.py | 137 +++++++++++++++++++++++++++++- memgpt/presets/presets.py | 99 ++++++++++++++------- memgpt/server/server.py | 20 ++++- memgpt/utils.py | 14 ++- tests/test_client.py | 28 +++++- tests/test_server.py | 5 ++ tests/test_summarize.py | 10 +-- tests/test_websocket_interface.py | 1 + 13 files changed, 325 insertions(+), 52 deletions(-) diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index e06f5bed6..35dab922b 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -344,8 +344,10 @@ def create_autogen_memgpt_agent( preset=agent_config["preset"], ) try: + preset = ms.get_preset(preset_name=agent_state.preset, user_id=user_id) memgpt_agent = presets.create_agent_from_preset( agent_state=agent_state, + preset=preset, interface=interface, persona_is_file=False, human_is_file=False, diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index a8352e0d1..8a778933c 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -606,8 +606,10 @@ def run( # create agent try: + preset = ms.get_preset(preset_name=agent_state.preset, user_id=user.id) memgpt_agent = presets.create_agent_from_preset( agent_state=agent_state, + preset=preset, interface=interface(), ) save_agent(agent=memgpt_agent, ms=ms) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 6e8aaf898..ed141bfbb 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -1,4 +1,5 @@ import builtins +import json import os import shutil import uuid @@ -664,12 +665,18 @@ def configure(): else: ms.create_user(user) + # create preset records in metadata store + from memgpt.presets.presets import add_default_presets + + add_default_presets(user_id, ms) + class ListChoice(str, Enum): agents = "agents" humans = "humans" personas = "personas" sources = "sources" + presets = "presets" @app.command() @@ -741,6 +748,22 @@ def list(arg: Annotated[ListChoice, typer.Argument]): ) print(table) + elif arg == ListChoice.presets: + """List all available presets""" + table = PrettyTable() + table.field_names = ["Name", "Description", "Sources", "Functions"] + for preset in ms.list_presets(user_id=user_id): + sources = ms.get_preset_sources(preset_id=preset.id) + table.add_row( + [ + preset.name, + preset.description, + ",".join([source.name for source in sources]), + # json.dumps(preset.functions_schema, indent=4) + ",\n".join([f["name"] for f in preset.functions_schema]), + ] + ) + print(table) else: raise ValueError(f"Unknown argument {arg}") diff --git a/memgpt/client/client.py b/memgpt/client/client.py index e84cab5f0..0e315e7e0 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -2,7 +2,7 @@ import os import uuid from typing import Dict, List, Union, Optional, Tuple -from memgpt.data_types import AgentState, User +from memgpt.data_types import AgentState, User, Preset from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.config import MemGPTConfig @@ -28,7 +28,6 @@ class Client(object): :param debug: indicates whether to display debug messages. """ self.auto_save = auto_save - # make sure everything is set up properly # TODO: remove this eventually? for multi-user, we can't have a shared config directory MemGPTConfig.create_config_dir() @@ -80,6 +79,11 @@ class Client(object): else: ms.create_user(self.user) + # create preset records in metadata store + from memgpt.presets.presets import add_default_presets + + add_default_presets(self.user_id, ms) + self.interface = QueuingInterface(debug=debug) self.server = SyncServer(default_interface=self.interface) @@ -114,6 +118,10 @@ class Client(object): agent_state = self.server.create_agent(user_id=self.user_id, agent_config=agent_config) return agent_state + def create_preset(self, preset: Preset): + preset = self.server.create_preset(preset=preset) + return preset + def get_agent_config(self, agent_id: str) -> Dict: self.interface.clear() return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id) diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 8dda6f44b..2fd347d37 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -9,6 +9,9 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string from memgpt.models import chat_completion_response +from memgpt.utils import get_human_text, get_persona_text, printd + +from pydantic import BaseModel, Field, Json class Record: @@ -494,3 +497,24 @@ class Source: # embedding info (optional) self.embedding_dim = embedding_dim self.embedding_model = embedding_model + + +class Preset(BaseModel): + name: str = Field(..., description="The name of the preset.") + id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.") + user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.") + description: Optional[str] = Field(None, description="The description of the preset.") + created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.") + system: str = Field(..., description="The system prompt of the preset.") + persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.") + human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.") + functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.") + # functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID + # sources: List[str] = Field(..., description="The sources of the preset.") # TODO: convert to ID + + +class Function(BaseModel): + name: str = Field(..., description="The name of the function.") + id: uuid.UUID = Field(..., description="The unique identifier of the function.") + user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the function.") + # TODO: figure out how represent functions diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 2741e02e1..32a8e2877 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -3,7 +3,7 @@ import os from typing import Optional, List, Dict from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS from memgpt.utils import get_local_time, enforce_types -from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig +from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Preset from memgpt.config import MemGPTConfig from memgpt.agent import Agent @@ -197,6 +197,67 @@ class AgentSourceMappingModel(Base): return f"" +class PresetSourceMapping(Base): + __tablename__ = "preset_source_mapping" + + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + user_id = Column(CommonUUID, nullable=False) + preset_id = Column(CommonUUID, nullable=False) + source_id = Column(CommonUUID, nullable=False) + + def __repr__(self) -> str: + return f"" + + +# class PresetFunctionMapping(Base): +# __tablename__ = "preset_function_mapping" +# +# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) +# user_id = Column(CommonUUID, nullable=False) +# preset_id = Column(CommonUUID, nullable=False) +# #function_id = Column(CommonUUID, nullable=False) +# function = Column(String, nullable=False) # TODO: convert to ID eventually +# +# def __repr__(self) -> str: +# return f"" + + +class PresetModel(Base): + """Defines data model for storing Preset objects""" + + __tablename__ = "presets" + __table_args__ = {"extend_existing": True} + + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + user_id = Column(CommonUUID, nullable=False) + name = Column(String, nullable=False) + description = Column(String) + system = Column(String) + human = Column(String) + persona = Column(String) + preset = Column(String) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + functions_schema = Column(JSON) + + def __repr__(self) -> str: + return f"" + + def to_record(self) -> Preset: + return Preset( + id=self.id, + user_id=self.user_id, + name=self.name, + description=self.description, + system=self.system, + human=self.human, + persona=self.persona, + preset=self.preset, + created_at=self.created_at, + functions_schema=self.functions_schema, + ) + + class MetadataStore: def __init__(self, config: MemGPTConfig): # TODO: get DB URI or path @@ -215,7 +276,15 @@ class MetadataStore: # Check if tables need to be created self.engine = create_engine(self.uri) Base.metadata.create_all( - self.engine, tables=[UserModel.__table__, AgentModel.__table__, SourceModel.__table__, AgentSourceMappingModel.__table__] + self.engine, + tables=[ + UserModel.__table__, + AgentModel.__table__, + SourceModel.__table__, + AgentSourceMappingModel.__table__, + PresetModel.__table__, + PresetSourceMapping.__table__, + ], ) self.session_maker = sessionmaker(bind=self.engine) @@ -246,6 +315,64 @@ class MetadataStore: session.add(UserModel(**vars(user))) session.commit() + @enforce_types + def create_preset(self, preset: Preset): + with self.session_maker() as session: + if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0: + raise ValueError(f"User with id {preset.id} already exists") + session.add(PresetModel(**vars(preset))) + session.commit() + + @enforce_types + def get_preset( + self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None + ) -> Optional[Preset]: + with self.session_maker() as session: + if preset_id: + results = session.query(PresetModel).filter(PresetModel.id == preset_id).all() + elif preset_name and user_id: + results = session.query(PresetModel).filter(PresetModel.name == preset_name).filter(PresetModel.user_id == user_id).all() + else: + raise ValueError("Must provide either preset_id or (preset_name and user_id)") + if len(results) == 0: + return None + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + return results[0].to_record() + + # @enforce_types + # def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]): + # preset = self.get_preset(preset_id) + # if preset is None: + # raise ValueError(f"Preset with id {preset_id} does not exist") + # user_id = preset.user_id + # with self.session_maker() as session: + # for function in functions: + # session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function)) + # session.commit() + + @enforce_types + def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]): + preset = self.get_preset(preset_id) + if preset is None: + raise ValueError(f"Preset with id {preset_id} does not exist") + user_id = preset.user_id + with self.session_maker() as session: + for source_id in sources: + session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id)) + session.commit() + + # @enforce_types + # def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]: + # with self.session_maker() as session: + # results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all() + # return [r.function for r in results] + + @enforce_types + def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]: + with self.session_maker() as session: + results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all() + return [r.source_id for r in results] + @enforce_types def update_agent(self, agent: AgentState): with self.session_maker() as session: @@ -298,6 +425,12 @@ class MetadataStore: session.commit() + @enforce_types + def list_presets(self, user_id: uuid.UUID) -> List[Preset]: + with self.session_maker() as session: + results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all() + return [r.to_record() for r in results] + @enforce_types def list_agents(self, user_id: uuid.UUID) -> List[AgentState]: with self.session_maker() as session: diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index 6325f930a..96ede89b3 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -1,17 +1,71 @@ -from memgpt.data_types import AgentState +from typing import List +from memgpt.data_types import AgentState, Preset from memgpt.interface import AgentInterface from memgpt.presets.utils import load_all_presets, is_valid_yaml_format -from memgpt.utils import get_human_text, get_persona_text +from memgpt.utils import get_human_text, get_persona_text, printd from memgpt.prompts import gpt_system from memgpt.functions.functions import load_all_function_sets +from memgpt.metadata import MetadataStore +from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET + +import uuid available_presets = load_all_presets() preset_options = list(available_presets.keys()) +def add_default_presets(user_id: uuid.UUID, ms: MetadataStore): + """Add the default presets to the metadata store""" + for preset_name in preset_options: + preset_config = available_presets[preset_name] + preset_system_prompt = preset_config["system_prompt"] + preset_function_set_names = preset_config["functions"] + functions_schema = generate_functions_json(preset_function_set_names) + + if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None: + printd(f"Preset '{preset_name}' already exists for user '{user_id}'") + continue + + preset = Preset( + user_id=user_id, + name=preset_name, + system=gpt_system.get_system_text(preset_system_prompt), + persona=get_persona_text(DEFAULT_PERSONA), + human=get_human_text(DEFAULT_HUMAN), + functions_schema=functions_schema, + ) + ms.create_preset(preset) + + +def generate_functions_json(preset_functions: List[str]): + """ + Generate JSON schema for the functions based on what is locally available. + + TODO: store function definitions in the DB, instead of locally + """ + # Available functions is a mapping from: + # function_name -> { + # json_schema: schema + # python_function: function + # } + available_functions = load_all_function_sets() + # Filter down the function set based on what the preset requested + preset_function_set = {} + for f_name in preset_functions: + if f_name not in available_functions: + raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}") + preset_function_set[f_name] = available_functions[f_name] + assert len(preset_functions) == len(preset_function_set) + preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()] + printd(f"Available functions:\n", list(preset_function_set.keys())) + return preset_function_set_schemas + + # def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager): -def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True): +def create_agent_from_preset( + agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True +): """Initialize a new agent from a preset (combination of system + function)""" # Input validation @@ -25,40 +79,21 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface, raise ValueError(f"'state' must be uninitialized (empty)") preset_name = agent_state.preset + assert preset_name == preset.name, f"AgentState preset '{preset_name}' does not match preset name '{preset.name}'" persona = agent_state.persona human = agent_state.human model = agent_state.llm_config.model from memgpt.agent import Agent - from memgpt.utils import printd - # Available functions is a mapping from: - # function_name -> { - # json_schema: schema - # python_function: function - # } - available_functions = load_all_function_sets() + # available_presets = load_all_presets() + # if preset_name not in available_presets: + # raise ValueError(f"Preset '{preset_name}.yaml' not found") - available_presets = load_all_presets() - if preset_name not in available_presets: - raise ValueError(f"Preset '{preset_name}.yaml' not found") - - preset = available_presets[preset_name] - if not is_valid_yaml_format(preset, list(available_functions.keys())): - raise ValueError(f"Preset '{preset_name}.yaml' is not valid") - - preset_system_prompt = preset["system_prompt"] - preset_function_set_names = preset["functions"] - - # Filter down the function set based on what the preset requested - preset_function_set = {} - for f_name in preset_function_set_names: - if f_name not in available_functions: - raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}") - preset_function_set[f_name] = available_functions[f_name] - assert len(preset_function_set_names) == len(preset_function_set) - preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()] - printd(f"Available functions:\n", list(preset_function_set.keys())) + # preset = available_presets[preset_name] + # preset_system_prompt = preset["system_prompt"] + # preset_function_set_names = preset["functions"] + # preset_function_set_schemas = generate_functions_json(preset_function_set_names) # Override the following in the AgentState: # persona: str # the current persona text @@ -69,8 +104,8 @@ def create_agent_from_preset(agent_state: AgentState, interface: AgentInterface, agent_state.state = { "persona": get_persona_text(persona) if persona_is_file else persona, "human": get_human_text(human) if human_is_file else human, - "system": gpt_system.get_system_text(preset_system_prompt), - "functions": preset_function_set_schemas, + "system": preset.system, + "functions": preset.functions_schema, "messages": None, } diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 2a15e4e48..ccd71ce05 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -31,6 +31,7 @@ from memgpt.data_types import ( EmbeddingConfig, Message, ToolCall, + Preset, ) @@ -582,7 +583,10 @@ class SyncServer(LockingServer): logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}") try: - agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface) + preset = self.ms.get_preset(preset_name=agent_state.preset, user_id=user_id) + assert preset is not None, f"preset {agent_state.preset} does not exist" + + agent = presets.create_agent_from_preset(agent_state=agent_state, preset=preset, interface=interface) save_agent(agent=agent, ms=self.ms) # FIXME: this is a hacky way to get the system prompts injected into agent into the DB @@ -613,6 +617,20 @@ class SyncServer(LockingServer): if agent is not None: self.ms.delete_agent(agent_id=agent_id) + def create_preset(self, preset: Preset): + """Create a new preset using a config""" + if self.ms.get_user(user_id=preset.user_id) is None: + raise ValueError(f"User user_id={preset.user_id} does not exist") + + self.ms.create_preset(preset) + return preset + + def get_preset( + self, preset_id: Optional[uuid.UUID] = None, preset_name: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None + ) -> Preset: + """Get the preset""" + return self.ms.get_preset(preset_id=preset_id, preset_name=preset_name, user_id=user_id) + def _agent_state_to_config(self, agent_state: AgentState) -> dict: """Convert AgentState to a dict for a JSON response""" assert agent_state is not None diff --git a/memgpt/utils.py b/memgpt/utils.py index 05e79526c..3527e17e0 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -937,8 +937,11 @@ def list_human_files(): memgpt_defaults = os.listdir(defaults_dir) memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")] - user_added = os.listdir(user_dir) - user_added = [os.path.join(user_dir, f) for f in user_added] + if os.path.exists(user_dir): + user_added = os.listdir(user_dir) + user_added = [os.path.join(user_dir, f) for f in user_added] + else: + user_added = [] return memgpt_defaults + user_added @@ -950,8 +953,11 @@ def list_persona_files(): memgpt_defaults = os.listdir(defaults_dir) memgpt_defaults = [os.path.join(defaults_dir, f) for f in memgpt_defaults if f.endswith(".txt")] - user_added = os.listdir(user_dir) - user_added = [os.path.join(user_dir, f) for f in user_added] + if os.path.exists(user_dir): + user_added = os.listdir(user_dir) + user_added = [os.path.join(user_dir, f) for f in user_added] + else: + user_added = [] return memgpt_defaults + user_added diff --git a/tests/test_client.py b/tests/test_client.py index f4ff2eeb7..daaf7b74c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -4,12 +4,18 @@ import os from memgpt import MemGPT from memgpt.config import MemGPTConfig from memgpt import constants -from memgpt.data_types import LLMConfig, EmbeddingConfig +from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset +from memgpt.functions.functions import load_all_function_sets +from memgpt.prompts import gpt_system +from memgpt.constants import DEFAULT_PRESET + + from .utils import wipe_config import uuid test_agent_name = f"test_client_{str(uuid.uuid4())}" +test_preset_name = "test_preset" test_agent_state = None client = None @@ -17,7 +23,7 @@ test_agent_state_post_message = None test_user_id = uuid.uuid4() -def test_create_agent(): +def test_create_preset(): wipe_config() global client if os.getenv("OPENAI_API_KEY"): @@ -25,6 +31,20 @@ def test_create_agent(): else: client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) + available_functions = load_all_function_sets(merge=True) + functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()] + preset = Preset( + name=test_preset_name, + user_id=test_user_id, + description="A preset for testing the MemGPT client", + system=gpt_system.get_system_text(DEFAULT_PRESET), + functions_schema=functions_schema, + ) + client.create_preset(preset) + + +def test_create_agent(): + wipe_config() config = MemGPTConfig.load() # ensure user exists @@ -36,8 +56,7 @@ def test_create_agent(): agent_config={ "user_id": test_user_id, "name": test_agent_name, - "persona": constants.DEFAULT_PERSONA, - "human": constants.DEFAULT_HUMAN, + "preset": test_preset_name, } ) print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}") @@ -109,5 +128,6 @@ def test_save_load(): if __name__ == "__main__": + test_create_preset() test_create_agent() test_user_message() diff --git a/tests/test_server.py b/tests/test_server.py index fd3b65f44..a855fc1fb 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -9,6 +9,7 @@ from memgpt.credentials import MemGPTCredentials from memgpt.server.server import SyncServer from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User from memgpt.embeddings import embedding_model +from memgpt.presets.presets import add_default_presets from .utils import wipe_config, wipe_memgpt_home @@ -86,6 +87,10 @@ def test_server(): except: raise + # create presets + add_default_presets(user.id, server.ms) + + # create agent agent_state = server.create_agent( user_id=user.id, agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"), diff --git a/tests/test_summarize.py b/tests/test_summarize.py index eb8692cc9..e4b5a03a6 100644 --- a/tests/test_summarize.py +++ b/tests/test_summarize.py @@ -20,17 +20,13 @@ def create_test_agent(): wipe_config() global client if os.getenv("OPENAI_API_KEY"): - client = MemGPT(quickstart="openai") + client = MemGPT(quickstart="openai", user_id=test_user_id) else: - client = MemGPT(quickstart="memgpt_hosted") - - user = client.server.create_user({"id": test_user_id}) - client.user_id = user.id - assert user is not None + client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id) agent_state = client.create_agent( agent_config={ - "user_id": user.id, + "user_id": test_user_id, "name": test_agent_name, "persona": constants.DEFAULT_PERSONA, "human": constants.DEFAULT_HUMAN, diff --git a/tests/test_websocket_interface.py b/tests/test_websocket_interface.py index 031f1d6f6..79b445a6b 100644 --- a/tests/test_websocket_interface.py +++ b/tests/test_websocket_interface.py @@ -79,6 +79,7 @@ async def test_websockets(): model_endpoint_type="openai", model_endpoint="https://api.openai.com/v1", ) + # TODO: get preset to pass in here memgpt_agent = presets.create_agent_from_preset(agent_state, ws_interface) # Mock the user message packaging