This commit is contained in:
Charles Packer 2024-08-16 19:52:47 -07:00 committed by GitHub
parent b39ad16a9d
commit 55a36a6e3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
112 changed files with 8008 additions and 8901 deletions

View File

@ -58,6 +58,7 @@ jobs:
pipx install poetry==1.8.2 pipx install poetry==1.8.2
poetry install -E dev -E postgres poetry install -E dev -E postgres
poetry run pytest -s tests/test_client.py poetry run pytest -s tests/test_client.py
poetry run pytest -s tests/test_concurrent_connections.py
- name: Print docker logs if tests fail - name: Print docker logs if tests fail
if: failure() if: failure()

View File

@ -2,7 +2,6 @@ name: Run All pytest Tests
env: env:
MEMGPT_PGURI: ${{ secrets.MEMGPT_PGURI }} MEMGPT_PGURI: ${{ secrets.MEMGPT_PGURI }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
on: on:
push: push:
@ -33,10 +32,10 @@ jobs:
with: with:
python-version: "3.12" python-version: "3.12"
poetry-version: "1.8.2" poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E crewai-tools" install-args: "-E dev -E postgres -E milvus"
- name: Initialize credentials - name: Initialize credentials
run: poetry run memgpt quickstart --backend openai run: poetry run memgpt quickstart --backend memgpt
#- name: Run docker compose server #- name: Run docker compose server
# env: # env:
@ -70,3 +69,14 @@ jobs:
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }} PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: | run: |
poetry run pytest -s -vv -k "not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests poetry run pytest -s -vv -k "not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests
- name: Run storage tests
env:
MEMGPT_PG_PORT: 8888
MEMGPT_PG_USER: memgpt
MEMGPT_PG_PASSWORD: memgpt
MEMGPT_PG_HOST: localhost
MEMGPT_PG_DB: memgpt
MEMGPT_SERVER_PASS: test_server_token
run: |
poetry run pytest -s -vv tests/test_storage.py

1
.gitignore vendored
View File

@ -1012,7 +1012,6 @@ FodyWeavers.xsd
## cached db data ## cached db data
pgdata/ pgdata/
!pgdata/.gitkeep !pgdata/.gitkeep
.persist/
## pytest mirrors ## pytest mirrors
memgpt/.pytest_cache/ memgpt/.pytest_cache/

View File

@ -14,7 +14,7 @@ WORKDIR /app
COPY pyproject.toml poetry.lock ./ COPY pyproject.toml poetry.lock ./
RUN poetry lock --no-update RUN poetry lock --no-update
RUN if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then \ RUN if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then \
poetry install --no-root -E "postgres server dev" ; \ poetry install --no-root -E "postgres server dev autogen" ; \
else \ else \
poetry install --no-root -E "postgres server" && \ poetry install --no-root -E "postgres server" && \
rm -rf $POETRY_CACHE_DIR ; \ rm -rf $POETRY_CACHE_DIR ; \

View File

@ -2,7 +2,8 @@ import datetime
import inspect import inspect
import json import json
import traceback import traceback
from typing import List, Literal, Optional, Tuple, Union import uuid
from typing import List, Literal, Optional, Tuple, Union, cast
from tqdm import tqdm from tqdm import tqdm
@ -18,20 +19,14 @@ from memgpt.constants import (
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC, MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC, MESSAGE_SUMMARY_WARNING_FRAC,
) )
from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage
from memgpt.interface import AgentInterface from memgpt.interface import AgentInterface
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
from memgpt.memory import ArchivalMemory, RecallMemory, summarize_messages from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages
from memgpt.metadata import MetadataStore from memgpt.metadata import MetadataStore
from memgpt.models import chat_completion_response
from memgpt.models.pydantic_models import OptionState, ToolModel
from memgpt.persistence_manager import LocalStateManager from memgpt.persistence_manager import LocalStateManager
from memgpt.schemas.agent import AgentState
from memgpt.schemas.block import Block
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import OptionState
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
from memgpt.schemas.passage import Passage
from memgpt.schemas.tool import Tool
from memgpt.system import ( from memgpt.system import (
get_initial_boot_messages, get_initial_boot_messages,
get_login_event, get_login_event,
@ -40,6 +35,7 @@ from memgpt.system import (
) )
from memgpt.utils import ( from memgpt.utils import (
count_tokens, count_tokens,
create_uuid_from_string,
get_local_time, get_local_time,
get_tool_call_id, get_tool_call_id,
get_utc_time, get_utc_time,
@ -76,7 +72,7 @@ def compile_memory_metadata_block(
def compile_system_message( def compile_system_message(
system_prompt: str, system_prompt: str,
in_context_memory: Memory, in_context_memory: BaseMemory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory? in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
archival_memory: Optional[ArchivalMemory] = None, archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None, recall_memory: Optional[RecallMemory] = None,
@ -139,7 +135,7 @@ def compile_system_message(
def initialize_message_sequence( def initialize_message_sequence(
model: str, model: str,
system: str, system: str,
memory: Memory, memory: BaseMemory,
archival_memory: Optional[ArchivalMemory] = None, archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None, recall_memory: Optional[RecallMemory] = None,
memory_edit_timestamp: Optional[datetime.datetime] = None, memory_edit_timestamp: Optional[datetime.datetime] = None,
@ -192,21 +188,35 @@ class Agent(object):
interface: AgentInterface, interface: AgentInterface,
# agents can be created from providing agent_state # agents can be created from providing agent_state
agent_state: AgentState, agent_state: AgentState,
tools: List[Tool], tools: List[ToolModel],
# memory: Memory, # memory: BaseMemory,
# extras # extras
messages_total: Optional[int] = None, # TODO remove? messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config? first_message_verify_mono: bool = True, # TODO move to config?
): ):
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}" # tools
for tool in tools:
assert tool, f"Tool is None - must be error in querying tool from DB"
assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools"
for tool_name in agent_state.tools:
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
# Store the functions schemas (this is passed as an argument to ChatCompletion)
self.functions = []
self.functions_python = {}
env = {}
env.update(globals())
for tool in tools:
# WARNING: name may not be consistent?
if tool.module: # execute the whole module
exec(tool.module, env)
else:
exec(tool.source_code, env)
self.functions_python[tool.name] = env[tool.name]
self.functions.append(tool.json_schema)
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
# Hold a copy of the state that was used to init the agent # Hold a copy of the state that was used to init the agent
self.agent_state = agent_state self.agent_state = agent_state
assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}"
try:
self.link_tools(tools)
except Exception as e:
raise ValueError(f"Encountered an error while trying to link agent tools during initialization:\n{str(e)}")
# gpt-4, gpt-3.5-turbo, ... # gpt-4, gpt-3.5-turbo, ...
self.model = self.agent_state.llm_config.model self.model = self.agent_state.llm_config.model
@ -215,8 +225,7 @@ class Agent(object):
self.system = self.agent_state.system self.system = self.agent_state.system
# Initialize the memory object # Initialize the memory object
self.memory = self.agent_state.memory self.memory = BaseMemory.load(self.agent_state.state["memory"])
assert isinstance(self.memory, Memory), f"Memory object is not of type Memory: {type(self.memory)}"
printd("Initialized memory object", self.memory) printd("Initialized memory object", self.memory)
# Interface must implement: # Interface must implement:
@ -245,13 +254,28 @@ class Agent(object):
self._messages: List[Message] = [] self._messages: List[Message] = []
# Once the memory object is initialized, use it to "bake" the system message # Once the memory object is initialized, use it to "bake" the system message
if self.agent_state.message_ids is not None: if "messages" in self.agent_state.state and self.agent_state.state["messages"] is not None:
self.set_message_buffer(message_ids=self.agent_state.message_ids) # print(f"Agent.__init__ :: loading, state={agent_state.state['messages']}")
if not isinstance(self.agent_state.state["messages"], list):
raise ValueError(f"'messages' in AgentState was bad type: {type(self.agent_state.state['messages'])}")
assert all([isinstance(msg, str) for msg in self.agent_state.state["messages"]])
# Convert to IDs, and pull from the database
raw_messages = [
self.persistence_manager.recall_memory.storage.get(id=uuid.UUID(msg_id)) for msg_id in self.agent_state.state["messages"]
]
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"])
self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None])
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
else: else:
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}") printd(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
# Generate a sequence of initial messages to put in the buffer
init_messages = initialize_message_sequence( init_messages = initialize_message_sequence(
model=self.model, model=self.model,
system=self.system, system=self.system,
@ -261,8 +285,6 @@ class Agent(object):
memory_edit_timestamp=get_utc_time(), memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True, include_initial_boot_message=True,
) )
# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = [] init_messages_objs = []
for msg in init_messages: for msg in init_messages:
init_messages_objs.append( init_messages_objs.append(
@ -271,12 +293,15 @@ class Agent(object):
) )
) )
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages) assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)
# Put the messages inside the message buffer
self.messages_total = 0 self.messages_total = 0
# self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None]) self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
self._append_to_messages(added_messages=init_messages_objs)
self._validate_message_buffer_is_utc() for m in self._messages:
assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
# Keep track of the total number of messages throughout all time # Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
@ -295,65 +320,6 @@ class Agent(object):
def messages(self, value): def messages(self, value):
raise Exception("Modifying message list directly not allowed") raise Exception("Modifying message list directly not allowed")
def link_tools(self, tools: List[Tool]):
"""Bind a tool object (schema + python function) to the agent object"""
# tools
for tool in tools:
assert tool, f"Tool is None - must be error in querying tool from DB"
assert tool.name in self.agent_state.tools, f"Tool {tool} not found in agent_state.tools"
for tool_name in self.agent_state.tools:
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
# Store the functions schemas (this is passed as an argument to ChatCompletion)
self.functions = []
self.functions_python = {}
env = {}
env.update(globals())
for tool in tools:
# WARNING: name may not be consistent?
if tool.module: # execute the whole module
exec(tool.module, env)
else:
exec(tool.source_code, env)
self.functions_python[tool.name] = env[tool.name]
self.functions.append(tool.json_schema)
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
"""Load a list of messages from recall storage"""
# Pull the message objects from the database
message_objs = [self.persistence_manager.recall_memory.storage.get(msg_id) for msg_id in message_ids]
assert all([isinstance(msg, Message) for msg in message_objs])
return message_objs
def _validate_message_buffer_is_utc(self):
"""Iterate over the message buffer and force all messages to be UTC stamped"""
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
"""Set the messages in the buffer to the message IDs list"""
message_objs = self._load_messages_from_recall(message_ids=message_ids)
# set the objects in the buffer
self._messages = message_objs
# bugfix for old agents that may not have had UTC specified in their timestamps
if force_utc:
self._validate_message_buffer_is_utc()
# also sync the message IDs attribute
self.agent_state.message_ids = message_ids
def _trim_messages(self, num): def _trim_messages(self, num):
"""Trim messages from the front, not including the system message""" """Trim messages from the front, not including the system message"""
self.persistence_manager.trim_messages(num) self.persistence_manager.trim_messages(num)
@ -406,7 +372,7 @@ class Agent(object):
first_message: bool = False, # hint first_message: bool = False, # hint
stream: bool = False, # TODO move to config? stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
) -> ChatCompletionResponse: ) -> chat_completion_response.ChatCompletionResponse:
"""Get response from LLM API""" """Get response from LLM API"""
try: try:
response = create( response = create(
@ -442,7 +408,9 @@ class Agent(object):
except Exception as e: except Exception as e:
raise e raise e
def _handle_ai_response(self, response_message: Message, override_tool_call_id: bool = True) -> Tuple[List[Message], bool, bool]: def _handle_ai_response(
self, response_message: chat_completion_response.Message, override_tool_call_id: bool = True
) -> Tuple[List[Message], bool, bool]:
"""Handles parsing and function execution""" """Handles parsing and function execution"""
messages = [] # append these to the history when done messages = [] # append these to the history when done
@ -647,7 +615,6 @@ class Agent(object):
stream: bool = False, # TODO move to config? stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None, timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT, inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]: ) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent""" """Top-level event message handler for the MemGPT agent"""
@ -677,20 +644,7 @@ class Agent(object):
raise e raise e
try: try:
# Step 0: update core memory # Step 0: add user message
# only pulling latest block data if shared memory is being used
# TODO: ensure we're passing in metadata store from all surfaces
if ms is not None:
should_update = False
for block in self.agent_state.memory.to_dict().values():
if not block.get("template", False):
should_update = True
if should_update:
# TODO: the force=True can be optimized away
# once we ensure we're correctly comparing whether in-memory core
# data is different than persisted core data.
self.rebuild_memory(force=True, ms=ms)
# Step 1: add user message
if user_message is not None: if user_message is not None:
if isinstance(user_message, Message): if isinstance(user_message, Message):
# Validate JSON via save/load # Validate JSON via save/load
@ -736,7 +690,7 @@ class Agent(object):
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user": if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue") printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
# Step 2: send the conversation and available functions to GPT # Step 1: send the conversation and available functions to GPT
if not skip_verify and (first_message or self.messages_total == self.messages_total_init): if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
printd(f"This is the first message. Running extra verifier on AI response.") printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0 counter = 0
@ -761,9 +715,9 @@ class Agent(object):
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
) )
# Step 3: check if LLM wanted to call a function # Step 2: check if LLM wanted to call a function
# (if yes) Step 4: call the function # (if yes) Step 3: call the function
# (if yes) Step 5: send the info on the function call and function response to LLM # (if yes) Step 4: send the info on the function call and function response to LLM
response_message = response.choices[0].message response_message = response.choices[0].message
response_message.model_copy() # TODO why are we copying here? response_message.model_copy() # TODO why are we copying here?
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message) all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message)
@ -779,7 +733,7 @@ class Agent(object):
# "functions": self.functions, # "functions": self.functions,
# } # }
# Step 6: extend the message history # Step 4: extend the message history
if user_message is not None: if user_message is not None:
if isinstance(user_message, Message): if isinstance(user_message, Message):
all_new_messages = [user_message] + all_response_messages all_new_messages = [user_message] + all_response_messages
@ -839,7 +793,6 @@ class Agent(object):
stream=stream, stream=stream,
timestamp=timestamp, timestamp=timestamp,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
ms=ms,
) )
else: else:
@ -988,8 +941,8 @@ class Agent(object):
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system) new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages self._messages = new_messages
def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None): def rebuild_memory(self, force=False, update_timestamp=True):
"""Rebuilds the system message with the latest memory object and any shared memory block updates""" """Rebuilds the system message with the latest memory object"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
# NOTE: This is a hacky way to check if the memory has changed # NOTE: This is a hacky way to check if the memory has changed
@ -998,28 +951,6 @@ class Agent(object):
printd(f"Memory has not changed, not rebuilding system") printd(f"Memory has not changed, not rebuilding system")
return return
if ms:
for block in self.memory.to_dict().values():
if block.get("templates", False):
# we don't expect to update shared memory blocks that
# are templates. this is something we could update in the
# future if we expect templates to change often.
continue
block_id = block.get("id")
db_block = ms.get_block(block_id=block_id)
if db_block is None:
# this case covers if someone has deleted a shared block by interacting
# with some other agent.
# in that case we should remove this shared block from the agent currently being
# evaluated.
printd(f"removing block: {block_id=}")
continue
if not isinstance(db_block.value, str):
printd(f"skipping block update, unexpected value: {block_id=}")
continue
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
self.memory.update_block_value(name=block.get("name", ""), value=db_block.value)
# If the memory didn't update, we probably don't want to update the timestamp inside # If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False # For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp: if update_timestamp:
@ -1117,14 +1048,25 @@ class Agent(object):
# return msg # return msg
def update_state(self) -> AgentState: def update_state(self) -> AgentState:
message_ids = [msg.id for msg in self._messages] memory = {
assert isinstance(self.memory, Memory), f"Memory is not a Memory object: {type(self.memory)}" "system": self.system,
"memory": self.memory.to_dict(),
# override any fields that may have been updated "messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids
self.agent_state.message_ids = message_ids }
self.agent_state.memory = self.memory self.agent_state = AgentState(
self.agent_state.system = self.system name=self.agent_state.name,
user_id=self.agent_state.user_id,
tools=self.agent_state.tools,
system=self.system,
## "model_state"
llm_config=self.agent_state.llm_config,
embedding_config=self.agent_state.embedding_config,
id=self.agent_state.id,
created_at=self.agent_state.created_at,
## "agent_state"
state=memory,
_metadata=self.agent_state._metadata,
)
return self.agent_state return self.agent_state
def migrate_embedding(self, embedding_config: EmbeddingConfig): def migrate_embedding(self, embedding_config: EmbeddingConfig):
@ -1134,12 +1076,13 @@ class Agent(object):
# TODO: recall memory # TODO: recall memory
raise NotImplementedError() raise NotImplementedError()
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore): def attach_source(self, source_name, source_connector: StorageConnector, ms: MetadataStore):
"""Attach data with name `source_name` to the agent from source_connector.""" """Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory # TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
filters = {"user_id": self.agent_state.user_id, "source_id": source_id} filters = {"user_id": self.agent_state.user_id, "data_source": source_name}
size = source_connector.size(filters) size = source_connector.size(filters)
# typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN)
page_size = 100 page_size = 100
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage] generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
all_passages = [] all_passages = []
@ -1152,8 +1095,7 @@ class Agent(object):
passage.agent_id = self.agent_state.id passage.agent_id = self.agent_state.id
# regenerate passage ID (avoid duplicates) # regenerate passage ID (avoid duplicates)
# TODO: need to find another solution to the text duplication issue passage.id = create_uuid_from_string(f"{source_name}_{str(passage.agent_id)}_{passage.text}")
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
# insert into agent archival memory # insert into agent archival memory
self.persistence_manager.archival_memory.storage.insert_many(passages) self.persistence_manager.archival_memory.storage.insert_many(passages)
@ -1165,14 +1107,15 @@ class Agent(object):
self.persistence_manager.archival_memory.storage.save() self.persistence_manager.archival_memory.storage.save()
# attach to agent # attach to agent
source = ms.get_source(source_id=source_id) source = ms.get_source(source_name=source_name, user_id=self.agent_state.user_id)
assert source is not None, f"Source {source_id} not found in metadata store" assert source is not None, f"source does not exist for source_name={source_name}, user_id={self.agent_state.user_id}"
source_id = source.id
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id) ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
total_agent_passages = self.persistence_manager.archival_memory.storage.size() total_agent_passages = self.persistence_manager.archival_memory.storage.size()
printd( printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.", f"Attached data source {source_name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
) )
@ -1181,36 +1124,8 @@ def save_agent(agent: Agent, ms: MetadataStore):
agent.update_state() agent.update_state()
agent_state = agent.agent_state agent_state = agent.agent_state
agent_id = agent_state.id
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
# NOTE: we're saving agent memory before persisting the agent to ensure if ms.get_agent(agent_name=agent_state.name, user_id=agent_state.user_id):
# that allocated block_ids for each memory block are present in the agent model
save_agent_memory(agent=agent, ms=ms)
if ms.get_agent(agent_id=agent.agent_state.id):
ms.update_agent(agent_state) ms.update_agent(agent_state)
else: else:
ms.create_agent(agent_state) ms.create_agent(agent_state)
agent.agent_state = ms.get_agent(agent_id=agent_id)
assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
def save_agent_memory(agent: Agent, ms: MetadataStore):
"""
Save agent memory to metadata store. Memory is a collection of blocks and each block is persisted to the block table.
NOTE: we are assuming agent.update_state has already been called.
"""
for block_dict in agent.memory.to_dict().values():
# TODO: block creation should happen in one place to enforce these sort of constraints consistently.
if block_dict.get("user_id", None) is None:
block_dict["user_id"] = agent.agent_state.user_id
block = Block(**block_dict)
# FIXME: should we expect for block values to be None? If not, we need to figure out why that is
# the case in some tests, if so we should relax the DB constraint.
if block.value is None:
block.value = ""
ms.update_or_create_block(block)

View File

@ -1,12 +1,12 @@
from typing import Dict, List, Optional, Tuple, cast import uuid
from typing import Dict, Iterator, List, Optional, Tuple, cast
import chromadb import chromadb
from chromadb.api.types import Include from chromadb.api.types import Include
from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.data_types import Passage, Record, RecordType
from memgpt.schemas.passage import Passage
from memgpt.utils import datetime_to_timestamp, printd, timestamp_to_datetime from memgpt.utils import datetime_to_timestamp, printd, timestamp_to_datetime
@ -34,6 +34,9 @@ class ChromaStorageConnector(StorageConnector):
self.collection = self.client.get_or_create_collection(self.table_name) self.collection = self.client.get_or_create_collection(self.table_name)
self.include: Include = ["documents", "embeddings", "metadatas"] self.include: Include = ["documents", "embeddings", "metadatas"]
# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]: def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
# get all filters for query # get all filters for query
if filters is not None: if filters is not None:
@ -51,7 +54,10 @@ class ChromaStorageConnector(StorageConnector):
continue continue
# filter by other keys # filter by other keys
chroma_filters.append({key: {"$eq": value}}) if key in self.uuid_fields:
chroma_filters.append({key: {"$eq": str(value)}})
else:
chroma_filters.append({key: {"$eq": value}})
if len(chroma_filters) > 1: if len(chroma_filters) > 1:
chroma_filters = {"$and": chroma_filters} chroma_filters = {"$and": chroma_filters}
@ -61,7 +67,7 @@ class ChromaStorageConnector(StorageConnector):
chroma_filters = chroma_filters[0] chroma_filters = chroma_filters[0]
return ids, chroma_filters return ids, chroma_filters
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0): def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0) -> Iterator[List[RecordType]]:
ids, filters = self.get_filters(filters) ids, filters = self.get_filters(filters)
while True: while True:
# Retrieve a chunk of records with the given page_size # Retrieve a chunk of records with the given page_size
@ -78,50 +84,29 @@ class ChromaStorageConnector(StorageConnector):
# Increment the offset to get the next chunk in the next iteration # Increment the offset to get the next chunk in the next iteration
offset += page_size offset += page_size
def results_to_records(self, results): def results_to_records(self, results) -> List[RecordType]:
# convert timestamps to datetime # convert timestamps to datetime
for metadata in results["metadatas"]: for metadata in results["metadatas"]:
if "created_at" in metadata: if "created_at" in metadata:
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"]) metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = uuid.UUID(value)
if results["embeddings"]: # may not be returned, depending on table type if results["embeddings"]: # may not be returned, depending on table type
passages = [] return [
for text, record_id, embedding, metadata in zip( cast(RecordType, self.type(text=text, embedding=embedding, id=uuid.UUID(record_id), **metadatas)) # type: ignore
results["documents"], results["ids"], results["embeddings"], results["metadatas"] for (text, record_id, embedding, metadatas) in zip(
): results["documents"], results["ids"], results["embeddings"], results["metadatas"]
args = {} )
for field in EmbeddingConfig.__fields__.keys(): ]
if field in metadata:
args[field] = metadata[field]
del metadata[field]
embedding_config = EmbeddingConfig(**args)
passages.append(Passage(text=text, embedding=embedding, id=record_id, embedding_config=embedding_config, **metadata))
# return [
# Passage(text=text, embedding=embedding, id=record_id, embedding_config=EmbeddingConfig(), **metadatas)
# for (text, record_id, embedding, metadatas) in zip(
# results["documents"], results["ids"], results["embeddings"], results["metadatas"]
# )
# ]
return passages
else: else:
# no embeddings # no embeddings
passages = [] return [
for text, id, metadata in zip(results["documents"], results["ids"], results["metadatas"]): cast(RecordType, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
args = {} for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
for field in EmbeddingConfig.__fields__.keys(): ]
if field in metadata:
args[field] = metadata[field]
del metadata[field]
embedding_config = EmbeddingConfig(**args)
passages.append(Passage(text=text, embedding=None, id=id, embedding_config=embedding_config, **metadata))
return passages
# return [ def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
# #cast(Passage, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
# Passage(text=text, embedding=None, id=id, **metadatas)
# for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
# ]
def get_all(self, filters: Optional[Dict] = {}, limit=None):
ids, filters = self.get_filters(filters) ids, filters = self.get_filters(filters)
if self.collection.count() == 0: if self.collection.count() == 0:
return [] return []
@ -131,13 +116,13 @@ class ChromaStorageConnector(StorageConnector):
results = self.collection.get(ids=ids, include=self.include, where=filters) results = self.collection.get(ids=ids, include=self.include, where=filters)
return self.results_to_records(results) return self.results_to_records(results)
def get(self, id): def get(self, id: uuid.UUID) -> Optional[RecordType]:
results = self.collection.get(ids=[str(id)]) results = self.collection.get(ids=[str(id)])
if len(results["ids"]) == 0: if len(results["ids"]) == 0:
return None return None
return self.results_to_records(results)[0] return self.results_to_records(results)[0]
def format_records(self, records): def format_records(self, records: List[RecordType]):
assert all([isinstance(r, Passage) for r in records]) assert all([isinstance(r, Passage) for r in records])
recs = [] recs = []
@ -160,13 +145,10 @@ class ChromaStorageConnector(StorageConnector):
# collect/format record metadata # collect/format record metadata
metadatas = [] metadatas = []
for record in recs: for record in recs:
embedding_config = vars(record.embedding_config)
metadata = vars(record) metadata = vars(record)
metadata.pop("id") metadata.pop("id")
metadata.pop("text") metadata.pop("text")
metadata.pop("embedding") metadata.pop("embedding")
metadata.pop("embedding_config")
metadata.pop("metadata_")
if "created_at" in metadata: if "created_at" in metadata:
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"]) metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
if "metadata_" in metadata and metadata["metadata_"] is not None: if "metadata_" in metadata and metadata["metadata_"] is not None:
@ -174,22 +156,23 @@ class ChromaStorageConnector(StorageConnector):
metadata.pop("metadata_") metadata.pop("metadata_")
else: else:
record_metadata = {} record_metadata = {}
metadata = {**metadata, **record_metadata} # merge with metadata
metadata = {**metadata, **embedding_config} # merge with embedding config
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
metadata = {**metadata, **record_metadata} # merge with metadata
# convert uuids to strings # convert uuids to strings
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = str(value)
metadatas.append(metadata) metadatas.append(metadata)
return ids, documents, embeddings, metadatas return ids, documents, embeddings, metadatas
def insert(self, record): def insert(self, record: Record):
ids, documents, embeddings, metadatas = self.format_records([record]) ids, documents, embeddings, metadatas = self.format_records([record])
if any([e is None for e in embeddings]): if any([e is None for e in embeddings]):
raise ValueError("Embeddings must be provided to chroma") raise ValueError("Embeddings must be provided to chroma")
self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas) self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas)
def insert_many(self, records, show_progress=False): def insert_many(self, records: List[RecordType], show_progress=False):
ids, documents, embeddings, metadatas = self.format_records(records) ids, documents, embeddings, metadatas = self.format_records(records)
if any([e is None for e in embeddings]): if any([e is None for e in embeddings]):
raise ValueError("Embeddings must be provided to chroma") raise ValueError("Embeddings must be provided to chroma")
@ -215,7 +198,7 @@ class ChromaStorageConnector(StorageConnector):
def list_data_sources(self): def list_data_sources(self):
raise NotImplementedError raise NotImplementedError
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}): def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
ids, filters = self.get_filters(filters) ids, filters = self.get_filters(filters)
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters) results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters)
@ -256,40 +239,10 @@ class ChromaStorageConnector(StorageConnector):
def get_all_cursor( def get_all_cursor(
self, self,
filters: Optional[Dict] = {}, filters: Optional[Dict] = {},
after: str = None, after: uuid.UUID = None,
before: str = None, before: uuid.UUID = None,
limit: Optional[int] = 1000, limit: Optional[int] = 1000,
order_by: str = "created_at", order_by: str = "created_at",
reverse: bool = False, reverse: bool = False,
): ):
records = self.get_all(filters=filters) raise ValueError("Cannot run get_all_cursor with chroma")
# WARNING: very hacky and slow implementation
def get_index(id, record_list):
for i in range(len(record_list)):
if record_list[i].id == id:
return i
assert False, f"Could not find id {id} in record list"
# sort by custom field
records = sorted(records, key=lambda x: getattr(x, order_by), reverse=reverse)
if after:
index = get_index(after, records)
if index + 1 >= len(records):
return None, []
records = records[index + 1 :]
if before:
index = get_index(before, records)
if index == 0:
return None, []
# TODO: not sure if this is correct
records = records[:index]
if len(records) == 0:
return None, []
# enforce limit
if limit:
records = records[:limit]
return records[-1].id, records

View File

@ -1,11 +1,14 @@
import base64 import base64
import os import os
import uuid
from datetime import datetime from datetime import datetime
from typing import Dict, List, Optional from typing import Dict, Iterator, List, Optional
import numpy as np import numpy as np
from sqlalchemy import ( from sqlalchemy import (
BIGINT,
BINARY, BINARY,
CHAR,
JSON, JSON,
Column, Column,
DateTime, DateTime,
@ -20,6 +23,7 @@ from sqlalchemy import (
select, select,
text, text,
) )
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base, mapped_column, sessionmaker from sqlalchemy.orm import declarative_base, mapped_column, sessionmaker
from sqlalchemy.orm.session import close_all_sessions from sqlalchemy.orm.session import close_all_sessions
from sqlalchemy.sql import func from sqlalchemy.sql import func
@ -29,15 +33,34 @@ from tqdm import tqdm
from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.constants import MAX_EMBEDDING_DIM from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.metadata import EmbeddingConfigColumn from memgpt.data_types import Message, Passage, Record, RecordType, ToolCall
# from memgpt.schemas.message import Message, Passage, Record, RecordType, ToolCall
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ToolCall, ToolCallFunction
from memgpt.schemas.passage import Passage
from memgpt.settings import settings from memgpt.settings import settings
# Custom UUID type
class CommonUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR())
def process_bind_param(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return str(value) # Convert UUID to string for SQLite
def process_result_value(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return uuid.UUID(value)
class CommonVector(TypeDecorator): class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite""" """Common type for representing vectors in SQLite"""
@ -70,6 +93,26 @@ class CommonVector(TypeDecorator):
# Custom serialization / de-serialization for JSON columns # Custom serialization / de-serialization for JSON columns
class ToolCallColumn(TypeDecorator):
"""Custom type for storing List[ToolCall] as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
return [vars(v) for v in value]
return value
def process_result_value(self, value, dialect):
if value:
return [ToolCall(**v) for v in value]
return value
Base = declarative_base() Base = declarative_base()
@ -77,8 +120,8 @@ def get_db_model(
config: MemGPTConfig, config: MemGPTConfig,
table_name: str, table_name: str,
table_type: TableType, table_type: TableType,
user_id: str, user_id: uuid.UUID,
agent_id: Optional[str] = None, agent_id: Optional[uuid.UUID] = None,
dialect="postgresql", dialect="postgresql",
): ):
# Define a helper function to create or get the model class # Define a helper function to create or get the model class
@ -97,12 +140,14 @@ def get_db_model(
__abstract__ = True # this line is necessary __abstract__ = True # this line is necessary
# Assuming passage_id is the primary key # Assuming passage_id is the primary key
id = Column(String, primary_key=True) # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(CommonUUID, nullable=False)
text = Column(String) text = Column(String)
doc_id = Column(String) doc_id = Column(CommonUUID)
agent_id = Column(String) agent_id = Column(CommonUUID)
source_id = Column(String) data_source = Column(String) # agent_name if agent, data_source name if from data source
# vector storage # vector storage
if dialect == "sqlite": if dialect == "sqlite":
@ -111,8 +156,9 @@ def get_db_model(
from pgvector.sqlalchemy import Vector from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM)) embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
embedding_config = Column(EmbeddingConfigColumn)
metadata_ = Column(MutableJson) metadata_ = Column(MutableJson)
# Add a datetime column, with default value as the current time # Add a datetime column, with default value as the current time
@ -127,11 +173,12 @@ def get_db_model(
return Passage( return Passage(
text=self.text, text=self.text,
embedding=self.embedding, embedding=self.embedding,
embedding_config=self.embedding_config, embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
doc_id=self.doc_id, doc_id=self.doc_id,
user_id=self.user_id, user_id=self.user_id,
id=self.id, id=self.id,
source_id=self.source_id, data_source=self.data_source,
agent_id=self.agent_id, agent_id=self.agent_id,
metadata_=self.metadata_, metadata_=self.metadata_,
created_at=self.created_at, created_at=self.created_at,
@ -149,9 +196,11 @@ def get_db_model(
__abstract__ = True # this line is necessary __abstract__ = True # this line is necessary
# Assuming message_id is the primary key # Assuming message_id is the primary key
id = Column(String, primary_key=True) # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
agent_id = Column(String, nullable=False) # id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(CommonUUID, nullable=False)
agent_id = Column(CommonUUID, nullable=False)
# openai info # openai info
role = Column(String, nullable=False) role = Column(String, nullable=False)
@ -163,29 +212,31 @@ def get_db_model(
# if role == "assistant", this MAY be specified # if role == "assistant", this MAY be specified
# if role != "assistant", this must be null # if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls # TODO align with OpenAI spec of multiple tool calls
# tool_calls = Column(ToolCallColumn) tool_calls = Column(ToolCallColumn)
tool_calls = Column(JSON)
# tool call response info # tool call response info
# if role == "tool", then this must be specified # if role == "tool", then this must be specified
# if role != "tool", this must be null # if role != "tool", this must be null
tool_call_id = Column(String) tool_call_id = Column(String)
# vector storage
if dialect == "sqlite":
embedding = Column(CommonVector)
else:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
# Add a datetime column, with default value as the current time # Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True)) created_at = Column(DateTime(timezone=True))
Index("message_idx_user", user_id, agent_id), Index("message_idx_user", user_id, agent_id),
def __repr__(self): def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}')>" return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self): def to_record(self):
calls = (
[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
if self.tool_calls
else None
)
if calls:
assert isinstance(calls[0], ToolCall)
return Message( return Message(
user_id=self.user_id, user_id=self.user_id,
agent_id=self.agent_id, agent_id=self.agent_id,
@ -193,9 +244,11 @@ def get_db_model(
name=self.name, name=self.name,
text=self.text, text=self.text,
model=self.model, model=self.model,
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls, tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id, tool_call_id=self.tool_call_id,
embedding=self.embedding,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
created_at=self.created_at, created_at=self.created_at,
id=self.id, id=self.id,
) )
@ -221,7 +274,7 @@ class SQLStorageConnector(StorageConnector):
all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
return all_filters return all_filters
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0): def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0) -> Iterator[List[RecordType]]:
filters = self.get_filters(filters) filters = self.get_filters(filters)
while True: while True:
# Retrieve a chunk of records with the given page_size # Retrieve a chunk of records with the given page_size
@ -241,8 +294,8 @@ class SQLStorageConnector(StorageConnector):
def get_all_cursor( def get_all_cursor(
self, self,
filters: Optional[Dict] = {}, filters: Optional[Dict] = {},
after: str = None, after: uuid.UUID = None,
before: str = None, before: uuid.UUID = None,
limit: Optional[int] = 1000, limit: Optional[int] = 1000,
order_by: str = "created_at", order_by: str = "created_at",
reverse: bool = False, reverse: bool = False,
@ -279,12 +332,12 @@ class SQLStorageConnector(StorageConnector):
return (None, []) return (None, [])
records = [record.to_record() for record in db_record_chunk] records = [record.to_record() for record in db_record_chunk]
next_cursor = db_record_chunk[-1].id next_cursor = db_record_chunk[-1].id
assert isinstance(next_cursor, str) assert isinstance(next_cursor, uuid.UUID)
# return (cursor, list[records]) # return (cursor, list[records])
return (next_cursor, records) return (next_cursor, records)
def get_all(self, filters: Optional[Dict] = {}, limit=None): def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
filters = self.get_filters(filters) filters = self.get_filters(filters)
with self.session_maker() as session: with self.session_maker() as session:
if limit: if limit:
@ -293,7 +346,7 @@ class SQLStorageConnector(StorageConnector):
db_records = session.query(self.db_model).filter(*filters).all() db_records = session.query(self.db_model).filter(*filters).all()
return [record.to_record() for record in db_records] return [record.to_record() for record in db_records]
def get(self, id: str): def get(self, id: uuid.UUID) -> Optional[Record]:
with self.session_maker() as session: with self.session_maker() as session:
db_record = session.get(self.db_model, id) db_record = session.get(self.db_model, id)
if db_record is None: if db_record is None:
@ -306,13 +359,13 @@ class SQLStorageConnector(StorageConnector):
with self.session_maker() as session: with self.session_maker() as session:
return session.query(self.db_model).filter(*filters).count() return session.query(self.db_model).filter(*filters).count()
def insert(self, record): def insert(self, record: Record):
raise NotImplementedError raise NotImplementedError
def insert_many(self, records, show_progress=False): def insert_many(self, records: List[RecordType], show_progress=False):
raise NotImplementedError raise NotImplementedError
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}): def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
raise NotImplementedError("Vector query not implemented for SQLStorageConnector") raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
def save(self): def save(self):
@ -417,7 +470,7 @@ class PostgresStorageConnector(SQLStorageConnector):
# create table # create table
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}): def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
filters = self.get_filters(filters) filters = self.get_filters(filters)
with self.session_maker() as session: with self.session_maker() as session:
results = session.scalars( results = session.scalars(
@ -428,7 +481,7 @@ class PostgresStorageConnector(SQLStorageConnector):
records = [result.to_record() for result in results] records = [result.to_record() for result in results]
return records return records
def insert_many(self, records, exists_ok=True, show_progress=False): def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
@ -450,15 +503,14 @@ class PostgresStorageConnector(SQLStorageConnector):
with self.session_maker() as session: with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records iterable = tqdm(records) if show_progress else records
for record in iterable: for record in iterable:
# db_record = self.db_model(**vars(record)) db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
session.add(db_record) session.add(db_record)
session.commit() session.commit()
def insert(self, record, exists_ok=True): def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok) self.insert_many([record], exists_ok=exists_ok)
def update(self, record): def update(self, record: RecordType):
""" """
Updates a record in the database based on the provided Record object. Updates a record in the database based on the provided Record object.
""" """
@ -523,12 +575,12 @@ class SQLLiteStorageConnector(SQLStorageConnector):
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
self.session_maker = sessionmaker(bind=self.engine) self.session_maker = sessionmaker(bind=self.engine)
# import sqlite3 import sqlite3
# sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le) sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
def insert_many(self, records, exists_ok=True, show_progress=False): def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.sqlite import insert from sqlalchemy.dialects.sqlite import insert
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel) # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
@ -550,15 +602,14 @@ class SQLLiteStorageConnector(SQLStorageConnector):
with self.session_maker() as session: with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records iterable = tqdm(records) if show_progress else records
for record in iterable: for record in iterable:
# db_record = self.db_model(**vars(record)) db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
session.add(db_record) session.add(db_record)
session.commit() session.commit()
def insert(self, record, exists_ok=True): def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok) self.insert_many([record], exists_ok=exists_ok)
def update(self, record): def update(self, record: Record):
""" """
Updates an existing record in the database with values from the provided record object. Updates an existing record in the database with values from the provided record object.
""" """

View File

@ -8,7 +8,7 @@ from lancedb.pydantic import LanceModel, Vector
from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import AgentConfig, MemGPTConfig from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.schemas.message import Message, Passage, Record from memgpt.data_types import Message, Passage, Record
""" Initial implementation - not complete """ """ Initial implementation - not complete """

View File

@ -5,14 +5,10 @@ We originally tried to use Llama Index VectorIndex, but their limited API was ex
import uuid import uuid
from abc import abstractmethod from abc import abstractmethod
from typing import Dict, List, Optional, Tuple, Type, Union from typing import Dict, Iterator, List, Optional, Tuple, Type, Union
from pydantic import BaseModel
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.schemas.document import Document from memgpt.data_types import Document, Message, Passage, Record, RecordType
from memgpt.schemas.message import Message
from memgpt.schemas.passage import Passage
from memgpt.utils import printd from memgpt.utils import printd
@ -39,7 +35,7 @@ DOCUMENT_TABLE_NAME = "memgpt_documents" # original documents (from source)
class StorageConnector: class StorageConnector:
"""Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory""" """Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
type: Type[BaseModel] type: Type[Record]
def __init__( def __init__(
self, self,
@ -140,15 +136,15 @@ class StorageConnector:
pass pass
@abstractmethod @abstractmethod
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000): def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
pass pass
@abstractmethod @abstractmethod
def get_all(self, filters: Optional[Dict] = {}, limit=10): def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
pass pass
@abstractmethod @abstractmethod
def get(self, id: uuid.UUID): def get(self, id: uuid.UUID) -> Optional[RecordType]:
pass pass
@abstractmethod @abstractmethod
@ -156,15 +152,15 @@ class StorageConnector:
pass pass
@abstractmethod @abstractmethod
def insert(self, record): def insert(self, record: RecordType):
pass pass
@abstractmethod @abstractmethod
def insert_many(self, records, show_progress=False): def insert_many(self, records: List[RecordType], show_progress=False):
pass pass
@abstractmethod @abstractmethod
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}): def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
pass pass
@abstractmethod @abstractmethod

View File

@ -5,7 +5,7 @@ from typing import Optional
from colorama import Fore, Style, init from colorama import Fore, Style, init
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.schemas.message import Message from memgpt.data_types import Message
init(autoreset=True) init(autoreset=True)

View File

@ -3,6 +3,7 @@ import logging
import os import os
import subprocess import subprocess
import sys import sys
import uuid
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from typing import Annotated, Optional from typing import Annotated, Optional
@ -18,12 +19,12 @@ from memgpt.cli.cli_config import configure
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR
from memgpt.credentials import MemGPTCredentials from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import EmbeddingConfig, LLMConfig, User
from memgpt.log import get_logger from memgpt.log import get_logger
from memgpt.memory import ChatMemory
from memgpt.metadata import MetadataStore from memgpt.metadata import MetadataStore
from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.schemas.enums import OptionState from memgpt.models.pydantic_models import OptionState
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import ChatMemory, Memory
from memgpt.server.constants import WS_DEFAULT_PORT from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.server import logger as server_logger from memgpt.server.server import logger as server_logger
@ -36,6 +37,14 @@ from memgpt.utils import open_folder_in_explorer, printd
logger = get_logger(__name__) logger = get_logger(__name__)
def migrate(
debug: Annotated[bool, typer.Option(help="Print extra tracebacks for failed migrations")] = False,
):
"""Migrate old agents (pre 0.2.12) to the new database system"""
migrate_all_agents(debug=debug)
migrate_all_sources(debug=debug)
class QuickstartChoice(Enum): class QuickstartChoice(Enum):
openai = "openai" openai = "openai"
# azure = "azure" # azure = "azure"
@ -171,10 +180,13 @@ def quickstart(
else: else:
# Load the file from the relative path # Load the file from the relative path
script_dir = os.path.dirname(__file__) # Get the directory where the script is located script_dir = os.path.dirname(__file__) # Get the directory where the script is located
# print("SCRIPT", script_dir)
backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json") backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json")
# print("FILE PATH", backup_config_path)
try: try:
with open(backup_config_path, "r", encoding="utf-8") as file: with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file) backup_config = json.load(file)
# print(backup_config)
printd("Loaded config file successfully.") printd("Loaded config file successfully.")
new_config, config_was_modified = set_config_with_dict(backup_config) new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError: except FileNotFoundError:
@ -201,6 +213,7 @@ def quickstart(
# Parse the response content as JSON # Parse the response content as JSON
config = response.json() config = response.json()
# Output a success message and the first few items in the dictionary as a sample # Output a success message and the first few items in the dictionary as a sample
print("JSON config file downloaded successfully.")
new_config, config_was_modified = set_config_with_dict(config) new_config, config_was_modified = set_config_with_dict(config)
else: else:
typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED) typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)
@ -279,6 +292,21 @@ class ServerChoice(Enum):
ws_api = "websocket" ws_api = "websocket"
def create_default_user_or_exit(config: MemGPTConfig, ms: MetadataStore):
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
else:
return user
else:
return user
def server( def server(
type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest", type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest",
port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None, port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
@ -295,8 +323,8 @@ def server(
if MemGPTConfig.exists(): if MemGPTConfig.exists():
config = MemGPTConfig.load() config = MemGPTConfig.load()
MetadataStore(config) ms = MetadataStore(config)
client = create_client() # triggers user creation create_default_user_or_exit(config, ms)
else: else:
typer.secho(f"No configuration exists. Run memgpt configure before starting the server.", fg=typer.colors.RED) typer.secho(f"No configuration exists. Run memgpt configure before starting the server.", fg=typer.colors.RED)
sys.exit(1) sys.exit(1)
@ -416,42 +444,42 @@ def run(
logger.setLevel(logging.CRITICAL) logger.setLevel(logging.CRITICAL)
server_logger.setLevel(logging.CRITICAL) server_logger.setLevel(logging.CRITICAL)
# from memgpt.migrate import ( from memgpt.migrate import (
# VERSION_CUTOFF, VERSION_CUTOFF,
# config_is_compatible, config_is_compatible,
# wipe_config_and_reconfigure, wipe_config_and_reconfigure,
# ) )
# if not config_is_compatible(allow_empty=True): if not config_is_compatible(allow_empty=True):
# typer.secho(f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}\n", fg=typer.colors.RED) typer.secho(f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}\n", fg=typer.colors.RED)
# choices = [ choices = [
# "Run the full config setup (recommended)", "Run the full config setup (recommended)",
# "Create a new config using defaults", "Create a new config using defaults",
# "Cancel", "Cancel",
# ] ]
# selection = questionary.select( selection = questionary.select(
# f"To use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}), or regenerate your config. Would you like to proceed?", f"To use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}), or regenerate your config. Would you like to proceed?",
# choices=choices, choices=choices,
# default=choices[0], default=choices[0],
# ).ask() ).ask()
# if selection == choices[0]: if selection == choices[0]:
# try: try:
# wipe_config_and_reconfigure() wipe_config_and_reconfigure()
# except Exception as e: except Exception as e:
# typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED) typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
# raise raise
# elif selection == choices[1]: elif selection == choices[1]:
# try: try:
# # Don't create a config, so that the next block of code asking about quickstart is run # Don't create a config, so that the next block of code asking about quickstart is run
# wipe_config_and_reconfigure(run_configure=False, create_config=False) wipe_config_and_reconfigure(run_configure=False, create_config=False)
# except Exception as e: except Exception as e:
# typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED) typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
# raise raise
# else: else:
# typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED) typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED)
# raise KeyboardInterrupt() raise KeyboardInterrupt()
# typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN) typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN)
if not MemGPTConfig.exists(): if not MemGPTConfig.exists():
# if no config, ask about quickstart # if no config, ask about quickstart
@ -496,12 +524,11 @@ def run(
# read user id from config # read user id from config
ms = MetadataStore(config) ms = MetadataStore(config)
client = create_client() user = create_default_user_or_exit(config, ms)
client.user_id
# determine agent to use, if not provided # determine agent to use, if not provided
if not yes and not agent: if not yes and not agent:
agents = client.list_agents() agents = ms.list_agents(user_id=user.id)
agents = [a.name for a in agents] agents = [a.name for a in agents]
if len(agents) > 0: if len(agents) > 0:
@ -513,11 +540,7 @@ def run(
agent = questionary.select("Select agent:", choices=agents).ask() agent = questionary.select("Select agent:", choices=agents).ask()
# create agent config # create agent config
if agent: agent_state = ms.get_agent(agent_name=agent, user_id=user.id) if agent else None
agent_id = client.get_agent_id(agent)
agent_state = client.get_agent(agent_id)
else:
agent_state = None
human = human if human else config.human human = human if human else config.human
persona = persona if persona else config.persona persona = persona if persona else config.persona
if agent and agent_state: # use existing agent if agent and agent_state: # use existing agent
@ -574,12 +597,13 @@ def run(
# agent_state.state["system"] = system # agent_state.state["system"] = system
# Update the agent with any overrides # Update the agent with any overrides
agent_state = client.update_agent( ms.update_agent(agent_state)
agent_id=agent_state.id, tools = []
name=agent_state.name, for tool_name in agent_state.tools:
llm_config=agent_state.llm_config, tool = ms.get_tool(tool_name, agent_state.user_id)
embedding_config=agent_state.embedding_config, if tool is None:
) typer.secho(f"Couldn't find tool {tool_name} in database, please run `memgpt add tool`", fg=typer.colors.RED)
tools.append(tool)
# create agent # create agent
memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools) memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
@ -622,52 +646,55 @@ def run(
llm_config.model_endpoint_type = model_endpoint_type llm_config.model_endpoint_type = model_endpoint_type
# create agent # create agent
client = create_client() try:
human_obj = client.get_human(client.get_human_id(name=human)) client = create_client()
persona_obj = client.get_persona(client.get_persona_id(name=persona)) human_obj = ms.get_human(human, user.id)
if human_obj is None: persona_obj = ms.get_persona(persona, user.id)
typer.secho(f"Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED) # TODO pull system prompts from the metadata store
# NOTE: will be overriden later to a default
if system_file:
try:
with open(system_file, "r", encoding="utf-8") as file:
system = file.read().strip()
printd("Loaded system file successfully.")
except FileNotFoundError:
typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED)
system_prompt = system if system else None
if human_obj is None:
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
if persona_obj is None:
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
memory = ChatMemory(human=human_obj.text, persona=persona_obj.text, limit=core_memory_limit)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
# add tools
agent_state = client.create_agent(
name=agent_name,
system_prompt=system_prompt,
embedding_config=embedding_config,
llm_config=llm_config,
memory=memory,
metadata=metadata,
)
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
memgpt_agent = Agent(
interface=interface(),
agent_state=agent_state,
tools=tools,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
)
save_agent(agent=memgpt_agent, ms=ms)
except ValueError as e:
typer.secho(f"Failed to create agent from provided information:\n{e}", fg=typer.colors.RED)
sys.exit(1) sys.exit(1)
if persona_obj is None:
typer.secho(f"Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
sys.exit(1)
if system_file:
try:
with open(system_file, "r", encoding="utf-8") as file:
system = file.read().strip()
printd("Loaded system file successfully.")
except FileNotFoundError:
typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED)
system_prompt = system if system else None
memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
# add tools
agent_state = client.create_agent(
name=agent_name,
system=system_prompt,
embedding_config=embedding_config,
llm_config=llm_config,
memory=memory,
metadata=metadata,
)
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
memgpt_agent = Agent(
interface=interface(),
agent_state=agent_state,
tools=tools,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
)
save_agent(agent=memgpt_agent, ms=ms)
typer.secho(f"🎉 Created new agent '{memgpt_agent.agent_state.name}' (id={memgpt_agent.agent_state.id})", fg=typer.colors.GREEN) typer.secho(f"🎉 Created new agent '{memgpt_agent.agent_state.name}' (id={memgpt_agent.agent_state.id})", fg=typer.colors.GREEN)
# start event loop # start event loop
@ -692,10 +719,19 @@ def delete_agent(
"""Delete an agent from the database""" """Delete an agent from the database"""
# use client ID is no user_id provided # use client ID is no user_id provided
config = MemGPTConfig.load() config = MemGPTConfig.load()
MetadataStore(config) ms = MetadataStore(config)
client = create_client(user_id=user_id) if user_id is None:
agent = client.get_agent_by_name(agent_name) user = create_default_user_or_exit(config, ms)
if not agent: else:
user = ms.get_user(user_id=uuid.UUID(user_id))
try:
agent = ms.get_agent(agent_name=agent_name, user_id=user.id)
except Exception as e:
typer.secho(f"Failed to get agent {agent_name}\n{e}", fg=typer.colors.RED)
sys.exit(1)
if agent is None:
typer.secho(f"Couldn't find agent named '{agent_name}' to delete", fg=typer.colors.RED) typer.secho(f"Couldn't find agent named '{agent_name}' to delete", fg=typer.colors.RED)
sys.exit(1) sys.exit(1)
@ -707,8 +743,7 @@ def delete_agent(
return return
try: try:
# delete the agent ms.delete_agent(agent_id=agent.id)
client.delete_agent(agent.id)
typer.secho(f"🕊️ Successfully deleted agent '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN) typer.secho(f"🕊️ Successfully deleted agent '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN)
except Exception: except Exception:
typer.secho(f"Failed to delete agent '{agent_name}' (id={agent.id})", fg=typer.colors.RED) typer.secho(f"Failed to delete agent '{agent_name}' (id={agent.id})", fg=typer.colors.RED)

View File

@ -1,8 +1,8 @@
import ast
import builtins import builtins
import os import os
import uuid
from enum import Enum from enum import Enum
from typing import Annotated, List, Optional from typing import Annotated, Optional
import questionary import questionary
import typer import typer
@ -13,6 +13,7 @@ from memgpt import utils
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.constants import LLM_MAX_TOKENS, MEMGPT_DIR from memgpt.constants import LLM_MAX_TOKENS, MEMGPT_DIR
from memgpt.credentials import SUPPORTED_AUTH_TYPES, MemGPTCredentials from memgpt.credentials import SUPPORTED_AUTH_TYPES, MemGPTCredentials
from memgpt.data_types import EmbeddingConfig, LLMConfig, Source, User
from memgpt.llm_api.anthropic import ( from memgpt.llm_api.anthropic import (
anthropic_get_model_list, anthropic_get_model_list,
antropic_get_model_context_window, antropic_get_model_context_window,
@ -35,8 +36,7 @@ from memgpt.local_llm.constants import (
DEFAULT_WRAPPER_NAME, DEFAULT_WRAPPER_NAME,
) )
from memgpt.local_llm.utils import get_available_wrappers from memgpt.local_llm.utils import get_available_wrappers
from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.metadata import MetadataStore
from memgpt.schemas.llm_config import LLMConfig
from memgpt.server.utils import shorten_key_middle from memgpt.server.utils import shorten_key_middle
app = typer.Typer() app = typer.Typer()
@ -1070,10 +1070,17 @@ def configure():
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN) typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
config.save() config.save()
from memgpt import create_client # create user records
ms = MetadataStore(config)
client = create_client() user_id = uuid.UUID(config.anon_clientid)
print("User ID:", client.user_id) user = User(
id=uuid.UUID(config.anon_clientid),
)
if ms.get_user(user_id):
# update user
ms.update_user(user)
else:
ms.create_user(user)
class ListChoice(str, Enum): class ListChoice(str, Enum):
@ -1087,14 +1094,17 @@ class ListChoice(str, Enum):
def list(arg: Annotated[ListChoice, typer.Argument]): def list(arg: Annotated[ListChoice, typer.Argument]):
from memgpt.client.client import create_client from memgpt.client.client import create_client
client = create_client() client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
table = ColorTable(theme=Themes.OCEAN) table = ColorTable(theme=Themes.OCEAN)
if arg == ListChoice.agents: if arg == ListChoice.agents:
"""List all agents""" """List all agents"""
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"] table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
for agent in tqdm(client.list_agents()): for agent in tqdm(client.list_agents()):
# TODO: add this function # TODO: add this function
sources = client.list_attached_sources(agent_id=agent.id) source_ids = client.list_attached_sources(agent_id=agent.id)
assert all([source_id is not None and isinstance(source_id, uuid.UUID) for source_id in source_ids])
sources = [client.get_source(source_id=source_id) for source_id in source_ids]
assert all([source is not None and isinstance(source, Source)] for source in sources)
source_names = [source.name for source in sources if source is not None] source_names = [source.name for source in sources if source is not None]
table.add_row( table.add_row(
[ [
@ -1102,8 +1112,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
agent.llm_config.model, agent.llm_config.model,
agent.embedding_config.embedding_model, agent.embedding_config.embedding_model,
agent.embedding_config.embedding_dim, agent.embedding_config.embedding_dim,
agent.memory.get_block("persona").value[:100] + "...", agent._metadata.get("persona", ""),
agent.memory.get_block("human").value[:100] + "...", agent._metadata.get("human", ""),
",".join(source_names), ",".join(source_names),
utils.format_datetime(agent.created_at), utils.format_datetime(agent.created_at),
] ]
@ -1113,13 +1123,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
"""List all humans""" """List all humans"""
table.field_names = ["Name", "Text"] table.field_names = ["Name", "Text"]
for human in client.list_humans(): for human in client.list_humans():
table.add_row([human.name, human.value.replace("\n", "")[:100]]) table.add_row([human.name, human.text.replace("\n", "")[:100]])
print(table) print(table)
elif arg == ListChoice.personas: elif arg == ListChoice.personas:
"""List all personas""" """List all personas"""
table.field_names = ["Name", "Text"] table.field_names = ["Name", "Text"]
for persona in client.list_personas(): for persona in client.list_personas():
table.add_row([persona.name, persona.value.replace("\n", "")[:100]]) table.add_row([persona.name, persona.text.replace("\n", "")[:100]])
print(table) print(table)
elif arg == ListChoice.sources: elif arg == ListChoice.sources:
"""List all data sources""" """List all data sources"""
@ -1149,63 +1159,6 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
return table return table
@app.command()
def add_tool(
filename: str = typer.Option(..., help="Path to the Python file containing the function"),
name: Optional[str] = typer.Option(None, help="Name of the tool"),
update: bool = typer.Option(True, help="Update the tool if it already exists"),
tags: Optional[List[str]] = typer.Option(None, help="Tags for the tool"),
):
"""Add or update a tool from a Python file."""
from memgpt.client.client import create_client
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
# 1. Parse the Python file
with open(filename, "r", encoding="utf-8") as file:
source_code = file.read()
# 2. Parse the source code to extract the function
# Note: here we assume it is one function only in the file.
module = ast.parse(source_code)
func_def = None
for node in module.body:
if isinstance(node, ast.FunctionDef):
func_def = node
break
if not func_def:
raise ValueError("No function found in the provided file")
# 3. Compile the function to make it callable
# Explanation courtesy of GPT-4:
# Compile the AST (Abstract Syntax Tree) node representing the function definition into a code object
# ast.Module creates a module node containing the function definition (func_def)
# compile converts the AST into a code object that can be executed by the Python interpreter
# The exec function executes the compiled code object in the current context,
# effectively defining the function within the current namespace
exec(compile(ast.Module([func_def], []), filename, "exec"))
# Retrieve the function object by evaluating its name in the current namespace
# eval looks up the function name in the current scope and returns the function object
func = eval(func_def.name)
# 4. Add or update the tool
tool = client.create_tool(func=func, name=name, tags=tags, update=update)
print(f"Tool {tool.name} added successfully")
@app.command()
def list_tools():
"""List all available tools."""
from memgpt.client.client import create_client
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
tools = client.list_tools()
for tool in tools:
print(f"Tool: {tool.name}")
@app.command() @app.command()
def add( def add(
option: str, # [human, persona] option: str, # [human, persona]
@ -1221,27 +1174,23 @@ def add(
assert text is None, "Cannot specify both text and filename" assert text is None, "Cannot specify both text and filename"
with open(filename, "r", encoding="utf-8") as f: with open(filename, "r", encoding="utf-8") as f:
text = f.read() text = f.read()
else:
assert text is not None, "Must specify either text or filename"
if option == "persona": if option == "persona":
persona_id = client.get_persona_id(name) persona = client.get_persona(name)
if persona_id: if persona:
client.get_persona(persona_id)
# config if user wants to overwrite # config if user wants to overwrite
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask(): if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
return return
client.update_persona(persona_id, text=text) client.update_persona(name=name, text=text)
else: else:
client.create_persona(name=name, text=text) client.create_persona(name=name, text=text)
elif option == "human": elif option == "human":
human_id = client.get_human_id(name) human = client.get_human(name=name)
if human_id: if human:
human = client.get_human(human_id)
# config if user wants to overwrite # config if user wants to overwrite
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask(): if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
return return
client.update_human(human_id, text=text) client.update_human(name=name, text=text)
else: else:
human = client.create_human(name=name, text=text) human = client.create_human(name=name, text=text)
else: else:
@ -1258,21 +1207,21 @@ def delete(option: str, name: str):
# delete from metadata # delete from metadata
if option == "source": if option == "source":
# delete metadata # delete metadata
source_id = client.get_source_id(name) source = client.get_source(name)
assert source_id is not None, f"Source {name} does not exist" assert source is not None, f"Source {name} does not exist"
client.delete_source(source_id) client.delete_source(source_id=source.id)
elif option == "agent": elif option == "agent":
agent_id = client.get_agent_id(name) agent = client.get_agent(agent_name=name)
assert agent_id is not None, f"Agent {name} does not exist" assert agent is not None, f"Agent {name} does not exist"
client.delete_agent(agent_id=agent_id) client.delete_agent(agent_id=agent.id)
elif option == "human": elif option == "human":
human_id = client.get_human_id(name) human = client.get_human(name=name)
assert human_id is not None, f"Human {name} does not exist" assert human is not None, f"Human {name} does not exist"
client.delete_human(human_id) client.delete_human(name=name)
elif option == "persona": elif option == "persona":
persona_id = client.get_persona_id(name) persona = client.get_persona(name=name)
assert persona_id is not None, f"Persona {name} does not exist" assert persona is not None, f"Persona {name} does not exist"
client.delete_persona(persona_id) client.delete_persona(name=name)
else: else:
raise ValueError(f"Option {option} not implemented") raise ValueError(f"Option {option} not implemented")

View File

@ -13,8 +13,15 @@ from typing import Annotated, List, Optional
import typer import typer
from memgpt import create_client from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.data_sources.connectors import DirectoryConnector from memgpt.config import MemGPTConfig
from memgpt.data_sources.connectors import (
DirectoryConnector,
VectorDBConnector,
load_data,
)
from memgpt.data_types import Source
from memgpt.metadata import MetadataStore
app = typer.Typer() app = typer.Typer()
@ -82,20 +89,41 @@ def load_directory(
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None, description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None,
): ):
client = create_client()
# create connector
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
# create source
source = client.create_source(name=name)
# load data
try: try:
client.load_data(connector, source_name=name) connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
except Exception as e: config = MemGPTConfig.load()
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) if not user_id:
client.delete_source(source.id) user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
source = Source(
name=name,
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
description=description,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
# ingest data into passage/document store
try:
num_passages, num_documents = load_data(
connector=connector,
source=source,
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except Exception as e:
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
ms.delete_source(source_id=source.id)
except ValueError as e:
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
raise
# @app.command("webpage") # @app.command("webpage")
@ -111,6 +139,56 @@ def load_directory(
# #
# except ValueError as e: # except ValueError as e:
# typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED) # typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)
#
#
# @app.command("database")
# def load_database(
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
# query: Annotated[str, typer.Option(help="Database query.")],
# dump_path: Annotated[Optional[str], typer.Option(help="Path to dump file.")] = None,
# scheme: Annotated[Optional[str], typer.Option(help="Database scheme.")] = None,
# host: Annotated[Optional[str], typer.Option(help="Database host.")] = None,
# port: Annotated[Optional[int], typer.Option(help="Database port.")] = None,
# user: Annotated[Optional[str], typer.Option(help="Database user.")] = None,
# password: Annotated[Optional[str], typer.Option(help="Database password.")] = None,
# dbname: Annotated[Optional[str], typer.Option(help="Database name.")] = None,
# ):
# try:
# from llama_index.readers.database import DatabaseReader
#
# print(dump_path, scheme)
#
# if dump_path is not None:
# # read from database dump file
# from sqlalchemy import create_engine
#
# engine = create_engine(f"sqlite:///{dump_path}")
#
# db = DatabaseReader(engine=engine)
# else:
# assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
# assert scheme is not None, "Must provide database scheme."
# assert host is not None, "Must provide database host."
# assert port is not None, "Must provide database port."
# assert user is not None, "Must provide database user."
# assert password is not None, "Must provide database password."
# assert dbname is not None, "Must provide database name."
#
# db = DatabaseReader(
# scheme=scheme, # Database Scheme
# host=host, # Database Host
# port=str(port), # Database Port
# user=user, # Database User
# password=password, # Database Password
# dbname=dbname, # Database Name
# )
#
# # load data
# docs = db.load_data(query=query)
# store_docs(name, docs)
# except ValueError as e:
# typer.secho(f"Failed to load database from provided information.\n{e}", fg=typer.colors.RED)
#
@app.command("vector-database") @app.command("vector-database")
@ -123,44 +201,43 @@ def load_vector_database(
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
): ):
"""Load pre-computed embeddings into MemGPT from a database.""" """Load pre-computed embeddings into MemGPT from a database."""
raise NotImplementedError try:
# try: config = MemGPTConfig.load()
# config = MemGPTConfig.load() connector = VectorDBConnector(
# connector = VectorDBConnector( uri=uri,
# uri=uri, table_name=table_name,
# table_name=table_name, text_column=text_column,
# text_column=text_column, embedding_column=embedding_column,
# embedding_column=embedding_column, embedding_dim=config.default_embedding_config.embedding_dim,
# embedding_dim=config.default_embedding_config.embedding_dim, )
# ) if not user_id:
# if not user_id: user_id = uuid.UUID(config.anon_clientid)
# user_id = uuid.UUID(config.anon_clientid)
# ms = MetadataStore(config) ms = MetadataStore(config)
# source = Source( source = Source(
# name=name, name=name,
# user_id=user_id, user_id=user_id,
# embedding_model=config.default_embedding_config.embedding_model, embedding_model=config.default_embedding_config.embedding_model,
# embedding_dim=config.default_embedding_config.embedding_dim, embedding_dim=config.default_embedding_config.embedding_dim,
# ) )
# ms.create_source(source) ms.create_source(source)
# passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# # TODO: also get document store # TODO: also get document store
# # ingest data into passage/document store # ingest data into passage/document store
# try: try:
# num_passages, num_documents = load_data( num_passages, num_documents = load_data(
# connector=connector, connector=connector,
# source=source, source=source,
# embedding_config=config.default_embedding_config, embedding_config=config.default_embedding_config,
# document_store=None, document_store=None,
# passage_store=passage_storage, passage_store=passage_storage,
# ) )
# print(f"Loaded {num_passages} passages and {num_documents} documents from {name}") print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
# except Exception as e: except Exception as e:
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
# ms.delete_source(source_id=source.id) ms.delete_source(source_id=source.id)
# except ValueError as e: except ValueError as e:
# typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED) typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
# raise raise

View File

@ -1,3 +1,4 @@
import uuid
from typing import List, Optional from typing import List, Optional
import requests import requests
@ -5,8 +6,19 @@ from requests import HTTPError
from memgpt.functions.functions import parse_source_code from memgpt.functions.functions import parse_source_code
from memgpt.functions.schema_generator import generate_schema from memgpt.functions.schema_generator import generate_schema
from memgpt.schemas.api_key import APIKey, APIKeyCreate from memgpt.server.rest_api.admin.tools import (
from memgpt.schemas.user import User, UserCreate CreateToolRequest,
ListToolsResponse,
ToolModel,
)
from memgpt.server.rest_api.admin.users import (
CreateAPIKeyResponse,
CreateUserResponse,
DeleteAPIKeyResponse,
DeleteUserResponse,
GetAllUsersResponse,
GetAPIKeysResponse,
)
class Admin: class Admin:
@ -21,7 +33,7 @@ class Admin:
self.token = token self.token = token
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"} self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
def get_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[User]: def get_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50):
params = {} params = {}
if cursor: if cursor:
params["cursor"] = str(cursor) params["cursor"] = str(cursor)
@ -30,54 +42,54 @@ class Admin:
response = requests.get(f"{self.base_url}/admin/users", params=params, headers=self.headers) response = requests.get(f"{self.base_url}/admin/users", params=params, headers=self.headers)
if response.status_code != 200: if response.status_code != 200:
raise HTTPError(response.json()) raise HTTPError(response.json())
return [User(**user) for user in response.json()] return GetAllUsersResponse(**response.json())
def create_key(self, user_id: str, key_name: Optional[str] = None) -> APIKey: def create_key(self, user_id: uuid.UUID, key_name: str):
request = APIKeyCreate(user_id=user_id, name=key_name) payload = {"user_id": str(user_id), "key_name": key_name}
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=request.model_dump()) response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload)
if response.status_code != 200: if response.status_code != 200:
raise HTTPError(response.json()) raise HTTPError(response.json())
return APIKey(**response.json()) return CreateAPIKeyResponse(**response.json())
def get_keys(self, user_id: str) -> List[APIKey]: def get_keys(self, user_id: uuid.UUID):
params = {"user_id": str(user_id)} params = {"user_id": str(user_id)}
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers) response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200: if response.status_code != 200:
raise HTTPError(response.json()) raise HTTPError(response.json())
return [APIKey(**key) for key in response.json()] return GetAPIKeysResponse(**response.json()).api_key_list
def delete_key(self, api_key: str) -> APIKey: def delete_key(self, api_key: str):
params = {"api_key": api_key} params = {"api_key": api_key}
response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers) response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200: if response.status_code != 200:
raise HTTPError(response.json()) raise HTTPError(response.json())
return APIKey(**response.json()) return DeleteAPIKeyResponse(**response.json())
def create_user(self, name: Optional[str] = None) -> User: def create_user(self, user_id: Optional[uuid.UUID] = None):
request = UserCreate(name=name) payload = {"user_id": str(user_id) if user_id else None}
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=request.model_dump()) response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
if response.status_code != 200: if response.status_code != 200:
raise HTTPError(response.json()) raise HTTPError(response.json())
response_json = response.json() response_json = response.json()
return User(**response_json) return CreateUserResponse(**response_json)
def delete_user(self, user_id: str) -> User: def delete_user(self, user_id: uuid.UUID):
params = {"user_id": str(user_id)} params = {"user_id": str(user_id)}
response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers) response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers)
if response.status_code != 200: if response.status_code != 200:
raise HTTPError(response.json()) raise HTTPError(response.json())
return User(**response.json()) return DeleteUserResponse(**response.json())
def _reset_server(self): def _reset_server(self):
# DANGER: this will delete all users and keys # DANGER: this will delete all users and keys
# clear all state associated with users # clear all state associated with users
# TODO: clear out all agents, presets, etc. # TODO: clear out all agents, presets, etc.
users = self.get_users() users = self.get_users().user_list
for user in users: for user in users:
keys = self.get_keys(user.id) keys = self.get_keys(user["user_id"])
for key in keys: for key in keys:
self.delete_key(key.key) self.delete_key(key)
self.delete_user(user.id) self.delete_user(user["user_id"])
# tools # tools
def create_tool( def create_tool(
@ -119,7 +131,7 @@ class Admin:
raise ValueError(f"Failed to create tool: {response.text}") raise ValueError(f"Failed to create tool: {response.text}")
return ToolModel(**response.json()) return ToolModel(**response.json())
def list_tools(self): def list_tools(self) -> ListToolsResponse:
response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers) response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers)
return ListToolsResponse(**response.json()).tools return ListToolsResponse(**response.json()).tools

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,6 @@ import json
import os import os
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import memgpt import memgpt
import memgpt.utils as utils import memgpt.utils as utils
@ -16,10 +15,8 @@ from memgpt.constants import (
DEFAULT_PRESET, DEFAULT_PRESET,
MEMGPT_DIR, MEMGPT_DIR,
) )
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig
from memgpt.log import get_logger from memgpt.log import get_logger
from memgpt.schemas.agent import AgentState
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
logger = get_logger(__name__) logger = get_logger(__name__)
@ -102,19 +99,19 @@ class MemGPTConfig:
return uuid.UUID(int=uuid.getnode()).hex return uuid.UUID(int=uuid.getnode()).hex
@classmethod @classmethod
def load(cls, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None) -> "MemGPTConfig": def load(cls) -> "MemGPTConfig":
# avoid circular import # avoid circular import
from memgpt.migrate import VERSION_CUTOFF, config_is_compatible
from memgpt.utils import printd from memgpt.utils import printd
# from memgpt.migrate import VERSION_CUTOFF, config_is_compatible if not config_is_compatible(allow_empty=True):
# if not config_is_compatible(allow_empty=True): error_message = " ".join(
# error_message = " ".join( [
# [ f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}.",
# f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}.", f"\nTo use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}) or regenerate your config using `memgpt configure`, or `memgpt migrate` if you would like to migrate old agents.",
# f"\nTo use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}) or regenerate your config using `memgpt configure`, or `memgpt migrate` if you would like to migrate old agents.", ]
# ] )
# ) raise ValueError(error_message)
# raise ValueError(error_message)
config = configparser.ConfigParser() config = configparser.ConfigParser()
@ -192,9 +189,6 @@ class MemGPTConfig:
return cls(**config_dict) return cls(**config_dict)
# assert embedding_config is not None, "Embedding config must be provided if config does not exist"
# assert llm_config is not None, "LLM config must be provided if config does not exist"
# create new config # create new config
anon_clientid = MemGPTConfig.generate_uuid() anon_clientid = MemGPTConfig.generate_uuid()
config = cls(anon_clientid=anon_clientid, config_path=config_path) config = cls(anon_clientid=anon_clientid, config_path=config_path)

View File

@ -4,10 +4,8 @@ import typer
from llama_index.core import Document as LlamaIndexDocument from llama_index.core import Document as LlamaIndexDocument
from memgpt.agent_store.storage import StorageConnector from memgpt.agent_store.storage import StorageConnector
from memgpt.data_types import Document, EmbeddingConfig, Passage, Source
from memgpt.embeddings import embedding_model from memgpt.embeddings import embedding_model
from memgpt.schemas.document import Document
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source
from memgpt.utils import create_uuid_from_string from memgpt.utils import create_uuid_from_string
@ -22,11 +20,17 @@ class DataConnector:
def load_data( def load_data(
connector: DataConnector, connector: DataConnector,
source: Source, source: Source,
embedding_config: EmbeddingConfig,
passage_store: StorageConnector, passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None, document_store: Optional[StorageConnector] = None,
): ):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id.""" """Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
embedding_config = source.embedding_config assert (
source.embedding_model == embedding_config.embedding_model
), f"Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}"
assert (
source.embedding_dim == embedding_config.embedding_dim
), f"Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}."
# embedding model # embedding model
embed_model = embedding_model(embedding_config) embed_model = embedding_model(embedding_config)
@ -39,9 +43,10 @@ def load_data(
for document_text, document_metadata in connector.generate_documents(): for document_text, document_metadata in connector.generate_documents():
# insert document into storage # insert document into storage
document = Document( document = Document(
id=create_uuid_from_string(f"{str(source.id)}_{document_text}"),
text=document_text, text=document_text,
metadata_=document_metadata, metadata=document_metadata,
source_id=source.id, data_source=source.name,
user_id=source.user_id, user_id=source.user_id,
) )
document_count += 1 document_count += 1
@ -73,15 +78,16 @@ def load_data(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"), id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
text=passage_text, text=passage_text,
doc_id=document.id, doc_id=document.id,
source_id=source.id,
metadata_=passage_metadata, metadata_=passage_metadata,
user_id=source.user_id, user_id=source.user_id,
embedding_config=source.embedding_config, data_source=source.name,
embedding_dim=source.embedding_dim,
embedding_model=source.embedding_model,
embedding=embedding, embedding=embedding,
) )
hashable_embedding = tuple(passage.embedding) hashable_embedding = tuple(passage.embedding)
document_name = document.metadata_.get("file_path", document.id) document_name = document.metadata.get("file_path", document.id)
if hashable_embedding in embedding_to_document_name: if hashable_embedding in embedding_to_document_name:
typer.secho( typer.secho(
f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.", f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
@ -144,7 +150,7 @@ class DirectoryConnector(DataConnector):
parser = TokenTextSplitter(chunk_size=chunk_size) parser = TokenTextSplitter(chunk_size=chunk_size)
for document in documents: for document in documents:
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)] llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata)]
nodes = parser.get_nodes_from_documents(llama_index_docs) nodes = parser.get_nodes_from_documents(llama_index_docs)
for node in nodes: for node in nodes:
# passage = Passage( # passage = Passage(

View File

@ -1,18 +1,75 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
import copy import copy
import json import json
import uuid
import warnings import warnings
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import List, Optional, Union from typing import Dict, List, Optional, TypeVar
from pydantic import Field, field_validator import numpy as np
from pydantic import BaseModel, Field
from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN from memgpt.constants import (
DEFAULT_HUMAN,
DEFAULT_PERSONA,
DEFAULT_PRESET,
JSON_ENSURE_ASCII,
LLM_MAX_TOKENS,
MAX_EMBEDDING_DIM,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
from memgpt.schemas.enums import MessageRole from memgpt.prompts import gpt_system
from memgpt.schemas.memgpt_base import MemGPTBase from memgpt.utils import (
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage create_uuid_from_string,
from memgpt.schemas.openai.chat_completions import ToolCall get_human_text,
from memgpt.utils import get_utc_time, is_utc_datetime get_persona_text,
get_utc_time,
is_utc_datetime,
)
class Record:
"""
Base class for an agent's memory unit. Each memory unit is represented in the database as a single row.
Memory units are searched over by functions defined in the memory classes
"""
def __init__(self, id: Optional[uuid.UUID] = None):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
# This allows type checking to work when you pass a Passage into a function expecting List[Record]
# (just use List[RecordType] instead)
RecordType = TypeVar("RecordType", bound="Record")
class ToolCall(object):
def __init__(
self,
id: str,
# TODO should we include this? it's fixed to 'function' only (for now) in OAI schema
# NOTE: called ToolCall.type in official OpenAI schema
tool_call_type: str, # only 'function' is supported
# function: { 'name': ..., 'arguments': ...}
function: Dict[str, str],
):
self.id = id
self.tool_call_type = tool_call_type
self.function = function
def to_dict(self):
return {
"id": self.id,
"type": self.tool_call_type,
"function": self.function,
}
def add_inner_thoughts_to_tool_call( def add_inner_thoughts_to_tool_call(
@ -24,34 +81,20 @@ def add_inner_thoughts_to_tool_call(
# because the kwargs are stored as strings, we need to load then write the JSON dicts # because the kwargs are stored as strings, we need to load then write the JSON dicts
try: try:
# load the args list # load the args list
func_args = json.loads(tool_call.function.arguments) func_args = json.loads(tool_call.function["arguments"])
# add the inner thoughts to the args list # add the inner thoughts to the args list
func_args[inner_thoughts_key] = inner_thoughts func_args[inner_thoughts_key] = inner_thoughts
# create the updated tool call (as a string) # create the updated tool call (as a string)
updated_tool_call = copy.deepcopy(tool_call) updated_tool_call = copy.deepcopy(tool_call)
updated_tool_call.function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII) updated_tool_call.function["arguments"] = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
return updated_tool_call return updated_tool_call
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
# TODO: change to logging
warnings.warn(f"Failed to put inner thoughts in kwargs: {e}") warnings.warn(f"Failed to put inner thoughts in kwargs: {e}")
raise e raise e
class BaseMessage(MemGPTBase): class Message(Record):
__id_prefix__ = "message" """Representation of a message sent.
class MessageCreate(BaseMessage):
"""Request to create a message"""
role: MessageRole = Field(..., description="The role of the participant.")
text: str = Field(..., description="The text of the message.")
name: Optional[str] = Field(None, description="The name of the participant.")
class Message(BaseMessage):
"""
Representation of a message sent.
Messages can be: Messages can be:
- agent->user (role=='agent') - agent->user (role=='agent')
@ -59,23 +102,65 @@ class Message(BaseMessage):
- or function/tool call returns (role=='function'/'tool'). - or function/tool call returns (role=='function'/'tool').
""" """
id: str = BaseMessage.generate_id_field() def __init__(
role: MessageRole = Field(..., description="The role of the participant.") self,
text: str = Field(..., description="The text of the message.") role: str,
user_id: str = Field(None, description="The unique identifier of the user.") text: str,
agent_id: str = Field(None, description="The unique identifier of the agent.") user_id: Optional[uuid.UUID] = None,
model: Optional[str] = Field(None, description="The model used to make the function call.") agent_id: Optional[uuid.UUID] = None,
name: Optional[str] = Field(None, description="The name of the participant.") model: Optional[str] = None, # model used to make function call
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.") name: Optional[str] = None, # optional participant name
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.") created_at: Optional[datetime] = None,
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.") tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
tool_call_id: Optional[str] = None,
# tool_call_name: Optional[str] = None, # not technically OpenAI spec, but it can be helpful to have on-hand
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
embedding_model: Optional[str] = None,
id: Optional[uuid.UUID] = None,
):
super().__init__(id)
self.user_id = user_id
self.agent_id = agent_id
self.text = text
self.model = model # model name (e.g. gpt-4)
self.created_at = created_at if created_at is not None else get_utc_time()
@field_validator("role") # openai info
@classmethod assert role in ["system", "assistant", "user", "tool"]
def validate_role(cls, v: str) -> str: self.role = role # role (agent/user/function)
roles = ["system", "assistant", "user", "tool"] self.name = name
assert v in roles, f"Role must be one of {roles}"
return v # pad and store embeddings
if isinstance(embedding, list):
embedding = np.array(embedding)
self.embedding = (
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
if self.embedding is not None:
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
# tool (i.e. function) call info (optional)
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
assert tool_calls is None or isinstance(tool_calls, list)
if tool_calls is not None:
assert all([isinstance(tc, ToolCall) for tc in tool_calls]), f"Tool calls must be of type ToolCall, got {tool_calls}"
self.tool_calls = tool_calls
# if role == "tool", then this must be specified
# if role != "tool", this must be null
if role == "tool":
assert tool_call_id is not None
else:
assert tool_call_id is None
self.tool_call_id = tool_call_id
def to_json(self): def to_json(self):
json_message = vars(self) json_message = vars(self)
@ -88,26 +173,16 @@ class Message(BaseMessage):
json_message["created_at"] = self.created_at.isoformat() json_message["created_at"] = self.created_at.isoformat()
return json_message return json_message
def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]:
"""Convert message object (in DB format) to the style used by the original MemGPT API
NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call)
"""
raise NotImplementedError
@staticmethod @staticmethod
def dict_to_message( def dict_to_message(
user_id: str, user_id: uuid.UUID,
agent_id: str, agent_id: uuid.UUID,
openai_message_dict: dict, openai_message_dict: dict,
model: Optional[str] = None, # model used to make function call model: Optional[str] = None, # model used to make function call
allow_functions_style: bool = False, # allow deprecated functions style? allow_functions_style: bool = False, # allow deprecated functions style?
created_at: Optional[datetime] = None, created_at: Optional[datetime] = None,
): ):
"""Convert a ChatCompletion message object into a Message object (synced to DB)""" """Convert a ChatCompletion message object into a Message object (synced to DB)"""
if not created_at:
# timestamp for creation
created_at = get_utc_time()
assert "role" in openai_message_dict, openai_message_dict assert "role" in openai_message_dict, openai_message_dict
assert "content" in openai_message_dict, openai_message_dict assert "content" in openai_message_dict, openai_message_dict
@ -121,6 +196,7 @@ class Message(BaseMessage):
# Convert from 'function' response to a 'tool' response # Convert from 'function' response to a 'tool' response
# NOTE: this does not conventionally include a tool_call_id, it's on the caster to provide it # NOTE: this does not conventionally include a tool_call_id, it's on the caster to provide it
return Message( return Message(
created_at=created_at,
user_id=user_id, user_id=user_id,
agent_id=agent_id, agent_id=agent_id,
model=model, model=model,
@ -130,7 +206,6 @@ class Message(BaseMessage):
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None, tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
) )
elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None: elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None:
@ -153,6 +228,7 @@ class Message(BaseMessage):
] ]
return Message( return Message(
created_at=created_at,
user_id=user_id, user_id=user_id,
agent_id=agent_id, agent_id=agent_id,
model=model, model=model,
@ -162,7 +238,6 @@ class Message(BaseMessage):
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls, tool_calls=tool_calls,
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool' tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
created_at=created_at,
) )
else: else:
@ -185,6 +260,7 @@ class Message(BaseMessage):
# If we're going from tool-call style # If we're going from tool-call style
return Message( return Message(
created_at=created_at,
user_id=user_id, user_id=user_id,
agent_id=agent_id, agent_id=agent_id,
model=model, model=model,
@ -194,7 +270,6 @@ class Message(BaseMessage):
name=openai_message_dict["name"] if "name" in openai_message_dict else None, name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls, tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
) )
def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict: def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict:
@ -248,11 +323,11 @@ class Message(BaseMessage):
tool_call, tool_call,
inner_thoughts=self.text, inner_thoughts=self.text,
inner_thoughts_key=INNER_THOUGHTS_KWARG, inner_thoughts_key=INNER_THOUGHTS_KWARG,
).model_dump() ).to_dict()
for tool_call in self.tool_calls for tool_call in self.tool_calls
] ]
else: else:
openai_message["tool_calls"] = [tool_call.model_dump() for tool_call in self.tool_calls] openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls]
if max_tool_id_length: if max_tool_id_length:
for tool_call_dict in openai_message["tool_calls"]: for tool_call_dict in openai_message["tool_calls"]:
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length] tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
@ -548,3 +623,313 @@ class Message(BaseMessage):
raise ValueError(self.role) raise ValueError(self.role)
return cohere_message return cohere_message
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None, metadata: Optional[Dict] = {}):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
self.id = create_uuid_from_string("".join([text, str(user_id)]))
else:
self.id = id
super().__init__(id)
self.user_id = user_id
self.text = text
self.data_source = data_source
self.metadata = metadata
# TODO: add optional embedding?
class Passage(Record):
"""A passage is a single unit of memory, and a standard format accross all storage backends.
It is a string of text with an assoidciated embedding.
"""
def __init__(
self,
text: str,
user_id: Optional[uuid.UUID] = None,
agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
embedding_model: Optional[str] = None,
data_source: Optional[str] = None, # None if created by agent
doc_id: Optional[uuid.UUID] = None,
id: Optional[uuid.UUID] = None,
metadata_: Optional[dict] = {},
created_at: Optional[datetime] = None,
):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
# TODO: use source-id instead?
if agent_id:
self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)]))
else:
self.id = create_uuid_from_string("".join([text, str(user_id)]))
else:
self.id = id
super().__init__(self.id)
self.user_id = user_id
self.agent_id = agent_id
self.text = text
self.data_source = data_source
self.doc_id = doc_id
self.metadata_ = metadata_
# pad and store embeddings
if isinstance(embedding, list):
embedding = np.array(embedding)
self.embedding = (
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
self.created_at = created_at if created_at is not None else get_utc_time()
if self.embedding is not None:
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type"
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type"
assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type"
class LLMConfig:
def __init__(
self,
model: Optional[str] = None,
model_endpoint_type: Optional[str] = None,
model_endpoint: Optional[str] = None,
model_wrapper: Optional[str] = None,
context_window: Optional[int] = None,
):
self.model = model
self.model_endpoint_type = model_endpoint_type
self.model_endpoint = model_endpoint
self.model_wrapper = model_wrapper
self.context_window = context_window
if context_window is None:
self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
else:
self.context_window = context_window
class EmbeddingConfig:
def __init__(
self,
embedding_endpoint_type: Optional[str] = None,
embedding_endpoint: Optional[str] = None,
embedding_model: Optional[str] = None,
embedding_dim: Optional[int] = None,
embedding_chunk_size: Optional[int] = 300,
):
self.embedding_endpoint_type = embedding_endpoint_type
self.embedding_endpoint = embedding_endpoint
self.embedding_model = embedding_model
self.embedding_dim = embedding_dim
self.embedding_chunk_size = embedding_chunk_size
# fields cannot be set to None
assert self.embedding_endpoint_type
assert self.embedding_dim
assert self.embedding_chunk_size
class OpenAIEmbeddingConfig(EmbeddingConfig):
def __init__(self, openai_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.openai_key = openai_key
class AzureEmbeddingConfig(EmbeddingConfig):
def __init__(
self,
azure_key: Optional[str] = None,
azure_endpoint: Optional[str] = None,
azure_version: Optional[str] = None,
azure_deployment: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.azure_key = azure_key
self.azure_endpoint = azure_endpoint
self.azure_version = azure_version
self.azure_deployment = azure_deployment
class User:
"""Defines user and default configurations"""
# TODO: make sure to encrypt/decrypt keys before storing in DB
def __init__(
self,
# name: str,
id: Optional[uuid.UUID] = None,
default_agent=None,
# other
policies_accepted=False,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
self.default_agent = default_agent
# misc
self.policies_accepted = policies_accepted
class AgentState:
def __init__(
self,
name: str,
user_id: uuid.UUID,
# tools
tools: List[str], # list of tools by name
# system prompt
system: str,
# config
llm_config: LLMConfig,
embedding_config: EmbeddingConfig,
# (in-context) state contains:
id: Optional[uuid.UUID] = None,
state: Optional[dict] = None,
created_at: Optional[datetime] = None,
# messages (TODO: implement this)
_metadata: Optional[dict] = None,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
# TODO(swooders) we need to handle the case where name is None here
# in AgentConfig we autogenerate a name, not sure what the correct thing w/ DBs is, what about NounAdjective combos? Like giphy does? BoredGiraffe etc
self.name = name
assert self.name, f"AgentState name must be a non-empty string"
self.user_id = user_id
# The INITIAL values of the persona and human
# The values inside self.state['persona'], self.state['human'] are the CURRENT values
self.llm_config = llm_config
self.embedding_config = embedding_config
self.created_at = created_at if created_at is not None else get_utc_time()
# state
self.state = {} if not state else state
# tools
self.tools = tools
# system
self.system = system
assert self.system is not None, f"Must provide system prompt, cannot be None"
# metadata
self._metadata = _metadata
class Source:
def __init__(
self,
user_id: uuid.UUID,
name: str,
description: Optional[str] = None,
created_at: Optional[datetime] = None,
id: Optional[uuid.UUID] = None,
# embedding info
embedding_model: Optional[str] = None,
embedding_dim: Optional[int] = None,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
self.name = name
self.user_id = user_id
self.description = description
self.created_at = created_at if created_at is not None else get_utc_time()
# embedding info (optional)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
class Token:
def __init__(
self,
user_id: uuid.UUID,
token: str,
name: Optional[str] = None,
id: Optional[uuid.UUID] = None,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
self.token = token
self.user_id = user_id
self.name = name
class Preset(BaseModel):
# TODO: remove Preset
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: str = Field(
gpt_system.get_system_text(DEFAULT_PRESET), description="The system prompt of the preset."
) # default system prompt is same as default preset name
# system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID
# sources: List[str] = Field(..., description="The sources of the preset.") # TODO: convert to ID
@staticmethod
def clone(preset_obj: "Preset", new_name_suffix: str = None) -> "Preset":
"""
Takes a Preset object and an optional new name suffix as input,
creates a clone of the given Preset object with a new ID and an optional new name,
and returns the new Preset object.
"""
new_preset = preset_obj.model_copy()
new_preset.id = uuid.uuid4()
if new_name_suffix:
new_preset.name = f"{preset_obj.name}_{new_name_suffix}"
else:
new_preset.name = f"{preset_obj.name}_{str(uuid.uuid4())[:8]}"
return new_preset
class Function(BaseModel):
name: str = Field(..., description="The name of the function.")
id: uuid.UUID = Field(..., description="The unique identifier of the function.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the function.")
# TODO: figure out how represent functions

View File

@ -22,7 +22,7 @@ from memgpt.constants import (
MAX_EMBEDDING_DIM, MAX_EMBEDDING_DIM,
) )
from memgpt.credentials import MemGPTCredentials from memgpt.credentials import MemGPTCredentials
from memgpt.schemas.embedding_config import EmbeddingConfig from memgpt.data_types import EmbeddingConfig
from memgpt.utils import is_valid_url, printd from memgpt.utils import is_valid_url, printd

View File

@ -11,8 +11,8 @@ from memgpt.constants import (
MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
) )
from memgpt.data_types import Message
from memgpt.llm_api.llm_api_tools import create from memgpt.llm_api.llm_api_tools import create
from memgpt.schemas.message import Message
def message_chatgpt(self, message: str): def message_chatgpt(self, message: str):

View File

@ -1,6 +1,6 @@
import inspect import inspect
import typing import typing
from typing import Any, Dict, Optional, Type, get_args, get_origin from typing import Optional, get_args, get_origin
from docstring_parser import parse from docstring_parser import parse
from pydantic import BaseModel from pydantic import BaseModel
@ -144,39 +144,3 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT) schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT)
return schema return schema
def generate_schema_from_args_schema(
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None
) -> Dict[str, Any]:
properties = {}
required = []
for field_name, field in args_schema.__fields__.items():
properties[field_name] = {"type": field.type_.__name__, "description": field.field_info.description}
if field.required:
required.append(field_name)
# Construct the OpenAI function call JSON object
function_call_json = {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": properties, "required": required},
}
return function_call_json
def generate_tool_wrapper(tool_name: str) -> str:
import_statement = f"from crewai_tools import {tool_name}"
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
{import_statement}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str

View File

@ -6,7 +6,7 @@ from typing import List, Optional
from colorama import Fore, Style, init from colorama import Fore, Style, init
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.schemas.message import Message from memgpt.data_types import Message
from memgpt.utils import printd from memgpt.utils import printd
init(autoreset=True) init(autoreset=True)

View File

@ -6,17 +6,17 @@ from typing import List, Optional, Union
import requests import requests
from memgpt.constants import JSON_ENSURE_ASCII from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.schemas.message import Message from memgpt.data_types import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.schemas.openai.chat_completion_response import ( from memgpt.models.chat_completion_response import (
ChatCompletionResponse, ChatCompletionResponse,
Choice, Choice,
FunctionCall, FunctionCall,
) )
from memgpt.schemas.openai.chat_completion_response import ( from memgpt.models.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
) )
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from memgpt.models.chat_completion_response import ToolCall, UsageStatistics
from memgpt.utils import get_utc_time, smart_urljoin from memgpt.utils import get_utc_time, smart_urljoin
BASE_URL = "https://api.anthropic.com/v1" BASE_URL = "https://api.anthropic.com/v1"

View File

@ -2,8 +2,8 @@ from typing import Union
import requests import requests
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.schemas.openai.embedding_response import EmbeddingResponse from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.utils import smart_urljoin from memgpt.utils import smart_urljoin
MODEL_TO_AZURE_ENGINE = { MODEL_TO_AZURE_ENGINE = {

View File

@ -5,18 +5,18 @@ from typing import List, Optional, Union
import requests import requests
from memgpt.constants import JSON_ENSURE_ASCII from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.data_types import Message
from memgpt.local_llm.utils import count_tokens from memgpt.local_llm.utils import count_tokens
from memgpt.schemas.message import Message from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool from memgpt.models.chat_completion_response import (
from memgpt.schemas.openai.chat_completion_response import (
ChatCompletionResponse, ChatCompletionResponse,
Choice, Choice,
FunctionCall, FunctionCall,
) )
from memgpt.schemas.openai.chat_completion_response import ( from memgpt.models.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
) )
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics from memgpt.models.chat_completion_response import ToolCall, UsageStatistics
from memgpt.utils import get_tool_call_id, get_utc_time, smart_urljoin from memgpt.utils import get_tool_call_id, get_utc_time, smart_urljoin
BASE_URL = "https://api.cohere.ai/v1" BASE_URL = "https://api.cohere.ai/v1"

View File

@ -7,8 +7,8 @@ import requests
from memgpt.constants import JSON_ENSURE_ASCII, NON_USER_MSG_PREFIX from memgpt.constants import JSON_ENSURE_ASCII, NON_USER_MSG_PREFIX
from memgpt.local_llm.json_parser import clean_json_string_extra_backslash from memgpt.local_llm.json_parser import clean_json_string_extra_backslash
from memgpt.local_llm.utils import count_tokens from memgpt.local_llm.utils import count_tokens
from memgpt.schemas.openai.chat_completion_request import Tool from memgpt.models.chat_completion_request import Tool
from memgpt.schemas.openai.chat_completion_response import ( from memgpt.models.chat_completion_response import (
ChatCompletionResponse, ChatCompletionResponse,
Choice, Choice,
FunctionCall, FunctionCall,

View File

@ -11,6 +11,7 @@ import requests
from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
from memgpt.credentials import MemGPTCredentials from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import Message
from memgpt.llm_api.anthropic import anthropic_chat_completions_request from memgpt.llm_api.anthropic import anthropic_chat_completions_request
from memgpt.llm_api.azure_openai import ( from memgpt.llm_api.azure_openai import (
MODEL_TO_AZURE_ENGINE, MODEL_TO_AZURE_ENGINE,
@ -30,15 +31,13 @@ from memgpt.local_llm.constants import (
INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION,
) )
from memgpt.schemas.enums import OptionState from memgpt.models.chat_completion_request import (
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import (
ChatCompletionRequest, ChatCompletionRequest,
Tool, Tool,
cast_message_to_subtype, cast_message_to_subtype,
) )
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.pydantic_models import LLMConfigModel, OptionState
from memgpt.streaming_interface import ( from memgpt.streaming_interface import (
AgentChunkStreamingInterface, AgentChunkStreamingInterface,
AgentRefreshStreamingInterface, AgentRefreshStreamingInterface,
@ -229,7 +228,7 @@ def retry_with_exponential_backoff(
@retry_with_exponential_backoff @retry_with_exponential_backoff
def create( def create(
# agent_state: AgentState, # agent_state: AgentState,
llm_config: LLMConfig, llm_config: LLMConfigModel,
messages: List[Message], messages: List[Message],
user_id: uuid.UUID = None, # option UUID to associate request with user_id: uuid.UUID = None, # option UUID to associate request with
functions: list = None, functions: list = None,
@ -260,6 +259,8 @@ def create(
printd("unsetting function_call because functions is None") printd("unsetting function_call because functions is None")
function_call = None function_call = None
# print("HELLO")
# openai # openai
if llm_config.model_endpoint_type == "openai": if llm_config.model_endpoint_type == "openai":

View File

@ -7,8 +7,8 @@ from httpx_sse import connect_sse
from httpx_sse._exceptions import SSEError from httpx_sse._exceptions import SSEError
from memgpt.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from memgpt.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest from memgpt.models.chat_completion_request import ChatCompletionRequest
from memgpt.schemas.openai.chat_completion_response import ( from memgpt.models.chat_completion_response import (
ChatCompletionChunkResponse, ChatCompletionChunkResponse,
ChatCompletionResponse, ChatCompletionResponse,
Choice, Choice,
@ -17,7 +17,7 @@ from memgpt.schemas.openai.chat_completion_response import (
ToolCall, ToolCall,
UsageStatistics, UsageStatistics,
) )
from memgpt.schemas.openai.embedding_response import EmbeddingResponse from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.streaming_interface import ( from memgpt.streaming_interface import (
AgentChunkStreamingInterface, AgentChunkStreamingInterface,
AgentRefreshStreamingInterface, AgentRefreshStreamingInterface,
@ -89,7 +89,6 @@ def openai_chat_completions_process_stream(
on the chunks received from the OpenAI-compatible server POST SSE response. on the chunks received from the OpenAI-compatible server POST SSE response.
""" """
assert chat_completion_request.stream == True assert chat_completion_request.stream == True
assert stream_inferface is not None, "Required"
# Count the prompt tokens # Count the prompt tokens
# TODO move to post-request? # TODO move to post-request?
@ -371,10 +370,7 @@ def openai_chat_completions_request(
url = smart_urljoin(url, "chat/completions") url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = chat_completion_request.model_dump(exclude_none=True) data = chat_completion_request.model_dump(exclude_none=True)
data["parallel_tool_calls"] = False
# add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified."
if chat_completion_request.tools is not None:
data["parallel_tool_calls"] = False
printd("Request:\n", json.dumps(data, indent=2)) printd("Request:\n", json.dumps(data, indent=2))
@ -390,7 +386,7 @@ def openai_chat_completions_request(
printd(f"Sending request to {url}") printd(f"Sending request to {url}")
try: try:
response = requests.post(url, headers=headers, json=data) response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}, response.text = {response.text}") # printd(f"response = {response}, response.text = {response.text}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string response = response.json() # convert to dict from string

View File

@ -25,14 +25,14 @@ from memgpt.local_llm.webui.api import get_webui_completion
from memgpt.local_llm.webui.legacy_api import ( from memgpt.local_llm.webui.legacy_api import (
get_webui_completion as get_webui_completion_legacy, get_webui_completion as get_webui_completion_legacy,
) )
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE from memgpt.models.chat_completion_response import (
from memgpt.schemas.openai.chat_completion_response import (
ChatCompletionResponse, ChatCompletionResponse,
Choice, Choice,
Message, Message,
ToolCall, ToolCall,
UsageStatistics, UsageStatistics,
) )
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from memgpt.utils import get_tool_call_id, get_utc_time from memgpt.utils import get_tool_call_id, get_utc_time
has_shown_warning = False has_shown_warning = False

View File

@ -11,12 +11,20 @@ from rich.console import Console
import memgpt.agent as agent import memgpt.agent as agent
import memgpt.errors as errors import memgpt.errors as errors
import memgpt.system as system import memgpt.system as system
from memgpt.agent_store.storage import StorageConnector, TableType
# import benchmark # import benchmark
from memgpt import create_client
from memgpt.benchmark.benchmark import bench from memgpt.benchmark.benchmark import bench
from memgpt.cli.cli import delete_agent, open_folder, quickstart, run, server, version from memgpt.cli.cli import (
from memgpt.cli.cli_config import add, add_tool, configure, delete, list, list_tools delete_agent,
migrate,
open_folder,
quickstart,
run,
server,
version,
)
from memgpt.cli.cli_config import add, configure, delete, list
from memgpt.cli.cli_load import app as load_app from memgpt.cli.cli_load import app as load_app
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.constants import ( from memgpt.constants import (
@ -26,7 +34,7 @@ from memgpt.constants import (
REQ_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE,
) )
from memgpt.metadata import MetadataStore from memgpt.metadata import MetadataStore
from memgpt.schemas.enums import OptionState from memgpt.models.pydantic_models import OptionState
# from memgpt.interface import CLIInterface as interface # for printing to terminal # from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.streaming_interface import AgentRefreshStreamingInterface from memgpt.streaming_interface import AgentRefreshStreamingInterface
@ -39,14 +47,14 @@ app.command(name="version")(version)
app.command(name="configure")(configure) app.command(name="configure")(configure)
app.command(name="list")(list) app.command(name="list")(list)
app.command(name="add")(add) app.command(name="add")(add)
app.command(name="add-tool")(add_tool)
app.command(name="list-tools")(list_tools)
app.command(name="delete")(delete) app.command(name="delete")(delete)
app.command(name="server")(server) app.command(name="server")(server)
app.command(name="folder")(open_folder) app.command(name="folder")(open_folder)
app.command(name="quickstart")(quickstart) app.command(name="quickstart")(quickstart)
# load data commands # load data commands
app.add_typer(load_app, name="load") app.add_typer(load_app, name="load")
# migration command
app.command(name="migrate")(migrate)
# benchmark command # benchmark command
app.command(name="benchmark")(bench) app.command(name="benchmark")(bench)
# delete agents # delete agents
@ -95,12 +103,7 @@ def run_agent_loop(
print() print()
multiline_input = False multiline_input = False
ms = MetadataStore(config)
# create client
client = create_client()
ms = MetadataStore(config) # TODO: remove
# run loops
while True: while True:
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
# Ask for user input # Ask for user input
@ -148,8 +151,8 @@ def run_agent_loop(
# TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not # TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not
# TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources # TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources
sources = client.list_sources() data_source_options = ms.list_sources(user_id=memgpt_agent.agent_state.user_id)
if len(sources) == 0: if len(data_source_options) == 0:
typer.secho( typer.secho(
'No sources available. You must load a souce with "memgpt load ..." before running /attach.', 'No sources available. You must load a souce with "memgpt load ..." before running /attach.',
fg=typer.colors.RED, fg=typer.colors.RED,
@ -160,8 +163,11 @@ def run_agent_loop(
# determine what sources are valid to be attached to this agent # determine what sources are valid to be attached to this agent
valid_options = [] valid_options = []
invalid_options = [] invalid_options = []
for source in sources: for source in data_source_options:
if source.embedding_config == memgpt_agent.agent_state.embedding_config: if (
source.embedding_model == memgpt_agent.agent_state.embedding_config.embedding_model
and source.embedding_dim == memgpt_agent.agent_state.embedding_config.embedding_dim
):
valid_options.append(source.name) valid_options.append(source.name)
else: else:
# print warning about invalid sources # print warning about invalid sources
@ -175,7 +181,11 @@ def run_agent_loop(
data_source = questionary.select("Select data source", choices=valid_options).ask() data_source = questionary.select("Select data source", choices=valid_options).ask()
# attach new data # attach new data
client.attach_source_to_agent(agent_id=memgpt_agent.agent_state.id, source_name=data_source) # attach(memgpt_agent.agent_state.name, data_source)
source_connector = StorageConnector.get_storage_connector(
TableType.PASSAGES, config, user_id=memgpt_agent.agent_state.user_id
)
memgpt_agent.attach_source(data_source, source_connector, ms)
continue continue
@ -420,10 +430,8 @@ def run_agent_loop(
skip_verify=no_verify, skip_verify=no_verify,
stream=stream, stream=stream,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
ms=ms,
) )
agent.save_agent(memgpt_agent, ms)
skip_next_user_input = False skip_next_user_input = False
if token_warning: if token_warning:
user_message = system.get_token_limit_warning() user_message = system.get_token_limit_warning()

View File

@ -3,14 +3,13 @@ import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from pydantic import BaseModel, validator
from memgpt.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC from memgpt.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
from memgpt.data_types import AgentState, Message, Passage
from memgpt.embeddings import embedding_model, parse_and_chunk_text, query_embedding from memgpt.embeddings import embedding_model, parse_and_chunk_text, query_embedding
from memgpt.llm_api.llm_api_tools import create from memgpt.llm_api.llm_api_tools import create
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt.schemas.agent import AgentState
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.schemas.passage import Passage
from memgpt.utils import ( from memgpt.utils import (
count_tokens, count_tokens,
extract_date_from_timestamp, extract_date_from_timestamp,
@ -19,135 +18,125 @@ from memgpt.utils import (
validate_date_format, validate_date_format,
) )
# class MemoryModule(BaseModel):
# """Base class for memory modules""" class MemoryModule(BaseModel):
# """Base class for memory modules"""
# description: Optional[str] = None
# limit: int = 2000 description: Optional[str] = None
# value: Optional[Union[List[str], str]] = None limit: int = 2000
# value: Optional[Union[List[str], str]] = None
# def __setattr__(self, name, value):
# """Run validation if self.value is updated""" def __setattr__(self, name, value):
# super().__setattr__(name, value) """Run validation if self.value is updated"""
# if name == "value": super().__setattr__(name, value)
# # run validation if name == "value":
# self.__class__.validate(self.dict(exclude_unset=True)) # run validation
# self.__class__.validate(self.dict(exclude_unset=True))
# @validator("value", always=True)
# def check_value_length(cls, v, values): @validator("value", always=True)
# if v is not None: def check_value_length(cls, v, values):
# # Fetching the limit from the values dictionary if v is not None:
# limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set # Fetching the limit from the values dictionary
# limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
# # Check if the value exceeds the limit
# if isinstance(v, str): # Check if the value exceeds the limit
# length = len(v) if isinstance(v, str):
# elif isinstance(v, list): length = len(v)
# length = sum(len(item) for item in v) elif isinstance(v, list):
# else: length = sum(len(item) for item in v)
# raise ValueError("Value must be either a string or a list of strings.") else:
# raise ValueError("Value must be either a string or a list of strings.")
# if length > limit:
# error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})." if length > limit:
# # TODO: add archival memory error? error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# raise ValueError(error_msg) # TODO: add archival memory error?
# return v raise ValueError(error_msg)
# return v
# def __len__(self):
# return len(str(self)) def __len__(self):
# return len(str(self))
# def __str__(self) -> str:
# if isinstance(self.value, list): def __str__(self) -> str:
# return ",".join(self.value) if isinstance(self.value, list):
# elif isinstance(self.value, str): return ",".join(self.value)
# return self.value elif isinstance(self.value, str):
# else: return self.value
# return "" else:
# return ""
#
# class BaseMemory:
#
# def __init__(self):
# self.memory = {}
#
# @classmethod
# def load(cls, state: dict):
# """Load memory from dictionary object"""
# obj = cls()
# for key, value in state.items():
# obj.memory[key] = MemoryModule(**value)
# return obj
#
# def __str__(self) -> str:
# """Representation of the memory in-context"""
# section_strs = []
# for section, module in self.memory.items():
# section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
# return "\n".join(section_strs)
#
# def to_dict(self):
# """Convert to dictionary representation"""
# return {key: value.dict() for key, value in self.memory.items()}
#
#
# class ChatMemory(BaseMemory):
#
# def __init__(self, persona: str, human: str, limit: int = 2000):
# self.memory = {
# "persona": MemoryModule(name="persona", value=persona, limit=limit),
# "human": MemoryModule(name="human", value=human, limit=limit),
# }
#
# def core_memory_append(self, name: str, content: str) -> Optional[str]:
# """
# Append to the contents of core memory.
#
# Args:
# name (str): Section of the memory to be edited (persona or human).
# content (str): Content to write to the memory. All unicode (including emojis) are supported.
#
# Returns:
# Optional[str]: None is always returned as this function does not produce a response.
# """
# self.memory[name].value += "\n" + content
# return None
#
# def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
# """
# Replace the contents of core memory. To delete memories, use an empty string for new_content.
#
# Args:
# name (str): Section of the memory to be edited (persona or human).
# old_content (str): String to replace. Must be an exact match.
# new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
#
# Returns:
# Optional[str]: None is always returned as this function does not produce a response.
# """
# self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
# return None
def get_memory_functions(cls: Memory) -> List[callable]: class BaseMemory:
def __init__(self):
self.memory = {}
@classmethod
def load(cls, state: dict):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
obj.memory[key] = MemoryModule(**value)
return obj
def __str__(self) -> str:
"""Representation of the memory in-context"""
section_strs = []
for section, module in self.memory.items():
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
return "\n".join(section_strs)
def to_dict(self):
"""Convert to dictionary representation"""
return {key: value.dict() for key, value in self.memory.items()}
class ChatMemory(BaseMemory):
def __init__(self, persona: str, human: str, limit: int = 2000):
self.memory = {
"persona": MemoryModule(name="persona", value=persona, limit=limit),
"human": MemoryModule(name="human", value=human, limit=limit),
}
def core_memory_append(self, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
Args:
name (str): Section of the memory to be edited (persona or human).
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory[name].value += "\n" + content
return None
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
name (str): Section of the memory to be edited (persona or human).
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
return None
def get_memory_functions(cls: BaseMemory) -> List[callable]:
"""Get memory functions for a memory class""" """Get memory functions for a memory class"""
functions = {} functions = {}
# collect base memory functions (should not be included)
base_functions = []
for func_name in dir(Memory):
funct = getattr(Memory, func_name)
if callable(funct):
base_functions.append(func_name)
for func_name in dir(cls): for func_name in dir(cls):
if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
continue continue
if func_name in base_functions: # dont use BaseMemory functions
continue
func = getattr(cls, func_name) func = getattr(cls, func_name)
if not callable(func): # not a function if callable(func):
continue functions[func_name] = func
functions[func_name] = func
return functions return functions
@ -264,8 +253,8 @@ def summarize_messages(
+ message_sequence_to_summarize[cutoff:] + message_sequence_to_summarize[cutoff:]
) )
dummy_user_id = agent_state.user_id dummy_user_id = uuid.uuid4()
dummy_agent_id = agent_state.id dummy_agent_id = uuid.uuid4()
message_sequence = [] message_sequence = []
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt)) message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt))
if insert_acknowledgement_assistant_message: if insert_acknowledgement_assistant_message:
@ -528,7 +517,8 @@ class EmbeddingArchivalMemory(ArchivalMemory):
agent_id=self.agent_state.id, agent_id=self.agent_state.id,
text=text, text=text,
embedding=embedding, embedding=embedding,
embedding_config=self.agent_state.embedding_config, embedding_dim=self.agent_state.embedding_config.embedding_dim,
embedding_model=self.agent_state.embedding_config.embedding_model,
) )
def save(self): def save(self):

View File

@ -3,10 +3,12 @@
import os import os
import secrets import secrets
import traceback import traceback
import uuid
from typing import List, Optional from typing import List, Optional
from sqlalchemy import ( from sqlalchemy import (
BIGINT, BIGINT,
CHAR,
JSON, JSON,
Boolean, Boolean,
Column, Column,
@ -18,28 +20,58 @@ from sqlalchemy import (
desc, desc,
func, func,
) )
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.exc import InterfaceError, OperationalError from sqlalchemy.exc import InterfaceError, OperationalError
from sqlalchemy.orm import declarative_base, sessionmaker from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.sql import func from sqlalchemy.sql import func
from memgpt.config import MemGPTConfig from memgpt.config import MemGPTConfig
from memgpt.schemas.agent import AgentState from memgpt.data_types import (
from memgpt.schemas.api_key import APIKey AgentState,
from memgpt.schemas.block import Block, Human, Persona EmbeddingConfig,
from memgpt.schemas.embedding_config import EmbeddingConfig LLMConfig,
from memgpt.schemas.enums import JobStatus Preset,
from memgpt.schemas.job import Job Source,
from memgpt.schemas.llm_config import LLMConfig Token,
from memgpt.schemas.memory import Memory User,
from memgpt.schemas.source import Source )
from memgpt.schemas.tool import Tool from memgpt.models.pydantic_models import (
from memgpt.schemas.user import User HumanModel,
JobModel,
JobStatus,
PersonaModel,
ToolModel,
)
from memgpt.settings import settings from memgpt.settings import settings
from memgpt.utils import enforce_types, get_utc_time, printd from memgpt.utils import enforce_types, get_utc_time, printd
Base = declarative_base() Base = declarative_base()
# Custom UUID type
class CommonUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR())
def process_bind_param(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return str(value) # Convert UUID to string for SQLite
def process_result_value(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return uuid.UUID(value)
class LLMConfigColumn(TypeDecorator): class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON""" """Custom type for storing LLMConfig as JSON"""
@ -84,44 +116,48 @@ class UserModel(Base):
__tablename__ = "users" __tablename__ = "users"
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
name = Column(String, nullable=False) # name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True)) default_agent = Column(String)
# TODO: what is this?
policies_accepted = Column(Boolean, nullable=False, default=False) policies_accepted = Column(Boolean, nullable=False, default=False)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<User(id='{self.id}' name='{self.name}')>" return f"<User(id='{self.id}')>"
def to_record(self) -> User: def to_record(self) -> User:
return User(id=self.id, name=self.name, created_at=self.created_at) return User(
id=self.id,
# name=self.name
default_agent=self.default_agent,
policies_accepted=self.policies_accepted,
)
class APIKeyModel(Base): class TokenModel(Base):
"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens).""" """Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""
__tablename__ = "tokens" __tablename__ = "tokens"
id = Column(String, primary_key=True) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# each api key is tied to a user account (that it validates access for) # each api key is tied to a user account (that it validates access for)
user_id = Column(String, nullable=False) user_id = Column(CommonUUID, nullable=False)
# the api key # the api key
key = Column(String, nullable=False) token = Column(String, nullable=False)
# extra (optional) metadata # extra (optional) metadata
name = Column(String) name = Column(String)
Index(__tablename__ + "_idx_user", user_id), Index(__tablename__ + "_idx_user", user_id),
Index(__tablename__ + "_idx_key", key), Index(__tablename__ + "_idx_token", token),
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<APIKey(id='{self.id}', key='{self.key}', name='{self.name}')>" return f"<Token(id='{self.id}', token='{self.token}', name='{self.name}')>"
def to_record(self) -> User: def to_record(self) -> User:
return APIKey( return Token(
id=self.id, id=self.id,
user_id=self.user_id, user_id=self.user_id,
key=self.key, token=self.token,
name=self.name, name=self.name,
) )
@ -140,24 +176,19 @@ class AgentModel(Base):
__tablename__ = "agents" __tablename__ = "agents"
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False) user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
description = Column(String)
# state (context compilation)
message_ids = Column(JSON)
memory = Column(JSON)
system = Column(String) system = Column(String)
tools = Column(JSON) created_at = Column(DateTime(timezone=True), server_default=func.now())
# configs # configs
llm_config = Column(LLMConfigColumn) llm_config = Column(LLMConfigColumn)
embedding_config = Column(EmbeddingConfigColumn) embedding_config = Column(EmbeddingConfigColumn)
# state # state
metadata_ = Column(JSON) state = Column(JSON)
_metadata = Column(JSON)
# tools # tools
tools = Column(JSON) tools = Column(JSON)
@ -173,14 +204,12 @@ class AgentModel(Base):
user_id=self.user_id, user_id=self.user_id,
name=self.name, name=self.name,
created_at=self.created_at, created_at=self.created_at,
description=self.description,
message_ids=self.message_ids,
memory=Memory.load(self.memory), # load dictionary
system=self.system,
tools=self.tools,
llm_config=self.llm_config, llm_config=self.llm_config,
embedding_config=self.embedding_config, embedding_config=self.embedding_config,
metadata_=self.metadata_, state=self.state,
tools=self.tools,
system=self.system,
_metadata=self._metadata,
) )
@ -192,13 +221,13 @@ class SourceModel(Base):
# Assuming passage_id is the primary key # Assuming passage_id is the primary key
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(String, primary_key=True) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False) user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False) name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now()) created_at = Column(DateTime(timezone=True), server_default=func.now())
embedding_config = Column(EmbeddingConfigColumn) embedding_dim = Column(BIGINT)
embedding_model = Column(String)
description = Column(String) description = Column(String)
metadata_ = Column(JSON)
Index(__tablename__ + "_idx_user", user_id), Index(__tablename__ + "_idx_user", user_id),
# TODO: add num passages # TODO: add num passages
@ -212,9 +241,9 @@ class SourceModel(Base):
user_id=self.user_id, user_id=self.user_id,
name=self.name, name=self.name,
created_at=self.created_at, created_at=self.created_at,
embedding_config=self.embedding_config, embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
description=self.description, description=self.description,
metadata_=self.metadata_,
) )
@ -223,116 +252,80 @@ class AgentSourceMappingModel(Base):
__tablename__ = "agent_source_mapping" __tablename__ = "agent_source_mapping"
id = Column(String, primary_key=True) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False) user_id = Column(CommonUUID, nullable=False)
agent_id = Column(String, nullable=False) agent_id = Column(CommonUUID, nullable=False)
source_id = Column(String, nullable=False) source_id = Column(CommonUUID, nullable=False)
Index(__tablename__ + "_idx_user", user_id, agent_id, source_id), Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>" return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
class BlockModel(Base): class PresetSourceMapping(Base):
__tablename__ = "block" __tablename__ = "preset_source_mapping"
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
preset_id = Column(CommonUUID, nullable=False)
source_id = Column(CommonUUID, nullable=False)
Index(__tablename__ + "_idx_user", user_id, preset_id, source_id),
def __repr__(self) -> str:
return f"<PresetSourceMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', source_id='{self.source_id}')>"
# class PresetFunctionMapping(Base):
# __tablename__ = "preset_function_mapping"
#
# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# user_id = Column(CommonUUID, nullable=False)
# preset_id = Column(CommonUUID, nullable=False)
# #function_id = Column(CommonUUID, nullable=False)
# function = Column(String, nullable=False) # TODO: convert to ID eventually
#
# def __repr__(self) -> str:
# return f"<PresetFunctionMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', function_id='{self.function_id}')>"
class PresetModel(Base):
"""Defines data model for storing Preset objects"""
__tablename__ = "presets"
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True, nullable=False) id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
value = Column(String, nullable=False) user_id = Column(CommonUUID, nullable=False)
limit = Column(BIGINT)
name = Column(String, nullable=False) name = Column(String, nullable=False)
template = Column(Boolean, default=False) # True: listed as possible human/persona
label = Column(String)
metadata_ = Column(JSON)
description = Column(String) description = Column(String)
user_id = Column(String) system = Column(String)
human = Column(String)
human_name = Column(String, nullable=False)
persona = Column(String)
persona_name = Column(String, nullable=False)
preset = Column(String)
created_at = Column(DateTime(timezone=True), server_default=func.now())
functions_schema = Column(JSON)
Index(__tablename__ + "_idx_user", user_id), Index(__tablename__ + "_idx_user", user_id),
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Block(id='{self.id}', name='{self.name}', template='{self.template}', label='{self.label}', user_id='{self.user_id}')>" return f"<Preset(id='{self.id}', name='{self.name}')>"
def to_record(self) -> Block: def to_record(self) -> Preset:
if self.label == "persona": return Preset(
return Persona(
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
description=self.description,
user_id=self.user_id,
)
elif self.label == "human":
return Human(
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
description=self.description,
user_id=self.user_id,
)
else:
raise ValueError(f"Block with label {self.label} is not supported")
class ToolModel(Base):
__tablename__ = "tools"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
user_id = Column(String)
description = Column(String)
source_type = Column(String)
source_code = Column(String)
json_schema = Column(JSON)
module = Column(String)
tags = Column(JSON)
def __repr__(self) -> str:
return f"<Tool(id='{self.id}', name='{self.name}')>"
def to_record(self) -> Tool:
return Tool(
id=self.id, id=self.id,
user_id=self.user_id,
name=self.name, name=self.name,
user_id=self.user_id,
description=self.description, description=self.description,
source_type=self.source_type, system=self.system,
source_code=self.source_code, human=self.human,
json_schema=self.json_schema, persona=self.persona,
module=self.module, human_name=self.human_name,
tags=self.tags, persona_name=self.persona_name,
) preset=self.preset,
class JobModel(Base):
__tablename__ = "jobs"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
user_id = Column(String)
status = Column(String, default=JobStatus.pending)
created_at = Column(DateTime(timezone=True), server_default=func.now())
completed_at = Column(DateTime(timezone=True), onupdate=func.now())
metadata_ = Column(JSON)
def __repr__(self) -> str:
return f"<Job(id='{self.id}', status='{self.status}')>"
def to_record(self):
return Job(
id=self.id,
user_id=self.user_id,
status=self.status,
created_at=self.created_at, created_at=self.created_at,
completed_at=self.completed_at, functions_schema=self.functions_schema,
metadata_=self.metadata_,
) )
@ -364,8 +357,11 @@ class MetadataStore:
AgentModel.__table__, AgentModel.__table__,
SourceModel.__table__, SourceModel.__table__,
AgentSourceMappingModel.__table__, AgentSourceMappingModel.__table__,
APIKeyModel.__table__, TokenModel.__table__,
BlockModel.__table__, PresetModel.__table__,
PresetSourceMapping.__table__,
HumanModel.__table__,
PersonaModel.__table__,
ToolModel.__table__, ToolModel.__table__,
JobModel.__table__, JobModel.__table__,
], ],
@ -391,17 +387,16 @@ class MetadataStore:
self.session_maker = sessionmaker(bind=self.engine) self.session_maker = sessionmaker(bind=self.engine)
@enforce_types @enforce_types
def create_api_key(self, user_id: str, name: str) -> APIKey: def create_api_key(self, user_id: uuid.UUID, name: Optional[str] = None) -> Token:
"""Create an API key for a user""" """Create an API key for a user"""
new_api_key = generate_api_key() new_api_key = generate_api_key()
with self.session_maker() as session: with self.session_maker() as session:
if session.query(APIKeyModel).filter(APIKeyModel.key == new_api_key).count() > 0: if session.query(TokenModel).filter(TokenModel.token == new_api_key).count() > 0:
# NOTE duplicate API keys / tokens should never happen, but if it does don't allow it # NOTE duplicate API keys / tokens should never happen, but if it does don't allow it
raise ValueError(f"Token {new_api_key} already exists") raise ValueError(f"Token {new_api_key} already exists")
# TODO store the API keys as hashed # TODO store the API keys as hashed
assert user_id and name, "User ID and name must be provided" token = Token(user_id=user_id, token=new_api_key, name=name)
token = APIKey(user_id=user_id, key=new_api_key, name=name) session.add(TokenModel(**vars(token)))
session.add(APIKeyModel(**vars(token)))
session.commit() session.commit()
return self.get_api_key(api_key=new_api_key) return self.get_api_key(api_key=new_api_key)
@ -409,22 +404,22 @@ class MetadataStore:
def delete_api_key(self, api_key: str): def delete_api_key(self, api_key: str):
"""Delete an API key from the database""" """Delete an API key from the database"""
with self.session_maker() as session: with self.session_maker() as session:
session.query(APIKeyModel).filter(APIKeyModel.key == api_key).delete() session.query(TokenModel).filter(TokenModel.token == api_key).delete()
session.commit() session.commit()
@enforce_types @enforce_types
def get_api_key(self, api_key: str) -> Optional[APIKey]: def get_api_key(self, api_key: str) -> Optional[Token]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(APIKeyModel).filter(APIKeyModel.key == api_key).all() results = session.query(TokenModel).filter(TokenModel.token == api_key).all()
if len(results) == 0: if len(results) == 0:
return None return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
return results[0].to_record() return results[0].to_record()
@enforce_types @enforce_types
def get_all_api_keys_for_user(self, user_id: str) -> List[APIKey]: def get_all_api_keys_for_user(self, user_id: uuid.UUID) -> List[Token]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all() results = session.query(TokenModel).filter(TokenModel.user_id == user_id).all()
tokens = [r.to_record() for r in results] tokens = [r.to_record() for r in results]
return tokens return tokens
@ -441,20 +436,25 @@ class MetadataStore:
def create_agent(self, agent: AgentState): def create_agent(self, agent: AgentState):
# insert into agent table # insert into agent table
# make sure agent.name does not already exist for user user_id # make sure agent.name does not already exist for user user_id
assert agent.state is not None, "Agent state must be provided"
assert len(list(agent.state.keys())) > 0, "Agent state must not be empty"
with self.session_maker() as session: with self.session_maker() as session:
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0: if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
raise ValueError(f"Agent with name {agent.name} already exists") raise ValueError(f"Agent with name {agent.name} already exists")
fields = vars(agent) session.add(AgentModel(**vars(agent)))
fields["memory"] = agent.memory.to_dict()
session.add(AgentModel(**fields))
session.commit() session.commit()
@enforce_types @enforce_types
def create_source(self, source: Source): def create_source(self, source: Source, exists_ok=False):
# make sure source.name does not already exist for user
with self.session_maker() as session: with self.session_maker() as session:
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0: if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}") if not exists_ok:
session.add(SourceModel(**vars(source))) raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
else:
session.update(SourceModel(**vars(source)))
else:
session.add(SourceModel(**vars(source)))
session.commit() session.commit()
@enforce_types @enforce_types
@ -466,40 +466,67 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types @enforce_types
def create_block(self, block: Block): def create_preset(self, preset: Preset):
with self.session_maker() as session: with self.session_maker() as session:
# TODO: fix? if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0:
# we are only validating that more than one template block raise ValueError(f"User with id {preset.id} already exists")
# with a given name doesn't exist. session.add(PresetModel(**vars(preset)))
if (
session.query(BlockModel)
.filter(BlockModel.name == block.name)
.filter(BlockModel.user_id == block.user_id)
.filter(BlockModel.template == True)
.filter(BlockModel.label == block.label)
.count()
> 0
):
raise ValueError(f"Block with name {block.name} already exists")
session.add(BlockModel(**vars(block)))
session.commit() session.commit()
@enforce_types @enforce_types
def create_tool(self, tool: Tool): def get_preset(
self, preset_id: Optional[uuid.UUID] = None, name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
) -> Optional[Preset]:
with self.session_maker() as session: with self.session_maker() as session:
if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None: if preset_id:
raise ValueError(f"Tool with name {tool.name} already exists") results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
session.add(ToolModel(**vars(tool))) elif name and user_id:
results = session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).all()
else:
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
# @enforce_types
# def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]):
# preset = self.get_preset(preset_id)
# if preset is None:
# raise ValueError(f"Preset with id {preset_id} does not exist")
# user_id = preset.user_id
# with self.session_maker() as session:
# for function in functions:
# session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function))
# session.commit()
@enforce_types
def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]):
preset = self.get_preset(preset_id)
if preset is None:
raise ValueError(f"Preset with id {preset_id} does not exist")
user_id = preset.user_id
with self.session_maker() as session:
for source_id in sources:
session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id))
session.commit() session.commit()
# @enforce_types
# def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]:
# with self.session_maker() as session:
# results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all()
# return [r.function for r in results]
@enforce_types
def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]:
with self.session_maker() as session:
results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all()
return [r.source_id for r in results]
@enforce_types @enforce_types
def update_agent(self, agent: AgentState): def update_agent(self, agent: AgentState):
with self.session_maker() as session: with self.session_maker() as session:
fields = vars(agent) session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
fields["memory"] = agent.memory.to_dict()
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.commit() session.commit()
@enforce_types @enforce_types
@ -515,41 +542,28 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types @enforce_types
def update_block(self, block: Block): def update_human(self, human: HumanModel):
with self.session_maker() as session: with self.session_maker() as session:
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block)) session.add(human)
session.commit() session.commit()
session.refresh(human)
@enforce_types @enforce_types
def update_or_create_block(self, block: Block): def update_persona(self, persona: PersonaModel):
with self.session_maker() as session: with self.session_maker() as session:
existing_block = session.query(BlockModel).filter(BlockModel.id == block.id).first() session.add(persona)
if existing_block:
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
else:
session.add(BlockModel(**vars(block)))
session.commit() session.commit()
session.refresh(persona)
@enforce_types @enforce_types
def update_tool(self, tool: Tool): def update_tool(self, tool: ToolModel):
with self.session_maker() as session: with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool)) session.add(tool)
session.commit() session.commit()
session.refresh(tool)
@enforce_types @enforce_types
def delete_tool(self, tool_id: str): def delete_agent(self, agent_id: uuid.UUID):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
session.commit()
@enforce_types
def delete_block(self, block_id: str):
with self.session_maker() as session:
session.query(BlockModel).filter(BlockModel.id == block_id).delete()
session.commit()
@enforce_types
def delete_agent(self, agent_id: str):
with self.session_maker() as session: with self.session_maker() as session:
# delete agents # delete agents
@ -561,7 +575,7 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types @enforce_types
def delete_source(self, source_id: str): def delete_source(self, source_id: uuid.UUID):
with self.session_maker() as session: with self.session_maker() as session:
# delete from sources table # delete from sources table
session.query(SourceModel).filter(SourceModel.id == source_id).delete() session.query(SourceModel).filter(SourceModel.id == source_id).delete()
@ -572,7 +586,7 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types @enforce_types
def delete_user(self, user_id: str): def delete_user(self, user_id: uuid.UUID):
with self.session_maker() as session: with self.session_maker() as session:
# delete from users table # delete from users table
session.query(UserModel).filter(UserModel.id == user_id).delete() session.query(UserModel).filter(UserModel.id == user_id).delete()
@ -589,30 +603,42 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types @enforce_types
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]: with self.session_maker() as session:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return [r.to_record() for r in results]
@enforce_types
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.user_id == None).all() results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
if user_id: if user_id:
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all() results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
res = [r.to_record() for r in results] return results
return res
@enforce_types @enforce_types
def list_agents(self, user_id: str) -> List[AgentState]: def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all() results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results] return [r.to_record() for r in results]
@enforce_types @enforce_types
def list_sources(self, user_id: str) -> List[Source]: def list_all_agents(self) -> List[AgentState]:
with self.session_maker() as session:
results = session.query(AgentModel).all()
return [r.to_record() for r in results]
@enforce_types
def list_sources(self, user_id: uuid.UUID) -> List[Source]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all() results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
return [r.to_record() for r in results] return [r.to_record() for r in results]
@enforce_types @enforce_types
def get_agent( def get_agent(
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None self, agent_id: Optional[uuid.UUID] = None, agent_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
) -> Optional[AgentState]: ) -> Optional[AgentState]:
with self.session_maker() as session: with self.session_maker() as session:
if agent_id: if agent_id:
@ -627,7 +653,7 @@ class MetadataStore:
return results[0].to_record() return results[0].to_record()
@enforce_types @enforce_types
def get_user(self, user_id: str) -> Optional[User]: def get_user(self, user_id: uuid.UUID) -> Optional[User]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(UserModel).filter(UserModel.id == user_id).all() results = session.query(UserModel).filter(UserModel.id == user_id).all()
if len(results) == 0: if len(results) == 0:
@ -636,7 +662,7 @@ class MetadataStore:
return results[0].to_record() return results[0].to_record()
@enforce_types @enforce_types
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50): def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]):
with self.session_maker() as session: with self.session_maker() as session:
query = session.query(UserModel).order_by(desc(UserModel.id)) query = session.query(UserModel).order_by(desc(UserModel.id))
if cursor: if cursor:
@ -646,13 +672,13 @@ class MetadataStore:
return None, [] return None, []
user_records = [r.to_record() for r in results] user_records = [r.to_record() for r in results]
next_cursor = user_records[-1].id next_cursor = user_records[-1].id
assert isinstance(next_cursor, str) assert isinstance(next_cursor, uuid.UUID)
return next_cursor, user_records return next_cursor, user_records
@enforce_types @enforce_types
def get_source( def get_source(
self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None self, source_id: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None
) -> Optional[Source]: ) -> Optional[Source]:
with self.session_maker() as session: with self.session_maker() as session:
if source_id: if source_id:
@ -666,89 +692,42 @@ class MetadataStore:
return results[0].to_record() return results[0].to_record()
@enforce_types @enforce_types
def get_tool( def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]:
self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None # TODO: add user_id when tools can eventually be added by users
) -> Optional[ToolModel]:
with self.session_maker() as session: with self.session_maker() as session:
if tool_id: results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
else:
assert tool_name is not None
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_block(self, block_id: str) -> Optional[Block]:
with self.session_maker() as session:
results = session.query(BlockModel).filter(BlockModel.id == block_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_blocks(
self,
user_id: Optional[str],
label: Optional[str] = None,
template: bool = True,
name: Optional[str] = None,
id: Optional[str] = None,
) -> List[Block]:
"""List available blocks"""
with self.session_maker() as session:
query = session.query(BlockModel).filter(BlockModel.template == template)
if user_id: if user_id:
query = query.filter(BlockModel.user_id == user_id) results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if label:
query = query.filter(BlockModel.label == label)
if name:
query = query.filter(BlockModel.name == name)
if id:
query = query.filter(BlockModel.id == id)
results = query.all()
if len(results) == 0: if len(results) == 0:
return None return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return [r.to_record() for r in results] return results[0]
# agent source metadata # agent source metadata
@enforce_types @enforce_types
def attach_source(self, user_id: str, agent_id: str, source_id: str): def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):
with self.session_maker() as session: with self.session_maker() as session:
# TODO: remove this (is a hack) session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
mapping_id = f"{user_id}-{agent_id}-{source_id}"
session.add(AgentSourceMappingModel(id=mapping_id, user_id=user_id, agent_id=agent_id, source_id=source_id))
session.commit() session.commit()
@enforce_types @enforce_types
def list_attached_sources(self, agent_id: str) -> List[Source]: def list_attached_sources(self, agent_id: uuid.UUID) -> List[uuid.UUID]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all() results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
sources = [] source_ids = []
# make sure source exists # make sure source exists
for r in results: for r in results:
source = self.get_source(source_id=r.source_id) source = self.get_source(source_id=r.source_id)
if source: if source:
sources.append(source) source_ids.append(r.source_id)
else: else:
printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.") printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.")
return sources return source_ids
@enforce_types @enforce_types
def list_attached_agents(self, source_id: str) -> List[str]: def list_attached_agents(self, source_id: uuid.UUID) -> List[uuid.UUID]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all() results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
@ -763,7 +742,7 @@ class MetadataStore:
return agent_ids return agent_ids
@enforce_types @enforce_types
def detach_source(self, agent_id: str, source_id: str): def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID):
with self.session_maker() as session: with self.session_maker() as session:
session.query(AgentSourceMappingModel).filter( session.query(AgentSourceMappingModel).filter(
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
@ -771,38 +750,120 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types @enforce_types
def create_job(self, job: Job): def add_human(self, human: HumanModel):
with self.session_maker() as session: with self.session_maker() as session:
session.add(JobModel(**vars(job))) if self.get_human(human.name, human.user_id):
raise ValueError(f"Human with name {human.name} already exists for user_id {human.user_id}")
session.add(human)
session.commit() session.commit()
def delete_job(self, job_id: str): @enforce_types
def add_persona(self, persona: PersonaModel):
with self.session_maker() as session: with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).delete() if self.get_persona(persona.name, persona.user_id):
raise ValueError(f"Persona with name {persona.name} already exists for user_id {persona.user_id}")
session.add(persona)
session.commit() session.commit()
def get_job(self, job_id: str) -> Optional[Job]: @enforce_types
def add_preset(self, preset: PresetModel): # TODO: remove
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(JobModel).filter(JobModel.id == job_id).all() session.add(preset)
session.commit()
@enforce_types
def add_tool(self, tool: ToolModel):
with self.session_maker() as session:
if self.get_tool(tool.name, tool.user_id):
raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}")
session.add(tool)
session.commit()
@enforce_types
def get_human(self, name: str, user_id: uuid.UUID) -> Optional[HumanModel]:
with self.session_maker() as session:
results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all()
if len(results) == 0: if len(results) == 0:
return None return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}" assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record() return results[0]
def list_jobs(self, user_id: str) -> List[Job]: @enforce_types
def get_persona(self, name: str, user_id: uuid.UUID) -> Optional[PersonaModel]:
with self.session_maker() as session: with self.session_maker() as session:
results = session.query(JobModel).filter(JobModel.user_id == user_id).all() results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all()
return [r.to_record() for r in results] if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
def update_job(self, job: Job) -> Job: @enforce_types
def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]:
with self.session_maker() as session: with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job.id).update(vars(job)) results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all()
return results
@enforce_types
def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]:
with self.session_maker() as session:
# if user_id matches provided user_id or if user_id is None
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
return results
@enforce_types
def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
with self.session_maker() as session:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return results
@enforce_types
def delete_human(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete()
session.commit() session.commit()
return Job
def update_job_status(self, job_id: str, status: JobStatus): @enforce_types
def delete_persona(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
session.commit()
@enforce_types
def delete_preset(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
session.commit()
@enforce_types
def delete_tool(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete()
session.commit()
# job related functions
def create_job(self, job: JobModel):
with self.session_maker() as session:
session.add(job)
session.commit()
session.expunge_all()
def update_job_status(self, job_id: uuid.UUID, status: JobStatus):
with self.session_maker() as session: with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).update({"status": status}) session.query(JobModel).filter(JobModel.id == job_id).update({"status": status})
if status == JobStatus.COMPLETED: if status == JobStatus.COMPLETED:
session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()}) session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()})
session.commit() session.commit()
def update_job(self, job: JobModel):
with self.session_maker() as session:
session.add(job)
session.commit()
session.refresh(job)
def get_job(self, job_id: uuid.UUID) -> Optional[JobModel]:
with self.session_maker() as session:
results = session.query(JobModel).filter(JobModel.id == job_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]

716
memgpt/migrate.py Normal file
View File

@ -0,0 +1,716 @@
import configparser
import glob
import json
import os
import pickle
import shutil
import sys
import traceback
import uuid
from datetime import datetime
from typing import List, Optional
import pytz
import questionary
import typer
from tqdm import tqdm
from memgpt.agent import Agent, save_agent
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.cli.cli_config import configure
from memgpt.config import MemGPTConfig
from memgpt.data_types import AgentState, Message, Passage, Source, User
from memgpt.metadata import MetadataStore
from memgpt.persistence_manager import LocalStateManager
from memgpt.utils import (
MEMGPT_DIR,
OpenAIBackcompatUnpickler,
annotate_message_json_list_with_tool_calls,
get_utc_time,
parse_formatted_time,
version_less_than,
)
# This is the version where the breaking change was made
VERSION_CUTOFF = "0.2.12"
# Migration backup dir (where we'll dump old agents that we successfully migrated)
MIGRATION_BACKUP_FOLDER = "migration_backups"
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True, create_config=True):
"""Wipe (backup) the config file, and launch `memgpt configure`"""
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
# Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS)
timestamp = get_utc_time().strftime("%Y%m%d_%H%M%S")
# Construct the new backup directory name with the timestamp
backup_filename = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}")
existing_filename = os.path.join(data_dir, "config")
# Check if the existing file exists before moving
if os.path.exists(existing_filename):
# shutil should work cross-platform
shutil.move(existing_filename, backup_filename)
typer.secho(f"Deleted config file ({existing_filename}) and saved as backup ({backup_filename})", fg=typer.colors.GREEN)
else:
typer.secho(f"Couldn't find an existing config file to delete", fg=typer.colors.RED)
if run_configure:
# Either run configure
configure()
elif create_config:
# Or create a new config with defaults
MemGPTConfig.load()
def config_is_compatible(data_dir: str = MEMGPT_DIR, allow_empty=False, echo=False) -> bool:
"""Check if the config is OK to use with 0.2.12, or if it needs to be deleted"""
# NOTE: don't use built-in load(), since that will apply defaults
# memgpt_config = MemGPTConfig.load()
memgpt_config_file = os.path.join(data_dir, "config")
if not os.path.exists(memgpt_config_file):
return True if allow_empty else False
parser = configparser.ConfigParser()
parser.read(memgpt_config_file)
if "version" in parser and "memgpt_version" in parser["version"]:
version = parser["version"]["memgpt_version"]
else:
version = None
if version is None:
# no version -- assume pre-determined config (does not need to be migrated)
return True
elif version_less_than(version, VERSION_CUTOFF):
if echo:
typer.secho(f"Current config version ({version}) is older than migration cutoff ({VERSION_CUTOFF})", fg=typer.colors.RED)
return False
else:
if echo:
typer.secho(f"Current config version {version} is compatible!", fg=typer.colors.GREEN)
return True
def agent_is_migrateable(agent_name: str, data_dir: str = MEMGPT_DIR) -> bool:
"""Determine whether or not the agent folder is a migration target"""
agent_folder = os.path.join(data_dir, "agents", agent_name)
if not os.path.exists(agent_folder):
raise ValueError(f"Folder {agent_folder} does not exist")
agent_config_file = os.path.join(agent_folder, "config.json")
if not os.path.exists(agent_config_file):
raise ValueError(f"Agent folder {agent_folder} does not have a config file")
try:
with open(agent_config_file, "r", encoding="utf-8") as fh:
agent_config = json.load(fh)
except Exception as e:
raise ValueError(f"Failed to load agent config file ({agent_config_file}), error = {e}")
if not hasattr(agent_config, "memgpt_version") or version_less_than(agent_config.memgpt_version, VERSION_CUTOFF):
return True
else:
return False
def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None):
"""
Migrate an old source folder (`~/.memgpt/sources/{source_name}`).
"""
# 1. Load the VectorIndex from ~/.memgpt/sources/{source_name}/index
# TODO
source_path = os.path.join(data_dir, "archival", source_name, "nodes.pkl")
assert os.path.exists(source_path), f"Source {source_name} does not exist at {source_path}"
# load state from old checkpoint file
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
if ms is None:
ms = MetadataStore(config)
# gets default user
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
# raise ValueError(
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
# )
# insert source into metadata store
source = Source(user_id=user.id, name=source_name)
ms.create_source(source)
try:
try:
nodes = pickle.load(open(source_path, "rb"))
except ModuleNotFoundError as e:
if "No module named 'llama_index.schema'" in str(e):
# cannot load source at all, so throw error
raise ValueError(
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
)
else:
raise e
passages = []
for node in nodes:
# print(len(node.embedding))
# TODO: make sure embedding config matches embedding size?
if len(node.embedding) != config.default_embedding_config.embedding_dim:
raise ValueError(
f"Cannot migrate source {source_name} due to incompatible embedding dimentions. Please re-load this source with `memgpt load`."
)
passages.append(
Passage(
user_id=user.id,
data_source=source_name,
text=node.text,
embedding=node.embedding,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
assert len(passages) > 0, f"Source {source_name} has no passages"
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
conn.insert_many(passages)
# print(f"Inserted {len(passages)} to {source_name}")
except Exception as e:
# delete from metadata store
ms.delete_source(source.id)
raise ValueError(f"Failed to migrate {source_name}: {str(e)}")
# basic checks
source = ms.get_source(user_id=user.id, source_name=source_name)
assert source is not None, f"Failed to load source {source_name} from database after migration"
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None) -> List[str]:
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
Steps:
1. Load the agent state JSON from the old folder
2. Create a new AgentState using the agent config + agent internal state
3. Instantiate a new Agent by passing AgentState to Agent.__init__
(This will automatically run into a new database)
If success, returns empty list
If warning, returns a list of strings (warning message)
If error, raises an Exception
"""
warnings = []
# 1. Load the agent state JSON from the old folder
# TODO
agent_folder = os.path.join(data_dir, "agents", agent_name)
# migration_file = os.path.join(agent_folder, MIGRATION_FILE_NAME)
# load state from old checkpoint file
agent_ckpt_directory = os.path.join(agent_folder, "agent_state")
json_files = glob.glob(os.path.join(agent_ckpt_directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}")
# NOTE this is a soft fail, just allow it to pass
# return
# return [f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}"]
# Sort files based on modified timestamp, with the latest file being the first.
state_filename = max(json_files, key=os.path.getmtime)
state_dict = json.load(open(state_filename, "r"))
# print(state_dict.keys())
# print(state_dict["memory"])
# dict_keys(['model', 'system', 'functions', 'messages', 'messages_total', 'memory'])
# load old data from the persistence manager
persistence_filename = os.path.basename(state_filename).replace(".json", ".persistence.pickle")
persistence_filename = os.path.join(agent_folder, "persistence_manager", persistence_filename)
archival_filename = os.path.join(agent_folder, "persistence_manager", "index", "nodes.pkl")
if not os.path.exists(persistence_filename):
raise ValueError(f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}")
# return [f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}"]
try:
with open(persistence_filename, "rb") as f:
data = pickle.load(f)
except ModuleNotFoundError:
# Patch for stripped openai package
# ModuleNotFoundError: No module named 'openai.openai_object'
with open(persistence_filename, "rb") as f:
unpickler = OpenAIBackcompatUnpickler(f)
data = unpickler.load()
from memgpt.openai_backcompat.openai_object import OpenAIObject
def convert_openai_objects_to_dict(obj):
if isinstance(obj, OpenAIObject):
# Convert to dict or handle as needed
# print(f"detected OpenAIObject on {obj}")
return obj.to_dict_recursive()
elif isinstance(obj, dict):
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_openai_objects_to_dict(v) for v in obj]
else:
return obj
data = convert_openai_objects_to_dict(data)
# data will contain:
# print("data.keys()", data.keys())
# manager.all_messages = data["all_messages"]
# manager.messages = data["messages"]
# manager.recall_memory = data["recall_memory"]
agent_config_filename = os.path.join(agent_folder, "config.json")
with open(agent_config_filename, "r", encoding="utf-8") as fh:
agent_config = json.load(fh)
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
if ms is None:
ms = MetadataStore(config)
# gets default user
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
# raise ValueError(
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
# )
# ms.create_user(User(id=user_id))
# user = ms.get_user(user_id=user_id)
# if user is None:
# typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
# sys.exit(1)
# create an agent_id ahead of time
agent_id = uuid.uuid4()
# create all the Messages in the database
# message_objs = []
# for message_dict in annotate_message_json_list_with_tool_calls(state_dict["messages"]):
# message_obj = Message.dict_to_message(
# user_id=user.id,
# agent_id=agent_id,
# openai_message_dict=message_dict,
# model=state_dict["model"] if "model" in state_dict else None,
# # allow_functions_style=False,
# allow_functions_style=True,
# )
# message_objs.append(message_obj)
agent_state = AgentState(
id=agent_id,
name=agent_config["name"],
user_id=user.id,
# persona_name=agent_config["persona"], # eg 'sam_pov'
# human_name=agent_config["human"], # eg 'basic'
persona=state_dict["memory"]["persona"], # NOTE: hacky (not init, but latest)
human=state_dict["memory"]["human"], # NOTE: hacky (not init, but latest)
preset=agent_config["preset"], # eg 'memgpt_chat'
state=dict(
human=state_dict["memory"]["human"],
persona=state_dict["memory"]["persona"],
system=state_dict["system"],
functions=state_dict["functions"], # this shouldn't matter, since Agent.__init__ will re-link
# messages=[str(m.id) for m in message_objs], # this is a list of uuids, not message dicts
),
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
)
persistence_manager = LocalStateManager(agent_state=agent_state)
# First clean up the recall message history to add tool call ids
# allow_tool_roles in case some of the old messages were actually already in tool call format (for whatever reason)
full_message_history_buffer = annotate_message_json_list_with_tool_calls(
[d["message"] for d in data["all_messages"]], allow_tool_roles=True
)
for i in range(len(data["all_messages"])):
data["all_messages"][i]["message"] = full_message_history_buffer[i]
# Figure out what messages in recall are in-context, and which are out-of-context
agent_message_cache = state_dict["messages"]
recall_message_full = data["all_messages"]
def messages_are_equal(msg1, msg2):
if msg1["role"] != msg2["role"]:
return False
if msg1["content"] != msg2["content"]:
return False
if "function_call" in msg1 and "function_call" in msg2 and msg1["function_call"] != msg2["function_call"]:
return False
if "name" in msg1 and "name" in msg2 and msg1["name"] != msg2["name"]:
return False
# otherwise checks pass, ~= equal
return True
in_context_messages = []
out_of_context_messages = []
assert len(agent_message_cache) <= len(recall_message_full), (len(agent_message_cache), len(recall_message_full))
for i, d in enumerate(recall_message_full):
# unpack into "timestamp" and "message"
recall_message = d["message"]
recall_timestamp = str(d["timestamp"])
try:
recall_datetime = parse_formatted_time(recall_timestamp.strip()).astimezone(pytz.utc)
except ValueError:
recall_datetime = datetime.strptime(recall_timestamp.strip(), "%Y-%m-%d %I:%M:%S %p").astimezone(pytz.utc)
# message object
message_obj = Message.dict_to_message(
created_at=recall_datetime,
user_id=user.id,
agent_id=agent_id,
openai_message_dict=recall_message,
allow_functions_style=True,
)
# message is either in-context, or out-of-context
if i >= (len(recall_message_full) - len(agent_message_cache)):
# there are len(agent_message_cache) total messages on the agent
# this will correspond to the last N messages in the recall memory (though possibly out-of-order)
message_is_in_context = [messages_are_equal(recall_message, cache_message) for cache_message in agent_message_cache]
# assert sum(message_is_in_context) <= 1, message_is_in_context
# if any(message_is_in_context):
# in_context_messages.append(message_obj)
# else:
# out_of_context_messages.append(message_obj)
if not any(message_is_in_context):
# typer.secho(
# f"Warning: didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
# fg=typer.colors.RED,
# )
warnings.append(
f"Didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
)
out_of_context_messages.append(message_obj)
else:
if sum(message_is_in_context) > 1:
# typer.secho(
# f"Warning: found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
# fg=typer.colors.RED,
# )
warnings.append(
f"Found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
)
in_context_messages.append(message_obj)
else:
# if we're not in the final portion of the recall memory buffer, then it's 100% out-of-context
out_of_context_messages.append(message_obj)
assert len(in_context_messages) > 0, f"Couldn't find any in-context messages (agent_cache = {len(agent_message_cache)})"
# assert len(in_context_messages) == len(agent_message_cache), (len(in_context_messages), len(agent_message_cache))
if len(in_context_messages) != len(agent_message_cache):
# typer.secho(
# f"Warning: uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})",
# fg=typer.colors.RED,
# )
warnings.append(
f"Uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})"
)
# assert (
# len(in_context_messages) + len(out_of_context_messages) == state_dict["messages_total"]
# ), f"{len(in_context_messages)} + {len(out_of_context_messages)} != {state_dict['messages_total']}"
# Now we can insert the messages into the actual recall database
# So when we construct the agent from the state, they will be available
persistence_manager.recall_memory.insert_many(out_of_context_messages)
persistence_manager.recall_memory.insert_many(in_context_messages)
# Overwrite the agent_state message object
agent_state.state["messages"] = [str(m.id) for m in in_context_messages] # this is a list of uuids, not message dicts
## 4. Insert into recall
# TODO should this be 'messages', or 'all_messages'?
# all_messages in recall will have fields "timestamp" and "message"
# full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
# We want to keep the timestamp
# for i in range(len(data["all_messages"])):
# data["all_messages"][i]["message"] = full_message_history_buffer[i]
# messages_to_insert = [
# Message.dict_to_message(
# user_id=user.id,
# agent_id=agent_id,
# openai_message_dict=msg,
# allow_functions_style=True,
# )
# # for msg in data["all_messages"]
# for msg in full_message_history_buffer
# ]
# agent.persistence_manager.recall_memory.insert_many(messages_to_insert)
# print("Finished migrating recall memory")
# 3. Instantiate a new Agent by passing AgentState to Agent.__init__
# NOTE: the Agent.__init__ will trigger a save, which will write to the DB
try:
agent = Agent(
agent_state=agent_state,
# messages_total=state_dict["messages_total"], # TODO: do we need this?
messages_total=len(in_context_messages) + len(out_of_context_messages),
interface=None,
)
save_agent(agent, ms=ms)
except Exception:
# if "Agent with name" in str(e):
# print(e)
# return
# elif "was specified in agent.state.functions":
# print(e)
# return
# else:
# raise
raise
# Wrap the rest in a try-except so that we can cleanup by deleting the agent if we fail
try:
# TODO should we also assign data["messages"] to RecallMemory.messages?
# 5. Insert into archival
if os.path.exists(archival_filename):
try:
nodes = pickle.load(open(archival_filename, "rb"))
except ModuleNotFoundError as e:
if "No module named 'llama_index.schema'" in str(e):
print(
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
)
nodes = []
else:
raise e
passages = []
failed_inserts = []
for node in nodes:
if len(node.embedding) != config.default_embedding_config.embedding_dim:
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
failed_inserts.append(
f"Cannot migrate passage due to incompatible embedding dimentions ({len(node.embedding)} != {config.default_embedding_config.embedding_dim}) - content = '{node.text}'."
)
passages.append(
Passage(
user_id=user.id,
agent_id=agent_state.id,
text=node.text,
embedding=node.embedding,
embedding_dim=agent_state.embedding_config.embedding_dim,
embedding_model=agent_state.embedding_config.embedding_model,
)
)
if len(passages) > 0:
agent.persistence_manager.archival_memory.storage.insert_many(passages)
# print(f"Inserted {len(passages)} passages into archival memory")
if len(failed_inserts) > 0:
warnings.append(
f"Failed to transfer {len(failed_inserts)}/{len(nodes)} passages from old archival memory: " + ", ".join(failed_inserts)
)
else:
warnings.append("No archival memory found at", archival_filename)
except:
ms.delete_agent(agent_state.id)
raise
try:
new_agent_folder = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents", agent_name)
shutil.move(agent_folder, new_agent_folder)
except Exception:
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
raise
return warnings
# def migrate_all_agents(stop_on_fail=True):
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
if not config_is_compatible(data_dir, echo=True):
typer.secho(f"Your current config file is incompatible with MemGPT versions >= {VERSION_CUTOFF}", fg=typer.colors.RED)
if questionary.confirm(
"To migrate old MemGPT agents, you must delete your config file and run `memgpt configure`. Would you like to proceed?"
).ask():
try:
wipe_config_and_reconfigure(data_dir)
except Exception as e:
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
raise
else:
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
raise KeyboardInterrupt()
agents_dir = os.path.join(data_dir, "agents")
# Ensure the directory exists
if not os.path.exists(agents_dir):
raise ValueError(f"Directory {agents_dir} does not exist.")
# Get a list of all folders in agents_dir
agent_folders = [f for f in os.listdir(agents_dir) if os.path.isdir(os.path.join(agents_dir, f))]
# Iterate over each folder with a tqdm progress bar
count = 0
successes = [] # agents that migrated w/o warnings
warnings = [] # agents that migrated but had warnings
failures = [] # agents that failed to migrate (fatal error)
candidates = []
config = MemGPTConfig.load()
print(config)
ms = MetadataStore(config)
try:
for agent_name in tqdm(agent_folders, desc="Migrating agents"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
if agent_is_migrateable(agent_name=agent_name, data_dir=data_dir):
candidates.append(agent_name)
migration_warnings = migrate_agent(agent_name, data_dir=data_dir, ms=ms)
if len(migration_warnings) == 0:
successes.append(agent_name)
else:
warnings.append((agent_name, migration_warnings))
count += 1
else:
continue
except Exception as e:
failures.append({"name": agent_name, "reason": str(e)})
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
if debug:
traceback.print_exc()
if stop_on_fail:
raise
except KeyboardInterrupt:
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
if len(candidates) == 0:
typer.secho(f"No migration candidates found ({len(agent_folders)} agent folders total)", fg=typer.colors.GREEN)
else:
typer.secho(f"Inspected {len(agent_folders)} agent folders for migration")
if len(warnings) > 0:
typer.secho(f"Migration warnings:", fg=typer.colors.BRIGHT_YELLOW)
for warn in warnings:
typer.secho(f"{warn[0]}: {warn[1]}", fg=typer.colors.BRIGHT_YELLOW)
if len(failures) > 0:
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
for fail in failures:
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
if len(failures) > 0:
typer.secho(
f"🔴 {len(failures)}/{len(candidates)} agents failed to migrate (see reasons above)",
fg=typer.colors.RED,
)
typer.secho(f"{[d['name'] for d in failures]}", fg=typer.colors.RED)
if len(warnings) > 0:
typer.secho(
f"🟠 {len(warnings)}/{len(candidates)} agents successfully migrated with warnings (see reasons above)",
fg=typer.colors.BRIGHT_YELLOW,
)
typer.secho(f"{[t[0] for t in warnings]}", fg=typer.colors.BRIGHT_YELLOW)
if len(successes) > 0:
typer.secho(
f"🟢 {len(successes)}/{len(candidates)} agents successfully migrated with no warnings",
fg=typer.colors.GREEN,
)
typer.secho(f"{successes}", fg=typer.colors.GREEN)
del ms
return {
"agent_folders": len(agent_folders),
"migration_candidates": candidates,
"successful_migrations": len(successes) + len(warnings),
"failed_migrations": failures,
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
}
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
sources_dir = os.path.join(data_dir, "archival")
# Ensure the directory exists
if not os.path.exists(sources_dir):
raise ValueError(f"Directory {sources_dir} does not exist.")
# Get a list of all folders in agents_dir
source_folders = [f for f in os.listdir(sources_dir) if os.path.isdir(os.path.join(sources_dir, f))]
# Iterate over each folder with a tqdm progress bar
count = 0
failures = []
candidates = []
config = MemGPTConfig.load()
ms = MetadataStore(config)
try:
for source_name in tqdm(source_folders, desc="Migrating data sources"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
candidates.append(source_name)
migrate_source(source_name, data_dir, ms=ms)
count += 1
except Exception as e:
failures.append({"name": source_name, "reason": str(e)})
if debug:
traceback.print_exc()
if stop_on_fail:
raise
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
except KeyboardInterrupt:
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
if len(candidates) == 0:
typer.secho(f"No migration candidates found ({len(source_folders)} source folders total)", fg=typer.colors.GREEN)
else:
typer.secho(f"Inspected {len(source_folders)} source folders")
if len(failures) > 0:
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
for fail in failures:
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
typer.secho(f"{len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
if count > 0:
typer.secho(
f"{count}/{len(candidates)} sources were successfully migrated to the new database format", fg=typer.colors.GREEN
)
del ms
return {
"source_folders": len(source_folders),
"migration_candidates": candidates,
"successful_migrations": count,
"failed_migrations": failures,
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
}

View File

@ -0,0 +1,197 @@
# tool imports
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import JSON, Column
from sqlalchemy_utils import ChoiceType
from sqlmodel import Field, SQLModel
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.utils import get_human_text, get_persona_text, get_utc_time
class OptionState(str, Enum):
"""Useful for kwargs that are bool + default option"""
YES = "yes"
NO = "no"
DEFAULT = "default"
class MemGPTUsageStatistics(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
step_count: int
class LLMConfigModel(BaseModel):
model: Optional[str] = "gpt-4"
model_endpoint_type: Optional[str] = "openai"
model_endpoint: Optional[str] = "https://api.openai.com/v1"
model_wrapper: Optional[str] = None
context_window: Optional[int] = None
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())
class EmbeddingConfigModel(BaseModel):
embedding_endpoint_type: Optional[str] = "openai"
embedding_endpoint: Optional[str] = "https://api.openai.com/v1"
embedding_model: Optional[str] = "text-embedding-ada-002"
embedding_dim: Optional[int] = 1536
embedding_chunk_size: Optional[int] = 300
class PresetModel(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
class ToolModel(SQLModel, table=True):
# TODO move into database
name: str = Field(..., description="The name of the function.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True)
tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")
module: Optional[str] = Field(None, description="The module of the function.")
json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")
# optional: user_id (user-specific tools)
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.", index=True)
# Needed for Column(JSON)
class Config:
arbitrary_types_allowed = True
class AgentToolMap(SQLModel, table=True):
# mapping between agents and tools
agent_id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the agent-tool map.", primary_key=True)
class PresetToolMap(SQLModel, table=True):
# mapping between presets and tools
preset_id: uuid.UUID = Field(..., description="The unique identifier of the preset.")
tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset-tool map.", primary_key=True)
class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")
description: Optional[str] = Field(None, description="The description of the agent.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
# timestamps
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the agent was created.")
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
# preset information
tools: List[str] = Field(..., description="The tools used by the agent.")
system: str = Field(..., description="The system prompt used by the agent.")
# functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
# llm information
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")
# agent state
state: Optional[Dict] = Field(None, description="The state of the agent.")
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
class CoreMemory(BaseModel):
human: str = Field(..., description="Human element of the core memory.")
persona: str = Field(..., description="Persona element of the core memory.")
class HumanModel(SQLModel, table=True):
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
name: str = Field(..., description="The name of the human.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.", index=True)
class PersonaModel(SQLModel, table=True):
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
name: str = Field(..., description="The name of the persona.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.", index=True)
class SourceModel(SQLModel, table=True):
name: str = Field(..., description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the source was created.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
description: Optional[str] = Field(None, description="The description of the source.")
# embedding info
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
# NOTE: .metadata is a reserved attribute on SQLModel
metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.")
class JobStatus(str, Enum):
created = "created"
running = "running"
completed = "completed"
failed = "failed"
class JobModel(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the job.", primary_key=True)
# status: str = Field(default="created", description="The status of the job.")
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.", sa_column=Column(ChoiceType(JobStatus)))
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the job.", index=True)
metadata_: Optional[dict] = Field({}, sa_column=Column(JSON), description="The metadata of the job.")
class PassageModel(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.")
agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.")
text: str = Field(..., description="The text of the passage.")
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
data_source: Optional[str] = Field(None, description="The data source of the passage.")
doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")
class DocumentModel(BaseModel):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.")
text: str = Field(..., description="The text of the document.")
data_source: str = Field(..., description="The data source of the document.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")
class UserModel(BaseModel):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user.")

View File

@ -2,10 +2,8 @@ from abc import ABC, abstractmethod
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from memgpt.data_types import AgentState, Message
from memgpt.memory import BaseRecallMemory, EmbeddingArchivalMemory from memgpt.memory import BaseRecallMemory, EmbeddingArchivalMemory
from memgpt.schemas.agent import AgentState
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.utils import printd from memgpt.utils import printd
@ -47,7 +45,7 @@ class LocalStateManager(PersistenceManager):
def __init__(self, agent_state: AgentState): def __init__(self, agent_state: AgentState):
# Memory held in-state useful for debugging stateful versions # Memory held in-state useful for debugging stateful versions
self.memory = agent_state.memory self.memory = None
# self.messages = [] # current in-context messages # self.messages = [] # current in-context messages
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB) # self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
self.archival_memory = EmbeddingArchivalMemory(agent_state) self.archival_memory = EmbeddingArchivalMemory(agent_state)
@ -59,6 +57,15 @@ class LocalStateManager(PersistenceManager):
self.archival_memory.save() self.archival_memory.save()
self.recall_memory.save() self.recall_memory.save()
def init(self, agent):
"""Connect persistent state manager to agent"""
printd(f"Initializing {self.__class__.__name__} with agent object")
# self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
# self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
# printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
''' '''
def json_to_message(self, message_json) -> Message: def json_to_message(self, message_json) -> Message:
"""Convert agent message JSON into Message object""" """Convert agent message JSON into Message object"""
@ -121,6 +128,7 @@ class LocalStateManager(PersistenceManager):
# self.messages = [self.messages[0]] + added_messages + self.messages[1:] # self.messages = [self.messages[0]] + added_messages + self.messages[1:]
# add to recall memory # add to recall memory
self.recall_memory.insert_many([m for m in added_messages])
def append_to_messages(self, added_messages: List[Message]): def append_to_messages(self, added_messages: List[Message]):
# first tag with timestamps # first tag with timestamps
@ -142,7 +150,6 @@ class LocalStateManager(PersistenceManager):
# add to recall memory # add to recall memory
self.recall_memory.insert(new_system_message) self.recall_memory.insert(new_system_message)
def update_memory(self, new_memory: Memory): def update_memory(self, new_memory):
printd(f"{self.__class__.__name__}.update_memory") printd(f"{self.__class__.__name__}.update_memory")
assert isinstance(new_memory, Memory), type(new_memory)
self.memory = new_memory self.memory = new_memory

View File

@ -0,0 +1,10 @@
system_prompt: "memgpt_chat"
functions:
- "send_message"
- "pause_heartbeats"
- "core_memory_append"
- "core_memory_replace"
- "conversation_search"
- "conversation_search_date"
- "archival_memory_insert"
- "archival_memory_search"

View File

@ -0,0 +1,10 @@
system_prompt: "memgpt_doc"
functions:
- "send_message"
- "pause_heartbeats"
- "core_memory_append"
- "core_memory_replace"
- "conversation_search"
- "conversation_search_date"
- "archival_memory_insert"
- "archival_memory_search"

View File

@ -0,0 +1,15 @@
system_prompt: "memgpt_chat"
functions:
- "send_message"
- "pause_heartbeats"
- "core_memory_append"
- "core_memory_replace"
- "conversation_search"
- "conversation_search_date"
- "archival_memory_insert"
- "archival_memory_search"
# extras for read/write to files
- "read_from_text_file"
- "append_to_text_file"
# internet access
- "http_request"

91
memgpt/presets/presets.py Normal file
View File

@ -0,0 +1,91 @@
import importlib
import inspect
import os
import uuid
from memgpt.data_types import AgentState, Preset
from memgpt.functions.functions import load_function_set
from memgpt.interface import AgentInterface
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel
from memgpt.presets.utils import load_all_presets
from memgpt.utils import list_human_files, list_persona_files, printd
available_presets = load_all_presets()
preset_options = list(available_presets.keys())
def load_module_tools(module_name="base"):
# return List[ToolModel] from base.py tools
full_module_name = f"memgpt.functions.function_sets.{module_name}"
try:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
raise e
# function tags
try:
# Load the function set
functions_to_schema = load_function_set(module)
except ValueError as e:
err = f"Error loading function set '{module_name}': {e}"
printd(err)
# create tool in db
tools = []
for name, schema in functions_to_schema.items():
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
if module_name == "base":
tags.append("memgpt-base")
tools.append(
ToolModel(
name=name,
tags=tags,
source_type="python",
module=schema["module"],
source_code=source_code,
json_schema=schema["json_schema"],
)
)
return tools
def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
module_name = "base"
for tool in load_module_tools(module_name=module_name):
existing_tool = ms.get_tool(tool.name)
if not existing_tool:
ms.add_tool(tool)
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
for persona_file in list_persona_files():
text = open(persona_file, "r", encoding="utf-8").read()
name = os.path.basename(persona_file).replace(".txt", "")
if ms.get_persona(user_id=user_id, name=name) is not None:
printd(f"Persona '{name}' already exists for user '{user_id}'")
continue
persona = PersonaModel(name=name, text=text, user_id=user_id)
ms.add_persona(persona)
for human_file in list_human_files():
text = open(human_file, "r", encoding="utf-8").read()
name = os.path.basename(human_file).replace(".txt", "")
if ms.get_human(user_id=user_id, name=name) is not None:
printd(f"Human '{name}' already exists for user '{user_id}'")
continue
human = HumanModel(name=name, text=text, user_id=user_id)
print(human, user_id)
ms.add_human(human)
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
def create_agent_from_preset(
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
):
"""Initialize a new agent from a preset (combination of system + function)"""
raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead")

79
memgpt/presets/utils.py Normal file
View File

@ -0,0 +1,79 @@
import glob
import os
import yaml
from memgpt.constants import MEMGPT_DIR
def is_valid_yaml_format(yaml_data, function_set):
"""
Check if the given YAML data follows the specified format and if all functions in the yaml are part of the function_set.
Raises ValueError if any check fails.
:param yaml_data: The data loaded from a YAML file.
:param function_set: A set of valid function names.
"""
# Check for required keys
if not all(key in yaml_data for key in ["system_prompt", "functions"]):
raise ValueError("YAML data is missing one or more required keys: 'system_prompt', 'functions'.")
# Check if 'functions' is a list of strings
if not all(isinstance(item, str) for item in yaml_data.get("functions", [])):
raise ValueError("'functions' should be a list of strings.")
# Check if all functions in YAML are part of function_set
if not set(yaml_data["functions"]).issubset(function_set):
raise ValueError(
f"Some functions in YAML are not part of the provided function set: {set(yaml_data['functions']) - set(function_set)} "
)
# If all checks pass
return True
def load_yaml_file(file_path):
"""
Load a YAML file and return the data.
:param file_path: Path to the YAML file.
:return: Data from the YAML file.
"""
with open(file_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
def load_all_presets():
"""Load all the preset configs in the examples directory"""
## Load the examples
# Get the directory in which the script is located
script_directory = os.path.dirname(os.path.abspath(__file__))
# Construct the path pattern
example_path_pattern = os.path.join(script_directory, "examples", "*.yaml")
# Listing all YAML files
example_yaml_files = glob.glob(example_path_pattern)
## Load the user-provided presets
# ~/.memgpt/presets/*.yaml
user_presets_dir = os.path.join(MEMGPT_DIR, "presets")
# Create directory if it doesn't exist
if not os.path.exists(user_presets_dir):
os.makedirs(user_presets_dir)
# Construct the path pattern
user_path_pattern = os.path.join(user_presets_dir, "*.yaml")
# Listing all YAML files
user_yaml_files = glob.glob(user_path_pattern)
# Pull from both examplesa and user-provided
all_yaml_files = example_yaml_files + user_yaml_files
# Loading and creating a mapping from file name to YAML data
all_yaml_data = {}
for file_path in all_yaml_files:
# Extracting the base file name without the '.yaml' extension
base_name = os.path.splitext(os.path.basename(file_path))[0]
data = load_yaml_file(file_path)
all_yaml_data[base_name] = data
return all_yaml_data

View File

@ -1,90 +0,0 @@
import uuid
from datetime import datetime
from typing import Dict, List, Optional
from pydantic import Field, field_validator
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memory import Memory
class BaseAgent(MemGPTBase, validate_assignment=True):
__id_prefix__ = "agent"
description: Optional[str] = Field(None, description="The description of the agent.")
# metadata
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
user_id: Optional[str] = Field(None, description="The user id of the agent.")
class AgentState(BaseAgent):
"""Representation of an agent's state."""
id: str = BaseAgent.generate_id_field()
name: str = Field(..., description="The name of the agent.")
created_at: datetime = Field(..., description="The datetime the agent was created.", default_factory=datetime.now)
# in-context memory
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
memory: Memory = Field(default_factory=Memory, description="The in-context memory of the agent.")
# tools
tools: List[str] = Field(..., description="The tools used by the agent.")
# system prompt
system: str = Field(..., description="The system prompt used by the agent.")
# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
class CreateAgent(BaseAgent):
# all optional as server can generate defaults
name: Optional[str] = Field(None, description="The name of the agent.")
message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
@field_validator("name")
@classmethod
def validate_name(cls, name: str) -> str:
"""Validate the requested new agent name (prevent bad inputs)"""
import re
if not name:
# don't check if not provided
return name
# TODO: this check should also be added to other model (e.g. User.name)
# Length check
if not (1 <= len(name) <= 50):
raise ValueError("Name length must be between 1 and 50 characters.")
# Regex for allowed characters (alphanumeric, spaces, hyphens, underscores)
if not re.match("^[A-Za-z0-9 _-]+$", name):
raise ValueError("Name contains invalid characters.")
# Further checks can be added here...
# TODO
return name
class UpdateAgentState(BaseAgent):
id: str = Field(..., description="The id of the agent.")
name: Optional[str] = Field(None, description="The name of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
# TODO: determine if these should be editable via this schema?
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")

View File

@ -1,21 +0,0 @@
from typing import Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class BaseAPIKey(MemGPTBase):
__id_prefix__ = "sk" # secret key
class APIKey(BaseAPIKey):
id: str = BaseAPIKey.generate_id_field()
user_id: str = Field(..., description="The unique identifier of the user associated with the token.")
key: str = Field(..., description="The key value.")
name: str = Field(..., description="Name of the token.")
class APIKeyCreate(BaseAPIKey):
user_id: str = Field(..., description="The unique identifier of the user associated with the token.")
name: Optional[str] = Field(None, description="Name of the token.")

View File

@ -1,117 +0,0 @@
from typing import List, Optional, Union
from pydantic import Field, model_validator
from typing_extensions import Self
from memgpt.schemas.memgpt_base import MemGPTBase
# block of the LLM context
class BaseBlock(MemGPTBase, validate_assignment=True):
"""Base block of the LLM context"""
__id_prefix__ = "block"
# data value
value: Optional[Union[List[str], str]] = Field(None, description="Value of the block.")
limit: int = Field(2000, description="Character limit of the block.")
name: Optional[str] = Field(None, description="Name of the block.")
template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).")
label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona').")
# metadat
description: Optional[str] = Field(None, description="Description of the block.")
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")
# associated user/agent
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the block.")
@model_validator(mode="after")
def verify_char_limit(self) -> Self:
try:
assert len(self) <= self.limit
except AssertionError:
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self)})."
raise ValueError(error_msg)
except Exception as e:
raise e
return self
def __len__(self):
return len(str(self))
def __str__(self) -> str:
if isinstance(self.value, list):
return ",".join(self.value)
elif isinstance(self.value, str):
return self.value
else:
return ""
def __setattr__(self, name, value):
"""Run validation if self.value is updated"""
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True))
class Block(BaseBlock):
"""Block of the LLM context"""
id: str = BaseBlock.generate_id_field()
value: str = Field(..., description="Value of the block.")
class Human(Block):
"""Human block of the LLM context"""
label: str = "human"
class Persona(Block):
"""Persona block of the LLM context"""
label: str = "persona"
class CreateBlock(BaseBlock):
"""Create a block"""
template: bool = True
label: str = Field(..., description="Label of the block.")
class CreatePersona(BaseBlock):
"""Create a persona block"""
template: bool = True
label: str = "persona"
class CreateHuman(BaseBlock):
"""Create a human block"""
template: bool = True
label: str = "human"
class UpdateBlock(BaseBlock):
"""Update a block"""
id: str = Field(..., description="The unique identifier of the block.")
limit: Optional[int] = Field(2000, description="Character limit of the block.")
class UpdatePersona(UpdateBlock):
"""Update a persona block"""
label: str = "persona"
class UpdateHuman(UpdateBlock):
"""Update a human block"""
label: str = "human"

View File

@ -1,21 +0,0 @@
from typing import Dict, Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class DocumentBase(MemGPTBase):
"""Base class for document schemas"""
__id_prefix__ = "doc"
class Document(DocumentBase):
"""Representation of a single document (broken up into `Passage` objects)"""
id: str = DocumentBase.generate_id_field()
text: str = Field(..., description="The text of the document.")
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
user_id: str = Field(description="The unique identifier of the user associated with the document.")
metadata_: Optional[Dict] = Field({}, description="The metadata of the document.")

View File

@ -1,18 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class EmbeddingConfig(BaseModel):
"""Embedding model configuration"""
embedding_endpoint_type: str = Field(..., description="The endpoint type for the model.")
embedding_endpoint: Optional[str] = Field(None, description="The endpoint for the model (`None` if local).")
embedding_model: str = Field(..., description="The model for the embedding.")
embedding_dim: int = Field(..., description="The dimension of the embedding.")
embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.")
# azure only
azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.")
azure_version: Optional[str] = Field(None, description="The Azure version for the model.")
azure_deployment: Optional[str] = Field(None, description="The Azure deployment for the model.")

View File

@ -1,30 +0,0 @@
from enum import Enum
class MessageRole(str, Enum):
assistant = "assistant"
user = "user"
tool = "tool"
system = "system"
class OptionState(str, Enum):
"""Useful for kwargs that are bool + default option"""
YES = "yes"
NO = "no"
DEFAULT = "default"
class JobStatus(str, Enum):
created = "created"
running = "running"
completed = "completed"
failed = "failed"
pending = "pending"
class MessageStreamStatus(str, Enum):
done_generation = "[DONE_GEN]"
done_step = "[DONE_STEP]"
done = "[DONE]"

View File

@ -1,28 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.utils import get_utc_time
class JobBase(MemGPTBase):
__id_prefix__ = "job"
metadata_: Optional[dict] = Field({}, description="The metadata of the job.")
class Job(JobBase):
"""Representation of offline jobs."""
id: str = JobBase.generate_id_field()
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
user_id: str = Field(..., description="The unique identifier of the user associated with the job.")
class JobUpdate(JobBase):
id: str = Field(..., description="The unique identifier of the job.")
status: Optional[JobStatus] = Field(..., description="The status of the job.")

View File

@ -1,15 +0,0 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
class LLMConfig(BaseModel):
# TODO: 🤮 don't default to a vendor! bug city!
model: str = Field(..., description="LLM model name. ")
model_endpoint_type: str = Field(..., description="The endpoint type for the model.")
model_endpoint: str = Field(..., description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())

View File

@ -1,80 +0,0 @@
import uuid
from logging import getLogger
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
# from: https://gist.github.com/norton120/22242eadb80bf2cf1dd54a961b151c61
logger = getLogger(__name__)
class MemGPTBase(BaseModel):
"""Base schema for MemGPT schemas (does not include model provider schemas, e.g. OpenAI)"""
model_config = ConfigDict(
# allows you to use the snake or camelcase names in your code (ie user_id or userId)
populate_by_name=True,
# allows you do dump a sqlalchemy object directly (ie PersistedAddress.model_validate(SQLAdress)
from_attributes=True,
# throw errors if attributes are given that don't belong
extra="forbid",
)
# def __id_prefix__(self):
# raise NotImplementedError("All schemas must have an __id_prefix__ attribute!")
@classmethod
def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
prefix = prefix or cls.__id_prefix__
# TODO: generate ID from regex pattern?
def _generate_id() -> str:
return f"{prefix}-{uuid.uuid4()}"
return Field(
...,
description=cls._id_description(prefix),
pattern=cls._id_regex_pattern(prefix),
examples=[cls._id_example(prefix)],
default_factory=_generate_id,
)
# def _generate_id(self) -> str:
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
@classmethod
def _id_regex_pattern(cls, prefix: str):
"""generates the regex pattern for a given id"""
return (
r"^" + prefix + r"-" # prefix string
r"[a-fA-F0-9]{8}" # 8 hexadecimal characters
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
# r"[a-fA-F0-9]{12}$" # 12 hexadecimal characters
)
@classmethod
def _id_example(cls, prefix: str):
"""generates an example id for a given prefix"""
return [prefix + "-123e4567-e89b-12d3-a456-426614174000"]
@classmethod
def _id_description(cls, prefix: str):
"""generates a factory function for a given prefix"""
return f"The human-friendly ID of the {prefix.capitalize()}"
@field_validator("id", check_fields=False, mode="before")
@classmethod
def allow_bare_uuids(cls, v, values):
"""to ease the transition to stripe ids,
we allow bare uuids and convert them with a warning
"""
_ = values # for SCA
if isinstance(v, UUID):
logger.warning("Bare UUIDs are deprecated, please use the full prefixed id!")
return f"{cls.__id_prefix__}-{v}"
return v

View File

@ -1,78 +0,0 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import BaseModel, field_serializer
# MemGPT API style responses (intended to be easier to use vs getting true Message types)
class BaseMemGPTMessage(BaseModel):
id: str
date: datetime
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
return dt.now(timezone.utc).isoformat()
class InternalMonologue(BaseMemGPTMessage):
"""
{
"internal_monologue": msg,
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
"id": str(msg_obj.id) if msg_obj is not None else None,
}
"""
internal_monologue: str
class FunctionCall(BaseModel):
name: str
arguments: str
class FunctionCallMessage(BaseMemGPTMessage):
"""
{
"function_call": {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
},
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""
function_call: FunctionCall
class FunctionReturn(BaseMemGPTMessage):
"""
{
"function_return": msg,
"status": "success" or "error",
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""
function_return: str
status: Literal["success", "error"]
MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn]
# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string
class AssistantMessage(BaseMemGPTMessage):
assistant_message: str
class LegacyFunctionCallMessage(BaseMemGPTMessage):
function_call: str
LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]

View File

@ -1,18 +0,0 @@
from typing import List
from pydantic import BaseModel, Field
from memgpt.schemas.message import MessageCreate
class MemGPTRequest(BaseModel):
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement
stream_steps: bool = Field(
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)

View File

@ -1,17 +0,0 @@
from typing import List, Union
from pydantic import BaseModel, Field
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics
# TODO: consider moving into own file
class MemGPTResponse(BaseModel):
# messages: List[Message] = Field(..., description="The messages returned by the agent.")
messages: Union[List[Message], List[MemGPTMessage], List[LegacyMemGPTMessage]] = Field(
..., description="The messages returned by the agent."
)
usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.")

View File

@ -1,138 +0,0 @@
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from memgpt.schemas.block import Block
class Memory(BaseModel, validate_assignment=True):
"""Represents the in-context memory of the agent"""
# Private variable to avoid assignments with incorrect types
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")
@classmethod
def load(cls, state: dict):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
obj.memory[key] = Block(**value)
return obj
def __str__(self) -> str:
"""Representation of the memory in-context"""
section_strs = []
for section, module in self.memory.items():
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
return "\n".join(section_strs)
def to_dict(self):
"""Convert to dictionary representation"""
return {key: value.dict() for key, value in self.memory.items()}
def to_flat_dict(self):
"""Convert to a dictionary that maps directly from block names to values"""
return {k: v.value for k, v in self.memory.items() if v is not None}
def list_block_names(self) -> List[str]:
"""Return a list of the block names held inside the memory object"""
return list(self.memory.keys())
def get_block(self, name: str) -> Block:
"""Correct way to index into the memory.memory field, returns a Block"""
if name not in self.memory:
return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
else:
return self.memory[name]
def link_block(self, name: str, block: Block, override: Optional[bool] = False):
"""Link a new block to the memory object"""
if not isinstance(block, Block):
raise ValueError(f"Param block must be type Block (not {type(block)})")
if not isinstance(name, str):
raise ValueError(f"Name must be str (not type {type(name)})")
if not override and name in self.memory:
raise ValueError(f"Block with name {name} already exists")
self.memory[name] = block
def update_block_value(self, name: str, value: Union[List[str], str]):
"""Update the value of a block"""
if name not in self.memory:
raise ValueError(f"Block with name {name} does not exist")
if not (isinstance(value, str) or (isinstance(value, list) and all(isinstance(v, str) for v in value))):
raise ValueError(f"Provided value must be a string or list of strings")
self.memory[name].value = value
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
class BaseChatMemory(Memory):
def core_memory_append(self, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
Args:
name (str): Section of the memory to be edited (persona or human).
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
current_value = str(self.memory.get_block(name).value)
new_value = current_value + "\n" + str(content)
self.memory.update_block_value(name=name, value=new_value)
return None
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
name (str): Section of the memory to be edited (persona or human).
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
current_value = str(self.memory.get_block(name).value)
new_value = current_value.replace(str(old_content), str(new_content))
self.memory.update_block_value(name=name, value=new_value)
return None
class ChatMemory(BaseChatMemory):
"""
ChatMemory initializes a BaseChatMemory with two default blocks
"""
def __init__(self, persona: str, human: str, limit: int = 2000):
super().__init__()
self.link_block(name="persona", block=Block(name="persona", value=persona, limit=limit, label="persona"))
self.link_block(name="human", block=Block(name="human", value=human, limit=limit, label="human"))
class BlockChatMemory(BaseChatMemory):
"""
BlockChatMemory is a subclass of BaseChatMemory which uses shared memory blocks specified at initialization-time.
"""
def __init__(self, blocks: List[Block] = []):
super().__init__()
for block in blocks:
# TODO: centralize these internal schema validations
assert block.name is not None and block.name != "", "each existing chat block must have a name"
self.link_block(name=block.name, block=block)
class UpdateMemory(BaseModel):
"""Update the memory of the agent"""
class ArchivalMemorySummary(BaseModel):
size: int = Field(..., description="Number of rows in archival memory")
class RecallMemorySummary(BaseModel):
size: int = Field(..., description="Number of rows in recall memory")

View File

@ -1,123 +0,0 @@
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
class SystemMessage(BaseModel):
content: str
role: str = "system"
name: Optional[str] = None
class UserMessage(BaseModel):
content: Union[str, List[str]]
role: str = "user"
name: Optional[str] = None
class ToolCallFunction(BaseModel):
name: str = Field(..., description="The name of the function to call")
arguments: str = Field(..., description="The arguments to pass to the function (JSON dump)")
class ToolCall(BaseModel):
id: str = Field(..., description="The ID of the tool call")
type: str = "function"
function: ToolCallFunction = Field(..., description="The arguments and name for the function")
class AssistantMessage(BaseModel):
content: Optional[str] = None
role: str = "assistant"
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
class ToolMessage(BaseModel):
content: str
role: str = "tool"
tool_call_id: str
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
# TODO: this might not be necessary with the validator
def cast_message_to_subtype(m_dict: dict) -> ChatMessage:
"""Cast a dictionary to one of the individual message types"""
role = m_dict.get("role")
if role == "system":
return SystemMessage(**m_dict)
elif role == "user":
return UserMessage(**m_dict)
elif role == "assistant":
return AssistantMessage(**m_dict)
elif role == "tool":
return ToolMessage(**m_dict)
else:
raise ValueError("Unknown message role")
class ResponseFormat(BaseModel):
type: str = Field(default="text", pattern="^(text|json_object)$")
## tool_choice ##
class FunctionCall(BaseModel):
name: str
class ToolFunctionChoice(BaseModel):
# The type of the tool. Currently, only function is supported
type: Literal["function"] = "function"
# type: str = Field(default="function", const=True)
function: FunctionCall
ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice]
## tools ##
class FunctionSchema(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None # JSON Schema for the parameters
class Tool(BaseModel):
# The type of the tool. Currently, only function is supported
type: Literal["function"] = "function"
# type: str = Field(default="function", const=True)
function: FunctionSchema
## function_call ##
FunctionCallChoice = Union[Literal["none", "auto"], FunctionCall]
class ChatCompletionRequest(BaseModel):
"""https://platform.openai.com/docs/api-reference/chat/create"""
model: str
messages: List[ChatMessage]
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
temperature: Optional[float] = 1
top_p: Optional[float] = 1
user: Optional[str] = None # unique ID of the end-user (for monitoring)
# function-calling related
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = "none"
# deprecated scheme
functions: Optional[List[FunctionSchema]] = None
function_call: Optional[FunctionCallChoice] = None

View File

@ -1,66 +0,0 @@
from datetime import datetime
from typing import Dict, List, Optional
from pydantic import Field, field_validator
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.utils import get_utc_time
class PassageBase(MemGPTBase):
__id_prefix__ = "passage"
# associated user/agent
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.")
# origin data source
source_id: Optional[str] = Field(None, description="The data source of the passage.")
# document association
doc_id: Optional[str] = Field(None, description="The unique identifier of the document associated with the passage.")
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
class Passage(PassageBase):
id: str = PassageBase.generate_id_field()
# passage text
text: str = Field(..., description="The text of the passage.")
# embeddings
embedding: Optional[List[float]] = Field(..., description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(..., description="The embedding configuration used by the passage.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.")
@field_validator("embedding")
@classmethod
def pad_embeddings(cls, embedding: List[float]) -> List[float]:
"""Pad embeddings to MAX_EMBEDDING_SIZE. This is necessary to ensure all stored embeddings are the same size."""
import numpy as np
if embedding and len(embedding) != MAX_EMBEDDING_DIM:
np_embedding = np.array(embedding)
padded_embedding = np.pad(np_embedding, (0, MAX_EMBEDDING_DIM - np_embedding.shape[0]), mode="constant")
return padded_embedding.tolist()
return embedding
class PassageCreate(PassageBase):
text: str = Field(..., description="The text of the passage.")
# optionally provide embeddings
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
class PassageUpdate(PassageCreate):
id: str = Field(..., description="The unique identifier of the passage.")
text: Optional[str] = Field(None, description="The text of the passage.")
# optionally provide embeddings
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")

View File

@ -1,49 +0,0 @@
from datetime import datetime
from typing import Optional
from fastapi import UploadFile
from pydantic import BaseModel, Field
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.utils import get_utc_time
class BaseSource(MemGPTBase):
"""
Shared attributes accourss all source schemas.
"""
__id_prefix__ = "source"
description: Optional[str] = Field(None, description="The description of the source.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
# NOTE: .metadata is a reserved attribute on SQLModel
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
class SourceCreate(BaseSource):
name: str = Field(..., description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
class Source(BaseSource):
id: str = BaseSource.generate_id_field()
name: str = Field(..., description="The name of the source.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the source.")
user_id: str = Field(..., description="The ID of the user that created the source.")
class SourceUpdate(BaseSource):
id: str = Field(..., description="The ID of the source.")
name: Optional[str] = Field(None, description="The name of the source.")
class UploadFileToSourceRequest(BaseModel):
file: UploadFile = Field(..., description="The file to upload.")
class UploadFileToSourceResponse(BaseModel):
source: Source = Field(..., description="The source the file was uploaded to.")
added_passages: int = Field(..., description="The number of passages added to the source.")
added_documents: int = Field(..., description="The number of documents added to the source.")

View File

@ -1,86 +0,0 @@
from typing import Dict, List, Optional
from pydantic import Field
from memgpt.functions.schema_generator import (
generate_schema_from_args_schema,
generate_tool_wrapper,
)
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.openai.chat_completions import ToolCall
class BaseTool(MemGPTBase):
__id_prefix__ = "tool"
# optional fields
description: Optional[str] = Field(None, description="The description of the tool.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
module: Optional[str] = Field(None, description="The module of the function.")
# optional: user_id (user-specific tools)
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the function.")
class Tool(BaseTool):
id: str = BaseTool.generate_id_field()
name: str = Field(..., description="The name of the function.")
tags: List[str] = Field(..., description="Metadata tags.")
# code
source_code: str = Field(..., description="The source code of the function.")
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
def to_dict(self):
"""Convert into OpenAI representation"""
return vars(
ToolCall(
tool_id=self.id,
tool_call_type="function",
function=self.module,
)
)
@classmethod
def from_crewai(cls, crewai_tool) -> "Tool":
"""
Class method to create an instance of Tool from a crewAI BaseTool object.
Args:
crewai_tool (CrewAIBaseTool): An instance of a crewAI BaseTool (BaseTool from crewai)
Returns:
Tool: A memGPT Tool initialized with attributes derived from the provided crewAI BaseTool object.
"""
crewai_tool.name
description = crewai_tool.description
source_type = "python"
tags = ["crew-ai"]
wrapper_func_name, wrapper_function_str = generate_tool_wrapper(crewai_tool.__class__.__name__)
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
return cls(
name=wrapper_func_name,
description=description,
source_type=source_type,
tags=tags,
source_code=wrapper_function_str,
json_schema=json_schema,
)
class ToolCreate(BaseTool):
name: str = Field(..., description="The name of the function.")
tags: List[str] = Field(..., description="Metadata tags.")
source_code: str = Field(..., description="The source code of the function.")
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
class ToolUpdate(ToolCreate):
id: str = Field(..., description="The unique identifier of the tool.")
name: Optional[str] = Field(None, description="The name of the function.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
source_code: Optional[str] = Field(None, description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")

View File

@ -1,8 +0,0 @@
from pydantic import BaseModel, Field
class MemGPTUsageStatistics(BaseModel):
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
step_count: int = Field(0, description="The number of steps taken by the agent.")

View File

@ -1,20 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class UserBase(MemGPTBase):
__id_prefix__ = "user"
class User(UserBase):
id: str = UserBase.generate_id_field()
name: str = Field(..., description="The name of the user.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
class UserCreate(UserBase):
name: Optional[str] = Field(None, description="The name of the user.")

View File

@ -1,8 +1,6 @@
from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from memgpt.schemas.agent import AgentState from memgpt.server.rest_api.agents.index import ListAgentsResponse
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
@ -10,12 +8,14 @@ router = APIRouter()
def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface): def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents", tags=["agents"], response_model=List[AgentState]) @router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
def get_all_agents(): def get_all_agents():
""" """
Get a list of all agents in the database Get a list of all agents in the database
""" """
interface.clear() interface.clear()
return server.list_agents() agents_data = server.list_agents_legacy()
return ListAgentsResponse(**agents_data)
return router return router

View File

@ -3,7 +3,7 @@ from typing import List, Literal, Optional
from fastapi import APIRouter, Body, HTTPException from fastapi import APIRouter, Body, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from memgpt.schemas.tool import Tool as ToolModel # TODO: modify from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer

View File

@ -1,46 +1,102 @@
import uuid
from typing import List, Optional from typing import List, Optional
from fastapi import APIRouter, Body, HTTPException, Query from fastapi import APIRouter, Body, HTTPException, Query
from pydantic import BaseModel, Field
from memgpt.schemas.api_key import APIKey, APIKeyCreate from memgpt.data_types import User
from memgpt.schemas.user import User, UserCreate
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
router = APIRouter() router = APIRouter()
class GetAllUsersResponse(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
user_list: List[dict] = Field(..., description="A list of users.")
class CreateUserRequest(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="Identifier of the user (optional, generated automatically if null).")
api_key_name: Optional[str] = Field(None, description="Name for API key autogenerated on user creation (optional).")
class CreateUserResponse(BaseModel):
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
api_key: str = Field(..., description="New API key generated for user.")
class CreateAPIKeyRequest(BaseModel):
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
name: Optional[str] = Field(None, description="Name for the API key (optional).")
class CreateAPIKeyResponse(BaseModel):
api_key: str = Field(..., description="New API key generated.")
class GetAPIKeysResponse(BaseModel):
api_key_list: List[str] = Field(..., description="Identifier of the user (UUID).")
class DeleteAPIKeyResponse(BaseModel):
message: str
api_key_deleted: str
class DeleteUserResponse(BaseModel):
message: str
user_id_deleted: uuid.UUID
def setup_admin_router(server: SyncServer, interface: QueuingInterface): def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users", tags=["admin"], response_model=List[User]) @router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
def get_all_users(cursor: Optional[str] = Query(None), limit: Optional[int] = Query(50)): def get_all_users(cursor: Optional[uuid.UUID] = Query(None), limit: Optional[int] = Query(50)):
""" """
Get a list of all users in the database Get a list of all users in the database
""" """
try: try:
# TODO: make this call a server function next_cursor, users = server.ms.get_all_users(cursor, limit)
_, users = server.ms.get_all_users(cursor=cursor, limit=limit) processed_users = [{"user_id": user.id} for user in users]
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
return users return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)
@router.post("/users", tags=["admin"], response_model=User) @router.post("/users", tags=["admin"], response_model=CreateUserResponse)
def create_user(request: UserCreate = Body(...)): def create_user(request: Optional[CreateUserRequest] = Body(None)):
""" """
Create a new user in the database Create a new user in the database
""" """
if request is None:
request = CreateUserRequest()
new_user = User(
id=None if not request.user_id else request.user_id,
# TODO can add more fields (name? metadata?)
)
try: try:
user = server.create_user(request) server.ms.create_user(new_user)
# make sure we can retrieve the user from the DB too
new_user_ret = server.ms.get_user(new_user.id)
if new_user_ret is None:
raise HTTPException(status_code=500, detail=f"Failed to verify user creation")
# create an API key for the user
token = server.ms.create_api_key(user_id=new_user.id, name=request.api_key_name)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
return user return CreateUserResponse(user_id=new_user_ret.id, api_key=token.token)
@router.delete("/users", tags=["admin"], response_model=User) @router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
def delete_user( def delete_user(
user_id: str = Query(..., description="The user_id key to be deleted."), user_id: uuid.UUID = Query(..., description="The user_id key to be deleted."),
): ):
# TODO make a soft deletion, instead of a hard deletion # TODO make a soft deletion, instead of a hard deletion
try: try:
@ -52,24 +108,24 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
return user return DeleteUserResponse(message="User successfully deleted.", user_id_deleted=user_id)
@router.post("/users/keys", tags=["admin"], response_model=APIKey) @router.post("/users/keys", tags=["admin"], response_model=CreateAPIKeyResponse)
def create_new_api_key(request: APIKeyCreate = Body(...)): def create_new_api_key(request: CreateAPIKeyRequest = Body(...)):
""" """
Create a new API key for a user Create a new API key for a user
""" """
try: try:
api_key = server.create_api_key(request) token = server.ms.create_api_key(user_id=request.user_id, name=request.name)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
return api_key return CreateAPIKeyResponse(api_key=token.token)
@router.get("/users/keys", tags=["admin"], response_model=List[APIKey]) @router.get("/users/keys", tags=["admin"], response_model=GetAPIKeysResponse)
def get_api_keys( def get_api_keys(
user_id: str = Query(..., description="The unique identifier of the user."), user_id: uuid.UUID = Query(..., description="The unique identifier of the user."),
): ):
""" """
Get a list of all API keys for a user Get a list of all API keys for a user
@ -77,22 +133,28 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
try: try:
if server.ms.get_user(user_id=user_id) is None: if server.ms.get_user(user_id=user_id) is None:
raise HTTPException(status_code=404, detail=f"User does not exist") raise HTTPException(status_code=404, detail=f"User does not exist")
api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id) tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
processed_tokens = [t.token for t in tokens]
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
return api_keys print("TOKENS", processed_tokens)
return GetAPIKeysResponse(api_key_list=processed_tokens)
@router.delete("/users/keys", tags=["admin"], response_model=APIKey) @router.delete("/users/keys", tags=["admin"], response_model=DeleteAPIKeyResponse)
def delete_api_key( def delete_api_key(
api_key: str = Query(..., description="The API key to be deleted."), api_key: str = Query(..., description="The API key to be deleted."),
): ):
try: try:
return server.delete_api_key(api_key) token = server.ms.get_api_key(api_key=api_key)
if token is None:
raise HTTPException(status_code=404, detail=f"API key does not exist")
server.ms.delete_api_key(api_key=api_key)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
return DeleteAPIKeyResponse(message="API key successfully deleted.", api_key_deleted=api_key)
return router return router

View File

@ -0,0 +1,48 @@
import uuid
from functools import partial
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class CommandRequest(BaseModel):
command: str = Field(..., description="The command to be executed by the agent.")
class CommandResponse(BaseModel):
response: str = Field(..., description="The result of the executed command.")
def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse)
def run_command(
agent_id: uuid.UUID,
request: CommandRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Execute a command on a specified agent.
This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately.
Raises an HTTPException for any processing errors.
"""
interface.clear()
try:
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return CommandResponse(response=response)
return router

View File

@ -0,0 +1,156 @@
import re
import uuid
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import (
AgentStateModel,
EmbeddingConfigModel,
LLMConfigModel,
)
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class AgentRenameRequest(BaseModel):
agent_name: str = Field(..., description="New name for the agent.")
class GetAgentResponse(BaseModel):
# config: dict = Field(..., description="The agent configuration object.")
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
sources: List[str] = Field(..., description="The list of data sources associated with the agent.")
last_run_at: Optional[int] = Field(None, description="The unix timestamp of when the agent was last run.")
def validate_agent_name(name: str) -> str:
"""Validate the requested new agent name (prevent bad inputs)"""
# Length check
if not (1 <= len(name) <= 50):
raise HTTPException(status_code=400, detail="Name length must be between 1 and 50 characters.")
# Regex for allowed characters (alphanumeric, spaces, hyphens, underscores)
if not re.match("^[A-Za-z0-9 _-]+$", name):
raise HTTPException(status_code=400, detail="Name contains invalid characters.")
# Further checks can be added here...
# TODO
return name
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the configuration for a specific agent.
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
interface.clear()
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)
# configs
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=agent_state.tools,
system=agent_state.system,
metadata=agent_state._metadata,
),
last_run_at=None, # TODO
sources=attached_sources,
)
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Updates the name of a specific agent.
This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
valid_name = validate_agent_name(request.agent_name)
interface.clear()
try:
agent_state = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=agent_state.tools,
system=agent_state.system,
),
last_run_at=None, # TODO
sources=attached_sources,
)
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
# agent_id = uuid.UUID(agent_id)
interface.clear()
try:
server.delete_agent(user_id=user_id, agent_id=agent_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent agent_id={agent_id} successfully deleted"})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@ -1,22 +1,49 @@
import uuid
from functools import partial from functools import partial
from typing import List from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState from memgpt.constants import BASE_TOOLS
from memgpt.memory import ChatMemory
from memgpt.models.pydantic_models import (
AgentStateModel,
EmbeddingConfigModel,
LLMConfigModel,
PresetModel,
)
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
from memgpt.settings import settings
router = APIRouter() router = APIRouter()
class ListAgentsResponse(BaseModel):
num_agents: int = Field(..., description="The number of agents available to the user.")
# TODO make return type List[AgentStateModel]
# also return - presets: List[PresetModel]
agents: List[dict] = Field(..., description="List of agent configurations.")
class CreateAgentRequest(BaseModel):
# TODO: modify this (along with front end)
config: dict = Field(..., description="The agent configuration object.")
class CreateAgentResponse(BaseModel):
agent_state: AgentStateModel = Field(..., description="The state of the newly created agent.")
preset: PresetModel = Field(..., description="The preset that the agent was created from.")
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str): def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password) get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents", tags=["agents"], response_model=List[AgentState]) @router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
def list_agents( def list_agents(
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
List all agents associated with a given user. List all agents associated with a given user.
@ -24,71 +51,95 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
""" """
interface.clear() interface.clear()
agents_data = server.list_agents(user_id=user_id) agents_data = server.list_agents_legacy(user_id=user_id)
return agents_data return ListAgentsResponse(**agents_data)
@router.post("/agents", tags=["agents"], response_model=AgentState) @router.post("/agents", tags=["agents"], response_model=CreateAgentResponse)
def create_agent( def create_agent(
request: CreateAgent = Body(...), request: CreateAgentRequest = Body(...),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Create a new agent with the specified configuration. Create a new agent with the specified configuration.
""" """
interface.clear() interface.clear()
agent_state = server.create_agent(request, user_id=user_id) # Parse request
return agent_state # TODO: don't just use JSON in the future
human_name = request.config["human_name"] if "human_name" in request.config else None
human = request.config["human"] if "human" in request.config else None
persona_name = request.config["persona_name"] if "persona_name" in request.config else None
persona = request.config["persona"] if "persona" in request.config else None
request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
tool_names = request.config["function_names"] if ("function_names" in request.config and request.config["function_names"]) else None
metadata = request.config["metadata"] if "metadata" in request.config else {}
metadata["human"] = human_name
metadata["persona"] = persona_name
# TODO: remove this -- should be added based on create agent fields
if isinstance(tool_names, str): # TODO: fix this on clinet side?
tool_names = tool_names.split(",")
if tool_names is None or tool_names == "":
tool_names = []
for name in BASE_TOOLS: # TODO: remove this
if name not in tool_names:
tool_names.append(name)
assert isinstance(tool_names, list), "Tool names must be a list of strings."
# TODO: eventually remove this - should support general memory at the REST endpoint
# TODO: the REST server should add default memory tools at startup time
memory = ChatMemory(persona=persona, human=human)
@router.post("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
def update_agent(
agent_id: str,
request: UpdateAgentState = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""Update an exsiting agent"""
interface.clear()
try: try:
# TODO: should id be moved out of UpdateAgentState? agent_state = server.create_agent(
agent_state = server.update_agent(request, user_id=user_id) user_id=user_id,
# **request.config
# TODO turn into a pydantic model
name=request.config["name"],
memory=memory,
system=request.config.get("system", None),
# persona_name=persona_name,
# human_name=human_name,
# persona=persona,
# human=human,
# llm_config=LLMConfigModel(
# model=request.config['model'],
# )
# tools
tools=tool_names,
metadata=metadata,
# function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
)
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return CreateAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=agent_state.tools,
system=agent_state.system,
metadata=agent_state._metadata,
),
preset=PresetModel( # TODO: remove (placeholder to avoid breaking frontend)
name="dummy_preset",
id=agent_state.id,
user_id=agent_state.user_id,
description="",
created_at=agent_state.created_at,
system=agent_state.system,
persona="",
human="",
functions_schema=[],
),
)
except Exception as e: except Exception as e:
print(str(e)) print(str(e))
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return agent_state
@router.get("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
def get_agent_state(
agent_id: str = None,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get the state of the agent.
"""
interface.clear()
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
return server.get_agent_state(user_id=user_id, agent_id=agent_id)
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
# agent_id = str(agent_id)
interface.clear()
try:
server.delete_agent(user_id=user_id, agent_id=agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router return router

View File

@ -1,12 +1,11 @@
import uuid
from functools import partial from functools import partial
from typing import Dict, List, Optional from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from memgpt.schemas.message import Message
from memgpt.schemas.passage import Passage
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
@ -14,24 +13,60 @@ from memgpt.server.server import SyncServer
router = APIRouter() router = APIRouter()
class CoreMemory(BaseModel):
human: str | None = Field(None, description="Human element of the core memory.")
persona: str | None = Field(None, description="Persona element of the core memory.")
class GetAgentMemoryResponse(BaseModel):
core_memory: CoreMemory = Field(..., description="The state of the agent's core memory.")
recall_memory: int = Field(..., description="Size of the agent's recall memory.")
archival_memory: int = Field(..., description="Size of the agent's archival memory.")
# NOTE not subclassing CoreMemory since in the request both field are optional
class UpdateAgentMemoryRequest(BaseModel):
human: str = Field(None, description="Human element of the core memory.")
persona: str = Field(None, description="Persona element of the core memory.")
class UpdateAgentMemoryResponse(BaseModel):
old_core_memory: CoreMemory = Field(..., description="The previous state of the agent's core memory.")
new_core_memory: CoreMemory = Field(..., description="The updated state of the agent's core memory.")
class ArchivalMemoryObject(BaseModel):
# TODO move to models/pydantic_models, or inherent from data_types Record
id: uuid.UUID = Field(..., description="Unique identifier for the memory object inside the archival memory store.")
contents: str = Field(..., description="The memory contents.")
class GetAgentArchivalMemoryResponse(BaseModel):
# TODO: make this List[Passage] instead
archival_memory: List[ArchivalMemoryObject] = Field(..., description="A list of all memory objects in archival memory.")
class InsertAgentArchivalMemoryRequest(BaseModel):
content: str = Field(..., description="The memory contents to insert into archival memory.")
class InsertAgentArchivalMemoryResponse(BaseModel):
ids: List[str] = Field(
..., description="Unique identifier for the new archival memory object. May return multiple ids if insert contents are chunked."
)
class DeleteAgentArchivalMemoryRequest(BaseModel):
id: str = Field(..., description="Unique identifier for the new archival memory object.")
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str): def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password) get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/memory/messages", tags=["agents"], response_model=List[Message]) @router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
def get_agent_in_context_messages(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the messages in the context of a specific agent.
"""
interface.clear()
return server.get_in_context_messages(agent_id=agent_id)
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
def get_agent_memory( def get_agent_memory(
agent_id: str, agent_id: uuid.UUID,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Retrieve the memory state of a specific agent. Retrieve the memory state of a specific agent.
@ -39,13 +74,14 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface,
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
""" """
interface.clear() interface.clear()
return server.get_agent_memory(agent_id=agent_id) memory = server.get_agent_memory(user_id=user_id, agent_id=agent_id)
return GetAgentMemoryResponse(**memory)
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory) @router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse)
def update_agent_memory( def update_agent_memory(
agent_id: str, agent_id: uuid.UUID,
request: Dict = Body(...), request: UpdateAgentMemoryRequest = Body(...),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Update the core memory of a specific agent. Update the core memory of a specific agent.
@ -53,86 +89,75 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface,
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID. This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
""" """
interface.clear() interface.clear()
memory = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=request)
return memory
@router.get("/agents/{agent_id}/memory/recall", tags=["agents"], response_model=RecallMemorySummary) new_memory_contents = {"persona": request.persona, "human": request.human}
def get_agent_recall_memory_summary( response = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=new_memory_contents)
agent_id: str, return UpdateAgentMemoryResponse(**response)
user_id: str = Depends(get_current_user_with_server),
@router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
def get_agent_archival_memory_all(
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Retrieve the summary of the recall memory of a specific agent. Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
""" """
interface.clear() interface.clear()
return server.get_recall_memory_summary(agent_id=agent_id) archival_memories = server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
print("archival_memories:", archival_memories)
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories]
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
@router.get("/agents/{agent_id}/memory/archival", tags=["agents"], response_model=ArchivalMemorySummary) @router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
def get_agent_archival_memory_summary(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the summary of the archival memory of a specific agent.
"""
interface.clear()
return server.get_archival_memory_summary(agent_id=agent_id)
# @router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=List[Passage])
# def get_agent_archival_memory_all(
# agent_id: str,
# user_id: str = Depends(get_current_user_with_server),
# ):
# """
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
# """
# interface.clear()
# return server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage])
def get_agent_archival_memory( def get_agent_archival_memory(
agent_id: str, agent_id: uuid.UUID,
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."), after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."), before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."), limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Retrieve the memories in an agent's archival memory store (paginated query). Retrieve the memories in an agent's archival memory store (paginated query).
""" """
interface.clear() interface.clear()
return server.get_agent_archival_cursor( # TODO need to add support for non-postgres here
# chroma will throw:
# raise ValueError("Cannot run get_all_cursor with chroma")
_, archival_json_records = server.get_agent_archival_cursor(
user_id=user_id, user_id=user_id,
agent_id=agent_id, agent_id=agent_id,
after=after, after=after,
before=before, before=before,
limit=limit, limit=limit,
) )
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_json_records]
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
@router.post("/agents/{agent_id}/archival/{memory}", tags=["agents"], response_model=List[Passage]) @router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=InsertAgentArchivalMemoryResponse)
def insert_agent_archival_memory( def insert_agent_archival_memory(
agent_id: str, agent_id: uuid.UUID,
memory: str, request: InsertAgentArchivalMemoryRequest = Body(...),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Insert a memory into an agent's archival memory store. Insert a memory into an agent's archival memory store.
""" """
interface.clear() interface.clear()
return server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=memory) memory_ids = server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=request.content)
return InsertAgentArchivalMemoryResponse(ids=memory_ids)
@router.delete("/agents/{agent_id}/archival/{memory_id}", tags=["agents"]) @router.delete("/agents/{agent_id}/archival", tags=["agents"])
def delete_agent_archival_memory( def delete_agent_archival_memory(
agent_id: str, agent_id: uuid.UUID,
memory_id: str, id: str = Query(..., description="Unique ID of the memory to be deleted."),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Delete a memory from an agent's archival memory store. Delete a memory from an agent's archival memory store.
""" """
# TODO: should probably return a `Passage`
interface.clear() interface.clear()
try: try:
memory_id = uuid.UUID(id)
server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id) server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
except HTTPException: except HTTPException:

View File

@ -1,48 +1,116 @@
import asyncio import asyncio
import uuid
from datetime import datetime from datetime import datetime
from enum import Enum
from functools import partial from functools import partial
from typing import List, Optional, Union from typing import List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse
from memgpt.schemas.enums import MessageRole, MessageStreamStatus from memgpt.models.pydantic_models import MemGPTUsageStatistics
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.message import Message
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
from memgpt.server.rest_api.utils import sse_async_generator from memgpt.server.rest_api.utils import sse_async_generator
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
from memgpt.utils import deduplicate
router = APIRouter() router = APIRouter()
# TODO: cpacker should check this file class MessageRoleType(str, Enum):
# TODO: move this into server.py? user = "user"
system = "system"
class UserMessageRequest(BaseModel):
message: str = Field(..., description="The message content to be processed by the agent.")
name: Optional[str] = Field(default=None, description="Name of the message request sender")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
stream_steps: bool = Field(
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)
timestamp: Optional[datetime] = Field(
None,
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
)
stream: bool = Field(
default=False,
description="Legacy flag for old streaming API, will be deprecrated in the future.",
deprecated=True,
)
# @validator("timestamp", pre=True, always=True)
# def validate_timestamp(cls, value: Optional[datetime]) -> Optional[datetime]:
# if value is None:
# return value # If the timestamp is None, just return None, implying default handling to set server-side
# if not isinstance(value, datetime):
# raise TypeError("Timestamp must be a datetime object with timezone information.")
# if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
# raise ValueError("Timestamp must be timezone-aware.")
# # Convert timestamp to UTC if it's not already in UTC
# if value.tzinfo.utcoffset(value) != timezone.utc.utcoffset(value):
# value = value.astimezone(timezone.utc)
# return value
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
usage: MemGPTUsageStatistics = Field(..., description="Usage statistics for the completion.")
class GetAgentMessagesRequest(BaseModel):
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class GetAgentMessagesCursorRequest(BaseModel):
before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.")
limit: int = Field(..., description="Maximum number of messages to retrieve.")
class GetAgentMessagesResponse(BaseModel):
messages: list = Field(..., description="List of message objects.")
async def send_message_to_agent( async def send_message_to_agent(
server: SyncServer, server: SyncServer,
agent_id: str, agent_id: uuid.UUID,
user_id: str, user_id: uuid.UUID,
role: MessageRole, role: str,
message: str, message: str,
stream_legacy: bool, # legacy
stream_steps: bool, stream_steps: bool,
stream_tokens: bool, stream_tokens: bool,
chat_completion_mode: Optional[bool] = False, chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None, timestamp: Optional[datetime] = None,
# related to whether or not we return `MemGPTMessage`s or `Message`s ) -> Union[StreamingResponse, UserMessageResponse]:
return_message_object: bool = True, # Should be True for Python Client, False for REST API
) -> Union[StreamingResponse, MemGPTResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy.""" """Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# TODO: @charles is this the correct way to handle?
include_final_message = True
# determine role # TODO this is a total hack but is required until we move streaming into the model config
if role == MessageRole.user: if server.server_llm_config.model_endpoint != "https://api.openai.com/v1":
stream_tokens = False
# handle the legacy mode streaming
if stream_legacy:
# NOTE: override
stream_steps = True
stream_tokens = False
include_final_message = False
else:
include_final_message = True
if role == "user" or role is None:
message_func = server.user_message message_func = server.user_message
elif role == MessageRole.system: elif role == "system":
message_func = server.system_message message_func = server.system_message
else: else:
raise HTTPException(status_code=500, detail=f"Bad role {role}") raise HTTPException(status_code=500, detail=f"Bad role {role}")
@ -53,11 +121,9 @@ async def send_message_to_agent(
# For streaming response # For streaming response
try: try:
# TODO: move this logic into server.py
# Get the generator object off of the agent's streaming interface # Get the generator object off of the agent's streaming interface
# This will be attached to the POST SSE request used under-the-hood # This will be attached to the POST SSE request used under-the-hood
memgpt_agent = server._get_or_load_agent(agent_id=agent_id) memgpt_agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
streaming_interface = memgpt_agent.interface streaming_interface = memgpt_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface): if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
@ -67,6 +133,8 @@ async def send_message_to_agent(
# "chatcompletion mode" does some remapping and ignores inner thoughts # "chatcompletion mode" does some remapping and ignores inner thoughts
streaming_interface.streaming_chat_completion_mode = chat_completion_mode streaming_interface.streaming_chat_completion_mode = chat_completion_mode
# NOTE: for legacy 'stream' flag
streaming_interface.nonstreaming_legacy_mode = stream_legacy
# streaming_interface.allow_assistant_message = stream # streaming_interface.allow_assistant_message = stream
# streaming_interface.function_call_legacy_mode = stream # streaming_interface.function_call_legacy_mode = stream
@ -77,44 +145,21 @@ async def send_message_to_agent(
) )
if stream_steps: if stream_steps:
if return_message_object:
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
raise NotImplementedError
# return a stream # return a stream
return StreamingResponse( return StreamingResponse(
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message), sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
media_type="text/event-stream", media_type="text/event-stream",
) )
else: else:
# buffer the stream, then return the list # buffer the stream, then return the list
generated_stream = [] generated_stream = []
async for message in streaming_interface.get_generator(): async for message in streaming_interface.get_generator():
assert (
isinstance(message, MemGPTMessage)
or isinstance(message, LegacyMemGPTMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message) generated_stream.append(message)
if message == MessageStreamStatus.done: if "data" in message and message["data"] == "[DONE]":
break break
filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]]
# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
usage = await task usage = await task
return UserMessageResponse(messages=filtered_stream, usage=usage)
# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
if return_message_object:
message_ids = [m.id for m in filtered_stream]
message_ids = deduplicate(message_ids)
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
return MemGPTResponse(messages=message_objs, usage=usage)
else:
return MemGPTResponse(messages=filtered_stream, usage=usage)
except HTTPException: except HTTPException:
raise raise
@ -129,39 +174,55 @@ async def send_message_to_agent(
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str): def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password) get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/messages/context/", tags=["agents"], response_model=List[Message]) @router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages_in_context( def get_agent_messages(
agent_id: str, agent_id: uuid.UUID,
start: int = Query(..., description="Message index to start on (reverse chronological)."), start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."), count: int = Query(..., description="How many messages to retrieve."),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate. Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
""" """
interface.clear() # Validate with the Pydantic model (optional)
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=start, count=count) request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
return messages # agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message]) interface.clear()
def get_agent_messages( messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
agent_id: str, return GetAgentMessagesResponse(messages=messages)
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
@router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages_cursor(
agent_id: uuid.UUID,
before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."), limit: int = Query(10, description="Maximum number of messages to retrieve."),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Retrieve message history for an agent. Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
""" """
interface.clear() # Validate with the Pydantic model (optional)
return server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, before=before, limit=limit, reverse=True) request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse) interface.clear()
[_, messages] = server.get_agent_recall_cursor(
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
)
# print("====> messages-cursor DEBUG")
# for i, msg in enumerate(messages):
# print(f"message {i+1}/{len(messages)}")
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
# print(f"ISO format string: {msg['created_at']}")
# print(msg)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
async def send_message( async def send_message(
# background_tasks: BackgroundTasks, # background_tasks: BackgroundTasks,
agent_id: str, agent_id: uuid.UUID,
request: MemGPTRequest = Body(...), request: UserMessageRequest = Body(...),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Process a user message and return the agent's response. Process a user message and return the agent's response.
@ -169,21 +230,17 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
This endpoint accepts a message from a user and processes it through the agent. This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True. It can optionally stream the response if 'stream' is set to True.
""" """
# TODO: should this recieve multiple messages? @cpacker
# TODO: revise to `MemGPTRequest`
# TODO: support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
message = request.messages[0]
# TODO: what to do with message.name?
return await send_message_to_agent( return await send_message_to_agent(
server=server, server=server,
agent_id=agent_id, agent_id=agent_id,
user_id=user_id, user_id=user_id,
role=message.role, role=request.role,
message=message.text, message=request.message,
stream_steps=request.stream_steps, stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens, stream_tokens=request.stream_tokens,
timestamp=request.timestamp,
# legacy
stream_legacy=request.stream,
) )
return router return router

View File

@ -1,73 +0,0 @@
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from memgpt.schemas.block import Block, CreateBlock
from memgpt.schemas.block import Human as HumanModel # TODO: modify
from memgpt.schemas.block import UpdateBlock
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_block_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/blocks", tags=["block"], response_model=List[Block])
async def list_blocks(
user_id: str = Depends(get_current_user_with_server),
# query parameters
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
):
# Clear the interface
interface.clear()
blocks = server.get_blocks(user_id=user_id, label=label, template=templates_only, name=name)
if blocks is None:
return []
return blocks
@router.post("/blocks", tags=["block"], response_model=Block)
async def create_block(
request: CreateBlock = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
request.user_id = user_id # TODO: remove?
return server.create_block(user_id=user_id, request=request)
@router.post("/blocks/{block_id}", tags=["block"], response_model=Block)
async def update_block(
block_id: str,
request: UpdateBlock = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
# TODO: should this be in the param or the POST data?
assert block_id == request.id
return server.update_block(user_id=user_id, request=request)
@router.delete("/blocks/{block_id}", tags=["block"], response_model=Block)
async def delete_block(
block_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
return server.delete_block(block_id=block_id)
@router.get("/blocks/{block_id}", tags=["block"], response_model=Block)
async def get_block(
block_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
block = server.get_block(block_id=block_id)
if block is None:
raise HTTPException(status_code=404, detail="Block not found")
return block
return router

View File

@ -1,11 +1,9 @@
import uuid
from functools import partial from functools import partial
from typing import List
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
@ -21,20 +19,13 @@ class ConfigResponse(BaseModel):
def setup_config_index_router(server: SyncServer, interface: QueuingInterface, password: str): def setup_config_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password) get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/config/llm", tags=["config"], response_model=List[LLMConfig]) @router.get("/config", tags=["config"], response_model=ConfigResponse)
def get_llm_configs(user_id: str = Depends(get_current_user_with_server)): def get_server_config(user_id: uuid.UUID = Depends(get_current_user_with_server)):
""" """
Retrieve the base configuration for the server. Retrieve the base configuration for the server.
""" """
interface.clear() interface.clear()
return [server.server_llm_config] response = server.get_server_config(include_defaults=True)
return ConfigResponse(config=response["config"], defaults=response["defaults"])
@router.get("/config/embedding", tags=["config"], response_model=List[EmbeddingConfig])
def get_embedding_configs(user_id: str = Depends(get_current_user_with_server)):
"""
Retrieve the base configuration for the server.
"""
interface.clear()
return [server.server_embedding_config]
return router return router

View File

@ -0,0 +1,69 @@
import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import HumanModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class ListHumansResponse(BaseModel):
humans: List[HumanModel] = Field(..., description="List of human configurations.")
class CreateHumanRequest(BaseModel):
text: str = Field(..., description="The human text.")
name: str = Field(..., description="The name of the human.")
def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/humans", tags=["humans"], response_model=ListHumansResponse)
async def list_humans(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# Clear the interface
interface.clear()
humans = server.ms.list_humans(user_id=user_id)
return ListHumansResponse(humans=humans)
@router.post("/humans", tags=["humans"], response_model=HumanModel)
async def create_human(
request: CreateHumanRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# TODO: disallow duplicate names for humans
interface.clear()
new_human = HumanModel(text=request.text, name=request.name, user_id=user_id)
human_id = new_human.id
server.ms.add_human(new_human)
return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id)
@router.delete("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
async def delete_human(
human_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
human = server.ms.delete_human(human_name, user_id=user_id)
return human
@router.get("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
async def get_human(
human_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
human = server.ms.get_human(human_name, user_id=user_id)
if human is None:
raise HTTPException(status_code=404, detail="Human not found")
return human
return router

View File

@ -2,24 +2,13 @@ import asyncio
import json import json
import queue import queue
from collections import deque from collections import deque
from typing import AsyncGenerator, Literal, Optional, Union from typing import AsyncGenerator, Optional
from memgpt.data_types import Message
from memgpt.interface import AgentInterface from memgpt.interface import AgentInterface
from memgpt.schemas.enums import MessageStreamStatus from memgpt.models.chat_completion_response import ChatCompletionChunkResponse
from memgpt.schemas.memgpt_message import (
AssistantMessage,
FunctionCall,
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
LegacyFunctionCallMessage,
LegacyMemGPTMessage,
MemGPTMessage,
)
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from memgpt.streaming_interface import AgentChunkStreamingInterface from memgpt.streaming_interface import AgentChunkStreamingInterface
from memgpt.utils import is_utc_datetime from memgpt.utils import get_utc_time, is_utc_datetime
class QueuingInterface(AgentInterface): class QueuingInterface(AgentInterface):
@ -29,66 +18,12 @@ class QueuingInterface(AgentInterface):
self.buffer = queue.Queue() self.buffer = queue.Queue()
self.debug = debug self.debug = debug
def _queue_push(self, message_api: Union[str, dict], message_obj: Union[Message, None]): def to_list(self):
"""Wrapper around self.buffer.queue.put() that ensures the types are safe
Data will be in the format: {
"message_obj": ...
"message_string": ...
}
"""
# Check the string first
if isinstance(message_api, str):
# check that it's the stop word
if message_api == "STOP":
assert message_obj is None
self.buffer.put(
{
"message_api": message_api,
"message_obj": None,
}
)
else:
raise ValueError(f"Unrecognized string pushed to buffer: {message_api}")
elif isinstance(message_api, dict):
# check if it's the error message style
if len(message_api.keys()) == 1 and "internal_error" in message_api:
assert message_obj is None
self.buffer.put(
{
"message_api": message_api,
"message_obj": None,
}
)
else:
assert message_obj is not None, message_api
self.buffer.put(
{
"message_api": message_api,
"message_obj": message_obj,
}
)
else:
raise ValueError(f"Unrecognized type pushed to buffer: {type(message_api)}")
def to_list(self, style: Literal["obj", "api"] = "obj"):
"""Convert queue to a list (empties it out at the same time)""" """Convert queue to a list (empties it out at the same time)"""
items = [] items = []
while not self.buffer.empty(): while not self.buffer.empty():
try: try:
# items.append(self.buffer.get_nowait()) items.append(self.buffer.get_nowait())
item_to_push = self.buffer.get_nowait()
if style == "obj":
if item_to_push["message_obj"] is not None:
items.append(item_to_push["message_obj"])
elif style == "api":
items.append(item_to_push["message_api"])
else:
raise ValueError(style)
except queue.Empty: except queue.Empty:
break break
if len(items) > 1 and items[-1] == "STOP": if len(items) > 1 and items[-1] == "STOP":
@ -101,30 +36,20 @@ class QueuingInterface(AgentInterface):
# Empty the queue # Empty the queue
self.buffer.queue.clear() self.buffer.queue.clear()
async def message_generator(self, style: Literal["obj", "api"] = "obj"): async def message_generator(self):
while True: while True:
if not self.buffer.empty(): if not self.buffer.empty():
message = self.buffer.get() message = self.buffer.get()
message_obj = message["message_obj"] if message == "STOP":
message_api = message["message_api"]
if message_api == "STOP":
break break
# yield message | {"date": datetime.now(tz=pytz.utc).isoformat()}
# yield message yield message
if style == "obj":
yield message_obj
elif style == "api":
yield message_api
else:
raise ValueError(style)
else: else:
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
def step_yield(self): def step_yield(self):
"""Enqueue a special stop message""" """Enqueue a special stop message"""
self._queue_push(message_api="STOP", message_obj=None) self.buffer.put("STOP")
@staticmethod @staticmethod
def step_complete(): def step_complete():
@ -132,8 +57,8 @@ class QueuingInterface(AgentInterface):
def error(self, error: str): def error(self, error: str):
"""Enqueue a special stop message""" """Enqueue a special stop message"""
self._queue_push(message_api={"internal_error": error}, message_obj=None) self.buffer.put({"internal_error": error})
self._queue_push(message_api="STOP", message_obj=None) self.buffer.put("STOP")
def user_message(self, msg: str, msg_obj: Optional[Message] = None): def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Handle reception of a user message""" """Handle reception of a user message"""
@ -159,7 +84,7 @@ class QueuingInterface(AgentInterface):
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat() new_message["date"] = msg_obj.created_at.isoformat()
self._queue_push(message_api=new_message, message_obj=msg_obj) self.buffer.put(new_message)
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None: def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent sending a message""" """Handle the agent sending a message"""
@ -183,13 +108,11 @@ class QueuingInterface(AgentInterface):
assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message." assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message."
# TODO also should not be accessing protected member here # TODO also should not be accessing protected member here
new_message["id"] = self.buffer.queue[-1]["message_api"]["id"] new_message["id"] = self.buffer.queue[-1]["id"]
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at # assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = self.buffer.queue[-1]["message_api"]["date"] new_message["date"] = self.buffer.queue[-1]["date"]
msg_obj = self.buffer.queue[-1]["message_obj"] self.buffer.put(new_message)
self._queue_push(message_api=new_message, message_obj=msg_obj)
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None: def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
"""Handle the agent calling a function""" """Handle the agent calling a function"""
@ -229,7 +152,7 @@ class QueuingInterface(AgentInterface):
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat() new_message["date"] = msg_obj.created_at.isoformat()
self._queue_push(message_api=new_message, message_obj=msg_obj) self.buffer.put(new_message)
class FunctionArgumentsStreamHandler: class FunctionArgumentsStreamHandler:
@ -316,21 +239,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# if multi_step = True, the stream ends when the agent yields # if multi_step = True, the stream ends when the agent yields
# if multi_step = False, the stream ends when the step ends # if multi_step = False, the stream ends when the step ends
self.multi_step = multi_step self.multi_step = multi_step
self.multi_step_indicator = MessageStreamStatus.done_step self.multi_step_indicator = "[DONE_STEP]"
self.multi_step_gen_indicator = MessageStreamStatus.done_generation self.multi_step_gen_indicator = "[DONE_GEN]"
# extra prints async def _create_generator(self) -> AsyncGenerator:
self.debug = False
self.timeout = 30
async def _create_generator(self) -> AsyncGenerator[Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus], None]:
"""An asynchronous generator that yields chunks as they become available.""" """An asynchronous generator that yields chunks as they become available."""
while self._active: while self._active:
try: # Wait until there is an item in the deque or the stream is deactivated
# Wait until there is an item in the deque or the stream is deactivated await self._event.wait()
await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout
except asyncio.TimeoutError:
break # Exit the loop if we timeout
while self._chunks: while self._chunks:
yield self._chunks.popleft() yield self._chunks.popleft()
@ -338,33 +254,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# Reset the event until a new item is pushed # Reset the event until a new item is pushed
self._event.clear() self._event.clear()
# while self._active:
# # Wait until there is an item in the deque or the stream is deactivated
# await self._event.wait()
# while self._chunks:
# yield self._chunks.popleft()
# # Reset the event until a new item is pushed
# self._event.clear()
def get_generator(self) -> AsyncGenerator:
"""Get the generator that yields processed chunks."""
if not self._active:
# If the stream is not active, don't return a generator that would produce values
raise StopIteration("The stream has not been started or has been ended.")
return self._create_generator()
def _push_to_buffer(self, item: Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus]):
"""Add an item to the deque"""
assert self._active, "Generator is inactive"
# assert isinstance(item, dict) or isinstance(item, MessageStreamStatus), f"Wrong type: {type(item)}"
assert (
isinstance(item, MemGPTMessage) or isinstance(item, LegacyMemGPTMessage) or isinstance(item, MessageStreamStatus)
), f"Wrong type: {type(item)}"
self._chunks.append(item)
self._event.set() # Signal that new data is available
def stream_start(self): def stream_start(self):
"""Initialize streaming by activating the generator and clearing any old chunks.""" """Initialize streaming by activating the generator and clearing any old chunks."""
self.streaming_chat_completion_mode_function_name = None self.streaming_chat_completion_mode_function_name = None
@ -379,10 +268,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_mode_function_name = None self.streaming_chat_completion_mode_function_name = None
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
self._push_to_buffer(self.multi_step_gen_indicator) self._chunks.append(self.multi_step_gen_indicator)
self._event.set() # Signal that new data is available
# self._active = False
# self._event.set() # Unblock the generator if it's waiting to allow it to complete
# if not self.multi_step: # if not self.multi_step:
# # end the stream # # end the stream
@ -393,27 +280,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# self._chunks.append(self.multi_step_indicator) # self._chunks.append(self.multi_step_indicator)
# self._event.set() # Signal that new data is available # self._event.set() # Signal that new data is available
def step_complete(self):
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
if not self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# signal that a new step has started in the stream
self._push_to_buffer(self.multi_step_indicator)
def step_yield(self):
"""If multi_step, this is the true 'stream_end' function."""
# if self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
@staticmethod
def clear():
return
def _process_chunk_to_memgpt_style(self, chunk: ChatCompletionChunkResponse) -> Optional[dict]: def _process_chunk_to_memgpt_style(self, chunk: ChatCompletionChunkResponse) -> Optional[dict]:
""" """
Example data from non-streaming response looks like: Example data from non-streaming response looks like:
@ -539,7 +405,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if msg_obj: if msg_obj:
processed_chunk["id"] = str(msg_obj.id) processed_chunk["id"] = str(msg_obj.id)
self._push_to_buffer(processed_chunk) self._chunks.append(processed_chunk)
self._event.set() # Signal that new data is available
def get_generator(self) -> AsyncGenerator:
"""Get the generator that yields processed chunks."""
if not self._active:
# If the stream is not active, don't return a generator that would produce values
raise StopIteration("The stream has not been started or has been ended.")
return self._create_generator()
def user_message(self, msg: str, msg_obj: Optional[Message] = None): def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT receives a user message""" """MemGPT receives a user message"""
@ -550,18 +424,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if not self.streaming_mode: if not self.streaming_mode:
# create a fake "chunk" of a stream # create a fake "chunk" of a stream
# processed_chunk = { processed_chunk = {
# "internal_monologue": msg, "internal_monologue": msg,
# "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(), "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
# "id": str(msg_obj.id) if msg_obj is not None else None, "id": str(msg_obj.id) if msg_obj is not None else None,
# } }
processed_chunk = InternalMonologue(
id=msg_obj.id,
date=msg_obj.created_at,
internal_monologue=msg,
)
self._push_to_buffer(processed_chunk) self._chunks.append(processed_chunk)
self._event.set() # Signal that new data is available
return return
@ -603,56 +473,42 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# "date": "2024-06-22T23:04:32.141923+00:00" # "date": "2024-06-22T23:04:32.141923+00:00"
# } # }
try: try:
func_args = json.loads(function_call.function.arguments) func_args = json.loads(function_call.function["arguments"])
except: except:
func_args = function_call.function.arguments func_args = function_call.function["arguments"]
# processed_chunk = { processed_chunk = {
# "function_call": f"{function_call.function.name}({func_args})", "function_call": f"{function_call.function['name']}({func_args})",
# "id": str(msg_obj.id), "id": str(msg_obj.id),
# "date": msg_obj.created_at.isoformat(), "date": msg_obj.created_at.isoformat(),
# } }
processed_chunk = LegacyFunctionCallMessage( self._chunks.append(processed_chunk)
id=msg_obj.id, self._event.set() # Signal that new data is available
date=msg_obj.created_at,
function_call=f"{function_call.function.name}({func_args})",
)
self._push_to_buffer(processed_chunk)
if function_call.function.name == "send_message": if function_call.function["name"] == "send_message":
try: try:
# processed_chunk = { processed_chunk = {
# "assistant_message": func_args["message"], "assistant_message": func_args["message"],
# "id": str(msg_obj.id), "id": str(msg_obj.id),
# "date": msg_obj.created_at.isoformat(), "date": msg_obj.created_at.isoformat(),
# } }
processed_chunk = AssistantMessage( self._chunks.append(processed_chunk)
id=msg_obj.id, self._event.set() # Signal that new data is available
date=msg_obj.created_at,
assistant_message=func_args["message"],
)
self._push_to_buffer(processed_chunk)
except Exception as e: except Exception as e:
print(f"Failed to parse function message: {e}") print(f"Failed to parse function message: {e}")
else: else:
processed_chunk = FunctionCallMessage( processed_chunk = {
id=msg_obj.id, "function_call": {
date=msg_obj.created_at, "id": function_call.id,
function_call=FunctionCall( "name": function_call.function["name"],
name=function_call.function.name, "arguments": function_call.function["arguments"],
arguments=function_call.function.arguments, },
), "id": str(msg_obj.id),
) "date": msg_obj.created_at.isoformat(),
# processed_chunk = { }
# "function_call": { self._chunks.append(processed_chunk)
# "name": function_call.function.name, self._event.set() # Signal that new data is available
# "arguments": function_call.function.arguments,
# },
# "id": str(msg_obj.id),
# "date": msg_obj.created_at.isoformat(),
# }
self._push_to_buffer(processed_chunk)
return return
else: else:
@ -667,33 +523,43 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
elif msg.startswith("Success: "): elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "") msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"} new_message = {"function_return": msg, "status": "success"}
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="success",
)
elif msg.startswith("Error: "): elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "") msg = msg.replace("Error: ", "")
# new_message = {"function_return": msg, "status": "error"} new_message = {"function_return": msg, "status": "error"}
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="error",
)
else: else:
# NOTE: generic, should not happen # NOTE: generic, should not happen
raise ValueError(msg)
new_message = {"function_message": msg} new_message = {"function_message": msg}
# add extra metadata # add extra metadata
# if msg_obj is not None: if msg_obj is not None:
# new_message["id"] = str(msg_obj.id) new_message["id"] = str(msg_obj.id)
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
# new_message["date"] = msg_obj.created_at.isoformat() new_message["date"] = msg_obj.created_at.isoformat()
self._push_to_buffer(new_message) self._chunks.append(new_message)
self._event.set() # Signal that new data is available
def step_complete(self):
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
if not self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# signal that a new step has started in the stream
self._chunks.append(self.multi_step_indicator)
self._event.set() # Signal that new data is available
def step_yield(self):
"""If multi_step, this is the true 'stream_end' function."""
if self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
@staticmethod
def clear():
return

View File

@ -4,7 +4,7 @@ from typing import List
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from memgpt.schemas.llm_config import LLMConfig from memgpt.models.pydantic_models import LLMConfigModel
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
@ -13,7 +13,7 @@ router = APIRouter()
class ListModelsResponse(BaseModel): class ListModelsResponse(BaseModel):
models: List[LLMConfig] = Field(..., description="List of model configurations.") models: List[LLMConfigModel] = Field(..., description="List of model configurations.")
def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str): def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str):
@ -25,7 +25,7 @@ def setup_models_index_router(server: SyncServer, interface: QueuingInterface, p
interface.clear() interface.clear()
# currently, the server only supports one model, however this may change in the future # currently, the server only supports one model, however this may change in the future
llm_config = LLMConfig( llm_config = LLMConfigModel(
model=server.server_llm_config.model, model=server.server_llm_config.model,
model_endpoint=server.server_llm_config.model_endpoint, model_endpoint=server.server_llm_config.model_endpoint,
model_endpoint_type=server.server_llm_config.model_endpoint_type, model_endpoint_type=server.server_llm_config.model_endpoint_type,

View File

@ -4,9 +4,10 @@ from typing import List, Optional
from fastapi import APIRouter, Body, HTTPException, Path, Query from fastapi import APIRouter, Body, HTTPException, Path, Query
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from memgpt.config import MemGPTConfig
from memgpt.constants import DEFAULT_PRESET from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.message import Message from memgpt.data_types import Message
from memgpt.schemas.openai.openai import ( from memgpt.models.openai import (
AssistantFile, AssistantFile,
MessageFile, MessageFile,
MessageRoleType, MessageRoleType,
@ -138,6 +139,10 @@ class SubmitToolOutputsToRunRequest(BaseModel):
# TODO: implement mechanism for creating/authenticating users associated with a bearer token # TODO: implement mechanism for creating/authenticating users associated with a bearer token
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface): def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
# TODO: remove this (when we have user auth)
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
print(f"User ID: {user_id}")
# create assistant (MemGPT agent) # create assistant (MemGPT agent)
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant) @router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)): def create_assistant(request: CreateAssistantRequest = Body(...)):

View File

@ -4,9 +4,9 @@ from functools import partial
from fastapi import APIRouter, Body, Depends, HTTPException from fastapi import APIRouter, Body, Depends, HTTPException
# from memgpt.schemas.message import Message # from memgpt.data_types import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest from memgpt.models.chat_completion_request import ChatCompletionRequest
from memgpt.schemas.openai.chat_completion_response import ( from memgpt.models.chat_completion_response import (
ChatCompletionResponse, ChatCompletionResponse,
Choice, Choice,
Message, Message,

View File

@ -5,7 +5,7 @@ from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from memgpt.schemas.block import Persona as PersonaModel # TODO: modify from memgpt.models.pydantic_models import PersonaModel
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
@ -44,7 +44,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
interface.clear() interface.clear()
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id) new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
persona_id = new_persona.id persona_id = new_persona.id
server.ms.create_persona(new_persona) server.ms.add_persona(new_persona)
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id) return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)
@router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel) @router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)

View File

@ -0,0 +1,171 @@
import uuid
from functools import partial
from typing import Dict, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.data_types import Preset # TODO remove
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
from memgpt.prompts import gpt_system
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text, get_persona_text
router = APIRouter()
"""
Implement the following functions:
* List all available presets
* Create a new preset
* Delete a preset
* TODO update a preset
"""
class ListPresetsResponse(BaseModel):
presets: List[PresetModel] = Field(..., description="List of available presets.")
class CreatePresetsRequest(BaseModel):
# TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)?
name: str = Field(..., description="The name of the preset.")
id: Optional[str] = Field(None, description="The unique identifier of the preset.")
# user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: Optional[str] = Field(None, description="The system prompt of the preset.") # TODO: make optional and allow defaults
persona: Optional[str] = Field(default=None, description="The persona of the preset.")
human: Optional[str] = Field(default=None, description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# TODO
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
class CreatePresetResponse(BaseModel):
preset: PresetModel = Field(..., description="The newly created preset.")
def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/presets/{preset_name}", tags=["presets"], response_model=PresetModel)
async def get_preset(
preset_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Get a preset."""
try:
preset = server.get_preset(user_id=user_id, preset_name=preset_name)
return preset
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.get("/presets", tags=["presets"], response_model=ListPresetsResponse)
async def list_presets(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all presets created by a user."""
# Clear the interface
interface.clear()
try:
presets = server.list_presets(user_id=user_id)
return ListPresetsResponse(presets=presets)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/presets", tags=["presets"], response_model=CreatePresetResponse)
async def create_preset(
request: CreatePresetsRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Create a preset."""
try:
if isinstance(request.id, str):
request.id = uuid.UUID(request.id)
# check if preset already exists
# TODO: move this into a server function to create a preset
if server.ms.get_preset(name=request.name, user_id=user_id):
raise HTTPException(status_code=400, detail=f"Preset with name {request.name} already exists.")
# For system/human/persona - if {system/human-personal}_name is None but the text is provied, then create a new data entry
if not request.system_name and request.system:
# new system provided without name identity
system_name = f"system_{request.name}_{str(uuid.uuid4())}"
system = request.system
# TODO: insert into system table
else:
system_name = request.system_name if request.system_name else DEFAULT_PRESET
system = request.system if request.system else gpt_system.get_system_text(system_name)
if not request.human_name and request.human:
# new human provided without name identity
human_name = f"human_{request.name}_{str(uuid.uuid4())}"
human = request.human
server.ms.add_human(HumanModel(text=human, name=human_name, user_id=user_id))
else:
human_name = request.human_name if request.human_name else DEFAULT_HUMAN
human = request.human if request.human else get_human_text(human_name)
if not request.persona_name and request.persona:
# new persona provided without name identity
persona_name = f"persona_{request.name}_{str(uuid.uuid4())}"
persona = request.persona
server.ms.add_persona(PersonaModel(text=persona, name=persona_name, user_id=user_id))
else:
persona_name = request.persona_name if request.persona_name else DEFAULT_PERSONA
persona = request.persona if request.persona else get_persona_text(persona_name)
# create preset
new_preset = Preset(
user_id=user_id,
id=request.id if request.id else uuid.uuid4(),
name=request.name,
description=request.description,
system=system,
persona=persona,
persona_name=persona_name,
human=human,
human_name=human_name,
functions_schema=request.functions_schema,
)
preset = server.create_preset(preset=new_preset)
# TODO remove once we migrate from Preset to PresetModel
preset = PresetModel(**vars(preset))
return CreatePresetResponse(preset=preset)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.delete("/presets/{preset_id}", tags=["presets"])
async def delete_preset(
preset_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a preset."""
interface.clear()
try:
preset = server.delete_preset(user_id=user_id, preset_id=preset_id)
return JSONResponse(
status_code=status.HTTP_200_OK, content={"message": f"Preset preset_id={str(preset.id)} successfully deleted"}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@ -15,12 +15,14 @@ from memgpt.server.constants import REST_DEFAULT_PORT
from memgpt.server.rest_api.admin.agents import setup_agents_admin_router from memgpt.server.rest_api.admin.agents import setup_agents_admin_router
from memgpt.server.rest_api.admin.tools import setup_tools_index_router from memgpt.server.rest_api.admin.tools import setup_tools_index_router
from memgpt.server.rest_api.admin.users import setup_admin_router from memgpt.server.rest_api.admin.users import setup_admin_router
from memgpt.server.rest_api.agents.command import setup_agents_command_router
from memgpt.server.rest_api.agents.config import setup_agents_config_router
from memgpt.server.rest_api.agents.index import setup_agents_index_router from memgpt.server.rest_api.agents.index import setup_agents_index_router
from memgpt.server.rest_api.agents.memory import setup_agents_memory_router from memgpt.server.rest_api.agents.memory import setup_agents_memory_router
from memgpt.server.rest_api.agents.message import setup_agents_message_router from memgpt.server.rest_api.agents.message import setup_agents_message_router
from memgpt.server.rest_api.auth.index import setup_auth_router from memgpt.server.rest_api.auth.index import setup_auth_router
from memgpt.server.rest_api.block.index import setup_block_index_router
from memgpt.server.rest_api.config.index import setup_config_index_router from memgpt.server.rest_api.config.index import setup_config_index_router
from memgpt.server.rest_api.humans.index import setup_humans_index_router
from memgpt.server.rest_api.interface import StreamingServerInterface from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.models.index import setup_models_index_router from memgpt.server.rest_api.models.index import setup_models_index_router
from memgpt.server.rest_api.openai_assistants.assistants import ( from memgpt.server.rest_api.openai_assistants.assistants import (
@ -29,6 +31,8 @@ from memgpt.server.rest_api.openai_assistants.assistants import (
from memgpt.server.rest_api.openai_chat_completions.chat_completions import ( from memgpt.server.rest_api.openai_chat_completions.chat_completions import (
setup_openai_chat_completions_router, setup_openai_chat_completions_router,
) )
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.presets.index import setup_presets_index_router
from memgpt.server.rest_api.sources.index import setup_sources_index_router from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.rest_api.static_files import mount_static_files from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_user_tools_index_router from memgpt.server.rest_api.tools.index import setup_user_tools_index_router
@ -91,13 +95,17 @@ app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PRE
app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)]) app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)])
# /api/agents endpoints # /api/agents endpoints
app.include_router(setup_agents_command_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_config_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_agents_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)
# /api/config endpoints # /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
@ -145,8 +153,7 @@ def on_startup():
@app.on_event("shutdown") @app.on_event("shutdown")
def on_shutdown(): def on_shutdown():
global server global server
if server: server.save_agents()
server.save_agents()
server = None server = None

View File

@ -1,7 +1,8 @@
import os import os
import tempfile import tempfile
import uuid
from functools import partial from functools import partial
from typing import List from typing import List, Optional
from fastapi import ( from fastapi import (
APIRouter, APIRouter,
@ -11,14 +12,20 @@ from fastapi import (
HTTPException, HTTPException,
Query, Query,
UploadFile, UploadFile,
status,
) )
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.schemas.document import Document from memgpt.data_sources.connectors import DirectoryConnector
from memgpt.schemas.job import Job from memgpt.data_types import Source
from memgpt.schemas.passage import Passage from memgpt.models.pydantic_models import (
DocumentModel,
# schemas JobModel,
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate, UploadFile JobStatus,
PassageModel,
SourceModel,
)
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
@ -37,73 +44,77 @@ Implement the following functions:
""" """
# class ListSourcesResponse(BaseModel): class ListSourcesResponse(BaseModel):
# sources: List[SourceModel] = Field(..., description="List of available sources.") sources: List[SourceModel] = Field(..., description="List of available sources.")
#
#
# class CreateSourceRequest(BaseModel):
# name: str = Field(..., description="The name of the source.")
# description: Optional[str] = Field(None, description="The description of the source.")
#
#
# class UploadFileToSourceRequest(BaseModel):
# file: UploadFile = Field(..., description="The file to upload.")
#
#
# class UploadFileToSourceResponse(BaseModel):
# source: SourceModel = Field(..., description="The source the file was uploaded to.")
# added_passages: int = Field(..., description="The number of passages added to the source.")
# added_documents: int = Field(..., description="The number of documents added to the source.")
#
#
# class GetSourcePassagesResponse(BaseModel):
# passages: List[PassageModel] = Field(..., description="List of passages from the source.")
#
#
# class GetSourceDocumentsResponse(BaseModel):
# documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes): class CreateSourceRequest(BaseModel):
# write the file to a temporary directory (deleted after the context manager exits) name: str = Field(..., description="The name of the source.")
with tempfile.TemporaryDirectory() as tmpdirname: description: Optional[str] = Field(None, description="The description of the source.")
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(bytes)
server.load_file_to_source(source_id, file_path, job_id)
class UploadFileToSourceRequest(BaseModel):
file: UploadFile = Field(..., description="The file to upload.")
class UploadFileToSourceResponse(BaseModel):
source: SourceModel = Field(..., description="The source the file was uploaded to.")
added_passages: int = Field(..., description="The number of passages added to the source.")
added_documents: int = Field(..., description="The number of documents added to the source.")
class GetSourcePassagesResponse(BaseModel):
passages: List[PassageModel] = Field(..., description="List of passages from the source.")
class GetSourceDocumentsResponse(BaseModel):
documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
def load_file_to_source(server: SyncServer, user_id: uuid.UUID, source: Source, job_id: uuid.UUID, file: UploadFile, bytes: bytes):
# update job status
job = server.ms.get_job(job_id=job_id)
job.status = JobStatus.running
server.ms.update_job(job)
try:
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(bytes)
# read the file
connector = DirectoryConnector(input_files=[file_path])
# TODO: pre-compute total number of passages?
# load the data into the source via the connector
num_passages, num_documents = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
except Exception as e:
# job failed with error
error = str(e)
print(error)
job.status = JobStatus.failed
job.metadata_["error"] = error
server.ms.update_job(job)
# TODO: delete any associated passages/documents?
return 0, 0
# update job status
job.status = JobStatus.completed
job.metadata_["num_passages"] = num_passages
job.metadata_["num_documents"] = num_documents
print("job completed", job.metadata_, job.id)
server.ms.update_job(job)
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str): def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password) get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/sources/{source_id}", tags=["sources"], response_model=Source) @router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
async def get_source(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get all sources
"""
interface.clear()
source = server.get_source(source_id=source_id, user_id=user_id)
return source
@router.get("/sources/name/{source_name}", tags=["sources"], response_model=str)
async def get_source_id_by_name(
source_name: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a source by name
"""
interface.clear()
source = server.get_source_id(source_name=source_name, user_id=user_id)
return source
@router.get("/sources", tags=["sources"], response_model=List[Source])
async def list_sources( async def list_sources(
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
List all data sources created by a user. List all data sources created by a user.
@ -113,40 +124,58 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
try: try:
sources = server.list_all_sources(user_id=user_id) sources = server.list_all_sources(user_id=user_id)
return sources return ListSourcesResponse(sources=sources)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources", tags=["sources"], response_model=Source) @router.post("/sources", tags=["sources"], response_model=SourceModel)
async def create_source( async def create_source(
request: SourceCreate = Body(...), request: CreateSourceRequest = Body(...),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Create a new data source. Create a new data source.
""" """
interface.clear() interface.clear()
try: try:
return server.create_source(request=request, user_id=user_id) # TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}", tags=["sources"], response_model=Source) @router.post("/sources/{source_id}", tags=["sources"], response_model=SourceModel)
async def update_source( async def update_source(
source_id: str, source_id: uuid.UUID,
request: SourceUpdate = Body(...), request: CreateSourceRequest = Body(...),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Update the name or documentation of an existing data source. Update the name or documentation of an existing data source.
""" """
interface.clear() interface.clear()
try: try:
return server.update_source(request=request, user_id=user_id) # TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.update_source(source_id=source_id, name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@ -154,8 +183,8 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
@router.delete("/sources/{source_id}", tags=["sources"]) @router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source( async def delete_source(
source_id: str, source_id: uuid.UUID,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Delete a data source. Delete a data source.
@ -163,58 +192,66 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
interface.clear() interface.clear()
try: try:
server.delete_source(source_id=source_id, user_id=user_id) server.delete_source(source_id=source_id, user_id=user_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}") raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=Source) @router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
async def attach_source_to_agent( async def attach_source_to_agent(
source_id: str, source_id: uuid.UUID,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."), agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Attach a data source to an existing agent. Attach a data source to an existing agent.
""" """
interface.clear() interface.clear()
assert isinstance(agent_id, str), f"Expected agent_id to be a UUID, got {agent_id}" assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, str), f"Expected user_id to be a UUID, got {user_id}" assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
source = server.ms.get_source(source_id=source_id, user_id=user_id) source = server.ms.get_source(source_id=source_id, user_id=user_id)
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=user_id) source = server.attach_source_to_agent(source_name=source.name, agent_id=agent_id, user_id=user_id)
return source return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at,
)
@router.post("/sources/{source_id}/detach", tags=["sources"]) @router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
async def detach_source_from_agent( async def detach_source_from_agent(
source_id: str, source_id: uuid.UUID,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."), agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Detach a data source from an existing agent. Detach a data source from an existing agent.
""" """
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id) server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=Job) @router.get("/sources/status/{job_id}", tags=["sources"], response_model=JobModel)
async def get_job( async def get_job_status(
job_id: str, job_id: uuid.UUID,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Get the status of a job. Get the status of a job.
""" """
job = server.get_job(job_id=job_id) job = server.ms.get_job(job_id=job_id)
if job is None: if job is None:
raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.") raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.")
return job return job
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=Job) @router.post("/sources/{source_id}/upload", tags=["sources"], response_model=JobModel)
async def upload_file_to_source( async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."), # file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile, file: UploadFile,
source_id: str, source_id: uuid.UUID,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Upload a file to a data source. Upload a file to a data source.
@ -224,39 +261,37 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
bytes = file.file.read() bytes = file.file.read()
# create job # create job
# TODO: create server function job = JobModel(user_id=user_id, metadata={"type": "embedding", "filename": file.filename, "source_id": source_id})
job = Job(user_id=user_id, metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id})
job_id = job.id job_id = job.id
server.ms.create_job(job) server.ms.create_job(job)
# create background task # create background task
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes) background_tasks.add_task(load_file_to_source, server, user_id, source, job_id, file, bytes)
# return job information # return job information
job = server.ms.get_job(job_id=job_id) job = server.ms.get_job(job_id=job_id)
return job return job
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=List[Passage]) @router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages( async def list_passages(
source_id: str, source_id: uuid.UUID,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
List all passages associated with a data source. List all passages associated with a data source.
""" """
# TODO: check if paginated?
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id) passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return passages return GetSourcePassagesResponse(passages=passages)
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=List[Document]) @router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
async def list_documents( async def list_documents(
source_id: str, source_id: uuid.UUID,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
List all documents associated with a data source. List all documents associated with a data source.
""" """
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id) documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
return documents return GetSourceDocumentsResponse(documents=documents)
return router return router

View File

@ -1,9 +1,11 @@
import uuid
from functools import partial from functools import partial
from typing import List from typing import List, Literal, Optional
from fastapi import APIRouter, Body, Depends, HTTPException from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer from memgpt.server.server import SyncServer
@ -11,92 +13,121 @@ from memgpt.server.server import SyncServer
router = APIRouter() router = APIRouter()
class ListToolsResponse(BaseModel):
tools: List[ToolModel] = Field(..., description="List of tools (functions).")
class CreateToolRequest(BaseModel):
json_schema: dict = Field(..., description="JSON schema of the tool.") # NOT OpenAI - just has `name`
source_code: str = Field(..., description="The source code of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
update: Optional[bool] = Field(False, description="Update the tool if it already exists.")
class CreateToolResponse(BaseModel):
tool: ToolModel = Field(..., description="Information about the newly created tool.")
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str): def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password) get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.delete("/tools/{tool_id}", tags=["tools"]) @router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool( async def delete_tool(
tool_id: str, tool_name: str,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Delete a tool by name Delete a tool by name
""" """
# Clear the interface # Clear the interface
interface.clear() interface.clear()
server.delete_tool(id) # tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name, user_id=user_id)
@router.get("/tools/{tool_id}", tags=["tools"], response_model=Tool) @router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool( async def get_tool(
tool_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
tool = server.get_tool(tool_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
return tool
@router.get("/tools/name/{tool_name}", tags=["tools"], response_model=str)
async def get_tool_id(
tool_name: str, tool_name: str,
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Get a tool by name Get a tool by name
""" """
# Clear the interface # Clear the interface
interface.clear() interface.clear()
tool = server.get_tool_id(tool_name, user_id=user_id) tool = server.ms.get_tool(tool_name=tool_name, user_id=user_id)
if tool is None: if tool is None:
# return 404 error # return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.") raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool return tool
@router.get("/tools", tags=["tools"], response_model=List[Tool]) @router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools( async def list_all_tools(
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Get a list of all tools available to agents created by a user Get a list of all tools available to agents created by a user
""" """
# Clear the interface # Clear the interface
interface.clear() interface.clear()
return server.list_tools(user_id) tools = server.ms.list_tools(user_id=user_id)
return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=Tool) @router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool( async def create_tool(
request: ToolCreate = Body(...), request: CreateToolRequest = Body(...),
user_id: str = Depends(get_current_user_with_server), user_id: uuid.UUID = Depends(get_current_user_with_server),
): ):
""" """
Create a new tool Create a new tool
""" """
try: # NOTE: horrifying code, should be replaced when we migrate dev portal
return server.create_tool(request, user_id=user_id) from memgpt.agent import Agent # nasty: need agent to be defined
except Exception as e: from memgpt.functions.schema_generator import generate_schema
print(e)
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}") name = request.json_schema["name"]
import ast
parsed_code = ast.parse(request.source_code)
function_names = []
# Function to find and print function names
def find_function_names(node):
for child in ast.iter_child_nodes(node):
if isinstance(child, ast.FunctionDef):
# Print the name of the function
function_names.append(child.name)
# Recurse into child nodes
find_function_names(child)
# Find and print function names
find_function_names(parsed_code)
assert len(function_names) == 1, f"Expected 1 function, found {len(function_names)}: {function_names}"
# generate JSON schema
env = {}
env.update(globals())
exec(request.source_code, env)
func = env.get(function_names[0])
json_schema = generate_schema(func, name=name)
from pprint import pprint
pprint(json_schema)
@router.post("/tools/{tool_id}", tags=["tools"], response_model=Tool)
async def update_tool(
tool_id: str,
request: ToolUpdate = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Update an existing tool
"""
try: try:
# TODO: check that the user has access to this tool?
return server.update_tool(request) return server.create_tool(
# json_schema=request.json_schema, # TODO: add back
json_schema=json_schema,
source_code=request.source_code,
source_type=request.source_type,
tags=request.tags,
user_id=user_id,
exists_ok=request.update,
)
except Exception as e: except Exception as e:
print(e) print(e)
raise HTTPException(status_code=500, detail=f"Failed to update tool: {e}") raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}, exists_ok={request.update}")
return router return router

View File

@ -1,14 +1,10 @@
import asyncio
import json import json
import traceback import traceback
from enum import Enum from typing import AsyncGenerator, Generator, Union
from typing import AsyncGenerator, Union
from pydantic import BaseModel
from memgpt.constants import JSON_ENSURE_ASCII from memgpt.constants import JSON_ENSURE_ASCII
SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n"
SSE_FINISH_MSG = "[DONE]" # mimic openai SSE_FINISH_MSG = "[DONE]" # mimic openai
SSE_ARTIFICIAL_DELAY = 0.1 SSE_ARTIFICIAL_DELAY = 0.1
@ -17,7 +13,27 @@ def sse_formatter(data: Union[dict, str]) -> str:
"""Prefix with 'data: ', and always include double newlines""" """Prefix with 'data: ', and always include double newlines"""
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}" assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data
return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}" return f"data: {data_str}\n\n"
async def sse_generator(generator: Generator[dict, None, None]) -> Generator[str, None, None]:
"""Generator that returns 'data: dict' formatted items, e.g.:
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
data: [DONE]
"""
try:
for msg in generator:
yield sse_formatter(msg)
if SSE_ARTIFICIAL_DELAY:
await asyncio.sleep(SSE_ARTIFICIAL_DELAY) # Sleep to prevent a tight loop, adjust time as needed
except Exception as e:
yield sse_formatter({"error": f"{str(e)}"})
yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete
async def sse_async_generator(generator: AsyncGenerator, finish_message=True): async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
@ -33,20 +49,12 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
try: try:
async for chunk in generator: async for chunk in generator:
# yield f"data: {json.dumps(chunk)}\n\n" # yield f"data: {json.dumps(chunk)}\n\n"
if isinstance(chunk, BaseModel):
chunk = chunk.model_dump()
elif isinstance(chunk, Enum):
chunk = str(chunk.value)
elif not isinstance(chunk, dict):
chunk = str(chunk)
yield sse_formatter(chunk) yield sse_formatter(chunk)
except Exception as e: except Exception as e:
print("stream decoder hit error:", e) print("stream decoder hit error:", e)
print(traceback.print_stack()) print(traceback.print_stack())
yield sse_formatter({"error": "stream decoder encountered an error"}) yield sse_formatter({"error": "stream decoder encountered an error"})
finally: finally:
# yield "data: [DONE]\n\n"
if finish_message: if finish_message:
# Signal that the stream is complete yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete
yield sse_formatter(SSE_FINISH_MSG)

File diff suppressed because it is too large Load Diff

View File

@ -43,12 +43,5 @@ class Settings(BaseSettings):
return None return None
class TestSettings(Settings):
model_config = SettingsConfigDict(env_prefix="memgpt_test_")
memgpt_dir: Optional[Path] = Field(Path.home() / ".memgpt/test", env="MEMGPT_TEST_DIR")
# singleton # singleton
settings = Settings() settings = Settings()
test_settings = TestSettings()

View File

@ -7,9 +7,9 @@ from rich.console import Console
from rich.live import Live from rich.live import Live
from rich.markup import escape from rich.markup import escape
from memgpt.data_types import Message
from memgpt.interface import CLIInterface from memgpt.interface import CLIInterface
from memgpt.schemas.message import Message from memgpt.models.chat_completion_response import (
from memgpt.schemas.openai.chat_completion_response import (
ChatCompletionChunkResponse, ChatCompletionChunkResponse,
ChatCompletionResponse, ChatCompletionResponse,
) )

View File

@ -33,8 +33,8 @@ from memgpt.constants import (
MEMGPT_DIR, MEMGPT_DIR,
TOOL_CALL_ID_MAX_LEN, TOOL_CALL_ID_MAX_LEN,
) )
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.openai_backcompat.openai_object import OpenAIObject from memgpt.openai_backcompat.openai_object import OpenAIObject
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
DEBUG = False DEBUG = False
if "LOG_LEVEL" in os.environ: if "LOG_LEVEL" in os.environ:
@ -468,17 +468,6 @@ NOUN_BANK = [
] ]
def deduplicate(target_list: list) -> list:
seen = set()
dedup_list = []
for i in target_list:
if i not in seen:
seen.add(i)
dedup_list.append(i)
return dedup_list
def smart_urljoin(base_url: str, relative_url: str) -> str: def smart_urljoin(base_url: str, relative_url: str) -> str:
"""urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off""" """urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off"""
if not base_url.endswith("/"): if not base_url.endswith("/"):

4360
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -28,13 +28,14 @@ pgvector = { version = "^0.2.3", optional = true }
pre-commit = {version = "^3.5.0", optional = true } pre-commit = {version = "^3.5.0", optional = true }
pg8000 = {version = "^1.30.3", optional = true} pg8000 = {version = "^1.30.3", optional = true}
websockets = {version = "^12.0", optional = true} websockets = {version = "^12.0", optional = true}
docstring-parser = ">=0.16,<0.17" docstring-parser = "^0.15"
lancedb = "^0.3.3"
httpx = "^0.25.2" httpx = "^0.25.2"
numpy = "^1.26.2" numpy = "^1.26.2"
demjson3 = "^3.0.6" demjson3 = "^3.0.6"
#tiktoken = ">=0.7.0,<0.8.0" tiktoken = "^0.5.1"
pyyaml = "^6.0.1" pyyaml = "^6.0.1"
chromadb = ">=0.4.24,<0.5.0" chromadb = "^0.5.0"
sqlalchemy-json = "^0.7.0" sqlalchemy-json = "^0.7.0"
fastapi = {version = "^0.104.1", optional = true} fastapi = {version = "^0.104.1", optional = true}
uvicorn = {version = "^0.24.0.post1", optional = true} uvicorn = {version = "^0.24.0.post1", optional = true}
@ -50,7 +51,7 @@ pymilvus = {version ="^2.4.3", optional = true}
python-box = "^7.1.1" python-box = "^7.1.1"
sqlmodel = "^0.0.16" sqlmodel = "^0.0.16"
autoflake = {version = "^2.3.0", optional = true} autoflake = {version = "^2.3.0", optional = true}
llama-index = "^0.10.65" llama-index = "^0.10.27"
llama-index-embeddings-openai = "^0.1.1" llama-index-embeddings-openai = "^0.1.1"
llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true} llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true}
llama-index-embeddings-azure-openai = "^0.1.6" llama-index-embeddings-azure-openai = "^0.1.6"
@ -58,15 +59,12 @@ python-multipart = "^0.0.9"
sqlalchemy-utils = "^0.41.2" sqlalchemy-utils = "^0.41.2"
pytest-order = {version = "^1.2.0", optional = true} pytest-order = {version = "^1.2.0", optional = true}
pytest-asyncio = {version = "^0.23.2", optional = true} pytest-asyncio = {version = "^0.23.2", optional = true}
pytest = { version = "^7.4.4", optional = true }
pydantic-settings = "^2.2.1" pydantic-settings = "^2.2.1"
httpx-sse = "^0.4.0" httpx-sse = "^0.4.0"
isort = { version = "^5.13.2", optional = true } isort = { version = "^5.13.2", optional = true }
llama-index-embeddings-ollama = {version = "^0.1.2", optional = true} llama-index-embeddings-ollama = {version = "^0.1.2", optional = true}
crewai = {version = "^0.41.1", optional = true} protobuf = "3.20.0"
crewai-tools = {version = "^0.8.3", optional = true}
docker = {version = "^7.1.0", optional = true}
tiktoken = "^0.7.0"
nltk = "^3.8.1"
[tool.poetry.extras] [tool.poetry.extras]
local = ["llama-index-embeddings-huggingface"] local = ["llama-index-embeddings-huggingface"]
@ -77,7 +75,6 @@ server = ["websockets", "fastapi", "uvicorn"]
autogen = ["pyautogen"] autogen = ["pyautogen"]
qdrant = ["qdrant-client"] qdrant = ["qdrant-client"]
ollama = ["llama-index-embeddings-ollama"] ollama = ["llama-index-embeddings-ollama"]
crewai-tools = ["crewai", "docker", "crewai-tools"]
[tool.poetry.group.dev.dependencies] [tool.poetry.group.dev.dependencies]
black = "^24.4.2" black = "^24.4.2"

View File

@ -1,3 +1,3 @@
# from tests.config import TestMGPTConfig from tests.config import TestMGPTConfig
#
# TEST_MEMGPT_CONFIG = TestMGPTConfig() TEST_MEMGPT_CONFIG = TestMGPTConfig()

View File

@ -3,4 +3,4 @@ pythonpath = /memgpt
testpaths = /tests testpaths = /tests
asyncio_mode = auto asyncio_mode = auto
filterwarnings = filterwarnings =
ignore::pytest.PytestRemovedIn9Warning ignore::pytest.PytestRemovedIn8Warning

View File

@ -1,9 +1,11 @@
import threading import threading
import time import time
import uuid
import pytest import pytest
from memgpt import Admin from memgpt import Admin
from tests.test_client import _reset_config, run_server
test_base_url = "http://localhost:8283" test_base_url = "http://localhost:8283"
@ -11,13 +13,6 @@ test_base_url = "http://localhost:8283"
test_server_token = "test_server_token" test_server_token = "test_server_token"
def run_server():
from memgpt.server.rest_api.server import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def start_uvicorn_server(): def start_uvicorn_server():
"""Starts Uvicorn server in a background thread.""" """Starts Uvicorn server in a background thread."""
@ -39,85 +34,91 @@ def admin_client():
def test_admin_client(admin_client): def test_admin_client(admin_client):
_reset_config()
# create a user # create a user
user_name = "test_user" user_id = uuid.uuid4()
user1 = admin_client.create_user(user_name) create_user1_response = admin_client.create_user(user_id)
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}" assert user_id == create_user1_response.user_id, f"Expected {user_id}, got {create_user1_response.user_id}"
# create another user # create another user
user2 = admin_client.create_user() create_user_2_response = admin_client.create_user()
# create keys # create keys
key1_name = "test_key1" key1_name = "test_key1"
key2_name = "test_key2" key2_name = "test_key2"
api_key1 = admin_client.create_key(user1.id, key1_name) create_key1_response = admin_client.create_key(user_id, key1_name)
admin_client.create_key(user2.id, key2_name) create_key2_response = admin_client.create_key(create_user_2_response.user_id, key2_name)
# list users # list users
users = admin_client.get_users() users = admin_client.get_users()
assert len(users) == 2 assert len(users.user_list) == 2
assert user1.id in [user.id for user in users] print(users.user_list)
assert user2.id in [user.id for user in users] assert user_id in [uuid.UUID(u["user_id"]) for u in users.user_list]
# list keys # list keys
user1_keys = admin_client.get_keys(user1.id) user1_keys = admin_client.get_keys(user_id)
assert len(user1_keys) == 1, f"Expected 1 keys, got {user1_keys}" assert len(user1_keys) == 2, f"Expected 2 keys, got {user1_keys}"
assert api_key1.key == user1_keys[0].key assert create_key1_response.api_key in user1_keys, f"Expected {create_key1_response.api_key} in {user1_keys}"
assert create_user1_response.api_key in user1_keys, f"Expected {create_user1_response.api_key} in {user1_keys}"
# delete key # delete key
deleted_key1 = admin_client.delete_key(api_key1.key) delete_key1_response = admin_client.delete_key(create_key1_response.api_key)
assert deleted_key1.key == api_key1.key assert delete_key1_response.api_key_deleted == create_key1_response.api_key
assert len(admin_client.get_keys(user1.id)) == 0 assert len(admin_client.get_keys(user_id)) == 1
delete_key2_response = admin_client.delete_key(create_key2_response.api_key)
assert delete_key2_response.api_key_deleted == create_key2_response.api_key
assert len(admin_client.get_keys(create_user_2_response.user_id)) == 1
# delete users # delete users
deleted_user1 = admin_client.delete_user(user1.id) delete_user1_response = admin_client.delete_user(user_id)
assert deleted_user1.id == user1.id assert delete_user1_response.user_id_deleted == user_id
deleted_user2 = admin_client.delete_user(user2.id) delete_user2_response = admin_client.delete_user(create_user_2_response.user_id)
assert deleted_user2.id == user2.id assert delete_user2_response.user_id_deleted == create_user_2_response.user_id
# list users # list users
users = admin_client.get_users() users = admin_client.get_users()
assert len(users) == 0, f"Expected 0 users, got {users}" assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
# def test_get_users_pagination(admin_client): def test_get_users_pagination(admin_client):
# _reset_config()
# page_size = 5
# num_users = 7 page_size = 5
# expected_users_remainder = num_users - page_size num_users = 7
# expected_users_remainder = num_users - page_size
# # create users
# all_user_ids = [] # create users
# for i in range(num_users): all_user_ids = []
# for i in range(num_users):
# user_id = uuid.uuid4()
# all_user_ids.append(user_id) user_id = uuid.uuid4()
# key_name = "test_key" + f"{i}" all_user_ids.append(user_id)
# key_name = "test_key" + f"{i}"
# create_user_response = admin_client.create_user(user_id)
# admin_client.create_key(create_user_response.user_id, key_name) create_user_response = admin_client.create_user(user_id)
# admin_client.create_key(create_user_response.user_id, key_name)
# # list users in page 1
# get_all_users_response1 = admin_client.get_users(limit=page_size) # list users in page 1
# cursor1 = get_all_users_response1.cursor get_all_users_response1 = admin_client.get_users(limit=page_size)
# user_list1 = get_all_users_response1.user_list cursor1 = get_all_users_response1.cursor
# assert len(user_list1) == page_size user_list1 = get_all_users_response1.user_list
# assert len(user_list1) == page_size
# # list users in page 2 using cursor
# get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size) # list users in page 2 using cursor
# cursor2 = get_all_users_response2.cursor get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
# user_list2 = get_all_users_response2.user_list cursor2 = get_all_users_response2.cursor
# user_list2 = get_all_users_response2.user_list
# assert len(user_list2) == expected_users_remainder
# assert cursor1 != cursor2 assert len(user_list2) == expected_users_remainder
# assert cursor1 != cursor2
# # delete users
# clean_up_users_and_keys(all_user_ids) # delete users
# clean_up_users_and_keys(all_user_ids)
# # list users to check pagination with no users
# users = admin_client.get_users() # list users to check pagination with no users
# assert len(users.user_list) == 0, f"Expected 0 users, got {users}" users = admin_client.get_users()
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
def clean_up_users_and_keys(user_id_list): def clean_up_users_and_keys(user_id_list):

View File

@ -1,41 +1,38 @@
# TODO: add back import os
import subprocess
# import os import pytest
# import subprocess
#
# import pytest @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
# def test_agent_groupchat():
#
# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key") # Define the path to the script you want to test
# def test_agent_groupchat(): script_path = "memgpt/autogen/examples/agent_groupchat.py"
#
# # Define the path to the script you want to test # Dynamically get the project's root directory (assuming this script is run from the root)
# script_path = "memgpt/autogen/examples/agent_groupchat.py" # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
# # print(project_root)
# # Dynamically get the project's root directory (assuming this script is run from the root) # project_root = os.path.join(project_root, "MemGPT")
# # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) # print(project_root)
# # print(project_root) # sys.exit(1)
# # project_root = os.path.join(project_root, "MemGPT")
# # print(project_root) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# # sys.exit(1) project_root = os.path.join(project_root, "memgpt")
# print(f"Adding the following to PATH: {project_root}")
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# project_root = os.path.join(project_root, "memgpt") # Prepare the environment, adding the project root to PYTHONPATH
# print(f"Adding the following to PATH: {project_root}") env = os.environ.copy()
# env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}"
# # Prepare the environment, adding the project root to PYTHONPATH
# env = os.environ.copy() # Run the script using subprocess.run
# env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}" # Capture the output (stdout) and the exit code
# # result = subprocess.run(["python", script_path], capture_output=True, text=True)
# # Run the script using subprocess.run result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True)
# # Capture the output (stdout) and the exit code
# # result = subprocess.run(["python", script_path], capture_output=True, text=True) # Check the exit code (0 indicates success)
# result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True) assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}"
#
# # Check the exit code (0 indicates success) # Optionally, check the output for expected content
# assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}" # For example, if you expect a specific line in the output, uncomment and adapt the following line:
# # assert "expected output" in result.stdout, "Expected output not found in script's output"
# # Optionally, check the output for expected content
# # For example, if you expect a specific line in the output, uncomment and adapt the following line:
# # assert "expected output" in result.stdout, "Expected output not found in script's output"
#

View File

@ -26,7 +26,7 @@ def agent_obj():
agent_state = client.create_agent() agent_state = client.create_agent()
global agent_obj global agent_obj
agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id) agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id)
yield agent_obj yield agent_obj
client.delete_agent(agent_obj.agent_state.id) client.delete_agent(agent_obj.agent_state.id)

View File

@ -1,12 +1,17 @@
import os
import subprocess import subprocess
import sys import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"])
import pexpect
from prettytable.colortable import ColorTable from prettytable.colortable import ColorTable
from memgpt.cli.cli_config import ListChoice, add, delete from memgpt.cli.cli_config import ListChoice, add, delete
from memgpt.cli.cli_config import list as list_command from memgpt.cli.cli_config import list as list_command
from .constants import TIMEOUT
from .utils import create_config
# def test_configure_memgpt(): # def test_configure_memgpt():
# configure_memgpt() # configure_memgpt()
@ -42,3 +47,41 @@ def test_cli_config():
assert "test data" in row assert "test data" in row
# delete # delete
delete(option=option, name="test") delete(option=option, name="test")
def test_save_load():
# configure_memgpt() # rely on configure running first^
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")
child = pexpect.spawn("poetry run memgpt run --agent test_save_load --first --strip-ui")
child.expect("Enter your message:", timeout=TIMEOUT)
child.sendline()
child.expect("Empty input received. Try again!", timeout=TIMEOUT)
child.sendline("/save")
child.expect("Enter your message:", timeout=TIMEOUT)
child.sendline("/exit")
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
child.close()
assert child.isalive() is False, "CLI should have terminated."
assert child.exitstatus == 0, "CLI did not exit cleanly."
child = pexpect.spawn("poetry run memgpt run --agent test_save_load --first --strip-ui")
child.expect("Using existing agent test_save_load", timeout=TIMEOUT)
child.expect("Enter your message:", timeout=TIMEOUT)
child.sendline("/exit")
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
child.close()
assert child.isalive() is False, "CLI should have terminated."
assert child.exitstatus == 0, "CLI did not exit cleanly."
if __name__ == "__main__":
# test_configure_memgpt()
test_save_load()

View File

@ -7,11 +7,12 @@ import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from memgpt import Admin, create_client from memgpt import Admin, create_client
from memgpt.config import MemGPTConfig
from memgpt.constants import DEFAULT_PRESET from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.message import Message from memgpt.credentials import MemGPTCredentials
from memgpt.schemas.usage import MemGPTUsageStatistics from memgpt.data_types import Preset # TODO move to PresetModel
from memgpt.settings import settings
# from tests.utils import create_config from tests.utils import create_config
test_agent_name = f"test_client_{str(uuid.uuid4())}" test_agent_name = f"test_client_{str(uuid.uuid4())}"
# test_preset_name = "test_preset" # test_preset_name = "test_preset"
@ -20,16 +21,44 @@ test_agent_state = None
client = None client = None
test_agent_state_post_message = None test_agent_state_post_message = None
test_user_id = uuid.uuid4()
# admin credentials # admin credentials
test_server_token = "test_server_token" test_server_token = "test_server_token"
def _reset_config():
# Use os.getenv with a fallback to os.environ.get
db_url = settings.memgpt_pg_uri
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
credentials = MemGPTCredentials(
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
create_config("memgpt_hosted")
credentials = MemGPTCredentials()
config = MemGPTConfig.load()
# set to use postgres
config.archival_storage_uri = db_url
config.recall_storage_uri = db_url
config.metadata_storage_uri = db_url
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
config.metadata_storage_type = "postgres"
config.save()
credentials.save()
print("_reset_config :: ", config.config_path)
def run_server(): def run_server():
load_dotenv() load_dotenv()
# _reset_config() _reset_config()
from memgpt.server.rest_api.server import start_server from memgpt.server.rest_api.server import start_server
@ -39,8 +68,7 @@ def run_server():
# Fixture to create clients with different configurations # Fixture to create clients with different configurations
@pytest.fixture( @pytest.fixture(
# params=[{"server": True}, {"server": False}], # whether to use REST API server params=[{"server": True}, {"server": False}], # whether to use REST API server
params=[{"server": True}], # whether to use REST API server
scope="module", scope="module",
) )
def client(request): def client(request):
@ -58,20 +86,21 @@ def client(request):
print("Running client tests with server:", server_url) print("Running client tests with server:", server_url)
# create user via admin client # create user via admin client
admin = Admin(server_url, test_server_token) admin = Admin(server_url, test_server_token)
user = admin.create_user() # Adjust as per your client's method response = admin.create_user(test_user_id) # Adjust as per your client's method
api_key = admin.create_key(user.id) token = response.api_key
client = create_client(base_url=server_url, token=api_key.key) # This yields control back to the test function
else: else:
# use local client (no server) # use local client (no server)
token = None
server_url = None server_url = None
client = create_client()
client = create_client(base_url=server_url, token=token) # This yields control back to the test function
try: try:
yield client yield client
finally: finally:
# cleanup user # cleanup user
if server_url: if server_url:
admin.delete_user(user.id) admin.delete_user(test_user_id) # Adjust as per your client's method
# Fixture for test agent # Fixture for test agent
@ -86,6 +115,7 @@ def agent(client):
def test_agent(client, agent): def test_agent(client, agent):
_reset_config()
# test client.rename_agent # test client.rename_agent
new_name = "RenamedTestAgent" new_name = "RenamedTestAgent"
@ -101,84 +131,61 @@ def test_agent(client, agent):
def test_memory(client, agent): def test_memory(client, agent):
# _reset_config() _reset_config()
memory_response = client.get_in_context_memory(agent_id=agent.id) memory_response = client.get_agent_memory(agent_id=agent.id)
print("MEMORY", memory_response) print("MEMORY", memory_response)
updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"} updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"}
client.update_in_context_memory(agent_id=agent.id, section="human", value=updated_memory["human"]) client.update_agent_core_memory(agent_id=agent.id, new_memory_contents=updated_memory)
client.update_in_context_memory(agent_id=agent.id, section="persona", value=updated_memory["persona"]) updated_memory_response = client.get_agent_memory(agent_id=agent.id)
updated_memory_response = client.get_in_context_memory(agent_id=agent.id)
assert ( assert (
updated_memory_response.get_block("human").value == updated_memory["human"] updated_memory_response.core_memory.human == updated_memory["human"]
and updated_memory_response.get_block("persona").value == updated_memory["persona"] and updated_memory_response.core_memory.persona == updated_memory["persona"]
), "Memory update failed" ), "Memory update failed"
def test_agent_interactions(client, agent): def test_agent_interactions(client, agent):
# _reset_config() _reset_config()
message = "Hello, agent!" message = "Hello, agent!"
print("Sending message", message) message_response = client.user_message(agent_id=agent.id, message=message)
response = client.user_message(agent_id=agent.id, message=message)
print("Response", response)
assert isinstance(response.usage, MemGPTUsageStatistics)
assert response.usage.step_count == 1
assert response.usage.total_tokens > 0
assert response.usage.completion_tokens > 0
assert isinstance(response.messages[0], Message)
print(response.messages)
# TODO: add streaming tests command = "/memory"
command_response = client.run_command(agent_id=agent.id, command=command)
print("command", command_response)
def test_archival_memory(client, agent): def test_archival_memory(client, agent):
# _reset_config() _reset_config()
memory_content = "Archival memory content" memory_content = "Archival memory content"
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)[0] insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)
print("Inserted memory", insert_response.text, insert_response.id)
assert insert_response, "Inserting archival memory failed" assert insert_response, "Inserting archival memory failed"
archival_memory_response = client.get_archival_memory(agent_id=agent.id, limit=1) archival_memory_response = client.get_agent_archival_memory(agent_id=agent.id, limit=1)
archival_memories = [memory.text for memory in archival_memory_response] print("MEMORY")
archival_memories = [memory.contents for memory in archival_memory_response.archival_memory]
assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}" assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}"
memory_id_to_delete = archival_memory_response[0].id memory_id_to_delete = archival_memory_response.archival_memory[0].id
client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete) client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete)
# add archival memory
memory_str = "I love chats"
passage = client.insert_archival_memory(agent.id, memory=memory_str)[0]
# list archival memory
passages = client.get_archival_memory(agent.id)
assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}"
# get archival memory summary
archival_summary = client.get_archival_memory_summary(agent.id)
assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}"
# delete archival memory
client.delete_archival_memory(agent.id, passage.id)
# TODO: check deletion # TODO: check deletion
client.get_archival_memory(agent.id)
def test_messages(client, agent): def test_messages(client, agent):
# _reset_config() _reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed" assert send_message_response, "Sending message failed"
messages_response = client.get_messages(agent_id=agent.id, limit=1) messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed" assert len(messages_response.messages) > 0, "Retrieving messages failed"
def test_humans_personas(client, agent): def test_humans_personas(client, agent):
# _reset_config() _reset_config()
humans_response = client.list_humans() humans_response = client.list_humans()
print("HUMANS", humans_response) print("HUMANS", humans_response)
@ -187,20 +194,18 @@ def test_humans_personas(client, agent):
print("PERSONAS", personas_response) print("PERSONAS", personas_response)
persona_name = "TestPersona" persona_name = "TestPersona"
persona_id = client.get_persona_id(persona_name) if client.get_persona(persona_name):
if persona_id: client.delete_persona(persona_name)
client.delete_persona(persona_id)
persona = client.create_persona(name=persona_name, text="Persona text") persona = client.create_persona(name=persona_name, text="Persona text")
assert persona.name == persona_name assert persona.name == persona_name
assert persona.value == "Persona text", "Creating persona failed" assert persona.text == "Persona text", "Creating persona failed"
human_name = "TestHuman" human_name = "TestHuman"
human_id = client.get_human_id(human_name) if client.get_human(human_name):
if human_id: client.delete_human(human_name)
client.delete_human(human_id)
human = client.create_human(name=human_name, text="Human text") human = client.create_human(name=human_name, text="Human text")
assert human.name == human_name assert human.name == human_name
assert human.value == "Human text", "Creating human failed" assert human.text == "Human text", "Creating human failed"
# def test_tools(client, agent): # def test_tools(client, agent):
@ -213,14 +218,11 @@ def test_humans_personas(client, agent):
def test_config(client, agent): def test_config(client, agent):
# _reset_config() _reset_config()
models_response = client.list_models() models_response = client.list_models()
print("MODELS", models_response) print("MODELS", models_response)
embeddings_response = client.list_embedding_models()
print("EMBEDDINGS", embeddings_response)
# TODO: add back # TODO: add back
# config_response = client.get_config() # config_response = client.get_config()
# TODO: ensure config is the same as the one in the server # TODO: ensure config is the same as the one in the server
@ -228,7 +230,7 @@ def test_config(client, agent):
def test_sources(client, agent): def test_sources(client, agent):
# _reset_config() _reset_config()
if not hasattr(client, "base_url"): if not hasattr(client, "base_url"):
pytest.skip("Skipping test_sources because base_url is None") pytest.skip("Skipping test_sources because base_url is None")
@ -236,7 +238,7 @@ def test_sources(client, agent):
# list sources # list sources
sources = client.list_sources() sources = client.list_sources()
print("listed sources", sources) print("listed sources", sources)
assert len(sources) == 0 assert len(sources.sources) == 0
# create a source # create a source
source = client.create_source(name="test_source") source = client.create_source(name="test_source")
@ -244,53 +246,36 @@ def test_sources(client, agent):
# list sources # list sources
sources = client.list_sources() sources = client.list_sources()
print("listed sources", sources) print("listed sources", sources)
assert len(sources) == 1 assert len(sources.sources) == 1
assert sources.sources[0].metadata_["num_passages"] == 0
# TODO: add back? assert sources.sources[0].metadata_["num_documents"] == 0
assert sources[0].metadata_["num_passages"] == 0
assert sources[0].metadata_["num_documents"] == 0
# update the source
original_id = source.id
original_name = source.name
new_name = original_name + "_new"
client.update_source(source_id=source.id, name=new_name)
# get the source name (check that it's been updated)
source = client.get_source(source_id=source.id)
assert source.name == new_name
assert source.id == original_id
# get the source id (make sure that it's the same)
assert str(original_id) == client.get_source_id(source_name=new_name)
# check agent archival memory size # check agent archival memory size
archival_memories = client.get_archival_memory(agent_id=agent.id) archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
print(archival_memories) print(archival_memories)
assert len(archival_memories) == 0 assert len(archival_memories) == 0
# load a file into a source # load a file into a source
filename = "CONTRIBUTING.md" filename = "CONTRIBUTING.md"
upload_job = client.load_file_into_source(filename=filename, source_id=source.id) upload_job = client.load_file_into_source(filename=filename, source_id=source.id)
print("Upload job", upload_job, upload_job.status, upload_job.metadata_) print("Upload job", upload_job, upload_job.status, upload_job.metadata)
# TODO: make sure things run in the right order # TODO: make sure things run in the right order
archival_memories = client.get_archival_memory(agent_id=agent.id) archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
assert len(archival_memories) == 0 assert len(archival_memories) == 0
# attach a source # attach a source
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id) client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
# list archival memory # list archival memory
archival_memories = client.get_archival_memory(agent_id=agent.id) archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
# print(archival_memories) # print(archival_memories)
assert len(archival_memories) == 20 or len(archival_memories) == 21 assert len(archival_memories) == 20 or len(archival_memories) == 21
# check number of passages # check number of passages
sources = client.list_sources() sources = client.list_sources()
# TODO: add back? assert sources.sources[0].metadata_["num_passages"] > 0
# assert sources.sources[0].metadata_["num_passages"] > 0 assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
# assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
print(sources) print(sources)
# detach the source # detach the source
@ -299,3 +284,80 @@ def test_sources(client, agent):
# delete the source # delete the source
client.delete_source(source.id) client.delete_source(source.id)
# def test_presets(client, agent):
# _reset_config()
#
# # new_preset = Preset(
# # # user_id=client.user_id,
# # name="pytest_test_preset",
# # description="DUMMY_DESCRIPTION",
# # system="DUMMY_SYSTEM",
# # persona="DUMMY_PERSONA",
# # persona_name="DUMMY_PERSONA_NAME",
# # human="DUMMY_HUMAN",
# # human_name="DUMMY_HUMAN_NAME",
# # functions_schema=[
# # {
# # "name": "send_message",
# # "json_schema": {
# # "name": "send_message",
# # "description": "Sends a message to the human user.",
# # "parameters": {
# # "type": "object",
# # "properties": {
# # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
# # },
# # "required": ["message"],
# # },
# # },
# # "tags": ["memgpt-base"],
# # "source_type": "python",
# # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
# # }
# # ],
# # )
#
# ## List all presets and make sure the preset is NOT in the list
# # all_presets = client.list_presets()
# # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# # Create a preset
# new_preset = client.create_preset(name="pytest_test_preset")
#
# # List all presets and make sure the preset is in the list
# all_presets = client.list_presets()
# assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
#
# # Delete the preset
# client.delete_preset(preset_id=new_preset.id)
#
# # List all presets and make sure the preset is NOT in the list
# all_presets = client.list_presets()
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# def test_tools(client, agent):
#
# # load a function
# file_path = "tests/data/functions/dump_json.py"
# module_name = "dump_json"
#
# # list functions
# response = client.list_tools()
# orig_tools = response.tools
# print(orig_tools)
#
# # add the tool
# create_tool_response = client.create_tool(name=module_name, file_path=file_path)
# print(create_tool_response)
#
# # list functions
# response = client.list_tools()
# new_tools = response.tools
# assert module_name in [tool.name for tool in new_tools]
# # assert len(new_tools) == len(orig_tools) + 1
#
# # TODO: add a function to a preset
#
# # TODO: add a function to an agent

View File

@ -1,142 +1,142 @@
# TODO: add back when messaging works import os
import threading
import time
import uuid
# import os import pytest
# import threading from dotenv import load_dotenv
# import time
# import uuid from memgpt import Admin, create_client
# from memgpt.config import MemGPTConfig
# import pytest from memgpt.constants import DEFAULT_PRESET
# from dotenv import load_dotenv from memgpt.credentials import MemGPTCredentials
# from memgpt.data_types import Preset # TODO move to PresetModel
# from memgpt import Admin, create_client from memgpt.settings import settings
# from memgpt.config import MemGPTConfig from tests.utils import create_config
# from memgpt.credentials import MemGPTCredentials
# from memgpt.settings import settings test_agent_name = f"test_client_{str(uuid.uuid4())}"
# from tests.utils import create_config # test_preset_name = "test_preset"
# test_preset_name = DEFAULT_PRESET
# test_agent_name = f"test_client_{str(uuid.uuid4())}" test_agent_state = None
## test_preset_name = "test_preset" client = None
# test_agent_state = None
# client = None test_agent_state_post_message = None
# test_user_id = uuid.uuid4()
# test_agent_state_post_message = None
# test_user_id = uuid.uuid4()
# # admin credentials
# test_server_token = "test_server_token"
## admin credentials
# test_server_token = "test_server_token"
# def _reset_config():
#
# def _reset_config(): # Use os.getenv with a fallback to os.environ.get
# db_url = settings.memgpt_pg_uri
# # Use os.getenv with a fallback to os.environ.get
# db_url = settings.memgpt_pg_uri if os.getenv("OPENAI_API_KEY"):
# create_config("openai")
# if os.getenv("OPENAI_API_KEY"): credentials = MemGPTCredentials(
# create_config("openai") openai_key=os.getenv("OPENAI_API_KEY"),
# credentials = MemGPTCredentials( )
# openai_key=os.getenv("OPENAI_API_KEY"), else: # hosted
# ) create_config("memgpt_hosted")
# else: # hosted credentials = MemGPTCredentials()
# create_config("memgpt_hosted")
# credentials = MemGPTCredentials() config = MemGPTConfig.load()
#
# config = MemGPTConfig.load() # set to use postgres
# config.archival_storage_uri = db_url
# # set to use postgres config.recall_storage_uri = db_url
# config.archival_storage_uri = db_url config.metadata_storage_uri = db_url
# config.recall_storage_uri = db_url config.archival_storage_type = "postgres"
# config.metadata_storage_uri = db_url config.recall_storage_type = "postgres"
# config.archival_storage_type = "postgres" config.metadata_storage_type = "postgres"
# config.recall_storage_type = "postgres"
# config.metadata_storage_type = "postgres" config.save()
# credentials.save()
# config.save() print("_reset_config :: ", config.config_path)
# credentials.save()
# print("_reset_config :: ", config.config_path)
# def run_server():
#
# def run_server(): load_dotenv()
#
# load_dotenv() _reset_config()
#
# _reset_config() from memgpt.server.rest_api.server import start_server
#
# from memgpt.server.rest_api.server import start_server print("Starting server...")
# start_server(debug=True)
# print("Starting server...")
# start_server(debug=True)
# # Fixture to create clients with different configurations
# @pytest.fixture(
## Fixture to create clients with different configurations params=[ # whether to use REST API server
# @pytest.fixture( {"server": True},
# params=[ # whether to use REST API server # {"server": False} # TODO: add when implemented
# {"server": True}, ],
# # {"server": False} # TODO: add when implemented scope="module",
# ], )
# scope="module", def admin_client(request):
# ) if request.param["server"]:
# def admin_client(request): # get URL from enviornment
# if request.param["server"]: server_url = os.getenv("MEMGPT_SERVER_URL")
# # get URL from enviornment if server_url is None:
# server_url = os.getenv("MEMGPT_SERVER_URL") # run server in thread
# if server_url is None: # NOTE: must set MEMGPT_SERVER_PASS enviornment variable
# # run server in thread server_url = "http://localhost:8283"
# # NOTE: must set MEMGPT_SERVER_PASS enviornment variable print("Starting server thread")
# server_url = "http://localhost:8283" thread = threading.Thread(target=run_server, daemon=True)
# print("Starting server thread") thread.start()
# thread = threading.Thread(target=run_server, daemon=True) time.sleep(5)
# thread.start() print("Running client tests with server:", server_url)
# time.sleep(5) # create user via admin client
# print("Running client tests with server:", server_url) admin = Admin(server_url, test_server_token)
# # create user via admin client response = admin.create_user(test_user_id) # Adjust as per your client's method
# admin = Admin(server_url, test_server_token)
# response = admin.create_user(test_user_id) # Adjust as per your client's method yield admin
#
# yield admin
# def test_concurrent_messages(admin_client):
# # test concurrent messages
# def test_concurrent_messages(admin_client):
# # test concurrent messages # create three
#
# # create three results = []
#
# results = [] def _send_message():
# try:
# def _send_message(): print("START SEND MESSAGE")
# try: response = admin_client.create_user()
# print("START SEND MESSAGE") token = response.api_key
# response = admin_client.create_user() client = create_client(base_url=admin_client.base_url, token=token)
# token = response.api_key agent = client.create_agent()
# client = create_client(base_url=admin_client.base_url, token=token)
# agent = client.create_agent() print("Agent created", agent.id)
#
# print("Agent created", agent.id) st = time.time()
# message = "Hello, how are you?"
# st = time.time() response = client.send_message(agent_id=agent.id, message=message, role="user")
# message = "Hello, how are you?" et = time.time()
# response = client.send_message(agent_id=agent.id, message=message, role="user") print(f"Message sent from {st} to {et}")
# et = time.time() print(response.messages)
# print(f"Message sent from {st} to {et}") results.append((st, et))
# print(response.messages) except Exception as e:
# results.append((st, et)) print("ERROR", e)
# except Exception as e:
# print("ERROR", e) threads = []
# print("Starting threads...")
# threads = [] for i in range(5):
# print("Starting threads...") thread = threading.Thread(target=_send_message)
# for i in range(5): threads.append(thread)
# thread = threading.Thread(target=_send_message) thread.start()
# threads.append(thread) print("CREATED THREAD")
# thread.start()
# print("CREATED THREAD") print("waiting for threads to finish...")
# for thread in threads:
# print("waiting for threads to finish...") print(thread.join())
# for thread in threads:
# print(thread.join()) # make sure runtime are overlapping
# assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
# # make sure runtime are overlapping results[1][0] < results[0][0] and results[1][1] > results[0][0]
# assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or ( ), f"Threads should have overlapping runtimes {results}"
# results[1][0] < results[0][0] and results[1][1] > results[0][0]
# ), f"Threads should have overlapping runtimes {results}"
#

Some files were not shown because too many files have changed in this diff Show More