mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Update REST API routes GET information for agents/humans/personas and store humans/personas in DB (#1074)
This commit is contained in:
parent
04c6d210d3
commit
15dbe34dfe
@ -26,6 +26,7 @@ from memgpt.server.utils import shorten_key_middle
|
|||||||
from memgpt.data_types import User, LLMConfig, EmbeddingConfig, Source
|
from memgpt.data_types import User, LLMConfig, EmbeddingConfig, Source
|
||||||
from memgpt.metadata import MetadataStore
|
from memgpt.metadata import MetadataStore
|
||||||
from memgpt.server.utils import shorten_key_middle
|
from memgpt.server.utils import shorten_key_middle
|
||||||
|
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||||
|
|
||||||
app = typer.Typer()
|
app = typer.Typer()
|
||||||
|
|
||||||
@ -761,20 +762,15 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
|||||||
"""List all humans"""
|
"""List all humans"""
|
||||||
table = PrettyTable()
|
table = PrettyTable()
|
||||||
table.field_names = ["Name", "Text"]
|
table.field_names = ["Name", "Text"]
|
||||||
for human_file in utils.list_human_files():
|
for human in ms.list_humans(user_id=user_id):
|
||||||
text = open(human_file, "r").read()
|
table.add_row([human.name, human.text])
|
||||||
name = os.path.basename(human_file).replace("txt", "")
|
|
||||||
table.add_row([name, text])
|
|
||||||
print(table)
|
print(table)
|
||||||
elif arg == ListChoice.personas:
|
elif arg == ListChoice.personas:
|
||||||
"""List all personas"""
|
"""List all personas"""
|
||||||
table = PrettyTable()
|
table = PrettyTable()
|
||||||
table.field_names = ["Name", "Text"]
|
table.field_names = ["Name", "Text"]
|
||||||
for persona_file in utils.list_persona_files():
|
for persona in ms.list_personas(user_id=user_id):
|
||||||
print(persona_file)
|
table.add_row([persona.name, persona.text])
|
||||||
text = open(persona_file, "r").read()
|
|
||||||
name = os.path.basename(persona_file).replace(".txt", "")
|
|
||||||
table.add_row([name, text])
|
|
||||||
print(table)
|
print(table)
|
||||||
elif arg == ListChoice.sources:
|
elif arg == ListChoice.sources:
|
||||||
"""List all data sources"""
|
"""List all data sources"""
|
||||||
@ -826,24 +822,16 @@ def add(
|
|||||||
filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None,
|
filename: Annotated[Optional[str], typer.Option("-f", help="Specify filename")] = None,
|
||||||
):
|
):
|
||||||
"""Add a person/human"""
|
"""Add a person/human"""
|
||||||
|
config = MemGPTConfig.load()
|
||||||
|
user_id = uuid.UUID(config.anon_clientid)
|
||||||
|
ms = MetadataStore(config)
|
||||||
if option == "persona":
|
if option == "persona":
|
||||||
directory = os.path.join(MEMGPT_DIR, "personas")
|
ms.add_persona(PersonaModel(name=name, text=text, user_id=user_id))
|
||||||
elif option == "human":
|
elif option == "human":
|
||||||
directory = os.path.join(MEMGPT_DIR, "humans")
|
ms.add_human(HumanModel(name=name, text=text, user_id=user_id))
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown kind {option}")
|
raise ValueError(f"Unknown kind {option}")
|
||||||
|
|
||||||
if filename:
|
|
||||||
assert text is None, f"Cannot provide both filename and text"
|
|
||||||
# copy file to directory
|
|
||||||
shutil.copyfile(filename, os.path.join(directory, name))
|
|
||||||
if text:
|
|
||||||
assert filename is None, f"Cannot provide both filename and text"
|
|
||||||
# write text to file
|
|
||||||
with open(os.path.join(directory, name), "w", encoding="utf-8") as f:
|
|
||||||
f.write(text)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def delete(option: str, name: str):
|
def delete(option: str, name: str):
|
||||||
@ -886,6 +874,10 @@ def delete(option: str, name: str):
|
|||||||
# metadata
|
# metadata
|
||||||
ms.delete_agent(agent_id=agent.id)
|
ms.delete_agent(agent_id=agent.id)
|
||||||
|
|
||||||
|
elif option == "human":
|
||||||
|
ms.delete_human(name=name, user_id=user_id)
|
||||||
|
elif option == "persona":
|
||||||
|
ms.delete_persona(name=name, user_id=user_id)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Option {option} not implemented")
|
raise ValueError(f"Option {option} not implemented")
|
||||||
|
|
||||||
|
@ -11,6 +11,8 @@ from memgpt.utils import get_local_time, enforce_types
|
|||||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
|
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
|
||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
|
|
||||||
|
from memgpt.models.pydantic_models import PersonaModel, HumanModel
|
||||||
|
|
||||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
|
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
|
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
|
||||||
@ -318,6 +320,8 @@ class MetadataStore:
|
|||||||
TokenModel.__table__,
|
TokenModel.__table__,
|
||||||
PresetModel.__table__,
|
PresetModel.__table__,
|
||||||
PresetSourceMapping.__table__,
|
PresetSourceMapping.__table__,
|
||||||
|
HumanModel.__table__,
|
||||||
|
PersonaModel.__table__,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
self.session_maker = sessionmaker(bind=self.engine)
|
self.session_maker = sessionmaker(bind=self.engine)
|
||||||
@ -599,3 +603,58 @@ class MetadataStore:
|
|||||||
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
||||||
).delete()
|
).delete()
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def add_human(self, human: HumanModel):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.add(human)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def add_persona(self, persona: PersonaModel):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.add(persona)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_human(self, name: str, user_id: uuid.UUID) -> str:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all()
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_persona(self, name: str, user_id: uuid.UUID) -> str:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all()
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||||
|
return results[0]
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all()
|
||||||
|
return results
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
# if user_id matches provided user_id or if user_id is None
|
||||||
|
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
|
||||||
|
return results
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def delete_human(self, name: str, user_id: uuid.UUID):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete()
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def delete_persona(self, name: str, user_id: uuid.UUID):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
|
||||||
|
session.commit()
|
||||||
|
@ -2,6 +2,11 @@ from typing import List, Union, Optional, Dict, Literal
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field, Json
|
from pydantic import BaseModel, Field, Json
|
||||||
import uuid
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from sqlmodel import Field, SQLModel
|
||||||
|
|
||||||
|
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_PRESET, LLM_MAX_TOKENS, MAX_EMBEDDING_DIM
|
||||||
|
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||||
|
|
||||||
|
|
||||||
class LLMConfigModel(BaseModel):
|
class LLMConfigModel(BaseModel):
|
||||||
@ -20,14 +25,50 @@ class EmbeddingConfigModel(BaseModel):
|
|||||||
embedding_chunk_size: Optional[int] = 300
|
embedding_chunk_size: Optional[int] = 300
|
||||||
|
|
||||||
|
|
||||||
|
class PresetModel(BaseModel):
|
||||||
|
name: str = Field(..., description="The name of the preset.")
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||||
|
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
|
||||||
|
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.now, description="The unix timestamp of when the preset was created.")
|
||||||
|
system: str = Field(..., description="The system prompt of the preset.")
|
||||||
|
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||||
|
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||||
|
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||||
|
|
||||||
|
|
||||||
class AgentStateModel(BaseModel):
|
class AgentStateModel(BaseModel):
|
||||||
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
|
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
|
||||||
name: str = Field(..., description="The name of the agent.")
|
name: str = Field(..., description="The name of the agent.")
|
||||||
|
description: str = Field(None, description="The description of the agent.")
|
||||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
|
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
|
||||||
|
|
||||||
|
# timestamps
|
||||||
|
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
|
||||||
|
|
||||||
|
# preset information
|
||||||
preset: str = Field(..., description="The preset used by the agent.")
|
preset: str = Field(..., description="The preset used by the agent.")
|
||||||
persona: str = Field(..., description="The persona used by the agent.")
|
persona: str = Field(..., description="The persona used by the agent.")
|
||||||
human: str = Field(..., description="The human used by the agent.")
|
human: str = Field(..., description="The human used by the agent.")
|
||||||
|
functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
|
||||||
|
|
||||||
|
# llm information
|
||||||
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
|
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
|
||||||
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")
|
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")
|
||||||
|
|
||||||
|
# agent state
|
||||||
state: Optional[Dict] = Field(None, description="The state of the agent.")
|
state: Optional[Dict] = Field(None, description="The state of the agent.")
|
||||||
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
|
|
||||||
|
|
||||||
|
class HumanModel(SQLModel, table=True):
|
||||||
|
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
|
||||||
|
name: str = Field(..., description="The name of the human.")
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
|
||||||
|
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.")
|
||||||
|
|
||||||
|
|
||||||
|
class PersonaModel(SQLModel, table=True):
|
||||||
|
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
|
||||||
|
name: str = Field(..., description="The name of the persona.")
|
||||||
|
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
|
||||||
|
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.")
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
import os
|
||||||
from memgpt.data_types import AgentState, Preset
|
from memgpt.data_types import AgentState, Preset
|
||||||
from memgpt.interface import AgentInterface
|
from memgpt.interface import AgentInterface
|
||||||
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
|
from memgpt.presets.utils import load_all_presets, is_valid_yaml_format
|
||||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
from memgpt.utils import get_human_text, get_persona_text, printd, list_human_files, list_persona_files
|
||||||
from memgpt.prompts import gpt_system
|
from memgpt.prompts import gpt_system
|
||||||
from memgpt.functions.functions import load_all_function_sets
|
from memgpt.functions.functions import load_all_function_sets
|
||||||
from memgpt.metadata import MetadataStore
|
from memgpt.metadata import MetadataStore
|
||||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||||
|
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
@ -15,15 +17,37 @@ available_presets = load_all_presets()
|
|||||||
preset_options = list(available_presets.keys())
|
preset_options = list(available_presets.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
|
||||||
|
for persona_file in list_persona_files():
|
||||||
|
text = open(persona_file, "r").read()
|
||||||
|
name = os.path.basename(persona_file).replace(".txt", "")
|
||||||
|
if ms.get_persona(user_id=user_id, name=name) is not None:
|
||||||
|
printd(f"Persona '{name}' already exists for user '{user_id}'")
|
||||||
|
continue
|
||||||
|
persona = PersonaModel(name=name, text=text, user_id=user_id)
|
||||||
|
ms.add_persona(persona)
|
||||||
|
for human_file in list_human_files():
|
||||||
|
text = open(human_file, "r").read()
|
||||||
|
name = os.path.basename(human_file).replace(".txt", "")
|
||||||
|
if ms.get_human(user_id=user_id, name=name) is not None:
|
||||||
|
printd(f"Human '{name}' already exists for user '{user_id}'")
|
||||||
|
continue
|
||||||
|
human = HumanModel(name=name, text=text, user_id=user_id)
|
||||||
|
ms.add_human(human)
|
||||||
|
|
||||||
|
|
||||||
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||||
"""Add the default presets to the metadata store"""
|
"""Add the default presets to the metadata store"""
|
||||||
|
# make sure humans/personas added
|
||||||
|
add_default_humans_and_personas(user_id=user_id, ms=ms)
|
||||||
|
|
||||||
|
# add default presets
|
||||||
for preset_name in preset_options:
|
for preset_name in preset_options:
|
||||||
preset_config = available_presets[preset_name]
|
preset_config = available_presets[preset_name]
|
||||||
preset_system_prompt = preset_config["system_prompt"]
|
preset_system_prompt = preset_config["system_prompt"]
|
||||||
preset_function_set_names = preset_config["functions"]
|
preset_function_set_names = preset_config["functions"]
|
||||||
functions_schema = generate_functions_json(preset_function_set_names)
|
functions_schema = generate_functions_json(preset_function_set_names)
|
||||||
|
|
||||||
print("PRESET", preset_name, user_id)
|
|
||||||
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
|
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
|
||||||
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||||
continue
|
continue
|
||||||
|
@ -5,8 +5,9 @@ from functools import partial
|
|||||||
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
|
from fastapi import APIRouter, Body, Depends, Query, HTTPException, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from memgpt.models.pydantic_models import AgentStateModel
|
from memgpt.models.pydantic_models import AgentStateModel, LLMConfigModel, EmbeddingConfigModel
|
||||||
from memgpt.server.rest_api.auth_token import get_current_user
|
from memgpt.server.rest_api.auth_token import get_current_user
|
||||||
from memgpt.server.rest_api.interface import QueuingInterface
|
from memgpt.server.rest_api.interface import QueuingInterface
|
||||||
from memgpt.server.server import SyncServer
|
from memgpt.server.server import SyncServer
|
||||||
@ -14,7 +15,7 @@ from memgpt.server.server import SyncServer
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigRequest(BaseModel):
|
class GetAgentRequest(BaseModel):
|
||||||
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
|
agent_id: str = Field(..., description="Unique identifier of the agent whose config is requested.")
|
||||||
|
|
||||||
|
|
||||||
@ -23,9 +24,11 @@ class AgentRenameRequest(BaseModel):
|
|||||||
agent_name: str = Field(..., description="New name for the agent.")
|
agent_name: str = Field(..., description="New name for the agent.")
|
||||||
|
|
||||||
|
|
||||||
class AgentConfigResponse(BaseModel):
|
class GetAgentResponse(BaseModel):
|
||||||
# config: dict = Field(..., description="The agent configuration object.")
|
# config: dict = Field(..., description="The agent configuration object.")
|
||||||
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
|
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
|
||||||
|
sources: List[str] = Field(..., description="The list of data sources associated with the agent.")
|
||||||
|
last_run_at: Optional[int] = Field(None, description="The unix timestamp of when the agent was last run.")
|
||||||
|
|
||||||
|
|
||||||
def validate_agent_name(name: str) -> str:
|
def validate_agent_name(name: str) -> str:
|
||||||
@ -48,7 +51,7 @@ def validate_agent_name(name: str) -> str:
|
|||||||
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
|
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||||
|
|
||||||
@router.get("/agents/config", tags=["agents"], response_model=AgentConfigResponse)
|
@router.get("/agents", tags=["agents"], response_model=GetAgentResponse)
|
||||||
def get_agent_config(
|
def get_agent_config(
|
||||||
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
|
agent_id: str = Query(..., description="Unique identifier of the agent whose config is requested."),
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||||
@ -58,15 +61,36 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
|
|
||||||
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
||||||
"""
|
"""
|
||||||
request = AgentConfigRequest(agent_id=agent_id)
|
request = GetAgentRequest(agent_id=agent_id)
|
||||||
|
|
||||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||||
|
attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||||
|
|
||||||
interface.clear()
|
interface.clear()
|
||||||
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
|
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
|
||||||
return AgentConfigResponse(agent_state=agent_state)
|
# return GetAgentResponse(agent_state=agent_state)
|
||||||
|
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||||
|
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||||
|
|
||||||
@router.patch("/agents/rename", tags=["agents"], response_model=AgentConfigResponse)
|
return GetAgentResponse(
|
||||||
|
agent_state=AgentStateModel(
|
||||||
|
id=agent_state.id,
|
||||||
|
name=agent_state.name,
|
||||||
|
user_id=agent_state.user_id,
|
||||||
|
preset=agent_state.preset,
|
||||||
|
persona=agent_state.persona,
|
||||||
|
human=agent_state.human,
|
||||||
|
llm_config=agent_state.llm_config,
|
||||||
|
embedding_config=agent_state.embedding_config,
|
||||||
|
state=agent_state.state,
|
||||||
|
created_at=int(agent_state.created_at.timestamp()),
|
||||||
|
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||||
|
),
|
||||||
|
last_run_at=None, # TODO
|
||||||
|
sources=attached_sources,
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.patch("/agents/rename", tags=["agents"], response_model=GetAgentResponse)
|
||||||
def update_agent_name(
|
def update_agent_name(
|
||||||
request: AgentRenameRequest = Body(...),
|
request: AgentRenameRequest = Body(...),
|
||||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||||
@ -87,7 +111,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
raise
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"{e}")
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
return AgentConfigResponse(agent_state=agent_state)
|
return GetAgentResponse(agent_state=agent_state)
|
||||||
|
|
||||||
@router.delete("/agents", tags=["agents"])
|
@router.delete("/agents", tags=["agents"])
|
||||||
def delete_agent(
|
def delete_agent(
|
||||||
@ -97,7 +121,7 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
"""
|
"""
|
||||||
Delete an agent.
|
Delete an agent.
|
||||||
"""
|
"""
|
||||||
request = AgentConfigRequest(agent_id=agent_id)
|
request = GetAgentRequest(agent_id=agent_id)
|
||||||
|
|
||||||
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||||
|
|
||||||
|
@ -69,6 +69,7 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
|||||||
embedding_config=embedding_config,
|
embedding_config=embedding_config,
|
||||||
state=agent_state.state,
|
state=agent_state.state,
|
||||||
created_at=int(agent_state.created_at.timestamp()),
|
created_at=int(agent_state.created_at.timestamp()),
|
||||||
|
functions_schema=agent_state.state["functions"], # TODO: this is very error prone, jsut lookup the preset instead
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# return CreateAgentResponse(
|
# return CreateAgentResponse(
|
||||||
|
@ -8,12 +8,13 @@ from pydantic import BaseModel, Field
|
|||||||
from memgpt.server.rest_api.auth_token import get_current_user
|
from memgpt.server.rest_api.auth_token import get_current_user
|
||||||
from memgpt.server.rest_api.interface import QueuingInterface
|
from memgpt.server.rest_api.interface import QueuingInterface
|
||||||
from memgpt.server.server import SyncServer
|
from memgpt.server.server import SyncServer
|
||||||
|
from memgpt.models.pydantic_models import HumanModel
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class ListHumansResponse(BaseModel):
|
class ListHumansResponse(BaseModel):
|
||||||
humans: List[dict] = Field(..., description="List of human configurations.")
|
humans: List[HumanModel] = Field(..., description="List of human configurations.")
|
||||||
|
|
||||||
|
|
||||||
def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||||
@ -25,14 +26,7 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p
|
|||||||
):
|
):
|
||||||
# Clear the interface
|
# Clear the interface
|
||||||
interface.clear()
|
interface.clear()
|
||||||
|
humans = server.ms.list_humans(user_id=user_id)
|
||||||
# TODO: Replace with actual data fetching logic once available
|
return ListHumansResponse(humans=humans)
|
||||||
humans_data = [
|
|
||||||
{"name": "Marco", "text": "About Me"},
|
|
||||||
{"name": "Sam", "text": "About Me 2"},
|
|
||||||
{"name": "Bruce", "text": "About Me 3"},
|
|
||||||
]
|
|
||||||
|
|
||||||
return ListHumansResponse(humans=humans_data)
|
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
@ -8,12 +8,13 @@ from pydantic import BaseModel, Field
|
|||||||
from memgpt.server.rest_api.auth_token import get_current_user
|
from memgpt.server.rest_api.auth_token import get_current_user
|
||||||
from memgpt.server.rest_api.interface import QueuingInterface
|
from memgpt.server.rest_api.interface import QueuingInterface
|
||||||
from memgpt.server.server import SyncServer
|
from memgpt.server.server import SyncServer
|
||||||
|
from memgpt.models.pydantic_models import PersonaModel
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
class ListPersonasResponse(BaseModel):
|
class ListPersonasResponse(BaseModel):
|
||||||
personas: List[dict] = Field(..., description="List of persona configurations.")
|
personas: List[PersonaModel] = Field(..., description="List of persona configurations.")
|
||||||
|
|
||||||
|
|
||||||
def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
def setup_personas_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||||
@ -26,13 +27,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
# Clear the interface
|
# Clear the interface
|
||||||
interface.clear()
|
interface.clear()
|
||||||
|
|
||||||
# TODO: Replace with actual data fetching logic once available
|
personas = server.ms.list_personas(user_id=user_id)
|
||||||
personas_data = [
|
return ListPersonasResponse(personas=personas)
|
||||||
{"name": "Persona 1", "text": "Details about Persona 1"},
|
|
||||||
{"name": "Persona 2", "text": "Details about Persona 2"},
|
|
||||||
{"name": "Persona 3", "text": "Details about Persona 3"},
|
|
||||||
]
|
|
||||||
|
|
||||||
return ListPersonasResponse(personas=personas_data)
|
|
||||||
|
|
||||||
return router
|
return router
|
||||||
|
@ -1071,3 +1071,7 @@ class SyncServer(LockingServer):
|
|||||||
|
|
||||||
# attach source to agent
|
# attach source to agent
|
||||||
agent.attach_source(data_source.name, source_connector, self.ms)
|
agent.attach_source(data_source.name, source_connector, self.ms)
|
||||||
|
|
||||||
|
def list_attached_sources(self, agent_id: uuid.UUID):
|
||||||
|
# list all attached sources to an agent
|
||||||
|
return self.ms.list_attached_sources(agent_id)
|
||||||
|
20
poetry.lock
generated
20
poetry.lock
generated
@ -1,4 +1,4 @@
|
|||||||
# This file is automatically @generated by Poetry 1.8.1 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiohttp"
|
name = "aiohttp"
|
||||||
@ -3822,7 +3822,6 @@ files = [
|
|||||||
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
|
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
|
||||||
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
|
||||||
@ -4394,6 +4393,21 @@ sqlalchemy = ">=0.7"
|
|||||||
[package.extras]
|
[package.extras]
|
||||||
dev = ["pytest"]
|
dev = ["pytest"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sqlmodel"
|
||||||
|
version = "0.0.16"
|
||||||
|
description = "SQLModel, SQL databases in Python, designed for simplicity, compatibility, and robustness."
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7,<4.0"
|
||||||
|
files = [
|
||||||
|
{file = "sqlmodel-0.0.16-py3-none-any.whl", hash = "sha256:b972f5d319580d6c37ecc417881f6ec4d1ad3ed3583d0ac0ed43234a28bf605a"},
|
||||||
|
{file = "sqlmodel-0.0.16.tar.gz", hash = "sha256:966656f18a8e9a2d159eb215b07fb0cf5222acfae3362707ca611848a8a06bd1"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pydantic = ">=1.10.13,<3.0.0"
|
||||||
|
SQLAlchemy = ">=2.0.0,<2.1.0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "stack-data"
|
name = "stack-data"
|
||||||
version = "0.6.3"
|
version = "0.6.3"
|
||||||
@ -5533,4 +5547,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "<3.12,>=3.10"
|
python-versions = "<3.12,>=3.10"
|
||||||
content-hash = "a420d69943e7e1c2b6900ac9ac9d9fd51ff9e16ff245bcea8e590a276b736251"
|
content-hash = "22ed7617b3b586152c81e359b366192a9f1708265933b8ed2dc3ad041b798264"
|
||||||
|
@ -56,6 +56,7 @@ llama-index = "^0.10.6"
|
|||||||
llama-index-embeddings-openai = "^0.1.1"
|
llama-index-embeddings-openai = "^0.1.1"
|
||||||
python-box = "^7.1.1"
|
python-box = "^7.1.1"
|
||||||
pytest-order = {version = "^1.2.0", optional = true}
|
pytest-order = {version = "^1.2.0", optional = true}
|
||||||
|
sqlmodel = "^0.0.16"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
local = ["torch", "huggingface-hub", "transformers"]
|
local = ["torch", "huggingface-hub", "transformers"]
|
||||||
|
@ -7,6 +7,9 @@ from memgpt.metadata import MetadataStore
|
|||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
|
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
|
||||||
from memgpt.utils import get_human_text, get_persona_text
|
from memgpt.utils import get_human_text, get_persona_text
|
||||||
|
from memgpt.presets.presets import add_default_presets, add_default_humans_and_personas
|
||||||
|
|
||||||
|
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
|
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
|
||||||
@ -26,9 +29,22 @@ def test_storage(storage_connector):
|
|||||||
|
|
||||||
ms = MetadataStore(config)
|
ms = MetadataStore(config)
|
||||||
|
|
||||||
# generate data
|
# users
|
||||||
user_1 = User()
|
user_1 = User()
|
||||||
user_2 = User()
|
user_2 = User()
|
||||||
|
ms.create_user(user_1)
|
||||||
|
ms.create_user(user_2)
|
||||||
|
|
||||||
|
# test adding defaults
|
||||||
|
# TODO: move below
|
||||||
|
add_default_humans_and_personas(user_id=user_1.id, ms=ms)
|
||||||
|
add_default_humans_and_personas(user_id=user_2.id, ms=ms)
|
||||||
|
ms.add_human(human=HumanModel(name="test_human", text="This is a test human"))
|
||||||
|
ms.add_persona(persona=PersonaModel(name="test_persona", text="This is a test persona"))
|
||||||
|
add_default_presets(user_id=user_1.id, ms=ms)
|
||||||
|
add_default_presets(user_id=user_2.id, ms=ms)
|
||||||
|
|
||||||
|
# generate data
|
||||||
agent_1 = AgentState(
|
agent_1 = AgentState(
|
||||||
user_id=user_1.id,
|
user_id=user_1.id,
|
||||||
name="agent_1",
|
name="agent_1",
|
||||||
@ -41,8 +57,6 @@ def test_storage(storage_connector):
|
|||||||
source_1 = Source(user_id=user_1.id, name="source_1")
|
source_1 = Source(user_id=user_1.id, name="source_1")
|
||||||
|
|
||||||
# test creation
|
# test creation
|
||||||
ms.create_user(user_1)
|
|
||||||
ms.create_user(user_2)
|
|
||||||
ms.create_agent(agent_1)
|
ms.create_agent(agent_1)
|
||||||
ms.create_source(source_1)
|
ms.create_source(source_1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user