feat: Update REST API routes GET information for agents/humans/personas and store humans/personas in DB (#1074)

This commit is contained in:
Sarah Wooders 2024-03-02 13:07:24 -08:00 committed by GitHub
parent 04c6d210d3
commit 15dbe34dfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 222 additions and 59 deletions

View File

@ -26,6 +26,7 @@ from memgpt.server.utils import shorten_key_middle
from memgpt.data_types import User, LLMConfig, EmbeddingConfig, Source
from memgpt.metadata import MetadataStore
from memgpt.server.utils import shorten_key_middle
from memgpt.models.pydantic_models import HumanModel, PersonaModel
app = typer.Typer()
@ -761,20 +762,15 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
"""List all humans"""
table = PrettyTable()
table.field_names = ["Name", "Text"]
for human_file in utils.list_human_files():
text = open(human_file, "r").read()
name = os.path.basename(human_file).replace("txt", "")
table.add_row([name, text])
for human in ms.list_humans(user_id=user_id):
table.add_row([human.name, human.text])
print(table)
elif arg == ListChoice.personas:
"""List all personas"""
table = PrettyTable()
table.field_names = ["Name", "Text"]
for persona_file in utils.list_persona_files():
print(persona_file)
text = open(persona_file, "r").read()
name = os.path.basename(persona_file).replace(".txt", "")
table.add_row([name, text])
for persona in ms.list_personas(user_id=user_id):
table.add_row([persona.name, persona.text])
print(table)
elif arg == ListChoice.sources:
"""List all data sources"""
@ -826,24 +822,16 @@ def add(
filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None,
):
"""Add a person/human"""
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
if option == "persona":
directory = os.path.join(MEMGPT_DIR, "personas")
ms.add_persona(PersonaModel(name=name, text=text, user_id=user_id))
elif option == "human":
directory = os.path.join(MEMGPT_DIR, "humans")
ms.add_human(HumanModel(name=name, text=text, user_id=user_id))
else:
raise ValueError(f"Unknown kind {option}")
if filename:
assert text is None, f"Cannot provide both filename and text"
# copy file to directory
shutil.copyfile(filename, os.path.join(directory, name))
if text:
assert filename is None, f"Cannot provide both filename and text"
# write text to file
with open(os.path.join(directory, name), "w", encoding="utf-8") as f:
f.write(text)
@app.command()
def delete(option: str, name: str):
@ -886,6 +874,10 @@ def delete(option: str, name: str):
# metadata
ms.delete_agent(agent_id=agent.id)
elif option == "human":
ms.delete_human(name=name, user_id=user_id)
elif option == "persona":
ms.delete_persona(name=name, user_id=user_id)
else:
raise ValueError(f"Option {option} not implemented")

View File

@ -11,6 +11,8 @@ from memgpt.utils import get_local_time, enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.models.pydantic_models import PersonaModel, HumanModel
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
@ -318,6 +320,8 @@ class MetadataStore:
TokenModel.__table__,
PresetModel.__table__,
PresetSourceMapping.__table__,
HumanModel.__table__,
PersonaModel.__table__,
],
)
self.session_maker = sessionmaker(bind=self.engine)
@ -599,3 +603,58 @@ class MetadataStore:
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
).delete()
session.commit()
@enforce_types
def add_human(self, human: HumanModel):
with self.session_maker() as session:
session.add(human)
session.commit()
@enforce_types
def add_persona(self, persona: PersonaModel):
with self.session_maker() as session:
session.add(persona)
session.commit()
@enforce_types
def get_human(self, name: str, user_id: uuid.UUID) -> str:
with self.session_maker() as session:
results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
@enforce_types
def get_persona(self, name: str, user_id: uuid.UUID) -> str:
with self.session_maker() as session:
results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
@enforce_types
def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]:
with self.session_maker() as session:
results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all()
return results
@enforce_types
def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]:
with self.session_maker() as session:
# if user_id matches provided user_id or if user_id is None
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
return results
@enforce_types
def delete_human(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete()
session.commit()
@enforce_types
def delete_persona(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
session.commit()

View File

@ -2,6 +2,11 @@ from typing import List, Union, Optional, Dict, Literal
from enum import Enum
from pydantic import BaseModel, Field, Json
import uuid
from datetime import datetime
from sqlmodel import Field, SQLModel
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.utils import get_human_text, get_persona_text, printd
class LLMConfigModel(BaseModel):
@ -20,14 +25,50 @@ class EmbeddingConfigModel(BaseModel):
embedding_chunk_size: Optional[int] = 300
class PresetModel(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")
description: str = Field(None, description="The description of the agent.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
# timestamps
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
# preset information
preset: str = Field(..., description="The preset used by the agent.")
persona: str = Field(..., description="The persona used by the agent.")
human: str = Field(..., description="The human used by the agent.")
functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
# llm information
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")
# agent state
state: Optional[Dict] = Field(None, description="The state of the agent.")
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
class HumanModel(SQLModel, table=True):
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
name: str = Field(..., description="The name of the human.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.")
class PersonaModel(SQLModel, table=True):
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
name: str = Field(..., description="The name of the persona.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")

View File

@ -1,12 +1,14 @@
from typing import List
import os
from memgpt.data_types import AgentState, Preset
from memgpt.interface import AgentInterface
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
from memgpt.utils import get_human_text, get_persona_text, printd
from memgpt.utils import get_human_text, get_persona_text, printd, list_human_files, list_persona_files
from memgpt.prompts import gpt_system
from memgpt.functions.functions import load_all_function_sets
from memgpt.metadata import MetadataStore
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.models.pydantic_models import HumanModel, PersonaModel
import uuid
@ -15,15 +17,37 @@ available_presets = load_all_presets()
preset_options = list(available_presets.keys())
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
for persona_file in list_persona_files():
text = open(persona_file, "r").read()
name = os.path.basename(persona_file).replace(".txt", "")
if ms.get_persona(user_id=user_id, name=name) is not None:
printd(f"Persona '{name}' already exists for user '{user_id}'")
continue
persona = PersonaModel(name=name, text=text, user_id=user_id)
ms.add_persona(persona)
for human_file in list_human_files():
text = open(human_file, "r").read()
name = os.path.basename(human_file).replace(".txt", "")
if ms.get_human(user_id=user_id, name=name) is not None:
printd(f"Human '{name}' already exists for user '{user_id}'")
continue
human = HumanModel(name=name, text=text, user_id=user_id)
ms.add_human(human)
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
"""Add the default presets to the metadata store"""
# make sure humans/personas added
add_default_humans_and_personas(user_id=user_id, ms=ms)
# add default presets
for preset_name in preset_options:
preset_config = available_presets[preset_name]
preset_system_prompt = preset_config["system_prompt"]
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)
print("PRESET", preset_name, user_id)
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
continue

View File

@ -5,8 +5,9 @@ from functools import partial
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from typing import List, Optional
from memgpt.models.pydantic_models import AgentStateModel
from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -14,7 +15,7 @@ from memgpt.server.server import SyncServer
router = APIRouter()
class AgentConfigRequest(BaseModel):
class GetAgentRequest(BaseModel):
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
@ -23,9 +24,11 @@ class AgentRenameRequest(BaseModel):
agent_name: str = Field(..., description="New name for the agent.")
class AgentConfigResponse(BaseModel):
class GetAgentResponse(BaseModel):
# config: dict = Field(..., description="The agent configuration object.")
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
sources: List[str] = Field(..., description="The list of data sources associated with the agent.")
last_run_at: Optional[int] = Field(None, description="The unix timestamp of when the agent was last run.")
def validate_agent_name(name: str) -> str:
@ -48,7 +51,7 @@ def validate_agent_name(name: str) -> str:
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/config", tags=["agents"], response_model=AgentConfigResponse)
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
@ -58,15 +61,36 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
request = AgentConfigRequest(agent_id=agent_id)
request = GetAgentRequest(agent_id=agent_id)
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
attached_sources = server.list_attached_sources(agent_id=agent_id)
interface.clear()
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
return AgentConfigResponse(agent_state=agent_state)
# return GetAgentResponse(agent_state=agent_state)
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
@router.patch("/agents/rename", tags=["agents"], response_model=AgentConfigResponse)
return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=agent_state.llm_config,
embedding_config=agent_state.embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
),
last_run_at=None, # TODO
sources=attached_sources,
)
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
@ -87,7 +111,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return AgentConfigResponse(agent_state=agent_state)
return GetAgentResponse(agent_state=agent_state)
@router.delete("/agents", tags=["agents"])
def delete_agent(
@ -97,7 +121,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
"""
Delete an agent.
"""
request = AgentConfigRequest(agent_id=agent_id)
request = GetAgentRequest(agent_id=agent_id)
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None

View File

@ -69,6 +69,7 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
)
)
# return CreateAgentResponse(

View File

@ -8,12 +8,13 @@ from pydantic import BaseModel, Field
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.models.pydantic_models import HumanModel
router = APIRouter()
class ListHumansResponse(BaseModel):
humans: List[dict] = Field(..., description="List of human configurations.")
humans: List[HumanModel] = Field(..., description="List of human configurations.")
def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str):
@ -25,14 +26,7 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p
):
# Clear the interface
interface.clear()
# TODO: Replace with actual data fetching logic once available
humans_data = [
{"name": "Marco", "text": "About Me"},
{"name": "Sam", "text": "About Me 2"},
{"name": "Bruce", "text": "About Me 3"},
]
return ListHumansResponse(humans=humans_data)
humans = server.ms.list_humans(user_id=user_id)
return ListHumansResponse(humans=humans)
return router

View File

@ -8,12 +8,13 @@ from pydantic import BaseModel, Field
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.models.pydantic_models import PersonaModel
router = APIRouter()
class ListPersonasResponse(BaseModel):
personas: List[dict] = Field(..., description="List of persona configurations.")
personas: List[PersonaModel] = Field(..., description="List of persona configurations.")
def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str):
@ -26,13 +27,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
# Clear the interface
interface.clear()
# TODO: Replace with actual data fetching logic once available
personas_data = [
{"name": "Persona 1", "text": "Details about Persona 1"},
{"name": "Persona 2", "text": "Details about Persona 2"},
{"name": "Persona 3", "text": "Details about Persona 3"},
]
return ListPersonasResponse(personas=personas_data)
personas = server.ms.list_personas(user_id=user_id)
return ListPersonasResponse(personas=personas)
return router

View File

@ -1071,3 +1071,7 @@ class SyncServer(LockingServer):
# attach source to agent
agent.attach_source(data_source.name, source_connector, self.ms)
def list_attached_sources(self, agent_id: uuid.UUID):
# list all attached sources to an agent
return self.ms.list_attached_sources(agent_id)

20
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "aiohttp"
@ -3822,7 +3822,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -4394,6 +4393,21 @@ sqlalchemy = ">=0.7"
[package.extras]
dev = ["pytest"]
[[package]]
name = "sqlmodel"
version = "0.0.16"
description = "SQLModel, SQL databases in Python, designed for simplicity, compatibility, and robustness."
optional = false
python-versions = ">=3.7,<4.0"
files = [
{file = "sqlmodel-0.0.16-py3-none-any.whl", hash = "sha256:b972f5d319580d6c37ecc417881f6ec4d1ad3ed3583d0ac0ed43234a28bf605a"},
{file = "sqlmodel-0.0.16.tar.gz", hash = "sha256:966656f18a8e9a2d159eb215b07fb0cf5222acfae3362707ca611848a8a06bd1"},
]
[package.dependencies]
pydantic = ">=1.10.13,<3.0.0"
SQLAlchemy = ">=2.0.0,<2.1.0"
[[package]]
name = "stack-data"
version = "0.6.3"
@ -5533,4 +5547,4 @@ server = ["fastapi", "uvicorn", "websockets"]
[metadata]
lock-version = "2.0"
python-versions = "<3.12,>=3.10"
content-hash = "a420d69943e7e1c2b6900ac9ac9d9fd51ff9e16ff245bcea8e590a276b736251"
content-hash = "22ed7617b3b586152c81e359b366192a9f1708265933b8ed2dc3ad041b798264"

View File

@ -56,6 +56,7 @@ llama-index = "^0.10.6"
llama-index-embeddings-openai = "^0.1.1"
python-box = "^7.1.1"
pytest-order = {version = "^1.2.0", optional = true}
sqlmodel = "^0.0.16"
[tool.poetry.extras]
local = ["torch", "huggingface-hub", "transformers"]

View File

@ -7,6 +7,9 @@ from memgpt.metadata import MetadataStore
from memgpt.config import MemGPTConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
from memgpt.utils import get_human_text, get_persona_text
from memgpt.presets.presets import add_default_presets, add_default_humans_and_personas
from memgpt.models.pydantic_models import HumanModel, PersonaModel
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
@ -26,9 +29,22 @@ def test_storage(storage_connector):
ms = MetadataStore(config)
# generate data
# users
user_1 = User()
user_2 = User()
ms.create_user(user_1)
ms.create_user(user_2)
# test adding defaults
# TODO: move below
add_default_humans_and_personas(user_id=user_1.id, ms=ms)
add_default_humans_and_personas(user_id=user_2.id, ms=ms)
ms.add_human(human=HumanModel(name="test_human", text="This is a test human"))
ms.add_persona(persona=PersonaModel(name="test_persona", text="This is a test persona"))
add_default_presets(user_id=user_1.id, ms=ms)
add_default_presets(user_id=user_2.id, ms=ms)
# generate data
agent_1 = AgentState(
user_id=user_1.id,
name="agent_1",
@ -41,8 +57,6 @@ def test_storage(storage_connector):
source_1 = Source(user_id=user_1.id, name="source_1")
# test creation
ms.create_user(user_1)
ms.create_user(user_2)
ms.create_agent(agent_1)
ms.create_source(source_1)