mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Update REST API routes GET information for agents/humans/personas and store humans/personas in DB (#1074)
This commit is contained in:
parent
04c6d210d3
commit
15dbe34dfe
@ -26,6 +26,7 @@ from memgpt.server.utils import shorten_key_middle
|
||||
from memgpt.data_types import User, LLMConfig, EmbeddingConfig, Source
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.server.utils import shorten_key_middle
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@ -761,20 +762,15 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
"""List all humans"""
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Name", "Text"]
|
||||
for human_file in utils.list_human_files():
|
||||
text = open(human_file, "r").read()
|
||||
name = os.path.basename(human_file).replace("txt", "")
|
||||
table.add_row([name, text])
|
||||
for human in ms.list_humans(user_id=user_id):
|
||||
table.add_row([human.name, human.text])
|
||||
print(table)
|
||||
elif arg == ListChoice.personas:
|
||||
"""List all personas"""
|
||||
table = PrettyTable()
|
||||
table.field_names = ["Name", "Text"]
|
||||
for persona_file in utils.list_persona_files():
|
||||
print(persona_file)
|
||||
text = open(persona_file, "r").read()
|
||||
name = os.path.basename(persona_file).replace(".txt", "")
|
||||
table.add_row([name, text])
|
||||
for persona in ms.list_personas(user_id=user_id):
|
||||
table.add_row([persona.name, persona.text])
|
||||
print(table)
|
||||
elif arg == ListChoice.sources:
|
||||
"""List all data sources"""
|
||||
@ -826,24 +822,16 @@ def add(
|
||||
filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None,
|
||||
):
|
||||
"""Add a person/human"""
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
ms = MetadataStore(config)
|
||||
if option == "persona":
|
||||
directory = os.path.join(MEMGPT_DIR, "personas")
|
||||
ms.add_persona(PersonaModel(name=name, text=text, user_id=user_id))
|
||||
elif option == "human":
|
||||
directory = os.path.join(MEMGPT_DIR, "humans")
|
||||
ms.add_human(HumanModel(name=name, text=text, user_id=user_id))
|
||||
else:
|
||||
raise ValueError(f"Unknown kind {option}")
|
||||
|
||||
if filename:
|
||||
assert text is None, f"Cannot provide both filename and text"
|
||||
# copy file to directory
|
||||
shutil.copyfile(filename, os.path.join(directory, name))
|
||||
if text:
|
||||
assert filename is None, f"Cannot provide both filename and text"
|
||||
# write text to file
|
||||
with open(os.path.join(directory, name), "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
|
||||
|
||||
@app.command()
|
||||
def delete(option: str, name: str):
|
||||
@ -886,6 +874,10 @@ def delete(option: str, name: str):
|
||||
# metadata
|
||||
ms.delete_agent(agent_id=agent.id)
|
||||
|
||||
elif option == "human":
|
||||
ms.delete_human(name=name, user_id=user_id)
|
||||
elif option == "persona":
|
||||
ms.delete_persona(name=name, user_id=user_id)
|
||||
else:
|
||||
raise ValueError(f"Option {option} not implemented")
|
||||
|
||||
|
@ -11,6 +11,8 @@ from memgpt.utils import get_local_time, enforce_types
|
||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
from memgpt.models.pydantic_models import PersonaModel, HumanModel
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
|
||||
@ -318,6 +320,8 @@ class MetadataStore:
|
||||
TokenModel.__table__,
|
||||
PresetModel.__table__,
|
||||
PresetSourceMapping.__table__,
|
||||
HumanModel.__table__,
|
||||
PersonaModel.__table__,
|
||||
],
|
||||
)
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
@ -599,3 +603,58 @@ class MetadataStore:
|
||||
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def add_human(self, human: HumanModel):
|
||||
with self.session_maker() as session:
|
||||
session.add(human)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def add_persona(self, persona: PersonaModel):
|
||||
with self.session_maker() as session:
|
||||
session.add(persona)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def get_human(self, name: str, user_id: uuid.UUID) -> str:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0]
|
||||
|
||||
@enforce_types
|
||||
def get_persona(self, name: str, user_id: uuid.UUID) -> str:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0]
|
||||
|
||||
@enforce_types
|
||||
def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]:
|
||||
with self.session_maker() as session:
|
||||
# if user_id matches provided user_id or if user_id is None
|
||||
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def delete_human(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_persona(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
@ -2,6 +2,11 @@ from typing import List, Union, Optional, Dict, Literal
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, Field, Json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
|
||||
|
||||
class LLMConfigModel(BaseModel):
|
||||
@ -20,14 +25,50 @@ class EmbeddingConfigModel(BaseModel):
|
||||
embedding_chunk_size: Optional[int] = 300
|
||||
|
||||
|
||||
class PresetModel(BaseModel):
|
||||
name: str = Field(..., description="The name of the preset.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
|
||||
system: str = Field(..., description="The system prompt of the preset.")
|
||||
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||
|
||||
|
||||
class AgentStateModel(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
description: str = Field(None, description="The description of the agent.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
|
||||
|
||||
# timestamps
|
||||
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
|
||||
|
||||
# preset information
|
||||
preset: str = Field(..., description="The preset used by the agent.")
|
||||
persona: str = Field(..., description="The persona used by the agent.")
|
||||
human: str = Field(..., description="The human used by the agent.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
|
||||
|
||||
# llm information
|
||||
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")
|
||||
|
||||
# agent state
|
||||
state: Optional[Dict] = Field(None, description="The state of the agent.")
|
||||
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
|
||||
|
||||
|
||||
class HumanModel(SQLModel, table=True):
|
||||
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
|
||||
name: str = Field(..., description="The name of the human.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
|
||||
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.")
|
||||
|
||||
|
||||
class PersonaModel(SQLModel, table=True):
|
||||
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
|
||||
name: str = Field(..., description="The name of the persona.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
|
||||
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")
|
||||
|
@ -1,12 +1,14 @@
|
||||
from typing import List
|
||||
import os
|
||||
from memgpt.data_types import AgentState, Preset
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd, list_human_files, list_persona_files
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
|
||||
import uuid
|
||||
|
||||
@ -15,15 +17,37 @@ available_presets = load_all_presets()
|
||||
preset_options = list(available_presets.keys())
|
||||
|
||||
|
||||
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
|
||||
for persona_file in list_persona_files():
|
||||
text = open(persona_file, "r").read()
|
||||
name = os.path.basename(persona_file).replace(".txt", "")
|
||||
if ms.get_persona(user_id=user_id, name=name) is not None:
|
||||
printd(f"Persona '{name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
persona = PersonaModel(name=name, text=text, user_id=user_id)
|
||||
ms.add_persona(persona)
|
||||
for human_file in list_human_files():
|
||||
text = open(human_file, "r").read()
|
||||
name = os.path.basename(human_file).replace(".txt", "")
|
||||
if ms.get_human(user_id=user_id, name=name) is not None:
|
||||
printd(f"Human '{name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
human = HumanModel(name=name, text=text, user_id=user_id)
|
||||
ms.add_human(human)
|
||||
|
||||
|
||||
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
"""Add the default presets to the metadata store"""
|
||||
# make sure humans/personas added
|
||||
add_default_humans_and_personas(user_id=user_id, ms=ms)
|
||||
|
||||
# add default presets
|
||||
for preset_name in preset_options:
|
||||
preset_config = available_presets[preset_name]
|
||||
preset_system_prompt = preset_config["system_prompt"]
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
print("PRESET", preset_name, user_id)
|
||||
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
|
||||
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
|
@ -5,8 +5,9 @@ from functools import partial
|
||||
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
|
||||
from memgpt.models.pydantic_models import AgentStateModel
|
||||
from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@ -14,7 +15,7 @@ from memgpt.server.server import SyncServer
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AgentConfigRequest(BaseModel):
|
||||
class GetAgentRequest(BaseModel):
|
||||
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
|
||||
|
||||
|
||||
@ -23,9 +24,11 @@ class AgentRenameRequest(BaseModel):
|
||||
agent_name: str = Field(..., description="New name for the agent.")
|
||||
|
||||
|
||||
class AgentConfigResponse(BaseModel):
|
||||
class GetAgentResponse(BaseModel):
|
||||
# config: dict = Field(..., description="The agent configuration object.")
|
||||
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
|
||||
sources: List[str] = Field(..., description="The list of data sources associated with the agent.")
|
||||
last_run_at: Optional[int] = Field(None, description="The unix timestamp of when the agent was last run.")
|
||||
|
||||
|
||||
def validate_agent_name(name: str) -> str:
|
||||
@ -48,7 +51,7 @@ def validate_agent_name(name: str) -> str:
|
||||
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/config", tags=["agents"], response_model=AgentConfigResponse)
|
||||
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
|
||||
def get_agent_config(
|
||||
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
@ -58,15 +61,36 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
||||
"""
|
||||
request = AgentConfigRequest(agent_id=agent_id)
|
||||
request = GetAgentRequest(agent_id=agent_id)
|
||||
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
|
||||
interface.clear()
|
||||
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
|
||||
return AgentConfigResponse(agent_state=agent_state)
|
||||
# return GetAgentResponse(agent_state=agent_state)
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
@router.patch("/agents/rename", tags=["agents"], response_model=AgentConfigResponse)
|
||||
return GetAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
user_id=agent_state.user_id,
|
||||
preset=agent_state.preset,
|
||||
persona=agent_state.persona,
|
||||
human=agent_state.human,
|
||||
llm_config=agent_state.llm_config,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||
),
|
||||
last_run_at=None, # TODO
|
||||
sources=attached_sources,
|
||||
)
|
||||
|
||||
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
|
||||
def update_agent_name(
|
||||
request: AgentRenameRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
@ -87,7 +111,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return AgentConfigResponse(agent_state=agent_state)
|
||||
return GetAgentResponse(agent_state=agent_state)
|
||||
|
||||
@router.delete("/agents", tags=["agents"])
|
||||
def delete_agent(
|
||||
@ -97,7 +121,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
request = AgentConfigRequest(agent_id=agent_id)
|
||||
request = GetAgentRequest(agent_id=agent_id)
|
||||
|
||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
|
||||
|
@ -69,6 +69,7 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||
)
|
||||
)
|
||||
# return CreateAgentResponse(
|
||||
|
@ -8,12 +8,13 @@ from pydantic import BaseModel, Field
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.models.pydantic_models import HumanModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListHumansResponse(BaseModel):
|
||||
humans: List[dict] = Field(..., description="List of human configurations.")
|
||||
humans: List[HumanModel] = Field(..., description="List of human configurations.")
|
||||
|
||||
|
||||
def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
@ -25,14 +26,7 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
|
||||
# TODO: Replace with actual data fetching logic once available
|
||||
humans_data = [
|
||||
{"name": "Marco", "text": "About Me"},
|
||||
{"name": "Sam", "text": "About Me 2"},
|
||||
{"name": "Bruce", "text": "About Me 3"},
|
||||
]
|
||||
|
||||
return ListHumansResponse(humans=humans_data)
|
||||
humans = server.ms.list_humans(user_id=user_id)
|
||||
return ListHumansResponse(humans=humans)
|
||||
|
||||
return router
|
||||
|
@ -8,12 +8,13 @@ from pydantic import BaseModel, Field
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.models.pydantic_models import PersonaModel
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListPersonasResponse(BaseModel):
|
||||
personas: List[dict] = Field(..., description="List of persona configurations.")
|
||||
personas: List[PersonaModel] = Field(..., description="List of persona configurations.")
|
||||
|
||||
|
||||
def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
@ -26,13 +27,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
|
||||
# TODO: Replace with actual data fetching logic once available
|
||||
personas_data = [
|
||||
{"name": "Persona 1", "text": "Details about Persona 1"},
|
||||
{"name": "Persona 2", "text": "Details about Persona 2"},
|
||||
{"name": "Persona 3", "text": "Details about Persona 3"},
|
||||
]
|
||||
|
||||
return ListPersonasResponse(personas=personas_data)
|
||||
personas = server.ms.list_personas(user_id=user_id)
|
||||
return ListPersonasResponse(personas=personas)
|
||||
|
||||
return router
|
||||
|
@ -1071,3 +1071,7 @@ class SyncServer(LockingServer):
|
||||
|
||||
# attach source to agent
|
||||
agent.attach_source(data_source.name, source_connector, self.ms)
|
||||
|
||||
def list_attached_sources(self, agent_id: uuid.UUID):
|
||||
# list all attached sources to an agent
|
||||
return self.ms.list_attached_sources(agent_id)
|
||||
|
20
poetry.lock
generated
20
poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohttp"
|
||||
@ -3822,7 +3822,6 @@ files = [
|
||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||
@ -4394,6 +4393,21 @@ sqlalchemy = ">=0.7"
|
||||
[package.extras]
|
||||
dev = ["pytest"]
|
||||
|
||||
[[package]]
|
||||
name = "sqlmodel"
|
||||
version = "0.0.16"
|
||||
description = "SQLModel, SQL databases in Python, designed for simplicity, compatibility, and robustness."
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "sqlmodel-0.0.16-py3-none-any.whl", hash = "sha256:b972f5d319580d6c37ecc417881f6ec4d1ad3ed3583d0ac0ed43234a28bf605a"},
|
||||
{file = "sqlmodel-0.0.16.tar.gz", hash = "sha256:966656f18a8e9a2d159eb215b07fb0cf5222acfae3362707ca611848a8a06bd1"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pydantic = ">=1.10.13,<3.0.0"
|
||||
SQLAlchemy = ">=2.0.0,<2.1.0"
|
||||
|
||||
[[package]]
|
||||
name = "stack-data"
|
||||
version = "0.6.3"
|
||||
@ -5533,4 +5547,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.12,>=3.10"
|
||||
content-hash = "a420d69943e7e1c2b6900ac9ac9d9fd51ff9e16ff245bcea8e590a276b736251"
|
||||
content-hash = "22ed7617b3b586152c81e359b366192a9f1708265933b8ed2dc3ad041b798264"
|
||||
|
@ -56,6 +56,7 @@ llama-index = "^0.10.6"
|
||||
llama-index-embeddings-openai = "^0.1.1"
|
||||
python-box = "^7.1.1"
|
||||
pytest-order = {version = "^1.2.0", optional = true}
|
||||
sqlmodel = "^0.0.16"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
|
@ -7,6 +7,9 @@ from memgpt.metadata import MetadataStore
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from memgpt.presets.presets import add_default_presets, add_default_humans_and_personas
|
||||
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
|
||||
@ -26,9 +29,22 @@ def test_storage(storage_connector):
|
||||
|
||||
ms = MetadataStore(config)
|
||||
|
||||
# generate data
|
||||
# users
|
||||
user_1 = User()
|
||||
user_2 = User()
|
||||
ms.create_user(user_1)
|
||||
ms.create_user(user_2)
|
||||
|
||||
# test adding defaults
|
||||
# TODO: move below
|
||||
add_default_humans_and_personas(user_id=user_1.id, ms=ms)
|
||||
add_default_humans_and_personas(user_id=user_2.id, ms=ms)
|
||||
ms.add_human(human=HumanModel(name="test_human", text="This is a test human"))
|
||||
ms.add_persona(persona=PersonaModel(name="test_persona", text="This is a test persona"))
|
||||
add_default_presets(user_id=user_1.id, ms=ms)
|
||||
add_default_presets(user_id=user_2.id, ms=ms)
|
||||
|
||||
# generate data
|
||||
agent_1 = AgentState(
|
||||
user_id=user_1.id,
|
||||
name="agent_1",
|
||||
@ -41,8 +57,6 @@ def test_storage(storage_connector):
|
||||
source_1 = Source(user_id=user_1.id, name="source_1")
|
||||
|
||||
# test creation
|
||||
ms.create_user(user_1)
|
||||
ms.create_user(user_2)
|
||||
ms.create_agent(agent_1)
|
||||
ms.create_source(source_1)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user