mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
parent
b39ad16a9d
commit
55a36a6e3d
@ -58,6 +58,7 @@ jobs:
|
||||
pipx install poetry==1.8.2
|
||||
poetry install -E dev -E postgres
|
||||
poetry run pytest -s tests/test_client.py
|
||||
poetry run pytest -s tests/test_concurrent_connections.py
|
||||
|
||||
- name: Print docker logs if tests fail
|
||||
if: failure()
|
||||
|
16
.github/workflows/tests.yml
vendored
16
.github/workflows/tests.yml
vendored
@ -2,7 +2,6 @@ name: Run All pytest Tests
|
||||
|
||||
env:
|
||||
MEMGPT_PGURI: ${{ secrets.MEMGPT_PGURI }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
on:
|
||||
push:
|
||||
@ -33,10 +32,10 @@ jobs:
|
||||
with:
|
||||
python-version: "3.12"
|
||||
poetry-version: "1.8.2"
|
||||
install-args: "-E dev -E postgres -E milvus -E crewai-tools"
|
||||
install-args: "-E dev -E postgres -E milvus"
|
||||
|
||||
- name: Initialize credentials
|
||||
run: poetry run memgpt quickstart --backend openai
|
||||
run: poetry run memgpt quickstart --backend memgpt
|
||||
|
||||
#- name: Run docker compose server
|
||||
# env:
|
||||
@ -70,3 +69,14 @@ jobs:
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests
|
||||
|
||||
- name: Run storage tests
|
||||
env:
|
||||
MEMGPT_PG_PORT: 8888
|
||||
MEMGPT_PG_USER: memgpt
|
||||
MEMGPT_PG_PASSWORD: memgpt
|
||||
MEMGPT_PG_HOST: localhost
|
||||
MEMGPT_PG_DB: memgpt
|
||||
MEMGPT_SERVER_PASS: test_server_token
|
||||
run: |
|
||||
poetry run pytest -s -vv tests/test_storage.py
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -1012,7 +1012,6 @@ FodyWeavers.xsd
|
||||
## cached db data
|
||||
pgdata/
|
||||
!pgdata/.gitkeep
|
||||
.persist/
|
||||
|
||||
## pytest mirrors
|
||||
memgpt/.pytest_cache/
|
||||
|
@ -14,7 +14,7 @@ WORKDIR /app
|
||||
COPY pyproject.toml poetry.lock ./
|
||||
RUN poetry lock --no-update
|
||||
RUN if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then \
|
||||
poetry install --no-root -E "postgres server dev" ; \
|
||||
poetry install --no-root -E "postgres server dev autogen" ; \
|
||||
else \
|
||||
poetry install --no-root -E "postgres server" && \
|
||||
rm -rf $POETRY_CACHE_DIR ; \
|
||||
|
287
memgpt/agent.py
287
memgpt/agent.py
@ -2,7 +2,8 @@ import datetime
|
||||
import inspect
|
||||
import json
|
||||
import traceback
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
import uuid
|
||||
from typing import List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -18,20 +19,14 @@ from memgpt.constants import (
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
|
||||
MESSAGE_SUMMARY_WARNING_FRAC,
|
||||
)
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
|
||||
from memgpt.memory import ArchivalMemory, RecallMemory, summarize_messages
|
||||
from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.models.pydantic_models import OptionState, ToolModel
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.schemas.agent import AgentState
|
||||
from memgpt.schemas.block import Block
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.enums import OptionState
|
||||
from memgpt.schemas.memory import Memory
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.schemas.tool import Tool
|
||||
from memgpt.system import (
|
||||
get_initial_boot_messages,
|
||||
get_login_event,
|
||||
@ -40,6 +35,7 @@ from memgpt.system import (
|
||||
)
|
||||
from memgpt.utils import (
|
||||
count_tokens,
|
||||
create_uuid_from_string,
|
||||
get_local_time,
|
||||
get_tool_call_id,
|
||||
get_utc_time,
|
||||
@ -76,7 +72,7 @@ def compile_memory_metadata_block(
|
||||
|
||||
def compile_system_message(
|
||||
system_prompt: str,
|
||||
in_context_memory: Memory,
|
||||
in_context_memory: BaseMemory,
|
||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
@ -139,7 +135,7 @@ def compile_system_message(
|
||||
def initialize_message_sequence(
|
||||
model: str,
|
||||
system: str,
|
||||
memory: Memory,
|
||||
memory: BaseMemory,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||
@ -192,21 +188,35 @@ class Agent(object):
|
||||
interface: AgentInterface,
|
||||
# agents can be created from providing agent_state
|
||||
agent_state: AgentState,
|
||||
tools: List[Tool],
|
||||
# memory: Memory,
|
||||
tools: List[ToolModel],
|
||||
# memory: BaseMemory,
|
||||
# extras
|
||||
messages_total: Optional[int] = None, # TODO remove?
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
):
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
|
||||
# tools
|
||||
for tool in tools:
|
||||
assert tool, f"Tool is None - must be error in querying tool from DB"
|
||||
assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools"
|
||||
for tool_name in agent_state.tools:
|
||||
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
|
||||
# Store the functions schemas (this is passed as an argument to ChatCompletion)
|
||||
self.functions = []
|
||||
self.functions_python = {}
|
||||
env = {}
|
||||
env.update(globals())
|
||||
for tool in tools:
|
||||
# WARNING: name may not be consistent?
|
||||
if tool.module: # execute the whole module
|
||||
exec(tool.module, env)
|
||||
else:
|
||||
exec(tool.source_code, env)
|
||||
self.functions_python[tool.name] = env[tool.name]
|
||||
self.functions.append(tool.json_schema)
|
||||
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
|
||||
|
||||
# Hold a copy of the state that was used to init the agent
|
||||
self.agent_state = agent_state
|
||||
assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}"
|
||||
|
||||
try:
|
||||
self.link_tools(tools)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Encountered an error while trying to link agent tools during initialization:\n{str(e)}")
|
||||
|
||||
# gpt-4, gpt-3.5-turbo, ...
|
||||
self.model = self.agent_state.llm_config.model
|
||||
@ -215,8 +225,7 @@ class Agent(object):
|
||||
self.system = self.agent_state.system
|
||||
|
||||
# Initialize the memory object
|
||||
self.memory = self.agent_state.memory
|
||||
assert isinstance(self.memory, Memory), f"Memory object is not of type Memory: {type(self.memory)}"
|
||||
self.memory = BaseMemory.load(self.agent_state.state["memory"])
|
||||
printd("Initialized memory object", self.memory)
|
||||
|
||||
# Interface must implement:
|
||||
@ -245,13 +254,28 @@ class Agent(object):
|
||||
self._messages: List[Message] = []
|
||||
|
||||
# Once the memory object is initialized, use it to "bake" the system message
|
||||
if self.agent_state.message_ids is not None:
|
||||
self.set_message_buffer(message_ids=self.agent_state.message_ids)
|
||||
if "messages" in self.agent_state.state and self.agent_state.state["messages"] is not None:
|
||||
# print(f"Agent.__init__ :: loading, state={agent_state.state['messages']}")
|
||||
if not isinstance(self.agent_state.state["messages"], list):
|
||||
raise ValueError(f"'messages' in AgentState was bad type: {type(self.agent_state.state['messages'])}")
|
||||
assert all([isinstance(msg, str) for msg in self.agent_state.state["messages"]])
|
||||
|
||||
# Convert to IDs, and pull from the database
|
||||
raw_messages = [
|
||||
self.persistence_manager.recall_memory.storage.get(id=uuid.UUID(msg_id)) for msg_id in self.agent_state.state["messages"]
|
||||
]
|
||||
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"])
|
||||
self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None])
|
||||
|
||||
for m in self._messages:
|
||||
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
|
||||
# TODO eventually do casting via an edit_message function
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
else:
|
||||
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
|
||||
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
printd(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
|
||||
init_messages = initialize_message_sequence(
|
||||
model=self.model,
|
||||
system=self.system,
|
||||
@ -261,8 +285,6 @@ class Agent(object):
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
)
|
||||
|
||||
# Cast the messages to actual Message objects to be synced to the DB
|
||||
init_messages_objs = []
|
||||
for msg in init_messages:
|
||||
init_messages_objs.append(
|
||||
@ -271,12 +293,15 @@ class Agent(object):
|
||||
)
|
||||
)
|
||||
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)
|
||||
|
||||
# Put the messages inside the message buffer
|
||||
self.messages_total = 0
|
||||
# self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
|
||||
self._append_to_messages(added_messages=init_messages_objs)
|
||||
self._validate_message_buffer_is_utc()
|
||||
self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
|
||||
|
||||
for m in self._messages:
|
||||
assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
|
||||
# TODO eventually do casting via an edit_message function
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
# Keep track of the total number of messages throughout all time
|
||||
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
|
||||
@ -295,65 +320,6 @@ class Agent(object):
|
||||
def messages(self, value):
|
||||
raise Exception("Modifying message list directly not allowed")
|
||||
|
||||
def link_tools(self, tools: List[Tool]):
|
||||
"""Bind a tool object (schema + python function) to the agent object"""
|
||||
|
||||
# tools
|
||||
for tool in tools:
|
||||
assert tool, f"Tool is None - must be error in querying tool from DB"
|
||||
assert tool.name in self.agent_state.tools, f"Tool {tool} not found in agent_state.tools"
|
||||
for tool_name in self.agent_state.tools:
|
||||
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
|
||||
|
||||
# Store the functions schemas (this is passed as an argument to ChatCompletion)
|
||||
self.functions = []
|
||||
self.functions_python = {}
|
||||
env = {}
|
||||
env.update(globals())
|
||||
for tool in tools:
|
||||
# WARNING: name may not be consistent?
|
||||
if tool.module: # execute the whole module
|
||||
exec(tool.module, env)
|
||||
else:
|
||||
exec(tool.source_code, env)
|
||||
self.functions_python[tool.name] = env[tool.name]
|
||||
self.functions.append(tool.json_schema)
|
||||
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
|
||||
|
||||
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
|
||||
"""Load a list of messages from recall storage"""
|
||||
|
||||
# Pull the message objects from the database
|
||||
message_objs = [self.persistence_manager.recall_memory.storage.get(msg_id) for msg_id in message_ids]
|
||||
assert all([isinstance(msg, Message) for msg in message_objs])
|
||||
|
||||
return message_objs
|
||||
|
||||
def _validate_message_buffer_is_utc(self):
|
||||
"""Iterate over the message buffer and force all messages to be UTC stamped"""
|
||||
|
||||
for m in self._messages:
|
||||
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
|
||||
# TODO eventually do casting via an edit_message function
|
||||
if not is_utc_datetime(m.created_at):
|
||||
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
|
||||
"""Set the messages in the buffer to the message IDs list"""
|
||||
|
||||
message_objs = self._load_messages_from_recall(message_ids=message_ids)
|
||||
|
||||
# set the objects in the buffer
|
||||
self._messages = message_objs
|
||||
|
||||
# bugfix for old agents that may not have had UTC specified in their timestamps
|
||||
if force_utc:
|
||||
self._validate_message_buffer_is_utc()
|
||||
|
||||
# also sync the message IDs attribute
|
||||
self.agent_state.message_ids = message_ids
|
||||
|
||||
def _trim_messages(self, num):
|
||||
"""Trim messages from the front, not including the system message"""
|
||||
self.persistence_manager.trim_messages(num)
|
||||
@ -406,7 +372,7 @@ class Agent(object):
|
||||
first_message: bool = False, # hint
|
||||
stream: bool = False, # TODO move to config?
|
||||
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
|
||||
) -> ChatCompletionResponse:
|
||||
) -> chat_completion_response.ChatCompletionResponse:
|
||||
"""Get response from LLM API"""
|
||||
try:
|
||||
response = create(
|
||||
@ -442,7 +408,9 @@ class Agent(object):
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def _handle_ai_response(self, response_message: Message, override_tool_call_id: bool = True) -> Tuple[List[Message], bool, bool]:
|
||||
def _handle_ai_response(
|
||||
self, response_message: chat_completion_response.Message, override_tool_call_id: bool = True
|
||||
) -> Tuple[List[Message], bool, bool]:
|
||||
"""Handles parsing and function execution"""
|
||||
|
||||
messages = [] # append these to the history when done
|
||||
@ -647,7 +615,6 @@ class Agent(object):
|
||||
stream: bool = False, # TODO move to config?
|
||||
timestamp: Optional[datetime.datetime] = None,
|
||||
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
|
||||
ms: Optional[MetadataStore] = None,
|
||||
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
|
||||
"""Top-level event message handler for the MemGPT agent"""
|
||||
|
||||
@ -677,20 +644,7 @@ class Agent(object):
|
||||
raise e
|
||||
|
||||
try:
|
||||
# Step 0: update core memory
|
||||
# only pulling latest block data if shared memory is being used
|
||||
# TODO: ensure we're passing in metadata store from all surfaces
|
||||
if ms is not None:
|
||||
should_update = False
|
||||
for block in self.agent_state.memory.to_dict().values():
|
||||
if not block.get("template", False):
|
||||
should_update = True
|
||||
if should_update:
|
||||
# TODO: the force=True can be optimized away
|
||||
# once we ensure we're correctly comparing whether in-memory core
|
||||
# data is different than persisted core data.
|
||||
self.rebuild_memory(force=True, ms=ms)
|
||||
# Step 1: add user message
|
||||
# Step 0: add user message
|
||||
if user_message is not None:
|
||||
if isinstance(user_message, Message):
|
||||
# Validate JSON via save/load
|
||||
@ -736,7 +690,7 @@ class Agent(object):
|
||||
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
|
||||
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
|
||||
|
||||
# Step 2: send the conversation and available functions to GPT
|
||||
# Step 1: send the conversation and available functions to GPT
|
||||
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
|
||||
printd(f"This is the first message. Running extra verifier on AI response.")
|
||||
counter = 0
|
||||
@ -761,9 +715,9 @@ class Agent(object):
|
||||
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||
)
|
||||
|
||||
# Step 3: check if LLM wanted to call a function
|
||||
# (if yes) Step 4: call the function
|
||||
# (if yes) Step 5: send the info on the function call and function response to LLM
|
||||
# Step 2: check if LLM wanted to call a function
|
||||
# (if yes) Step 3: call the function
|
||||
# (if yes) Step 4: send the info on the function call and function response to LLM
|
||||
response_message = response.choices[0].message
|
||||
response_message.model_copy() # TODO why are we copying here?
|
||||
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message)
|
||||
@ -779,7 +733,7 @@ class Agent(object):
|
||||
# "functions": self.functions,
|
||||
# }
|
||||
|
||||
# Step 6: extend the message history
|
||||
# Step 4: extend the message history
|
||||
if user_message is not None:
|
||||
if isinstance(user_message, Message):
|
||||
all_new_messages = [user_message] + all_response_messages
|
||||
@ -839,7 +793,6 @@ class Agent(object):
|
||||
stream=stream,
|
||||
timestamp=timestamp,
|
||||
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||
ms=ms,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -988,8 +941,8 @@ class Agent(object):
|
||||
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
|
||||
self._messages = new_messages
|
||||
|
||||
def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None):
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates"""
|
||||
def rebuild_memory(self, force=False, update_timestamp=True):
|
||||
"""Rebuilds the system message with the latest memory object"""
|
||||
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
|
||||
|
||||
# NOTE: This is a hacky way to check if the memory has changed
|
||||
@ -998,28 +951,6 @@ class Agent(object):
|
||||
printd(f"Memory has not changed, not rebuilding system")
|
||||
return
|
||||
|
||||
if ms:
|
||||
for block in self.memory.to_dict().values():
|
||||
if block.get("templates", False):
|
||||
# we don't expect to update shared memory blocks that
|
||||
# are templates. this is something we could update in the
|
||||
# future if we expect templates to change often.
|
||||
continue
|
||||
block_id = block.get("id")
|
||||
db_block = ms.get_block(block_id=block_id)
|
||||
if db_block is None:
|
||||
# this case covers if someone has deleted a shared block by interacting
|
||||
# with some other agent.
|
||||
# in that case we should remove this shared block from the agent currently being
|
||||
# evaluated.
|
||||
printd(f"removing block: {block_id=}")
|
||||
continue
|
||||
if not isinstance(db_block.value, str):
|
||||
printd(f"skipping block update, unexpected value: {block_id=}")
|
||||
continue
|
||||
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
|
||||
self.memory.update_block_value(name=block.get("name", ""), value=db_block.value)
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
if update_timestamp:
|
||||
@ -1117,14 +1048,25 @@ class Agent(object):
|
||||
# return msg
|
||||
|
||||
def update_state(self) -> AgentState:
|
||||
message_ids = [msg.id for msg in self._messages]
|
||||
assert isinstance(self.memory, Memory), f"Memory is not a Memory object: {type(self.memory)}"
|
||||
|
||||
# override any fields that may have been updated
|
||||
self.agent_state.message_ids = message_ids
|
||||
self.agent_state.memory = self.memory
|
||||
self.agent_state.system = self.system
|
||||
|
||||
memory = {
|
||||
"system": self.system,
|
||||
"memory": self.memory.to_dict(),
|
||||
"messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids
|
||||
}
|
||||
self.agent_state = AgentState(
|
||||
name=self.agent_state.name,
|
||||
user_id=self.agent_state.user_id,
|
||||
tools=self.agent_state.tools,
|
||||
system=self.system,
|
||||
## "model_state"
|
||||
llm_config=self.agent_state.llm_config,
|
||||
embedding_config=self.agent_state.embedding_config,
|
||||
id=self.agent_state.id,
|
||||
created_at=self.agent_state.created_at,
|
||||
## "agent_state"
|
||||
state=memory,
|
||||
_metadata=self.agent_state._metadata,
|
||||
)
|
||||
return self.agent_state
|
||||
|
||||
def migrate_embedding(self, embedding_config: EmbeddingConfig):
|
||||
@ -1134,12 +1076,13 @@ class Agent(object):
|
||||
# TODO: recall memory
|
||||
raise NotImplementedError()
|
||||
|
||||
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
|
||||
def attach_source(self, source_name, source_connector: StorageConnector, ms: MetadataStore):
|
||||
"""Attach data with name `source_name` to the agent from source_connector."""
|
||||
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
|
||||
|
||||
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
|
||||
filters = {"user_id": self.agent_state.user_id, "data_source": source_name}
|
||||
size = source_connector.size(filters)
|
||||
# typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN)
|
||||
page_size = 100
|
||||
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
|
||||
all_passages = []
|
||||
@ -1152,8 +1095,7 @@ class Agent(object):
|
||||
passage.agent_id = self.agent_state.id
|
||||
|
||||
# regenerate passage ID (avoid duplicates)
|
||||
# TODO: need to find another solution to the text duplication issue
|
||||
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
|
||||
passage.id = create_uuid_from_string(f"{source_name}_{str(passage.agent_id)}_{passage.text}")
|
||||
|
||||
# insert into agent archival memory
|
||||
self.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
@ -1165,14 +1107,15 @@ class Agent(object):
|
||||
self.persistence_manager.archival_memory.storage.save()
|
||||
|
||||
# attach to agent
|
||||
source = ms.get_source(source_id=source_id)
|
||||
assert source is not None, f"Source {source_id} not found in metadata store"
|
||||
source = ms.get_source(source_name=source_name, user_id=self.agent_state.user_id)
|
||||
assert source is not None, f"source does not exist for source_name={source_name}, user_id={self.agent_state.user_id}"
|
||||
source_id = source.id
|
||||
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
|
||||
|
||||
total_agent_passages = self.persistence_manager.archival_memory.storage.size()
|
||||
|
||||
printd(
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
f"Attached data source {source_name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
)
|
||||
|
||||
|
||||
@ -1181,36 +1124,8 @@ def save_agent(agent: Agent, ms: MetadataStore):
|
||||
|
||||
agent.update_state()
|
||||
agent_state = agent.agent_state
|
||||
agent_id = agent_state.id
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
||||
|
||||
# NOTE: we're saving agent memory before persisting the agent to ensure
|
||||
# that allocated block_ids for each memory block are present in the agent model
|
||||
save_agent_memory(agent=agent, ms=ms)
|
||||
|
||||
if ms.get_agent(agent_id=agent.agent_state.id):
|
||||
if ms.get_agent(agent_name=agent_state.name, user_id=agent_state.user_id):
|
||||
ms.update_agent(agent_state)
|
||||
else:
|
||||
ms.create_agent(agent_state)
|
||||
|
||||
agent.agent_state = ms.get_agent(agent_id=agent_id)
|
||||
assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
|
||||
|
||||
|
||||
def save_agent_memory(agent: Agent, ms: MetadataStore):
|
||||
"""
|
||||
Save agent memory to metadata store. Memory is a collection of blocks and each block is persisted to the block table.
|
||||
|
||||
NOTE: we are assuming agent.update_state has already been called.
|
||||
"""
|
||||
|
||||
for block_dict in agent.memory.to_dict().values():
|
||||
# TODO: block creation should happen in one place to enforce these sort of constraints consistently.
|
||||
if block_dict.get("user_id", None) is None:
|
||||
block_dict["user_id"] = agent.agent_state.user_id
|
||||
block = Block(**block_dict)
|
||||
# FIXME: should we expect for block values to be None? If not, we need to figure out why that is
|
||||
# the case in some tests, if so we should relax the DB constraint.
|
||||
if block.value is None:
|
||||
block.value = ""
|
||||
ms.update_or_create_block(block)
|
||||
|
@ -1,12 +1,12 @@
|
||||
from typing import Dict, List, Optional, Tuple, cast
|
||||
import uuid
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, cast
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.types import Include
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.data_types import Passage, Record, RecordType
|
||||
from memgpt.utils import datetime_to_timestamp, printd, timestamp_to_datetime
|
||||
|
||||
|
||||
@ -34,6 +34,9 @@ class ChromaStorageConnector(StorageConnector):
|
||||
self.collection = self.client.get_or_create_collection(self.table_name)
|
||||
self.include: Include = ["documents", "embeddings", "metadatas"]
|
||||
|
||||
# need to be converted to strings
|
||||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
|
||||
# get all filters for query
|
||||
if filters is not None:
|
||||
@ -51,7 +54,10 @@ class ChromaStorageConnector(StorageConnector):
|
||||
continue
|
||||
|
||||
# filter by other keys
|
||||
chroma_filters.append({key: {"$eq": value}})
|
||||
if key in self.uuid_fields:
|
||||
chroma_filters.append({key: {"$eq": str(value)}})
|
||||
else:
|
||||
chroma_filters.append({key: {"$eq": value}})
|
||||
|
||||
if len(chroma_filters) > 1:
|
||||
chroma_filters = {"$and": chroma_filters}
|
||||
@ -61,7 +67,7 @@ class ChromaStorageConnector(StorageConnector):
|
||||
chroma_filters = chroma_filters[0]
|
||||
return ids, chroma_filters
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0):
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0) -> Iterator[List[RecordType]]:
|
||||
ids, filters = self.get_filters(filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
@ -78,50 +84,29 @@ class ChromaStorageConnector(StorageConnector):
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
|
||||
def results_to_records(self, results):
|
||||
def results_to_records(self, results) -> List[RecordType]:
|
||||
# convert timestamps to datetime
|
||||
for metadata in results["metadatas"]:
|
||||
if "created_at" in metadata:
|
||||
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
|
||||
for key, value in metadata.items():
|
||||
if key in self.uuid_fields:
|
||||
metadata[key] = uuid.UUID(value)
|
||||
if results["embeddings"]: # may not be returned, depending on table type
|
||||
passages = []
|
||||
for text, record_id, embedding, metadata in zip(
|
||||
results["documents"], results["ids"], results["embeddings"], results["metadatas"]
|
||||
):
|
||||
args = {}
|
||||
for field in EmbeddingConfig.__fields__.keys():
|
||||
if field in metadata:
|
||||
args[field] = metadata[field]
|
||||
del metadata[field]
|
||||
embedding_config = EmbeddingConfig(**args)
|
||||
passages.append(Passage(text=text, embedding=embedding, id=record_id, embedding_config=embedding_config, **metadata))
|
||||
# return [
|
||||
# Passage(text=text, embedding=embedding, id=record_id, embedding_config=EmbeddingConfig(), **metadatas)
|
||||
# for (text, record_id, embedding, metadatas) in zip(
|
||||
# results["documents"], results["ids"], results["embeddings"], results["metadatas"]
|
||||
# )
|
||||
# ]
|
||||
return passages
|
||||
return [
|
||||
cast(RecordType, self.type(text=text, embedding=embedding, id=uuid.UUID(record_id), **metadatas)) # type: ignore
|
||||
for (text, record_id, embedding, metadatas) in zip(
|
||||
results["documents"], results["ids"], results["embeddings"], results["metadatas"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
# no embeddings
|
||||
passages = []
|
||||
for text, id, metadata in zip(results["documents"], results["ids"], results["metadatas"]):
|
||||
args = {}
|
||||
for field in EmbeddingConfig.__fields__.keys():
|
||||
if field in metadata:
|
||||
args[field] = metadata[field]
|
||||
del metadata[field]
|
||||
embedding_config = EmbeddingConfig(**args)
|
||||
passages.append(Passage(text=text, embedding=None, id=id, embedding_config=embedding_config, **metadata))
|
||||
return passages
|
||||
return [
|
||||
cast(RecordType, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
|
||||
for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
|
||||
]
|
||||
|
||||
# return [
|
||||
# #cast(Passage, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
|
||||
# Passage(text=text, embedding=None, id=id, **metadatas)
|
||||
# for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
|
||||
# ]
|
||||
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None):
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
|
||||
ids, filters = self.get_filters(filters)
|
||||
if self.collection.count() == 0:
|
||||
return []
|
||||
@ -131,13 +116,13 @@ class ChromaStorageConnector(StorageConnector):
|
||||
results = self.collection.get(ids=ids, include=self.include, where=filters)
|
||||
return self.results_to_records(results)
|
||||
|
||||
def get(self, id):
|
||||
def get(self, id: uuid.UUID) -> Optional[RecordType]:
|
||||
results = self.collection.get(ids=[str(id)])
|
||||
if len(results["ids"]) == 0:
|
||||
return None
|
||||
return self.results_to_records(results)[0]
|
||||
|
||||
def format_records(self, records):
|
||||
def format_records(self, records: List[RecordType]):
|
||||
assert all([isinstance(r, Passage) for r in records])
|
||||
|
||||
recs = []
|
||||
@ -160,13 +145,10 @@ class ChromaStorageConnector(StorageConnector):
|
||||
# collect/format record metadata
|
||||
metadatas = []
|
||||
for record in recs:
|
||||
embedding_config = vars(record.embedding_config)
|
||||
metadata = vars(record)
|
||||
metadata.pop("id")
|
||||
metadata.pop("text")
|
||||
metadata.pop("embedding")
|
||||
metadata.pop("embedding_config")
|
||||
metadata.pop("metadata_")
|
||||
if "created_at" in metadata:
|
||||
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
|
||||
if "metadata_" in metadata and metadata["metadata_"] is not None:
|
||||
@ -174,22 +156,23 @@ class ChromaStorageConnector(StorageConnector):
|
||||
metadata.pop("metadata_")
|
||||
else:
|
||||
record_metadata = {}
|
||||
|
||||
metadata = {**metadata, **record_metadata} # merge with metadata
|
||||
metadata = {**metadata, **embedding_config} # merge with embedding config
|
||||
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
|
||||
metadata = {**metadata, **record_metadata} # merge with metadata
|
||||
|
||||
# convert uuids to strings
|
||||
for key, value in metadata.items():
|
||||
if key in self.uuid_fields:
|
||||
metadata[key] = str(value)
|
||||
metadatas.append(metadata)
|
||||
return ids, documents, embeddings, metadatas
|
||||
|
||||
def insert(self, record):
|
||||
def insert(self, record: Record):
|
||||
ids, documents, embeddings, metadatas = self.format_records([record])
|
||||
if any([e is None for e in embeddings]):
|
||||
raise ValueError("Embeddings must be provided to chroma")
|
||||
self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas)
|
||||
|
||||
def insert_many(self, records, show_progress=False):
|
||||
def insert_many(self, records: List[RecordType], show_progress=False):
|
||||
ids, documents, embeddings, metadatas = self.format_records(records)
|
||||
if any([e is None for e in embeddings]):
|
||||
raise ValueError("Embeddings must be provided to chroma")
|
||||
@ -215,7 +198,7 @@ class ChromaStorageConnector(StorageConnector):
|
||||
def list_data_sources(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
|
||||
ids, filters = self.get_filters(filters)
|
||||
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters)
|
||||
|
||||
@ -256,40 +239,10 @@ class ChromaStorageConnector(StorageConnector):
|
||||
def get_all_cursor(
|
||||
self,
|
||||
filters: Optional[Dict] = {},
|
||||
after: str = None,
|
||||
before: str = None,
|
||||
after: uuid.UUID = None,
|
||||
before: uuid.UUID = None,
|
||||
limit: Optional[int] = 1000,
|
||||
order_by: str = "created_at",
|
||||
reverse: bool = False,
|
||||
):
|
||||
records = self.get_all(filters=filters)
|
||||
|
||||
# WARNING: very hacky and slow implementation
|
||||
def get_index(id, record_list):
|
||||
for i in range(len(record_list)):
|
||||
if record_list[i].id == id:
|
||||
return i
|
||||
assert False, f"Could not find id {id} in record list"
|
||||
|
||||
# sort by custom field
|
||||
records = sorted(records, key=lambda x: getattr(x, order_by), reverse=reverse)
|
||||
if after:
|
||||
index = get_index(after, records)
|
||||
if index + 1 >= len(records):
|
||||
return None, []
|
||||
records = records[index + 1 :]
|
||||
if before:
|
||||
index = get_index(before, records)
|
||||
if index == 0:
|
||||
return None, []
|
||||
|
||||
# TODO: not sure if this is correct
|
||||
records = records[:index]
|
||||
|
||||
if len(records) == 0:
|
||||
return None, []
|
||||
|
||||
# enforce limit
|
||||
if limit:
|
||||
records = records[:limit]
|
||||
return records[-1].id, records
|
||||
raise ValueError("Cannot run get_all_cursor with chroma")
|
||||
|
@ -1,11 +1,14 @@
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy import (
|
||||
BIGINT,
|
||||
BINARY,
|
||||
CHAR,
|
||||
JSON,
|
||||
Column,
|
||||
DateTime,
|
||||
@ -20,6 +23,7 @@ from sqlalchemy import (
|
||||
select,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import declarative_base, mapped_column, sessionmaker
|
||||
from sqlalchemy.orm.session import close_all_sessions
|
||||
from sqlalchemy.sql import func
|
||||
@ -29,15 +33,34 @@ from tqdm import tqdm
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
from memgpt.metadata import EmbeddingConfigColumn
|
||||
|
||||
# from memgpt.schemas.message import Message, Passage, Record, RecordType, ToolCall
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import ToolCall, ToolCallFunction
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.data_types import Message, Passage, Record, RecordType, ToolCall
|
||||
from memgpt.settings import settings
|
||||
|
||||
|
||||
# Custom UUID type
|
||||
class CommonUUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID(as_uuid=True))
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if dialect.name == "postgresql" or value is None:
|
||||
return value
|
||||
else:
|
||||
return str(value) # Convert UUID to string for SQLite
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if dialect.name == "postgresql" or value is None:
|
||||
return value
|
||||
else:
|
||||
return uuid.UUID(value)
|
||||
|
||||
|
||||
class CommonVector(TypeDecorator):
|
||||
"""Common type for representing vectors in SQLite"""
|
||||
|
||||
@ -70,6 +93,26 @@ class CommonVector(TypeDecorator):
|
||||
# Custom serialization / de-serialization for JSON columns
|
||||
|
||||
|
||||
class ToolCallColumn(TypeDecorator):
|
||||
"""Custom type for storing List[ToolCall] as JSON"""
|
||||
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
return dialect.type_descriptor(JSON())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
return [vars(v) for v in value]
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return [ToolCall(**v) for v in value]
|
||||
return value
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@ -77,8 +120,8 @@ def get_db_model(
|
||||
config: MemGPTConfig,
|
||||
table_name: str,
|
||||
table_type: TableType,
|
||||
user_id: str,
|
||||
agent_id: Optional[str] = None,
|
||||
user_id: uuid.UUID,
|
||||
agent_id: Optional[uuid.UUID] = None,
|
||||
dialect="postgresql",
|
||||
):
|
||||
# Define a helper function to create or get the model class
|
||||
@ -97,12 +140,14 @@ def get_db_model(
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
text = Column(String)
|
||||
doc_id = Column(String)
|
||||
agent_id = Column(String)
|
||||
source_id = Column(String)
|
||||
doc_id = Column(CommonUUID)
|
||||
agent_id = Column(CommonUUID)
|
||||
data_source = Column(String) # agent_name if agent, data_source name if from data source
|
||||
|
||||
# vector storage
|
||||
if dialect == "sqlite":
|
||||
@ -111,8 +156,9 @@ def get_db_model(
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
embedding_dim = Column(BIGINT)
|
||||
embedding_model = Column(String)
|
||||
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
metadata_ = Column(MutableJson)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
@ -127,11 +173,12 @@ def get_db_model(
|
||||
return Passage(
|
||||
text=self.text,
|
||||
embedding=self.embedding,
|
||||
embedding_config=self.embedding_config,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_model=self.embedding_model,
|
||||
doc_id=self.doc_id,
|
||||
user_id=self.user_id,
|
||||
id=self.id,
|
||||
source_id=self.source_id,
|
||||
data_source=self.data_source,
|
||||
agent_id=self.agent_id,
|
||||
metadata_=self.metadata_,
|
||||
created_at=self.created_at,
|
||||
@ -149,9 +196,11 @@ def get_db_model(
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
agent_id = Column(CommonUUID, nullable=False)
|
||||
|
||||
# openai info
|
||||
role = Column(String, nullable=False)
|
||||
@ -163,29 +212,31 @@ def get_db_model(
|
||||
# if role == "assistant", this MAY be specified
|
||||
# if role != "assistant", this must be null
|
||||
# TODO align with OpenAI spec of multiple tool calls
|
||||
# tool_calls = Column(ToolCallColumn)
|
||||
tool_calls = Column(JSON)
|
||||
tool_calls = Column(ToolCallColumn)
|
||||
|
||||
# tool call response info
|
||||
# if role == "tool", then this must be specified
|
||||
# if role != "tool", this must be null
|
||||
tool_call_id = Column(String)
|
||||
|
||||
# vector storage
|
||||
if dialect == "sqlite":
|
||||
embedding = Column(CommonVector)
|
||||
else:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
|
||||
embedding_dim = Column(BIGINT)
|
||||
embedding_model = Column(String)
|
||||
|
||||
# Add a datetime column, with default value as the current time
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
Index("message_idx_user", user_id, agent_id),
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}')>"
|
||||
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||
|
||||
def to_record(self):
|
||||
calls = (
|
||||
[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
|
||||
if self.tool_calls
|
||||
else None
|
||||
)
|
||||
if calls:
|
||||
assert isinstance(calls[0], ToolCall)
|
||||
return Message(
|
||||
user_id=self.user_id,
|
||||
agent_id=self.agent_id,
|
||||
@ -193,9 +244,11 @@ def get_db_model(
|
||||
name=self.name,
|
||||
text=self.text,
|
||||
model=self.model,
|
||||
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls,
|
||||
tool_call_id=self.tool_call_id,
|
||||
embedding=self.embedding,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_model=self.embedding_model,
|
||||
created_at=self.created_at,
|
||||
id=self.id,
|
||||
)
|
||||
@ -221,7 +274,7 @@ class SQLStorageConnector(StorageConnector):
|
||||
all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||
return all_filters
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0):
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0) -> Iterator[List[RecordType]]:
|
||||
filters = self.get_filters(filters)
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
@ -241,8 +294,8 @@ class SQLStorageConnector(StorageConnector):
|
||||
def get_all_cursor(
|
||||
self,
|
||||
filters: Optional[Dict] = {},
|
||||
after: str = None,
|
||||
before: str = None,
|
||||
after: uuid.UUID = None,
|
||||
before: uuid.UUID = None,
|
||||
limit: Optional[int] = 1000,
|
||||
order_by: str = "created_at",
|
||||
reverse: bool = False,
|
||||
@ -279,12 +332,12 @@ class SQLStorageConnector(StorageConnector):
|
||||
return (None, [])
|
||||
records = [record.to_record() for record in db_record_chunk]
|
||||
next_cursor = db_record_chunk[-1].id
|
||||
assert isinstance(next_cursor, str)
|
||||
assert isinstance(next_cursor, uuid.UUID)
|
||||
|
||||
# return (cursor, list[records])
|
||||
return (next_cursor, records)
|
||||
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None):
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
|
||||
filters = self.get_filters(filters)
|
||||
with self.session_maker() as session:
|
||||
if limit:
|
||||
@ -293,7 +346,7 @@ class SQLStorageConnector(StorageConnector):
|
||||
db_records = session.query(self.db_model).filter(*filters).all()
|
||||
return [record.to_record() for record in db_records]
|
||||
|
||||
def get(self, id: str):
|
||||
def get(self, id: uuid.UUID) -> Optional[Record]:
|
||||
with self.session_maker() as session:
|
||||
db_record = session.get(self.db_model, id)
|
||||
if db_record is None:
|
||||
@ -306,13 +359,13 @@ class SQLStorageConnector(StorageConnector):
|
||||
with self.session_maker() as session:
|
||||
return session.query(self.db_model).filter(*filters).count()
|
||||
|
||||
def insert(self, record):
|
||||
def insert(self, record: Record):
|
||||
raise NotImplementedError
|
||||
|
||||
def insert_many(self, records, show_progress=False):
|
||||
def insert_many(self, records: List[RecordType], show_progress=False):
|
||||
raise NotImplementedError
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
|
||||
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
|
||||
|
||||
def save(self):
|
||||
@ -417,7 +470,7 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
# create table
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
|
||||
filters = self.get_filters(filters)
|
||||
with self.session_maker() as session:
|
||||
results = session.scalars(
|
||||
@ -428,7 +481,7 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
records = [result.to_record() for result in results]
|
||||
return records
|
||||
|
||||
def insert_many(self, records, exists_ok=True, show_progress=False):
|
||||
def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
|
||||
@ -450,15 +503,14 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
with self.session_maker() as session:
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for record in iterable:
|
||||
# db_record = self.db_model(**vars(record))
|
||||
db_record = self.db_model(**record.dict())
|
||||
db_record = self.db_model(**vars(record))
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
|
||||
def insert(self, record, exists_ok=True):
|
||||
def insert(self, record: Record, exists_ok=True):
|
||||
self.insert_many([record], exists_ok=exists_ok)
|
||||
|
||||
def update(self, record):
|
||||
def update(self, record: RecordType):
|
||||
"""
|
||||
Updates a record in the database based on the provided Record object.
|
||||
"""
|
||||
@ -523,12 +575,12 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
|
||||
# import sqlite3
|
||||
import sqlite3
|
||||
|
||||
# sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
|
||||
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
|
||||
sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
|
||||
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
|
||||
|
||||
def insert_many(self, records, exists_ok=True, show_progress=False):
|
||||
def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
|
||||
@ -550,15 +602,14 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
with self.session_maker() as session:
|
||||
iterable = tqdm(records) if show_progress else records
|
||||
for record in iterable:
|
||||
# db_record = self.db_model(**vars(record))
|
||||
db_record = self.db_model(**record.dict())
|
||||
db_record = self.db_model(**vars(record))
|
||||
session.add(db_record)
|
||||
session.commit()
|
||||
|
||||
def insert(self, record, exists_ok=True):
|
||||
def insert(self, record: Record, exists_ok=True):
|
||||
self.insert_many([record], exists_ok=exists_ok)
|
||||
|
||||
def update(self, record):
|
||||
def update(self, record: Record):
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
"""
|
||||
|
@ -8,7 +8,7 @@ from lancedb.pydantic import LanceModel, Vector
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
from memgpt.schemas.message import Message, Passage, Record
|
||||
from memgpt.data_types import Message, Passage, Record
|
||||
|
||||
""" Initial implementation - not complete """
|
||||
|
||||
|
@ -5,14 +5,10 @@ We originally tried to use Llama Index VectorIndex, but their limited API was ex
|
||||
|
||||
import uuid
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Iterator, List, Optional, Tuple, Type, Union
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.schemas.document import Document
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.data_types import Document, Message, Passage, Record, RecordType
|
||||
from memgpt.utils import printd
|
||||
|
||||
|
||||
@ -39,7 +35,7 @@ DOCUMENT_TABLE_NAME = "memgpt_documents" # original documents (from source)
|
||||
class StorageConnector:
|
||||
"""Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
|
||||
|
||||
type: Type[BaseModel]
|
||||
type: Type[Record]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -140,15 +136,15 @@ class StorageConnector:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000):
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=10):
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id: uuid.UUID):
|
||||
def get(self, id: uuid.UUID) -> Optional[RecordType]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@ -156,15 +152,15 @@ class StorageConnector:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, record):
|
||||
def insert(self, record: RecordType):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def insert_many(self, records, show_progress=False):
|
||||
def insert_many(self, records: List[RecordType], show_progress=False):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
@ -5,7 +5,7 @@ from typing import Optional
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.data_types import Message
|
||||
|
||||
init(autoreset=True)
|
||||
|
||||
|
@ -3,6 +3,7 @@ import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional
|
||||
@ -18,12 +19,12 @@ from memgpt.cli.cli_config import configure
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig, User
|
||||
from memgpt.log import get_logger
|
||||
from memgpt.memory import ChatMemory
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.enums import OptionState
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.schemas.memory import ChatMemory, Memory
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
from memgpt.models.pydantic_models import OptionState
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT
|
||||
from memgpt.server.server import logger as server_logger
|
||||
|
||||
@ -36,6 +37,14 @@ from memgpt.utils import open_folder_in_explorer, printd
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def migrate(
|
||||
debug: Annotated[bool, typer.Option(help="Print extra tracebacks for failed migrations")] = False,
|
||||
):
|
||||
"""Migrate old agents (pre 0.2.12) to the new database system"""
|
||||
migrate_all_agents(debug=debug)
|
||||
migrate_all_sources(debug=debug)
|
||||
|
||||
|
||||
class QuickstartChoice(Enum):
|
||||
openai = "openai"
|
||||
# azure = "azure"
|
||||
@ -171,10 +180,13 @@ def quickstart(
|
||||
else:
|
||||
# Load the file from the relative path
|
||||
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
|
||||
# print("SCRIPT", script_dir)
|
||||
backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json")
|
||||
# print("FILE PATH", backup_config_path)
|
||||
try:
|
||||
with open(backup_config_path, "r", encoding="utf-8") as file:
|
||||
backup_config = json.load(file)
|
||||
# print(backup_config)
|
||||
printd("Loaded config file successfully.")
|
||||
new_config, config_was_modified = set_config_with_dict(backup_config)
|
||||
except FileNotFoundError:
|
||||
@ -201,6 +213,7 @@ def quickstart(
|
||||
# Parse the response content as JSON
|
||||
config = response.json()
|
||||
# Output a success message and the first few items in the dictionary as a sample
|
||||
print("JSON config file downloaded successfully.")
|
||||
new_config, config_was_modified = set_config_with_dict(config)
|
||||
else:
|
||||
typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)
|
||||
@ -279,6 +292,21 @@ class ServerChoice(Enum):
|
||||
ws_api = "websocket"
|
||||
|
||||
|
||||
def create_default_user_or_exit(config: MemGPTConfig, ms: MetadataStore):
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
ms.create_user(User(id=user_id))
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
else:
|
||||
return user
|
||||
else:
|
||||
return user
|
||||
|
||||
|
||||
def server(
|
||||
type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest",
|
||||
port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
|
||||
@ -295,8 +323,8 @@ def server(
|
||||
|
||||
if MemGPTConfig.exists():
|
||||
config = MemGPTConfig.load()
|
||||
MetadataStore(config)
|
||||
client = create_client() # triggers user creation
|
||||
ms = MetadataStore(config)
|
||||
create_default_user_or_exit(config, ms)
|
||||
else:
|
||||
typer.secho(f"No configuration exists. Run memgpt configure before starting the server.", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
@ -416,42 +444,42 @@ def run(
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
server_logger.setLevel(logging.CRITICAL)
|
||||
|
||||
# from memgpt.migrate import (
|
||||
# VERSION_CUTOFF,
|
||||
# config_is_compatible,
|
||||
# wipe_config_and_reconfigure,
|
||||
# )
|
||||
from memgpt.migrate import (
|
||||
VERSION_CUTOFF,
|
||||
config_is_compatible,
|
||||
wipe_config_and_reconfigure,
|
||||
)
|
||||
|
||||
# if not config_is_compatible(allow_empty=True):
|
||||
# typer.secho(f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}\n", fg=typer.colors.RED)
|
||||
# choices = [
|
||||
# "Run the full config setup (recommended)",
|
||||
# "Create a new config using defaults",
|
||||
# "Cancel",
|
||||
# ]
|
||||
# selection = questionary.select(
|
||||
# f"To use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}), or regenerate your config. Would you like to proceed?",
|
||||
# choices=choices,
|
||||
# default=choices[0],
|
||||
# ).ask()
|
||||
# if selection == choices[0]:
|
||||
# try:
|
||||
# wipe_config_and_reconfigure()
|
||||
# except Exception as e:
|
||||
# typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
|
||||
# raise
|
||||
# elif selection == choices[1]:
|
||||
# try:
|
||||
# # Don't create a config, so that the next block of code asking about quickstart is run
|
||||
# wipe_config_and_reconfigure(run_configure=False, create_config=False)
|
||||
# except Exception as e:
|
||||
# typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
|
||||
# raise
|
||||
# else:
|
||||
# typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED)
|
||||
# raise KeyboardInterrupt()
|
||||
if not config_is_compatible(allow_empty=True):
|
||||
typer.secho(f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}\n", fg=typer.colors.RED)
|
||||
choices = [
|
||||
"Run the full config setup (recommended)",
|
||||
"Create a new config using defaults",
|
||||
"Cancel",
|
||||
]
|
||||
selection = questionary.select(
|
||||
f"To use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}), or regenerate your config. Would you like to proceed?",
|
||||
choices=choices,
|
||||
default=choices[0],
|
||||
).ask()
|
||||
if selection == choices[0]:
|
||||
try:
|
||||
wipe_config_and_reconfigure()
|
||||
except Exception as e:
|
||||
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
elif selection == choices[1]:
|
||||
try:
|
||||
# Don't create a config, so that the next block of code asking about quickstart is run
|
||||
wipe_config_and_reconfigure(run_configure=False, create_config=False)
|
||||
except Exception as e:
|
||||
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
else:
|
||||
typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED)
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
# typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN)
|
||||
typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN)
|
||||
|
||||
if not MemGPTConfig.exists():
|
||||
# if no config, ask about quickstart
|
||||
@ -496,12 +524,11 @@ def run(
|
||||
|
||||
# read user id from config
|
||||
ms = MetadataStore(config)
|
||||
client = create_client()
|
||||
client.user_id
|
||||
user = create_default_user_or_exit(config, ms)
|
||||
|
||||
# determine agent to use, if not provided
|
||||
if not yes and not agent:
|
||||
agents = client.list_agents()
|
||||
agents = ms.list_agents(user_id=user.id)
|
||||
agents = [a.name for a in agents]
|
||||
|
||||
if len(agents) > 0:
|
||||
@ -513,11 +540,7 @@ def run(
|
||||
agent = questionary.select("Select agent:", choices=agents).ask()
|
||||
|
||||
# create agent config
|
||||
if agent:
|
||||
agent_id = client.get_agent_id(agent)
|
||||
agent_state = client.get_agent(agent_id)
|
||||
else:
|
||||
agent_state = None
|
||||
agent_state = ms.get_agent(agent_name=agent, user_id=user.id) if agent else None
|
||||
human = human if human else config.human
|
||||
persona = persona if persona else config.persona
|
||||
if agent and agent_state: # use existing agent
|
||||
@ -574,12 +597,13 @@ def run(
|
||||
# agent_state.state["system"] = system
|
||||
|
||||
# Update the agent with any overrides
|
||||
agent_state = client.update_agent(
|
||||
agent_id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
llm_config=agent_state.llm_config,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
)
|
||||
ms.update_agent(agent_state)
|
||||
tools = []
|
||||
for tool_name in agent_state.tools:
|
||||
tool = ms.get_tool(tool_name, agent_state.user_id)
|
||||
if tool is None:
|
||||
typer.secho(f"Couldn't find tool {tool_name} in database, please run `memgpt add tool`", fg=typer.colors.RED)
|
||||
tools.append(tool)
|
||||
|
||||
# create agent
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
|
||||
@ -622,52 +646,55 @@ def run(
|
||||
llm_config.model_endpoint_type = model_endpoint_type
|
||||
|
||||
# create agent
|
||||
client = create_client()
|
||||
human_obj = client.get_human(client.get_human_id(name=human))
|
||||
persona_obj = client.get_persona(client.get_persona_id(name=persona))
|
||||
if human_obj is None:
|
||||
typer.secho(f"Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
|
||||
try:
|
||||
client = create_client()
|
||||
human_obj = ms.get_human(human, user.id)
|
||||
persona_obj = ms.get_persona(persona, user.id)
|
||||
# TODO pull system prompts from the metadata store
|
||||
# NOTE: will be overriden later to a default
|
||||
if system_file:
|
||||
try:
|
||||
with open(system_file, "r", encoding="utf-8") as file:
|
||||
system = file.read().strip()
|
||||
printd("Loaded system file successfully.")
|
||||
except FileNotFoundError:
|
||||
typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED)
|
||||
system_prompt = system if system else None
|
||||
if human_obj is None:
|
||||
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
|
||||
if persona_obj is None:
|
||||
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
|
||||
|
||||
memory = ChatMemory(human=human_obj.text, persona=persona_obj.text, limit=core_memory_limit)
|
||||
metadata = {"human": human_obj.name, "persona": persona_obj.name}
|
||||
|
||||
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
|
||||
|
||||
# add tools
|
||||
agent_state = client.create_agent(
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
embedding_config=embedding_config,
|
||||
llm_config=llm_config,
|
||||
memory=memory,
|
||||
metadata=metadata,
|
||||
)
|
||||
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
|
||||
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
|
||||
|
||||
memgpt_agent = Agent(
|
||||
interface=interface(),
|
||||
agent_state=agent_state,
|
||||
tools=tools,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
|
||||
)
|
||||
save_agent(agent=memgpt_agent, ms=ms)
|
||||
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to create agent from provided information:\n{e}", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
if persona_obj is None:
|
||||
typer.secho(f"Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
|
||||
if system_file:
|
||||
try:
|
||||
with open(system_file, "r", encoding="utf-8") as file:
|
||||
system = file.read().strip()
|
||||
printd("Loaded system file successfully.")
|
||||
except FileNotFoundError:
|
||||
typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED)
|
||||
system_prompt = system if system else None
|
||||
|
||||
memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit)
|
||||
metadata = {"human": human_obj.name, "persona": persona_obj.name}
|
||||
|
||||
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
|
||||
|
||||
# add tools
|
||||
agent_state = client.create_agent(
|
||||
name=agent_name,
|
||||
system=system_prompt,
|
||||
embedding_config=embedding_config,
|
||||
llm_config=llm_config,
|
||||
memory=memory,
|
||||
metadata=metadata,
|
||||
)
|
||||
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
|
||||
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
|
||||
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
|
||||
|
||||
memgpt_agent = Agent(
|
||||
interface=interface(),
|
||||
agent_state=agent_state,
|
||||
tools=tools,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
|
||||
)
|
||||
save_agent(agent=memgpt_agent, ms=ms)
|
||||
typer.secho(f"🎉 Created new agent '{memgpt_agent.agent_state.name}' (id={memgpt_agent.agent_state.id})", fg=typer.colors.GREEN)
|
||||
|
||||
# start event loop
|
||||
@ -692,10 +719,19 @@ def delete_agent(
|
||||
"""Delete an agent from the database"""
|
||||
# use client ID is no user_id provided
|
||||
config = MemGPTConfig.load()
|
||||
MetadataStore(config)
|
||||
client = create_client(user_id=user_id)
|
||||
agent = client.get_agent_by_name(agent_name)
|
||||
if not agent:
|
||||
ms = MetadataStore(config)
|
||||
if user_id is None:
|
||||
user = create_default_user_or_exit(config, ms)
|
||||
else:
|
||||
user = ms.get_user(user_id=uuid.UUID(user_id))
|
||||
|
||||
try:
|
||||
agent = ms.get_agent(agent_name=agent_name, user_id=user.id)
|
||||
except Exception as e:
|
||||
typer.secho(f"Failed to get agent {agent_name}\n{e}", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
|
||||
if agent is None:
|
||||
typer.secho(f"Couldn't find agent named '{agent_name}' to delete", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
|
||||
@ -707,8 +743,7 @@ def delete_agent(
|
||||
return
|
||||
|
||||
try:
|
||||
# delete the agent
|
||||
client.delete_agent(agent.id)
|
||||
ms.delete_agent(agent_id=agent.id)
|
||||
typer.secho(f"🕊️ Successfully deleted agent '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN)
|
||||
except Exception:
|
||||
typer.secho(f"Failed to delete agent '{agent_name}' (id={agent.id})", fg=typer.colors.RED)
|
||||
|
@ -1,8 +1,8 @@
|
||||
import ast
|
||||
import builtins
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Annotated, List, Optional
|
||||
from typing import Annotated, Optional
|
||||
|
||||
import questionary
|
||||
import typer
|
||||
@ -13,6 +13,7 @@ from memgpt import utils
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import LLM_MAX_TOKENS, MEMGPT_DIR
|
||||
from memgpt.credentials import SUPPORTED_AUTH_TYPES, MemGPTCredentials
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig, Source, User
|
||||
from memgpt.llm_api.anthropic import (
|
||||
anthropic_get_model_list,
|
||||
antropic_get_model_context_window,
|
||||
@ -35,8 +36,7 @@ from memgpt.local_llm.constants import (
|
||||
DEFAULT_WRAPPER_NAME,
|
||||
)
|
||||
from memgpt.local_llm.utils import get_available_wrappers
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.server.utils import shorten_key_middle
|
||||
|
||||
app = typer.Typer()
|
||||
@ -1070,10 +1070,17 @@ def configure():
|
||||
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
|
||||
config.save()
|
||||
|
||||
from memgpt import create_client
|
||||
|
||||
client = create_client()
|
||||
print("User ID:", client.user_id)
|
||||
# create user records
|
||||
ms = MetadataStore(config)
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user = User(
|
||||
id=uuid.UUID(config.anon_clientid),
|
||||
)
|
||||
if ms.get_user(user_id):
|
||||
# update user
|
||||
ms.update_user(user)
|
||||
else:
|
||||
ms.create_user(user)
|
||||
|
||||
|
||||
class ListChoice(str, Enum):
|
||||
@ -1087,14 +1094,17 @@ class ListChoice(str, Enum):
|
||||
def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
from memgpt.client.client import create_client
|
||||
|
||||
client = create_client()
|
||||
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
|
||||
table = ColorTable(theme=Themes.OCEAN)
|
||||
if arg == ListChoice.agents:
|
||||
"""List all agents"""
|
||||
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
|
||||
for agent in tqdm(client.list_agents()):
|
||||
# TODO: add this function
|
||||
sources = client.list_attached_sources(agent_id=agent.id)
|
||||
source_ids = client.list_attached_sources(agent_id=agent.id)
|
||||
assert all([source_id is not None and isinstance(source_id, uuid.UUID) for source_id in source_ids])
|
||||
sources = [client.get_source(source_id=source_id) for source_id in source_ids]
|
||||
assert all([source is not None and isinstance(source, Source)] for source in sources)
|
||||
source_names = [source.name for source in sources if source is not None]
|
||||
table.add_row(
|
||||
[
|
||||
@ -1102,8 +1112,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
agent.llm_config.model,
|
||||
agent.embedding_config.embedding_model,
|
||||
agent.embedding_config.embedding_dim,
|
||||
agent.memory.get_block("persona").value[:100] + "...",
|
||||
agent.memory.get_block("human").value[:100] + "...",
|
||||
agent._metadata.get("persona", ""),
|
||||
agent._metadata.get("human", ""),
|
||||
",".join(source_names),
|
||||
utils.format_datetime(agent.created_at),
|
||||
]
|
||||
@ -1113,13 +1123,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
"""List all humans"""
|
||||
table.field_names = ["Name", "Text"]
|
||||
for human in client.list_humans():
|
||||
table.add_row([human.name, human.value.replace("\n", "")[:100]])
|
||||
table.add_row([human.name, human.text.replace("\n", "")[:100]])
|
||||
print(table)
|
||||
elif arg == ListChoice.personas:
|
||||
"""List all personas"""
|
||||
table.field_names = ["Name", "Text"]
|
||||
for persona in client.list_personas():
|
||||
table.add_row([persona.name, persona.value.replace("\n", "")[:100]])
|
||||
table.add_row([persona.name, persona.text.replace("\n", "")[:100]])
|
||||
print(table)
|
||||
elif arg == ListChoice.sources:
|
||||
"""List all data sources"""
|
||||
@ -1149,63 +1159,6 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
return table
|
||||
|
||||
|
||||
@app.command()
|
||||
def add_tool(
|
||||
filename: str = typer.Option(..., help="Path to the Python file containing the function"),
|
||||
name: Optional[str] = typer.Option(None, help="Name of the tool"),
|
||||
update: bool = typer.Option(True, help="Update the tool if it already exists"),
|
||||
tags: Optional[List[str]] = typer.Option(None, help="Tags for the tool"),
|
||||
):
|
||||
"""Add or update a tool from a Python file."""
|
||||
from memgpt.client.client import create_client
|
||||
|
||||
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
|
||||
|
||||
# 1. Parse the Python file
|
||||
with open(filename, "r", encoding="utf-8") as file:
|
||||
source_code = file.read()
|
||||
|
||||
# 2. Parse the source code to extract the function
|
||||
# Note: here we assume it is one function only in the file.
|
||||
module = ast.parse(source_code)
|
||||
func_def = None
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
func_def = node
|
||||
break
|
||||
|
||||
if not func_def:
|
||||
raise ValueError("No function found in the provided file")
|
||||
|
||||
# 3. Compile the function to make it callable
|
||||
# Explanation courtesy of GPT-4:
|
||||
# Compile the AST (Abstract Syntax Tree) node representing the function definition into a code object
|
||||
# ast.Module creates a module node containing the function definition (func_def)
|
||||
# compile converts the AST into a code object that can be executed by the Python interpreter
|
||||
# The exec function executes the compiled code object in the current context,
|
||||
# effectively defining the function within the current namespace
|
||||
exec(compile(ast.Module([func_def], []), filename, "exec"))
|
||||
# Retrieve the function object by evaluating its name in the current namespace
|
||||
# eval looks up the function name in the current scope and returns the function object
|
||||
func = eval(func_def.name)
|
||||
|
||||
# 4. Add or update the tool
|
||||
tool = client.create_tool(func=func, name=name, tags=tags, update=update)
|
||||
print(f"Tool {tool.name} added successfully")
|
||||
|
||||
|
||||
@app.command()
|
||||
def list_tools():
|
||||
"""List all available tools."""
|
||||
from memgpt.client.client import create_client
|
||||
|
||||
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
|
||||
|
||||
tools = client.list_tools()
|
||||
for tool in tools:
|
||||
print(f"Tool: {tool.name}")
|
||||
|
||||
|
||||
@app.command()
|
||||
def add(
|
||||
option: str, # [human, persona]
|
||||
@ -1221,27 +1174,23 @@ def add(
|
||||
assert text is None, "Cannot specify both text and filename"
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
else:
|
||||
assert text is not None, "Must specify either text or filename"
|
||||
if option == "persona":
|
||||
persona_id = client.get_persona_id(name)
|
||||
if persona_id:
|
||||
client.get_persona(persona_id)
|
||||
persona = client.get_persona(name)
|
||||
if persona:
|
||||
# config if user wants to overwrite
|
||||
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
|
||||
return
|
||||
client.update_persona(persona_id, text=text)
|
||||
client.update_persona(name=name, text=text)
|
||||
else:
|
||||
client.create_persona(name=name, text=text)
|
||||
|
||||
elif option == "human":
|
||||
human_id = client.get_human_id(name)
|
||||
if human_id:
|
||||
human = client.get_human(human_id)
|
||||
human = client.get_human(name=name)
|
||||
if human:
|
||||
# config if user wants to overwrite
|
||||
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
|
||||
return
|
||||
client.update_human(human_id, text=text)
|
||||
client.update_human(name=name, text=text)
|
||||
else:
|
||||
human = client.create_human(name=name, text=text)
|
||||
else:
|
||||
@ -1258,21 +1207,21 @@ def delete(option: str, name: str):
|
||||
# delete from metadata
|
||||
if option == "source":
|
||||
# delete metadata
|
||||
source_id = client.get_source_id(name)
|
||||
assert source_id is not None, f"Source {name} does not exist"
|
||||
client.delete_source(source_id)
|
||||
source = client.get_source(name)
|
||||
assert source is not None, f"Source {name} does not exist"
|
||||
client.delete_source(source_id=source.id)
|
||||
elif option == "agent":
|
||||
agent_id = client.get_agent_id(name)
|
||||
assert agent_id is not None, f"Agent {name} does not exist"
|
||||
client.delete_agent(agent_id=agent_id)
|
||||
agent = client.get_agent(agent_name=name)
|
||||
assert agent is not None, f"Agent {name} does not exist"
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
elif option == "human":
|
||||
human_id = client.get_human_id(name)
|
||||
assert human_id is not None, f"Human {name} does not exist"
|
||||
client.delete_human(human_id)
|
||||
human = client.get_human(name=name)
|
||||
assert human is not None, f"Human {name} does not exist"
|
||||
client.delete_human(name=name)
|
||||
elif option == "persona":
|
||||
persona_id = client.get_persona_id(name)
|
||||
assert persona_id is not None, f"Persona {name} does not exist"
|
||||
client.delete_persona(persona_id)
|
||||
persona = client.get_persona(name=name)
|
||||
assert persona is not None, f"Persona {name} does not exist"
|
||||
client.delete_persona(name=name)
|
||||
else:
|
||||
raise ValueError(f"Option {option} not implemented")
|
||||
|
||||
|
@ -13,8 +13,15 @@ from typing import Annotated, List, Optional
|
||||
|
||||
import typer
|
||||
|
||||
from memgpt import create_client
|
||||
from memgpt.data_sources.connectors import DirectoryConnector
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.data_sources.connectors import (
|
||||
DirectoryConnector,
|
||||
VectorDBConnector,
|
||||
load_data,
|
||||
)
|
||||
from memgpt.data_types import Source
|
||||
from memgpt.metadata import MetadataStore
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@ -82,20 +89,41 @@ def load_directory(
|
||||
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
|
||||
description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None,
|
||||
):
|
||||
client = create_client()
|
||||
|
||||
# create connector
|
||||
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
|
||||
|
||||
# create source
|
||||
source = client.create_source(name=name)
|
||||
|
||||
# load data
|
||||
try:
|
||||
client.load_data(connector, source_name=name)
|
||||
except Exception as e:
|
||||
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
|
||||
client.delete_source(source.id)
|
||||
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
|
||||
config = MemGPTConfig.load()
|
||||
if not user_id:
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
ms = MetadataStore(config)
|
||||
source = Source(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
description=description,
|
||||
)
|
||||
ms.create_source(source)
|
||||
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
# TODO: also get document store
|
||||
|
||||
# ingest data into passage/document store
|
||||
try:
|
||||
num_passages, num_documents = load_data(
|
||||
connector=connector,
|
||||
source=source,
|
||||
embedding_config=config.default_embedding_config,
|
||||
document_store=None,
|
||||
passage_store=passage_storage,
|
||||
)
|
||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
except Exception as e:
|
||||
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
|
||||
ms.delete_source(source_id=source.id)
|
||||
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
|
||||
|
||||
# @app.command("webpage")
|
||||
@ -111,6 +139,56 @@ def load_directory(
|
||||
#
|
||||
# except ValueError as e:
|
||||
# typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)
|
||||
#
|
||||
#
|
||||
# @app.command("database")
|
||||
# def load_database(
|
||||
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
|
||||
# query: Annotated[str, typer.Option(help="Database query.")],
|
||||
# dump_path: Annotated[Optional[str], typer.Option(help="Path to dump file.")] = None,
|
||||
# scheme: Annotated[Optional[str], typer.Option(help="Database scheme.")] = None,
|
||||
# host: Annotated[Optional[str], typer.Option(help="Database host.")] = None,
|
||||
# port: Annotated[Optional[int], typer.Option(help="Database port.")] = None,
|
||||
# user: Annotated[Optional[str], typer.Option(help="Database user.")] = None,
|
||||
# password: Annotated[Optional[str], typer.Option(help="Database password.")] = None,
|
||||
# dbname: Annotated[Optional[str], typer.Option(help="Database name.")] = None,
|
||||
# ):
|
||||
# try:
|
||||
# from llama_index.readers.database import DatabaseReader
|
||||
#
|
||||
# print(dump_path, scheme)
|
||||
#
|
||||
# if dump_path is not None:
|
||||
# # read from database dump file
|
||||
# from sqlalchemy import create_engine
|
||||
#
|
||||
# engine = create_engine(f"sqlite:///{dump_path}")
|
||||
#
|
||||
# db = DatabaseReader(engine=engine)
|
||||
# else:
|
||||
# assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
|
||||
# assert scheme is not None, "Must provide database scheme."
|
||||
# assert host is not None, "Must provide database host."
|
||||
# assert port is not None, "Must provide database port."
|
||||
# assert user is not None, "Must provide database user."
|
||||
# assert password is not None, "Must provide database password."
|
||||
# assert dbname is not None, "Must provide database name."
|
||||
#
|
||||
# db = DatabaseReader(
|
||||
# scheme=scheme, # Database Scheme
|
||||
# host=host, # Database Host
|
||||
# port=str(port), # Database Port
|
||||
# user=user, # Database User
|
||||
# password=password, # Database Password
|
||||
# dbname=dbname, # Database Name
|
||||
# )
|
||||
#
|
||||
# # load data
|
||||
# docs = db.load_data(query=query)
|
||||
# store_docs(name, docs)
|
||||
# except ValueError as e:
|
||||
# typer.secho(f"Failed to load database from provided information.\n{e}", fg=typer.colors.RED)
|
||||
#
|
||||
|
||||
|
||||
@app.command("vector-database")
|
||||
@ -123,44 +201,43 @@ def load_vector_database(
|
||||
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
|
||||
):
|
||||
"""Load pre-computed embeddings into MemGPT from a database."""
|
||||
raise NotImplementedError
|
||||
# try:
|
||||
# config = MemGPTConfig.load()
|
||||
# connector = VectorDBConnector(
|
||||
# uri=uri,
|
||||
# table_name=table_name,
|
||||
# text_column=text_column,
|
||||
# embedding_column=embedding_column,
|
||||
# embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
# )
|
||||
# if not user_id:
|
||||
# user_id = uuid.UUID(config.anon_clientid)
|
||||
try:
|
||||
config = MemGPTConfig.load()
|
||||
connector = VectorDBConnector(
|
||||
uri=uri,
|
||||
table_name=table_name,
|
||||
text_column=text_column,
|
||||
embedding_column=embedding_column,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
)
|
||||
if not user_id:
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
# ms = MetadataStore(config)
|
||||
# source = Source(
|
||||
# name=name,
|
||||
# user_id=user_id,
|
||||
# embedding_model=config.default_embedding_config.embedding_model,
|
||||
# embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
# )
|
||||
# ms.create_source(source)
|
||||
# passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
# # TODO: also get document store
|
||||
ms = MetadataStore(config)
|
||||
source = Source(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
)
|
||||
ms.create_source(source)
|
||||
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
# TODO: also get document store
|
||||
|
||||
# # ingest data into passage/document store
|
||||
# try:
|
||||
# num_passages, num_documents = load_data(
|
||||
# connector=connector,
|
||||
# source=source,
|
||||
# embedding_config=config.default_embedding_config,
|
||||
# document_store=None,
|
||||
# passage_store=passage_storage,
|
||||
# )
|
||||
# print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
# except Exception as e:
|
||||
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
|
||||
# ms.delete_source(source_id=source.id)
|
||||
# ingest data into passage/document store
|
||||
try:
|
||||
num_passages, num_documents = load_data(
|
||||
connector=connector,
|
||||
source=source,
|
||||
embedding_config=config.default_embedding_config,
|
||||
document_store=None,
|
||||
passage_store=passage_storage,
|
||||
)
|
||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
except Exception as e:
|
||||
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
|
||||
ms.delete_source(source_id=source.id)
|
||||
|
||||
# except ValueError as e:
|
||||
# typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
|
||||
# raise
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
@ -5,8 +6,19 @@ from requests import HTTPError
|
||||
|
||||
from memgpt.functions.functions import parse_source_code
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
from memgpt.schemas.api_key import APIKey, APIKeyCreate
|
||||
from memgpt.schemas.user import User, UserCreate
|
||||
from memgpt.server.rest_api.admin.tools import (
|
||||
CreateToolRequest,
|
||||
ListToolsResponse,
|
||||
ToolModel,
|
||||
)
|
||||
from memgpt.server.rest_api.admin.users import (
|
||||
CreateAPIKeyResponse,
|
||||
CreateUserResponse,
|
||||
DeleteAPIKeyResponse,
|
||||
DeleteUserResponse,
|
||||
GetAllUsersResponse,
|
||||
GetAPIKeysResponse,
|
||||
)
|
||||
|
||||
|
||||
class Admin:
|
||||
@ -21,7 +33,7 @@ class Admin:
|
||||
self.token = token
|
||||
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
|
||||
|
||||
def get_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[User]:
|
||||
def get_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50):
|
||||
params = {}
|
||||
if cursor:
|
||||
params["cursor"] = str(cursor)
|
||||
@ -30,54 +42,54 @@ class Admin:
|
||||
response = requests.get(f"{self.base_url}/admin/users", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return [User(**user) for user in response.json()]
|
||||
return GetAllUsersResponse(**response.json())
|
||||
|
||||
def create_key(self, user_id: str, key_name: Optional[str] = None) -> APIKey:
|
||||
request = APIKeyCreate(user_id=user_id, name=key_name)
|
||||
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=request.model_dump())
|
||||
def create_key(self, user_id: uuid.UUID, key_name: str):
|
||||
payload = {"user_id": str(user_id), "key_name": key_name}
|
||||
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return APIKey(**response.json())
|
||||
return CreateAPIKeyResponse(**response.json())
|
||||
|
||||
def get_keys(self, user_id: str) -> List[APIKey]:
|
||||
def get_keys(self, user_id: uuid.UUID):
|
||||
params = {"user_id": str(user_id)}
|
||||
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return [APIKey(**key) for key in response.json()]
|
||||
return GetAPIKeysResponse(**response.json()).api_key_list
|
||||
|
||||
def delete_key(self, api_key: str) -> APIKey:
|
||||
def delete_key(self, api_key: str):
|
||||
params = {"api_key": api_key}
|
||||
response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return APIKey(**response.json())
|
||||
return DeleteAPIKeyResponse(**response.json())
|
||||
|
||||
def create_user(self, name: Optional[str] = None) -> User:
|
||||
request = UserCreate(name=name)
|
||||
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=request.model_dump())
|
||||
def create_user(self, user_id: Optional[uuid.UUID] = None):
|
||||
payload = {"user_id": str(user_id) if user_id else None}
|
||||
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
response_json = response.json()
|
||||
return User(**response_json)
|
||||
return CreateUserResponse(**response_json)
|
||||
|
||||
def delete_user(self, user_id: str) -> User:
|
||||
def delete_user(self, user_id: uuid.UUID):
|
||||
params = {"user_id": str(user_id)}
|
||||
response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return User(**response.json())
|
||||
return DeleteUserResponse(**response.json())
|
||||
|
||||
def _reset_server(self):
|
||||
# DANGER: this will delete all users and keys
|
||||
# clear all state associated with users
|
||||
# TODO: clear out all agents, presets, etc.
|
||||
users = self.get_users()
|
||||
users = self.get_users().user_list
|
||||
for user in users:
|
||||
keys = self.get_keys(user.id)
|
||||
keys = self.get_keys(user["user_id"])
|
||||
for key in keys:
|
||||
self.delete_key(key.key)
|
||||
self.delete_user(user.id)
|
||||
self.delete_key(key)
|
||||
self.delete_user(user["user_id"])
|
||||
|
||||
# tools
|
||||
def create_tool(
|
||||
@ -119,7 +131,7 @@ class Admin:
|
||||
raise ValueError(f"Failed to create tool: {response.text}")
|
||||
return ToolModel(**response.json())
|
||||
|
||||
def list_tools(self):
|
||||
def list_tools(self) -> ListToolsResponse:
|
||||
response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers)
|
||||
return ListToolsResponse(**response.json()).tools
|
||||
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -4,7 +4,6 @@ import json
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import memgpt
|
||||
import memgpt.utils as utils
|
||||
@ -16,10 +15,8 @@ from memgpt.constants import (
|
||||
DEFAULT_PRESET,
|
||||
MEMGPT_DIR,
|
||||
)
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig
|
||||
from memgpt.log import get_logger
|
||||
from memgpt.schemas.agent import AgentState
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@ -102,19 +99,19 @@ class MemGPTConfig:
|
||||
return uuid.UUID(int=uuid.getnode()).hex
|
||||
|
||||
@classmethod
|
||||
def load(cls, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None) -> "MemGPTConfig":
|
||||
def load(cls) -> "MemGPTConfig":
|
||||
# avoid circular import
|
||||
from memgpt.migrate import VERSION_CUTOFF, config_is_compatible
|
||||
from memgpt.utils import printd
|
||||
|
||||
# from memgpt.migrate import VERSION_CUTOFF, config_is_compatible
|
||||
# if not config_is_compatible(allow_empty=True):
|
||||
# error_message = " ".join(
|
||||
# [
|
||||
# f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}.",
|
||||
# f"\nTo use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}) or regenerate your config using `memgpt configure`, or `memgpt migrate` if you would like to migrate old agents.",
|
||||
# ]
|
||||
# )
|
||||
# raise ValueError(error_message)
|
||||
if not config_is_compatible(allow_empty=True):
|
||||
error_message = " ".join(
|
||||
[
|
||||
f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}.",
|
||||
f"\nTo use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}) or regenerate your config using `memgpt configure`, or `memgpt migrate` if you would like to migrate old agents.",
|
||||
]
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
config = configparser.ConfigParser()
|
||||
|
||||
@ -192,9 +189,6 @@ class MemGPTConfig:
|
||||
|
||||
return cls(**config_dict)
|
||||
|
||||
# assert embedding_config is not None, "Embedding config must be provided if config does not exist"
|
||||
# assert llm_config is not None, "LLM config must be provided if config does not exist"
|
||||
|
||||
# create new config
|
||||
anon_clientid = MemGPTConfig.generate_uuid()
|
||||
config = cls(anon_clientid=anon_clientid, config_path=config_path)
|
||||
|
@ -4,10 +4,8 @@ import typer
|
||||
from llama_index.core import Document as LlamaIndexDocument
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector
|
||||
from memgpt.data_types import Document, EmbeddingConfig, Passage, Source
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.schemas.document import Document
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.schemas.source import Source
|
||||
from memgpt.utils import create_uuid_from_string
|
||||
|
||||
|
||||
@ -22,11 +20,17 @@ class DataConnector:
|
||||
def load_data(
|
||||
connector: DataConnector,
|
||||
source: Source,
|
||||
embedding_config: EmbeddingConfig,
|
||||
passage_store: StorageConnector,
|
||||
document_store: Optional[StorageConnector] = None,
|
||||
):
|
||||
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
|
||||
embedding_config = source.embedding_config
|
||||
assert (
|
||||
source.embedding_model == embedding_config.embedding_model
|
||||
), f"Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}"
|
||||
assert (
|
||||
source.embedding_dim == embedding_config.embedding_dim
|
||||
), f"Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}."
|
||||
|
||||
# embedding model
|
||||
embed_model = embedding_model(embedding_config)
|
||||
@ -39,9 +43,10 @@ def load_data(
|
||||
for document_text, document_metadata in connector.generate_documents():
|
||||
# insert document into storage
|
||||
document = Document(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{document_text}"),
|
||||
text=document_text,
|
||||
metadata_=document_metadata,
|
||||
source_id=source.id,
|
||||
metadata=document_metadata,
|
||||
data_source=source.name,
|
||||
user_id=source.user_id,
|
||||
)
|
||||
document_count += 1
|
||||
@ -73,15 +78,16 @@ def load_data(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||
text=passage_text,
|
||||
doc_id=document.id,
|
||||
source_id=source.id,
|
||||
metadata_=passage_metadata,
|
||||
user_id=source.user_id,
|
||||
embedding_config=source.embedding_config,
|
||||
data_source=source.name,
|
||||
embedding_dim=source.embedding_dim,
|
||||
embedding_model=source.embedding_model,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
hashable_embedding = tuple(passage.embedding)
|
||||
document_name = document.metadata_.get("file_path", document.id)
|
||||
document_name = document.metadata.get("file_path", document.id)
|
||||
if hashable_embedding in embedding_to_document_name:
|
||||
typer.secho(
|
||||
f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
|
||||
@ -144,7 +150,7 @@ class DirectoryConnector(DataConnector):
|
||||
|
||||
parser = TokenTextSplitter(chunk_size=chunk_size)
|
||||
for document in documents:
|
||||
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)]
|
||||
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata)]
|
||||
nodes = parser.get_nodes_from_documents(llama_index_docs)
|
||||
for node in nodes:
|
||||
# passage = Passage(
|
||||
|
@ -1,18 +1,75 @@
|
||||
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
|
||||
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, TypeVar
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN
|
||||
from memgpt.constants import (
|
||||
DEFAULT_HUMAN,
|
||||
DEFAULT_PERSONA,
|
||||
DEFAULT_PRESET,
|
||||
JSON_ENSURE_ASCII,
|
||||
LLM_MAX_TOKENS,
|
||||
MAX_EMBEDDING_DIM,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
)
|
||||
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from memgpt.schemas.enums import MessageRole
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
|
||||
from memgpt.schemas.openai.chat_completions import ToolCall
|
||||
from memgpt.utils import get_utc_time, is_utc_datetime
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.utils import (
|
||||
create_uuid_from_string,
|
||||
get_human_text,
|
||||
get_persona_text,
|
||||
get_utc_time,
|
||||
is_utc_datetime,
|
||||
)
|
||||
|
||||
|
||||
class Record:
|
||||
"""
|
||||
Base class for an agent's memory unit. Each memory unit is represented in the database as a single row.
|
||||
Memory units are searched over by functions defined in the memory classes
|
||||
"""
|
||||
|
||||
def __init__(self, id: Optional[uuid.UUID] = None):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
else:
|
||||
self.id = id
|
||||
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
|
||||
|
||||
# This allows type checking to work when you pass a Passage into a function expecting List[Record]
|
||||
# (just use List[RecordType] instead)
|
||||
RecordType = TypeVar("RecordType", bound="Record")
|
||||
|
||||
|
||||
class ToolCall(object):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
# TODO should we include this? it's fixed to 'function' only (for now) in OAI schema
|
||||
# NOTE: called ToolCall.type in official OpenAI schema
|
||||
tool_call_type: str, # only 'function' is supported
|
||||
# function: { 'name': ..., 'arguments': ...}
|
||||
function: Dict[str, str],
|
||||
):
|
||||
self.id = id
|
||||
self.tool_call_type = tool_call_type
|
||||
self.function = function
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.tool_call_type,
|
||||
"function": self.function,
|
||||
}
|
||||
|
||||
|
||||
def add_inner_thoughts_to_tool_call(
|
||||
@ -24,34 +81,20 @@ def add_inner_thoughts_to_tool_call(
|
||||
# because the kwargs are stored as strings, we need to load then write the JSON dicts
|
||||
try:
|
||||
# load the args list
|
||||
func_args = json.loads(tool_call.function.arguments)
|
||||
func_args = json.loads(tool_call.function["arguments"])
|
||||
# add the inner thoughts to the args list
|
||||
func_args[inner_thoughts_key] = inner_thoughts
|
||||
# create the updated tool call (as a string)
|
||||
updated_tool_call = copy.deepcopy(tool_call)
|
||||
updated_tool_call.function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
updated_tool_call.function["arguments"] = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
|
||||
return updated_tool_call
|
||||
except json.JSONDecodeError as e:
|
||||
# TODO: change to logging
|
||||
warnings.warn(f"Failed to put inner thoughts in kwargs: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
class BaseMessage(MemGPTBase):
|
||||
__id_prefix__ = "message"
|
||||
|
||||
|
||||
class MessageCreate(BaseMessage):
|
||||
"""Request to create a message"""
|
||||
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
text: str = Field(..., description="The text of the message.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
|
||||
|
||||
class Message(BaseMessage):
|
||||
"""
|
||||
Representation of a message sent.
|
||||
class Message(Record):
|
||||
"""Representation of a message sent.
|
||||
|
||||
Messages can be:
|
||||
- agent->user (role=='agent')
|
||||
@ -59,23 +102,65 @@ class Message(BaseMessage):
|
||||
- or function/tool call returns (role=='function'/'tool').
|
||||
"""
|
||||
|
||||
id: str = BaseMessage.generate_id_field()
|
||||
role: MessageRole = Field(..., description="The role of the participant.")
|
||||
text: str = Field(..., description="The text of the message.")
|
||||
user_id: str = Field(None, description="The unique identifier of the user.")
|
||||
agent_id: str = Field(None, description="The unique identifier of the agent.")
|
||||
model: Optional[str] = Field(None, description="The model used to make the function call.")
|
||||
name: Optional[str] = Field(None, description="The name of the participant.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")
|
||||
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
|
||||
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
text: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
agent_id: Optional[uuid.UUID] = None,
|
||||
model: Optional[str] = None, # model used to make function call
|
||||
name: Optional[str] = None, # optional participant name
|
||||
created_at: Optional[datetime] = None,
|
||||
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
|
||||
tool_call_id: Optional[str] = None,
|
||||
# tool_call_name: Optional[str] = None, # not technically OpenAI spec, but it can be helpful to have on-hand
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
):
|
||||
super().__init__(id)
|
||||
self.user_id = user_id
|
||||
self.agent_id = agent_id
|
||||
self.text = text
|
||||
self.model = model # model name (e.g. gpt-4)
|
||||
self.created_at = created_at if created_at is not None else get_utc_time()
|
||||
|
||||
@field_validator("role")
|
||||
@classmethod
|
||||
def validate_role(cls, v: str) -> str:
|
||||
roles = ["system", "assistant", "user", "tool"]
|
||||
assert v in roles, f"Role must be one of {roles}"
|
||||
return v
|
||||
# openai info
|
||||
assert role in ["system", "assistant", "user", "tool"]
|
||||
self.role = role # role (agent/user/function)
|
||||
self.name = name
|
||||
|
||||
# pad and store embeddings
|
||||
if isinstance(embedding, list):
|
||||
embedding = np.array(embedding)
|
||||
self.embedding = (
|
||||
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
if self.embedding is not None:
|
||||
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
|
||||
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
|
||||
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
|
||||
|
||||
# tool (i.e. function) call info (optional)
|
||||
|
||||
# if role == "assistant", this MAY be specified
|
||||
# if role != "assistant", this must be null
|
||||
assert tool_calls is None or isinstance(tool_calls, list)
|
||||
if tool_calls is not None:
|
||||
assert all([isinstance(tc, ToolCall) for tc in tool_calls]), f"Tool calls must be of type ToolCall, got {tool_calls}"
|
||||
self.tool_calls = tool_calls
|
||||
|
||||
# if role == "tool", then this must be specified
|
||||
# if role != "tool", this must be null
|
||||
if role == "tool":
|
||||
assert tool_call_id is not None
|
||||
else:
|
||||
assert tool_call_id is None
|
||||
self.tool_call_id = tool_call_id
|
||||
|
||||
def to_json(self):
|
||||
json_message = vars(self)
|
||||
@ -88,26 +173,16 @@ class Message(BaseMessage):
|
||||
json_message["created_at"] = self.created_at.isoformat()
|
||||
return json_message
|
||||
|
||||
def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]:
|
||||
"""Convert message object (in DB format) to the style used by the original MemGPT API
|
||||
|
||||
NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def dict_to_message(
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
openai_message_dict: dict,
|
||||
model: Optional[str] = None, # model used to make function call
|
||||
allow_functions_style: bool = False, # allow deprecated functions style?
|
||||
created_at: Optional[datetime] = None,
|
||||
):
|
||||
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""
|
||||
if not created_at:
|
||||
# timestamp for creation
|
||||
created_at = get_utc_time()
|
||||
|
||||
assert "role" in openai_message_dict, openai_message_dict
|
||||
assert "content" in openai_message_dict, openai_message_dict
|
||||
@ -121,6 +196,7 @@ class Message(BaseMessage):
|
||||
# Convert from 'function' response to a 'tool' response
|
||||
# NOTE: this does not conventionally include a tool_call_id, it's on the caster to provide it
|
||||
return Message(
|
||||
created_at=created_at,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
@ -130,7 +206,6 @@ class Message(BaseMessage):
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None:
|
||||
@ -153,6 +228,7 @@ class Message(BaseMessage):
|
||||
]
|
||||
|
||||
return Message(
|
||||
created_at=created_at,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
@ -162,7 +238,6 @@ class Message(BaseMessage):
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
else:
|
||||
@ -185,6 +260,7 @@ class Message(BaseMessage):
|
||||
|
||||
# If we're going from tool-call style
|
||||
return Message(
|
||||
created_at=created_at,
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
@ -194,7 +270,6 @@ class Message(BaseMessage):
|
||||
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
|
||||
tool_calls=tool_calls,
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict:
|
||||
@ -248,11 +323,11 @@ class Message(BaseMessage):
|
||||
tool_call,
|
||||
inner_thoughts=self.text,
|
||||
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
||||
).model_dump()
|
||||
).to_dict()
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
else:
|
||||
openai_message["tool_calls"] = [tool_call.model_dump() for tool_call in self.tool_calls]
|
||||
openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls]
|
||||
if max_tool_id_length:
|
||||
for tool_call_dict in openai_message["tool_calls"]:
|
||||
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
|
||||
@ -548,3 +623,313 @@ class Message(BaseMessage):
|
||||
raise ValueError(self.role)
|
||||
|
||||
return cohere_message
|
||||
|
||||
|
||||
class Document(Record):
|
||||
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
|
||||
|
||||
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None, metadata: Optional[Dict] = {}):
|
||||
if id is None:
|
||||
# by default, generate ID as a hash of the text (avoid duplicates)
|
||||
self.id = create_uuid_from_string("".join([text, str(user_id)]))
|
||||
else:
|
||||
self.id = id
|
||||
super().__init__(id)
|
||||
self.user_id = user_id
|
||||
self.text = text
|
||||
self.data_source = data_source
|
||||
self.metadata = metadata
|
||||
# TODO: add optional embedding?
|
||||
|
||||
|
||||
class Passage(Record):
|
||||
"""A passage is a single unit of memory, and a standard format accross all storage backends.
|
||||
|
||||
It is a string of text with an assoidciated embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
text: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
data_source: Optional[str] = None, # None if created by agent
|
||||
doc_id: Optional[uuid.UUID] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
metadata_: Optional[dict] = {},
|
||||
created_at: Optional[datetime] = None,
|
||||
):
|
||||
if id is None:
|
||||
# by default, generate ID as a hash of the text (avoid duplicates)
|
||||
# TODO: use source-id instead?
|
||||
if agent_id:
|
||||
self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)]))
|
||||
else:
|
||||
self.id = create_uuid_from_string("".join([text, str(user_id)]))
|
||||
else:
|
||||
self.id = id
|
||||
super().__init__(self.id)
|
||||
self.user_id = user_id
|
||||
self.agent_id = agent_id
|
||||
self.text = text
|
||||
self.data_source = data_source
|
||||
self.doc_id = doc_id
|
||||
self.metadata_ = metadata_
|
||||
|
||||
# pad and store embeddings
|
||||
if isinstance(embedding, list):
|
||||
embedding = np.array(embedding)
|
||||
self.embedding = (
|
||||
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
|
||||
)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
self.created_at = created_at if created_at is not None else get_utc_time()
|
||||
|
||||
if self.embedding is not None:
|
||||
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
|
||||
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
|
||||
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
|
||||
|
||||
assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type"
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type"
|
||||
assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type"
|
||||
|
||||
|
||||
class LLMConfig:
|
||||
def __init__(
|
||||
self,
|
||||
model: Optional[str] = None,
|
||||
model_endpoint_type: Optional[str] = None,
|
||||
model_endpoint: Optional[str] = None,
|
||||
model_wrapper: Optional[str] = None,
|
||||
context_window: Optional[int] = None,
|
||||
):
|
||||
self.model = model
|
||||
self.model_endpoint_type = model_endpoint_type
|
||||
self.model_endpoint = model_endpoint
|
||||
self.model_wrapper = model_wrapper
|
||||
self.context_window = context_window
|
||||
|
||||
if context_window is None:
|
||||
self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
|
||||
else:
|
||||
self.context_window = context_window
|
||||
|
||||
|
||||
class EmbeddingConfig:
|
||||
def __init__(
|
||||
self,
|
||||
embedding_endpoint_type: Optional[str] = None,
|
||||
embedding_endpoint: Optional[str] = None,
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
embedding_chunk_size: Optional[int] = 300,
|
||||
):
|
||||
self.embedding_endpoint_type = embedding_endpoint_type
|
||||
self.embedding_endpoint = embedding_endpoint
|
||||
self.embedding_model = embedding_model
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
# fields cannot be set to None
|
||||
assert self.embedding_endpoint_type
|
||||
assert self.embedding_dim
|
||||
assert self.embedding_chunk_size
|
||||
|
||||
|
||||
class OpenAIEmbeddingConfig(EmbeddingConfig):
|
||||
def __init__(self, openai_key: Optional[str] = None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.openai_key = openai_key
|
||||
|
||||
|
||||
class AzureEmbeddingConfig(EmbeddingConfig):
|
||||
def __init__(
|
||||
self,
|
||||
azure_key: Optional[str] = None,
|
||||
azure_endpoint: Optional[str] = None,
|
||||
azure_version: Optional[str] = None,
|
||||
azure_deployment: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.azure_key = azure_key
|
||||
self.azure_endpoint = azure_endpoint
|
||||
self.azure_version = azure_version
|
||||
self.azure_deployment = azure_deployment
|
||||
|
||||
|
||||
class User:
|
||||
"""Defines user and default configurations"""
|
||||
|
||||
# TODO: make sure to encrypt/decrypt keys before storing in DB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# name: str,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
default_agent=None,
|
||||
# other
|
||||
policies_accepted=False,
|
||||
):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
else:
|
||||
self.id = id
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
|
||||
self.default_agent = default_agent
|
||||
|
||||
# misc
|
||||
self.policies_accepted = policies_accepted
|
||||
|
||||
|
||||
class AgentState:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
user_id: uuid.UUID,
|
||||
# tools
|
||||
tools: List[str], # list of tools by name
|
||||
# system prompt
|
||||
system: str,
|
||||
# config
|
||||
llm_config: LLMConfig,
|
||||
embedding_config: EmbeddingConfig,
|
||||
# (in-context) state contains:
|
||||
id: Optional[uuid.UUID] = None,
|
||||
state: Optional[dict] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
# messages (TODO: implement this)
|
||||
_metadata: Optional[dict] = None,
|
||||
):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
else:
|
||||
self.id = id
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
|
||||
|
||||
# TODO(swooders) we need to handle the case where name is None here
|
||||
# in AgentConfig we autogenerate a name, not sure what the correct thing w/ DBs is, what about NounAdjective combos? Like giphy does? BoredGiraffe etc
|
||||
self.name = name
|
||||
assert self.name, f"AgentState name must be a non-empty string"
|
||||
self.user_id = user_id
|
||||
# The INITIAL values of the persona and human
|
||||
# The values inside self.state['persona'], self.state['human'] are the CURRENT values
|
||||
|
||||
self.llm_config = llm_config
|
||||
self.embedding_config = embedding_config
|
||||
|
||||
self.created_at = created_at if created_at is not None else get_utc_time()
|
||||
|
||||
# state
|
||||
self.state = {} if not state else state
|
||||
|
||||
# tools
|
||||
self.tools = tools
|
||||
|
||||
# system
|
||||
self.system = system
|
||||
assert self.system is not None, f"Must provide system prompt, cannot be None"
|
||||
|
||||
# metadata
|
||||
self._metadata = _metadata
|
||||
|
||||
|
||||
class Source:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
# embedding info
|
||||
embedding_model: Optional[str] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
else:
|
||||
self.id = id
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
|
||||
|
||||
self.name = name
|
||||
self.user_id = user_id
|
||||
self.description = description
|
||||
self.created_at = created_at if created_at is not None else get_utc_time()
|
||||
|
||||
# embedding info (optional)
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embedding_model = embedding_model
|
||||
|
||||
|
||||
class Token:
|
||||
def __init__(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
token: str,
|
||||
name: Optional[str] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
else:
|
||||
self.id = id
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
|
||||
|
||||
self.token = token
|
||||
self.user_id = user_id
|
||||
self.name = name
|
||||
|
||||
|
||||
class Preset(BaseModel):
|
||||
# TODO: remove Preset
|
||||
name: str = Field(..., description="The name of the preset.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
|
||||
system: str = Field(
|
||||
gpt_system.get_system_text(DEFAULT_PRESET), description="The system prompt of the preset."
|
||||
) # default system prompt is same as default preset name
|
||||
# system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
|
||||
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
|
||||
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||
# functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID
|
||||
# sources: List[str] = Field(..., description="The sources of the preset.") # TODO: convert to ID
|
||||
|
||||
@staticmethod
|
||||
def clone(preset_obj: "Preset", new_name_suffix: str = None) -> "Preset":
|
||||
"""
|
||||
Takes a Preset object and an optional new name suffix as input,
|
||||
creates a clone of the given Preset object with a new ID and an optional new name,
|
||||
and returns the new Preset object.
|
||||
"""
|
||||
new_preset = preset_obj.model_copy()
|
||||
new_preset.id = uuid.uuid4()
|
||||
if new_name_suffix:
|
||||
new_preset.name = f"{preset_obj.name}_{new_name_suffix}"
|
||||
else:
|
||||
new_preset.name = f"{preset_obj.name}_{str(uuid.uuid4())[:8]}"
|
||||
return new_preset
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
id: uuid.UUID = Field(..., description="The unique identifier of the function.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the function.")
|
||||
# TODO: figure out how represent functions
|
@ -22,7 +22,7 @@ from memgpt.constants import (
|
||||
MAX_EMBEDDING_DIM,
|
||||
)
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.data_types import EmbeddingConfig
|
||||
from memgpt.utils import is_valid_url, printd
|
||||
|
||||
|
||||
|
@ -11,8 +11,8 @@ from memgpt.constants import (
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL,
|
||||
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
|
||||
)
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.llm_api.llm_api_tools import create
|
||||
from memgpt.schemas.message import Message
|
||||
|
||||
|
||||
def message_chatgpt(self, message: str):
|
||||
|
@ -1,6 +1,6 @@
|
||||
import inspect
|
||||
import typing
|
||||
from typing import Any, Dict, Optional, Type, get_args, get_origin
|
||||
from typing import Optional, get_args, get_origin
|
||||
|
||||
from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
@ -144,39 +144,3 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def generate_schema_from_args_schema(
|
||||
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
properties = {}
|
||||
required = []
|
||||
for field_name, field in args_schema.__fields__.items():
|
||||
properties[field_name] = {"type": field.type_.__name__, "description": field.field_info.description}
|
||||
if field.required:
|
||||
required.append(field_name)
|
||||
|
||||
# Construct the OpenAI function call JSON object
|
||||
function_call_json = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {"type": "object", "properties": properties, "required": required},
|
||||
}
|
||||
|
||||
return function_call_json
|
||||
|
||||
|
||||
def generate_tool_wrapper(tool_name: str) -> str:
|
||||
import_statement = f"from crewai_tools import {tool_name}"
|
||||
tool_instantiation = f"tool = {tool_name}()"
|
||||
run_call = f"return tool._run(**kwargs)"
|
||||
func_name = f"run_{tool_name.lower()}"
|
||||
|
||||
# Combine all parts into the wrapper function
|
||||
wrapper_function_str = f"""
|
||||
def {func_name}(**kwargs):
|
||||
{import_statement}
|
||||
{tool_instantiation}
|
||||
{run_call}
|
||||
"""
|
||||
return func_name, wrapper_function_str
|
||||
|
@ -6,7 +6,7 @@ from typing import List, Optional
|
||||
from colorama import Fore, Style, init
|
||||
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.utils import printd
|
||||
|
||||
init(autoreset=True)
|
||||
|
@ -6,17 +6,17 @@ from typing import List, Optional, Union
|
||||
import requests
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
FunctionCall,
|
||||
)
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.models.chat_completion_response import (
|
||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
|
||||
)
|
||||
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from memgpt.models.chat_completion_response import ToolCall, UsageStatistics
|
||||
from memgpt.utils import get_utc_time, smart_urljoin
|
||||
|
||||
BASE_URL = "https://api.anthropic.com/v1"
|
||||
|
@ -2,8 +2,8 @@ from typing import Union
|
||||
|
||||
import requests
|
||||
|
||||
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.schemas.openai.embedding_response import EmbeddingResponse
|
||||
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.models.embedding_response import EmbeddingResponse
|
||||
from memgpt.utils import smart_urljoin
|
||||
|
||||
MODEL_TO_AZURE_ENGINE = {
|
||||
|
@ -5,18 +5,18 @@ from typing import List, Optional, Union
|
||||
import requests
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.local_llm.utils import count_tokens
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
FunctionCall,
|
||||
)
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.models.chat_completion_response import (
|
||||
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
|
||||
)
|
||||
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
|
||||
from memgpt.models.chat_completion_response import ToolCall, UsageStatistics
|
||||
from memgpt.utils import get_tool_call_id, get_utc_time, smart_urljoin
|
||||
|
||||
BASE_URL = "https://api.cohere.ai/v1"
|
||||
|
@ -7,8 +7,8 @@ import requests
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, NON_USER_MSG_PREFIX
|
||||
from memgpt.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from memgpt.local_llm.utils import count_tokens
|
||||
from memgpt.schemas.openai.chat_completion_request import Tool
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.models.chat_completion_request import Tool
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
FunctionCall,
|
||||
|
@ -11,6 +11,7 @@ import requests
|
||||
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
|
||||
from memgpt.llm_api.azure_openai import (
|
||||
MODEL_TO_AZURE_ENGINE,
|
||||
@ -30,15 +31,13 @@ from memgpt.local_llm.constants import (
|
||||
INNER_THOUGHTS_KWARG,
|
||||
INNER_THOUGHTS_KWARG_DESCRIPTION,
|
||||
)
|
||||
from memgpt.schemas.enums import OptionState
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import (
|
||||
from memgpt.models.chat_completion_request import (
|
||||
ChatCompletionRequest,
|
||||
Tool,
|
||||
cast_message_to_subtype,
|
||||
)
|
||||
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.models.pydantic_models import LLMConfigModel, OptionState
|
||||
from memgpt.streaming_interface import (
|
||||
AgentChunkStreamingInterface,
|
||||
AgentRefreshStreamingInterface,
|
||||
@ -229,7 +228,7 @@ def retry_with_exponential_backoff(
|
||||
@retry_with_exponential_backoff
|
||||
def create(
|
||||
# agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
llm_config: LLMConfigModel,
|
||||
messages: List[Message],
|
||||
user_id: uuid.UUID = None, # option UUID to associate request with
|
||||
functions: list = None,
|
||||
@ -260,6 +259,8 @@ def create(
|
||||
printd("unsetting function_call because functions is None")
|
||||
function_call = None
|
||||
|
||||
# print("HELLO")
|
||||
|
||||
# openai
|
||||
if llm_config.model_endpoint_type == "openai":
|
||||
|
||||
|
@ -7,8 +7,8 @@ from httpx_sse import connect_sse
|
||||
from httpx_sse._exceptions import SSEError
|
||||
|
||||
from memgpt.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.models.chat_completion_request import ChatCompletionRequest
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionChunkResponse,
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
@ -17,7 +17,7 @@ from memgpt.schemas.openai.chat_completion_response import (
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from memgpt.schemas.openai.embedding_response import EmbeddingResponse
|
||||
from memgpt.models.embedding_response import EmbeddingResponse
|
||||
from memgpt.streaming_interface import (
|
||||
AgentChunkStreamingInterface,
|
||||
AgentRefreshStreamingInterface,
|
||||
@ -89,7 +89,6 @@ def openai_chat_completions_process_stream(
|
||||
on the chunks received from the OpenAI-compatible server POST SSE response.
|
||||
"""
|
||||
assert chat_completion_request.stream == True
|
||||
assert stream_inferface is not None, "Required"
|
||||
|
||||
# Count the prompt tokens
|
||||
# TODO move to post-request?
|
||||
@ -371,10 +370,7 @@ def openai_chat_completions_request(
|
||||
url = smart_urljoin(url, "chat/completions")
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
data = chat_completion_request.model_dump(exclude_none=True)
|
||||
|
||||
# add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified."
|
||||
if chat_completion_request.tools is not None:
|
||||
data["parallel_tool_calls"] = False
|
||||
data["parallel_tool_calls"] = False
|
||||
|
||||
printd("Request:\n", json.dumps(data, indent=2))
|
||||
|
||||
@ -390,7 +386,7 @@ def openai_chat_completions_request(
|
||||
printd(f"Sending request to {url}")
|
||||
try:
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
printd(f"response = {response}, response.text = {response.text}")
|
||||
# printd(f"response = {response}, response.text = {response.text}")
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
|
||||
response = response.json() # convert to dict from string
|
||||
|
@ -25,14 +25,14 @@ from memgpt.local_llm.webui.api import get_webui_completion
|
||||
from memgpt.local_llm.webui.legacy_api import (
|
||||
get_webui_completion as get_webui_completion_legacy,
|
||||
)
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
Message,
|
||||
ToolCall,
|
||||
UsageStatistics,
|
||||
)
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from memgpt.utils import get_tool_call_id, get_utc_time
|
||||
|
||||
has_shown_warning = False
|
||||
|
@ -11,12 +11,20 @@ from rich.console import Console
|
||||
import memgpt.agent as agent
|
||||
import memgpt.errors as errors
|
||||
import memgpt.system as system
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
|
||||
# import benchmark
|
||||
from memgpt import create_client
|
||||
from memgpt.benchmark.benchmark import bench
|
||||
from memgpt.cli.cli import delete_agent, open_folder, quickstart, run, server, version
|
||||
from memgpt.cli.cli_config import add, add_tool, configure, delete, list, list_tools
|
||||
from memgpt.cli.cli import (
|
||||
delete_agent,
|
||||
migrate,
|
||||
open_folder,
|
||||
quickstart,
|
||||
run,
|
||||
server,
|
||||
version,
|
||||
)
|
||||
from memgpt.cli.cli_config import add, configure, delete, list
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import (
|
||||
@ -26,7 +34,7 @@ from memgpt.constants import (
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
)
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.schemas.enums import OptionState
|
||||
from memgpt.models.pydantic_models import OptionState
|
||||
|
||||
# from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
from memgpt.streaming_interface import AgentRefreshStreamingInterface
|
||||
@ -39,14 +47,14 @@ app.command(name="version")(version)
|
||||
app.command(name="configure")(configure)
|
||||
app.command(name="list")(list)
|
||||
app.command(name="add")(add)
|
||||
app.command(name="add-tool")(add_tool)
|
||||
app.command(name="list-tools")(list_tools)
|
||||
app.command(name="delete")(delete)
|
||||
app.command(name="server")(server)
|
||||
app.command(name="folder")(open_folder)
|
||||
app.command(name="quickstart")(quickstart)
|
||||
# load data commands
|
||||
app.add_typer(load_app, name="load")
|
||||
# migration command
|
||||
app.command(name="migrate")(migrate)
|
||||
# benchmark command
|
||||
app.command(name="benchmark")(bench)
|
||||
# delete agents
|
||||
@ -95,12 +103,7 @@ def run_agent_loop(
|
||||
print()
|
||||
|
||||
multiline_input = False
|
||||
|
||||
# create client
|
||||
client = create_client()
|
||||
ms = MetadataStore(config) # TODO: remove
|
||||
|
||||
# run loops
|
||||
ms = MetadataStore(config)
|
||||
while True:
|
||||
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
|
||||
# Ask for user input
|
||||
@ -148,8 +151,8 @@ def run_agent_loop(
|
||||
# TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not
|
||||
# TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources
|
||||
|
||||
sources = client.list_sources()
|
||||
if len(sources) == 0:
|
||||
data_source_options = ms.list_sources(user_id=memgpt_agent.agent_state.user_id)
|
||||
if len(data_source_options) == 0:
|
||||
typer.secho(
|
||||
'No sources available. You must load a souce with "memgpt load ..." before running /attach.',
|
||||
fg=typer.colors.RED,
|
||||
@ -160,8 +163,11 @@ def run_agent_loop(
|
||||
# determine what sources are valid to be attached to this agent
|
||||
valid_options = []
|
||||
invalid_options = []
|
||||
for source in sources:
|
||||
if source.embedding_config == memgpt_agent.agent_state.embedding_config:
|
||||
for source in data_source_options:
|
||||
if (
|
||||
source.embedding_model == memgpt_agent.agent_state.embedding_config.embedding_model
|
||||
and source.embedding_dim == memgpt_agent.agent_state.embedding_config.embedding_dim
|
||||
):
|
||||
valid_options.append(source.name)
|
||||
else:
|
||||
# print warning about invalid sources
|
||||
@ -175,7 +181,11 @@ def run_agent_loop(
|
||||
data_source = questionary.select("Select data source", choices=valid_options).ask()
|
||||
|
||||
# attach new data
|
||||
client.attach_source_to_agent(agent_id=memgpt_agent.agent_state.id, source_name=data_source)
|
||||
# attach(memgpt_agent.agent_state.name, data_source)
|
||||
source_connector = StorageConnector.get_storage_connector(
|
||||
TableType.PASSAGES, config, user_id=memgpt_agent.agent_state.user_id
|
||||
)
|
||||
memgpt_agent.attach_source(data_source, source_connector, ms)
|
||||
|
||||
continue
|
||||
|
||||
@ -420,10 +430,8 @@ def run_agent_loop(
|
||||
skip_verify=no_verify,
|
||||
stream=stream,
|
||||
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
|
||||
ms=ms,
|
||||
)
|
||||
|
||||
agent.save_agent(memgpt_agent, ms)
|
||||
skip_next_user_input = False
|
||||
if token_warning:
|
||||
user_message = system.get_token_limit_warning()
|
||||
|
246
memgpt/memory.py
246
memgpt/memory.py
@ -3,14 +3,13 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from memgpt.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
|
||||
from memgpt.data_types import AgentState, Message, Passage
|
||||
from memgpt.embeddings import embedding_model, parse_and_chunk_text, query_embedding
|
||||
from memgpt.llm_api.llm_api_tools import create
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from memgpt.schemas.agent import AgentState
|
||||
from memgpt.schemas.memory import Memory
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.utils import (
|
||||
count_tokens,
|
||||
extract_date_from_timestamp,
|
||||
@ -19,135 +18,125 @@ from memgpt.utils import (
|
||||
validate_date_format,
|
||||
)
|
||||
|
||||
# class MemoryModule(BaseModel):
|
||||
# """Base class for memory modules"""
|
||||
#
|
||||
# description: Optional[str] = None
|
||||
# limit: int = 2000
|
||||
# value: Optional[Union[List[str], str]] = None
|
||||
#
|
||||
# def __setattr__(self, name, value):
|
||||
# """Run validation if self.value is updated"""
|
||||
# super().__setattr__(name, value)
|
||||
# if name == "value":
|
||||
# # run validation
|
||||
# self.__class__.validate(self.dict(exclude_unset=True))
|
||||
#
|
||||
# @validator("value", always=True)
|
||||
# def check_value_length(cls, v, values):
|
||||
# if v is not None:
|
||||
# # Fetching the limit from the values dictionary
|
||||
# limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
|
||||
#
|
||||
# # Check if the value exceeds the limit
|
||||
# if isinstance(v, str):
|
||||
# length = len(v)
|
||||
# elif isinstance(v, list):
|
||||
# length = sum(len(item) for item in v)
|
||||
# else:
|
||||
# raise ValueError("Value must be either a string or a list of strings.")
|
||||
#
|
||||
# if length > limit:
|
||||
# error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
|
||||
# # TODO: add archival memory error?
|
||||
# raise ValueError(error_msg)
|
||||
# return v
|
||||
#
|
||||
# def __len__(self):
|
||||
# return len(str(self))
|
||||
#
|
||||
# def __str__(self) -> str:
|
||||
# if isinstance(self.value, list):
|
||||
# return ",".join(self.value)
|
||||
# elif isinstance(self.value, str):
|
||||
# return self.value
|
||||
# else:
|
||||
# return ""
|
||||
#
|
||||
#
|
||||
# class BaseMemory:
|
||||
#
|
||||
# def __init__(self):
|
||||
# self.memory = {}
|
||||
#
|
||||
# @classmethod
|
||||
# def load(cls, state: dict):
|
||||
# """Load memory from dictionary object"""
|
||||
# obj = cls()
|
||||
# for key, value in state.items():
|
||||
# obj.memory[key] = MemoryModule(**value)
|
||||
# return obj
|
||||
#
|
||||
# def __str__(self) -> str:
|
||||
# """Representation of the memory in-context"""
|
||||
# section_strs = []
|
||||
# for section, module in self.memory.items():
|
||||
# section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
|
||||
# return "\n".join(section_strs)
|
||||
#
|
||||
# def to_dict(self):
|
||||
# """Convert to dictionary representation"""
|
||||
# return {key: value.dict() for key, value in self.memory.items()}
|
||||
#
|
||||
#
|
||||
# class ChatMemory(BaseMemory):
|
||||
#
|
||||
# def __init__(self, persona: str, human: str, limit: int = 2000):
|
||||
# self.memory = {
|
||||
# "persona": MemoryModule(name="persona", value=persona, limit=limit),
|
||||
# "human": MemoryModule(name="human", value=human, limit=limit),
|
||||
# }
|
||||
#
|
||||
# def core_memory_append(self, name: str, content: str) -> Optional[str]:
|
||||
# """
|
||||
# Append to the contents of core memory.
|
||||
#
|
||||
# Args:
|
||||
# name (str): Section of the memory to be edited (persona or human).
|
||||
# content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
#
|
||||
# Returns:
|
||||
# Optional[str]: None is always returned as this function does not produce a response.
|
||||
# """
|
||||
# self.memory[name].value += "\n" + content
|
||||
# return None
|
||||
#
|
||||
# def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
# """
|
||||
# Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
#
|
||||
# Args:
|
||||
# name (str): Section of the memory to be edited (persona or human).
|
||||
# old_content (str): String to replace. Must be an exact match.
|
||||
# new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
#
|
||||
# Returns:
|
||||
# Optional[str]: None is always returned as this function does not produce a response.
|
||||
# """
|
||||
# self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
|
||||
# return None
|
||||
|
||||
class MemoryModule(BaseModel):
|
||||
"""Base class for memory modules"""
|
||||
|
||||
description: Optional[str] = None
|
||||
limit: int = 2000
|
||||
value: Optional[Union[List[str], str]] = None
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Run validation if self.value is updated"""
|
||||
super().__setattr__(name, value)
|
||||
if name == "value":
|
||||
# run validation
|
||||
self.__class__.validate(self.dict(exclude_unset=True))
|
||||
|
||||
@validator("value", always=True)
|
||||
def check_value_length(cls, v, values):
|
||||
if v is not None:
|
||||
# Fetching the limit from the values dictionary
|
||||
limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
|
||||
|
||||
# Check if the value exceeds the limit
|
||||
if isinstance(v, str):
|
||||
length = len(v)
|
||||
elif isinstance(v, list):
|
||||
length = sum(len(item) for item in v)
|
||||
else:
|
||||
raise ValueError("Value must be either a string or a list of strings.")
|
||||
|
||||
if length > limit:
|
||||
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
|
||||
# TODO: add archival memory error?
|
||||
raise ValueError(error_msg)
|
||||
return v
|
||||
|
||||
def __len__(self):
|
||||
return len(str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, list):
|
||||
return ",".join(self.value)
|
||||
elif isinstance(self.value, str):
|
||||
return self.value
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def get_memory_functions(cls: Memory) -> List[callable]:
|
||||
class BaseMemory:
|
||||
|
||||
def __init__(self):
|
||||
self.memory = {}
|
||||
|
||||
@classmethod
|
||||
def load(cls, state: dict):
|
||||
"""Load memory from dictionary object"""
|
||||
obj = cls()
|
||||
for key, value in state.items():
|
||||
obj.memory[key] = MemoryModule(**value)
|
||||
return obj
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Representation of the memory in-context"""
|
||||
section_strs = []
|
||||
for section, module in self.memory.items():
|
||||
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
|
||||
return "\n".join(section_strs)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary representation"""
|
||||
return {key: value.dict() for key, value in self.memory.items()}
|
||||
|
||||
|
||||
class ChatMemory(BaseMemory):
|
||||
|
||||
def __init__(self, persona: str, human: str, limit: int = 2000):
|
||||
self.memory = {
|
||||
"persona": MemoryModule(name="persona", value=persona, limit=limit),
|
||||
"human": MemoryModule(name="human", value=human, limit=limit),
|
||||
}
|
||||
|
||||
def core_memory_append(self, name: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.memory[name].value += "\n" + content
|
||||
return None
|
||||
|
||||
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
|
||||
return None
|
||||
|
||||
|
||||
def get_memory_functions(cls: BaseMemory) -> List[callable]:
|
||||
"""Get memory functions for a memory class"""
|
||||
functions = {}
|
||||
|
||||
# collect base memory functions (should not be included)
|
||||
base_functions = []
|
||||
for func_name in dir(Memory):
|
||||
funct = getattr(Memory, func_name)
|
||||
if callable(funct):
|
||||
base_functions.append(func_name)
|
||||
|
||||
for func_name in dir(cls):
|
||||
if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
|
||||
continue
|
||||
if func_name in base_functions: # dont use BaseMemory functions
|
||||
continue
|
||||
func = getattr(cls, func_name)
|
||||
if not callable(func): # not a function
|
||||
continue
|
||||
functions[func_name] = func
|
||||
if callable(func):
|
||||
functions[func_name] = func
|
||||
return functions
|
||||
|
||||
|
||||
@ -264,8 +253,8 @@ def summarize_messages(
|
||||
+ message_sequence_to_summarize[cutoff:]
|
||||
)
|
||||
|
||||
dummy_user_id = agent_state.user_id
|
||||
dummy_agent_id = agent_state.id
|
||||
dummy_user_id = uuid.uuid4()
|
||||
dummy_agent_id = uuid.uuid4()
|
||||
message_sequence = []
|
||||
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt))
|
||||
if insert_acknowledgement_assistant_message:
|
||||
@ -528,7 +517,8 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
agent_id=self.agent_state.id,
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
embedding_config=self.agent_state.embedding_config,
|
||||
embedding_dim=self.agent_state.embedding_config.embedding_dim,
|
||||
embedding_model=self.agent_state.embedding_config.embedding_model,
|
||||
)
|
||||
|
||||
def save(self):
|
||||
|
@ -3,10 +3,12 @@
|
||||
import os
|
||||
import secrets
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
BIGINT,
|
||||
CHAR,
|
||||
JSON,
|
||||
Boolean,
|
||||
Column,
|
||||
@ -18,28 +20,58 @@ from sqlalchemy import (
|
||||
desc,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.exc import InterfaceError, OperationalError
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.schemas.agent import AgentState
|
||||
from memgpt.schemas.api_key import APIKey
|
||||
from memgpt.schemas.block import Block, Human, Persona
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.enums import JobStatus
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.schemas.memory import Memory
|
||||
from memgpt.schemas.source import Source
|
||||
from memgpt.schemas.tool import Tool
|
||||
from memgpt.schemas.user import User
|
||||
from memgpt.data_types import (
|
||||
AgentState,
|
||||
EmbeddingConfig,
|
||||
LLMConfig,
|
||||
Preset,
|
||||
Source,
|
||||
Token,
|
||||
User,
|
||||
)
|
||||
from memgpt.models.pydantic_models import (
|
||||
HumanModel,
|
||||
JobModel,
|
||||
JobStatus,
|
||||
PersonaModel,
|
||||
ToolModel,
|
||||
)
|
||||
from memgpt.settings import settings
|
||||
from memgpt.utils import enforce_types, get_utc_time, printd
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# Custom UUID type
|
||||
class CommonUUID(TypeDecorator):
|
||||
impl = CHAR
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID(as_uuid=True))
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if dialect.name == "postgresql" or value is None:
|
||||
return value
|
||||
else:
|
||||
return str(value) # Convert UUID to string for SQLite
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if dialect.name == "postgresql" or value is None:
|
||||
return value
|
||||
else:
|
||||
return uuid.UUID(value)
|
||||
|
||||
|
||||
class LLMConfigColumn(TypeDecorator):
|
||||
"""Custom type for storing LLMConfig as JSON"""
|
||||
|
||||
@ -84,44 +116,48 @@ class UserModel(Base):
|
||||
__tablename__ = "users"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# name = Column(String, nullable=False)
|
||||
default_agent = Column(String)
|
||||
|
||||
# TODO: what is this?
|
||||
policies_accepted = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<User(id='{self.id}' name='{self.name}')>"
|
||||
return f"<User(id='{self.id}')>"
|
||||
|
||||
def to_record(self) -> User:
|
||||
return User(id=self.id, name=self.name, created_at=self.created_at)
|
||||
return User(
|
||||
id=self.id,
|
||||
# name=self.name
|
||||
default_agent=self.default_agent,
|
||||
policies_accepted=self.policies_accepted,
|
||||
)
|
||||
|
||||
|
||||
class APIKeyModel(Base):
|
||||
class TokenModel(Base):
|
||||
"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""
|
||||
|
||||
__tablename__ = "tokens"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# each api key is tied to a user account (that it validates access for)
|
||||
user_id = Column(String, nullable=False)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
# the api key
|
||||
key = Column(String, nullable=False)
|
||||
token = Column(String, nullable=False)
|
||||
# extra (optional) metadata
|
||||
name = Column(String)
|
||||
|
||||
Index(__tablename__ + "_idx_user", user_id),
|
||||
Index(__tablename__ + "_idx_key", key),
|
||||
Index(__tablename__ + "_idx_token", token),
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<APIKey(id='{self.id}', key='{self.key}', name='{self.name}')>"
|
||||
return f"<Token(id='{self.id}', token='{self.token}', name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> User:
|
||||
return APIKey(
|
||||
return Token(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
key=self.key,
|
||||
token=self.token,
|
||||
name=self.name,
|
||||
)
|
||||
|
||||
@ -140,24 +176,19 @@ class AgentModel(Base):
|
||||
__tablename__ = "agents"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
description = Column(String)
|
||||
|
||||
# state (context compilation)
|
||||
message_ids = Column(JSON)
|
||||
memory = Column(JSON)
|
||||
system = Column(String)
|
||||
tools = Column(JSON)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# configs
|
||||
llm_config = Column(LLMConfigColumn)
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
|
||||
# state
|
||||
metadata_ = Column(JSON)
|
||||
state = Column(JSON)
|
||||
_metadata = Column(JSON)
|
||||
|
||||
# tools
|
||||
tools = Column(JSON)
|
||||
@ -173,14 +204,12 @@ class AgentModel(Base):
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
created_at=self.created_at,
|
||||
description=self.description,
|
||||
message_ids=self.message_ids,
|
||||
memory=Memory.load(self.memory), # load dictionary
|
||||
system=self.system,
|
||||
tools=self.tools,
|
||||
llm_config=self.llm_config,
|
||||
embedding_config=self.embedding_config,
|
||||
metadata_=self.metadata_,
|
||||
state=self.state,
|
||||
tools=self.tools,
|
||||
system=self.system,
|
||||
_metadata=self._metadata,
|
||||
)
|
||||
|
||||
|
||||
@ -192,13 +221,13 @@ class SourceModel(Base):
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
embedding_dim = Column(BIGINT)
|
||||
embedding_model = Column(String)
|
||||
description = Column(String)
|
||||
metadata_ = Column(JSON)
|
||||
Index(__tablename__ + "_idx_user", user_id),
|
||||
|
||||
# TODO: add num passages
|
||||
@ -212,9 +241,9 @@ class SourceModel(Base):
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
created_at=self.created_at,
|
||||
embedding_config=self.embedding_config,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_model=self.embedding_model,
|
||||
description=self.description,
|
||||
metadata_=self.metadata_,
|
||||
)
|
||||
|
||||
|
||||
@ -223,116 +252,80 @@ class AgentSourceMappingModel(Base):
|
||||
|
||||
__tablename__ = "agent_source_mapping"
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
source_id = Column(String, nullable=False)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
agent_id = Column(CommonUUID, nullable=False)
|
||||
source_id = Column(CommonUUID, nullable=False)
|
||||
Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
|
||||
|
||||
|
||||
class BlockModel(Base):
|
||||
__tablename__ = "block"
|
||||
class PresetSourceMapping(Base):
|
||||
__tablename__ = "preset_source_mapping"
|
||||
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
preset_id = Column(CommonUUID, nullable=False)
|
||||
source_id = Column(CommonUUID, nullable=False)
|
||||
Index(__tablename__ + "_idx_user", user_id, preset_id, source_id),
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<PresetSourceMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', source_id='{self.source_id}')>"
|
||||
|
||||
|
||||
# class PresetFunctionMapping(Base):
|
||||
# __tablename__ = "preset_function_mapping"
|
||||
#
|
||||
# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# user_id = Column(CommonUUID, nullable=False)
|
||||
# preset_id = Column(CommonUUID, nullable=False)
|
||||
# #function_id = Column(CommonUUID, nullable=False)
|
||||
# function = Column(String, nullable=False) # TODO: convert to ID eventually
|
||||
#
|
||||
# def __repr__(self) -> str:
|
||||
# return f"<PresetFunctionMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', function_id='{self.function_id}')>"
|
||||
|
||||
|
||||
class PresetModel(Base):
|
||||
"""Defines data model for storing Preset objects"""
|
||||
|
||||
__tablename__ = "presets"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True, nullable=False)
|
||||
value = Column(String, nullable=False)
|
||||
limit = Column(BIGINT)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
template = Column(Boolean, default=False) # True: listed as possible human/persona
|
||||
label = Column(String)
|
||||
metadata_ = Column(JSON)
|
||||
description = Column(String)
|
||||
user_id = Column(String)
|
||||
system = Column(String)
|
||||
human = Column(String)
|
||||
human_name = Column(String, nullable=False)
|
||||
persona = Column(String)
|
||||
persona_name = Column(String, nullable=False)
|
||||
preset = Column(String)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
functions_schema = Column(JSON)
|
||||
Index(__tablename__ + "_idx_user", user_id),
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Block(id='{self.id}', name='{self.name}', template='{self.template}', label='{self.label}', user_id='{self.user_id}')>"
|
||||
return f"<Preset(id='{self.id}', name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> Block:
|
||||
if self.label == "persona":
|
||||
return Persona(
|
||||
id=self.id,
|
||||
value=self.value,
|
||||
limit=self.limit,
|
||||
name=self.name,
|
||||
template=self.template,
|
||||
label=self.label,
|
||||
metadata_=self.metadata_,
|
||||
description=self.description,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
elif self.label == "human":
|
||||
return Human(
|
||||
id=self.id,
|
||||
value=self.value,
|
||||
limit=self.limit,
|
||||
name=self.name,
|
||||
template=self.template,
|
||||
label=self.label,
|
||||
metadata_=self.metadata_,
|
||||
description=self.description,
|
||||
user_id=self.user_id,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Block with label {self.label} is not supported")
|
||||
|
||||
|
||||
class ToolModel(Base):
|
||||
__tablename__ = "tools"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
user_id = Column(String)
|
||||
description = Column(String)
|
||||
source_type = Column(String)
|
||||
source_code = Column(String)
|
||||
json_schema = Column(JSON)
|
||||
module = Column(String)
|
||||
tags = Column(JSON)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tool(id='{self.id}', name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> Tool:
|
||||
return Tool(
|
||||
def to_record(self) -> Preset:
|
||||
return Preset(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
user_id=self.user_id,
|
||||
description=self.description,
|
||||
source_type=self.source_type,
|
||||
source_code=self.source_code,
|
||||
json_schema=self.json_schema,
|
||||
module=self.module,
|
||||
tags=self.tags,
|
||||
)
|
||||
|
||||
|
||||
class JobModel(Base):
|
||||
__tablename__ = "jobs"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
user_id = Column(String)
|
||||
status = Column(String, default=JobStatus.pending)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
completed_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
metadata_ = Column(JSON)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Job(id='{self.id}', status='{self.status}')>"
|
||||
|
||||
def to_record(self):
|
||||
return Job(
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
status=self.status,
|
||||
system=self.system,
|
||||
human=self.human,
|
||||
persona=self.persona,
|
||||
human_name=self.human_name,
|
||||
persona_name=self.persona_name,
|
||||
preset=self.preset,
|
||||
created_at=self.created_at,
|
||||
completed_at=self.completed_at,
|
||||
metadata_=self.metadata_,
|
||||
functions_schema=self.functions_schema,
|
||||
)
|
||||
|
||||
|
||||
@ -364,8 +357,11 @@ class MetadataStore:
|
||||
AgentModel.__table__,
|
||||
SourceModel.__table__,
|
||||
AgentSourceMappingModel.__table__,
|
||||
APIKeyModel.__table__,
|
||||
BlockModel.__table__,
|
||||
TokenModel.__table__,
|
||||
PresetModel.__table__,
|
||||
PresetSourceMapping.__table__,
|
||||
HumanModel.__table__,
|
||||
PersonaModel.__table__,
|
||||
ToolModel.__table__,
|
||||
JobModel.__table__,
|
||||
],
|
||||
@ -391,17 +387,16 @@ class MetadataStore:
|
||||
self.session_maker = sessionmaker(bind=self.engine)
|
||||
|
||||
@enforce_types
|
||||
def create_api_key(self, user_id: str, name: str) -> APIKey:
|
||||
def create_api_key(self, user_id: uuid.UUID, name: Optional[str] = None) -> Token:
|
||||
"""Create an API key for a user"""
|
||||
new_api_key = generate_api_key()
|
||||
with self.session_maker() as session:
|
||||
if session.query(APIKeyModel).filter(APIKeyModel.key == new_api_key).count() > 0:
|
||||
if session.query(TokenModel).filter(TokenModel.token == new_api_key).count() > 0:
|
||||
# NOTE duplicate API keys / tokens should never happen, but if it does don't allow it
|
||||
raise ValueError(f"Token {new_api_key} already exists")
|
||||
# TODO store the API keys as hashed
|
||||
assert user_id and name, "User ID and name must be provided"
|
||||
token = APIKey(user_id=user_id, key=new_api_key, name=name)
|
||||
session.add(APIKeyModel(**vars(token)))
|
||||
token = Token(user_id=user_id, token=new_api_key, name=name)
|
||||
session.add(TokenModel(**vars(token)))
|
||||
session.commit()
|
||||
return self.get_api_key(api_key=new_api_key)
|
||||
|
||||
@ -409,22 +404,22 @@ class MetadataStore:
|
||||
def delete_api_key(self, api_key: str):
|
||||
"""Delete an API key from the database"""
|
||||
with self.session_maker() as session:
|
||||
session.query(APIKeyModel).filter(APIKeyModel.key == api_key).delete()
|
||||
session.query(TokenModel).filter(TokenModel.token == api_key).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def get_api_key(self, api_key: str) -> Optional[APIKey]:
|
||||
def get_api_key(self, api_key: str) -> Optional[Token]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(APIKeyModel).filter(APIKeyModel.key == api_key).all()
|
||||
results = session.query(TokenModel).filter(TokenModel.token == api_key).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_all_api_keys_for_user(self, user_id: str) -> List[APIKey]:
|
||||
def get_all_api_keys_for_user(self, user_id: uuid.UUID) -> List[Token]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all()
|
||||
results = session.query(TokenModel).filter(TokenModel.user_id == user_id).all()
|
||||
tokens = [r.to_record() for r in results]
|
||||
return tokens
|
||||
|
||||
@ -441,20 +436,25 @@ class MetadataStore:
|
||||
def create_agent(self, agent: AgentState):
|
||||
# insert into agent table
|
||||
# make sure agent.name does not already exist for user user_id
|
||||
assert agent.state is not None, "Agent state must be provided"
|
||||
assert len(list(agent.state.keys())) > 0, "Agent state must not be empty"
|
||||
with self.session_maker() as session:
|
||||
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
|
||||
raise ValueError(f"Agent with name {agent.name} already exists")
|
||||
fields = vars(agent)
|
||||
fields["memory"] = agent.memory.to_dict()
|
||||
session.add(AgentModel(**fields))
|
||||
session.add(AgentModel(**vars(agent)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_source(self, source: Source):
|
||||
def create_source(self, source: Source, exists_ok=False):
|
||||
# make sure source.name does not already exist for user
|
||||
with self.session_maker() as session:
|
||||
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
|
||||
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
|
||||
session.add(SourceModel(**vars(source)))
|
||||
if not exists_ok:
|
||||
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
|
||||
else:
|
||||
session.update(SourceModel(**vars(source)))
|
||||
else:
|
||||
session.add(SourceModel(**vars(source)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
@ -466,40 +466,67 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_block(self, block: Block):
|
||||
def create_preset(self, preset: Preset):
|
||||
with self.session_maker() as session:
|
||||
# TODO: fix?
|
||||
# we are only validating that more than one template block
|
||||
# with a given name doesn't exist.
|
||||
if (
|
||||
session.query(BlockModel)
|
||||
.filter(BlockModel.name == block.name)
|
||||
.filter(BlockModel.user_id == block.user_id)
|
||||
.filter(BlockModel.template == True)
|
||||
.filter(BlockModel.label == block.label)
|
||||
.count()
|
||||
> 0
|
||||
):
|
||||
|
||||
raise ValueError(f"Block with name {block.name} already exists")
|
||||
session.add(BlockModel(**vars(block)))
|
||||
if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0:
|
||||
raise ValueError(f"User with id {preset.id} already exists")
|
||||
session.add(PresetModel(**vars(preset)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_tool(self, tool: Tool):
|
||||
def get_preset(
|
||||
self, preset_id: Optional[uuid.UUID] = None, name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
|
||||
) -> Optional[Preset]:
|
||||
with self.session_maker() as session:
|
||||
if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None:
|
||||
raise ValueError(f"Tool with name {tool.name} already exists")
|
||||
session.add(ToolModel(**vars(tool)))
|
||||
if preset_id:
|
||||
results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
|
||||
elif name and user_id:
|
||||
results = session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).all()
|
||||
else:
|
||||
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
# @enforce_types
|
||||
# def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]):
|
||||
# preset = self.get_preset(preset_id)
|
||||
# if preset is None:
|
||||
# raise ValueError(f"Preset with id {preset_id} does not exist")
|
||||
# user_id = preset.user_id
|
||||
# with self.session_maker() as session:
|
||||
# for function in functions:
|
||||
# session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function))
|
||||
# session.commit()
|
||||
|
||||
@enforce_types
|
||||
def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]):
|
||||
preset = self.get_preset(preset_id)
|
||||
if preset is None:
|
||||
raise ValueError(f"Preset with id {preset_id} does not exist")
|
||||
user_id = preset.user_id
|
||||
with self.session_maker() as session:
|
||||
for source_id in sources:
|
||||
session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id))
|
||||
session.commit()
|
||||
|
||||
# @enforce_types
|
||||
# def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]:
|
||||
# with self.session_maker() as session:
|
||||
# results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all()
|
||||
# return [r.function for r in results]
|
||||
|
||||
@enforce_types
|
||||
def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all()
|
||||
return [r.source_id for r in results]
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent: AgentState):
|
||||
with self.session_maker() as session:
|
||||
fields = vars(agent)
|
||||
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
|
||||
fields["memory"] = agent.memory.to_dict()
|
||||
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
|
||||
session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
@ -515,41 +542,28 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def update_block(self, block: Block):
|
||||
def update_human(self, human: HumanModel):
|
||||
with self.session_maker() as session:
|
||||
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
|
||||
session.add(human)
|
||||
session.commit()
|
||||
session.refresh(human)
|
||||
|
||||
@enforce_types
|
||||
def update_or_create_block(self, block: Block):
|
||||
def update_persona(self, persona: PersonaModel):
|
||||
with self.session_maker() as session:
|
||||
existing_block = session.query(BlockModel).filter(BlockModel.id == block.id).first()
|
||||
if existing_block:
|
||||
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
|
||||
else:
|
||||
session.add(BlockModel(**vars(block)))
|
||||
session.add(persona)
|
||||
session.commit()
|
||||
session.refresh(persona)
|
||||
|
||||
@enforce_types
|
||||
def update_tool(self, tool: Tool):
|
||||
def update_tool(self, tool: ToolModel):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool))
|
||||
session.add(tool)
|
||||
session.commit()
|
||||
session.refresh(tool)
|
||||
|
||||
@enforce_types
|
||||
def delete_tool(self, tool_id: str):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_block(self, block_id: str):
|
||||
with self.session_maker() as session:
|
||||
session.query(BlockModel).filter(BlockModel.id == block_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_agent(self, agent_id: str):
|
||||
def delete_agent(self, agent_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
|
||||
# delete agents
|
||||
@ -561,7 +575,7 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_source(self, source_id: str):
|
||||
def delete_source(self, source_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
# delete from sources table
|
||||
session.query(SourceModel).filter(SourceModel.id == source_id).delete()
|
||||
@ -572,7 +586,7 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_user(self, user_id: str):
|
||||
def delete_user(self, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
# delete from users table
|
||||
session.query(UserModel).filter(UserModel.id == user_id).delete()
|
||||
@ -589,30 +603,42 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
|
||||
def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||
def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
|
||||
res = [r.to_record() for r in results]
|
||||
return res
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def list_agents(self, user_id: str) -> List[AgentState]:
|
||||
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_sources(self, user_id: str) -> List[Source]:
|
||||
def list_all_agents(self) -> List[AgentState]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(AgentModel).all()
|
||||
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def list_sources(self, user_id: uuid.UUID) -> List[Source]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
|
||||
@enforce_types
|
||||
def get_agent(
|
||||
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
|
||||
self, agent_id: Optional[uuid.UUID] = None, agent_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
|
||||
) -> Optional[AgentState]:
|
||||
with self.session_maker() as session:
|
||||
if agent_id:
|
||||
@ -627,7 +653,7 @@ class MetadataStore:
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_user(self, user_id: str) -> Optional[User]:
|
||||
def get_user(self, user_id: uuid.UUID) -> Optional[User]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(UserModel).filter(UserModel.id == user_id).all()
|
||||
if len(results) == 0:
|
||||
@ -636,7 +662,7 @@ class MetadataStore:
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
||||
def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]):
|
||||
with self.session_maker() as session:
|
||||
query = session.query(UserModel).order_by(desc(UserModel.id))
|
||||
if cursor:
|
||||
@ -646,13 +672,13 @@ class MetadataStore:
|
||||
return None, []
|
||||
user_records = [r.to_record() for r in results]
|
||||
next_cursor = user_records[-1].id
|
||||
assert isinstance(next_cursor, str)
|
||||
assert isinstance(next_cursor, uuid.UUID)
|
||||
|
||||
return next_cursor, user_records
|
||||
|
||||
@enforce_types
|
||||
def get_source(
|
||||
self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None
|
||||
self, source_id: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None
|
||||
) -> Optional[Source]:
|
||||
with self.session_maker() as session:
|
||||
if source_id:
|
||||
@ -666,89 +692,42 @@ class MetadataStore:
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_tool(
|
||||
self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None
|
||||
) -> Optional[ToolModel]:
|
||||
def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]:
|
||||
# TODO: add user_id when tools can eventually be added by users
|
||||
with self.session_maker() as session:
|
||||
if tool_id:
|
||||
results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
|
||||
else:
|
||||
assert tool_name is not None
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_block(self, block_id: str) -> Optional[Block]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(BlockModel).filter(BlockModel.id == block_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_blocks(
|
||||
self,
|
||||
user_id: Optional[str],
|
||||
label: Optional[str] = None,
|
||||
template: bool = True,
|
||||
name: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
) -> List[Block]:
|
||||
"""List available blocks"""
|
||||
with self.session_maker() as session:
|
||||
query = session.query(BlockModel).filter(BlockModel.template == template)
|
||||
|
||||
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
|
||||
if user_id:
|
||||
query = query.filter(BlockModel.user_id == user_id)
|
||||
|
||||
if label:
|
||||
query = query.filter(BlockModel.label == label)
|
||||
|
||||
if name:
|
||||
query = query.filter(BlockModel.name == name)
|
||||
|
||||
if id:
|
||||
query = query.filter(BlockModel.id == id)
|
||||
|
||||
results = query.all()
|
||||
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
|
||||
|
||||
if len(results) == 0:
|
||||
return None
|
||||
|
||||
return [r.to_record() for r in results]
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0]
|
||||
|
||||
# agent source metadata
|
||||
@enforce_types
|
||||
def attach_source(self, user_id: str, agent_id: str, source_id: str):
|
||||
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
# TODO: remove this (is a hack)
|
||||
mapping_id = f"{user_id}-{agent_id}-{source_id}"
|
||||
session.add(AgentSourceMappingModel(id=mapping_id, user_id=user_id, agent_id=agent_id, source_id=source_id))
|
||||
session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def list_attached_sources(self, agent_id: str) -> List[Source]:
|
||||
def list_attached_sources(self, agent_id: uuid.UUID) -> List[uuid.UUID]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
|
||||
|
||||
sources = []
|
||||
source_ids = []
|
||||
# make sure source exists
|
||||
for r in results:
|
||||
source = self.get_source(source_id=r.source_id)
|
||||
if source:
|
||||
sources.append(source)
|
||||
source_ids.append(r.source_id)
|
||||
else:
|
||||
printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.")
|
||||
return sources
|
||||
return source_ids
|
||||
|
||||
@enforce_types
|
||||
def list_attached_agents(self, source_id: str) -> List[str]:
|
||||
def list_attached_agents(self, source_id: uuid.UUID) -> List[uuid.UUID]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
|
||||
|
||||
@ -763,7 +742,7 @@ class MetadataStore:
|
||||
return agent_ids
|
||||
|
||||
@enforce_types
|
||||
def detach_source(self, agent_id: str, source_id: str):
|
||||
def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(AgentSourceMappingModel).filter(
|
||||
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
||||
@ -771,38 +750,120 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_job(self, job: Job):
|
||||
def add_human(self, human: HumanModel):
|
||||
with self.session_maker() as session:
|
||||
session.add(JobModel(**vars(job)))
|
||||
if self.get_human(human.name, human.user_id):
|
||||
raise ValueError(f"Human with name {human.name} already exists for user_id {human.user_id}")
|
||||
session.add(human)
|
||||
session.commit()
|
||||
|
||||
def delete_job(self, job_id: str):
|
||||
@enforce_types
|
||||
def add_persona(self, persona: PersonaModel):
|
||||
with self.session_maker() as session:
|
||||
session.query(JobModel).filter(JobModel.id == job_id).delete()
|
||||
if self.get_persona(persona.name, persona.user_id):
|
||||
raise ValueError(f"Persona with name {persona.name} already exists for user_id {persona.user_id}")
|
||||
session.add(persona)
|
||||
session.commit()
|
||||
|
||||
def get_job(self, job_id: str) -> Optional[Job]:
|
||||
@enforce_types
|
||||
def add_preset(self, preset: PresetModel): # TODO: remove
|
||||
with self.session_maker() as session:
|
||||
results = session.query(JobModel).filter(JobModel.id == job_id).all()
|
||||
session.add(preset)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def add_tool(self, tool: ToolModel):
|
||||
with self.session_maker() as session:
|
||||
if self.get_tool(tool.name, tool.user_id):
|
||||
raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}")
|
||||
session.add(tool)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def get_human(self, name: str, user_id: uuid.UUID) -> Optional[HumanModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
return results[0]
|
||||
|
||||
def list_jobs(self, user_id: str) -> List[Job]:
|
||||
@enforce_types
|
||||
def get_persona(self, name: str, user_id: uuid.UUID) -> Optional[PersonaModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(JobModel).filter(JobModel.user_id == user_id).all()
|
||||
return [r.to_record() for r in results]
|
||||
results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0]
|
||||
|
||||
def update_job(self, job: Job) -> Job:
|
||||
@enforce_types
|
||||
def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]:
|
||||
with self.session_maker() as session:
|
||||
session.query(JobModel).filter(JobModel.id == job.id).update(vars(job))
|
||||
results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]:
|
||||
with self.session_maker() as session:
|
||||
# if user_id matches provided user_id or if user_id is None
|
||||
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
|
||||
return results
|
||||
|
||||
@enforce_types
|
||||
def delete_human(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
return Job
|
||||
|
||||
def update_job_status(self, job_id: str, status: JobStatus):
|
||||
@enforce_types
|
||||
def delete_persona(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_preset(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_tool(self, name: str, user_id: uuid.UUID):
|
||||
with self.session_maker() as session:
|
||||
session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete()
|
||||
session.commit()
|
||||
|
||||
# job related functions
|
||||
def create_job(self, job: JobModel):
|
||||
with self.session_maker() as session:
|
||||
session.add(job)
|
||||
session.commit()
|
||||
session.expunge_all()
|
||||
|
||||
def update_job_status(self, job_id: uuid.UUID, status: JobStatus):
|
||||
with self.session_maker() as session:
|
||||
session.query(JobModel).filter(JobModel.id == job_id).update({"status": status})
|
||||
if status == JobStatus.COMPLETED:
|
||||
session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()})
|
||||
session.commit()
|
||||
|
||||
def update_job(self, job: JobModel):
|
||||
with self.session_maker() as session:
|
||||
session.add(job)
|
||||
session.commit()
|
||||
session.refresh(job)
|
||||
|
||||
def get_job(self, job_id: uuid.UUID) -> Optional[JobModel]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(JobModel).filter(JobModel.id == job_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0]
|
||||
|
716
memgpt/migrate.py
Normal file
716
memgpt/migrate.py
Normal file
@ -0,0 +1,716 @@
|
||||
import configparser
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
import pytz
|
||||
import questionary
|
||||
import typer
|
||||
from tqdm import tqdm
|
||||
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.cli.cli_config import configure
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.data_types import AgentState, Message, Passage, Source, User
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.utils import (
|
||||
MEMGPT_DIR,
|
||||
OpenAIBackcompatUnpickler,
|
||||
annotate_message_json_list_with_tool_calls,
|
||||
get_utc_time,
|
||||
parse_formatted_time,
|
||||
version_less_than,
|
||||
)
|
||||
|
||||
# This is the version where the breaking change was made
|
||||
VERSION_CUTOFF = "0.2.12"
|
||||
|
||||
# Migration backup dir (where we'll dump old agents that we successfully migrated)
|
||||
MIGRATION_BACKUP_FOLDER = "migration_backups"
|
||||
|
||||
|
||||
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True, create_config=True):
|
||||
"""Wipe (backup) the config file, and launch `memgpt configure`"""
|
||||
|
||||
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
|
||||
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
|
||||
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
|
||||
|
||||
# Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS)
|
||||
timestamp = get_utc_time().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Construct the new backup directory name with the timestamp
|
||||
backup_filename = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}")
|
||||
existing_filename = os.path.join(data_dir, "config")
|
||||
|
||||
# Check if the existing file exists before moving
|
||||
if os.path.exists(existing_filename):
|
||||
# shutil should work cross-platform
|
||||
shutil.move(existing_filename, backup_filename)
|
||||
typer.secho(f"Deleted config file ({existing_filename}) and saved as backup ({backup_filename})", fg=typer.colors.GREEN)
|
||||
else:
|
||||
typer.secho(f"Couldn't find an existing config file to delete", fg=typer.colors.RED)
|
||||
|
||||
if run_configure:
|
||||
# Either run configure
|
||||
configure()
|
||||
elif create_config:
|
||||
# Or create a new config with defaults
|
||||
MemGPTConfig.load()
|
||||
|
||||
|
||||
def config_is_compatible(data_dir: str = MEMGPT_DIR, allow_empty=False, echo=False) -> bool:
|
||||
"""Check if the config is OK to use with 0.2.12, or if it needs to be deleted"""
|
||||
# NOTE: don't use built-in load(), since that will apply defaults
|
||||
# memgpt_config = MemGPTConfig.load()
|
||||
memgpt_config_file = os.path.join(data_dir, "config")
|
||||
if not os.path.exists(memgpt_config_file):
|
||||
return True if allow_empty else False
|
||||
parser = configparser.ConfigParser()
|
||||
parser.read(memgpt_config_file)
|
||||
|
||||
if "version" in parser and "memgpt_version" in parser["version"]:
|
||||
version = parser["version"]["memgpt_version"]
|
||||
else:
|
||||
version = None
|
||||
|
||||
if version is None:
|
||||
# no version -- assume pre-determined config (does not need to be migrated)
|
||||
return True
|
||||
elif version_less_than(version, VERSION_CUTOFF):
|
||||
if echo:
|
||||
typer.secho(f"Current config version ({version}) is older than migration cutoff ({VERSION_CUTOFF})", fg=typer.colors.RED)
|
||||
return False
|
||||
else:
|
||||
if echo:
|
||||
typer.secho(f"Current config version {version} is compatible!", fg=typer.colors.GREEN)
|
||||
return True
|
||||
|
||||
|
||||
def agent_is_migrateable(agent_name: str, data_dir: str = MEMGPT_DIR) -> bool:
|
||||
"""Determine whether or not the agent folder is a migration target"""
|
||||
agent_folder = os.path.join(data_dir, "agents", agent_name)
|
||||
|
||||
if not os.path.exists(agent_folder):
|
||||
raise ValueError(f"Folder {agent_folder} does not exist")
|
||||
|
||||
agent_config_file = os.path.join(agent_folder, "config.json")
|
||||
if not os.path.exists(agent_config_file):
|
||||
raise ValueError(f"Agent folder {agent_folder} does not have a config file")
|
||||
|
||||
try:
|
||||
with open(agent_config_file, "r", encoding="utf-8") as fh:
|
||||
agent_config = json.load(fh)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load agent config file ({agent_config_file}), error = {e}")
|
||||
|
||||
if not hasattr(agent_config, "memgpt_version") or version_less_than(agent_config.memgpt_version, VERSION_CUTOFF):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None):
|
||||
"""
|
||||
Migrate an old source folder (`~/.memgpt/sources/{source_name}`).
|
||||
"""
|
||||
|
||||
# 1. Load the VectorIndex from ~/.memgpt/sources/{source_name}/index
|
||||
# TODO
|
||||
source_path = os.path.join(data_dir, "archival", source_name, "nodes.pkl")
|
||||
assert os.path.exists(source_path), f"Source {source_name} does not exist at {source_path}"
|
||||
|
||||
# load state from old checkpoint file
|
||||
|
||||
# 2. Create a new AgentState using the agent config + agent internal state
|
||||
config = MemGPTConfig.load()
|
||||
if ms is None:
|
||||
ms = MetadataStore(config)
|
||||
|
||||
# gets default user
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
ms.create_user(User(id=user_id))
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
# raise ValueError(
|
||||
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
|
||||
# )
|
||||
|
||||
# insert source into metadata store
|
||||
source = Source(user_id=user.id, name=source_name)
|
||||
ms.create_source(source)
|
||||
|
||||
try:
|
||||
try:
|
||||
nodes = pickle.load(open(source_path, "rb"))
|
||||
except ModuleNotFoundError as e:
|
||||
if "No module named 'llama_index.schema'" in str(e):
|
||||
# cannot load source at all, so throw error
|
||||
raise ValueError(
|
||||
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
passages = []
|
||||
for node in nodes:
|
||||
# print(len(node.embedding))
|
||||
# TODO: make sure embedding config matches embedding size?
|
||||
if len(node.embedding) != config.default_embedding_config.embedding_dim:
|
||||
raise ValueError(
|
||||
f"Cannot migrate source {source_name} due to incompatible embedding dimentions. Please re-load this source with `memgpt load`."
|
||||
)
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
data_source=source_name,
|
||||
text=node.text,
|
||||
embedding=node.embedding,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
|
||||
assert len(passages) > 0, f"Source {source_name} has no passages"
|
||||
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
|
||||
conn.insert_many(passages)
|
||||
# print(f"Inserted {len(passages)} to {source_name}")
|
||||
except Exception as e:
|
||||
# delete from metadata store
|
||||
ms.delete_source(source.id)
|
||||
raise ValueError(f"Failed to migrate {source_name}: {str(e)}")
|
||||
|
||||
# basic checks
|
||||
source = ms.get_source(user_id=user.id, source_name=source_name)
|
||||
assert source is not None, f"Failed to load source {source_name} from database after migration"
|
||||
|
||||
|
||||
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None) -> List[str]:
|
||||
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
|
||||
|
||||
Steps:
|
||||
1. Load the agent state JSON from the old folder
|
||||
2. Create a new AgentState using the agent config + agent internal state
|
||||
3. Instantiate a new Agent by passing AgentState to Agent.__init__
|
||||
(This will automatically run into a new database)
|
||||
|
||||
If success, returns empty list
|
||||
If warning, returns a list of strings (warning message)
|
||||
If error, raises an Exception
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
# 1. Load the agent state JSON from the old folder
|
||||
# TODO
|
||||
agent_folder = os.path.join(data_dir, "agents", agent_name)
|
||||
# migration_file = os.path.join(agent_folder, MIGRATION_FILE_NAME)
|
||||
|
||||
# load state from old checkpoint file
|
||||
agent_ckpt_directory = os.path.join(agent_folder, "agent_state")
|
||||
json_files = glob.glob(os.path.join(agent_ckpt_directory, "*.json")) # This will list all .json files in the current directory.
|
||||
if not json_files:
|
||||
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}")
|
||||
# NOTE this is a soft fail, just allow it to pass
|
||||
# return
|
||||
# return [f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}"]
|
||||
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
state_filename = max(json_files, key=os.path.getmtime)
|
||||
state_dict = json.load(open(state_filename, "r"))
|
||||
|
||||
# print(state_dict.keys())
|
||||
# print(state_dict["memory"])
|
||||
# dict_keys(['model', 'system', 'functions', 'messages', 'messages_total', 'memory'])
|
||||
|
||||
# load old data from the persistence manager
|
||||
persistence_filename = os.path.basename(state_filename).replace(".json", ".persistence.pickle")
|
||||
persistence_filename = os.path.join(agent_folder, "persistence_manager", persistence_filename)
|
||||
archival_filename = os.path.join(agent_folder, "persistence_manager", "index", "nodes.pkl")
|
||||
if not os.path.exists(persistence_filename):
|
||||
raise ValueError(f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}")
|
||||
# return [f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}"]
|
||||
|
||||
try:
|
||||
with open(persistence_filename, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
except ModuleNotFoundError:
|
||||
# Patch for stripped openai package
|
||||
# ModuleNotFoundError: No module named 'openai.openai_object'
|
||||
with open(persistence_filename, "rb") as f:
|
||||
unpickler = OpenAIBackcompatUnpickler(f)
|
||||
data = unpickler.load()
|
||||
|
||||
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
||||
|
||||
def convert_openai_objects_to_dict(obj):
|
||||
if isinstance(obj, OpenAIObject):
|
||||
# Convert to dict or handle as needed
|
||||
# print(f"detected OpenAIObject on {obj}")
|
||||
return obj.to_dict_recursive()
|
||||
elif isinstance(obj, dict):
|
||||
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
|
||||
elif isinstance(obj, list):
|
||||
return [convert_openai_objects_to_dict(v) for v in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
data = convert_openai_objects_to_dict(data)
|
||||
|
||||
# data will contain:
|
||||
# print("data.keys()", data.keys())
|
||||
# manager.all_messages = data["all_messages"]
|
||||
# manager.messages = data["messages"]
|
||||
# manager.recall_memory = data["recall_memory"]
|
||||
|
||||
agent_config_filename = os.path.join(agent_folder, "config.json")
|
||||
with open(agent_config_filename, "r", encoding="utf-8") as fh:
|
||||
agent_config = json.load(fh)
|
||||
|
||||
# 2. Create a new AgentState using the agent config + agent internal state
|
||||
config = MemGPTConfig.load()
|
||||
if ms is None:
|
||||
ms = MetadataStore(config)
|
||||
|
||||
# gets default user
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
ms.create_user(User(id=user_id))
|
||||
user = ms.get_user(user_id=user_id)
|
||||
if user is None:
|
||||
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
# raise ValueError(
|
||||
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
|
||||
# )
|
||||
# ms.create_user(User(id=user_id))
|
||||
# user = ms.get_user(user_id=user_id)
|
||||
# if user is None:
|
||||
# typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
|
||||
# sys.exit(1)
|
||||
|
||||
# create an agent_id ahead of time
|
||||
agent_id = uuid.uuid4()
|
||||
|
||||
# create all the Messages in the database
|
||||
# message_objs = []
|
||||
# for message_dict in annotate_message_json_list_with_tool_calls(state_dict["messages"]):
|
||||
# message_obj = Message.dict_to_message(
|
||||
# user_id=user.id,
|
||||
# agent_id=agent_id,
|
||||
# openai_message_dict=message_dict,
|
||||
# model=state_dict["model"] if "model" in state_dict else None,
|
||||
# # allow_functions_style=False,
|
||||
# allow_functions_style=True,
|
||||
# )
|
||||
# message_objs.append(message_obj)
|
||||
|
||||
agent_state = AgentState(
|
||||
id=agent_id,
|
||||
name=agent_config["name"],
|
||||
user_id=user.id,
|
||||
# persona_name=agent_config["persona"], # eg 'sam_pov'
|
||||
# human_name=agent_config["human"], # eg 'basic'
|
||||
persona=state_dict["memory"]["persona"], # NOTE: hacky (not init, but latest)
|
||||
human=state_dict["memory"]["human"], # NOTE: hacky (not init, but latest)
|
||||
preset=agent_config["preset"], # eg 'memgpt_chat'
|
||||
state=dict(
|
||||
human=state_dict["memory"]["human"],
|
||||
persona=state_dict["memory"]["persona"],
|
||||
system=state_dict["system"],
|
||||
functions=state_dict["functions"], # this shouldn't matter, since Agent.__init__ will re-link
|
||||
# messages=[str(m.id) for m in message_objs], # this is a list of uuids, not message dicts
|
||||
),
|
||||
llm_config=config.default_llm_config,
|
||||
embedding_config=config.default_embedding_config,
|
||||
)
|
||||
|
||||
persistence_manager = LocalStateManager(agent_state=agent_state)
|
||||
|
||||
# First clean up the recall message history to add tool call ids
|
||||
# allow_tool_roles in case some of the old messages were actually already in tool call format (for whatever reason)
|
||||
full_message_history_buffer = annotate_message_json_list_with_tool_calls(
|
||||
[d["message"] for d in data["all_messages"]], allow_tool_roles=True
|
||||
)
|
||||
for i in range(len(data["all_messages"])):
|
||||
data["all_messages"][i]["message"] = full_message_history_buffer[i]
|
||||
|
||||
# Figure out what messages in recall are in-context, and which are out-of-context
|
||||
agent_message_cache = state_dict["messages"]
|
||||
recall_message_full = data["all_messages"]
|
||||
|
||||
def messages_are_equal(msg1, msg2):
|
||||
if msg1["role"] != msg2["role"]:
|
||||
return False
|
||||
if msg1["content"] != msg2["content"]:
|
||||
return False
|
||||
if "function_call" in msg1 and "function_call" in msg2 and msg1["function_call"] != msg2["function_call"]:
|
||||
return False
|
||||
if "name" in msg1 and "name" in msg2 and msg1["name"] != msg2["name"]:
|
||||
return False
|
||||
|
||||
# otherwise checks pass, ~= equal
|
||||
return True
|
||||
|
||||
in_context_messages = []
|
||||
out_of_context_messages = []
|
||||
assert len(agent_message_cache) <= len(recall_message_full), (len(agent_message_cache), len(recall_message_full))
|
||||
for i, d in enumerate(recall_message_full):
|
||||
# unpack into "timestamp" and "message"
|
||||
recall_message = d["message"]
|
||||
recall_timestamp = str(d["timestamp"])
|
||||
try:
|
||||
recall_datetime = parse_formatted_time(recall_timestamp.strip()).astimezone(pytz.utc)
|
||||
except ValueError:
|
||||
recall_datetime = datetime.strptime(recall_timestamp.strip(), "%Y-%m-%d %I:%M:%S %p").astimezone(pytz.utc)
|
||||
|
||||
# message object
|
||||
message_obj = Message.dict_to_message(
|
||||
created_at=recall_datetime,
|
||||
user_id=user.id,
|
||||
agent_id=agent_id,
|
||||
openai_message_dict=recall_message,
|
||||
allow_functions_style=True,
|
||||
)
|
||||
|
||||
# message is either in-context, or out-of-context
|
||||
|
||||
if i >= (len(recall_message_full) - len(agent_message_cache)):
|
||||
# there are len(agent_message_cache) total messages on the agent
|
||||
# this will correspond to the last N messages in the recall memory (though possibly out-of-order)
|
||||
message_is_in_context = [messages_are_equal(recall_message, cache_message) for cache_message in agent_message_cache]
|
||||
# assert sum(message_is_in_context) <= 1, message_is_in_context
|
||||
# if any(message_is_in_context):
|
||||
# in_context_messages.append(message_obj)
|
||||
# else:
|
||||
# out_of_context_messages.append(message_obj)
|
||||
|
||||
if not any(message_is_in_context):
|
||||
# typer.secho(
|
||||
# f"Warning: didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
|
||||
# fg=typer.colors.RED,
|
||||
# )
|
||||
warnings.append(
|
||||
f"Didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
|
||||
)
|
||||
out_of_context_messages.append(message_obj)
|
||||
else:
|
||||
if sum(message_is_in_context) > 1:
|
||||
# typer.secho(
|
||||
# f"Warning: found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
|
||||
# fg=typer.colors.RED,
|
||||
# )
|
||||
warnings.append(
|
||||
f"Found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
|
||||
)
|
||||
in_context_messages.append(message_obj)
|
||||
|
||||
else:
|
||||
# if we're not in the final portion of the recall memory buffer, then it's 100% out-of-context
|
||||
out_of_context_messages.append(message_obj)
|
||||
|
||||
assert len(in_context_messages) > 0, f"Couldn't find any in-context messages (agent_cache = {len(agent_message_cache)})"
|
||||
# assert len(in_context_messages) == len(agent_message_cache), (len(in_context_messages), len(agent_message_cache))
|
||||
if len(in_context_messages) != len(agent_message_cache):
|
||||
# typer.secho(
|
||||
# f"Warning: uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})",
|
||||
# fg=typer.colors.RED,
|
||||
# )
|
||||
warnings.append(
|
||||
f"Uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})"
|
||||
)
|
||||
# assert (
|
||||
# len(in_context_messages) + len(out_of_context_messages) == state_dict["messages_total"]
|
||||
# ), f"{len(in_context_messages)} + {len(out_of_context_messages)} != {state_dict['messages_total']}"
|
||||
|
||||
# Now we can insert the messages into the actual recall database
|
||||
# So when we construct the agent from the state, they will be available
|
||||
persistence_manager.recall_memory.insert_many(out_of_context_messages)
|
||||
persistence_manager.recall_memory.insert_many(in_context_messages)
|
||||
|
||||
# Overwrite the agent_state message object
|
||||
agent_state.state["messages"] = [str(m.id) for m in in_context_messages] # this is a list of uuids, not message dicts
|
||||
|
||||
## 4. Insert into recall
|
||||
# TODO should this be 'messages', or 'all_messages'?
|
||||
# all_messages in recall will have fields "timestamp" and "message"
|
||||
# full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
|
||||
# We want to keep the timestamp
|
||||
# for i in range(len(data["all_messages"])):
|
||||
# data["all_messages"][i]["message"] = full_message_history_buffer[i]
|
||||
# messages_to_insert = [
|
||||
# Message.dict_to_message(
|
||||
# user_id=user.id,
|
||||
# agent_id=agent_id,
|
||||
# openai_message_dict=msg,
|
||||
# allow_functions_style=True,
|
||||
# )
|
||||
# # for msg in data["all_messages"]
|
||||
# for msg in full_message_history_buffer
|
||||
# ]
|
||||
# agent.persistence_manager.recall_memory.insert_many(messages_to_insert)
|
||||
# print("Finished migrating recall memory")
|
||||
|
||||
# 3. Instantiate a new Agent by passing AgentState to Agent.__init__
|
||||
# NOTE: the Agent.__init__ will trigger a save, which will write to the DB
|
||||
try:
|
||||
agent = Agent(
|
||||
agent_state=agent_state,
|
||||
# messages_total=state_dict["messages_total"], # TODO: do we need this?
|
||||
messages_total=len(in_context_messages) + len(out_of_context_messages),
|
||||
interface=None,
|
||||
)
|
||||
save_agent(agent, ms=ms)
|
||||
except Exception:
|
||||
# if "Agent with name" in str(e):
|
||||
# print(e)
|
||||
# return
|
||||
# elif "was specified in agent.state.functions":
|
||||
# print(e)
|
||||
# return
|
||||
# else:
|
||||
# raise
|
||||
raise
|
||||
|
||||
# Wrap the rest in a try-except so that we can cleanup by deleting the agent if we fail
|
||||
try:
|
||||
# TODO should we also assign data["messages"] to RecallMemory.messages?
|
||||
|
||||
# 5. Insert into archival
|
||||
if os.path.exists(archival_filename):
|
||||
try:
|
||||
nodes = pickle.load(open(archival_filename, "rb"))
|
||||
except ModuleNotFoundError as e:
|
||||
if "No module named 'llama_index.schema'" in str(e):
|
||||
print(
|
||||
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
|
||||
)
|
||||
nodes = []
|
||||
else:
|
||||
raise e
|
||||
|
||||
passages = []
|
||||
failed_inserts = []
|
||||
for node in nodes:
|
||||
if len(node.embedding) != config.default_embedding_config.embedding_dim:
|
||||
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
|
||||
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
|
||||
failed_inserts.append(
|
||||
f"Cannot migrate passage due to incompatible embedding dimentions ({len(node.embedding)} != {config.default_embedding_config.embedding_dim}) - content = '{node.text}'."
|
||||
)
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
agent_id=agent_state.id,
|
||||
text=node.text,
|
||||
embedding=node.embedding,
|
||||
embedding_dim=agent_state.embedding_config.embedding_dim,
|
||||
embedding_model=agent_state.embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
if len(passages) > 0:
|
||||
agent.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
# print(f"Inserted {len(passages)} passages into archival memory")
|
||||
|
||||
if len(failed_inserts) > 0:
|
||||
warnings.append(
|
||||
f"Failed to transfer {len(failed_inserts)}/{len(nodes)} passages from old archival memory: " + ", ".join(failed_inserts)
|
||||
)
|
||||
|
||||
else:
|
||||
warnings.append("No archival memory found at", archival_filename)
|
||||
|
||||
except:
|
||||
ms.delete_agent(agent_state.id)
|
||||
raise
|
||||
|
||||
try:
|
||||
new_agent_folder = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents", agent_name)
|
||||
shutil.move(agent_folder, new_agent_folder)
|
||||
except Exception:
|
||||
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
|
||||
raise
|
||||
|
||||
return warnings
|
||||
|
||||
|
||||
# def migrate_all_agents(stop_on_fail=True):
|
||||
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
|
||||
"""Scan over all agent folders in data_dir and migrate each agent."""
|
||||
|
||||
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
|
||||
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
|
||||
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
|
||||
|
||||
if not config_is_compatible(data_dir, echo=True):
|
||||
typer.secho(f"Your current config file is incompatible with MemGPT versions >= {VERSION_CUTOFF}", fg=typer.colors.RED)
|
||||
if questionary.confirm(
|
||||
"To migrate old MemGPT agents, you must delete your config file and run `memgpt configure`. Would you like to proceed?"
|
||||
).ask():
|
||||
try:
|
||||
wipe_config_and_reconfigure(data_dir)
|
||||
except Exception as e:
|
||||
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
else:
|
||||
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
|
||||
raise KeyboardInterrupt()
|
||||
|
||||
agents_dir = os.path.join(data_dir, "agents")
|
||||
|
||||
# Ensure the directory exists
|
||||
if not os.path.exists(agents_dir):
|
||||
raise ValueError(f"Directory {agents_dir} does not exist.")
|
||||
|
||||
# Get a list of all folders in agents_dir
|
||||
agent_folders = [f for f in os.listdir(agents_dir) if os.path.isdir(os.path.join(agents_dir, f))]
|
||||
|
||||
# Iterate over each folder with a tqdm progress bar
|
||||
count = 0
|
||||
successes = [] # agents that migrated w/o warnings
|
||||
warnings = [] # agents that migrated but had warnings
|
||||
failures = [] # agents that failed to migrate (fatal error)
|
||||
candidates = []
|
||||
config = MemGPTConfig.load()
|
||||
print(config)
|
||||
ms = MetadataStore(config)
|
||||
try:
|
||||
for agent_name in tqdm(agent_folders, desc="Migrating agents"):
|
||||
# Assuming migrate_agent is a function that takes the agent name and performs migration
|
||||
try:
|
||||
if agent_is_migrateable(agent_name=agent_name, data_dir=data_dir):
|
||||
candidates.append(agent_name)
|
||||
migration_warnings = migrate_agent(agent_name, data_dir=data_dir, ms=ms)
|
||||
if len(migration_warnings) == 0:
|
||||
successes.append(agent_name)
|
||||
else:
|
||||
warnings.append((agent_name, migration_warnings))
|
||||
count += 1
|
||||
else:
|
||||
continue
|
||||
except Exception as e:
|
||||
failures.append({"name": agent_name, "reason": str(e)})
|
||||
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
|
||||
if debug:
|
||||
traceback.print_exc()
|
||||
if stop_on_fail:
|
||||
raise
|
||||
except KeyboardInterrupt:
|
||||
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
|
||||
|
||||
if len(candidates) == 0:
|
||||
typer.secho(f"No migration candidates found ({len(agent_folders)} agent folders total)", fg=typer.colors.GREEN)
|
||||
else:
|
||||
typer.secho(f"Inspected {len(agent_folders)} agent folders for migration")
|
||||
|
||||
if len(warnings) > 0:
|
||||
typer.secho(f"Migration warnings:", fg=typer.colors.BRIGHT_YELLOW)
|
||||
for warn in warnings:
|
||||
typer.secho(f"{warn[0]}: {warn[1]}", fg=typer.colors.BRIGHT_YELLOW)
|
||||
|
||||
if len(failures) > 0:
|
||||
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
|
||||
for fail in failures:
|
||||
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
|
||||
|
||||
if len(failures) > 0:
|
||||
typer.secho(
|
||||
f"🔴 {len(failures)}/{len(candidates)} agents failed to migrate (see reasons above)",
|
||||
fg=typer.colors.RED,
|
||||
)
|
||||
typer.secho(f"{[d['name'] for d in failures]}", fg=typer.colors.RED)
|
||||
|
||||
if len(warnings) > 0:
|
||||
typer.secho(
|
||||
f"🟠 {len(warnings)}/{len(candidates)} agents successfully migrated with warnings (see reasons above)",
|
||||
fg=typer.colors.BRIGHT_YELLOW,
|
||||
)
|
||||
typer.secho(f"{[t[0] for t in warnings]}", fg=typer.colors.BRIGHT_YELLOW)
|
||||
|
||||
if len(successes) > 0:
|
||||
typer.secho(
|
||||
f"🟢 {len(successes)}/{len(candidates)} agents successfully migrated with no warnings",
|
||||
fg=typer.colors.GREEN,
|
||||
)
|
||||
typer.secho(f"{successes}", fg=typer.colors.GREEN)
|
||||
|
||||
del ms
|
||||
return {
|
||||
"agent_folders": len(agent_folders),
|
||||
"migration_candidates": candidates,
|
||||
"successful_migrations": len(successes) + len(warnings),
|
||||
"failed_migrations": failures,
|
||||
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
|
||||
}
|
||||
|
||||
|
||||
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
|
||||
"""Scan over all agent folders in data_dir and migrate each agent."""
|
||||
|
||||
sources_dir = os.path.join(data_dir, "archival")
|
||||
|
||||
# Ensure the directory exists
|
||||
if not os.path.exists(sources_dir):
|
||||
raise ValueError(f"Directory {sources_dir} does not exist.")
|
||||
|
||||
# Get a list of all folders in agents_dir
|
||||
source_folders = [f for f in os.listdir(sources_dir) if os.path.isdir(os.path.join(sources_dir, f))]
|
||||
|
||||
# Iterate over each folder with a tqdm progress bar
|
||||
count = 0
|
||||
failures = []
|
||||
candidates = []
|
||||
config = MemGPTConfig.load()
|
||||
ms = MetadataStore(config)
|
||||
try:
|
||||
for source_name in tqdm(source_folders, desc="Migrating data sources"):
|
||||
# Assuming migrate_agent is a function that takes the agent name and performs migration
|
||||
try:
|
||||
candidates.append(source_name)
|
||||
migrate_source(source_name, data_dir, ms=ms)
|
||||
count += 1
|
||||
except Exception as e:
|
||||
failures.append({"name": source_name, "reason": str(e)})
|
||||
if debug:
|
||||
traceback.print_exc()
|
||||
if stop_on_fail:
|
||||
raise
|
||||
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
|
||||
except KeyboardInterrupt:
|
||||
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
|
||||
|
||||
if len(candidates) == 0:
|
||||
typer.secho(f"No migration candidates found ({len(source_folders)} source folders total)", fg=typer.colors.GREEN)
|
||||
else:
|
||||
typer.secho(f"Inspected {len(source_folders)} source folders")
|
||||
if len(failures) > 0:
|
||||
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
|
||||
for fail in failures:
|
||||
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
|
||||
typer.secho(f"❌ {len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
|
||||
if count > 0:
|
||||
typer.secho(
|
||||
f"✅ {count}/{len(candidates)} sources were successfully migrated to the new database format", fg=typer.colors.GREEN
|
||||
)
|
||||
|
||||
del ms
|
||||
return {
|
||||
"source_folders": len(source_folders),
|
||||
"migration_candidates": candidates,
|
||||
"successful_migrations": count,
|
||||
"failed_migrations": failures,
|
||||
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
|
||||
}
|
197
memgpt/models/pydantic_models.py
Normal file
197
memgpt/models/pydantic_models.py
Normal file
@ -0,0 +1,197 @@
|
||||
# tool imports
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import JSON, Column
|
||||
from sqlalchemy_utils import ChoiceType
|
||||
from sqlmodel import Field, SQLModel
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from memgpt.utils import get_human_text, get_persona_text, get_utc_time
|
||||
|
||||
|
||||
class OptionState(str, Enum):
|
||||
"""Useful for kwargs that are bool + default option"""
|
||||
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class MemGPTUsageStatistics(BaseModel):
|
||||
completion_tokens: int
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
step_count: int
|
||||
|
||||
|
||||
class LLMConfigModel(BaseModel):
|
||||
model: Optional[str] = "gpt-4"
|
||||
model_endpoint_type: Optional[str] = "openai"
|
||||
model_endpoint: Optional[str] = "https://api.openai.com/v1"
|
||||
model_wrapper: Optional[str] = None
|
||||
context_window: Optional[int] = None
|
||||
|
||||
# FIXME hack to silence pydantic protected namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class EmbeddingConfigModel(BaseModel):
|
||||
embedding_endpoint_type: Optional[str] = "openai"
|
||||
embedding_endpoint: Optional[str] = "https://api.openai.com/v1"
|
||||
embedding_model: Optional[str] = "text-embedding-ada-002"
|
||||
embedding_dim: Optional[int] = 1536
|
||||
embedding_chunk_size: Optional[int] = 300
|
||||
|
||||
|
||||
class PresetModel(BaseModel):
|
||||
name: str = Field(..., description="The name of the preset.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
|
||||
system: str = Field(..., description="The system prompt of the preset.")
|
||||
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
|
||||
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
|
||||
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
|
||||
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
|
||||
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||
|
||||
|
||||
class ToolModel(SQLModel, table=True):
|
||||
# TODO move into database
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True)
|
||||
tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
source_code: Optional[str] = Field(..., description="The source code of the function.")
|
||||
module: Optional[str] = Field(None, description="The module of the function.")
|
||||
|
||||
json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")
|
||||
|
||||
# optional: user_id (user-specific tools)
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.", index=True)
|
||||
|
||||
# Needed for Column(JSON)
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class AgentToolMap(SQLModel, table=True):
|
||||
# mapping between agents and tools
|
||||
agent_id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
|
||||
tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the agent-tool map.", primary_key=True)
|
||||
|
||||
|
||||
class PresetToolMap(SQLModel, table=True):
|
||||
# mapping between presets and tools
|
||||
preset_id: uuid.UUID = Field(..., description="The unique identifier of the preset.")
|
||||
tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset-tool map.", primary_key=True)
|
||||
|
||||
|
||||
class AgentStateModel(BaseModel):
|
||||
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
|
||||
|
||||
# timestamps
|
||||
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the agent was created.")
|
||||
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
|
||||
|
||||
# preset information
|
||||
tools: List[str] = Field(..., description="The tools used by the agent.")
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
# functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
|
||||
|
||||
# llm information
|
||||
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")
|
||||
|
||||
# agent state
|
||||
state: Optional[Dict] = Field(None, description="The state of the agent.")
|
||||
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
|
||||
|
||||
|
||||
class CoreMemory(BaseModel):
|
||||
human: str = Field(..., description="Human element of the core memory.")
|
||||
persona: str = Field(..., description="Persona element of the core memory.")
|
||||
|
||||
|
||||
class HumanModel(SQLModel, table=True):
|
||||
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
|
||||
name: str = Field(..., description="The name of the human.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
|
||||
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.", index=True)
|
||||
|
||||
|
||||
class PersonaModel(SQLModel, table=True):
|
||||
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
|
||||
name: str = Field(..., description="The name of the persona.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
|
||||
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.", index=True)
|
||||
|
||||
|
||||
class SourceModel(SQLModel, table=True):
|
||||
name: str = Field(..., description="The name of the source.")
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the source was created.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
# embedding info
|
||||
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
|
||||
embedding_config: Optional[EmbeddingConfigModel] = Field(
|
||||
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
|
||||
)
|
||||
# NOTE: .metadata is a reserved attribute on SQLModel
|
||||
metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.")
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
|
||||
|
||||
class JobModel(SQLModel, table=True):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the job.", primary_key=True)
|
||||
# status: str = Field(default="created", description="The status of the job.")
|
||||
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.", sa_column=Column(ChoiceType(JobStatus)))
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
|
||||
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the job.", index=True)
|
||||
metadata_: Optional[dict] = Field({}, sa_column=Column(JSON), description="The metadata of the job.")
|
||||
|
||||
|
||||
class PassageModel(BaseModel):
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.")
|
||||
agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.")
|
||||
text: str = Field(..., description="The text of the passage.")
|
||||
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
|
||||
embedding_config: Optional[EmbeddingConfigModel] = Field(
|
||||
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
|
||||
)
|
||||
data_source: Optional[str] = Field(None, description="The data source of the passage.")
|
||||
doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True)
|
||||
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")
|
||||
|
||||
|
||||
class DocumentModel(BaseModel):
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.")
|
||||
text: str = Field(..., description="The text of the document.")
|
||||
data_source: str = Field(..., description="The data source of the document.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
|
||||
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user.")
|
@ -2,10 +2,8 @@ from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from memgpt.data_types import AgentState, Message
|
||||
from memgpt.memory import BaseRecallMemory, EmbeddingArchivalMemory
|
||||
from memgpt.schemas.agent import AgentState
|
||||
from memgpt.schemas.memory import Memory
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.utils import printd
|
||||
|
||||
|
||||
@ -47,7 +45,7 @@ class LocalStateManager(PersistenceManager):
|
||||
|
||||
def __init__(self, agent_state: AgentState):
|
||||
# Memory held in-state useful for debugging stateful versions
|
||||
self.memory = agent_state.memory
|
||||
self.memory = None
|
||||
# self.messages = [] # current in-context messages
|
||||
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
|
||||
self.archival_memory = EmbeddingArchivalMemory(agent_state)
|
||||
@ -59,6 +57,15 @@ class LocalStateManager(PersistenceManager):
|
||||
self.archival_memory.save()
|
||||
self.recall_memory.save()
|
||||
|
||||
def init(self, agent):
|
||||
"""Connect persistent state manager to agent"""
|
||||
printd(f"Initializing {self.__class__.__name__} with agent object")
|
||||
# self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
# self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
# printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
||||
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
||||
|
||||
'''
|
||||
def json_to_message(self, message_json) -> Message:
|
||||
"""Convert agent message JSON into Message object"""
|
||||
@ -121,6 +128,7 @@ class LocalStateManager(PersistenceManager):
|
||||
# self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||
|
||||
# add to recall memory
|
||||
self.recall_memory.insert_many([m for m in added_messages])
|
||||
|
||||
def append_to_messages(self, added_messages: List[Message]):
|
||||
# first tag with timestamps
|
||||
@ -142,7 +150,6 @@ class LocalStateManager(PersistenceManager):
|
||||
# add to recall memory
|
||||
self.recall_memory.insert(new_system_message)
|
||||
|
||||
def update_memory(self, new_memory: Memory):
|
||||
def update_memory(self, new_memory):
|
||||
printd(f"{self.__class__.__name__}.update_memory")
|
||||
assert isinstance(new_memory, Memory), type(new_memory)
|
||||
self.memory = new_memory
|
||||
|
10
memgpt/presets/examples/memgpt_chat.yaml
Normal file
10
memgpt/presets/examples/memgpt_chat.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
system_prompt: "memgpt_chat"
|
||||
functions:
|
||||
- "send_message"
|
||||
- "pause_heartbeats"
|
||||
- "core_memory_append"
|
||||
- "core_memory_replace"
|
||||
- "conversation_search"
|
||||
- "conversation_search_date"
|
||||
- "archival_memory_insert"
|
||||
- "archival_memory_search"
|
10
memgpt/presets/examples/memgpt_docs.yaml
Normal file
10
memgpt/presets/examples/memgpt_docs.yaml
Normal file
@ -0,0 +1,10 @@
|
||||
system_prompt: "memgpt_doc"
|
||||
functions:
|
||||
- "send_message"
|
||||
- "pause_heartbeats"
|
||||
- "core_memory_append"
|
||||
- "core_memory_replace"
|
||||
- "conversation_search"
|
||||
- "conversation_search_date"
|
||||
- "archival_memory_insert"
|
||||
- "archival_memory_search"
|
15
memgpt/presets/examples/memgpt_extras.yaml
Normal file
15
memgpt/presets/examples/memgpt_extras.yaml
Normal file
@ -0,0 +1,15 @@
|
||||
system_prompt: "memgpt_chat"
|
||||
functions:
|
||||
- "send_message"
|
||||
- "pause_heartbeats"
|
||||
- "core_memory_append"
|
||||
- "core_memory_replace"
|
||||
- "conversation_search"
|
||||
- "conversation_search_date"
|
||||
- "archival_memory_insert"
|
||||
- "archival_memory_search"
|
||||
# extras for read/write to files
|
||||
- "read_from_text_file"
|
||||
- "append_to_text_file"
|
||||
# internet access
|
||||
- "http_request"
|
91
memgpt/presets/presets.py
Normal file
91
memgpt/presets/presets.py
Normal file
@ -0,0 +1,91 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from memgpt.data_types import AgentState, Preset
|
||||
from memgpt.functions.functions import load_function_set
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel
|
||||
from memgpt.presets.utils import load_all_presets
|
||||
from memgpt.utils import list_human_files, list_persona_files, printd
|
||||
|
||||
available_presets = load_all_presets()
|
||||
preset_options = list(available_presets.keys())
|
||||
|
||||
|
||||
def load_module_tools(module_name="base"):
|
||||
# return List[ToolModel] from base.py tools
|
||||
full_module_name = f"memgpt.functions.function_sets.{module_name}"
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
raise e
|
||||
|
||||
# function tags
|
||||
|
||||
try:
|
||||
# Load the function set
|
||||
functions_to_schema = load_function_set(module)
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
printd(err)
|
||||
|
||||
# create tool in db
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tags = [module_name]
|
||||
if module_name == "base":
|
||||
tags.append("memgpt-base")
|
||||
|
||||
tools.append(
|
||||
ToolModel(
|
||||
name=name,
|
||||
tags=tags,
|
||||
source_type="python",
|
||||
module=schema["module"],
|
||||
source_code=source_code,
|
||||
json_schema=schema["json_schema"],
|
||||
)
|
||||
)
|
||||
return tools
|
||||
|
||||
|
||||
def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
|
||||
module_name = "base"
|
||||
for tool in load_module_tools(module_name=module_name):
|
||||
existing_tool = ms.get_tool(tool.name)
|
||||
if not existing_tool:
|
||||
ms.add_tool(tool)
|
||||
|
||||
|
||||
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
|
||||
for persona_file in list_persona_files():
|
||||
text = open(persona_file, "r", encoding="utf-8").read()
|
||||
name = os.path.basename(persona_file).replace(".txt", "")
|
||||
if ms.get_persona(user_id=user_id, name=name) is not None:
|
||||
printd(f"Persona '{name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
persona = PersonaModel(name=name, text=text, user_id=user_id)
|
||||
ms.add_persona(persona)
|
||||
for human_file in list_human_files():
|
||||
text = open(human_file, "r", encoding="utf-8").read()
|
||||
name = os.path.basename(human_file).replace(".txt", "")
|
||||
if ms.get_human(user_id=user_id, name=name) is not None:
|
||||
printd(f"Human '{name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
human = HumanModel(name=name, text=text, user_id=user_id)
|
||||
print(human, user_id)
|
||||
ms.add_human(human)
|
||||
|
||||
|
||||
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
||||
def create_agent_from_preset(
|
||||
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
|
||||
):
|
||||
"""Initialize a new agent from a preset (combination of system + function)"""
|
||||
raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead")
|
79
memgpt/presets/utils.py
Normal file
79
memgpt/presets/utils.py
Normal file
@ -0,0 +1,79 @@
|
||||
import glob
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
|
||||
|
||||
def is_valid_yaml_format(yaml_data, function_set):
|
||||
"""
|
||||
Check if the given YAML data follows the specified format and if all functions in the yaml are part of the function_set.
|
||||
Raises ValueError if any check fails.
|
||||
|
||||
:param yaml_data: The data loaded from a YAML file.
|
||||
:param function_set: A set of valid function names.
|
||||
"""
|
||||
# Check for required keys
|
||||
if not all(key in yaml_data for key in ["system_prompt", "functions"]):
|
||||
raise ValueError("YAML data is missing one or more required keys: 'system_prompt', 'functions'.")
|
||||
|
||||
# Check if 'functions' is a list of strings
|
||||
if not all(isinstance(item, str) for item in yaml_data.get("functions", [])):
|
||||
raise ValueError("'functions' should be a list of strings.")
|
||||
|
||||
# Check if all functions in YAML are part of function_set
|
||||
if not set(yaml_data["functions"]).issubset(function_set):
|
||||
raise ValueError(
|
||||
f"Some functions in YAML are not part of the provided function set: {set(yaml_data['functions']) - set(function_set)} "
|
||||
)
|
||||
|
||||
# If all checks pass
|
||||
return True
|
||||
|
||||
|
||||
def load_yaml_file(file_path):
|
||||
"""
|
||||
Load a YAML file and return the data.
|
||||
|
||||
:param file_path: Path to the YAML file.
|
||||
:return: Data from the YAML file.
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
return yaml.safe_load(file)
|
||||
|
||||
|
||||
def load_all_presets():
|
||||
"""Load all the preset configs in the examples directory"""
|
||||
|
||||
## Load the examples
|
||||
# Get the directory in which the script is located
|
||||
script_directory = os.path.dirname(os.path.abspath(__file__))
|
||||
# Construct the path pattern
|
||||
example_path_pattern = os.path.join(script_directory, "examples", "*.yaml")
|
||||
# Listing all YAML files
|
||||
example_yaml_files = glob.glob(example_path_pattern)
|
||||
|
||||
## Load the user-provided presets
|
||||
# ~/.memgpt/presets/*.yaml
|
||||
user_presets_dir = os.path.join(MEMGPT_DIR, "presets")
|
||||
# Create directory if it doesn't exist
|
||||
if not os.path.exists(user_presets_dir):
|
||||
os.makedirs(user_presets_dir)
|
||||
# Construct the path pattern
|
||||
user_path_pattern = os.path.join(user_presets_dir, "*.yaml")
|
||||
# Listing all YAML files
|
||||
user_yaml_files = glob.glob(user_path_pattern)
|
||||
|
||||
# Pull from both examplesa and user-provided
|
||||
all_yaml_files = example_yaml_files + user_yaml_files
|
||||
|
||||
# Loading and creating a mapping from file name to YAML data
|
||||
all_yaml_data = {}
|
||||
for file_path in all_yaml_files:
|
||||
# Extracting the base file name without the '.yaml' extension
|
||||
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
data = load_yaml_file(file_path)
|
||||
all_yaml_data[base_name] = data
|
||||
|
||||
return all_yaml_data
|
@ -1,90 +0,0 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.schemas.memory import Memory
|
||||
|
||||
|
||||
class BaseAgent(MemGPTBase, validate_assignment=True):
|
||||
__id_prefix__ = "agent"
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
|
||||
# metadata
|
||||
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
||||
user_id: Optional[str] = Field(None, description="The user id of the agent.")
|
||||
|
||||
|
||||
class AgentState(BaseAgent):
|
||||
"""Representation of an agent's state."""
|
||||
|
||||
id: str = BaseAgent.generate_id_field()
|
||||
name: str = Field(..., description="The name of the agent.")
|
||||
created_at: datetime = Field(..., description="The datetime the agent was created.", default_factory=datetime.now)
|
||||
|
||||
# in-context memory
|
||||
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
|
||||
memory: Memory = Field(default_factory=Memory, description="The in-context memory of the agent.")
|
||||
|
||||
# tools
|
||||
tools: List[str] = Field(..., description="The tools used by the agent.")
|
||||
|
||||
# system prompt
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
|
||||
# llm information
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
||||
|
||||
|
||||
class CreateAgent(BaseAgent):
|
||||
# all optional as server can generate defaults
|
||||
name: Optional[str] = Field(None, description="The name of the agent.")
|
||||
message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
def validate_name(cls, name: str) -> str:
|
||||
"""Validate the requested new agent name (prevent bad inputs)"""
|
||||
|
||||
import re
|
||||
|
||||
if not name:
|
||||
# don't check if not provided
|
||||
return name
|
||||
|
||||
# TODO: this check should also be added to other model (e.g. User.name)
|
||||
# Length check
|
||||
if not (1 <= len(name) <= 50):
|
||||
raise ValueError("Name length must be between 1 and 50 characters.")
|
||||
|
||||
# Regex for allowed characters (alphanumeric, spaces, hyphens, underscores)
|
||||
if not re.match("^[A-Za-z0-9 _-]+$", name):
|
||||
raise ValueError("Name contains invalid characters.")
|
||||
|
||||
# Further checks can be added here...
|
||||
# TODO
|
||||
|
||||
return name
|
||||
|
||||
|
||||
class UpdateAgentState(BaseAgent):
|
||||
id: str = Field(..., description="The id of the agent.")
|
||||
name: Optional[str] = Field(None, description="The name of the agent.")
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
|
||||
# TODO: determine if these should be editable via this schema?
|
||||
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
|
||||
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
@ -1,21 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
|
||||
|
||||
class BaseAPIKey(MemGPTBase):
|
||||
__id_prefix__ = "sk" # secret key
|
||||
|
||||
|
||||
class APIKey(BaseAPIKey):
|
||||
id: str = BaseAPIKey.generate_id_field()
|
||||
user_id: str = Field(..., description="The unique identifier of the user associated with the token.")
|
||||
key: str = Field(..., description="The key value.")
|
||||
name: str = Field(..., description="Name of the token.")
|
||||
|
||||
|
||||
class APIKeyCreate(BaseAPIKey):
|
||||
user_id: str = Field(..., description="The unique identifier of the user associated with the token.")
|
||||
name: Optional[str] = Field(None, description="Name of the token.")
|
@ -1,117 +0,0 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
|
||||
# block of the LLM context
|
||||
|
||||
|
||||
class BaseBlock(MemGPTBase, validate_assignment=True):
|
||||
"""Base block of the LLM context"""
|
||||
|
||||
__id_prefix__ = "block"
|
||||
|
||||
# data value
|
||||
value: Optional[Union[List[str], str]] = Field(None, description="Value of the block.")
|
||||
limit: int = Field(2000, description="Character limit of the block.")
|
||||
|
||||
name: Optional[str] = Field(None, description="Name of the block.")
|
||||
template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).")
|
||||
label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona').")
|
||||
|
||||
# metadat
|
||||
description: Optional[str] = Field(None, description="Description of the block.")
|
||||
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")
|
||||
|
||||
# associated user/agent
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the block.")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def verify_char_limit(self) -> Self:
|
||||
try:
|
||||
assert len(self) <= self.limit
|
||||
except AssertionError:
|
||||
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self)})."
|
||||
raise ValueError(error_msg)
|
||||
except Exception as e:
|
||||
raise e
|
||||
return self
|
||||
|
||||
def __len__(self):
|
||||
return len(str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, list):
|
||||
return ",".join(self.value)
|
||||
elif isinstance(self.value, str):
|
||||
return self.value
|
||||
else:
|
||||
return ""
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Run validation if self.value is updated"""
|
||||
super().__setattr__(name, value)
|
||||
if name == "value":
|
||||
# run validation
|
||||
self.__class__.validate(self.dict(exclude_unset=True))
|
||||
|
||||
|
||||
class Block(BaseBlock):
|
||||
"""Block of the LLM context"""
|
||||
|
||||
id: str = BaseBlock.generate_id_field()
|
||||
value: str = Field(..., description="Value of the block.")
|
||||
|
||||
|
||||
class Human(Block):
|
||||
"""Human block of the LLM context"""
|
||||
|
||||
label: str = "human"
|
||||
|
||||
|
||||
class Persona(Block):
|
||||
"""Persona block of the LLM context"""
|
||||
|
||||
label: str = "persona"
|
||||
|
||||
|
||||
class CreateBlock(BaseBlock):
|
||||
"""Create a block"""
|
||||
|
||||
template: bool = True
|
||||
label: str = Field(..., description="Label of the block.")
|
||||
|
||||
|
||||
class CreatePersona(BaseBlock):
|
||||
"""Create a persona block"""
|
||||
|
||||
template: bool = True
|
||||
label: str = "persona"
|
||||
|
||||
|
||||
class CreateHuman(BaseBlock):
|
||||
"""Create a human block"""
|
||||
|
||||
template: bool = True
|
||||
label: str = "human"
|
||||
|
||||
|
||||
class UpdateBlock(BaseBlock):
|
||||
"""Update a block"""
|
||||
|
||||
id: str = Field(..., description="The unique identifier of the block.")
|
||||
limit: Optional[int] = Field(2000, description="Character limit of the block.")
|
||||
|
||||
|
||||
class UpdatePersona(UpdateBlock):
|
||||
"""Update a persona block"""
|
||||
|
||||
label: str = "persona"
|
||||
|
||||
|
||||
class UpdateHuman(UpdateBlock):
|
||||
"""Update a human block"""
|
||||
|
||||
label: str = "human"
|
@ -1,21 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
|
||||
|
||||
class DocumentBase(MemGPTBase):
|
||||
"""Base class for document schemas"""
|
||||
|
||||
__id_prefix__ = "doc"
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Representation of a single document (broken up into `Passage` objects)"""
|
||||
|
||||
id: str = DocumentBase.generate_id_field()
|
||||
text: str = Field(..., description="The text of the document.")
|
||||
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
|
||||
user_id: str = Field(description="The unique identifier of the user associated with the document.")
|
||||
metadata_: Optional[Dict] = Field({}, description="The metadata of the document.")
|
@ -1,18 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class EmbeddingConfig(BaseModel):
|
||||
"""Embedding model configuration"""
|
||||
|
||||
embedding_endpoint_type: str = Field(..., description="The endpoint type for the model.")
|
||||
embedding_endpoint: Optional[str] = Field(None, description="The endpoint for the model (`None` if local).")
|
||||
embedding_model: str = Field(..., description="The model for the embedding.")
|
||||
embedding_dim: int = Field(..., description="The dimension of the embedding.")
|
||||
embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.")
|
||||
|
||||
# azure only
|
||||
azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.")
|
||||
azure_version: Optional[str] = Field(None, description="The Azure version for the model.")
|
||||
azure_deployment: Optional[str] = Field(None, description="The Azure deployment for the model.")
|
@ -1,30 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MessageRole(str, Enum):
|
||||
assistant = "assistant"
|
||||
user = "user"
|
||||
tool = "tool"
|
||||
system = "system"
|
||||
|
||||
|
||||
class OptionState(str, Enum):
|
||||
"""Useful for kwargs that are bool + default option"""
|
||||
|
||||
YES = "yes"
|
||||
NO = "no"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class JobStatus(str, Enum):
|
||||
created = "created"
|
||||
running = "running"
|
||||
completed = "completed"
|
||||
failed = "failed"
|
||||
pending = "pending"
|
||||
|
||||
|
||||
class MessageStreamStatus(str, Enum):
|
||||
done_generation = "[DONE_GEN]"
|
||||
done_step = "[DONE_STEP]"
|
||||
done = "[DONE]"
|
@ -1,28 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from memgpt.schemas.enums import JobStatus
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
|
||||
class JobBase(MemGPTBase):
|
||||
__id_prefix__ = "job"
|
||||
metadata_: Optional[dict] = Field({}, description="The metadata of the job.")
|
||||
|
||||
|
||||
class Job(JobBase):
|
||||
"""Representation of offline jobs."""
|
||||
|
||||
id: str = JobBase.generate_id_field()
|
||||
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
|
||||
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
|
||||
user_id: str = Field(..., description="The unique identifier of the user associated with the job.")
|
||||
|
||||
|
||||
class JobUpdate(JobBase):
|
||||
id: str = Field(..., description="The unique identifier of the job.")
|
||||
status: Optional[JobStatus] = Field(..., description="The status of the job.")
|
@ -1,15 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class LLMConfig(BaseModel):
|
||||
# TODO: 🤮 don't default to a vendor! bug city!
|
||||
model: str = Field(..., description="LLM model name. ")
|
||||
model_endpoint_type: str = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: str = Field(..., description="The endpoint for the model.")
|
||||
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
||||
context_window: int = Field(..., description="The context window size for the model.")
|
||||
|
||||
# FIXME hack to silence pydantic protected namespace warning
|
||||
model_config = ConfigDict(protected_namespaces=())
|
@ -1,80 +0,0 @@
|
||||
import uuid
|
||||
from logging import getLogger
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
# from: https://gist.github.com/norton120/22242eadb80bf2cf1dd54a961b151c61
|
||||
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
class MemGPTBase(BaseModel):
|
||||
"""Base schema for MemGPT schemas (does not include model provider schemas, e.g. OpenAI)"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
# allows you to use the snake or camelcase names in your code (ie user_id or userId)
|
||||
populate_by_name=True,
|
||||
# allows you do dump a sqlalchemy object directly (ie PersistedAddress.model_validate(SQLAdress)
|
||||
from_attributes=True,
|
||||
# throw errors if attributes are given that don't belong
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
# def __id_prefix__(self):
|
||||
# raise NotImplementedError("All schemas must have an __id_prefix__ attribute!")
|
||||
|
||||
@classmethod
|
||||
def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
|
||||
prefix = prefix or cls.__id_prefix__
|
||||
|
||||
# TODO: generate ID from regex pattern?
|
||||
def _generate_id() -> str:
|
||||
return f"{prefix}-{uuid.uuid4()}"
|
||||
|
||||
return Field(
|
||||
...,
|
||||
description=cls._id_description(prefix),
|
||||
pattern=cls._id_regex_pattern(prefix),
|
||||
examples=[cls._id_example(prefix)],
|
||||
default_factory=_generate_id,
|
||||
)
|
||||
|
||||
# def _generate_id(self) -> str:
|
||||
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
|
||||
|
||||
@classmethod
|
||||
def _id_regex_pattern(cls, prefix: str):
|
||||
"""generates the regex pattern for a given id"""
|
||||
return (
|
||||
r"^" + prefix + r"-" # prefix string
|
||||
r"[a-fA-F0-9]{8}" # 8 hexadecimal characters
|
||||
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
|
||||
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
|
||||
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
|
||||
# r"[a-fA-F0-9]{12}$" # 12 hexadecimal characters
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _id_example(cls, prefix: str):
|
||||
"""generates an example id for a given prefix"""
|
||||
return [prefix + "-123e4567-e89b-12d3-a456-426614174000"]
|
||||
|
||||
@classmethod
|
||||
def _id_description(cls, prefix: str):
|
||||
"""generates a factory function for a given prefix"""
|
||||
return f"The human-friendly ID of the {prefix.capitalize()}"
|
||||
|
||||
@field_validator("id", check_fields=False, mode="before")
|
||||
@classmethod
|
||||
def allow_bare_uuids(cls, v, values):
|
||||
"""to ease the transition to stripe ids,
|
||||
we allow bare uuids and convert them with a warning
|
||||
"""
|
||||
_ = values # for SCA
|
||||
if isinstance(v, UUID):
|
||||
logger.warning("Bare UUIDs are deprecated, please use the full prefixed id!")
|
||||
return f"{cls.__id_prefix__}-{v}"
|
||||
return v
|
@ -1,78 +0,0 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_serializer
|
||||
|
||||
# MemGPT API style responses (intended to be easier to use vs getting true Message types)
|
||||
|
||||
|
||||
class BaseMemGPTMessage(BaseModel):
|
||||
id: str
|
||||
date: datetime
|
||||
|
||||
@field_serializer("date")
|
||||
def serialize_datetime(self, dt: datetime, _info):
|
||||
return dt.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
class InternalMonologue(BaseMemGPTMessage):
|
||||
"""
|
||||
{
|
||||
"internal_monologue": msg,
|
||||
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
|
||||
"id": str(msg_obj.id) if msg_obj is not None else None,
|
||||
}
|
||||
"""
|
||||
|
||||
internal_monologue: str
|
||||
|
||||
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class FunctionCallMessage(BaseMemGPTMessage):
|
||||
"""
|
||||
{
|
||||
"function_call": {
|
||||
"name": function_call.function.name,
|
||||
"arguments": function_call.function.arguments,
|
||||
},
|
||||
"id": str(msg_obj.id),
|
||||
"date": msg_obj.created_at.isoformat(),
|
||||
}
|
||||
"""
|
||||
|
||||
function_call: FunctionCall
|
||||
|
||||
|
||||
class FunctionReturn(BaseMemGPTMessage):
|
||||
"""
|
||||
{
|
||||
"function_return": msg,
|
||||
"status": "success" or "error",
|
||||
"id": str(msg_obj.id),
|
||||
"date": msg_obj.created_at.isoformat(),
|
||||
}
|
||||
"""
|
||||
|
||||
function_return: str
|
||||
status: Literal["success", "error"]
|
||||
|
||||
|
||||
MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn]
|
||||
|
||||
|
||||
# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string
|
||||
|
||||
|
||||
class AssistantMessage(BaseMemGPTMessage):
|
||||
assistant_message: str
|
||||
|
||||
|
||||
class LegacyFunctionCallMessage(BaseMemGPTMessage):
|
||||
function_call: str
|
||||
|
||||
|
||||
LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]
|
@ -1,18 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.message import MessageCreate
|
||||
|
||||
|
||||
class MemGPTRequest(BaseModel):
|
||||
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
||||
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement
|
||||
|
||||
stream_steps: bool = Field(
|
||||
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
|
||||
)
|
||||
stream_tokens: bool = Field(
|
||||
default=False,
|
||||
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
|
||||
)
|
@ -1,17 +0,0 @@
|
||||
from typing import List, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.usage import MemGPTUsageStatistics
|
||||
|
||||
# TODO: consider moving into own file
|
||||
|
||||
|
||||
class MemGPTResponse(BaseModel):
|
||||
# messages: List[Message] = Field(..., description="The messages returned by the agent.")
|
||||
messages: Union[List[Message], List[MemGPTMessage], List[LegacyMemGPTMessage]] = Field(
|
||||
..., description="The messages returned by the agent."
|
||||
)
|
||||
usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.")
|
@ -1,138 +0,0 @@
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.block import Block
|
||||
|
||||
|
||||
class Memory(BaseModel, validate_assignment=True):
|
||||
"""Represents the in-context memory of the agent"""
|
||||
|
||||
# Private variable to avoid assignments with incorrect types
|
||||
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")
|
||||
|
||||
@classmethod
|
||||
def load(cls, state: dict):
|
||||
"""Load memory from dictionary object"""
|
||||
obj = cls()
|
||||
for key, value in state.items():
|
||||
obj.memory[key] = Block(**value)
|
||||
return obj
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Representation of the memory in-context"""
|
||||
section_strs = []
|
||||
for section, module in self.memory.items():
|
||||
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
|
||||
return "\n".join(section_strs)
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary representation"""
|
||||
return {key: value.dict() for key, value in self.memory.items()}
|
||||
|
||||
def to_flat_dict(self):
|
||||
"""Convert to a dictionary that maps directly from block names to values"""
|
||||
return {k: v.value for k, v in self.memory.items() if v is not None}
|
||||
|
||||
def list_block_names(self) -> List[str]:
|
||||
"""Return a list of the block names held inside the memory object"""
|
||||
return list(self.memory.keys())
|
||||
|
||||
def get_block(self, name: str) -> Block:
|
||||
"""Correct way to index into the memory.memory field, returns a Block"""
|
||||
if name not in self.memory:
|
||||
return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
|
||||
else:
|
||||
return self.memory[name]
|
||||
|
||||
def link_block(self, name: str, block: Block, override: Optional[bool] = False):
|
||||
"""Link a new block to the memory object"""
|
||||
if not isinstance(block, Block):
|
||||
raise ValueError(f"Param block must be type Block (not {type(block)})")
|
||||
if not isinstance(name, str):
|
||||
raise ValueError(f"Name must be str (not type {type(name)})")
|
||||
if not override and name in self.memory:
|
||||
raise ValueError(f"Block with name {name} already exists")
|
||||
|
||||
self.memory[name] = block
|
||||
|
||||
def update_block_value(self, name: str, value: Union[List[str], str]):
|
||||
"""Update the value of a block"""
|
||||
if name not in self.memory:
|
||||
raise ValueError(f"Block with name {name} does not exist")
|
||||
if not (isinstance(value, str) or (isinstance(value, list) and all(isinstance(v, str) for v in value))):
|
||||
raise ValueError(f"Provided value must be a string or list of strings")
|
||||
|
||||
self.memory[name].value = value
|
||||
|
||||
|
||||
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
|
||||
class BaseChatMemory(Memory):
|
||||
def core_memory_append(self, name: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
current_value = str(self.memory.get_block(name).value)
|
||||
new_value = current_value + "\n" + str(content)
|
||||
self.memory.update_block_value(name=name, value=new_value)
|
||||
return None
|
||||
|
||||
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
current_value = str(self.memory.get_block(name).value)
|
||||
new_value = current_value.replace(str(old_content), str(new_content))
|
||||
self.memory.update_block_value(name=name, value=new_value)
|
||||
return None
|
||||
|
||||
|
||||
class ChatMemory(BaseChatMemory):
|
||||
"""
|
||||
ChatMemory initializes a BaseChatMemory with two default blocks
|
||||
"""
|
||||
|
||||
def __init__(self, persona: str, human: str, limit: int = 2000):
|
||||
super().__init__()
|
||||
self.link_block(name="persona", block=Block(name="persona", value=persona, limit=limit, label="persona"))
|
||||
self.link_block(name="human", block=Block(name="human", value=human, limit=limit, label="human"))
|
||||
|
||||
|
||||
class BlockChatMemory(BaseChatMemory):
|
||||
"""
|
||||
BlockChatMemory is a subclass of BaseChatMemory which uses shared memory blocks specified at initialization-time.
|
||||
"""
|
||||
|
||||
def __init__(self, blocks: List[Block] = []):
|
||||
super().__init__()
|
||||
for block in blocks:
|
||||
# TODO: centralize these internal schema validations
|
||||
assert block.name is not None and block.name != "", "each existing chat block must have a name"
|
||||
self.link_block(name=block.name, block=block)
|
||||
|
||||
|
||||
class UpdateMemory(BaseModel):
|
||||
"""Update the memory of the agent"""
|
||||
|
||||
|
||||
class ArchivalMemorySummary(BaseModel):
|
||||
size: int = Field(..., description="Number of rows in archival memory")
|
||||
|
||||
|
||||
class RecallMemorySummary(BaseModel):
|
||||
size: int = Field(..., description="Number of rows in recall memory")
|
@ -1,123 +0,0 @@
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SystemMessage(BaseModel):
|
||||
content: str
|
||||
role: str = "system"
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
content: Union[str, List[str]]
|
||||
role: str = "user"
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
name: str = Field(..., description="The name of the function to call")
|
||||
arguments: str = Field(..., description="The arguments to pass to the function (JSON dump)")
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str = Field(..., description="The ID of the tool call")
|
||||
type: str = "function"
|
||||
function: ToolCallFunction = Field(..., description="The arguments and name for the function")
|
||||
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
content: Optional[str] = None
|
||||
role: str = "assistant"
|
||||
name: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
|
||||
|
||||
class ToolMessage(BaseModel):
|
||||
content: str
|
||||
role: str = "tool"
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
|
||||
|
||||
|
||||
# TODO: this might not be necessary with the validator
|
||||
def cast_message_to_subtype(m_dict: dict) -> ChatMessage:
|
||||
"""Cast a dictionary to one of the individual message types"""
|
||||
role = m_dict.get("role")
|
||||
if role == "system":
|
||||
return SystemMessage(**m_dict)
|
||||
elif role == "user":
|
||||
return UserMessage(**m_dict)
|
||||
elif role == "assistant":
|
||||
return AssistantMessage(**m_dict)
|
||||
elif role == "tool":
|
||||
return ToolMessage(**m_dict)
|
||||
else:
|
||||
raise ValueError("Unknown message role")
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
type: str = Field(default="text", pattern="^(text|json_object)$")
|
||||
|
||||
|
||||
## tool_choice ##
|
||||
class FunctionCall(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ToolFunctionChoice(BaseModel):
|
||||
# The type of the tool. Currently, only function is supported
|
||||
type: Literal["function"] = "function"
|
||||
# type: str = Field(default="function", const=True)
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice]
|
||||
|
||||
|
||||
## tools ##
|
||||
class FunctionSchema(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, Any]] = None # JSON Schema for the parameters
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
# The type of the tool. Currently, only function is supported
|
||||
type: Literal["function"] = "function"
|
||||
# type: str = Field(default="function", const=True)
|
||||
function: FunctionSchema
|
||||
|
||||
|
||||
## function_call ##
|
||||
FunctionCallChoice = Union[Literal["none", "auto"], FunctionCall]
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""https://platform.openai.com/docs/api-reference/chat/create"""
|
||||
|
||||
model: str
|
||||
messages: List[ChatMessage]
|
||||
frequency_penalty: Optional[float] = 0
|
||||
logit_bias: Optional[Dict[str, int]] = None
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = 1
|
||||
presence_penalty: Optional[float] = 0
|
||||
response_format: Optional[ResponseFormat] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stream: Optional[bool] = False
|
||||
temperature: Optional[float] = 1
|
||||
top_p: Optional[float] = 1
|
||||
user: Optional[str] = None # unique ID of the end-user (for monitoring)
|
||||
|
||||
# function-calling related
|
||||
tools: Optional[List[Tool]] = None
|
||||
tool_choice: Optional[ToolChoice] = "none"
|
||||
# deprecated scheme
|
||||
functions: Optional[List[FunctionSchema]] = None
|
||||
function_call: Optional[FunctionCallChoice] = None
|
@ -1,66 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
|
||||
class PassageBase(MemGPTBase):
|
||||
__id_prefix__ = "passage"
|
||||
|
||||
# associated user/agent
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
|
||||
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.")
|
||||
|
||||
# origin data source
|
||||
source_id: Optional[str] = Field(None, description="The data source of the passage.")
|
||||
|
||||
# document association
|
||||
doc_id: Optional[str] = Field(None, description="The unique identifier of the document associated with the passage.")
|
||||
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
|
||||
|
||||
|
||||
class Passage(PassageBase):
|
||||
id: str = PassageBase.generate_id_field()
|
||||
|
||||
# passage text
|
||||
text: str = Field(..., description="The text of the passage.")
|
||||
|
||||
# embeddings
|
||||
embedding: Optional[List[float]] = Field(..., description="The embedding of the passage.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(..., description="The embedding configuration used by the passage.")
|
||||
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.")
|
||||
|
||||
@field_validator("embedding")
|
||||
@classmethod
|
||||
def pad_embeddings(cls, embedding: List[float]) -> List[float]:
|
||||
"""Pad embeddings to MAX_EMBEDDING_SIZE. This is necessary to ensure all stored embeddings are the same size."""
|
||||
import numpy as np
|
||||
|
||||
if embedding and len(embedding) != MAX_EMBEDDING_DIM:
|
||||
np_embedding = np.array(embedding)
|
||||
padded_embedding = np.pad(np_embedding, (0, MAX_EMBEDDING_DIM - np_embedding.shape[0]), mode="constant")
|
||||
return padded_embedding.tolist()
|
||||
return embedding
|
||||
|
||||
|
||||
class PassageCreate(PassageBase):
|
||||
text: str = Field(..., description="The text of the passage.")
|
||||
|
||||
# optionally provide embeddings
|
||||
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
|
||||
|
||||
|
||||
class PassageUpdate(PassageCreate):
|
||||
id: str = Field(..., description="The unique identifier of the passage.")
|
||||
text: Optional[str] = Field(None, description="The text of the passage.")
|
||||
|
||||
# optionally provide embeddings
|
||||
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
|
@ -1,49 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.utils import get_utc_time
|
||||
|
||||
|
||||
class BaseSource(MemGPTBase):
|
||||
"""
|
||||
Shared attributes accourss all source schemas.
|
||||
"""
|
||||
|
||||
__id_prefix__ = "source"
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
|
||||
# NOTE: .metadata is a reserved attribute on SQLModel
|
||||
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
||||
|
||||
|
||||
class SourceCreate(BaseSource):
|
||||
name: str = Field(..., description="The name of the source.")
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
|
||||
|
||||
class Source(BaseSource):
|
||||
id: str = BaseSource.generate_id_field()
|
||||
name: str = Field(..., description="The name of the source.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the source.")
|
||||
user_id: str = Field(..., description="The ID of the user that created the source.")
|
||||
|
||||
|
||||
class SourceUpdate(BaseSource):
|
||||
id: str = Field(..., description="The ID of the source.")
|
||||
name: Optional[str] = Field(None, description="The name of the source.")
|
||||
|
||||
|
||||
class UploadFileToSourceRequest(BaseModel):
|
||||
file: UploadFile = Field(..., description="The file to upload.")
|
||||
|
||||
|
||||
class UploadFileToSourceResponse(BaseModel):
|
||||
source: Source = Field(..., description="The source the file was uploaded to.")
|
||||
added_passages: int = Field(..., description="The number of passages added to the source.")
|
||||
added_documents: int = Field(..., description="The number of documents added to the source.")
|
@ -1,86 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from memgpt.functions.schema_generator import (
|
||||
generate_schema_from_args_schema,
|
||||
generate_tool_wrapper,
|
||||
)
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
from memgpt.schemas.openai.chat_completions import ToolCall
|
||||
|
||||
|
||||
class BaseTool(MemGPTBase):
|
||||
__id_prefix__ = "tool"
|
||||
|
||||
# optional fields
|
||||
description: Optional[str] = Field(None, description="The description of the tool.")
|
||||
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
||||
module: Optional[str] = Field(None, description="The module of the function.")
|
||||
|
||||
# optional: user_id (user-specific tools)
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the function.")
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
|
||||
id: str = BaseTool.generate_id_field()
|
||||
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
tags: List[str] = Field(..., description="Metadata tags.")
|
||||
|
||||
# code
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert into OpenAI representation"""
|
||||
return vars(
|
||||
ToolCall(
|
||||
tool_id=self.id,
|
||||
tool_call_type="function",
|
||||
function=self.module,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_crewai(cls, crewai_tool) -> "Tool":
|
||||
"""
|
||||
Class method to create an instance of Tool from a crewAI BaseTool object.
|
||||
|
||||
Args:
|
||||
crewai_tool (CrewAIBaseTool): An instance of a crewAI BaseTool (BaseTool from crewai)
|
||||
|
||||
Returns:
|
||||
Tool: A memGPT Tool initialized with attributes derived from the provided crewAI BaseTool object.
|
||||
"""
|
||||
crewai_tool.name
|
||||
description = crewai_tool.description
|
||||
source_type = "python"
|
||||
tags = ["crew-ai"]
|
||||
wrapper_func_name, wrapper_function_str = generate_tool_wrapper(crewai_tool.__class__.__name__)
|
||||
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
return cls(
|
||||
name=wrapper_func_name,
|
||||
description=description,
|
||||
source_type=source_type,
|
||||
tags=tags,
|
||||
source_code=wrapper_function_str,
|
||||
json_schema=json_schema,
|
||||
)
|
||||
|
||||
|
||||
class ToolCreate(BaseTool):
|
||||
name: str = Field(..., description="The name of the function.")
|
||||
tags: List[str] = Field(..., description="Metadata tags.")
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
|
||||
|
||||
|
||||
class ToolUpdate(ToolCreate):
|
||||
id: str = Field(..., description="The unique identifier of the tool.")
|
||||
name: Optional[str] = Field(None, description="The name of the function.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
source_code: Optional[str] = Field(None, description="The source code of the function.")
|
||||
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")
|
@ -1,8 +0,0 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MemGPTUsageStatistics(BaseModel):
|
||||
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
|
||||
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
|
||||
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
|
||||
step_count: int = Field(0, description="The number of steps taken by the agent.")
|
@ -1,20 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
|
||||
|
||||
class UserBase(MemGPTBase):
|
||||
__id_prefix__ = "user"
|
||||
|
||||
|
||||
class User(UserBase):
|
||||
id: str = UserBase.generate_id_field()
|
||||
name: str = Field(..., description="The name of the user.")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
name: Optional[str] = Field(None, description="The name of the user.")
|
@ -1,8 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from memgpt.schemas.agent import AgentState
|
||||
from memgpt.server.rest_api.agents.index import ListAgentsResponse
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
@ -10,12 +8,14 @@ router = APIRouter()
|
||||
|
||||
|
||||
def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
|
||||
@router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
|
||||
def get_all_agents():
|
||||
"""
|
||||
Get a list of all agents in the database
|
||||
"""
|
||||
interface.clear()
|
||||
return server.list_agents()
|
||||
agents_data = server.list_agents_legacy()
|
||||
|
||||
return ListAgentsResponse(**agents_data)
|
||||
|
||||
return router
|
||||
|
@ -3,7 +3,7 @@ from typing import List, Literal, Optional
|
||||
from fastapi import APIRouter, Body, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.tool import Tool as ToolModel # TODO: modify
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
|
@ -1,46 +1,102 @@
|
||||
import uuid
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.api_key import APIKey, APIKeyCreate
|
||||
from memgpt.schemas.user import User, UserCreate
|
||||
from memgpt.data_types import User
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class GetAllUsersResponse(BaseModel):
|
||||
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
|
||||
user_list: List[dict] = Field(..., description="A list of users.")
|
||||
|
||||
|
||||
class CreateUserRequest(BaseModel):
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="Identifier of the user (optional, generated automatically if null).")
|
||||
api_key_name: Optional[str] = Field(None, description="Name for API key autogenerated on user creation (optional).")
|
||||
|
||||
|
||||
class CreateUserResponse(BaseModel):
|
||||
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
|
||||
api_key: str = Field(..., description="New API key generated for user.")
|
||||
|
||||
|
||||
class CreateAPIKeyRequest(BaseModel):
|
||||
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
|
||||
name: Optional[str] = Field(None, description="Name for the API key (optional).")
|
||||
|
||||
|
||||
class CreateAPIKeyResponse(BaseModel):
|
||||
api_key: str = Field(..., description="New API key generated.")
|
||||
|
||||
|
||||
class GetAPIKeysResponse(BaseModel):
|
||||
api_key_list: List[str] = Field(..., description="Identifier of the user (UUID).")
|
||||
|
||||
|
||||
class DeleteAPIKeyResponse(BaseModel):
|
||||
message: str
|
||||
api_key_deleted: str
|
||||
|
||||
|
||||
class DeleteUserResponse(BaseModel):
|
||||
message: str
|
||||
user_id_deleted: uuid.UUID
|
||||
|
||||
|
||||
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
@router.get("/users", tags=["admin"], response_model=List[User])
|
||||
def get_all_users(cursor: Optional[str] = Query(None), limit: Optional[int] = Query(50)):
|
||||
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
|
||||
def get_all_users(cursor: Optional[uuid.UUID] = Query(None), limit: Optional[int] = Query(50)):
|
||||
"""
|
||||
Get a list of all users in the database
|
||||
"""
|
||||
try:
|
||||
# TODO: make this call a server function
|
||||
_, users = server.ms.get_all_users(cursor=cursor, limit=limit)
|
||||
next_cursor, users = server.ms.get_all_users(cursor, limit)
|
||||
processed_users = [{"user_id": user.id} for user in users]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return users
|
||||
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)
|
||||
|
||||
@router.post("/users", tags=["admin"], response_model=User)
|
||||
def create_user(request: UserCreate = Body(...)):
|
||||
@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
|
||||
def create_user(request: Optional[CreateUserRequest] = Body(None)):
|
||||
"""
|
||||
Create a new user in the database
|
||||
"""
|
||||
if request is None:
|
||||
request = CreateUserRequest()
|
||||
|
||||
new_user = User(
|
||||
id=None if not request.user_id else request.user_id,
|
||||
# TODO can add more fields (name? metadata?)
|
||||
)
|
||||
|
||||
try:
|
||||
user = server.create_user(request)
|
||||
server.ms.create_user(new_user)
|
||||
|
||||
# make sure we can retrieve the user from the DB too
|
||||
new_user_ret = server.ms.get_user(new_user.id)
|
||||
if new_user_ret is None:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to verify user creation")
|
||||
|
||||
# create an API key for the user
|
||||
token = server.ms.create_api_key(user_id=new_user.id, name=request.api_key_name)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return user
|
||||
return CreateUserResponse(user_id=new_user_ret.id, api_key=token.token)
|
||||
|
||||
@router.delete("/users", tags=["admin"], response_model=User)
|
||||
@router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
|
||||
def delete_user(
|
||||
user_id: str = Query(..., description="The user_id key to be deleted."),
|
||||
user_id: uuid.UUID = Query(..., description="The user_id key to be deleted."),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
@ -52,24 +108,24 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return user
|
||||
return DeleteUserResponse(message="User successfully deleted.", user_id_deleted=user_id)
|
||||
|
||||
@router.post("/users/keys", tags=["admin"], response_model=APIKey)
|
||||
def create_new_api_key(request: APIKeyCreate = Body(...)):
|
||||
@router.post("/users/keys", tags=["admin"], response_model=CreateAPIKeyResponse)
|
||||
def create_new_api_key(request: CreateAPIKeyRequest = Body(...)):
|
||||
"""
|
||||
Create a new API key for a user
|
||||
"""
|
||||
try:
|
||||
api_key = server.create_api_key(request)
|
||||
token = server.ms.create_api_key(user_id=request.user_id, name=request.name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return api_key
|
||||
return CreateAPIKeyResponse(api_key=token.token)
|
||||
|
||||
@router.get("/users/keys", tags=["admin"], response_model=List[APIKey])
|
||||
@router.get("/users/keys", tags=["admin"], response_model=GetAPIKeysResponse)
|
||||
def get_api_keys(
|
||||
user_id: str = Query(..., description="The unique identifier of the user."),
|
||||
user_id: uuid.UUID = Query(..., description="The unique identifier of the user."),
|
||||
):
|
||||
"""
|
||||
Get a list of all API keys for a user
|
||||
@ -77,22 +133,28 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
try:
|
||||
if server.ms.get_user(user_id=user_id) is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id)
|
||||
tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
|
||||
processed_tokens = [t.token for t in tokens]
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return api_keys
|
||||
print("TOKENS", processed_tokens)
|
||||
return GetAPIKeysResponse(api_key_list=processed_tokens)
|
||||
|
||||
@router.delete("/users/keys", tags=["admin"], response_model=APIKey)
|
||||
@router.delete("/users/keys", tags=["admin"], response_model=DeleteAPIKeyResponse)
|
||||
def delete_api_key(
|
||||
api_key: str = Query(..., description="The API key to be deleted."),
|
||||
):
|
||||
try:
|
||||
return server.delete_api_key(api_key)
|
||||
token = server.ms.get_api_key(api_key=api_key)
|
||||
if token is None:
|
||||
raise HTTPException(status_code=404, detail=f"API key does not exist")
|
||||
server.ms.delete_api_key(api_key=api_key)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return DeleteAPIKeyResponse(message="API key successfully deleted.", api_key_deleted=api_key)
|
||||
|
||||
return router
|
||||
|
48
memgpt/server/rest_api/agents/command.py
Normal file
48
memgpt/server/rest_api/agents/command.py
Normal file
@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CommandRequest(BaseModel):
|
||||
command: str = Field(..., description="The command to be executed by the agent.")
|
||||
|
||||
|
||||
class CommandResponse(BaseModel):
|
||||
response: str = Field(..., description="The result of the executed command.")
|
||||
|
||||
|
||||
def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse)
|
||||
def run_command(
|
||||
agent_id: uuid.UUID,
|
||||
request: CommandRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Execute a command on a specified agent.
|
||||
|
||||
This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately.
|
||||
|
||||
Raises an HTTPException for any processing errors.
|
||||
"""
|
||||
interface.clear()
|
||||
try:
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return CommandResponse(response=response)
|
||||
|
||||
return router
|
156
memgpt/server/rest_api/agents/config.py
Normal file
156
memgpt/server/rest_api/agents/config.py
Normal file
@ -0,0 +1,156 @@
|
||||
import re
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import (
|
||||
AgentStateModel,
|
||||
EmbeddingConfigModel,
|
||||
LLMConfigModel,
|
||||
)
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class AgentRenameRequest(BaseModel):
|
||||
agent_name: str = Field(..., description="New name for the agent.")
|
||||
|
||||
|
||||
class GetAgentResponse(BaseModel):
|
||||
# config: dict = Field(..., description="The agent configuration object.")
|
||||
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
|
||||
sources: List[str] = Field(..., description="The list of data sources associated with the agent.")
|
||||
last_run_at: Optional[int] = Field(None, description="The unix timestamp of when the agent was last run.")
|
||||
|
||||
|
||||
def validate_agent_name(name: str) -> str:
|
||||
"""Validate the requested new agent name (prevent bad inputs)"""
|
||||
|
||||
# Length check
|
||||
if not (1 <= len(name) <= 50):
|
||||
raise HTTPException(status_code=400, detail="Name length must be between 1 and 50 characters.")
|
||||
|
||||
# Regex for allowed characters (alphanumeric, spaces, hyphens, underscores)
|
||||
if not re.match("^[A-Za-z0-9 _-]+$", name):
|
||||
raise HTTPException(status_code=400, detail="Name contains invalid characters.")
|
||||
|
||||
# Further checks can be added here...
|
||||
# TODO
|
||||
|
||||
return name
|
||||
|
||||
|
||||
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse)
|
||||
def get_agent_config(
|
||||
agent_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the configuration for a specific agent.
|
||||
|
||||
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
||||
"""
|
||||
|
||||
interface.clear()
|
||||
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
|
||||
# agent does not exist
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
|
||||
# get sources
|
||||
attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
|
||||
# configs
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
return GetAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
user_id=agent_state.user_id,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
tools=agent_state.tools,
|
||||
system=agent_state.system,
|
||||
metadata=agent_state._metadata,
|
||||
),
|
||||
last_run_at=None, # TODO
|
||||
sources=attached_sources,
|
||||
)
|
||||
|
||||
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
|
||||
def update_agent_name(
|
||||
agent_id: uuid.UUID,
|
||||
request: AgentRenameRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Updates the name of a specific agent.
|
||||
|
||||
This changes the name of the agent in the database but does NOT edit the agent's persona.
|
||||
"""
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
|
||||
valid_name = validate_agent_name(request.agent_name)
|
||||
|
||||
interface.clear()
|
||||
try:
|
||||
agent_state = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name)
|
||||
# get sources
|
||||
attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
return GetAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
user_id=agent_state.user_id,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
tools=agent_state.tools,
|
||||
system=agent_state.system,
|
||||
),
|
||||
last_run_at=None, # TODO
|
||||
sources=attached_sources,
|
||||
)
|
||||
|
||||
@router.delete("/agents/{agent_id}", tags=["agents"])
|
||||
def delete_agent(
|
||||
agent_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
# agent_id = uuid.UUID(agent_id)
|
||||
|
||||
interface.clear()
|
||||
try:
|
||||
server.delete_agent(user_id=user_id, agent_id=agent_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent agent_id={agent_id} successfully deleted"})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
return router
|
@ -1,22 +1,49 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from memgpt.constants import BASE_TOOLS
|
||||
from memgpt.memory import ChatMemory
|
||||
from memgpt.models.pydantic_models import (
|
||||
AgentStateModel,
|
||||
EmbeddingConfigModel,
|
||||
LLMConfigModel,
|
||||
PresetModel,
|
||||
)
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListAgentsResponse(BaseModel):
|
||||
num_agents: int = Field(..., description="The number of agents available to the user.")
|
||||
# TODO make return type List[AgentStateModel]
|
||||
# also return - presets: List[PresetModel]
|
||||
agents: List[dict] = Field(..., description="List of agent configurations.")
|
||||
|
||||
|
||||
class CreateAgentRequest(BaseModel):
|
||||
# TODO: modify this (along with front end)
|
||||
config: dict = Field(..., description="The agent configuration object.")
|
||||
|
||||
|
||||
class CreateAgentResponse(BaseModel):
|
||||
agent_state: AgentStateModel = Field(..., description="The state of the newly created agent.")
|
||||
preset: PresetModel = Field(..., description="The preset that the agent was created from.")
|
||||
|
||||
|
||||
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
|
||||
@router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
|
||||
def list_agents(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
@ -24,71 +51,95 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
interface.clear()
|
||||
agents_data = server.list_agents(user_id=user_id)
|
||||
return agents_data
|
||||
agents_data = server.list_agents_legacy(user_id=user_id)
|
||||
return ListAgentsResponse(**agents_data)
|
||||
|
||||
@router.post("/agents", tags=["agents"], response_model=AgentState)
|
||||
@router.post("/agents", tags=["agents"], response_model=CreateAgentResponse)
|
||||
def create_agent(
|
||||
request: CreateAgent = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
request: CreateAgentRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
interface.clear()
|
||||
|
||||
agent_state = server.create_agent(request, user_id=user_id)
|
||||
return agent_state
|
||||
# Parse request
|
||||
# TODO: don't just use JSON in the future
|
||||
human_name = request.config["human_name"] if "human_name" in request.config else None
|
||||
human = request.config["human"] if "human" in request.config else None
|
||||
persona_name = request.config["persona_name"] if "persona_name" in request.config else None
|
||||
persona = request.config["persona"] if "persona" in request.config else None
|
||||
request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
|
||||
tool_names = request.config["function_names"] if ("function_names" in request.config and request.config["function_names"]) else None
|
||||
metadata = request.config["metadata"] if "metadata" in request.config else {}
|
||||
metadata["human"] = human_name
|
||||
metadata["persona"] = persona_name
|
||||
|
||||
# TODO: remove this -- should be added based on create agent fields
|
||||
if isinstance(tool_names, str): # TODO: fix this on clinet side?
|
||||
tool_names = tool_names.split(",")
|
||||
if tool_names is None or tool_names == "":
|
||||
tool_names = []
|
||||
for name in BASE_TOOLS: # TODO: remove this
|
||||
if name not in tool_names:
|
||||
tool_names.append(name)
|
||||
assert isinstance(tool_names, list), "Tool names must be a list of strings."
|
||||
|
||||
# TODO: eventually remove this - should support general memory at the REST endpoint
|
||||
# TODO: the REST server should add default memory tools at startup time
|
||||
memory = ChatMemory(persona=persona, human=human)
|
||||
|
||||
@router.post("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
|
||||
def update_agent(
|
||||
agent_id: str,
|
||||
request: UpdateAgentState = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Update an exsiting agent"""
|
||||
interface.clear()
|
||||
try:
|
||||
# TODO: should id be moved out of UpdateAgentState?
|
||||
agent_state = server.update_agent(request, user_id=user_id)
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
# **request.config
|
||||
# TODO turn into a pydantic model
|
||||
name=request.config["name"],
|
||||
memory=memory,
|
||||
system=request.config.get("system", None),
|
||||
# persona_name=persona_name,
|
||||
# human_name=human_name,
|
||||
# persona=persona,
|
||||
# human=human,
|
||||
# llm_config=LLMConfigModel(
|
||||
# model=request.config['model'],
|
||||
# )
|
||||
# tools
|
||||
tools=tool_names,
|
||||
metadata=metadata,
|
||||
# function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
|
||||
)
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
return CreateAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
user_id=agent_state.user_id,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
tools=agent_state.tools,
|
||||
system=agent_state.system,
|
||||
metadata=agent_state._metadata,
|
||||
),
|
||||
preset=PresetModel( # TODO: remove (placeholder to avoid breaking frontend)
|
||||
name="dummy_preset",
|
||||
id=agent_state.id,
|
||||
user_id=agent_state.user_id,
|
||||
description="",
|
||||
created_at=agent_state.created_at,
|
||||
system=agent_state.system,
|
||||
persona="",
|
||||
human="",
|
||||
functions_schema=[],
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
print(str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return agent_state
|
||||
|
||||
@router.get("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
|
||||
def get_agent_state(
|
||||
agent_id: str = None,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get the state of the agent.
|
||||
"""
|
||||
|
||||
interface.clear()
|
||||
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
|
||||
# agent does not exist
|
||||
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
return server.get_agent_state(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
@router.delete("/agents/{agent_id}", tags=["agents"])
|
||||
def delete_agent(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete an agent.
|
||||
"""
|
||||
# agent_id = str(agent_id)
|
||||
|
||||
interface.clear()
|
||||
try:
|
||||
server.delete_agent(user_id=user_id, agent_id=agent_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
return router
|
||||
|
@ -1,12 +1,11 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@ -14,24 +13,60 @@ from memgpt.server.server import SyncServer
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class CoreMemory(BaseModel):
|
||||
human: str | None = Field(None, description="Human element of the core memory.")
|
||||
persona: str | None = Field(None, description="Persona element of the core memory.")
|
||||
|
||||
|
||||
class GetAgentMemoryResponse(BaseModel):
|
||||
core_memory: CoreMemory = Field(..., description="The state of the agent's core memory.")
|
||||
recall_memory: int = Field(..., description="Size of the agent's recall memory.")
|
||||
archival_memory: int = Field(..., description="Size of the agent's archival memory.")
|
||||
|
||||
|
||||
# NOTE not subclassing CoreMemory since in the request both field are optional
|
||||
class UpdateAgentMemoryRequest(BaseModel):
|
||||
human: str = Field(None, description="Human element of the core memory.")
|
||||
persona: str = Field(None, description="Persona element of the core memory.")
|
||||
|
||||
|
||||
class UpdateAgentMemoryResponse(BaseModel):
|
||||
old_core_memory: CoreMemory = Field(..., description="The previous state of the agent's core memory.")
|
||||
new_core_memory: CoreMemory = Field(..., description="The updated state of the agent's core memory.")
|
||||
|
||||
|
||||
class ArchivalMemoryObject(BaseModel):
|
||||
# TODO move to models/pydantic_models, or inherent from data_types Record
|
||||
id: uuid.UUID = Field(..., description="Unique identifier for the memory object inside the archival memory store.")
|
||||
contents: str = Field(..., description="The memory contents.")
|
||||
|
||||
|
||||
class GetAgentArchivalMemoryResponse(BaseModel):
|
||||
# TODO: make this List[Passage] instead
|
||||
archival_memory: List[ArchivalMemoryObject] = Field(..., description="A list of all memory objects in archival memory.")
|
||||
|
||||
|
||||
class InsertAgentArchivalMemoryRequest(BaseModel):
|
||||
content: str = Field(..., description="The memory contents to insert into archival memory.")
|
||||
|
||||
|
||||
class InsertAgentArchivalMemoryResponse(BaseModel):
|
||||
ids: List[str] = Field(
|
||||
..., description="Unique identifier for the new archival memory object. May return multiple ids if insert contents are chunked."
|
||||
)
|
||||
|
||||
|
||||
class DeleteAgentArchivalMemoryRequest(BaseModel):
|
||||
id: str = Field(..., description="Unique identifier for the new archival memory object.")
|
||||
|
||||
|
||||
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/{agent_id}/memory/messages", tags=["agents"], response_model=List[Message])
|
||||
def get_agent_in_context_messages(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the messages in the context of a specific agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_in_context_messages(agent_id=agent_id)
|
||||
|
||||
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
|
||||
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
|
||||
def get_agent_memory(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
agent_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the memory state of a specific agent.
|
||||
@ -39,13 +74,14 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface,
|
||||
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_agent_memory(agent_id=agent_id)
|
||||
memory = server.get_agent_memory(user_id=user_id, agent_id=agent_id)
|
||||
return GetAgentMemoryResponse(**memory)
|
||||
|
||||
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
|
||||
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse)
|
||||
def update_agent_memory(
|
||||
agent_id: str,
|
||||
request: Dict = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
agent_id: uuid.UUID,
|
||||
request: UpdateAgentMemoryRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Update the core memory of a specific agent.
|
||||
@ -53,86 +89,75 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface,
|
||||
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
|
||||
"""
|
||||
interface.clear()
|
||||
memory = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=request)
|
||||
return memory
|
||||
|
||||
@router.get("/agents/{agent_id}/memory/recall", tags=["agents"], response_model=RecallMemorySummary)
|
||||
def get_agent_recall_memory_summary(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
new_memory_contents = {"persona": request.persona, "human": request.human}
|
||||
response = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=new_memory_contents)
|
||||
return UpdateAgentMemoryResponse(**response)
|
||||
|
||||
@router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
|
||||
def get_agent_archival_memory_all(
|
||||
agent_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the recall memory of a specific agent.
|
||||
Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_recall_memory_summary(agent_id=agent_id)
|
||||
archival_memories = server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
|
||||
print("archival_memories:", archival_memories)
|
||||
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories]
|
||||
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
|
||||
|
||||
@router.get("/agents/{agent_id}/memory/archival", tags=["agents"], response_model=ArchivalMemorySummary)
|
||||
def get_agent_archival_memory_summary(
|
||||
agent_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the summary of the archival memory of a specific agent.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_archival_memory_summary(agent_id=agent_id)
|
||||
|
||||
# @router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=List[Passage])
|
||||
# def get_agent_archival_memory_all(
|
||||
# agent_id: str,
|
||||
# user_id: str = Depends(get_current_user_with_server),
|
||||
# ):
|
||||
# """
|
||||
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
|
||||
# """
|
||||
# interface.clear()
|
||||
# return server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage])
|
||||
@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
|
||||
def get_agent_archival_memory(
|
||||
agent_id: str,
|
||||
agent_id: uuid.UUID,
|
||||
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
|
||||
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
|
||||
limit: Optional[int] = Query(None, description="How many results to include in the response."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the memories in an agent's archival memory store (paginated query).
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_agent_archival_cursor(
|
||||
# TODO need to add support for non-postgres here
|
||||
# chroma will throw:
|
||||
# raise ValueError("Cannot run get_all_cursor with chroma")
|
||||
_, archival_json_records = server.get_agent_archival_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
)
|
||||
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_json_records]
|
||||
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
|
||||
|
||||
@router.post("/agents/{agent_id}/archival/{memory}", tags=["agents"], response_model=List[Passage])
|
||||
@router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=InsertAgentArchivalMemoryResponse)
|
||||
def insert_agent_archival_memory(
|
||||
agent_id: str,
|
||||
memory: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
agent_id: uuid.UUID,
|
||||
request: InsertAgentArchivalMemoryRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Insert a memory into an agent's archival memory store.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=memory)
|
||||
memory_ids = server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=request.content)
|
||||
return InsertAgentArchivalMemoryResponse(ids=memory_ids)
|
||||
|
||||
@router.delete("/agents/{agent_id}/archival/{memory_id}", tags=["agents"])
|
||||
@router.delete("/agents/{agent_id}/archival", tags=["agents"])
|
||||
def delete_agent_archival_memory(
|
||||
agent_id: str,
|
||||
memory_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
agent_id: uuid.UUID,
|
||||
id: str = Query(..., description="Unique ID of the memory to be deleted."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete a memory from an agent's archival memory store.
|
||||
"""
|
||||
# TODO: should probably return a `Passage`
|
||||
interface.clear()
|
||||
try:
|
||||
memory_id = uuid.UUID(id)
|
||||
server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
|
||||
except HTTPException:
|
||||
|
@ -1,48 +1,116 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from memgpt.schemas.enums import MessageRole, MessageStreamStatus
|
||||
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
|
||||
from memgpt.schemas.memgpt_request import MemGPTRequest
|
||||
from memgpt.schemas.memgpt_response import MemGPTResponse
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.models.pydantic_models import MemGPTUsageStatistics
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
|
||||
from memgpt.server.rest_api.utils import sse_async_generator
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import deduplicate
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# TODO: cpacker should check this file
|
||||
# TODO: move this into server.py?
|
||||
class MessageRoleType(str, Enum):
|
||||
user = "user"
|
||||
system = "system"
|
||||
|
||||
|
||||
class UserMessageRequest(BaseModel):
|
||||
message: str = Field(..., description="The message content to be processed by the agent.")
|
||||
name: Optional[str] = Field(default=None, description="Name of the message request sender")
|
||||
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
|
||||
stream_steps: bool = Field(
|
||||
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
|
||||
)
|
||||
stream_tokens: bool = Field(
|
||||
default=False,
|
||||
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
|
||||
)
|
||||
timestamp: Optional[datetime] = Field(
|
||||
None,
|
||||
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
|
||||
)
|
||||
stream: bool = Field(
|
||||
default=False,
|
||||
description="Legacy flag for old streaming API, will be deprecrated in the future.",
|
||||
deprecated=True,
|
||||
)
|
||||
|
||||
# @validator("timestamp", pre=True, always=True)
|
||||
# def validate_timestamp(cls, value: Optional[datetime]) -> Optional[datetime]:
|
||||
# if value is None:
|
||||
# return value # If the timestamp is None, just return None, implying default handling to set server-side
|
||||
|
||||
# if not isinstance(value, datetime):
|
||||
# raise TypeError("Timestamp must be a datetime object with timezone information.")
|
||||
|
||||
# if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
|
||||
# raise ValueError("Timestamp must be timezone-aware.")
|
||||
|
||||
# # Convert timestamp to UTC if it's not already in UTC
|
||||
# if value.tzinfo.utcoffset(value) != timezone.utc.utcoffset(value):
|
||||
# value = value.astimezone(timezone.utc)
|
||||
|
||||
# return value
|
||||
|
||||
|
||||
class UserMessageResponse(BaseModel):
|
||||
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
|
||||
usage: MemGPTUsageStatistics = Field(..., description="Usage statistics for the completion.")
|
||||
|
||||
|
||||
class GetAgentMessagesRequest(BaseModel):
|
||||
start: int = Field(..., description="Message index to start on (reverse chronological).")
|
||||
count: int = Field(..., description="How many messages to retrieve.")
|
||||
|
||||
|
||||
class GetAgentMessagesCursorRequest(BaseModel):
|
||||
before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.")
|
||||
limit: int = Field(..., description="Maximum number of messages to retrieve.")
|
||||
|
||||
|
||||
class GetAgentMessagesResponse(BaseModel):
|
||||
messages: list = Field(..., description="List of message objects.")
|
||||
|
||||
|
||||
async def send_message_to_agent(
|
||||
server: SyncServer,
|
||||
agent_id: str,
|
||||
user_id: str,
|
||||
role: MessageRole,
|
||||
agent_id: uuid.UUID,
|
||||
user_id: uuid.UUID,
|
||||
role: str,
|
||||
message: str,
|
||||
stream_legacy: bool, # legacy
|
||||
stream_steps: bool,
|
||||
stream_tokens: bool,
|
||||
chat_completion_mode: Optional[bool] = False,
|
||||
timestamp: Optional[datetime] = None,
|
||||
# related to whether or not we return `MemGPTMessage`s or `Message`s
|
||||
return_message_object: bool = True, # Should be True for Python Client, False for REST API
|
||||
) -> Union[StreamingResponse, MemGPTResponse]:
|
||||
) -> Union[StreamingResponse, UserMessageResponse]:
|
||||
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
|
||||
# TODO: @charles is this the correct way to handle?
|
||||
include_final_message = True
|
||||
|
||||
# determine role
|
||||
if role == MessageRole.user:
|
||||
# TODO this is a total hack but is required until we move streaming into the model config
|
||||
if server.server_llm_config.model_endpoint != "https://api.openai.com/v1":
|
||||
stream_tokens = False
|
||||
|
||||
# handle the legacy mode streaming
|
||||
if stream_legacy:
|
||||
# NOTE: override
|
||||
stream_steps = True
|
||||
stream_tokens = False
|
||||
include_final_message = False
|
||||
else:
|
||||
include_final_message = True
|
||||
|
||||
if role == "user" or role is None:
|
||||
message_func = server.user_message
|
||||
elif role == MessageRole.system:
|
||||
elif role == "system":
|
||||
message_func = server.system_message
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Bad role {role}")
|
||||
@ -53,11 +121,9 @@ async def send_message_to_agent(
|
||||
# For streaming response
|
||||
try:
|
||||
|
||||
# TODO: move this logic into server.py
|
||||
|
||||
# Get the generator object off of the agent's streaming interface
|
||||
# This will be attached to the POST SSE request used under-the-hood
|
||||
memgpt_agent = server._get_or_load_agent(agent_id=agent_id)
|
||||
memgpt_agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
streaming_interface = memgpt_agent.interface
|
||||
if not isinstance(streaming_interface, StreamingServerInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
@ -67,6 +133,8 @@ async def send_message_to_agent(
|
||||
# "chatcompletion mode" does some remapping and ignores inner thoughts
|
||||
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
|
||||
|
||||
# NOTE: for legacy 'stream' flag
|
||||
streaming_interface.nonstreaming_legacy_mode = stream_legacy
|
||||
# streaming_interface.allow_assistant_message = stream
|
||||
# streaming_interface.function_call_legacy_mode = stream
|
||||
|
||||
@ -77,44 +145,21 @@ async def send_message_to_agent(
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
if return_message_object:
|
||||
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
|
||||
raise NotImplementedError
|
||||
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
else:
|
||||
# buffer the stream, then return the list
|
||||
generated_stream = []
|
||||
async for message in streaming_interface.get_generator():
|
||||
assert (
|
||||
isinstance(message, MemGPTMessage)
|
||||
or isinstance(message, LegacyMemGPTMessage)
|
||||
or isinstance(message, MessageStreamStatus)
|
||||
), type(message)
|
||||
generated_stream.append(message)
|
||||
if message == MessageStreamStatus.done:
|
||||
if "data" in message and message["data"] == "[DONE]":
|
||||
break
|
||||
|
||||
# Get rid of the stream status messages
|
||||
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
|
||||
filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]]
|
||||
usage = await task
|
||||
|
||||
# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
|
||||
# If we want to convert these to Message, we can use the attached IDs
|
||||
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
|
||||
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
|
||||
if return_message_object:
|
||||
message_ids = [m.id for m in filtered_stream]
|
||||
message_ids = deduplicate(message_ids)
|
||||
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
|
||||
return MemGPTResponse(messages=message_objs, usage=usage)
|
||||
else:
|
||||
return MemGPTResponse(messages=filtered_stream, usage=usage)
|
||||
return UserMessageResponse(messages=filtered_stream, usage=usage)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@ -129,39 +174,55 @@ async def send_message_to_agent(
|
||||
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/agents/{agent_id}/messages/context/", tags=["agents"], response_model=List[Message])
|
||||
def get_agent_messages_in_context(
|
||||
agent_id: str,
|
||||
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
|
||||
def get_agent_messages(
|
||||
agent_id: uuid.UUID,
|
||||
start: int = Query(..., description="Message index to start on (reverse chronological)."),
|
||||
count: int = Query(..., description="How many messages to retrieve."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
|
||||
"""
|
||||
interface.clear()
|
||||
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=start, count=count)
|
||||
return messages
|
||||
# Validate with the Pydantic model (optional)
|
||||
request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
|
||||
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
|
||||
|
||||
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message])
|
||||
def get_agent_messages(
|
||||
agent_id: str,
|
||||
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
interface.clear()
|
||||
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
@router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse)
|
||||
def get_agent_messages_cursor(
|
||||
agent_id: uuid.UUID,
|
||||
before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."),
|
||||
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Retrieve message history for an agent.
|
||||
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
|
||||
"""
|
||||
interface.clear()
|
||||
return server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, before=before, limit=limit, reverse=True)
|
||||
# Validate with the Pydantic model (optional)
|
||||
request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit)
|
||||
|
||||
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse)
|
||||
interface.clear()
|
||||
[_, messages] = server.get_agent_recall_cursor(
|
||||
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
|
||||
)
|
||||
# print("====> messages-cursor DEBUG")
|
||||
# for i, msg in enumerate(messages):
|
||||
# print(f"message {i+1}/{len(messages)}")
|
||||
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
|
||||
# print(f"ISO format string: {msg['created_at']}")
|
||||
# print(msg)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
|
||||
async def send_message(
|
||||
# background_tasks: BackgroundTasks,
|
||||
agent_id: str,
|
||||
request: MemGPTRequest = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
agent_id: uuid.UUID,
|
||||
request: UserMessageRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Process a user message and return the agent's response.
|
||||
@ -169,21 +230,17 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
It can optionally stream the response if 'stream' is set to True.
|
||||
"""
|
||||
# TODO: should this recieve multiple messages? @cpacker
|
||||
# TODO: revise to `MemGPTRequest`
|
||||
# TODO: support sending multiple messages
|
||||
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
||||
message = request.messages[0]
|
||||
|
||||
# TODO: what to do with message.name?
|
||||
return await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=user_id,
|
||||
role=message.role,
|
||||
message=message.text,
|
||||
role=request.role,
|
||||
message=request.message,
|
||||
stream_steps=request.stream_steps,
|
||||
stream_tokens=request.stream_tokens,
|
||||
timestamp=request.timestamp,
|
||||
# legacy
|
||||
stream_legacy=request.stream,
|
||||
)
|
||||
|
||||
return router
|
||||
|
@ -1,73 +0,0 @@
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
|
||||
from memgpt.schemas.block import Block, CreateBlock
|
||||
from memgpt.schemas.block import Human as HumanModel # TODO: modify
|
||||
from memgpt.schemas.block import UpdateBlock
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_block_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/blocks", tags=["block"], response_model=List[Block])
|
||||
async def list_blocks(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
# query parameters
|
||||
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
|
||||
templates_only: bool = Query(True, description="Whether to include only templates"),
|
||||
name: Optional[str] = Query(None, description="Name of the block"),
|
||||
):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
blocks = server.get_blocks(user_id=user_id, label=label, template=templates_only, name=name)
|
||||
if blocks is None:
|
||||
return []
|
||||
return blocks
|
||||
|
||||
@router.post("/blocks", tags=["block"], response_model=Block)
|
||||
async def create_block(
|
||||
request: CreateBlock = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
request.user_id = user_id # TODO: remove?
|
||||
return server.create_block(user_id=user_id, request=request)
|
||||
|
||||
@router.post("/blocks/{block_id}", tags=["block"], response_model=Block)
|
||||
async def update_block(
|
||||
block_id: str,
|
||||
request: UpdateBlock = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
# TODO: should this be in the param or the POST data?
|
||||
assert block_id == request.id
|
||||
return server.update_block(user_id=user_id, request=request)
|
||||
|
||||
@router.delete("/blocks/{block_id}", tags=["block"], response_model=Block)
|
||||
async def delete_block(
|
||||
block_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
return server.delete_block(block_id=block_id)
|
||||
|
||||
@router.get("/blocks/{block_id}", tags=["block"], response_model=Block)
|
||||
async def get_block(
|
||||
block_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
block = server.get_block(block_id=block_id)
|
||||
if block is None:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
return block
|
||||
|
||||
return router
|
@ -1,11 +1,9 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.embedding_config import EmbeddingConfig
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@ -21,20 +19,13 @@ class ConfigResponse(BaseModel):
|
||||
def setup_config_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/config/llm", tags=["config"], response_model=List[LLMConfig])
|
||||
def get_llm_configs(user_id: str = Depends(get_current_user_with_server)):
|
||||
@router.get("/config", tags=["config"], response_model=ConfigResponse)
|
||||
def get_server_config(user_id: uuid.UUID = Depends(get_current_user_with_server)):
|
||||
"""
|
||||
Retrieve the base configuration for the server.
|
||||
"""
|
||||
interface.clear()
|
||||
return [server.server_llm_config]
|
||||
|
||||
@router.get("/config/embedding", tags=["config"], response_model=List[EmbeddingConfig])
|
||||
def get_embedding_configs(user_id: str = Depends(get_current_user_with_server)):
|
||||
"""
|
||||
Retrieve the base configuration for the server.
|
||||
"""
|
||||
interface.clear()
|
||||
return [server.server_embedding_config]
|
||||
response = server.get_server_config(include_defaults=True)
|
||||
return ConfigResponse(config=response["config"], defaults=response["defaults"])
|
||||
|
||||
return router
|
||||
|
69
memgpt/server/rest_api/humans/index.py
Normal file
69
memgpt/server/rest_api/humans/index.py
Normal file
@ -0,0 +1,69 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import HumanModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListHumansResponse(BaseModel):
|
||||
humans: List[HumanModel] = Field(..., description="List of human configurations.")
|
||||
|
||||
|
||||
class CreateHumanRequest(BaseModel):
|
||||
text: str = Field(..., description="The human text.")
|
||||
name: str = Field(..., description="The name of the human.")
|
||||
|
||||
|
||||
def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/humans", tags=["humans"], response_model=ListHumansResponse)
|
||||
async def list_humans(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
humans = server.ms.list_humans(user_id=user_id)
|
||||
return ListHumansResponse(humans=humans)
|
||||
|
||||
@router.post("/humans", tags=["humans"], response_model=HumanModel)
|
||||
async def create_human(
|
||||
request: CreateHumanRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
# TODO: disallow duplicate names for humans
|
||||
interface.clear()
|
||||
new_human = HumanModel(text=request.text, name=request.name, user_id=user_id)
|
||||
human_id = new_human.id
|
||||
server.ms.add_human(new_human)
|
||||
return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id)
|
||||
|
||||
@router.delete("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
|
||||
async def delete_human(
|
||||
human_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
human = server.ms.delete_human(human_name, user_id=user_id)
|
||||
return human
|
||||
|
||||
@router.get("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
|
||||
async def get_human(
|
||||
human_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
human = server.ms.get_human(human_name, user_id=user_id)
|
||||
if human is None:
|
||||
raise HTTPException(status_code=404, detail="Human not found")
|
||||
return human
|
||||
|
||||
return router
|
@ -2,24 +2,13 @@ import asyncio
|
||||
import json
|
||||
import queue
|
||||
from collections import deque
|
||||
from typing import AsyncGenerator, Literal, Optional, Union
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.schemas.enums import MessageStreamStatus
|
||||
from memgpt.schemas.memgpt_message import (
|
||||
AssistantMessage,
|
||||
FunctionCall,
|
||||
FunctionCallMessage,
|
||||
FunctionReturn,
|
||||
InternalMonologue,
|
||||
LegacyFunctionCallMessage,
|
||||
LegacyMemGPTMessage,
|
||||
MemGPTMessage,
|
||||
)
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
from memgpt.models.chat_completion_response import ChatCompletionChunkResponse
|
||||
from memgpt.streaming_interface import AgentChunkStreamingInterface
|
||||
from memgpt.utils import is_utc_datetime
|
||||
from memgpt.utils import get_utc_time, is_utc_datetime
|
||||
|
||||
|
||||
class QueuingInterface(AgentInterface):
|
||||
@ -29,66 +18,12 @@ class QueuingInterface(AgentInterface):
|
||||
self.buffer = queue.Queue()
|
||||
self.debug = debug
|
||||
|
||||
def _queue_push(self, message_api: Union[str, dict], message_obj: Union[Message, None]):
|
||||
"""Wrapper around self.buffer.queue.put() that ensures the types are safe
|
||||
|
||||
Data will be in the format: {
|
||||
"message_obj": ...
|
||||
"message_string": ...
|
||||
}
|
||||
"""
|
||||
|
||||
# Check the string first
|
||||
|
||||
if isinstance(message_api, str):
|
||||
# check that it's the stop word
|
||||
if message_api == "STOP":
|
||||
assert message_obj is None
|
||||
self.buffer.put(
|
||||
{
|
||||
"message_api": message_api,
|
||||
"message_obj": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized string pushed to buffer: {message_api}")
|
||||
|
||||
elif isinstance(message_api, dict):
|
||||
# check if it's the error message style
|
||||
if len(message_api.keys()) == 1 and "internal_error" in message_api:
|
||||
assert message_obj is None
|
||||
self.buffer.put(
|
||||
{
|
||||
"message_api": message_api,
|
||||
"message_obj": None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
assert message_obj is not None, message_api
|
||||
self.buffer.put(
|
||||
{
|
||||
"message_api": message_api,
|
||||
"message_obj": message_obj,
|
||||
}
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unrecognized type pushed to buffer: {type(message_api)}")
|
||||
|
||||
def to_list(self, style: Literal["obj", "api"] = "obj"):
|
||||
def to_list(self):
|
||||
"""Convert queue to a list (empties it out at the same time)"""
|
||||
items = []
|
||||
while not self.buffer.empty():
|
||||
try:
|
||||
# items.append(self.buffer.get_nowait())
|
||||
item_to_push = self.buffer.get_nowait()
|
||||
if style == "obj":
|
||||
if item_to_push["message_obj"] is not None:
|
||||
items.append(item_to_push["message_obj"])
|
||||
elif style == "api":
|
||||
items.append(item_to_push["message_api"])
|
||||
else:
|
||||
raise ValueError(style)
|
||||
items.append(self.buffer.get_nowait())
|
||||
except queue.Empty:
|
||||
break
|
||||
if len(items) > 1 and items[-1] == "STOP":
|
||||
@ -101,30 +36,20 @@ class QueuingInterface(AgentInterface):
|
||||
# Empty the queue
|
||||
self.buffer.queue.clear()
|
||||
|
||||
async def message_generator(self, style: Literal["obj", "api"] = "obj"):
|
||||
async def message_generator(self):
|
||||
while True:
|
||||
if not self.buffer.empty():
|
||||
message = self.buffer.get()
|
||||
message_obj = message["message_obj"]
|
||||
message_api = message["message_api"]
|
||||
|
||||
if message_api == "STOP":
|
||||
if message == "STOP":
|
||||
break
|
||||
|
||||
# yield message
|
||||
if style == "obj":
|
||||
yield message_obj
|
||||
elif style == "api":
|
||||
yield message_api
|
||||
else:
|
||||
raise ValueError(style)
|
||||
|
||||
# yield message | {"date": datetime.now(tz=pytz.utc).isoformat()}
|
||||
yield message
|
||||
else:
|
||||
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
|
||||
|
||||
def step_yield(self):
|
||||
"""Enqueue a special stop message"""
|
||||
self._queue_push(message_api="STOP", message_obj=None)
|
||||
self.buffer.put("STOP")
|
||||
|
||||
@staticmethod
|
||||
def step_complete():
|
||||
@ -132,8 +57,8 @@ class QueuingInterface(AgentInterface):
|
||||
|
||||
def error(self, error: str):
|
||||
"""Enqueue a special stop message"""
|
||||
self._queue_push(message_api={"internal_error": error}, message_obj=None)
|
||||
self._queue_push(message_api="STOP", message_obj=None)
|
||||
self.buffer.put({"internal_error": error})
|
||||
self.buffer.put("STOP")
|
||||
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""Handle reception of a user message"""
|
||||
@ -159,7 +84,7 @@ class QueuingInterface(AgentInterface):
|
||||
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = msg_obj.created_at.isoformat()
|
||||
|
||||
self._queue_push(message_api=new_message, message_obj=msg_obj)
|
||||
self.buffer.put(new_message)
|
||||
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
"""Handle the agent sending a message"""
|
||||
@ -183,13 +108,11 @@ class QueuingInterface(AgentInterface):
|
||||
assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message."
|
||||
# TODO also should not be accessing protected member here
|
||||
|
||||
new_message["id"] = self.buffer.queue[-1]["message_api"]["id"]
|
||||
new_message["id"] = self.buffer.queue[-1]["id"]
|
||||
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = self.buffer.queue[-1]["message_api"]["date"]
|
||||
new_message["date"] = self.buffer.queue[-1]["date"]
|
||||
|
||||
msg_obj = self.buffer.queue[-1]["message_obj"]
|
||||
|
||||
self._queue_push(message_api=new_message, message_obj=msg_obj)
|
||||
self.buffer.put(new_message)
|
||||
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
|
||||
"""Handle the agent calling a function"""
|
||||
@ -229,7 +152,7 @@ class QueuingInterface(AgentInterface):
|
||||
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = msg_obj.created_at.isoformat()
|
||||
|
||||
self._queue_push(message_api=new_message, message_obj=msg_obj)
|
||||
self.buffer.put(new_message)
|
||||
|
||||
|
||||
class FunctionArgumentsStreamHandler:
|
||||
@ -316,21 +239,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# if multi_step = True, the stream ends when the agent yields
|
||||
# if multi_step = False, the stream ends when the step ends
|
||||
self.multi_step = multi_step
|
||||
self.multi_step_indicator = MessageStreamStatus.done_step
|
||||
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
|
||||
self.multi_step_indicator = "[DONE_STEP]"
|
||||
self.multi_step_gen_indicator = "[DONE_GEN]"
|
||||
|
||||
# extra prints
|
||||
self.debug = False
|
||||
self.timeout = 30
|
||||
|
||||
async def _create_generator(self) -> AsyncGenerator[Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus], None]:
|
||||
async def _create_generator(self) -> AsyncGenerator:
|
||||
"""An asynchronous generator that yields chunks as they become available."""
|
||||
while self._active:
|
||||
try:
|
||||
# Wait until there is an item in the deque or the stream is deactivated
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout
|
||||
except asyncio.TimeoutError:
|
||||
break # Exit the loop if we timeout
|
||||
# Wait until there is an item in the deque or the stream is deactivated
|
||||
await self._event.wait()
|
||||
|
||||
while self._chunks:
|
||||
yield self._chunks.popleft()
|
||||
@ -338,33 +254,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# Reset the event until a new item is pushed
|
||||
self._event.clear()
|
||||
|
||||
# while self._active:
|
||||
# # Wait until there is an item in the deque or the stream is deactivated
|
||||
# await self._event.wait()
|
||||
|
||||
# while self._chunks:
|
||||
# yield self._chunks.popleft()
|
||||
|
||||
# # Reset the event until a new item is pushed
|
||||
# self._event.clear()
|
||||
|
||||
def get_generator(self) -> AsyncGenerator:
|
||||
"""Get the generator that yields processed chunks."""
|
||||
if not self._active:
|
||||
# If the stream is not active, don't return a generator that would produce values
|
||||
raise StopIteration("The stream has not been started or has been ended.")
|
||||
return self._create_generator()
|
||||
|
||||
def _push_to_buffer(self, item: Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus]):
|
||||
"""Add an item to the deque"""
|
||||
assert self._active, "Generator is inactive"
|
||||
# assert isinstance(item, dict) or isinstance(item, MessageStreamStatus), f"Wrong type: {type(item)}"
|
||||
assert (
|
||||
isinstance(item, MemGPTMessage) or isinstance(item, LegacyMemGPTMessage) or isinstance(item, MessageStreamStatus)
|
||||
), f"Wrong type: {type(item)}"
|
||||
self._chunks.append(item)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
def stream_start(self):
|
||||
"""Initialize streaming by activating the generator and clearing any old chunks."""
|
||||
self.streaming_chat_completion_mode_function_name = None
|
||||
@ -379,10 +268,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
self.streaming_chat_completion_mode_function_name = None
|
||||
|
||||
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
self._push_to_buffer(self.multi_step_gen_indicator)
|
||||
|
||||
# self._active = False
|
||||
# self._event.set() # Unblock the generator if it's waiting to allow it to complete
|
||||
self._chunks.append(self.multi_step_gen_indicator)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
# if not self.multi_step:
|
||||
# # end the stream
|
||||
@ -393,27 +280,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# self._chunks.append(self.multi_step_indicator)
|
||||
# self._event.set() # Signal that new data is available
|
||||
|
||||
def step_complete(self):
|
||||
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
|
||||
if not self.multi_step:
|
||||
# end the stream
|
||||
self._active = False
|
||||
self._event.set() # Unblock the generator if it's waiting to allow it to complete
|
||||
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
# signal that a new step has started in the stream
|
||||
self._push_to_buffer(self.multi_step_indicator)
|
||||
|
||||
def step_yield(self):
|
||||
"""If multi_step, this is the true 'stream_end' function."""
|
||||
# if self.multi_step:
|
||||
# end the stream
|
||||
self._active = False
|
||||
self._event.set() # Unblock the generator if it's waiting to allow it to complete
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
return
|
||||
|
||||
def _process_chunk_to_memgpt_style(self, chunk: ChatCompletionChunkResponse) -> Optional[dict]:
|
||||
"""
|
||||
Example data from non-streaming response looks like:
|
||||
@ -539,7 +405,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
if msg_obj:
|
||||
processed_chunk["id"] = str(msg_obj.id)
|
||||
|
||||
self._push_to_buffer(processed_chunk)
|
||||
self._chunks.append(processed_chunk)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
def get_generator(self) -> AsyncGenerator:
|
||||
"""Get the generator that yields processed chunks."""
|
||||
if not self._active:
|
||||
# If the stream is not active, don't return a generator that would produce values
|
||||
raise StopIteration("The stream has not been started or has been ended.")
|
||||
return self._create_generator()
|
||||
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
|
||||
"""MemGPT receives a user message"""
|
||||
@ -550,18 +424,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
if not self.streaming_mode:
|
||||
|
||||
# create a fake "chunk" of a stream
|
||||
# processed_chunk = {
|
||||
# "internal_monologue": msg,
|
||||
# "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
|
||||
# "id": str(msg_obj.id) if msg_obj is not None else None,
|
||||
# }
|
||||
processed_chunk = InternalMonologue(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
internal_monologue=msg,
|
||||
)
|
||||
processed_chunk = {
|
||||
"internal_monologue": msg,
|
||||
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
|
||||
"id": str(msg_obj.id) if msg_obj is not None else None,
|
||||
}
|
||||
|
||||
self._push_to_buffer(processed_chunk)
|
||||
self._chunks.append(processed_chunk)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
return
|
||||
|
||||
@ -603,56 +473,42 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# "date": "2024-06-22T23:04:32.141923+00:00"
|
||||
# }
|
||||
try:
|
||||
func_args = json.loads(function_call.function.arguments)
|
||||
func_args = json.loads(function_call.function["arguments"])
|
||||
except:
|
||||
func_args = function_call.function.arguments
|
||||
# processed_chunk = {
|
||||
# "function_call": f"{function_call.function.name}({func_args})",
|
||||
# "id": str(msg_obj.id),
|
||||
# "date": msg_obj.created_at.isoformat(),
|
||||
# }
|
||||
processed_chunk = LegacyFunctionCallMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_call=f"{function_call.function.name}({func_args})",
|
||||
)
|
||||
self._push_to_buffer(processed_chunk)
|
||||
func_args = function_call.function["arguments"]
|
||||
processed_chunk = {
|
||||
"function_call": f"{function_call.function['name']}({func_args})",
|
||||
"id": str(msg_obj.id),
|
||||
"date": msg_obj.created_at.isoformat(),
|
||||
}
|
||||
self._chunks.append(processed_chunk)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
if function_call.function.name == "send_message":
|
||||
if function_call.function["name"] == "send_message":
|
||||
try:
|
||||
# processed_chunk = {
|
||||
# "assistant_message": func_args["message"],
|
||||
# "id": str(msg_obj.id),
|
||||
# "date": msg_obj.created_at.isoformat(),
|
||||
# }
|
||||
processed_chunk = AssistantMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
assistant_message=func_args["message"],
|
||||
)
|
||||
self._push_to_buffer(processed_chunk)
|
||||
processed_chunk = {
|
||||
"assistant_message": func_args["message"],
|
||||
"id": str(msg_obj.id),
|
||||
"date": msg_obj.created_at.isoformat(),
|
||||
}
|
||||
self._chunks.append(processed_chunk)
|
||||
self._event.set() # Signal that new data is available
|
||||
except Exception as e:
|
||||
print(f"Failed to parse function message: {e}")
|
||||
|
||||
else:
|
||||
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_call=FunctionCall(
|
||||
name=function_call.function.name,
|
||||
arguments=function_call.function.arguments,
|
||||
),
|
||||
)
|
||||
# processed_chunk = {
|
||||
# "function_call": {
|
||||
# "name": function_call.function.name,
|
||||
# "arguments": function_call.function.arguments,
|
||||
# },
|
||||
# "id": str(msg_obj.id),
|
||||
# "date": msg_obj.created_at.isoformat(),
|
||||
# }
|
||||
self._push_to_buffer(processed_chunk)
|
||||
processed_chunk = {
|
||||
"function_call": {
|
||||
"id": function_call.id,
|
||||
"name": function_call.function["name"],
|
||||
"arguments": function_call.function["arguments"],
|
||||
},
|
||||
"id": str(msg_obj.id),
|
||||
"date": msg_obj.created_at.isoformat(),
|
||||
}
|
||||
self._chunks.append(processed_chunk)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
return
|
||||
else:
|
||||
@ -667,33 +523,43 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
elif msg.startswith("Success: "):
|
||||
msg = msg.replace("Success: ", "")
|
||||
# new_message = {"function_return": msg, "status": "success"}
|
||||
new_message = FunctionReturn(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_return=msg,
|
||||
status="success",
|
||||
)
|
||||
new_message = {"function_return": msg, "status": "success"}
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
msg = msg.replace("Error: ", "")
|
||||
# new_message = {"function_return": msg, "status": "error"}
|
||||
new_message = FunctionReturn(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_return=msg,
|
||||
status="error",
|
||||
)
|
||||
new_message = {"function_return": msg, "status": "error"}
|
||||
|
||||
else:
|
||||
# NOTE: generic, should not happen
|
||||
raise ValueError(msg)
|
||||
new_message = {"function_message": msg}
|
||||
|
||||
# add extra metadata
|
||||
# if msg_obj is not None:
|
||||
# new_message["id"] = str(msg_obj.id)
|
||||
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
# new_message["date"] = msg_obj.created_at.isoformat()
|
||||
if msg_obj is not None:
|
||||
new_message["id"] = str(msg_obj.id)
|
||||
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
|
||||
new_message["date"] = msg_obj.created_at.isoformat()
|
||||
|
||||
self._push_to_buffer(new_message)
|
||||
self._chunks.append(new_message)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
def step_complete(self):
|
||||
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
|
||||
if not self.multi_step:
|
||||
# end the stream
|
||||
self._active = False
|
||||
self._event.set() # Unblock the generator if it's waiting to allow it to complete
|
||||
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
# signal that a new step has started in the stream
|
||||
self._chunks.append(self.multi_step_indicator)
|
||||
self._event.set() # Signal that new data is available
|
||||
|
||||
def step_yield(self):
|
||||
"""If multi_step, this is the true 'stream_end' function."""
|
||||
if self.multi_step:
|
||||
# end the stream
|
||||
self._active = False
|
||||
self._event.set() # Unblock the generator if it's waiting to allow it to complete
|
||||
|
||||
@staticmethod
|
||||
def clear():
|
||||
return
|
||||
|
@ -4,7 +4,7 @@ from typing import List
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.models.pydantic_models import LLMConfigModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@ -13,7 +13,7 @@ router = APIRouter()
|
||||
|
||||
|
||||
class ListModelsResponse(BaseModel):
|
||||
models: List[LLMConfig] = Field(..., description="List of model configurations.")
|
||||
models: List[LLMConfigModel] = Field(..., description="List of model configurations.")
|
||||
|
||||
|
||||
def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
@ -25,7 +25,7 @@ def setup_models_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
interface.clear()
|
||||
|
||||
# currently, the server only supports one model, however this may change in the future
|
||||
llm_config = LLMConfig(
|
||||
llm_config = LLMConfigModel(
|
||||
model=server.server_llm_config.model,
|
||||
model_endpoint=server.server_llm_config.model_endpoint,
|
||||
model_endpoint_type=server.server_llm_config.model_endpoint_type,
|
||||
|
@ -4,9 +4,10 @@ from typing import List, Optional
|
||||
from fastapi import APIRouter, Body, HTTPException, Path, Query
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.openai import (
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.models.openai import (
|
||||
AssistantFile,
|
||||
MessageFile,
|
||||
MessageRoleType,
|
||||
@ -138,6 +139,10 @@ class SubmitToolOutputsToRunRequest(BaseModel):
|
||||
|
||||
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
|
||||
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
|
||||
# TODO: remove this (when we have user auth)
|
||||
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
|
||||
print(f"User ID: {user_id}")
|
||||
|
||||
# create assistant (MemGPT agent)
|
||||
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
|
||||
def create_assistant(request: CreateAssistantRequest = Body(...)):
|
||||
|
@ -4,9 +4,9 @@ from functools import partial
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
|
||||
# from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
# from memgpt.data_types import Message
|
||||
from memgpt.models.chat_completion_request import ChatCompletionRequest
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
Message,
|
||||
|
@ -5,7 +5,7 @@ from typing import List
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.block import Persona as PersonaModel # TODO: modify
|
||||
from memgpt.models.pydantic_models import PersonaModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@ -44,7 +44,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
interface.clear()
|
||||
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
|
||||
persona_id = new_persona.id
|
||||
server.ms.create_persona(new_persona)
|
||||
server.ms.add_persona(new_persona)
|
||||
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)
|
||||
|
||||
@router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
|
||||
|
0
memgpt/server/rest_api/presets/__init__.py
Normal file
0
memgpt/server/rest_api/presets/__init__.py
Normal file
171
memgpt/server/rest_api/presets/index.py
Normal file
171
memgpt/server/rest_api/presets/index.py
Normal file
@ -0,0 +1,171 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||
from memgpt.data_types import Preset # TODO remove
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
"""
|
||||
Implement the following functions:
|
||||
* List all available presets
|
||||
* Create a new preset
|
||||
* Delete a preset
|
||||
* TODO update a preset
|
||||
"""
|
||||
|
||||
|
||||
class ListPresetsResponse(BaseModel):
|
||||
presets: List[PresetModel] = Field(..., description="List of available presets.")
|
||||
|
||||
|
||||
class CreatePresetsRequest(BaseModel):
|
||||
# TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)?
|
||||
name: str = Field(..., description="The name of the preset.")
|
||||
id: Optional[str] = Field(None, description="The unique identifier of the preset.")
|
||||
# user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
|
||||
description: Optional[str] = Field(None, description="The description of the preset.")
|
||||
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
|
||||
system: Optional[str] = Field(None, description="The system prompt of the preset.") # TODO: make optional and allow defaults
|
||||
persona: Optional[str] = Field(default=None, description="The persona of the preset.")
|
||||
human: Optional[str] = Field(default=None, description="The human of the preset.")
|
||||
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
|
||||
# TODO
|
||||
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
|
||||
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
|
||||
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
|
||||
|
||||
|
||||
class CreatePresetResponse(BaseModel):
|
||||
preset: PresetModel = Field(..., description="The newly created preset.")
|
||||
|
||||
|
||||
def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/presets/{preset_name}", tags=["presets"], response_model=PresetModel)
|
||||
async def get_preset(
|
||||
preset_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Get a preset."""
|
||||
try:
|
||||
preset = server.get_preset(user_id=user_id, preset_name=preset_name)
|
||||
return preset
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.get("/presets", tags=["presets"], response_model=ListPresetsResponse)
|
||||
async def list_presets(
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""List all presets created by a user."""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
|
||||
try:
|
||||
presets = server.list_presets(user_id=user_id)
|
||||
return ListPresetsResponse(presets=presets)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/presets", tags=["presets"], response_model=CreatePresetResponse)
|
||||
async def create_preset(
|
||||
request: CreatePresetsRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Create a preset."""
|
||||
try:
|
||||
if isinstance(request.id, str):
|
||||
request.id = uuid.UUID(request.id)
|
||||
|
||||
# check if preset already exists
|
||||
# TODO: move this into a server function to create a preset
|
||||
if server.ms.get_preset(name=request.name, user_id=user_id):
|
||||
raise HTTPException(status_code=400, detail=f"Preset with name {request.name} already exists.")
|
||||
|
||||
# For system/human/persona - if {system/human-personal}_name is None but the text is provied, then create a new data entry
|
||||
if not request.system_name and request.system:
|
||||
# new system provided without name identity
|
||||
system_name = f"system_{request.name}_{str(uuid.uuid4())}"
|
||||
system = request.system
|
||||
# TODO: insert into system table
|
||||
else:
|
||||
system_name = request.system_name if request.system_name else DEFAULT_PRESET
|
||||
system = request.system if request.system else gpt_system.get_system_text(system_name)
|
||||
|
||||
if not request.human_name and request.human:
|
||||
# new human provided without name identity
|
||||
human_name = f"human_{request.name}_{str(uuid.uuid4())}"
|
||||
human = request.human
|
||||
server.ms.add_human(HumanModel(text=human, name=human_name, user_id=user_id))
|
||||
else:
|
||||
human_name = request.human_name if request.human_name else DEFAULT_HUMAN
|
||||
human = request.human if request.human else get_human_text(human_name)
|
||||
|
||||
if not request.persona_name and request.persona:
|
||||
# new persona provided without name identity
|
||||
persona_name = f"persona_{request.name}_{str(uuid.uuid4())}"
|
||||
persona = request.persona
|
||||
server.ms.add_persona(PersonaModel(text=persona, name=persona_name, user_id=user_id))
|
||||
else:
|
||||
persona_name = request.persona_name if request.persona_name else DEFAULT_PERSONA
|
||||
persona = request.persona if request.persona else get_persona_text(persona_name)
|
||||
|
||||
# create preset
|
||||
new_preset = Preset(
|
||||
user_id=user_id,
|
||||
id=request.id if request.id else uuid.uuid4(),
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
system=system,
|
||||
persona=persona,
|
||||
persona_name=persona_name,
|
||||
human=human,
|
||||
human_name=human_name,
|
||||
functions_schema=request.functions_schema,
|
||||
)
|
||||
preset = server.create_preset(preset=new_preset)
|
||||
|
||||
# TODO remove once we migrate from Preset to PresetModel
|
||||
preset = PresetModel(**vars(preset))
|
||||
|
||||
return CreatePresetResponse(preset=preset)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.delete("/presets/{preset_id}", tags=["presets"])
|
||||
async def delete_preset(
|
||||
preset_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""Delete a preset."""
|
||||
interface.clear()
|
||||
try:
|
||||
preset = server.delete_preset(user_id=user_id, preset_id=preset_id)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={"message": f"Preset preset_id={str(preset.id)} successfully deleted"}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
return router
|
@ -15,12 +15,14 @@ from memgpt.server.constants import REST_DEFAULT_PORT
|
||||
from memgpt.server.rest_api.admin.agents import setup_agents_admin_router
|
||||
from memgpt.server.rest_api.admin.tools import setup_tools_index_router
|
||||
from memgpt.server.rest_api.admin.users import setup_admin_router
|
||||
from memgpt.server.rest_api.agents.command import setup_agents_command_router
|
||||
from memgpt.server.rest_api.agents.config import setup_agents_config_router
|
||||
from memgpt.server.rest_api.agents.index import setup_agents_index_router
|
||||
from memgpt.server.rest_api.agents.memory import setup_agents_memory_router
|
||||
from memgpt.server.rest_api.agents.message import setup_agents_message_router
|
||||
from memgpt.server.rest_api.auth.index import setup_auth_router
|
||||
from memgpt.server.rest_api.block.index import setup_block_index_router
|
||||
from memgpt.server.rest_api.config.index import setup_config_index_router
|
||||
from memgpt.server.rest_api.humans.index import setup_humans_index_router
|
||||
from memgpt.server.rest_api.interface import StreamingServerInterface
|
||||
from memgpt.server.rest_api.models.index import setup_models_index_router
|
||||
from memgpt.server.rest_api.openai_assistants.assistants import (
|
||||
@ -29,6 +31,8 @@ from memgpt.server.rest_api.openai_assistants.assistants import (
|
||||
from memgpt.server.rest_api.openai_chat_completions.chat_completions import (
|
||||
setup_openai_chat_completions_router,
|
||||
)
|
||||
from memgpt.server.rest_api.personas.index import setup_personas_index_router
|
||||
from memgpt.server.rest_api.presets.index import setup_presets_index_router
|
||||
from memgpt.server.rest_api.sources.index import setup_sources_index_router
|
||||
from memgpt.server.rest_api.static_files import mount_static_files
|
||||
from memgpt.server.rest_api.tools.index import setup_user_tools_index_router
|
||||
@ -91,13 +95,17 @@ app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PRE
|
||||
app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)])
|
||||
|
||||
# /api/agents endpoints
|
||||
app.include_router(setup_agents_command_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_agents_config_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_agents_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
# /api/config endpoints
|
||||
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
@ -145,8 +153,7 @@ def on_startup():
|
||||
@app.on_event("shutdown")
|
||||
def on_shutdown():
|
||||
global server
|
||||
if server:
|
||||
server.save_agents()
|
||||
server.save_agents()
|
||||
server = None
|
||||
|
||||
|
||||
|
@ -1,7 +1,8 @@
|
||||
import os
|
||||
import tempfile
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
@ -11,14 +12,20 @@ from fastapi import (
|
||||
HTTPException,
|
||||
Query,
|
||||
UploadFile,
|
||||
status,
|
||||
)
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.document import Document
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.schemas.passage import Passage
|
||||
|
||||
# schemas
|
||||
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate, UploadFile
|
||||
from memgpt.data_sources.connectors import DirectoryConnector
|
||||
from memgpt.data_types import Source
|
||||
from memgpt.models.pydantic_models import (
|
||||
DocumentModel,
|
||||
JobModel,
|
||||
JobStatus,
|
||||
PassageModel,
|
||||
SourceModel,
|
||||
)
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@ -37,73 +44,77 @@ Implement the following functions:
|
||||
"""
|
||||
|
||||
|
||||
# class ListSourcesResponse(BaseModel):
|
||||
# sources: List[SourceModel] = Field(..., description="List of available sources.")
|
||||
#
|
||||
#
|
||||
# class CreateSourceRequest(BaseModel):
|
||||
# name: str = Field(..., description="The name of the source.")
|
||||
# description: Optional[str] = Field(None, description="The description of the source.")
|
||||
#
|
||||
#
|
||||
# class UploadFileToSourceRequest(BaseModel):
|
||||
# file: UploadFile = Field(..., description="The file to upload.")
|
||||
#
|
||||
#
|
||||
# class UploadFileToSourceResponse(BaseModel):
|
||||
# source: SourceModel = Field(..., description="The source the file was uploaded to.")
|
||||
# added_passages: int = Field(..., description="The number of passages added to the source.")
|
||||
# added_documents: int = Field(..., description="The number of documents added to the source.")
|
||||
#
|
||||
#
|
||||
# class GetSourcePassagesResponse(BaseModel):
|
||||
# passages: List[PassageModel] = Field(..., description="List of passages from the source.")
|
||||
#
|
||||
#
|
||||
# class GetSourceDocumentsResponse(BaseModel):
|
||||
# documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
|
||||
class ListSourcesResponse(BaseModel):
|
||||
sources: List[SourceModel] = Field(..., description="List of available sources.")
|
||||
|
||||
|
||||
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
|
||||
# write the file to a temporary directory (deleted after the context manager exits)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_path = os.path.join(tmpdirname, file.filename)
|
||||
with open(file_path, "wb") as buffer:
|
||||
buffer.write(bytes)
|
||||
class CreateSourceRequest(BaseModel):
|
||||
name: str = Field(..., description="The name of the source.")
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
|
||||
server.load_file_to_source(source_id, file_path, job_id)
|
||||
|
||||
class UploadFileToSourceRequest(BaseModel):
|
||||
file: UploadFile = Field(..., description="The file to upload.")
|
||||
|
||||
|
||||
class UploadFileToSourceResponse(BaseModel):
|
||||
source: SourceModel = Field(..., description="The source the file was uploaded to.")
|
||||
added_passages: int = Field(..., description="The number of passages added to the source.")
|
||||
added_documents: int = Field(..., description="The number of documents added to the source.")
|
||||
|
||||
|
||||
class GetSourcePassagesResponse(BaseModel):
|
||||
passages: List[PassageModel] = Field(..., description="List of passages from the source.")
|
||||
|
||||
|
||||
class GetSourceDocumentsResponse(BaseModel):
|
||||
documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
|
||||
|
||||
|
||||
def load_file_to_source(server: SyncServer, user_id: uuid.UUID, source: Source, job_id: uuid.UUID, file: UploadFile, bytes: bytes):
|
||||
# update job status
|
||||
job = server.ms.get_job(job_id=job_id)
|
||||
job.status = JobStatus.running
|
||||
server.ms.update_job(job)
|
||||
|
||||
try:
|
||||
# write the file to a temporary directory (deleted after the context manager exits)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_path = os.path.join(tmpdirname, file.filename)
|
||||
with open(file_path, "wb") as buffer:
|
||||
buffer.write(bytes)
|
||||
|
||||
# read the file
|
||||
connector = DirectoryConnector(input_files=[file_path])
|
||||
|
||||
# TODO: pre-compute total number of passages?
|
||||
|
||||
# load the data into the source via the connector
|
||||
num_passages, num_documents = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
|
||||
except Exception as e:
|
||||
# job failed with error
|
||||
error = str(e)
|
||||
print(error)
|
||||
job.status = JobStatus.failed
|
||||
job.metadata_["error"] = error
|
||||
server.ms.update_job(job)
|
||||
# TODO: delete any associated passages/documents?
|
||||
return 0, 0
|
||||
|
||||
# update job status
|
||||
job.status = JobStatus.completed
|
||||
job.metadata_["num_passages"] = num_passages
|
||||
job.metadata_["num_documents"] = num_documents
|
||||
print("job completed", job.metadata_, job.id)
|
||||
server.ms.update_job(job)
|
||||
|
||||
|
||||
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/sources/{source_id}", tags=["sources"], response_model=Source)
|
||||
async def get_source(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get all sources
|
||||
"""
|
||||
interface.clear()
|
||||
source = server.get_source(source_id=source_id, user_id=user_id)
|
||||
return source
|
||||
|
||||
@router.get("/sources/name/{source_name}", tags=["sources"], response_model=str)
|
||||
async def get_source_id_by_name(
|
||||
source_name: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a source by name
|
||||
"""
|
||||
interface.clear()
|
||||
source = server.get_source_id(source_name=source_name, user_id=user_id)
|
||||
return source
|
||||
|
||||
@router.get("/sources", tags=["sources"], response_model=List[Source])
|
||||
@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
|
||||
async def list_sources(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all data sources created by a user.
|
||||
@ -113,40 +124,58 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
try:
|
||||
sources = server.list_all_sources(user_id=user_id)
|
||||
return sources
|
||||
return ListSourcesResponse(sources=sources)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources", tags=["sources"], response_model=Source)
|
||||
@router.post("/sources", tags=["sources"], response_model=SourceModel)
|
||||
async def create_source(
|
||||
request: SourceCreate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
request: CreateSourceRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Create a new data source.
|
||||
"""
|
||||
interface.clear()
|
||||
try:
|
||||
return server.create_source(request=request, user_id=user_id)
|
||||
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
|
||||
source = server.create_source(name=request.name, user_id=user_id, description=request.description)
|
||||
return SourceModel(
|
||||
name=source.name,
|
||||
description=source.description,
|
||||
user_id=source.user_id,
|
||||
id=source.id,
|
||||
embedding_config=server.server_embedding_config,
|
||||
created_at=source.created_at.timestamp(),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources/{source_id}", tags=["sources"], response_model=Source)
|
||||
@router.post("/sources/{source_id}", tags=["sources"], response_model=SourceModel)
|
||||
async def update_source(
|
||||
source_id: str,
|
||||
request: SourceUpdate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
source_id: uuid.UUID,
|
||||
request: CreateSourceRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Update the name or documentation of an existing data source.
|
||||
"""
|
||||
interface.clear()
|
||||
try:
|
||||
return server.update_source(request=request, user_id=user_id)
|
||||
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
|
||||
source = server.update_source(source_id=source_id, name=request.name, user_id=user_id, description=request.description)
|
||||
return SourceModel(
|
||||
name=source.name,
|
||||
description=source.description,
|
||||
user_id=source.user_id,
|
||||
id=source.id,
|
||||
embedding_config=server.server_embedding_config,
|
||||
created_at=source.created_at.timestamp(),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@ -154,8 +183,8 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
|
||||
@router.delete("/sources/{source_id}", tags=["sources"])
|
||||
async def delete_source(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
source_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete a data source.
|
||||
@ -163,58 +192,66 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
interface.clear()
|
||||
try:
|
||||
server.delete_source(source_id=source_id, user_id=user_id)
|
||||
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=Source)
|
||||
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
|
||||
async def attach_source_to_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
source_id: uuid.UUID,
|
||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Attach a data source to an existing agent.
|
||||
"""
|
||||
interface.clear()
|
||||
assert isinstance(agent_id, str), f"Expected agent_id to be a UUID, got {agent_id}"
|
||||
assert isinstance(user_id, str), f"Expected user_id to be a UUID, got {user_id}"
|
||||
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
|
||||
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
|
||||
source = server.ms.get_source(source_id=source_id, user_id=user_id)
|
||||
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=user_id)
|
||||
return source
|
||||
source = server.attach_source_to_agent(source_name=source.name, agent_id=agent_id, user_id=user_id)
|
||||
return SourceModel(
|
||||
name=source.name,
|
||||
description=None, # TODO: actually store descriptions
|
||||
user_id=source.user_id,
|
||||
id=source.id,
|
||||
embedding_config=server.server_embedding_config,
|
||||
created_at=source.created_at,
|
||||
)
|
||||
|
||||
@router.post("/sources/{source_id}/detach", tags=["sources"])
|
||||
@router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
|
||||
async def detach_source_from_agent(
|
||||
source_id: str,
|
||||
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
source_id: uuid.UUID,
|
||||
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Detach a data source from an existing agent.
|
||||
"""
|
||||
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
|
||||
|
||||
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=Job)
|
||||
async def get_job(
|
||||
job_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=JobModel)
|
||||
async def get_job_status(
|
||||
job_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get the status of a job.
|
||||
"""
|
||||
job = server.get_job(job_id=job_id)
|
||||
job = server.ms.get_job(job_id=job_id)
|
||||
if job is None:
|
||||
raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.")
|
||||
return job
|
||||
|
||||
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=Job)
|
||||
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=JobModel)
|
||||
async def upload_file_to_source(
|
||||
# file: UploadFile = UploadFile(..., description="The file to upload."),
|
||||
file: UploadFile,
|
||||
source_id: str,
|
||||
source_id: uuid.UUID,
|
||||
background_tasks: BackgroundTasks,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
@ -224,39 +261,37 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
bytes = file.file.read()
|
||||
|
||||
# create job
|
||||
# TODO: create server function
|
||||
job = Job(user_id=user_id, metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id})
|
||||
job = JobModel(user_id=user_id, metadata={"type": "embedding", "filename": file.filename, "source_id": source_id})
|
||||
job_id = job.id
|
||||
server.ms.create_job(job)
|
||||
|
||||
# create background task
|
||||
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
|
||||
background_tasks.add_task(load_file_to_source, server, user_id, source, job_id, file, bytes)
|
||||
|
||||
# return job information
|
||||
job = server.ms.get_job(job_id=job_id)
|
||||
return job
|
||||
|
||||
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=List[Passage])
|
||||
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
|
||||
async def list_passages(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
source_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all passages associated with a data source.
|
||||
"""
|
||||
# TODO: check if paginated?
|
||||
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
|
||||
return passages
|
||||
return GetSourcePassagesResponse(passages=passages)
|
||||
|
||||
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=List[Document])
|
||||
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
|
||||
async def list_documents(
|
||||
source_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
source_id: uuid.UUID,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
List all documents associated with a data source.
|
||||
"""
|
||||
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
|
||||
return documents
|
||||
return GetSourceDocumentsResponse(documents=documents)
|
||||
|
||||
return router
|
||||
|
@ -1,9 +1,11 @@
|
||||
import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
@ -11,92 +13,121 @@ from memgpt.server.server import SyncServer
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ListToolsResponse(BaseModel):
|
||||
tools: List[ToolModel] = Field(..., description="List of tools (functions).")
|
||||
|
||||
|
||||
class CreateToolRequest(BaseModel):
|
||||
json_schema: dict = Field(..., description="JSON schema of the tool.") # NOT OpenAI - just has `name`
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
update: Optional[bool] = Field(False, description="Update the tool if it already exists.")
|
||||
|
||||
|
||||
class CreateToolResponse(BaseModel):
|
||||
tool: ToolModel = Field(..., description="Information about the newly created tool.")
|
||||
|
||||
|
||||
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.delete("/tools/{tool_id}", tags=["tools"])
|
||||
@router.delete("/tools/{tool_name}", tags=["tools"])
|
||||
async def delete_tool(
|
||||
tool_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
tool_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Delete a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
server.delete_tool(id)
|
||||
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
|
||||
server.ms.delete_tool(name=tool_name, user_id=user_id)
|
||||
|
||||
@router.get("/tools/{tool_id}", tags=["tools"], response_model=Tool)
|
||||
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
|
||||
async def get_tool(
|
||||
tool_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
tool = server.get_tool(tool_id)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
|
||||
return tool
|
||||
|
||||
@router.get("/tools/name/{tool_name}", tags=["tools"], response_model=str)
|
||||
async def get_tool_id(
|
||||
tool_name: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a tool by name
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
tool = server.get_tool_id(tool_name, user_id=user_id)
|
||||
tool = server.ms.get_tool(tool_name=tool_name, user_id=user_id)
|
||||
if tool is None:
|
||||
# return 404 error
|
||||
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
|
||||
return tool
|
||||
|
||||
@router.get("/tools", tags=["tools"], response_model=List[Tool])
|
||||
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
|
||||
async def list_all_tools(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all tools available to agents created by a user
|
||||
"""
|
||||
# Clear the interface
|
||||
interface.clear()
|
||||
return server.list_tools(user_id)
|
||||
tools = server.ms.list_tools(user_id=user_id)
|
||||
return ListToolsResponse(tools=tools)
|
||||
|
||||
@router.post("/tools", tags=["tools"], response_model=Tool)
|
||||
@router.post("/tools", tags=["tools"], response_model=ToolModel)
|
||||
async def create_tool(
|
||||
request: ToolCreate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
request: CreateToolRequest = Body(...),
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Create a new tool
|
||||
"""
|
||||
try:
|
||||
return server.create_tool(request, user_id=user_id)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
|
||||
# NOTE: horrifying code, should be replaced when we migrate dev portal
|
||||
from memgpt.agent import Agent # nasty: need agent to be defined
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
|
||||
name = request.json_schema["name"]
|
||||
|
||||
import ast
|
||||
|
||||
parsed_code = ast.parse(request.source_code)
|
||||
function_names = []
|
||||
|
||||
# Function to find and print function names
|
||||
def find_function_names(node):
|
||||
for child in ast.iter_child_nodes(node):
|
||||
if isinstance(child, ast.FunctionDef):
|
||||
# Print the name of the function
|
||||
function_names.append(child.name)
|
||||
# Recurse into child nodes
|
||||
find_function_names(child)
|
||||
|
||||
# Find and print function names
|
||||
find_function_names(parsed_code)
|
||||
assert len(function_names) == 1, f"Expected 1 function, found {len(function_names)}: {function_names}"
|
||||
|
||||
# generate JSON schema
|
||||
env = {}
|
||||
env.update(globals())
|
||||
exec(request.source_code, env)
|
||||
func = env.get(function_names[0])
|
||||
json_schema = generate_schema(func, name=name)
|
||||
from pprint import pprint
|
||||
|
||||
pprint(json_schema)
|
||||
|
||||
@router.post("/tools/{tool_id}", tags=["tools"], response_model=Tool)
|
||||
async def update_tool(
|
||||
tool_id: str,
|
||||
request: ToolUpdate = Body(...),
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
"""
|
||||
Update an existing tool
|
||||
"""
|
||||
try:
|
||||
# TODO: check that the user has access to this tool?
|
||||
return server.update_tool(request)
|
||||
|
||||
return server.create_tool(
|
||||
# json_schema=request.json_schema, # TODO: add back
|
||||
json_schema=json_schema,
|
||||
source_code=request.source_code,
|
||||
source_type=request.source_type,
|
||||
tags=request.tags,
|
||||
user_id=user_id,
|
||||
exists_ok=request.update,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update tool: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}, exists_ok={request.update}")
|
||||
|
||||
return router
|
||||
|
@ -1,14 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Generator, Union
|
||||
|
||||
from memgpt.constants import JSON_ENSURE_ASCII
|
||||
|
||||
SSE_PREFIX = "data: "
|
||||
SSE_SUFFIX = "\n\n"
|
||||
SSE_FINISH_MSG = "[DONE]" # mimic openai
|
||||
SSE_ARTIFICIAL_DELAY = 0.1
|
||||
|
||||
@ -17,7 +13,27 @@ def sse_formatter(data: Union[dict, str]) -> str:
|
||||
"""Prefix with 'data: ', and always include double newlines"""
|
||||
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
|
||||
data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data
|
||||
return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}"
|
||||
return f"data: {data_str}\n\n"
|
||||
|
||||
|
||||
async def sse_generator(generator: Generator[dict, None, None]) -> Generator[str, None, None]:
|
||||
"""Generator that returns 'data: dict' formatted items, e.g.:
|
||||
|
||||
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||
|
||||
data: [DONE]
|
||||
|
||||
"""
|
||||
try:
|
||||
for msg in generator:
|
||||
yield sse_formatter(msg)
|
||||
if SSE_ARTIFICIAL_DELAY:
|
||||
await asyncio.sleep(SSE_ARTIFICIAL_DELAY) # Sleep to prevent a tight loop, adjust time as needed
|
||||
except Exception as e:
|
||||
yield sse_formatter({"error": f"{str(e)}"})
|
||||
yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete
|
||||
|
||||
|
||||
async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
|
||||
@ -33,20 +49,12 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
|
||||
try:
|
||||
async for chunk in generator:
|
||||
# yield f"data: {json.dumps(chunk)}\n\n"
|
||||
if isinstance(chunk, BaseModel):
|
||||
chunk = chunk.model_dump()
|
||||
elif isinstance(chunk, Enum):
|
||||
chunk = str(chunk.value)
|
||||
elif not isinstance(chunk, dict):
|
||||
chunk = str(chunk)
|
||||
yield sse_formatter(chunk)
|
||||
|
||||
except Exception as e:
|
||||
print("stream decoder hit error:", e)
|
||||
print(traceback.print_stack())
|
||||
yield sse_formatter({"error": "stream decoder encountered an error"})
|
||||
|
||||
finally:
|
||||
# yield "data: [DONE]\n\n"
|
||||
if finish_message:
|
||||
# Signal that the stream is complete
|
||||
yield sse_formatter(SSE_FINISH_MSG)
|
||||
yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -43,12 +43,5 @@ class Settings(BaseSettings):
|
||||
return None
|
||||
|
||||
|
||||
class TestSettings(Settings):
|
||||
model_config = SettingsConfigDict(env_prefix="memgpt_test_")
|
||||
|
||||
memgpt_dir: Optional[Path] = Field(Path.home() / ".memgpt/test", env="MEMGPT_TEST_DIR")
|
||||
|
||||
|
||||
# singleton
|
||||
settings = Settings()
|
||||
test_settings = TestSettings()
|
||||
|
@ -7,9 +7,9 @@ from rich.console import Console
|
||||
from rich.live import Live
|
||||
from rich.markup import escape
|
||||
|
||||
from memgpt.data_types import Message
|
||||
from memgpt.interface import CLIInterface
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.openai.chat_completion_response import (
|
||||
from memgpt.models.chat_completion_response import (
|
||||
ChatCompletionChunkResponse,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
|
@ -33,8 +33,8 @@ from memgpt.constants import (
|
||||
MEMGPT_DIR,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
)
|
||||
from memgpt.models.chat_completion_response import ChatCompletionResponse
|
||||
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
||||
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
|
||||
DEBUG = False
|
||||
if "LOG_LEVEL" in os.environ:
|
||||
@ -468,17 +468,6 @@ NOUN_BANK = [
|
||||
]
|
||||
|
||||
|
||||
def deduplicate(target_list: list) -> list:
|
||||
seen = set()
|
||||
dedup_list = []
|
||||
for i in target_list:
|
||||
if i not in seen:
|
||||
seen.add(i)
|
||||
dedup_list.append(i)
|
||||
|
||||
return dedup_list
|
||||
|
||||
|
||||
def smart_urljoin(base_url: str, relative_url: str) -> str:
|
||||
"""urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off"""
|
||||
if not base_url.endswith("/"):
|
||||
|
4360
poetry.lock
generated
4360
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -28,13 +28,14 @@ pgvector = { version = "^0.2.3", optional = true }
|
||||
pre-commit = {version = "^3.5.0", optional = true }
|
||||
pg8000 = {version = "^1.30.3", optional = true}
|
||||
websockets = {version = "^12.0", optional = true}
|
||||
docstring-parser = ">=0.16,<0.17"
|
||||
docstring-parser = "^0.15"
|
||||
lancedb = "^0.3.3"
|
||||
httpx = "^0.25.2"
|
||||
numpy = "^1.26.2"
|
||||
demjson3 = "^3.0.6"
|
||||
#tiktoken = ">=0.7.0,<0.8.0"
|
||||
tiktoken = "^0.5.1"
|
||||
pyyaml = "^6.0.1"
|
||||
chromadb = ">=0.4.24,<0.5.0"
|
||||
chromadb = "^0.5.0"
|
||||
sqlalchemy-json = "^0.7.0"
|
||||
fastapi = {version = "^0.104.1", optional = true}
|
||||
uvicorn = {version = "^0.24.0.post1", optional = true}
|
||||
@ -50,7 +51,7 @@ pymilvus = {version ="^2.4.3", optional = true}
|
||||
python-box = "^7.1.1"
|
||||
sqlmodel = "^0.0.16"
|
||||
autoflake = {version = "^2.3.0", optional = true}
|
||||
llama-index = "^0.10.65"
|
||||
llama-index = "^0.10.27"
|
||||
llama-index-embeddings-openai = "^0.1.1"
|
||||
llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true}
|
||||
llama-index-embeddings-azure-openai = "^0.1.6"
|
||||
@ -58,15 +59,12 @@ python-multipart = "^0.0.9"
|
||||
sqlalchemy-utils = "^0.41.2"
|
||||
pytest-order = {version = "^1.2.0", optional = true}
|
||||
pytest-asyncio = {version = "^0.23.2", optional = true}
|
||||
pytest = { version = "^7.4.4", optional = true }
|
||||
pydantic-settings = "^2.2.1"
|
||||
httpx-sse = "^0.4.0"
|
||||
isort = { version = "^5.13.2", optional = true }
|
||||
llama-index-embeddings-ollama = {version = "^0.1.2", optional = true}
|
||||
crewai = {version = "^0.41.1", optional = true}
|
||||
crewai-tools = {version = "^0.8.3", optional = true}
|
||||
docker = {version = "^7.1.0", optional = true}
|
||||
tiktoken = "^0.7.0"
|
||||
nltk = "^3.8.1"
|
||||
protobuf = "3.20.0"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["llama-index-embeddings-huggingface"]
|
||||
@ -77,7 +75,6 @@ server = ["websockets", "fastapi", "uvicorn"]
|
||||
autogen = ["pyautogen"]
|
||||
qdrant = ["qdrant-client"]
|
||||
ollama = ["llama-index-embeddings-ollama"]
|
||||
crewai-tools = ["crewai", "docker", "crewai-tools"]
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
black = "^24.4.2"
|
||||
|
@ -1,3 +1,3 @@
|
||||
# from tests.config import TestMGPTConfig
|
||||
#
|
||||
# TEST_MEMGPT_CONFIG = TestMGPTConfig()
|
||||
from tests.config import TestMGPTConfig
|
||||
|
||||
TEST_MEMGPT_CONFIG = TestMGPTConfig()
|
||||
|
@ -3,4 +3,4 @@ pythonpath = /memgpt
|
||||
testpaths = /tests
|
||||
asyncio_mode = auto
|
||||
filterwarnings =
|
||||
ignore::pytest.PytestRemovedIn9Warning
|
||||
ignore::pytest.PytestRemovedIn8Warning
|
||||
|
@ -1,9 +1,11 @@
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from memgpt import Admin
|
||||
from tests.test_client import _reset_config, run_server
|
||||
|
||||
test_base_url = "http://localhost:8283"
|
||||
|
||||
@ -11,13 +13,6 @@ test_base_url = "http://localhost:8283"
|
||||
test_server_token = "test_server_token"
|
||||
|
||||
|
||||
def run_server():
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def start_uvicorn_server():
|
||||
"""Starts Uvicorn server in a background thread."""
|
||||
@ -39,85 +34,91 @@ def admin_client():
|
||||
|
||||
|
||||
def test_admin_client(admin_client):
|
||||
_reset_config()
|
||||
|
||||
# create a user
|
||||
user_name = "test_user"
|
||||
user1 = admin_client.create_user(user_name)
|
||||
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
|
||||
user_id = uuid.uuid4()
|
||||
create_user1_response = admin_client.create_user(user_id)
|
||||
assert user_id == create_user1_response.user_id, f"Expected {user_id}, got {create_user1_response.user_id}"
|
||||
|
||||
# create another user
|
||||
user2 = admin_client.create_user()
|
||||
create_user_2_response = admin_client.create_user()
|
||||
|
||||
# create keys
|
||||
key1_name = "test_key1"
|
||||
key2_name = "test_key2"
|
||||
api_key1 = admin_client.create_key(user1.id, key1_name)
|
||||
admin_client.create_key(user2.id, key2_name)
|
||||
create_key1_response = admin_client.create_key(user_id, key1_name)
|
||||
create_key2_response = admin_client.create_key(create_user_2_response.user_id, key2_name)
|
||||
|
||||
# list users
|
||||
users = admin_client.get_users()
|
||||
assert len(users) == 2
|
||||
assert user1.id in [user.id for user in users]
|
||||
assert user2.id in [user.id for user in users]
|
||||
assert len(users.user_list) == 2
|
||||
print(users.user_list)
|
||||
assert user_id in [uuid.UUID(u["user_id"]) for u in users.user_list]
|
||||
|
||||
# list keys
|
||||
user1_keys = admin_client.get_keys(user1.id)
|
||||
assert len(user1_keys) == 1, f"Expected 1 keys, got {user1_keys}"
|
||||
assert api_key1.key == user1_keys[0].key
|
||||
user1_keys = admin_client.get_keys(user_id)
|
||||
assert len(user1_keys) == 2, f"Expected 2 keys, got {user1_keys}"
|
||||
assert create_key1_response.api_key in user1_keys, f"Expected {create_key1_response.api_key} in {user1_keys}"
|
||||
assert create_user1_response.api_key in user1_keys, f"Expected {create_user1_response.api_key} in {user1_keys}"
|
||||
|
||||
# delete key
|
||||
deleted_key1 = admin_client.delete_key(api_key1.key)
|
||||
assert deleted_key1.key == api_key1.key
|
||||
assert len(admin_client.get_keys(user1.id)) == 0
|
||||
delete_key1_response = admin_client.delete_key(create_key1_response.api_key)
|
||||
assert delete_key1_response.api_key_deleted == create_key1_response.api_key
|
||||
assert len(admin_client.get_keys(user_id)) == 1
|
||||
delete_key2_response = admin_client.delete_key(create_key2_response.api_key)
|
||||
assert delete_key2_response.api_key_deleted == create_key2_response.api_key
|
||||
assert len(admin_client.get_keys(create_user_2_response.user_id)) == 1
|
||||
|
||||
# delete users
|
||||
deleted_user1 = admin_client.delete_user(user1.id)
|
||||
assert deleted_user1.id == user1.id
|
||||
deleted_user2 = admin_client.delete_user(user2.id)
|
||||
assert deleted_user2.id == user2.id
|
||||
delete_user1_response = admin_client.delete_user(user_id)
|
||||
assert delete_user1_response.user_id_deleted == user_id
|
||||
delete_user2_response = admin_client.delete_user(create_user_2_response.user_id)
|
||||
assert delete_user2_response.user_id_deleted == create_user_2_response.user_id
|
||||
|
||||
# list users
|
||||
users = admin_client.get_users()
|
||||
assert len(users) == 0, f"Expected 0 users, got {users}"
|
||||
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||
|
||||
|
||||
# def test_get_users_pagination(admin_client):
|
||||
#
|
||||
# page_size = 5
|
||||
# num_users = 7
|
||||
# expected_users_remainder = num_users - page_size
|
||||
#
|
||||
# # create users
|
||||
# all_user_ids = []
|
||||
# for i in range(num_users):
|
||||
#
|
||||
# user_id = uuid.uuid4()
|
||||
# all_user_ids.append(user_id)
|
||||
# key_name = "test_key" + f"{i}"
|
||||
#
|
||||
# create_user_response = admin_client.create_user(user_id)
|
||||
# admin_client.create_key(create_user_response.user_id, key_name)
|
||||
#
|
||||
# # list users in page 1
|
||||
# get_all_users_response1 = admin_client.get_users(limit=page_size)
|
||||
# cursor1 = get_all_users_response1.cursor
|
||||
# user_list1 = get_all_users_response1.user_list
|
||||
# assert len(user_list1) == page_size
|
||||
#
|
||||
# # list users in page 2 using cursor
|
||||
# get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
|
||||
# cursor2 = get_all_users_response2.cursor
|
||||
# user_list2 = get_all_users_response2.user_list
|
||||
#
|
||||
# assert len(user_list2) == expected_users_remainder
|
||||
# assert cursor1 != cursor2
|
||||
#
|
||||
# # delete users
|
||||
# clean_up_users_and_keys(all_user_ids)
|
||||
#
|
||||
# # list users to check pagination with no users
|
||||
# users = admin_client.get_users()
|
||||
# assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||
def test_get_users_pagination(admin_client):
|
||||
_reset_config()
|
||||
|
||||
page_size = 5
|
||||
num_users = 7
|
||||
expected_users_remainder = num_users - page_size
|
||||
|
||||
# create users
|
||||
all_user_ids = []
|
||||
for i in range(num_users):
|
||||
|
||||
user_id = uuid.uuid4()
|
||||
all_user_ids.append(user_id)
|
||||
key_name = "test_key" + f"{i}"
|
||||
|
||||
create_user_response = admin_client.create_user(user_id)
|
||||
admin_client.create_key(create_user_response.user_id, key_name)
|
||||
|
||||
# list users in page 1
|
||||
get_all_users_response1 = admin_client.get_users(limit=page_size)
|
||||
cursor1 = get_all_users_response1.cursor
|
||||
user_list1 = get_all_users_response1.user_list
|
||||
assert len(user_list1) == page_size
|
||||
|
||||
# list users in page 2 using cursor
|
||||
get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
|
||||
cursor2 = get_all_users_response2.cursor
|
||||
user_list2 = get_all_users_response2.user_list
|
||||
|
||||
assert len(user_list2) == expected_users_remainder
|
||||
assert cursor1 != cursor2
|
||||
|
||||
# delete users
|
||||
clean_up_users_and_keys(all_user_ids)
|
||||
|
||||
# list users to check pagination with no users
|
||||
users = admin_client.get_users()
|
||||
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
|
||||
|
||||
|
||||
def clean_up_users_and_keys(user_id_list):
|
||||
|
@ -1,41 +1,38 @@
|
||||
# TODO: add back
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
# import os
|
||||
# import subprocess
|
||||
#
|
||||
# import pytest
|
||||
#
|
||||
#
|
||||
# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
|
||||
# def test_agent_groupchat():
|
||||
#
|
||||
# # Define the path to the script you want to test
|
||||
# script_path = "memgpt/autogen/examples/agent_groupchat.py"
|
||||
#
|
||||
# # Dynamically get the project's root directory (assuming this script is run from the root)
|
||||
# # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
# # print(project_root)
|
||||
# # project_root = os.path.join(project_root, "MemGPT")
|
||||
# # print(project_root)
|
||||
# # sys.exit(1)
|
||||
#
|
||||
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
# project_root = os.path.join(project_root, "memgpt")
|
||||
# print(f"Adding the following to PATH: {project_root}")
|
||||
#
|
||||
# # Prepare the environment, adding the project root to PYTHONPATH
|
||||
# env = os.environ.copy()
|
||||
# env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}"
|
||||
#
|
||||
# # Run the script using subprocess.run
|
||||
# # Capture the output (stdout) and the exit code
|
||||
# # result = subprocess.run(["python", script_path], capture_output=True, text=True)
|
||||
# result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True)
|
||||
#
|
||||
# # Check the exit code (0 indicates success)
|
||||
# assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}"
|
||||
#
|
||||
# # Optionally, check the output for expected content
|
||||
# # For example, if you expect a specific line in the output, uncomment and adapt the following line:
|
||||
# # assert "expected output" in result.stdout, "Expected output not found in script's output"
|
||||
#
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
|
||||
def test_agent_groupchat():
|
||||
|
||||
# Define the path to the script you want to test
|
||||
script_path = "memgpt/autogen/examples/agent_groupchat.py"
|
||||
|
||||
# Dynamically get the project's root directory (assuming this script is run from the root)
|
||||
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
# print(project_root)
|
||||
# project_root = os.path.join(project_root, "MemGPT")
|
||||
# print(project_root)
|
||||
# sys.exit(1)
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
project_root = os.path.join(project_root, "memgpt")
|
||||
print(f"Adding the following to PATH: {project_root}")
|
||||
|
||||
# Prepare the environment, adding the project root to PYTHONPATH
|
||||
env = os.environ.copy()
|
||||
env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}"
|
||||
|
||||
# Run the script using subprocess.run
|
||||
# Capture the output (stdout) and the exit code
|
||||
# result = subprocess.run(["python", script_path], capture_output=True, text=True)
|
||||
result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True)
|
||||
|
||||
# Check the exit code (0 indicates success)
|
||||
assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}"
|
||||
|
||||
# Optionally, check the output for expected content
|
||||
# For example, if you expect a specific line in the output, uncomment and adapt the following line:
|
||||
# assert "expected output" in result.stdout, "Expected output not found in script's output"
|
||||
|
@ -26,7 +26,7 @@ def agent_obj():
|
||||
agent_state = client.create_agent()
|
||||
|
||||
global agent_obj
|
||||
agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id)
|
||||
agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id)
|
||||
yield agent_obj
|
||||
|
||||
client.delete_agent(agent_obj.agent_state.id)
|
||||
|
@ -1,12 +1,17 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"])
|
||||
import pexpect
|
||||
from prettytable.colortable import ColorTable
|
||||
|
||||
from memgpt.cli.cli_config import ListChoice, add, delete
|
||||
from memgpt.cli.cli_config import list as list_command
|
||||
|
||||
from .constants import TIMEOUT
|
||||
from .utils import create_config
|
||||
|
||||
# def test_configure_memgpt():
|
||||
# configure_memgpt()
|
||||
|
||||
@ -42,3 +47,41 @@ def test_cli_config():
|
||||
assert "test data" in row
|
||||
# delete
|
||||
delete(option=option, name="test")
|
||||
|
||||
|
||||
def test_save_load():
|
||||
# configure_memgpt() # rely on configure running first^
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
else:
|
||||
create_config("memgpt_hosted")
|
||||
|
||||
child = pexpect.spawn("poetry run memgpt run --agent test_save_load --first --strip-ui")
|
||||
|
||||
child.expect("Enter your message:", timeout=TIMEOUT)
|
||||
child.sendline()
|
||||
|
||||
child.expect("Empty input received. Try again!", timeout=TIMEOUT)
|
||||
child.sendline("/save")
|
||||
|
||||
child.expect("Enter your message:", timeout=TIMEOUT)
|
||||
child.sendline("/exit")
|
||||
|
||||
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
|
||||
child.close()
|
||||
assert child.isalive() is False, "CLI should have terminated."
|
||||
assert child.exitstatus == 0, "CLI did not exit cleanly."
|
||||
|
||||
child = pexpect.spawn("poetry run memgpt run --agent test_save_load --first --strip-ui")
|
||||
child.expect("Using existing agent test_save_load", timeout=TIMEOUT)
|
||||
child.expect("Enter your message:", timeout=TIMEOUT)
|
||||
child.sendline("/exit")
|
||||
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
|
||||
child.close()
|
||||
assert child.isalive() is False, "CLI should have terminated."
|
||||
assert child.exitstatus == 0, "CLI did not exit cleanly."
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_configure_memgpt()
|
||||
test_save_load()
|
||||
|
@ -7,11 +7,12 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.usage import MemGPTUsageStatistics
|
||||
|
||||
# from tests.utils import create_config
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import Preset # TODO move to PresetModel
|
||||
from memgpt.settings import settings
|
||||
from tests.utils import create_config
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
# test_preset_name = "test_preset"
|
||||
@ -20,16 +21,44 @@ test_agent_state = None
|
||||
client = None
|
||||
|
||||
test_agent_state_post_message = None
|
||||
test_user_id = uuid.uuid4()
|
||||
|
||||
|
||||
# admin credentials
|
||||
test_server_token = "test_server_token"
|
||||
|
||||
|
||||
def _reset_config():
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
create_config("memgpt_hosted")
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# set to use postgres
|
||||
config.archival_storage_uri = db_url
|
||||
config.recall_storage_uri = db_url
|
||||
config.metadata_storage_uri = db_url
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
config.metadata_storage_type = "postgres"
|
||||
config.save()
|
||||
credentials.save()
|
||||
print("_reset_config :: ", config.config_path)
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
|
||||
@ -39,8 +68,7 @@ def run_server():
|
||||
|
||||
# Fixture to create clients with different configurations
|
||||
@pytest.fixture(
|
||||
# params=[{"server": True}, {"server": False}], # whether to use REST API server
|
||||
params=[{"server": True}], # whether to use REST API server
|
||||
params=[{"server": True}, {"server": False}], # whether to use REST API server
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
@ -58,20 +86,21 @@ def client(request):
|
||||
print("Running client tests with server:", server_url)
|
||||
# create user via admin client
|
||||
admin = Admin(server_url, test_server_token)
|
||||
user = admin.create_user() # Adjust as per your client's method
|
||||
api_key = admin.create_key(user.id)
|
||||
client = create_client(base_url=server_url, token=api_key.key) # This yields control back to the test function
|
||||
response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
token = response.api_key
|
||||
|
||||
else:
|
||||
# use local client (no server)
|
||||
token = None
|
||||
server_url = None
|
||||
client = create_client()
|
||||
|
||||
client = create_client(base_url=server_url, token=token) # This yields control back to the test function
|
||||
try:
|
||||
yield client
|
||||
finally:
|
||||
# cleanup user
|
||||
if server_url:
|
||||
admin.delete_user(user.id)
|
||||
admin.delete_user(test_user_id) # Adjust as per your client's method
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@ -86,6 +115,7 @@ def agent(client):
|
||||
|
||||
|
||||
def test_agent(client, agent):
|
||||
_reset_config()
|
||||
|
||||
# test client.rename_agent
|
||||
new_name = "RenamedTestAgent"
|
||||
@ -101,84 +131,61 @@ def test_agent(client, agent):
|
||||
|
||||
|
||||
def test_memory(client, agent):
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
memory_response = client.get_in_context_memory(agent_id=agent.id)
|
||||
memory_response = client.get_agent_memory(agent_id=agent.id)
|
||||
print("MEMORY", memory_response)
|
||||
|
||||
updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"}
|
||||
client.update_in_context_memory(agent_id=agent.id, section="human", value=updated_memory["human"])
|
||||
client.update_in_context_memory(agent_id=agent.id, section="persona", value=updated_memory["persona"])
|
||||
updated_memory_response = client.get_in_context_memory(agent_id=agent.id)
|
||||
client.update_agent_core_memory(agent_id=agent.id, new_memory_contents=updated_memory)
|
||||
updated_memory_response = client.get_agent_memory(agent_id=agent.id)
|
||||
assert (
|
||||
updated_memory_response.get_block("human").value == updated_memory["human"]
|
||||
and updated_memory_response.get_block("persona").value == updated_memory["persona"]
|
||||
updated_memory_response.core_memory.human == updated_memory["human"]
|
||||
and updated_memory_response.core_memory.persona == updated_memory["persona"]
|
||||
), "Memory update failed"
|
||||
|
||||
|
||||
def test_agent_interactions(client, agent):
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
message = "Hello, agent!"
|
||||
print("Sending message", message)
|
||||
response = client.user_message(agent_id=agent.id, message=message)
|
||||
print("Response", response)
|
||||
assert isinstance(response.usage, MemGPTUsageStatistics)
|
||||
assert response.usage.step_count == 1
|
||||
assert response.usage.total_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert isinstance(response.messages[0], Message)
|
||||
print(response.messages)
|
||||
message_response = client.user_message(agent_id=agent.id, message=message)
|
||||
|
||||
# TODO: add streaming tests
|
||||
command = "/memory"
|
||||
command_response = client.run_command(agent_id=agent.id, command=command)
|
||||
print("command", command_response)
|
||||
|
||||
|
||||
def test_archival_memory(client, agent):
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
memory_content = "Archival memory content"
|
||||
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)[0]
|
||||
print("Inserted memory", insert_response.text, insert_response.id)
|
||||
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)
|
||||
assert insert_response, "Inserting archival memory failed"
|
||||
|
||||
archival_memory_response = client.get_archival_memory(agent_id=agent.id, limit=1)
|
||||
archival_memories = [memory.text for memory in archival_memory_response]
|
||||
archival_memory_response = client.get_agent_archival_memory(agent_id=agent.id, limit=1)
|
||||
print("MEMORY")
|
||||
archival_memories = [memory.contents for memory in archival_memory_response.archival_memory]
|
||||
assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}"
|
||||
|
||||
memory_id_to_delete = archival_memory_response[0].id
|
||||
memory_id_to_delete = archival_memory_response.archival_memory[0].id
|
||||
client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete)
|
||||
|
||||
# add archival memory
|
||||
memory_str = "I love chats"
|
||||
passage = client.insert_archival_memory(agent.id, memory=memory_str)[0]
|
||||
|
||||
# list archival memory
|
||||
passages = client.get_archival_memory(agent.id)
|
||||
assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}"
|
||||
|
||||
# get archival memory summary
|
||||
archival_summary = client.get_archival_memory_summary(agent.id)
|
||||
assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}"
|
||||
|
||||
# delete archival memory
|
||||
client.delete_archival_memory(agent.id, passage.id)
|
||||
|
||||
# TODO: check deletion
|
||||
client.get_archival_memory(agent.id)
|
||||
|
||||
|
||||
def test_messages(client, agent):
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
assert send_message_response, "Sending message failed"
|
||||
|
||||
messages_response = client.get_messages(agent_id=agent.id, limit=1)
|
||||
assert len(messages_response) > 0, "Retrieving messages failed"
|
||||
assert len(messages_response.messages) > 0, "Retrieving messages failed"
|
||||
|
||||
|
||||
def test_humans_personas(client, agent):
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
humans_response = client.list_humans()
|
||||
print("HUMANS", humans_response)
|
||||
@ -187,20 +194,18 @@ def test_humans_personas(client, agent):
|
||||
print("PERSONAS", personas_response)
|
||||
|
||||
persona_name = "TestPersona"
|
||||
persona_id = client.get_persona_id(persona_name)
|
||||
if persona_id:
|
||||
client.delete_persona(persona_id)
|
||||
if client.get_persona(persona_name):
|
||||
client.delete_persona(persona_name)
|
||||
persona = client.create_persona(name=persona_name, text="Persona text")
|
||||
assert persona.name == persona_name
|
||||
assert persona.value == "Persona text", "Creating persona failed"
|
||||
assert persona.text == "Persona text", "Creating persona failed"
|
||||
|
||||
human_name = "TestHuman"
|
||||
human_id = client.get_human_id(human_name)
|
||||
if human_id:
|
||||
client.delete_human(human_id)
|
||||
if client.get_human(human_name):
|
||||
client.delete_human(human_name)
|
||||
human = client.create_human(name=human_name, text="Human text")
|
||||
assert human.name == human_name
|
||||
assert human.value == "Human text", "Creating human failed"
|
||||
assert human.text == "Human text", "Creating human failed"
|
||||
|
||||
|
||||
# def test_tools(client, agent):
|
||||
@ -213,14 +218,11 @@ def test_humans_personas(client, agent):
|
||||
|
||||
|
||||
def test_config(client, agent):
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
models_response = client.list_models()
|
||||
print("MODELS", models_response)
|
||||
|
||||
embeddings_response = client.list_embedding_models()
|
||||
print("EMBEDDINGS", embeddings_response)
|
||||
|
||||
# TODO: add back
|
||||
# config_response = client.get_config()
|
||||
# TODO: ensure config is the same as the one in the server
|
||||
@ -228,7 +230,7 @@ def test_config(client, agent):
|
||||
|
||||
|
||||
def test_sources(client, agent):
|
||||
# _reset_config()
|
||||
_reset_config()
|
||||
|
||||
if not hasattr(client, "base_url"):
|
||||
pytest.skip("Skipping test_sources because base_url is None")
|
||||
@ -236,7 +238,7 @@ def test_sources(client, agent):
|
||||
# list sources
|
||||
sources = client.list_sources()
|
||||
print("listed sources", sources)
|
||||
assert len(sources) == 0
|
||||
assert len(sources.sources) == 0
|
||||
|
||||
# create a source
|
||||
source = client.create_source(name="test_source")
|
||||
@ -244,53 +246,36 @@ def test_sources(client, agent):
|
||||
# list sources
|
||||
sources = client.list_sources()
|
||||
print("listed sources", sources)
|
||||
assert len(sources) == 1
|
||||
|
||||
# TODO: add back?
|
||||
assert sources[0].metadata_["num_passages"] == 0
|
||||
assert sources[0].metadata_["num_documents"] == 0
|
||||
|
||||
# update the source
|
||||
original_id = source.id
|
||||
original_name = source.name
|
||||
new_name = original_name + "_new"
|
||||
client.update_source(source_id=source.id, name=new_name)
|
||||
|
||||
# get the source name (check that it's been updated)
|
||||
source = client.get_source(source_id=source.id)
|
||||
assert source.name == new_name
|
||||
assert source.id == original_id
|
||||
|
||||
# get the source id (make sure that it's the same)
|
||||
assert str(original_id) == client.get_source_id(source_name=new_name)
|
||||
assert len(sources.sources) == 1
|
||||
assert sources.sources[0].metadata_["num_passages"] == 0
|
||||
assert sources.sources[0].metadata_["num_documents"] == 0
|
||||
|
||||
# check agent archival memory size
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||
print(archival_memories)
|
||||
assert len(archival_memories) == 0
|
||||
|
||||
# load a file into a source
|
||||
filename = "CONTRIBUTING.md"
|
||||
upload_job = client.load_file_into_source(filename=filename, source_id=source.id)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata)
|
||||
|
||||
# TODO: make sure things run in the right order
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||
assert len(archival_memories) == 0
|
||||
|
||||
# attach a source
|
||||
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
|
||||
|
||||
# list archival memory
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||
# print(archival_memories)
|
||||
assert len(archival_memories) == 20 or len(archival_memories) == 21
|
||||
|
||||
# check number of passages
|
||||
sources = client.list_sources()
|
||||
# TODO: add back?
|
||||
# assert sources.sources[0].metadata_["num_passages"] > 0
|
||||
# assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
|
||||
assert sources.sources[0].metadata_["num_passages"] > 0
|
||||
assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
|
||||
print(sources)
|
||||
|
||||
# detach the source
|
||||
@ -299,3 +284,80 @@ def test_sources(client, agent):
|
||||
|
||||
# delete the source
|
||||
client.delete_source(source.id)
|
||||
|
||||
|
||||
# def test_presets(client, agent):
|
||||
# _reset_config()
|
||||
#
|
||||
# # new_preset = Preset(
|
||||
# # # user_id=client.user_id,
|
||||
# # name="pytest_test_preset",
|
||||
# # description="DUMMY_DESCRIPTION",
|
||||
# # system="DUMMY_SYSTEM",
|
||||
# # persona="DUMMY_PERSONA",
|
||||
# # persona_name="DUMMY_PERSONA_NAME",
|
||||
# # human="DUMMY_HUMAN",
|
||||
# # human_name="DUMMY_HUMAN_NAME",
|
||||
# # functions_schema=[
|
||||
# # {
|
||||
# # "name": "send_message",
|
||||
# # "json_schema": {
|
||||
# # "name": "send_message",
|
||||
# # "description": "Sends a message to the human user.",
|
||||
# # "parameters": {
|
||||
# # "type": "object",
|
||||
# # "properties": {
|
||||
# # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
|
||||
# # },
|
||||
# # "required": ["message"],
|
||||
# # },
|
||||
# # },
|
||||
# # "tags": ["memgpt-base"],
|
||||
# # "source_type": "python",
|
||||
# # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
|
||||
# # }
|
||||
# # ],
|
||||
# # )
|
||||
#
|
||||
# ## List all presets and make sure the preset is NOT in the list
|
||||
# # all_presets = client.list_presets()
|
||||
# # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
# # Create a preset
|
||||
# new_preset = client.create_preset(name="pytest_test_preset")
|
||||
#
|
||||
# # List all presets and make sure the preset is in the list
|
||||
# all_presets = client.list_presets()
|
||||
# assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
#
|
||||
# # Delete the preset
|
||||
# client.delete_preset(preset_id=new_preset.id)
|
||||
#
|
||||
# # List all presets and make sure the preset is NOT in the list
|
||||
# all_presets = client.list_presets()
|
||||
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
|
||||
|
||||
# def test_tools(client, agent):
|
||||
#
|
||||
# # load a function
|
||||
# file_path = "tests/data/functions/dump_json.py"
|
||||
# module_name = "dump_json"
|
||||
#
|
||||
# # list functions
|
||||
# response = client.list_tools()
|
||||
# orig_tools = response.tools
|
||||
# print(orig_tools)
|
||||
#
|
||||
# # add the tool
|
||||
# create_tool_response = client.create_tool(name=module_name, file_path=file_path)
|
||||
# print(create_tool_response)
|
||||
#
|
||||
# # list functions
|
||||
# response = client.list_tools()
|
||||
# new_tools = response.tools
|
||||
# assert module_name in [tool.name for tool in new_tools]
|
||||
# # assert len(new_tools) == len(orig_tools) + 1
|
||||
#
|
||||
# # TODO: add a function to a preset
|
||||
#
|
||||
# # TODO: add a function to an agent
|
||||
|
@ -1,142 +1,142 @@
|
||||
# TODO: add back when messaging works
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# import os
|
||||
# import threading
|
||||
# import time
|
||||
# import uuid
|
||||
#
|
||||
# import pytest
|
||||
# from dotenv import load_dotenv
|
||||
#
|
||||
# from memgpt import Admin, create_client
|
||||
# from memgpt.config import MemGPTConfig
|
||||
# from memgpt.credentials import MemGPTCredentials
|
||||
# from memgpt.settings import settings
|
||||
# from tests.utils import create_config
|
||||
#
|
||||
# test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
## test_preset_name = "test_preset"
|
||||
# test_agent_state = None
|
||||
# client = None
|
||||
#
|
||||
# test_agent_state_post_message = None
|
||||
# test_user_id = uuid.uuid4()
|
||||
#
|
||||
#
|
||||
## admin credentials
|
||||
# test_server_token = "test_server_token"
|
||||
#
|
||||
#
|
||||
# def _reset_config():
|
||||
#
|
||||
# # Use os.getenv with a fallback to os.environ.get
|
||||
# db_url = settings.memgpt_pg_uri
|
||||
#
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# create_config("openai")
|
||||
# credentials = MemGPTCredentials(
|
||||
# openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# )
|
||||
# else: # hosted
|
||||
# create_config("memgpt_hosted")
|
||||
# credentials = MemGPTCredentials()
|
||||
#
|
||||
# config = MemGPTConfig.load()
|
||||
#
|
||||
# # set to use postgres
|
||||
# config.archival_storage_uri = db_url
|
||||
# config.recall_storage_uri = db_url
|
||||
# config.metadata_storage_uri = db_url
|
||||
# config.archival_storage_type = "postgres"
|
||||
# config.recall_storage_type = "postgres"
|
||||
# config.metadata_storage_type = "postgres"
|
||||
#
|
||||
# config.save()
|
||||
# credentials.save()
|
||||
# print("_reset_config :: ", config.config_path)
|
||||
#
|
||||
#
|
||||
# def run_server():
|
||||
#
|
||||
# load_dotenv()
|
||||
#
|
||||
# _reset_config()
|
||||
#
|
||||
# from memgpt.server.rest_api.server import start_server
|
||||
#
|
||||
# print("Starting server...")
|
||||
# start_server(debug=True)
|
||||
#
|
||||
#
|
||||
## Fixture to create clients with different configurations
|
||||
# @pytest.fixture(
|
||||
# params=[ # whether to use REST API server
|
||||
# {"server": True},
|
||||
# # {"server": False} # TODO: add when implemented
|
||||
# ],
|
||||
# scope="module",
|
||||
# )
|
||||
# def admin_client(request):
|
||||
# if request.param["server"]:
|
||||
# # get URL from enviornment
|
||||
# server_url = os.getenv("MEMGPT_SERVER_URL")
|
||||
# if server_url is None:
|
||||
# # run server in thread
|
||||
# # NOTE: must set MEMGPT_SERVER_PASS enviornment variable
|
||||
# server_url = "http://localhost:8283"
|
||||
# print("Starting server thread")
|
||||
# thread = threading.Thread(target=run_server, daemon=True)
|
||||
# thread.start()
|
||||
# time.sleep(5)
|
||||
# print("Running client tests with server:", server_url)
|
||||
# # create user via admin client
|
||||
# admin = Admin(server_url, test_server_token)
|
||||
# response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
#
|
||||
# yield admin
|
||||
#
|
||||
#
|
||||
# def test_concurrent_messages(admin_client):
|
||||
# # test concurrent messages
|
||||
#
|
||||
# # create three
|
||||
#
|
||||
# results = []
|
||||
#
|
||||
# def _send_message():
|
||||
# try:
|
||||
# print("START SEND MESSAGE")
|
||||
# response = admin_client.create_user()
|
||||
# token = response.api_key
|
||||
# client = create_client(base_url=admin_client.base_url, token=token)
|
||||
# agent = client.create_agent()
|
||||
#
|
||||
# print("Agent created", agent.id)
|
||||
#
|
||||
# st = time.time()
|
||||
# message = "Hello, how are you?"
|
||||
# response = client.send_message(agent_id=agent.id, message=message, role="user")
|
||||
# et = time.time()
|
||||
# print(f"Message sent from {st} to {et}")
|
||||
# print(response.messages)
|
||||
# results.append((st, et))
|
||||
# except Exception as e:
|
||||
# print("ERROR", e)
|
||||
#
|
||||
# threads = []
|
||||
# print("Starting threads...")
|
||||
# for i in range(5):
|
||||
# thread = threading.Thread(target=_send_message)
|
||||
# threads.append(thread)
|
||||
# thread.start()
|
||||
# print("CREATED THREAD")
|
||||
#
|
||||
# print("waiting for threads to finish...")
|
||||
# for thread in threads:
|
||||
# print(thread.join())
|
||||
#
|
||||
# # make sure runtime are overlapping
|
||||
# assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
|
||||
# results[1][0] < results[0][0] and results[1][1] > results[0][0]
|
||||
# ), f"Threads should have overlapping runtimes {results}"
|
||||
#
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import Preset # TODO move to PresetModel
|
||||
from memgpt.settings import settings
|
||||
from tests.utils import create_config
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
# test_preset_name = "test_preset"
|
||||
test_preset_name = DEFAULT_PRESET
|
||||
test_agent_state = None
|
||||
client = None
|
||||
|
||||
test_agent_state_post_message = None
|
||||
test_user_id = uuid.uuid4()
|
||||
|
||||
|
||||
# admin credentials
|
||||
test_server_token = "test_server_token"
|
||||
|
||||
|
||||
def _reset_config():
|
||||
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
create_config("memgpt_hosted")
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# set to use postgres
|
||||
config.archival_storage_uri = db_url
|
||||
config.recall_storage_uri = db_url
|
||||
config.metadata_storage_uri = db_url
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
config.metadata_storage_type = "postgres"
|
||||
|
||||
config.save()
|
||||
credentials.save()
|
||||
print("_reset_config :: ", config.config_path)
|
||||
|
||||
|
||||
def run_server():
|
||||
|
||||
load_dotenv()
|
||||
|
||||
_reset_config()
|
||||
|
||||
from memgpt.server.rest_api.server import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
# Fixture to create clients with different configurations
|
||||
@pytest.fixture(
|
||||
params=[ # whether to use REST API server
|
||||
{"server": True},
|
||||
# {"server": False} # TODO: add when implemented
|
||||
],
|
||||
scope="module",
|
||||
)
|
||||
def admin_client(request):
|
||||
if request.param["server"]:
|
||||
# get URL from enviornment
|
||||
server_url = os.getenv("MEMGPT_SERVER_URL")
|
||||
if server_url is None:
|
||||
# run server in thread
|
||||
# NOTE: must set MEMGPT_SERVER_PASS enviornment variable
|
||||
server_url = "http://localhost:8283"
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5)
|
||||
print("Running client tests with server:", server_url)
|
||||
# create user via admin client
|
||||
admin = Admin(server_url, test_server_token)
|
||||
response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
|
||||
yield admin
|
||||
|
||||
|
||||
def test_concurrent_messages(admin_client):
|
||||
# test concurrent messages
|
||||
|
||||
# create three
|
||||
|
||||
results = []
|
||||
|
||||
def _send_message():
|
||||
try:
|
||||
print("START SEND MESSAGE")
|
||||
response = admin_client.create_user()
|
||||
token = response.api_key
|
||||
client = create_client(base_url=admin_client.base_url, token=token)
|
||||
agent = client.create_agent()
|
||||
|
||||
print("Agent created", agent.id)
|
||||
|
||||
st = time.time()
|
||||
message = "Hello, how are you?"
|
||||
response = client.send_message(agent_id=agent.id, message=message, role="user")
|
||||
et = time.time()
|
||||
print(f"Message sent from {st} to {et}")
|
||||
print(response.messages)
|
||||
results.append((st, et))
|
||||
except Exception as e:
|
||||
print("ERROR", e)
|
||||
|
||||
threads = []
|
||||
print("Starting threads...")
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=_send_message)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
print("CREATED THREAD")
|
||||
|
||||
print("waiting for threads to finish...")
|
||||
for thread in threads:
|
||||
print(thread.join())
|
||||
|
||||
# make sure runtime are overlapping
|
||||
assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
|
||||
results[1][0] < results[0][0] and results[1][1] > results[0][0]
|
||||
), f"Threads should have overlapping runtimes {results}"
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user