feat: add Preset routes to API + patch for tool_call_id max length OpenAI error (#1165)

This commit is contained in:
Charles Packer 2024-03-20 17:05:06 -07:00 committed by GitHub
parent 091ee663de
commit 89cc4b98ed
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 272 additions and 18 deletions

View File

@ -5,7 +5,7 @@ import uuid
from typing import Dict, List, Union, Optional, Tuple
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
from memgpt.models.pydantic_models import HumanModel, PersonaModel
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
from memgpt.cli.cli import QuickstartChoice
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
from memgpt.config import MemGPTConfig
@ -30,6 +30,7 @@ from memgpt.server.rest_api.humans.index import ListHumansResponse
from memgpt.server.rest_api.personas.index import ListPersonasResponse
from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse
from memgpt.server.rest_api.models.index import ListModelsResponse
from memgpt.server.rest_api.presets.index import CreatePresetResponse, CreatePresetsRequest, ListPresetsResponse
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
@ -85,6 +86,12 @@ class AbstractClient(object):
def create_preset(self, preset: Preset):
raise NotImplementedError
def delete_preset(self, preset_id: uuid.UUID):
raise NotImplementedError
def list_presets(self):
raise NotImplementedError
# memory
def get_agent_memory(self, agent_id: str) -> Dict:
@ -300,11 +307,32 @@ class RESTClient(AbstractClient):
return self.get_agent_response_to_state(response_obj)
# presets
def create_preset(self, preset: Preset):
raise NotImplementedError
def create_preset(self, preset: Preset) -> CreatePresetResponse:
# TODO should the arg type here be PresetModel, not Preset?
payload = CreatePresetsRequest(
id=str(preset.id),
name=preset.name,
description=preset.description,
system=preset.system,
persona=preset.persona,
human=preset.human,
persona_name=preset.persona_name,
human_name=preset.human_name,
functions_schema=preset.functions_schema,
)
response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
assert response.status_code == 200, f"Failed to create preset: {response.text}"
return CreatePresetResponse(**response.json())
def delete_preset(self, preset_id: uuid.UUID):
response = requests.delete(f"{self.base_url}/api/presets/{str(preset_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete preset: {response.text}"
def list_presets(self) -> List[PresetModel]:
response = requests.get(f"{self.base_url}/api/presets", headers=self.headers)
return ListPresetsResponse(**response.json()).presets
# memory
def get_agent_memory(self, agent_id: uuid.UUID) -> GetAgentMemoryResponse:
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/memory", headers=self.headers)
return GetAgentMemoryResponse(**response.json())
@ -542,10 +570,18 @@ class LocalClient(AbstractClient):
)
return agent_state
def create_preset(self, preset: Preset):
def create_preset(self, preset: Preset) -> Preset:
if preset.user_id is None:
preset.user_id = self.user_id
preset = self.server.create_preset(preset=preset)
return preset
def delete_preset(self, preset_id: uuid.UUID):
preset = self.server.delete_preset(preset_id=preset_id, user_id=self.user_id)
def list_presets(self) -> List[PresetModel]:
return self.server.list_presets(user_id=self.user_id)
def get_agent_config(self, agent_id: str) -> AgentState:
self.interface.clear()
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)

View File

@ -3,6 +3,9 @@ from logging import CRITICAL, ERROR, WARN, WARNING, INFO, DEBUG, NOTSET
MEMGPT_DIR = os.path.join(os.path.expanduser("~"), ".memgpt")
# OpenAI error message: Invalid 'messages[1].tool_calls[0].id': string too long. Expected a string with maximum length 29, but got a string with length 36 instead.
TOOL_CALL_ID_MAX_LEN = 29
# embeddings
MAX_EMBEDDING_DIM = 4096 # maximum supported embeding size - do NOT change or else DBs will need to be reset

View File

@ -5,7 +5,15 @@ from datetime import datetime
from typing import Optional, List, Dict, TypeVar
import numpy as np
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
from memgpt.constants import (
DEFAULT_HUMAN,
DEFAULT_MEMGPT_MODEL,
DEFAULT_PERSONA,
DEFAULT_PRESET,
LLM_MAX_TOKENS,
MAX_EMBEDDING_DIM,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_uuid_from_string
from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd
@ -229,7 +237,7 @@ class Message(Record):
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)
def to_openai_dict(self):
def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN):
"""Go from Message class to ChatCompletion message object"""
# TODO change to pydantic casting, eg `return SystemMessageModel(self)`
@ -265,13 +273,16 @@ class Message(Record):
openai_message["name"] = self.name
if self.tool_calls is not None:
openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls]
if max_tool_id_length:
for tool_call_dict in openai_message["tool_calls"]:
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
elif self.role == "tool":
assert all([v is not None for v in [self.role, self.tool_call_id]]), vars(self)
openai_message = {
"content": self.text,
"role": self.role,
"tool_call_id": self.tool_call_id,
"tool_call_id": self.tool_call_id[:max_tool_id_length] if max_tool_id_length else self.tool_call_id,
}
else:
raise ValueError(self.role)
@ -540,7 +551,7 @@ class Token:
class Preset(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")

View File

@ -33,12 +33,14 @@ class EmbeddingConfigModel(BaseModel):
class PresetModel(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")

View File

@ -0,0 +1,121 @@
import uuid
from functools import partial
from typing import List, Optional, Dict, Union
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.data_types import Preset # TODO remove
from memgpt.models.pydantic_models import PresetModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.utils import get_human_text, get_persona_text
router = APIRouter()
"""
Implement the following functions:
* List all available presets
* Create a new preset
* Delete a preset
* TODO update a preset
"""
class ListPresetsResponse(BaseModel):
presets: List[PresetModel] = Field(..., description="List of available presets.")
class CreatePresetsRequest(BaseModel):
# TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)?
name: str = Field(..., description="The name of the preset.")
id: Optional[Union[uuid.UUID, str]] = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
# user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
# created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# TODO
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
class CreatePresetResponse(BaseModel):
preset: PresetModel = Field(..., description="The newly created preset.")
def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/presets", tags=["presets"], response_model=ListPresetsResponse)
async def list_presets(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all presets created by a user."""
# Clear the interface
interface.clear()
try:
presets = server.list_presets(user_id=user_id)
return ListPresetsResponse(presets=presets)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/presets", tags=["presets"], response_model=CreatePresetResponse)
async def create_preset(
request: CreatePresetsRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Create a preset."""
try:
if isinstance(request.id, str):
request.id = uuid.UUID(request.id)
# new_preset = PresetModel(
new_preset = Preset(
user_id=user_id,
id=request.id,
name=request.name,
description=request.description,
system=request.system,
persona=request.persona,
human=request.human,
functions_schema=request.functions_schema,
persona_name=request.persona_name,
human_name=request.human_name,
)
preset = server.create_preset(preset=new_preset)
# TODO remove once we migrate from Preset to PresetModel
preset = PresetModel(**vars(preset))
return CreatePresetResponse(preset=preset)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.delete("/presets/{preset_id}", tags=["presets"])
async def delete_preset(
preset_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a preset."""
interface.clear()
try:
preset = server.delete_preset(user_id=user_id, preset_id=preset_id)
return JSONResponse(
status_code=status.HTTP_200_OK, content={"message": f"Preset preset_id={str(preset.id)} successfully deleted"}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@ -25,6 +25,7 @@ from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_tools_index_router
from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.rest_api.presets.index import setup_presets_index_router
from memgpt.server.server import SyncServer
from memgpt.config import MemGPTConfig
from memgpt.server.constants import REST_DEFAULT_PORT
@ -102,6 +103,7 @@ app.include_router(setup_personas_index_router(server, interface, password), pre
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)
# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)

View File

@ -65,15 +65,20 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
async def list_source(
async def list_sources(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all data sources created by a user."""
# Clear the interface
interface.clear()
sources = server.list_all_sources(user_id=user_id)
return ListSourcesResponse(sources=sources)
try:
sources = server.list_all_sources(user_id=user_id)
return ListSourcesResponse(sources=sources)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources", tags=["sources"], response_model=SourceModel)
async def create_source(
@ -100,13 +105,13 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
@router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source(
source_id,
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a data source."""
interface.clear()
try:
server.delete_source(source_id=uuid.UUID(source_id), user_id=user_id)
server.delete_source(source_id=source_id, user_id=user_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
except HTTPException:
raise

View File

@ -37,7 +37,7 @@ from memgpt.data_types import (
Preset,
)
from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel
from memgpt.models.pydantic_models import SourceModel, PassageModel, DocumentModel, PresetModel
from memgpt.interface import AgentInterface # abstract
# TODO use custom interface
@ -751,13 +751,27 @@ class SyncServer(LockingServer):
if agent is not None:
self.ms.delete_agent(agent_id=agent_id)
def delete_preset(self, user_id: uuid.UUID, preset_id: uuid.UUID) -> Preset:
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
# first get the preset by name
preset = self.get_preset(preset_id=preset_id, user_id=user_id)
if preset is None:
raise ValueError(f"Could not find preset_id {preset_id}")
# then delete via name
# TODO allow delete-by-id, eg via server.delete_preset function
self.ms.delete_preset(name=preset.name, user_id=user_id)
return preset
def initialize_default_presets(self, user_id: uuid.UUID):
"""Add default preset options into the metadata store"""
presets.add_default_presets(user_id, self.ms)
def create_preset(self, preset: Preset):
"""Create a new preset using a config"""
if self.ms.get_user(user_id=preset.user_id) is None:
if preset.user_id is not None and self.ms.get_user(user_id=preset.user_id) is None:
raise ValueError(f"User user_id={preset.user_id} does not exist")
self.ms.create_preset(preset)
@ -769,6 +783,13 @@ class SyncServer(LockingServer):
"""Get the preset"""
return self.ms.get_preset(preset_id=preset_id, name=preset_name, user_id=user_id)
def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
# TODO update once we strip Preset in favor of PresetModel
presets = self.ms.list_presets(user_id=user_id)
presets = [PresetModel(**vars(p)) for p in presets]
return presets
def _agent_state_to_config(self, agent_state: AgentState) -> dict:
"""Convert AgentState to a dict for a JSON response"""
assert agent_state is not None

View File

@ -33,6 +33,7 @@ from memgpt.constants import (
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
JSON_ENSURE_ASCII,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.models.chat_completion_response import ChatCompletionResponse
@ -469,7 +470,7 @@ NOUN_BANK = [
def get_tool_call_id() -> str:
return str(uuid.uuid4())
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]
def assistant_function_to_tool(assistant_message: dict) -> dict:

View File

@ -7,6 +7,7 @@ from dotenv import load_dotenv
from memgpt.server.rest_api.server import start_server
from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
from memgpt.data_types import Preset # TODO move to PresetModel
from dotenv import load_dotenv
from tests.config import TestMGPTConfig
@ -287,3 +288,54 @@ def test_sources(client, agent):
# delete the source
client.delete_source(source.id)
def test_presets(client, agent):
new_preset = Preset(
# user_id=client.user_id,
name="pytest_test_preset",
description="DUMMY_DESCRIPTION",
system="DUMMY_SYSTEM",
persona="DUMMY_PERSONA",
persona_name="DUMMY_PERSONA_NAME",
human="DUMMY_HUMAN",
human_name="DUMMY_HUMAN_NAME",
functions_schema=[
{
"name": "send_message",
"json_schema": {
"name": "send_message",
"description": "Sends a message to the human user.",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
},
"required": ["message"],
},
},
"tags": ["memgpt-base"],
"source_type": "python",
"source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
}
],
)
# List all presets and make sure the preset is NOT in the list
all_presets = client.list_presets()
assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# Create a preset
client.create_preset(preset=new_preset)
# List all presets and make sure the preset is in the list
all_presets = client.list_presets()
assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
# Delete the preset
client.delete_preset(preset_id=new_preset.id)
# List all presets and make sure the preset is NOT in the list
all_presets = client.list_presets()
assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)