This commit is contained in:
Charles Packer 2024-08-16 19:52:47 -07:00 committed by GitHub
parent b39ad16a9d
commit 55a36a6e3d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
112 changed files with 8008 additions and 8901 deletions

View File

@ -58,6 +58,7 @@ jobs:
pipx install poetry==1.8.2
poetry install -E dev -E postgres
poetry run pytest -s tests/test_client.py
poetry run pytest -s tests/test_concurrent_connections.py
- name: Print docker logs if tests fail
if: failure()

View File

@ -2,7 +2,6 @@ name: Run All pytest Tests
env:
MEMGPT_PGURI: ${{ secrets.MEMGPT_PGURI }}
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
on:
push:
@ -33,10 +32,10 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E crewai-tools"
install-args: "-E dev -E postgres -E milvus"
- name: Initialize credentials
run: poetry run memgpt quickstart --backend openai
run: poetry run memgpt quickstart --backend memgpt
#- name: Run docker compose server
# env:
@ -70,3 +69,14 @@ jobs:
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not test_concurrent_connections.py and not test_quickstart and not test_endpoints and not test_storage and not test_server and not test_openai_client" tests
- name: Run storage tests
env:
MEMGPT_PG_PORT: 8888
MEMGPT_PG_USER: memgpt
MEMGPT_PG_PASSWORD: memgpt
MEMGPT_PG_HOST: localhost
MEMGPT_PG_DB: memgpt
MEMGPT_SERVER_PASS: test_server_token
run: |
poetry run pytest -s -vv tests/test_storage.py

1
.gitignore vendored
View File

@ -1012,7 +1012,6 @@ FodyWeavers.xsd
## cached db data
pgdata/
!pgdata/.gitkeep
.persist/
## pytest mirrors
memgpt/.pytest_cache/

View File

@ -14,7 +14,7 @@ WORKDIR /app
COPY pyproject.toml poetry.lock ./
RUN poetry lock --no-update
RUN if [ "$MEMGPT_ENVIRONMENT" = "DEVELOPMENT" ] ; then \
poetry install --no-root -E "postgres server dev" ; \
poetry install --no-root -E "postgres server dev autogen" ; \
else \
poetry install --no-root -E "postgres server" && \
rm -rf $POETRY_CACHE_DIR ; \

View File

@ -2,7 +2,8 @@ import datetime
import inspect
import json
import traceback
from typing import List, Literal, Optional, Tuple, Union
import uuid
from typing import List, Literal, Optional, Tuple, Union, cast
from tqdm import tqdm
@ -18,20 +19,14 @@ from memgpt.constants import (
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
)
from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage
from memgpt.interface import AgentInterface
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
from memgpt.memory import ArchivalMemory, RecallMemory, summarize_messages
from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages
from memgpt.metadata import MetadataStore
from memgpt.models import chat_completion_response
from memgpt.models.pydantic_models import OptionState, ToolModel
from memgpt.persistence_manager import LocalStateManager
from memgpt.schemas.agent import AgentState
from memgpt.schemas.block import Block
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import OptionState
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
from memgpt.schemas.passage import Passage
from memgpt.schemas.tool import Tool
from memgpt.system import (
get_initial_boot_messages,
get_login_event,
@ -40,6 +35,7 @@ from memgpt.system import (
)
from memgpt.utils import (
count_tokens,
create_uuid_from_string,
get_local_time,
get_tool_call_id,
get_utc_time,
@ -76,7 +72,7 @@ def compile_memory_metadata_block(
def compile_system_message(
system_prompt: str,
in_context_memory: Memory,
in_context_memory: BaseMemory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None,
@ -139,7 +135,7 @@ def compile_system_message(
def initialize_message_sequence(
model: str,
system: str,
memory: Memory,
memory: BaseMemory,
archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None,
memory_edit_timestamp: Optional[datetime.datetime] = None,
@ -192,21 +188,35 @@ class Agent(object):
interface: AgentInterface,
# agents can be created from providing agent_state
agent_state: AgentState,
tools: List[Tool],
# memory: Memory,
tools: List[ToolModel],
# memory: BaseMemory,
# extras
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
):
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
# tools
for tool in tools:
assert tool, f"Tool is None - must be error in querying tool from DB"
assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools"
for tool_name in agent_state.tools:
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
# Store the functions schemas (this is passed as an argument to ChatCompletion)
self.functions = []
self.functions_python = {}
env = {}
env.update(globals())
for tool in tools:
# WARNING: name may not be consistent?
if tool.module: # execute the whole module
exec(tool.module, env)
else:
exec(tool.source_code, env)
self.functions_python[tool.name] = env[tool.name]
self.functions.append(tool.json_schema)
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
# Hold a copy of the state that was used to init the agent
self.agent_state = agent_state
assert isinstance(self.agent_state.memory, Memory), f"Memory object is not of type Memory: {type(self.agent_state.memory)}"
try:
self.link_tools(tools)
except Exception as e:
raise ValueError(f"Encountered an error while trying to link agent tools during initialization:\n{str(e)}")
# gpt-4, gpt-3.5-turbo, ...
self.model = self.agent_state.llm_config.model
@ -215,8 +225,7 @@ class Agent(object):
self.system = self.agent_state.system
# Initialize the memory object
self.memory = self.agent_state.memory
assert isinstance(self.memory, Memory), f"Memory object is not of type Memory: {type(self.memory)}"
self.memory = BaseMemory.load(self.agent_state.state["memory"])
printd("Initialized memory object", self.memory)
# Interface must implement:
@ -245,13 +254,28 @@ class Agent(object):
self._messages: List[Message] = []
# Once the memory object is initialized, use it to "bake" the system message
if self.agent_state.message_ids is not None:
self.set_message_buffer(message_ids=self.agent_state.message_ids)
if "messages" in self.agent_state.state and self.agent_state.state["messages"] is not None:
# print(f"Agent.__init__ :: loading, state={agent_state.state['messages']}")
if not isinstance(self.agent_state.state["messages"], list):
raise ValueError(f"'messages' in AgentState was bad type: {type(self.agent_state.state['messages'])}")
assert all([isinstance(msg, str) for msg in self.agent_state.state["messages"]])
# Convert to IDs, and pull from the database
raw_messages = [
self.persistence_manager.recall_memory.storage.get(id=uuid.UUID(msg_id)) for msg_id in self.agent_state.state["messages"]
]
assert all([isinstance(msg, Message) for msg in raw_messages]), (raw_messages, self.agent_state.state["messages"])
self._messages.extend([cast(Message, msg) for msg in raw_messages if msg is not None])
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
else:
printd(f"Agent.__init__ :: creating, state={agent_state.message_ids}")
# Generate a sequence of initial messages to put in the buffer
printd(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
init_messages = initialize_message_sequence(
model=self.model,
system=self.system,
@ -261,8 +285,6 @@ class Agent(object):
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
)
# Cast the messages to actual Message objects to be synced to the DB
init_messages_objs = []
for msg in init_messages:
init_messages_objs.append(
@ -271,12 +293,15 @@ class Agent(object):
)
)
assert all([isinstance(msg, Message) for msg in init_messages_objs]), (init_messages_objs, init_messages)
# Put the messages inside the message buffer
self.messages_total = 0
# self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
self._append_to_messages(added_messages=init_messages_objs)
self._validate_message_buffer_is_utc()
self._append_to_messages(added_messages=[cast(Message, msg) for msg in init_messages_objs if msg is not None])
for m in self._messages:
assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
@ -295,65 +320,6 @@ class Agent(object):
def messages(self, value):
raise Exception("Modifying message list directly not allowed")
def link_tools(self, tools: List[Tool]):
"""Bind a tool object (schema + python function) to the agent object"""
# tools
for tool in tools:
assert tool, f"Tool is None - must be error in querying tool from DB"
assert tool.name in self.agent_state.tools, f"Tool {tool} not found in agent_state.tools"
for tool_name in self.agent_state.tools:
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
# Store the functions schemas (this is passed as an argument to ChatCompletion)
self.functions = []
self.functions_python = {}
env = {}
env.update(globals())
for tool in tools:
# WARNING: name may not be consistent?
if tool.module: # execute the whole module
exec(tool.module, env)
else:
exec(tool.source_code, env)
self.functions_python[tool.name] = env[tool.name]
self.functions.append(tool.json_schema)
assert all([callable(f) for k, f in self.functions_python.items()]), self.functions_python
def _load_messages_from_recall(self, message_ids: List[str]) -> List[Message]:
"""Load a list of messages from recall storage"""
# Pull the message objects from the database
message_objs = [self.persistence_manager.recall_memory.storage.get(msg_id) for msg_id in message_ids]
assert all([isinstance(msg, Message) for msg in message_objs])
return message_objs
def _validate_message_buffer_is_utc(self):
"""Iterate over the message buffer and force all messages to be UTC stamped"""
for m in self._messages:
# assert is_utc_datetime(m.created_at), f"created_at on message for agent {self.agent_state.name} isn't UTC:\n{vars(m)}"
# TODO eventually do casting via an edit_message function
if not is_utc_datetime(m.created_at):
printd(f"Warning - created_at on message for agent {self.agent_state.name} isn't UTC (text='{m.text}')")
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
def set_message_buffer(self, message_ids: List[str], force_utc: bool = True):
"""Set the messages in the buffer to the message IDs list"""
message_objs = self._load_messages_from_recall(message_ids=message_ids)
# set the objects in the buffer
self._messages = message_objs
# bugfix for old agents that may not have had UTC specified in their timestamps
if force_utc:
self._validate_message_buffer_is_utc()
# also sync the message IDs attribute
self.agent_state.message_ids = message_ids
def _trim_messages(self, num):
"""Trim messages from the front, not including the system message"""
self.persistence_manager.trim_messages(num)
@ -406,7 +372,7 @@ class Agent(object):
first_message: bool = False, # hint
stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
) -> ChatCompletionResponse:
) -> chat_completion_response.ChatCompletionResponse:
"""Get response from LLM API"""
try:
response = create(
@ -442,7 +408,9 @@ class Agent(object):
except Exception as e:
raise e
def _handle_ai_response(self, response_message: Message, override_tool_call_id: bool = True) -> Tuple[List[Message], bool, bool]:
def _handle_ai_response(
self, response_message: chat_completion_response.Message, override_tool_call_id: bool = True
) -> Tuple[List[Message], bool, bool]:
"""Handles parsing and function execution"""
messages = [] # append these to the history when done
@ -647,7 +615,6 @@ class Agent(object):
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> Tuple[List[Union[dict, Message]], bool, bool, bool]:
"""Top-level event message handler for the MemGPT agent"""
@ -677,20 +644,7 @@ class Agent(object):
raise e
try:
# Step 0: update core memory
# only pulling latest block data if shared memory is being used
# TODO: ensure we're passing in metadata store from all surfaces
if ms is not None:
should_update = False
for block in self.agent_state.memory.to_dict().values():
if not block.get("template", False):
should_update = True
if should_update:
# TODO: the force=True can be optimized away
# once we ensure we're correctly comparing whether in-memory core
# data is different than persisted core data.
self.rebuild_memory(force=True, ms=ms)
# Step 1: add user message
# Step 0: add user message
if user_message is not None:
if isinstance(user_message, Message):
# Validate JSON via save/load
@ -736,7 +690,7 @@ class Agent(object):
if len(input_message_sequence) > 1 and input_message_sequence[-1].role != "user":
printd(f"{CLI_WARNING_PREFIX}Attempting to run ChatCompletion without user as the last message in the queue")
# Step 2: send the conversation and available functions to GPT
# Step 1: send the conversation and available functions to GPT
if not skip_verify and (first_message or self.messages_total == self.messages_total_init):
printd(f"This is the first message. Running extra verifier on AI response.")
counter = 0
@ -761,9 +715,9 @@ class Agent(object):
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
)
# Step 3: check if LLM wanted to call a function
# (if yes) Step 4: call the function
# (if yes) Step 5: send the info on the function call and function response to LLM
# Step 2: check if LLM wanted to call a function
# (if yes) Step 3: call the function
# (if yes) Step 4: send the info on the function call and function response to LLM
response_message = response.choices[0].message
response_message.model_copy() # TODO why are we copying here?
all_response_messages, heartbeat_request, function_failed = self._handle_ai_response(response_message)
@ -779,7 +733,7 @@ class Agent(object):
# "functions": self.functions,
# }
# Step 6: extend the message history
# Step 4: extend the message history
if user_message is not None:
if isinstance(user_message, Message):
all_new_messages = [user_message] + all_response_messages
@ -839,7 +793,6 @@ class Agent(object):
stream=stream,
timestamp=timestamp,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
ms=ms,
)
else:
@ -988,8 +941,8 @@ class Agent(object):
new_messages = [new_system_message_obj] + self._messages[1:] # swap index 0 (system)
self._messages = new_messages
def rebuild_memory(self, force=False, update_timestamp=True, ms: Optional[MetadataStore] = None):
"""Rebuilds the system message with the latest memory object and any shared memory block updates"""
def rebuild_memory(self, force=False, update_timestamp=True):
"""Rebuilds the system message with the latest memory object"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
# NOTE: This is a hacky way to check if the memory has changed
@ -998,28 +951,6 @@ class Agent(object):
printd(f"Memory has not changed, not rebuilding system")
return
if ms:
for block in self.memory.to_dict().values():
if block.get("templates", False):
# we don't expect to update shared memory blocks that
# are templates. this is something we could update in the
# future if we expect templates to change often.
continue
block_id = block.get("id")
db_block = ms.get_block(block_id=block_id)
if db_block is None:
# this case covers if someone has deleted a shared block by interacting
# with some other agent.
# in that case we should remove this shared block from the agent currently being
# evaluated.
printd(f"removing block: {block_id=}")
continue
if not isinstance(db_block.value, str):
printd(f"skipping block update, unexpected value: {block_id=}")
continue
# TODO: we may want to update which columns we're updating from shared memory e.g. the limit
self.memory.update_block_value(name=block.get("name", ""), value=db_block.value)
# If the memory didn't update, we probably don't want to update the timestamp inside
# For example, if we're doing a system prompt swap, this should probably be False
if update_timestamp:
@ -1117,14 +1048,25 @@ class Agent(object):
# return msg
def update_state(self) -> AgentState:
message_ids = [msg.id for msg in self._messages]
assert isinstance(self.memory, Memory), f"Memory is not a Memory object: {type(self.memory)}"
# override any fields that may have been updated
self.agent_state.message_ids = message_ids
self.agent_state.memory = self.memory
self.agent_state.system = self.system
memory = {
"system": self.system,
"memory": self.memory.to_dict(),
"messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids
}
self.agent_state = AgentState(
name=self.agent_state.name,
user_id=self.agent_state.user_id,
tools=self.agent_state.tools,
system=self.system,
## "model_state"
llm_config=self.agent_state.llm_config,
embedding_config=self.agent_state.embedding_config,
id=self.agent_state.id,
created_at=self.agent_state.created_at,
## "agent_state"
state=memory,
_metadata=self.agent_state._metadata,
)
return self.agent_state
def migrate_embedding(self, embedding_config: EmbeddingConfig):
@ -1134,12 +1076,13 @@ class Agent(object):
# TODO: recall memory
raise NotImplementedError()
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
def attach_source(self, source_name, source_connector: StorageConnector, ms: MetadataStore):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
filters = {"user_id": self.agent_state.user_id, "data_source": source_name}
size = source_connector.size(filters)
# typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN)
page_size = 100
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
all_passages = []
@ -1152,8 +1095,7 @@ class Agent(object):
passage.agent_id = self.agent_state.id
# regenerate passage ID (avoid duplicates)
# TODO: need to find another solution to the text duplication issue
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
passage.id = create_uuid_from_string(f"{source_name}_{str(passage.agent_id)}_{passage.text}")
# insert into agent archival memory
self.persistence_manager.archival_memory.storage.insert_many(passages)
@ -1165,14 +1107,15 @@ class Agent(object):
self.persistence_manager.archival_memory.storage.save()
# attach to agent
source = ms.get_source(source_id=source_id)
assert source is not None, f"Source {source_id} not found in metadata store"
source = ms.get_source(source_name=source_name, user_id=self.agent_state.user_id)
assert source is not None, f"source does not exist for source_name={source_name}, user_id={self.agent_state.user_id}"
source_id = source.id
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
total_agent_passages = self.persistence_manager.archival_memory.storage.size()
printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
f"Attached data source {source_name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
)
@ -1181,36 +1124,8 @@ def save_agent(agent: Agent, ms: MetadataStore):
agent.update_state()
agent_state = agent.agent_state
agent_id = agent_state.id
assert isinstance(agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
# NOTE: we're saving agent memory before persisting the agent to ensure
# that allocated block_ids for each memory block are present in the agent model
save_agent_memory(agent=agent, ms=ms)
if ms.get_agent(agent_id=agent.agent_state.id):
if ms.get_agent(agent_name=agent_state.name, user_id=agent_state.user_id):
ms.update_agent(agent_state)
else:
ms.create_agent(agent_state)
agent.agent_state = ms.get_agent(agent_id=agent_id)
assert isinstance(agent.agent_state.memory, Memory), f"Memory is not a Memory object: {type(agent_state.memory)}"
def save_agent_memory(agent: Agent, ms: MetadataStore):
"""
Save agent memory to metadata store. Memory is a collection of blocks and each block is persisted to the block table.
NOTE: we are assuming agent.update_state has already been called.
"""
for block_dict in agent.memory.to_dict().values():
# TODO: block creation should happen in one place to enforce these sort of constraints consistently.
if block_dict.get("user_id", None) is None:
block_dict["user_id"] = agent.agent_state.user_id
block = Block(**block_dict)
# FIXME: should we expect for block values to be None? If not, we need to figure out why that is
# the case in some tests, if so we should relax the DB constraint.
if block.value is None:
block.value = ""
ms.update_or_create_block(block)

View File

@ -1,12 +1,12 @@
from typing import Dict, List, Optional, Tuple, cast
import uuid
from typing import Dict, Iterator, List, Optional, Tuple, cast
import chromadb
from chromadb.api.types import Include
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.passage import Passage
from memgpt.data_types import Passage, Record, RecordType
from memgpt.utils import datetime_to_timestamp, printd, timestamp_to_datetime
@ -34,6 +34,9 @@ class ChromaStorageConnector(StorageConnector):
self.collection = self.client.get_or_create_collection(self.table_name)
self.include: Include = ["documents", "embeddings", "metadatas"]
# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
# get all filters for query
if filters is not None:
@ -51,7 +54,10 @@ class ChromaStorageConnector(StorageConnector):
continue
# filter by other keys
chroma_filters.append({key: {"$eq": value}})
if key in self.uuid_fields:
chroma_filters.append({key: {"$eq": str(value)}})
else:
chroma_filters.append({key: {"$eq": value}})
if len(chroma_filters) > 1:
chroma_filters = {"$and": chroma_filters}
@ -61,7 +67,7 @@ class ChromaStorageConnector(StorageConnector):
chroma_filters = chroma_filters[0]
return ids, chroma_filters
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0):
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000, offset: int = 0) -> Iterator[List[RecordType]]:
ids, filters = self.get_filters(filters)
while True:
# Retrieve a chunk of records with the given page_size
@ -78,50 +84,29 @@ class ChromaStorageConnector(StorageConnector):
# Increment the offset to get the next chunk in the next iteration
offset += page_size
def results_to_records(self, results):
def results_to_records(self, results) -> List[RecordType]:
# convert timestamps to datetime
for metadata in results["metadatas"]:
if "created_at" in metadata:
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = uuid.UUID(value)
if results["embeddings"]: # may not be returned, depending on table type
passages = []
for text, record_id, embedding, metadata in zip(
results["documents"], results["ids"], results["embeddings"], results["metadatas"]
):
args = {}
for field in EmbeddingConfig.__fields__.keys():
if field in metadata:
args[field] = metadata[field]
del metadata[field]
embedding_config = EmbeddingConfig(**args)
passages.append(Passage(text=text, embedding=embedding, id=record_id, embedding_config=embedding_config, **metadata))
# return [
# Passage(text=text, embedding=embedding, id=record_id, embedding_config=EmbeddingConfig(), **metadatas)
# for (text, record_id, embedding, metadatas) in zip(
# results["documents"], results["ids"], results["embeddings"], results["metadatas"]
# )
# ]
return passages
return [
cast(RecordType, self.type(text=text, embedding=embedding, id=uuid.UUID(record_id), **metadatas)) # type: ignore
for (text, record_id, embedding, metadatas) in zip(
results["documents"], results["ids"], results["embeddings"], results["metadatas"]
)
]
else:
# no embeddings
passages = []
for text, id, metadata in zip(results["documents"], results["ids"], results["metadatas"]):
args = {}
for field in EmbeddingConfig.__fields__.keys():
if field in metadata:
args[field] = metadata[field]
del metadata[field]
embedding_config = EmbeddingConfig(**args)
passages.append(Passage(text=text, embedding=None, id=id, embedding_config=embedding_config, **metadata))
return passages
return [
cast(RecordType, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
]
# return [
# #cast(Passage, self.type(text=text, id=uuid.UUID(id), **metadatas)) # type: ignore
# Passage(text=text, embedding=None, id=id, **metadatas)
# for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
# ]
def get_all(self, filters: Optional[Dict] = {}, limit=None):
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
ids, filters = self.get_filters(filters)
if self.collection.count() == 0:
return []
@ -131,13 +116,13 @@ class ChromaStorageConnector(StorageConnector):
results = self.collection.get(ids=ids, include=self.include, where=filters)
return self.results_to_records(results)
def get(self, id):
def get(self, id: uuid.UUID) -> Optional[RecordType]:
results = self.collection.get(ids=[str(id)])
if len(results["ids"]) == 0:
return None
return self.results_to_records(results)[0]
def format_records(self, records):
def format_records(self, records: List[RecordType]):
assert all([isinstance(r, Passage) for r in records])
recs = []
@ -160,13 +145,10 @@ class ChromaStorageConnector(StorageConnector):
# collect/format record metadata
metadatas = []
for record in recs:
embedding_config = vars(record.embedding_config)
metadata = vars(record)
metadata.pop("id")
metadata.pop("text")
metadata.pop("embedding")
metadata.pop("embedding_config")
metadata.pop("metadata_")
if "created_at" in metadata:
metadata["created_at"] = datetime_to_timestamp(metadata["created_at"])
if "metadata_" in metadata and metadata["metadata_"] is not None:
@ -174,22 +156,23 @@ class ChromaStorageConnector(StorageConnector):
metadata.pop("metadata_")
else:
record_metadata = {}
metadata = {**metadata, **record_metadata} # merge with metadata
metadata = {**metadata, **embedding_config} # merge with embedding config
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
metadata = {**metadata, **record_metadata} # merge with metadata
# convert uuids to strings
for key, value in metadata.items():
if key in self.uuid_fields:
metadata[key] = str(value)
metadatas.append(metadata)
return ids, documents, embeddings, metadatas
def insert(self, record):
def insert(self, record: Record):
ids, documents, embeddings, metadatas = self.format_records([record])
if any([e is None for e in embeddings]):
raise ValueError("Embeddings must be provided to chroma")
self.collection.upsert(documents=documents, embeddings=[e for e in embeddings if e is not None], ids=ids, metadatas=metadatas)
def insert_many(self, records, show_progress=False):
def insert_many(self, records: List[RecordType], show_progress=False):
ids, documents, embeddings, metadatas = self.format_records(records)
if any([e is None for e in embeddings]):
raise ValueError("Embeddings must be provided to chroma")
@ -215,7 +198,7 @@ class ChromaStorageConnector(StorageConnector):
def list_data_sources(self):
raise NotImplementedError
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
ids, filters = self.get_filters(filters)
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters)
@ -256,40 +239,10 @@ class ChromaStorageConnector(StorageConnector):
def get_all_cursor(
self,
filters: Optional[Dict] = {},
after: str = None,
before: str = None,
after: uuid.UUID = None,
before: uuid.UUID = None,
limit: Optional[int] = 1000,
order_by: str = "created_at",
reverse: bool = False,
):
records = self.get_all(filters=filters)
# WARNING: very hacky and slow implementation
def get_index(id, record_list):
for i in range(len(record_list)):
if record_list[i].id == id:
return i
assert False, f"Could not find id {id} in record list"
# sort by custom field
records = sorted(records, key=lambda x: getattr(x, order_by), reverse=reverse)
if after:
index = get_index(after, records)
if index + 1 >= len(records):
return None, []
records = records[index + 1 :]
if before:
index = get_index(before, records)
if index == 0:
return None, []
# TODO: not sure if this is correct
records = records[:index]
if len(records) == 0:
return None, []
# enforce limit
if limit:
records = records[:limit]
return records[-1].id, records
raise ValueError("Cannot run get_all_cursor with chroma")

View File

@ -1,11 +1,14 @@
import base64
import os
import uuid
from datetime import datetime
from typing import Dict, List, Optional
from typing import Dict, Iterator, List, Optional
import numpy as np
from sqlalchemy import (
BIGINT,
BINARY,
CHAR,
JSON,
Column,
DateTime,
@ -20,6 +23,7 @@ from sqlalchemy import (
select,
text,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import declarative_base, mapped_column, sessionmaker
from sqlalchemy.orm.session import close_all_sessions
from sqlalchemy.sql import func
@ -29,15 +33,34 @@ from tqdm import tqdm
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.metadata import EmbeddingConfigColumn
# from memgpt.schemas.message import Message, Passage, Record, RecordType, ToolCall
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ToolCall, ToolCallFunction
from memgpt.schemas.passage import Passage
from memgpt.data_types import Message, Passage, Record, RecordType, ToolCall
from memgpt.settings import settings
# Custom UUID type
class CommonUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR())
def process_bind_param(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return str(value) # Convert UUID to string for SQLite
def process_result_value(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return uuid.UUID(value)
class CommonVector(TypeDecorator):
"""Common type for representing vectors in SQLite"""
@ -70,6 +93,26 @@ class CommonVector(TypeDecorator):
# Custom serialization / de-serialization for JSON columns
class ToolCallColumn(TypeDecorator):
"""Custom type for storing List[ToolCall] as JSON"""
impl = JSON
cache_ok = True
def load_dialect_impl(self, dialect):
return dialect.type_descriptor(JSON())
def process_bind_param(self, value, dialect):
if value:
return [vars(v) for v in value]
return value
def process_result_value(self, value, dialect):
if value:
return [ToolCall(**v) for v in value]
return value
Base = declarative_base()
@ -77,8 +120,8 @@ def get_db_model(
config: MemGPTConfig,
table_name: str,
table_type: TableType,
user_id: str,
agent_id: Optional[str] = None,
user_id: uuid.UUID,
agent_id: Optional[uuid.UUID] = None,
dialect="postgresql",
):
# Define a helper function to create or get the model class
@ -97,12 +140,14 @@ def get_db_model(
__abstract__ = True # this line is necessary
# Assuming passage_id is the primary key
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(CommonUUID, nullable=False)
text = Column(String)
doc_id = Column(String)
agent_id = Column(String)
source_id = Column(String)
doc_id = Column(CommonUUID)
agent_id = Column(CommonUUID)
data_source = Column(String) # agent_name if agent, data_source name if from data source
# vector storage
if dialect == "sqlite":
@ -111,8 +156,9 @@ def get_db_model(
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
embedding_config = Column(EmbeddingConfigColumn)
metadata_ = Column(MutableJson)
# Add a datetime column, with default value as the current time
@ -127,11 +173,12 @@ def get_db_model(
return Passage(
text=self.text,
embedding=self.embedding,
embedding_config=self.embedding_config,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
doc_id=self.doc_id,
user_id=self.user_id,
id=self.id,
source_id=self.source_id,
data_source=self.data_source,
agent_id=self.agent_id,
metadata_=self.metadata_,
created_at=self.created_at,
@ -149,9 +196,11 @@ def get_db_model(
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(CommonUUID, nullable=False)
agent_id = Column(CommonUUID, nullable=False)
# openai info
role = Column(String, nullable=False)
@ -163,29 +212,31 @@ def get_db_model(
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
# TODO align with OpenAI spec of multiple tool calls
# tool_calls = Column(ToolCallColumn)
tool_calls = Column(JSON)
tool_calls = Column(ToolCallColumn)
# tool call response info
# if role == "tool", then this must be specified
# if role != "tool", this must be null
tool_call_id = Column(String)
# vector storage
if dialect == "sqlite":
embedding = Column(CommonVector)
else:
from pgvector.sqlalchemy import Vector
embedding = mapped_column(Vector(MAX_EMBEDDING_DIM))
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
# Add a datetime column, with default value as the current time
created_at = Column(DateTime(timezone=True))
Index("message_idx_user", user_id, agent_id),
def __repr__(self):
return f"<Message(message_id='{self.id}', text='{self.text}')>"
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
def to_record(self):
calls = (
[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls]
if self.tool_calls
else None
)
if calls:
assert isinstance(calls[0], ToolCall)
return Message(
user_id=self.user_id,
agent_id=self.agent_id,
@ -193,9 +244,11 @@ def get_db_model(
name=self.name,
text=self.text,
model=self.model,
# tool_calls=[ToolCall(id=tool_call["id"], function=ToolCallFunction(**tool_call["function"])) for tool_call in self.tool_calls] if self.tool_calls else None,
tool_calls=self.tool_calls,
tool_call_id=self.tool_call_id,
embedding=self.embedding,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
created_at=self.created_at,
id=self.id,
)
@ -221,7 +274,7 @@ class SQLStorageConnector(StorageConnector):
all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
return all_filters
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0):
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0) -> Iterator[List[RecordType]]:
filters = self.get_filters(filters)
while True:
# Retrieve a chunk of records with the given page_size
@ -241,8 +294,8 @@ class SQLStorageConnector(StorageConnector):
def get_all_cursor(
self,
filters: Optional[Dict] = {},
after: str = None,
before: str = None,
after: uuid.UUID = None,
before: uuid.UUID = None,
limit: Optional[int] = 1000,
order_by: str = "created_at",
reverse: bool = False,
@ -279,12 +332,12 @@ class SQLStorageConnector(StorageConnector):
return (None, [])
records = [record.to_record() for record in db_record_chunk]
next_cursor = db_record_chunk[-1].id
assert isinstance(next_cursor, str)
assert isinstance(next_cursor, uuid.UUID)
# return (cursor, list[records])
return (next_cursor, records)
def get_all(self, filters: Optional[Dict] = {}, limit=None):
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
filters = self.get_filters(filters)
with self.session_maker() as session:
if limit:
@ -293,7 +346,7 @@ class SQLStorageConnector(StorageConnector):
db_records = session.query(self.db_model).filter(*filters).all()
return [record.to_record() for record in db_records]
def get(self, id: str):
def get(self, id: uuid.UUID) -> Optional[Record]:
with self.session_maker() as session:
db_record = session.get(self.db_model, id)
if db_record is None:
@ -306,13 +359,13 @@ class SQLStorageConnector(StorageConnector):
with self.session_maker() as session:
return session.query(self.db_model).filter(*filters).count()
def insert(self, record):
def insert(self, record: Record):
raise NotImplementedError
def insert_many(self, records, show_progress=False):
def insert_many(self, records: List[RecordType], show_progress=False):
raise NotImplementedError
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
def save(self):
@ -417,7 +470,7 @@ class PostgresStorageConnector(SQLStorageConnector):
# create table
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
filters = self.get_filters(filters)
with self.session_maker() as session:
results = session.scalars(
@ -428,7 +481,7 @@ class PostgresStorageConnector(SQLStorageConnector):
records = [result.to_record() for result in results]
return records
def insert_many(self, records, exists_ok=True, show_progress=False):
def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.postgresql import insert
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
@ -450,15 +503,14 @@ class PostgresStorageConnector(SQLStorageConnector):
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
# db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
def insert(self, record, exists_ok=True):
def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)
def update(self, record):
def update(self, record: RecordType):
"""
Updates a record in the database based on the provided Record object.
"""
@ -523,12 +575,12 @@ class SQLLiteStorageConnector(SQLStorageConnector):
Base.metadata.create_all(self.engine, tables=[self.db_model.__table__]) # Create the table if it doesn't exist
self.session_maker = sessionmaker(bind=self.engine)
# import sqlite3
import sqlite3
# sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
# sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
def insert_many(self, records, exists_ok=True, show_progress=False):
def insert_many(self, records: List[RecordType], exists_ok=True, show_progress=False):
from sqlalchemy.dialects.sqlite import insert
# TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
@ -550,15 +602,14 @@ class SQLLiteStorageConnector(SQLStorageConnector):
with self.session_maker() as session:
iterable = tqdm(records) if show_progress else records
for record in iterable:
# db_record = self.db_model(**vars(record))
db_record = self.db_model(**record.dict())
db_record = self.db_model(**vars(record))
session.add(db_record)
session.commit()
def insert(self, record, exists_ok=True):
def insert(self, record: Record, exists_ok=True):
self.insert_many([record], exists_ok=exists_ok)
def update(self, record):
def update(self, record: Record):
"""
Updates an existing record in the database with values from the provided record object.
"""

View File

@ -8,7 +8,7 @@ from lancedb.pydantic import LanceModel, Vector
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.schemas.message import Message, Passage, Record
from memgpt.data_types import Message, Passage, Record
""" Initial implementation - not complete """

View File

@ -5,14 +5,10 @@ We originally tried to use Llama Index VectorIndex, but their limited API was ex
import uuid
from abc import abstractmethod
from typing import Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel
from typing import Dict, Iterator, List, Optional, Tuple, Type, Union
from memgpt.config import MemGPTConfig
from memgpt.schemas.document import Document
from memgpt.schemas.message import Message
from memgpt.schemas.passage import Passage
from memgpt.data_types import Document, Message, Passage, Record, RecordType
from memgpt.utils import printd
@ -39,7 +35,7 @@ DOCUMENT_TABLE_NAME = "memgpt_documents" # original documents (from source)
class StorageConnector:
"""Defines a DB connection that is user-specific to access data: Documents, Passages, Archival/Recall Memory"""
type: Type[BaseModel]
type: Type[Record]
def __init__(
self,
@ -140,15 +136,15 @@ class StorageConnector:
pass
@abstractmethod
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000):
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
pass
@abstractmethod
def get_all(self, filters: Optional[Dict] = {}, limit=10):
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[RecordType]:
pass
@abstractmethod
def get(self, id: uuid.UUID):
def get(self, id: uuid.UUID) -> Optional[RecordType]:
pass
@abstractmethod
@ -156,15 +152,15 @@ class StorageConnector:
pass
@abstractmethod
def insert(self, record):
def insert(self, record: RecordType):
pass
@abstractmethod
def insert_many(self, records, show_progress=False):
def insert_many(self, records: List[RecordType], show_progress=False):
pass
@abstractmethod
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
pass
@abstractmethod

View File

@ -5,7 +5,7 @@ from typing import Optional
from colorama import Fore, Style, init
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.schemas.message import Message
from memgpt.data_types import Message
init(autoreset=True)

View File

@ -3,6 +3,7 @@ import logging
import os
import subprocess
import sys
import uuid
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional
@ -18,12 +19,12 @@ from memgpt.cli.cli_config import configure
from memgpt.config import MemGPTConfig
from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import EmbeddingConfig, LLMConfig, User
from memgpt.log import get_logger
from memgpt.memory import ChatMemory
from memgpt.metadata import MetadataStore
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import OptionState
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import ChatMemory, Memory
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.models.pydantic_models import OptionState
from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.server import logger as server_logger
@ -36,6 +37,14 @@ from memgpt.utils import open_folder_in_explorer, printd
logger = get_logger(__name__)
def migrate(
debug: Annotated[bool, typer.Option(help="Print extra tracebacks for failed migrations")] = False,
):
"""Migrate old agents (pre 0.2.12) to the new database system"""
migrate_all_agents(debug=debug)
migrate_all_sources(debug=debug)
class QuickstartChoice(Enum):
openai = "openai"
# azure = "azure"
@ -171,10 +180,13 @@ def quickstart(
else:
# Load the file from the relative path
script_dir = os.path.dirname(__file__) # Get the directory where the script is located
# print("SCRIPT", script_dir)
backup_config_path = os.path.join(script_dir, "..", "configs", "memgpt_hosted.json")
# print("FILE PATH", backup_config_path)
try:
with open(backup_config_path, "r", encoding="utf-8") as file:
backup_config = json.load(file)
# print(backup_config)
printd("Loaded config file successfully.")
new_config, config_was_modified = set_config_with_dict(backup_config)
except FileNotFoundError:
@ -201,6 +213,7 @@ def quickstart(
# Parse the response content as JSON
config = response.json()
# Output a success message and the first few items in the dictionary as a sample
print("JSON config file downloaded successfully.")
new_config, config_was_modified = set_config_with_dict(config)
else:
typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED)
@ -279,6 +292,21 @@ class ServerChoice(Enum):
ws_api = "websocket"
def create_default_user_or_exit(config: MemGPTConfig, ms: MetadataStore):
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
else:
return user
else:
return user
def server(
type: Annotated[ServerChoice, typer.Option(help="Server to run")] = "rest",
port: Annotated[Optional[int], typer.Option(help="Port to run the server on")] = None,
@ -295,8 +323,8 @@ def server(
if MemGPTConfig.exists():
config = MemGPTConfig.load()
MetadataStore(config)
client = create_client() # triggers user creation
ms = MetadataStore(config)
create_default_user_or_exit(config, ms)
else:
typer.secho(f"No configuration exists. Run memgpt configure before starting the server.", fg=typer.colors.RED)
sys.exit(1)
@ -416,42 +444,42 @@ def run(
logger.setLevel(logging.CRITICAL)
server_logger.setLevel(logging.CRITICAL)
# from memgpt.migrate import (
# VERSION_CUTOFF,
# config_is_compatible,
# wipe_config_and_reconfigure,
# )
from memgpt.migrate import (
VERSION_CUTOFF,
config_is_compatible,
wipe_config_and_reconfigure,
)
# if not config_is_compatible(allow_empty=True):
# typer.secho(f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}\n", fg=typer.colors.RED)
# choices = [
# "Run the full config setup (recommended)",
# "Create a new config using defaults",
# "Cancel",
# ]
# selection = questionary.select(
# f"To use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}), or regenerate your config. Would you like to proceed?",
# choices=choices,
# default=choices[0],
# ).ask()
# if selection == choices[0]:
# try:
# wipe_config_and_reconfigure()
# except Exception as e:
# typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
# raise
# elif selection == choices[1]:
# try:
# # Don't create a config, so that the next block of code asking about quickstart is run
# wipe_config_and_reconfigure(run_configure=False, create_config=False)
# except Exception as e:
# typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
# raise
# else:
# typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED)
# raise KeyboardInterrupt()
if not config_is_compatible(allow_empty=True):
typer.secho(f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}\n", fg=typer.colors.RED)
choices = [
"Run the full config setup (recommended)",
"Create a new config using defaults",
"Cancel",
]
selection = questionary.select(
f"To use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}), or regenerate your config. Would you like to proceed?",
choices=choices,
default=choices[0],
).ask()
if selection == choices[0]:
try:
wipe_config_and_reconfigure()
except Exception as e:
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
raise
elif selection == choices[1]:
try:
# Don't create a config, so that the next block of code asking about quickstart is run
wipe_config_and_reconfigure(run_configure=False, create_config=False)
except Exception as e:
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
raise
else:
typer.secho("MemGPT config regeneration cancelled", fg=typer.colors.RED)
raise KeyboardInterrupt()
# typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN)
typer.secho("Note: if you would like to migrate old agents to the new release, please run `memgpt migrate`!", fg=typer.colors.GREEN)
if not MemGPTConfig.exists():
# if no config, ask about quickstart
@ -496,12 +524,11 @@ def run(
# read user id from config
ms = MetadataStore(config)
client = create_client()
client.user_id
user = create_default_user_or_exit(config, ms)
# determine agent to use, if not provided
if not yes and not agent:
agents = client.list_agents()
agents = ms.list_agents(user_id=user.id)
agents = [a.name for a in agents]
if len(agents) > 0:
@ -513,11 +540,7 @@ def run(
agent = questionary.select("Select agent:", choices=agents).ask()
# create agent config
if agent:
agent_id = client.get_agent_id(agent)
agent_state = client.get_agent(agent_id)
else:
agent_state = None
agent_state = ms.get_agent(agent_name=agent, user_id=user.id) if agent else None
human = human if human else config.human
persona = persona if persona else config.persona
if agent and agent_state: # use existing agent
@ -574,12 +597,13 @@ def run(
# agent_state.state["system"] = system
# Update the agent with any overrides
agent_state = client.update_agent(
agent_id=agent_state.id,
name=agent_state.name,
llm_config=agent_state.llm_config,
embedding_config=agent_state.embedding_config,
)
ms.update_agent(agent_state)
tools = []
for tool_name in agent_state.tools:
tool = ms.get_tool(tool_name, agent_state.user_id)
if tool is None:
typer.secho(f"Couldn't find tool {tool_name} in database, please run `memgpt add tool`", fg=typer.colors.RED)
tools.append(tool)
# create agent
memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
@ -622,52 +646,55 @@ def run(
llm_config.model_endpoint_type = model_endpoint_type
# create agent
client = create_client()
human_obj = client.get_human(client.get_human_id(name=human))
persona_obj = client.get_persona(client.get_persona_id(name=persona))
if human_obj is None:
typer.secho(f"Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
try:
client = create_client()
human_obj = ms.get_human(human, user.id)
persona_obj = ms.get_persona(persona, user.id)
# TODO pull system prompts from the metadata store
# NOTE: will be overriden later to a default
if system_file:
try:
with open(system_file, "r", encoding="utf-8") as file:
system = file.read().strip()
printd("Loaded system file successfully.")
except FileNotFoundError:
typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED)
system_prompt = system if system else None
if human_obj is None:
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
if persona_obj is None:
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
memory = ChatMemory(human=human_obj.text, persona=persona_obj.text, limit=core_memory_limit)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
# add tools
agent_state = client.create_agent(
name=agent_name,
system_prompt=system_prompt,
embedding_config=embedding_config,
llm_config=llm_config,
memory=memory,
metadata=metadata,
)
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
memgpt_agent = Agent(
interface=interface(),
agent_state=agent_state,
tools=tools,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
)
save_agent(agent=memgpt_agent, ms=ms)
except ValueError as e:
typer.secho(f"Failed to create agent from provided information:\n{e}", fg=typer.colors.RED)
sys.exit(1)
if persona_obj is None:
typer.secho(f"Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
sys.exit(1)
if system_file:
try:
with open(system_file, "r", encoding="utf-8") as file:
system = file.read().strip()
printd("Loaded system file successfully.")
except FileNotFoundError:
typer.secho(f"System file not found at {system_file}", fg=typer.colors.RED)
system_prompt = system if system else None
memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
# add tools
agent_state = client.create_agent(
name=agent_name,
system=system_prompt,
embedding_config=embedding_config,
llm_config=llm_config,
memory=memory,
metadata=metadata,
)
assert isinstance(agent_state.memory, Memory), f"Expected Memory, got {type(agent_state.memory)}"
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
memgpt_agent = Agent(
interface=interface(),
agent_state=agent_state,
tools=tools,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
)
save_agent(agent=memgpt_agent, ms=ms)
typer.secho(f"🎉 Created new agent '{memgpt_agent.agent_state.name}' (id={memgpt_agent.agent_state.id})", fg=typer.colors.GREEN)
# start event loop
@ -692,10 +719,19 @@ def delete_agent(
"""Delete an agent from the database"""
# use client ID is no user_id provided
config = MemGPTConfig.load()
MetadataStore(config)
client = create_client(user_id=user_id)
agent = client.get_agent_by_name(agent_name)
if not agent:
ms = MetadataStore(config)
if user_id is None:
user = create_default_user_or_exit(config, ms)
else:
user = ms.get_user(user_id=uuid.UUID(user_id))
try:
agent = ms.get_agent(agent_name=agent_name, user_id=user.id)
except Exception as e:
typer.secho(f"Failed to get agent {agent_name}\n{e}", fg=typer.colors.RED)
sys.exit(1)
if agent is None:
typer.secho(f"Couldn't find agent named '{agent_name}' to delete", fg=typer.colors.RED)
sys.exit(1)
@ -707,8 +743,7 @@ def delete_agent(
return
try:
# delete the agent
client.delete_agent(agent.id)
ms.delete_agent(agent_id=agent.id)
typer.secho(f"🕊️ Successfully deleted agent '{agent_name}' (id={agent.id})", fg=typer.colors.GREEN)
except Exception:
typer.secho(f"Failed to delete agent '{agent_name}' (id={agent.id})", fg=typer.colors.RED)

View File

@ -1,8 +1,8 @@
import ast
import builtins
import os
import uuid
from enum import Enum
from typing import Annotated, List, Optional
from typing import Annotated, Optional
import questionary
import typer
@ -13,6 +13,7 @@ from memgpt import utils
from memgpt.config import MemGPTConfig
from memgpt.constants import LLM_MAX_TOKENS, MEMGPT_DIR
from memgpt.credentials import SUPPORTED_AUTH_TYPES, MemGPTCredentials
from memgpt.data_types import EmbeddingConfig, LLMConfig, Source, User
from memgpt.llm_api.anthropic import (
anthropic_get_model_list,
antropic_get_model_context_window,
@ -35,8 +36,7 @@ from memgpt.local_llm.constants import (
DEFAULT_WRAPPER_NAME,
)
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.metadata import MetadataStore
from memgpt.server.utils import shorten_key_middle
app = typer.Typer()
@ -1070,10 +1070,17 @@ def configure():
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
config.save()
from memgpt import create_client
client = create_client()
print("User ID:", client.user_id)
# create user records
ms = MetadataStore(config)
user_id = uuid.UUID(config.anon_clientid)
user = User(
id=uuid.UUID(config.anon_clientid),
)
if ms.get_user(user_id):
# update user
ms.update_user(user)
else:
ms.create_user(user)
class ListChoice(str, Enum):
@ -1087,14 +1094,17 @@ class ListChoice(str, Enum):
def list(arg: Annotated[ListChoice, typer.Argument]):
from memgpt.client.client import create_client
client = create_client()
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
table = ColorTable(theme=Themes.OCEAN)
if arg == ListChoice.agents:
"""List all agents"""
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
for agent in tqdm(client.list_agents()):
# TODO: add this function
sources = client.list_attached_sources(agent_id=agent.id)
source_ids = client.list_attached_sources(agent_id=agent.id)
assert all([source_id is not None and isinstance(source_id, uuid.UUID) for source_id in source_ids])
sources = [client.get_source(source_id=source_id) for source_id in source_ids]
assert all([source is not None and isinstance(source, Source)] for source in sources)
source_names = [source.name for source in sources if source is not None]
table.add_row(
[
@ -1102,8 +1112,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
agent.llm_config.model,
agent.embedding_config.embedding_model,
agent.embedding_config.embedding_dim,
agent.memory.get_block("persona").value[:100] + "...",
agent.memory.get_block("human").value[:100] + "...",
agent._metadata.get("persona", ""),
agent._metadata.get("human", ""),
",".join(source_names),
utils.format_datetime(agent.created_at),
]
@ -1113,13 +1123,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
"""List all humans"""
table.field_names = ["Name", "Text"]
for human in client.list_humans():
table.add_row([human.name, human.value.replace("\n", "")[:100]])
table.add_row([human.name, human.text.replace("\n", "")[:100]])
print(table)
elif arg == ListChoice.personas:
"""List all personas"""
table.field_names = ["Name", "Text"]
for persona in client.list_personas():
table.add_row([persona.name, persona.value.replace("\n", "")[:100]])
table.add_row([persona.name, persona.text.replace("\n", "")[:100]])
print(table)
elif arg == ListChoice.sources:
"""List all data sources"""
@ -1149,63 +1159,6 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
return table
@app.command()
def add_tool(
filename: str = typer.Option(..., help="Path to the Python file containing the function"),
name: Optional[str] = typer.Option(None, help="Name of the tool"),
update: bool = typer.Option(True, help="Update the tool if it already exists"),
tags: Optional[List[str]] = typer.Option(None, help="Tags for the tool"),
):
"""Add or update a tool from a Python file."""
from memgpt.client.client import create_client
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
# 1. Parse the Python file
with open(filename, "r", encoding="utf-8") as file:
source_code = file.read()
# 2. Parse the source code to extract the function
# Note: here we assume it is one function only in the file.
module = ast.parse(source_code)
func_def = None
for node in module.body:
if isinstance(node, ast.FunctionDef):
func_def = node
break
if not func_def:
raise ValueError("No function found in the provided file")
# 3. Compile the function to make it callable
# Explanation courtesy of GPT-4:
# Compile the AST (Abstract Syntax Tree) node representing the function definition into a code object
# ast.Module creates a module node containing the function definition (func_def)
# compile converts the AST into a code object that can be executed by the Python interpreter
# The exec function executes the compiled code object in the current context,
# effectively defining the function within the current namespace
exec(compile(ast.Module([func_def], []), filename, "exec"))
# Retrieve the function object by evaluating its name in the current namespace
# eval looks up the function name in the current scope and returns the function object
func = eval(func_def.name)
# 4. Add or update the tool
tool = client.create_tool(func=func, name=name, tags=tags, update=update)
print(f"Tool {tool.name} added successfully")
@app.command()
def list_tools():
"""List all available tools."""
from memgpt.client.client import create_client
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
tools = client.list_tools()
for tool in tools:
print(f"Tool: {tool.name}")
@app.command()
def add(
option: str, # [human, persona]
@ -1221,27 +1174,23 @@ def add(
assert text is None, "Cannot specify both text and filename"
with open(filename, "r", encoding="utf-8") as f:
text = f.read()
else:
assert text is not None, "Must specify either text or filename"
if option == "persona":
persona_id = client.get_persona_id(name)
if persona_id:
client.get_persona(persona_id)
persona = client.get_persona(name)
if persona:
# config if user wants to overwrite
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
return
client.update_persona(persona_id, text=text)
client.update_persona(name=name, text=text)
else:
client.create_persona(name=name, text=text)
elif option == "human":
human_id = client.get_human_id(name)
if human_id:
human = client.get_human(human_id)
human = client.get_human(name=name)
if human:
# config if user wants to overwrite
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
return
client.update_human(human_id, text=text)
client.update_human(name=name, text=text)
else:
human = client.create_human(name=name, text=text)
else:
@ -1258,21 +1207,21 @@ def delete(option: str, name: str):
# delete from metadata
if option == "source":
# delete metadata
source_id = client.get_source_id(name)
assert source_id is not None, f"Source {name} does not exist"
client.delete_source(source_id)
source = client.get_source(name)
assert source is not None, f"Source {name} does not exist"
client.delete_source(source_id=source.id)
elif option == "agent":
agent_id = client.get_agent_id(name)
assert agent_id is not None, f"Agent {name} does not exist"
client.delete_agent(agent_id=agent_id)
agent = client.get_agent(agent_name=name)
assert agent is not None, f"Agent {name} does not exist"
client.delete_agent(agent_id=agent.id)
elif option == "human":
human_id = client.get_human_id(name)
assert human_id is not None, f"Human {name} does not exist"
client.delete_human(human_id)
human = client.get_human(name=name)
assert human is not None, f"Human {name} does not exist"
client.delete_human(name=name)
elif option == "persona":
persona_id = client.get_persona_id(name)
assert persona_id is not None, f"Persona {name} does not exist"
client.delete_persona(persona_id)
persona = client.get_persona(name=name)
assert persona is not None, f"Persona {name} does not exist"
client.delete_persona(name=name)
else:
raise ValueError(f"Option {option} not implemented")

View File

@ -13,8 +13,15 @@ from typing import Annotated, List, Optional
import typer
from memgpt import create_client
from memgpt.data_sources.connectors import DirectoryConnector
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.data_sources.connectors import (
DirectoryConnector,
VectorDBConnector,
load_data,
)
from memgpt.data_types import Source
from memgpt.metadata import MetadataStore
app = typer.Typer()
@ -82,20 +89,41 @@ def load_directory(
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None,
):
client = create_client()
# create connector
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
# create source
source = client.create_source(name=name)
# load data
try:
client.load_data(connector, source_name=name)
except Exception as e:
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
client.delete_source(source.id)
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
config = MemGPTConfig.load()
if not user_id:
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
source = Source(
name=name,
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
description=description,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
# ingest data into passage/document store
try:
num_passages, num_documents = load_data(
connector=connector,
source=source,
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except Exception as e:
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
ms.delete_source(source_id=source.id)
except ValueError as e:
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
raise
# @app.command("webpage")
@ -111,6 +139,56 @@ def load_directory(
#
# except ValueError as e:
# typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)
#
#
# @app.command("database")
# def load_database(
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
# query: Annotated[str, typer.Option(help="Database query.")],
# dump_path: Annotated[Optional[str], typer.Option(help="Path to dump file.")] = None,
# scheme: Annotated[Optional[str], typer.Option(help="Database scheme.")] = None,
# host: Annotated[Optional[str], typer.Option(help="Database host.")] = None,
# port: Annotated[Optional[int], typer.Option(help="Database port.")] = None,
# user: Annotated[Optional[str], typer.Option(help="Database user.")] = None,
# password: Annotated[Optional[str], typer.Option(help="Database password.")] = None,
# dbname: Annotated[Optional[str], typer.Option(help="Database name.")] = None,
# ):
# try:
# from llama_index.readers.database import DatabaseReader
#
# print(dump_path, scheme)
#
# if dump_path is not None:
# # read from database dump file
# from sqlalchemy import create_engine
#
# engine = create_engine(f"sqlite:///{dump_path}")
#
# db = DatabaseReader(engine=engine)
# else:
# assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
# assert scheme is not None, "Must provide database scheme."
# assert host is not None, "Must provide database host."
# assert port is not None, "Must provide database port."
# assert user is not None, "Must provide database user."
# assert password is not None, "Must provide database password."
# assert dbname is not None, "Must provide database name."
#
# db = DatabaseReader(
# scheme=scheme, # Database Scheme
# host=host, # Database Host
# port=str(port), # Database Port
# user=user, # Database User
# password=password, # Database Password
# dbname=dbname, # Database Name
# )
#
# # load data
# docs = db.load_data(query=query)
# store_docs(name, docs)
# except ValueError as e:
# typer.secho(f"Failed to load database from provided information.\n{e}", fg=typer.colors.RED)
#
@app.command("vector-database")
@ -123,44 +201,43 @@ def load_vector_database(
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
):
"""Load pre-computed embeddings into MemGPT from a database."""
raise NotImplementedError
# try:
# config = MemGPTConfig.load()
# connector = VectorDBConnector(
# uri=uri,
# table_name=table_name,
# text_column=text_column,
# embedding_column=embedding_column,
# embedding_dim=config.default_embedding_config.embedding_dim,
# )
# if not user_id:
# user_id = uuid.UUID(config.anon_clientid)
try:
config = MemGPTConfig.load()
connector = VectorDBConnector(
uri=uri,
table_name=table_name,
text_column=text_column,
embedding_column=embedding_column,
embedding_dim=config.default_embedding_config.embedding_dim,
)
if not user_id:
user_id = uuid.UUID(config.anon_clientid)
# ms = MetadataStore(config)
# source = Source(
# name=name,
# user_id=user_id,
# embedding_model=config.default_embedding_config.embedding_model,
# embedding_dim=config.default_embedding_config.embedding_dim,
# )
# ms.create_source(source)
# passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# # TODO: also get document store
ms = MetadataStore(config)
source = Source(
name=name,
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
# # ingest data into passage/document store
# try:
# num_passages, num_documents = load_data(
# connector=connector,
# source=source,
# embedding_config=config.default_embedding_config,
# document_store=None,
# passage_store=passage_storage,
# )
# print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
# except Exception as e:
# typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
# ms.delete_source(source_id=source.id)
# ingest data into passage/document store
try:
num_passages, num_documents = load_data(
connector=connector,
source=source,
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except Exception as e:
typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED)
ms.delete_source(source_id=source.id)
# except ValueError as e:
# typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
# raise
except ValueError as e:
typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
raise

View File

@ -1,3 +1,4 @@
import uuid
from typing import List, Optional
import requests
@ -5,8 +6,19 @@ from requests import HTTPError
from memgpt.functions.functions import parse_source_code
from memgpt.functions.schema_generator import generate_schema
from memgpt.schemas.api_key import APIKey, APIKeyCreate
from memgpt.schemas.user import User, UserCreate
from memgpt.server.rest_api.admin.tools import (
CreateToolRequest,
ListToolsResponse,
ToolModel,
)
from memgpt.server.rest_api.admin.users import (
CreateAPIKeyResponse,
CreateUserResponse,
DeleteAPIKeyResponse,
DeleteUserResponse,
GetAllUsersResponse,
GetAPIKeysResponse,
)
class Admin:
@ -21,7 +33,7 @@ class Admin:
self.token = token
self.headers = {"accept": "application/json", "content-type": "application/json", "authorization": f"Bearer {token}"}
def get_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[User]:
def get_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50):
params = {}
if cursor:
params["cursor"] = str(cursor)
@ -30,54 +42,54 @@ class Admin:
response = requests.get(f"{self.base_url}/admin/users", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return [User(**user) for user in response.json()]
return GetAllUsersResponse(**response.json())
def create_key(self, user_id: str, key_name: Optional[str] = None) -> APIKey:
request = APIKeyCreate(user_id=user_id, name=key_name)
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=request.model_dump())
def create_key(self, user_id: uuid.UUID, key_name: str):
payload = {"user_id": str(user_id), "key_name": key_name}
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload)
if response.status_code != 200:
raise HTTPError(response.json())
return APIKey(**response.json())
return CreateAPIKeyResponse(**response.json())
def get_keys(self, user_id: str) -> List[APIKey]:
def get_keys(self, user_id: uuid.UUID):
params = {"user_id": str(user_id)}
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return [APIKey(**key) for key in response.json()]
return GetAPIKeysResponse(**response.json()).api_key_list
def delete_key(self, api_key: str) -> APIKey:
def delete_key(self, api_key: str):
params = {"api_key": api_key}
response = requests.delete(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return APIKey(**response.json())
return DeleteAPIKeyResponse(**response.json())
def create_user(self, name: Optional[str] = None) -> User:
request = UserCreate(name=name)
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=request.model_dump())
def create_user(self, user_id: Optional[uuid.UUID] = None):
payload = {"user_id": str(user_id) if user_id else None}
response = requests.post(f"{self.base_url}/admin/users", headers=self.headers, json=payload)
if response.status_code != 200:
raise HTTPError(response.json())
response_json = response.json()
return User(**response_json)
return CreateUserResponse(**response_json)
def delete_user(self, user_id: str) -> User:
def delete_user(self, user_id: uuid.UUID):
params = {"user_id": str(user_id)}
response = requests.delete(f"{self.base_url}/admin/users", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return User(**response.json())
return DeleteUserResponse(**response.json())
def _reset_server(self):
# DANGER: this will delete all users and keys
# clear all state associated with users
# TODO: clear out all agents, presets, etc.
users = self.get_users()
users = self.get_users().user_list
for user in users:
keys = self.get_keys(user.id)
keys = self.get_keys(user["user_id"])
for key in keys:
self.delete_key(key.key)
self.delete_user(user.id)
self.delete_key(key)
self.delete_user(user["user_id"])
# tools
def create_tool(
@ -119,7 +131,7 @@ class Admin:
raise ValueError(f"Failed to create tool: {response.text}")
return ToolModel(**response.json())
def list_tools(self):
def list_tools(self) -> ListToolsResponse:
response = requests.get(f"{self.base_url}/admin/tools", headers=self.headers)
return ListToolsResponse(**response.json()).tools

File diff suppressed because it is too large Load Diff

View File

@ -4,7 +4,6 @@ import json
import os
import uuid
from dataclasses import dataclass
from typing import Optional
import memgpt
import memgpt.utils as utils
@ -16,10 +15,8 @@ from memgpt.constants import (
DEFAULT_PRESET,
MEMGPT_DIR,
)
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig
from memgpt.log import get_logger
from memgpt.schemas.agent import AgentState
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
logger = get_logger(__name__)
@ -102,19 +99,19 @@ class MemGPTConfig:
return uuid.UUID(int=uuid.getnode()).hex
@classmethod
def load(cls, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None) -> "MemGPTConfig":
def load(cls) -> "MemGPTConfig":
# avoid circular import
from memgpt.migrate import VERSION_CUTOFF, config_is_compatible
from memgpt.utils import printd
# from memgpt.migrate import VERSION_CUTOFF, config_is_compatible
# if not config_is_compatible(allow_empty=True):
# error_message = " ".join(
# [
# f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}.",
# f"\nTo use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}) or regenerate your config using `memgpt configure`, or `memgpt migrate` if you would like to migrate old agents.",
# ]
# )
# raise ValueError(error_message)
if not config_is_compatible(allow_empty=True):
error_message = " ".join(
[
f"\nYour current config file is incompatible with MemGPT versions later than {VERSION_CUTOFF}.",
f"\nTo use MemGPT, you must either downgrade your MemGPT version (<= {VERSION_CUTOFF}) or regenerate your config using `memgpt configure`, or `memgpt migrate` if you would like to migrate old agents.",
]
)
raise ValueError(error_message)
config = configparser.ConfigParser()
@ -192,9 +189,6 @@ class MemGPTConfig:
return cls(**config_dict)
# assert embedding_config is not None, "Embedding config must be provided if config does not exist"
# assert llm_config is not None, "LLM config must be provided if config does not exist"
# create new config
anon_clientid = MemGPTConfig.generate_uuid()
config = cls(anon_clientid=anon_clientid, config_path=config_path)

View File

@ -4,10 +4,8 @@ import typer
from llama_index.core import Document as LlamaIndexDocument
from memgpt.agent_store.storage import StorageConnector
from memgpt.data_types import Document, EmbeddingConfig, Passage, Source
from memgpt.embeddings import embedding_model
from memgpt.schemas.document import Document
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source
from memgpt.utils import create_uuid_from_string
@ -22,11 +20,17 @@ class DataConnector:
def load_data(
connector: DataConnector,
source: Source,
embedding_config: EmbeddingConfig,
passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
embedding_config = source.embedding_config
assert (
source.embedding_model == embedding_config.embedding_model
), f"Source and embedding config models must match, got: {source.embedding_model} and {embedding_config.embedding_model}"
assert (
source.embedding_dim == embedding_config.embedding_dim
), f"Source and embedding config dimensions must match, got: {source.embedding_dim} and {embedding_config.embedding_dim}."
# embedding model
embed_model = embedding_model(embedding_config)
@ -39,9 +43,10 @@ def load_data(
for document_text, document_metadata in connector.generate_documents():
# insert document into storage
document = Document(
id=create_uuid_from_string(f"{str(source.id)}_{document_text}"),
text=document_text,
metadata_=document_metadata,
source_id=source.id,
metadata=document_metadata,
data_source=source.name,
user_id=source.user_id,
)
document_count += 1
@ -73,15 +78,16 @@ def load_data(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
text=passage_text,
doc_id=document.id,
source_id=source.id,
metadata_=passage_metadata,
user_id=source.user_id,
embedding_config=source.embedding_config,
data_source=source.name,
embedding_dim=source.embedding_dim,
embedding_model=source.embedding_model,
embedding=embedding,
)
hashable_embedding = tuple(passage.embedding)
document_name = document.metadata_.get("file_path", document.id)
document_name = document.metadata.get("file_path", document.id)
if hashable_embedding in embedding_to_document_name:
typer.secho(
f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
@ -144,7 +150,7 @@ class DirectoryConnector(DataConnector):
parser = TokenTextSplitter(chunk_size=chunk_size)
for document in documents:
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)]
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata)]
nodes = parser.get_nodes_from_documents(llama_index_docs)
for node in nodes:
# passage = Passage(

View File

@ -1,18 +1,75 @@
""" This module contains the data types used by MemGPT. Each data type must include a function to create a DB model. """
import copy
import json
import uuid
import warnings
from datetime import datetime, timezone
from typing import List, Optional, Union
from typing import Dict, List, Optional, TypeVar
from pydantic import Field, field_validator
import numpy as np
from pydantic import BaseModel, Field
from memgpt.constants import JSON_ENSURE_ASCII, TOOL_CALL_ID_MAX_LEN
from memgpt.constants import (
DEFAULT_HUMAN,
DEFAULT_PERSONA,
DEFAULT_PRESET,
JSON_ENSURE_ASCII,
LLM_MAX_TOKENS,
MAX_EMBEDDING_DIM,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.local_llm.constants import INNER_THOUGHTS_KWARG
from memgpt.schemas.enums import MessageRole
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.openai.chat_completions import ToolCall
from memgpt.utils import get_utc_time, is_utc_datetime
from memgpt.prompts import gpt_system
from memgpt.utils import (
create_uuid_from_string,
get_human_text,
get_persona_text,
get_utc_time,
is_utc_datetime,
)
class Record:
"""
Base class for an agent's memory unit. Each memory unit is represented in the database as a single row.
Memory units are searched over by functions defined in the memory classes
"""
def __init__(self, id: Optional[uuid.UUID] = None):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
# This allows type checking to work when you pass a Passage into a function expecting List[Record]
# (just use List[RecordType] instead)
RecordType = TypeVar("RecordType", bound="Record")
class ToolCall(object):
def __init__(
self,
id: str,
# TODO should we include this? it's fixed to 'function' only (for now) in OAI schema
# NOTE: called ToolCall.type in official OpenAI schema
tool_call_type: str, # only 'function' is supported
# function: { 'name': ..., 'arguments': ...}
function: Dict[str, str],
):
self.id = id
self.tool_call_type = tool_call_type
self.function = function
def to_dict(self):
return {
"id": self.id,
"type": self.tool_call_type,
"function": self.function,
}
def add_inner_thoughts_to_tool_call(
@ -24,34 +81,20 @@ def add_inner_thoughts_to_tool_call(
# because the kwargs are stored as strings, we need to load then write the JSON dicts
try:
# load the args list
func_args = json.loads(tool_call.function.arguments)
func_args = json.loads(tool_call.function["arguments"])
# add the inner thoughts to the args list
func_args[inner_thoughts_key] = inner_thoughts
# create the updated tool call (as a string)
updated_tool_call = copy.deepcopy(tool_call)
updated_tool_call.function.arguments = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
updated_tool_call.function["arguments"] = json.dumps(func_args, ensure_ascii=JSON_ENSURE_ASCII)
return updated_tool_call
except json.JSONDecodeError as e:
# TODO: change to logging
warnings.warn(f"Failed to put inner thoughts in kwargs: {e}")
raise e
class BaseMessage(MemGPTBase):
__id_prefix__ = "message"
class MessageCreate(BaseMessage):
"""Request to create a message"""
role: MessageRole = Field(..., description="The role of the participant.")
text: str = Field(..., description="The text of the message.")
name: Optional[str] = Field(None, description="The name of the participant.")
class Message(BaseMessage):
"""
Representation of a message sent.
class Message(Record):
"""Representation of a message sent.
Messages can be:
- agent->user (role=='agent')
@ -59,23 +102,65 @@ class Message(BaseMessage):
- or function/tool call returns (role=='function'/'tool').
"""
id: str = BaseMessage.generate_id_field()
role: MessageRole = Field(..., description="The role of the participant.")
text: str = Field(..., description="The text of the message.")
user_id: str = Field(None, description="The unique identifier of the user.")
agent_id: str = Field(None, description="The unique identifier of the agent.")
model: Optional[str] = Field(None, description="The model used to make the function call.")
name: Optional[str] = Field(None, description="The name of the participant.")
created_at: datetime = Field(default_factory=get_utc_time, description="The time the message was created.")
tool_calls: Optional[List[ToolCall]] = Field(None, description="The list of tool calls requested.")
tool_call_id: Optional[str] = Field(None, description="The id of the tool call.")
def __init__(
self,
role: str,
text: str,
user_id: Optional[uuid.UUID] = None,
agent_id: Optional[uuid.UUID] = None,
model: Optional[str] = None, # model used to make function call
name: Optional[str] = None, # optional participant name
created_at: Optional[datetime] = None,
tool_calls: Optional[List[ToolCall]] = None, # list of tool calls requested
tool_call_id: Optional[str] = None,
# tool_call_name: Optional[str] = None, # not technically OpenAI spec, but it can be helpful to have on-hand
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
embedding_model: Optional[str] = None,
id: Optional[uuid.UUID] = None,
):
super().__init__(id)
self.user_id = user_id
self.agent_id = agent_id
self.text = text
self.model = model # model name (e.g. gpt-4)
self.created_at = created_at if created_at is not None else get_utc_time()
@field_validator("role")
@classmethod
def validate_role(cls, v: str) -> str:
roles = ["system", "assistant", "user", "tool"]
assert v in roles, f"Role must be one of {roles}"
return v
# openai info
assert role in ["system", "assistant", "user", "tool"]
self.role = role # role (agent/user/function)
self.name = name
# pad and store embeddings
if isinstance(embedding, list):
embedding = np.array(embedding)
self.embedding = (
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
if self.embedding is not None:
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
# tool (i.e. function) call info (optional)
# if role == "assistant", this MAY be specified
# if role != "assistant", this must be null
assert tool_calls is None or isinstance(tool_calls, list)
if tool_calls is not None:
assert all([isinstance(tc, ToolCall) for tc in tool_calls]), f"Tool calls must be of type ToolCall, got {tool_calls}"
self.tool_calls = tool_calls
# if role == "tool", then this must be specified
# if role != "tool", this must be null
if role == "tool":
assert tool_call_id is not None
else:
assert tool_call_id is None
self.tool_call_id = tool_call_id
def to_json(self):
json_message = vars(self)
@ -88,26 +173,16 @@ class Message(BaseMessage):
json_message["created_at"] = self.created_at.isoformat()
return json_message
def to_memgpt_message(self) -> Union[List[MemGPTMessage], List[LegacyMemGPTMessage]]:
"""Convert message object (in DB format) to the style used by the original MemGPT API
NOTE: this may split the message into two pieces (e.g. if the assistant has inner thoughts + function call)
"""
raise NotImplementedError
@staticmethod
def dict_to_message(
user_id: str,
agent_id: str,
user_id: uuid.UUID,
agent_id: uuid.UUID,
openai_message_dict: dict,
model: Optional[str] = None, # model used to make function call
allow_functions_style: bool = False, # allow deprecated functions style?
created_at: Optional[datetime] = None,
):
"""Convert a ChatCompletion message object into a Message object (synced to DB)"""
if not created_at:
# timestamp for creation
created_at = get_utc_time()
assert "role" in openai_message_dict, openai_message_dict
assert "content" in openai_message_dict, openai_message_dict
@ -121,6 +196,7 @@ class Message(BaseMessage):
# Convert from 'function' response to a 'tool' response
# NOTE: this does not conventionally include a tool_call_id, it's on the caster to provide it
return Message(
created_at=created_at,
user_id=user_id,
agent_id=agent_id,
model=model,
@ -130,7 +206,6 @@ class Message(BaseMessage):
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=openai_message_dict["tool_calls"] if "tool_calls" in openai_message_dict else None,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
)
elif "function_call" in openai_message_dict and openai_message_dict["function_call"] is not None:
@ -153,6 +228,7 @@ class Message(BaseMessage):
]
return Message(
created_at=created_at,
user_id=user_id,
agent_id=agent_id,
model=model,
@ -162,7 +238,6 @@ class Message(BaseMessage):
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls,
tool_call_id=None, # NOTE: None, since this field is only non-null for role=='tool'
created_at=created_at,
)
else:
@ -185,6 +260,7 @@ class Message(BaseMessage):
# If we're going from tool-call style
return Message(
created_at=created_at,
user_id=user_id,
agent_id=agent_id,
model=model,
@ -194,7 +270,6 @@ class Message(BaseMessage):
name=openai_message_dict["name"] if "name" in openai_message_dict else None,
tool_calls=tool_calls,
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
created_at=created_at,
)
def to_openai_dict_search_results(self, max_tool_id_length: int = TOOL_CALL_ID_MAX_LEN) -> dict:
@ -248,11 +323,11 @@ class Message(BaseMessage):
tool_call,
inner_thoughts=self.text,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
).model_dump()
).to_dict()
for tool_call in self.tool_calls
]
else:
openai_message["tool_calls"] = [tool_call.model_dump() for tool_call in self.tool_calls]
openai_message["tool_calls"] = [tool_call.to_dict() for tool_call in self.tool_calls]
if max_tool_id_length:
for tool_call_dict in openai_message["tool_calls"]:
tool_call_dict["id"] = tool_call_dict["id"][:max_tool_id_length]
@ -548,3 +623,313 @@ class Message(BaseMessage):
raise ValueError(self.role)
return cohere_message
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None, metadata: Optional[Dict] = {}):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
self.id = create_uuid_from_string("".join([text, str(user_id)]))
else:
self.id = id
super().__init__(id)
self.user_id = user_id
self.text = text
self.data_source = data_source
self.metadata = metadata
# TODO: add optional embedding?
class Passage(Record):
"""A passage is a single unit of memory, and a standard format accross all storage backends.
It is a string of text with an assoidciated embedding.
"""
def __init__(
self,
text: str,
user_id: Optional[uuid.UUID] = None,
agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
embedding_model: Optional[str] = None,
data_source: Optional[str] = None, # None if created by agent
doc_id: Optional[uuid.UUID] = None,
id: Optional[uuid.UUID] = None,
metadata_: Optional[dict] = {},
created_at: Optional[datetime] = None,
):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
# TODO: use source-id instead?
if agent_id:
self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)]))
else:
self.id = create_uuid_from_string("".join([text, str(user_id)]))
else:
self.id = id
super().__init__(self.id)
self.user_id = user_id
self.agent_id = agent_id
self.text = text
self.data_source = data_source
self.doc_id = doc_id
self.metadata_ = metadata_
# pad and store embeddings
if isinstance(embedding, list):
embedding = np.array(embedding)
self.embedding = (
np.pad(embedding, (0, MAX_EMBEDDING_DIM - embedding.shape[0]), mode="constant").tolist() if embedding is not None else None
)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
self.created_at = created_at if created_at is not None else get_utc_time()
if self.embedding is not None:
assert self.embedding_dim, f"Must specify embedding_dim if providing an embedding"
assert self.embedding_model, f"Must specify embedding_model if providing an embedding"
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type"
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type"
assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type"
class LLMConfig:
def __init__(
self,
model: Optional[str] = None,
model_endpoint_type: Optional[str] = None,
model_endpoint: Optional[str] = None,
model_wrapper: Optional[str] = None,
context_window: Optional[int] = None,
):
self.model = model
self.model_endpoint_type = model_endpoint_type
self.model_endpoint = model_endpoint
self.model_wrapper = model_wrapper
self.context_window = context_window
if context_window is None:
self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"]
else:
self.context_window = context_window
class EmbeddingConfig:
def __init__(
self,
embedding_endpoint_type: Optional[str] = None,
embedding_endpoint: Optional[str] = None,
embedding_model: Optional[str] = None,
embedding_dim: Optional[int] = None,
embedding_chunk_size: Optional[int] = 300,
):
self.embedding_endpoint_type = embedding_endpoint_type
self.embedding_endpoint = embedding_endpoint
self.embedding_model = embedding_model
self.embedding_dim = embedding_dim
self.embedding_chunk_size = embedding_chunk_size
# fields cannot be set to None
assert self.embedding_endpoint_type
assert self.embedding_dim
assert self.embedding_chunk_size
class OpenAIEmbeddingConfig(EmbeddingConfig):
def __init__(self, openai_key: Optional[str] = None, **kwargs):
super().__init__(**kwargs)
self.openai_key = openai_key
class AzureEmbeddingConfig(EmbeddingConfig):
def __init__(
self,
azure_key: Optional[str] = None,
azure_endpoint: Optional[str] = None,
azure_version: Optional[str] = None,
azure_deployment: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.azure_key = azure_key
self.azure_endpoint = azure_endpoint
self.azure_version = azure_version
self.azure_deployment = azure_deployment
class User:
"""Defines user and default configurations"""
# TODO: make sure to encrypt/decrypt keys before storing in DB
def __init__(
self,
# name: str,
id: Optional[uuid.UUID] = None,
default_agent=None,
# other
policies_accepted=False,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
self.default_agent = default_agent
# misc
self.policies_accepted = policies_accepted
class AgentState:
def __init__(
self,
name: str,
user_id: uuid.UUID,
# tools
tools: List[str], # list of tools by name
# system prompt
system: str,
# config
llm_config: LLMConfig,
embedding_config: EmbeddingConfig,
# (in-context) state contains:
id: Optional[uuid.UUID] = None,
state: Optional[dict] = None,
created_at: Optional[datetime] = None,
# messages (TODO: implement this)
_metadata: Optional[dict] = None,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
# TODO(swooders) we need to handle the case where name is None here
# in AgentConfig we autogenerate a name, not sure what the correct thing w/ DBs is, what about NounAdjective combos? Like giphy does? BoredGiraffe etc
self.name = name
assert self.name, f"AgentState name must be a non-empty string"
self.user_id = user_id
# The INITIAL values of the persona and human
# The values inside self.state['persona'], self.state['human'] are the CURRENT values
self.llm_config = llm_config
self.embedding_config = embedding_config
self.created_at = created_at if created_at is not None else get_utc_time()
# state
self.state = {} if not state else state
# tools
self.tools = tools
# system
self.system = system
assert self.system is not None, f"Must provide system prompt, cannot be None"
# metadata
self._metadata = _metadata
class Source:
def __init__(
self,
user_id: uuid.UUID,
name: str,
description: Optional[str] = None,
created_at: Optional[datetime] = None,
id: Optional[uuid.UUID] = None,
# embedding info
embedding_model: Optional[str] = None,
embedding_dim: Optional[int] = None,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
self.name = name
self.user_id = user_id
self.description = description
self.created_at = created_at if created_at is not None else get_utc_time()
# embedding info (optional)
self.embedding_dim = embedding_dim
self.embedding_model = embedding_model
class Token:
def __init__(
self,
user_id: uuid.UUID,
token: str,
name: Optional[str] = None,
id: Optional[uuid.UUID] = None,
):
if id is None:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert isinstance(user_id, uuid.UUID), f"UUID {user_id} must be a UUID type"
self.token = token
self.user_id = user_id
self.name = name
class Preset(BaseModel):
# TODO: remove Preset
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: str = Field(
gpt_system.get_system_text(DEFAULT_PRESET), description="The system prompt of the preset."
) # default system prompt is same as default preset name
# system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# functions: List[str] = Field(..., description="The functions of the preset.") # TODO: convert to ID
# sources: List[str] = Field(..., description="The sources of the preset.") # TODO: convert to ID
@staticmethod
def clone(preset_obj: "Preset", new_name_suffix: str = None) -> "Preset":
"""
Takes a Preset object and an optional new name suffix as input,
creates a clone of the given Preset object with a new ID and an optional new name,
and returns the new Preset object.
"""
new_preset = preset_obj.model_copy()
new_preset.id = uuid.uuid4()
if new_name_suffix:
new_preset.name = f"{preset_obj.name}_{new_name_suffix}"
else:
new_preset.name = f"{preset_obj.name}_{str(uuid.uuid4())[:8]}"
return new_preset
class Function(BaseModel):
name: str = Field(..., description="The name of the function.")
id: uuid.UUID = Field(..., description="The unique identifier of the function.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the function.")
# TODO: figure out how represent functions

View File

@ -22,7 +22,7 @@ from memgpt.constants import (
MAX_EMBEDDING_DIM,
)
from memgpt.credentials import MemGPTCredentials
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.data_types import EmbeddingConfig
from memgpt.utils import is_valid_url, printd

View File

@ -11,8 +11,8 @@ from memgpt.constants import (
MESSAGE_CHATGPT_FUNCTION_MODEL,
MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE,
)
from memgpt.data_types import Message
from memgpt.llm_api.llm_api_tools import create
from memgpt.schemas.message import Message
def message_chatgpt(self, message: str):

View File

@ -1,6 +1,6 @@
import inspect
import typing
from typing import Any, Dict, Optional, Type, get_args, get_origin
from typing import Optional, get_args, get_origin
from docstring_parser import parse
from pydantic import BaseModel
@ -144,39 +144,3 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
schema["parameters"]["required"].append(FUNCTION_PARAM_NAME_REQ_HEARTBEAT)
return schema
def generate_schema_from_args_schema(
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None
) -> Dict[str, Any]:
properties = {}
required = []
for field_name, field in args_schema.__fields__.items():
properties[field_name] = {"type": field.type_.__name__, "description": field.field_info.description}
if field.required:
required.append(field_name)
# Construct the OpenAI function call JSON object
function_call_json = {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": properties, "required": required},
}
return function_call_json
def generate_tool_wrapper(tool_name: str) -> str:
import_statement = f"from crewai_tools import {tool_name}"
tool_instantiation = f"tool = {tool_name}()"
run_call = f"return tool._run(**kwargs)"
func_name = f"run_{tool_name.lower()}"
# Combine all parts into the wrapper function
wrapper_function_str = f"""
def {func_name}(**kwargs):
{import_statement}
{tool_instantiation}
{run_call}
"""
return func_name, wrapper_function_str

View File

@ -6,7 +6,7 @@ from typing import List, Optional
from colorama import Fore, Style, init
from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT
from memgpt.schemas.message import Message
from memgpt.data_types import Message
from memgpt.utils import printd
init(autoreset=True)

View File

@ -6,17 +6,17 @@ from typing import List, Optional, Union
import requests
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.data_types import Message
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.models.chat_completion_response import (
ChatCompletionResponse,
Choice,
FunctionCall,
)
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.models.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
)
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from memgpt.models.chat_completion_response import ToolCall, UsageStatistics
from memgpt.utils import get_utc_time, smart_urljoin
BASE_URL = "https://api.anthropic.com/v1"

View File

@ -2,8 +2,8 @@ from typing import Union
import requests
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
from memgpt.schemas.openai.embedding_response import EmbeddingResponse
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.utils import smart_urljoin
MODEL_TO_AZURE_ENGINE = {

View File

@ -5,18 +5,18 @@ from typing import List, Optional, Union
import requests
from memgpt.constants import JSON_ENSURE_ASCII
from memgpt.data_types import Message
from memgpt.local_llm.utils import count_tokens
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.models.chat_completion_request import ChatCompletionRequest, Tool
from memgpt.models.chat_completion_response import (
ChatCompletionResponse,
Choice,
FunctionCall,
)
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.models.chat_completion_response import (
Message as ChoiceMessage, # NOTE: avoid conflict with our own MemGPT Message datatype
)
from memgpt.schemas.openai.chat_completion_response import ToolCall, UsageStatistics
from memgpt.models.chat_completion_response import ToolCall, UsageStatistics
from memgpt.utils import get_tool_call_id, get_utc_time, smart_urljoin
BASE_URL = "https://api.cohere.ai/v1"

View File

@ -7,8 +7,8 @@ import requests
from memgpt.constants import JSON_ENSURE_ASCII, NON_USER_MSG_PREFIX
from memgpt.local_llm.json_parser import clean_json_string_extra_backslash
from memgpt.local_llm.utils import count_tokens
from memgpt.schemas.openai.chat_completion_request import Tool
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.models.chat_completion_request import Tool
from memgpt.models.chat_completion_response import (
ChatCompletionResponse,
Choice,
FunctionCall,

View File

@ -11,6 +11,7 @@ import requests
from memgpt.constants import CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import Message
from memgpt.llm_api.anthropic import anthropic_chat_completions_request
from memgpt.llm_api.azure_openai import (
MODEL_TO_AZURE_ENGINE,
@ -30,15 +31,13 @@ from memgpt.local_llm.constants import (
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
)
from memgpt.schemas.enums import OptionState
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import (
from memgpt.models.chat_completion_request import (
ChatCompletionRequest,
Tool,
cast_message_to_subtype,
)
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.models.pydantic_models import LLMConfigModel, OptionState
from memgpt.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
@ -229,7 +228,7 @@ def retry_with_exponential_backoff(
@retry_with_exponential_backoff
def create(
# agent_state: AgentState,
llm_config: LLMConfig,
llm_config: LLMConfigModel,
messages: List[Message],
user_id: uuid.UUID = None, # option UUID to associate request with
functions: list = None,
@ -260,6 +259,8 @@ def create(
printd("unsetting function_call because functions is None")
function_call = None
# print("HELLO")
# openai
if llm_config.model_endpoint_type == "openai":

View File

@ -7,8 +7,8 @@ from httpx_sse import connect_sse
from httpx_sse._exceptions import SSEError
from memgpt.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.models.chat_completion_request import ChatCompletionRequest
from memgpt.models.chat_completion_response import (
ChatCompletionChunkResponse,
ChatCompletionResponse,
Choice,
@ -17,7 +17,7 @@ from memgpt.schemas.openai.chat_completion_response import (
ToolCall,
UsageStatistics,
)
from memgpt.schemas.openai.embedding_response import EmbeddingResponse
from memgpt.models.embedding_response import EmbeddingResponse
from memgpt.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
@ -89,7 +89,6 @@ def openai_chat_completions_process_stream(
on the chunks received from the OpenAI-compatible server POST SSE response.
"""
assert chat_completion_request.stream == True
assert stream_inferface is not None, "Required"
# Count the prompt tokens
# TODO move to post-request?
@ -371,10 +370,7 @@ def openai_chat_completions_request(
url = smart_urljoin(url, "chat/completions")
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
data = chat_completion_request.model_dump(exclude_none=True)
# add check otherwise will cause error: "Invalid value for 'parallel_tool_calls': 'parallel_tool_calls' is only allowed when 'tools' are specified."
if chat_completion_request.tools is not None:
data["parallel_tool_calls"] = False
data["parallel_tool_calls"] = False
printd("Request:\n", json.dumps(data, indent=2))
@ -390,7 +386,7 @@ def openai_chat_completions_request(
printd(f"Sending request to {url}")
try:
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}, response.text = {response.text}")
# printd(f"response = {response}, response.text = {response.text}")
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string

View File

@ -25,14 +25,14 @@ from memgpt.local_llm.webui.api import get_webui_completion
from memgpt.local_llm.webui.legacy_api import (
get_webui_completion as get_webui_completion_legacy,
)
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.models.chat_completion_response import (
ChatCompletionResponse,
Choice,
Message,
ToolCall,
UsageStatistics,
)
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from memgpt.utils import get_tool_call_id, get_utc_time
has_shown_warning = False

View File

@ -11,12 +11,20 @@ from rich.console import Console
import memgpt.agent as agent
import memgpt.errors as errors
import memgpt.system as system
from memgpt.agent_store.storage import StorageConnector, TableType
# import benchmark
from memgpt import create_client
from memgpt.benchmark.benchmark import bench
from memgpt.cli.cli import delete_agent, open_folder, quickstart, run, server, version
from memgpt.cli.cli_config import add, add_tool, configure, delete, list, list_tools
from memgpt.cli.cli import (
delete_agent,
migrate,
open_folder,
quickstart,
run,
server,
version,
)
from memgpt.cli.cli_config import add, configure, delete, list
from memgpt.cli.cli_load import app as load_app
from memgpt.config import MemGPTConfig
from memgpt.constants import (
@ -26,7 +34,7 @@ from memgpt.constants import (
REQ_HEARTBEAT_MESSAGE,
)
from memgpt.metadata import MetadataStore
from memgpt.schemas.enums import OptionState
from memgpt.models.pydantic_models import OptionState
# from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.streaming_interface import AgentRefreshStreamingInterface
@ -39,14 +47,14 @@ app.command(name="version")(version)
app.command(name="configure")(configure)
app.command(name="list")(list)
app.command(name="add")(add)
app.command(name="add-tool")(add_tool)
app.command(name="list-tools")(list_tools)
app.command(name="delete")(delete)
app.command(name="server")(server)
app.command(name="folder")(open_folder)
app.command(name="quickstart")(quickstart)
# load data commands
app.add_typer(load_app, name="load")
# migration command
app.command(name="migrate")(migrate)
# benchmark command
app.command(name="benchmark")(bench)
# delete agents
@ -95,12 +103,7 @@ def run_agent_loop(
print()
multiline_input = False
# create client
client = create_client()
ms = MetadataStore(config) # TODO: remove
# run loops
ms = MetadataStore(config)
while True:
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
# Ask for user input
@ -148,8 +151,8 @@ def run_agent_loop(
# TODO: check to ensure source embedding dimentions/model match agents, and disallow attachment if not
# TODO: alternatively, only list sources with compatible embeddings, and print warning about non-compatible sources
sources = client.list_sources()
if len(sources) == 0:
data_source_options = ms.list_sources(user_id=memgpt_agent.agent_state.user_id)
if len(data_source_options) == 0:
typer.secho(
'No sources available. You must load a souce with "memgpt load ..." before running /attach.',
fg=typer.colors.RED,
@ -160,8 +163,11 @@ def run_agent_loop(
# determine what sources are valid to be attached to this agent
valid_options = []
invalid_options = []
for source in sources:
if source.embedding_config == memgpt_agent.agent_state.embedding_config:
for source in data_source_options:
if (
source.embedding_model == memgpt_agent.agent_state.embedding_config.embedding_model
and source.embedding_dim == memgpt_agent.agent_state.embedding_config.embedding_dim
):
valid_options.append(source.name)
else:
# print warning about invalid sources
@ -175,7 +181,11 @@ def run_agent_loop(
data_source = questionary.select("Select data source", choices=valid_options).ask()
# attach new data
client.attach_source_to_agent(agent_id=memgpt_agent.agent_state.id, source_name=data_source)
# attach(memgpt_agent.agent_state.name, data_source)
source_connector = StorageConnector.get_storage_connector(
TableType.PASSAGES, config, user_id=memgpt_agent.agent_state.user_id
)
memgpt_agent.attach_source(data_source, source_connector, ms)
continue
@ -420,10 +430,8 @@ def run_agent_loop(
skip_verify=no_verify,
stream=stream,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
ms=ms,
)
agent.save_agent(memgpt_agent, ms)
skip_next_user_input = False
if token_warning:
user_message = system.get_token_limit_warning()

View File

@ -3,14 +3,13 @@ import uuid
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel, validator
from memgpt.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
from memgpt.data_types import AgentState, Message, Passage
from memgpt.embeddings import embedding_model, parse_and_chunk_text, query_embedding
from memgpt.llm_api.llm_api_tools import create
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt.schemas.agent import AgentState
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.schemas.passage import Passage
from memgpt.utils import (
count_tokens,
extract_date_from_timestamp,
@ -19,135 +18,125 @@ from memgpt.utils import (
validate_date_format,
)
# class MemoryModule(BaseModel):
# """Base class for memory modules"""
#
# description: Optional[str] = None
# limit: int = 2000
# value: Optional[Union[List[str], str]] = None
#
# def __setattr__(self, name, value):
# """Run validation if self.value is updated"""
# super().__setattr__(name, value)
# if name == "value":
# # run validation
# self.__class__.validate(self.dict(exclude_unset=True))
#
# @validator("value", always=True)
# def check_value_length(cls, v, values):
# if v is not None:
# # Fetching the limit from the values dictionary
# limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
#
# # Check if the value exceeds the limit
# if isinstance(v, str):
# length = len(v)
# elif isinstance(v, list):
# length = sum(len(item) for item in v)
# else:
# raise ValueError("Value must be either a string or a list of strings.")
#
# if length > limit:
# error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# # TODO: add archival memory error?
# raise ValueError(error_msg)
# return v
#
# def __len__(self):
# return len(str(self))
#
# def __str__(self) -> str:
# if isinstance(self.value, list):
# return ",".join(self.value)
# elif isinstance(self.value, str):
# return self.value
# else:
# return ""
#
#
# class BaseMemory:
#
# def __init__(self):
# self.memory = {}
#
# @classmethod
# def load(cls, state: dict):
# """Load memory from dictionary object"""
# obj = cls()
# for key, value in state.items():
# obj.memory[key] = MemoryModule(**value)
# return obj
#
# def __str__(self) -> str:
# """Representation of the memory in-context"""
# section_strs = []
# for section, module in self.memory.items():
# section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
# return "\n".join(section_strs)
#
# def to_dict(self):
# """Convert to dictionary representation"""
# return {key: value.dict() for key, value in self.memory.items()}
#
#
# class ChatMemory(BaseMemory):
#
# def __init__(self, persona: str, human: str, limit: int = 2000):
# self.memory = {
# "persona": MemoryModule(name="persona", value=persona, limit=limit),
# "human": MemoryModule(name="human", value=human, limit=limit),
# }
#
# def core_memory_append(self, name: str, content: str) -> Optional[str]:
# """
# Append to the contents of core memory.
#
# Args:
# name (str): Section of the memory to be edited (persona or human).
# content (str): Content to write to the memory. All unicode (including emojis) are supported.
#
# Returns:
# Optional[str]: None is always returned as this function does not produce a response.
# """
# self.memory[name].value += "\n" + content
# return None
#
# def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
# """
# Replace the contents of core memory. To delete memories, use an empty string for new_content.
#
# Args:
# name (str): Section of the memory to be edited (persona or human).
# old_content (str): String to replace. Must be an exact match.
# new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
#
# Returns:
# Optional[str]: None is always returned as this function does not produce a response.
# """
# self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
# return None
class MemoryModule(BaseModel):
"""Base class for memory modules"""
description: Optional[str] = None
limit: int = 2000
value: Optional[Union[List[str], str]] = None
def __setattr__(self, name, value):
"""Run validation if self.value is updated"""
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True))
@validator("value", always=True)
def check_value_length(cls, v, values):
if v is not None:
# Fetching the limit from the values dictionary
limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
# Check if the value exceeds the limit
if isinstance(v, str):
length = len(v)
elif isinstance(v, list):
length = sum(len(item) for item in v)
else:
raise ValueError("Value must be either a string or a list of strings.")
if length > limit:
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# TODO: add archival memory error?
raise ValueError(error_msg)
return v
def __len__(self):
return len(str(self))
def __str__(self) -> str:
if isinstance(self.value, list):
return ",".join(self.value)
elif isinstance(self.value, str):
return self.value
else:
return ""
def get_memory_functions(cls: Memory) -> List[callable]:
class BaseMemory:
def __init__(self):
self.memory = {}
@classmethod
def load(cls, state: dict):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
obj.memory[key] = MemoryModule(**value)
return obj
def __str__(self) -> str:
"""Representation of the memory in-context"""
section_strs = []
for section, module in self.memory.items():
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
return "\n".join(section_strs)
def to_dict(self):
"""Convert to dictionary representation"""
return {key: value.dict() for key, value in self.memory.items()}
class ChatMemory(BaseMemory):
def __init__(self, persona: str, human: str, limit: int = 2000):
self.memory = {
"persona": MemoryModule(name="persona", value=persona, limit=limit),
"human": MemoryModule(name="human", value=human, limit=limit),
}
def core_memory_append(self, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
Args:
name (str): Section of the memory to be edited (persona or human).
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory[name].value += "\n" + content
return None
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
name (str): Section of the memory to be edited (persona or human).
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
return None
def get_memory_functions(cls: BaseMemory) -> List[callable]:
"""Get memory functions for a memory class"""
functions = {}
# collect base memory functions (should not be included)
base_functions = []
for func_name in dir(Memory):
funct = getattr(Memory, func_name)
if callable(funct):
base_functions.append(func_name)
for func_name in dir(cls):
if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
continue
if func_name in base_functions: # dont use BaseMemory functions
continue
func = getattr(cls, func_name)
if not callable(func): # not a function
continue
functions[func_name] = func
if callable(func):
functions[func_name] = func
return functions
@ -264,8 +253,8 @@ def summarize_messages(
+ message_sequence_to_summarize[cutoff:]
)
dummy_user_id = agent_state.user_id
dummy_agent_id = agent_state.id
dummy_user_id = uuid.uuid4()
dummy_agent_id = uuid.uuid4()
message_sequence = []
message_sequence.append(Message(user_id=dummy_user_id, agent_id=dummy_agent_id, role="system", text=summary_prompt))
if insert_acknowledgement_assistant_message:
@ -528,7 +517,8 @@ class EmbeddingArchivalMemory(ArchivalMemory):
agent_id=self.agent_state.id,
text=text,
embedding=embedding,
embedding_config=self.agent_state.embedding_config,
embedding_dim=self.agent_state.embedding_config.embedding_dim,
embedding_model=self.agent_state.embedding_config.embedding_model,
)
def save(self):

View File

@ -3,10 +3,12 @@
import os
import secrets
import traceback
import uuid
from typing import List, Optional
from sqlalchemy import (
BIGINT,
CHAR,
JSON,
Boolean,
Column,
@ -18,28 +20,58 @@ from sqlalchemy import (
desc,
func,
)
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.exc import InterfaceError, OperationalError
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.sql import func
from memgpt.config import MemGPTConfig
from memgpt.schemas.agent import AgentState
from memgpt.schemas.api_key import APIKey
from memgpt.schemas.block import Block, Human, Persona
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.job import Job
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import Memory
from memgpt.schemas.source import Source
from memgpt.schemas.tool import Tool
from memgpt.schemas.user import User
from memgpt.data_types import (
AgentState,
EmbeddingConfig,
LLMConfig,
Preset,
Source,
Token,
User,
)
from memgpt.models.pydantic_models import (
HumanModel,
JobModel,
JobStatus,
PersonaModel,
ToolModel,
)
from memgpt.settings import settings
from memgpt.utils import enforce_types, get_utc_time, printd
Base = declarative_base()
# Custom UUID type
class CommonUUID(TypeDecorator):
impl = CHAR
cache_ok = True
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR())
def process_bind_param(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return str(value) # Convert UUID to string for SQLite
def process_result_value(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return uuid.UUID(value)
class LLMConfigColumn(TypeDecorator):
"""Custom type for storing LLMConfig as JSON"""
@ -84,44 +116,48 @@ class UserModel(Base):
__tablename__ = "users"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True))
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# name = Column(String, nullable=False)
default_agent = Column(String)
# TODO: what is this?
policies_accepted = Column(Boolean, nullable=False, default=False)
def __repr__(self) -> str:
return f"<User(id='{self.id}' name='{self.name}')>"
return f"<User(id='{self.id}')>"
def to_record(self) -> User:
return User(id=self.id, name=self.name, created_at=self.created_at)
return User(
id=self.id,
# name=self.name
default_agent=self.default_agent,
policies_accepted=self.policies_accepted,
)
class APIKeyModel(Base):
class TokenModel(Base):
"""Data model for authentication tokens. One-to-many relationship with UserModel (1 User - N tokens)."""
__tablename__ = "tokens"
id = Column(String, primary_key=True)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# each api key is tied to a user account (that it validates access for)
user_id = Column(String, nullable=False)
user_id = Column(CommonUUID, nullable=False)
# the api key
key = Column(String, nullable=False)
token = Column(String, nullable=False)
# extra (optional) metadata
name = Column(String)
Index(__tablename__ + "_idx_user", user_id),
Index(__tablename__ + "_idx_key", key),
Index(__tablename__ + "_idx_token", token),
def __repr__(self) -> str:
return f"<APIKey(id='{self.id}', key='{self.key}', name='{self.name}')>"
return f"<Token(id='{self.id}', token='{self.token}', name='{self.name}')>"
def to_record(self) -> User:
return APIKey(
return Token(
id=self.id,
user_id=self.user_id,
key=self.key,
token=self.token,
name=self.name,
)
@ -140,24 +176,19 @@ class AgentModel(Base):
__tablename__ = "agents"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
description = Column(String)
# state (context compilation)
message_ids = Column(JSON)
memory = Column(JSON)
system = Column(String)
tools = Column(JSON)
created_at = Column(DateTime(timezone=True), server_default=func.now())
# configs
llm_config = Column(LLMConfigColumn)
embedding_config = Column(EmbeddingConfigColumn)
# state
metadata_ = Column(JSON)
state = Column(JSON)
_metadata = Column(JSON)
# tools
tools = Column(JSON)
@ -173,14 +204,12 @@ class AgentModel(Base):
user_id=self.user_id,
name=self.name,
created_at=self.created_at,
description=self.description,
message_ids=self.message_ids,
memory=Memory.load(self.memory), # load dictionary
system=self.system,
tools=self.tools,
llm_config=self.llm_config,
embedding_config=self.embedding_config,
metadata_=self.metadata_,
state=self.state,
tools=self.tools,
system=self.system,
_metadata=self._metadata,
)
@ -192,13 +221,13 @@ class SourceModel(Base):
# Assuming passage_id is the primary key
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True), server_default=func.now())
embedding_config = Column(EmbeddingConfigColumn)
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
description = Column(String)
metadata_ = Column(JSON)
Index(__tablename__ + "_idx_user", user_id),
# TODO: add num passages
@ -212,9 +241,9 @@ class SourceModel(Base):
user_id=self.user_id,
name=self.name,
created_at=self.created_at,
embedding_config=self.embedding_config,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
description=self.description,
metadata_=self.metadata_,
)
@ -223,116 +252,80 @@ class AgentSourceMappingModel(Base):
__tablename__ = "agent_source_mapping"
id = Column(String, primary_key=True)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
source_id = Column(String, nullable=False)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
agent_id = Column(CommonUUID, nullable=False)
source_id = Column(CommonUUID, nullable=False)
Index(__tablename__ + "_idx_user", user_id, agent_id, source_id),
def __repr__(self) -> str:
return f"<AgentSourceMapping(user_id='{self.user_id}', agent_id='{self.agent_id}', source_id='{self.source_id}')>"
class BlockModel(Base):
__tablename__ = "block"
class PresetSourceMapping(Base):
__tablename__ = "preset_source_mapping"
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
preset_id = Column(CommonUUID, nullable=False)
source_id = Column(CommonUUID, nullable=False)
Index(__tablename__ + "_idx_user", user_id, preset_id, source_id),
def __repr__(self) -> str:
return f"<PresetSourceMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', source_id='{self.source_id}')>"
# class PresetFunctionMapping(Base):
# __tablename__ = "preset_function_mapping"
#
# id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# user_id = Column(CommonUUID, nullable=False)
# preset_id = Column(CommonUUID, nullable=False)
# #function_id = Column(CommonUUID, nullable=False)
# function = Column(String, nullable=False) # TODO: convert to ID eventually
#
# def __repr__(self) -> str:
# return f"<PresetFunctionMapping(user_id='{self.user_id}', preset_id='{self.preset_id}', function_id='{self.function_id}')>"
class PresetModel(Base):
"""Defines data model for storing Preset objects"""
__tablename__ = "presets"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True, nullable=False)
value = Column(String, nullable=False)
limit = Column(BIGINT)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False)
template = Column(Boolean, default=False) # True: listed as possible human/persona
label = Column(String)
metadata_ = Column(JSON)
description = Column(String)
user_id = Column(String)
system = Column(String)
human = Column(String)
human_name = Column(String, nullable=False)
persona = Column(String)
persona_name = Column(String, nullable=False)
preset = Column(String)
created_at = Column(DateTime(timezone=True), server_default=func.now())
functions_schema = Column(JSON)
Index(__tablename__ + "_idx_user", user_id),
def __repr__(self) -> str:
return f"<Block(id='{self.id}', name='{self.name}', template='{self.template}', label='{self.label}', user_id='{self.user_id}')>"
return f"<Preset(id='{self.id}', name='{self.name}')>"
def to_record(self) -> Block:
if self.label == "persona":
return Persona(
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
description=self.description,
user_id=self.user_id,
)
elif self.label == "human":
return Human(
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
description=self.description,
user_id=self.user_id,
)
else:
raise ValueError(f"Block with label {self.label} is not supported")
class ToolModel(Base):
__tablename__ = "tools"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
user_id = Column(String)
description = Column(String)
source_type = Column(String)
source_code = Column(String)
json_schema = Column(JSON)
module = Column(String)
tags = Column(JSON)
def __repr__(self) -> str:
return f"<Tool(id='{self.id}', name='{self.name}')>"
def to_record(self) -> Tool:
return Tool(
def to_record(self) -> Preset:
return Preset(
id=self.id,
user_id=self.user_id,
name=self.name,
user_id=self.user_id,
description=self.description,
source_type=self.source_type,
source_code=self.source_code,
json_schema=self.json_schema,
module=self.module,
tags=self.tags,
)
class JobModel(Base):
__tablename__ = "jobs"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
user_id = Column(String)
status = Column(String, default=JobStatus.pending)
created_at = Column(DateTime(timezone=True), server_default=func.now())
completed_at = Column(DateTime(timezone=True), onupdate=func.now())
metadata_ = Column(JSON)
def __repr__(self) -> str:
return f"<Job(id='{self.id}', status='{self.status}')>"
def to_record(self):
return Job(
id=self.id,
user_id=self.user_id,
status=self.status,
system=self.system,
human=self.human,
persona=self.persona,
human_name=self.human_name,
persona_name=self.persona_name,
preset=self.preset,
created_at=self.created_at,
completed_at=self.completed_at,
metadata_=self.metadata_,
functions_schema=self.functions_schema,
)
@ -364,8 +357,11 @@ class MetadataStore:
AgentModel.__table__,
SourceModel.__table__,
AgentSourceMappingModel.__table__,
APIKeyModel.__table__,
BlockModel.__table__,
TokenModel.__table__,
PresetModel.__table__,
PresetSourceMapping.__table__,
HumanModel.__table__,
PersonaModel.__table__,
ToolModel.__table__,
JobModel.__table__,
],
@ -391,17 +387,16 @@ class MetadataStore:
self.session_maker = sessionmaker(bind=self.engine)
@enforce_types
def create_api_key(self, user_id: str, name: str) -> APIKey:
def create_api_key(self, user_id: uuid.UUID, name: Optional[str] = None) -> Token:
"""Create an API key for a user"""
new_api_key = generate_api_key()
with self.session_maker() as session:
if session.query(APIKeyModel).filter(APIKeyModel.key == new_api_key).count() > 0:
if session.query(TokenModel).filter(TokenModel.token == new_api_key).count() > 0:
# NOTE duplicate API keys / tokens should never happen, but if it does don't allow it
raise ValueError(f"Token {new_api_key} already exists")
# TODO store the API keys as hashed
assert user_id and name, "User ID and name must be provided"
token = APIKey(user_id=user_id, key=new_api_key, name=name)
session.add(APIKeyModel(**vars(token)))
token = Token(user_id=user_id, token=new_api_key, name=name)
session.add(TokenModel(**vars(token)))
session.commit()
return self.get_api_key(api_key=new_api_key)
@ -409,22 +404,22 @@ class MetadataStore:
def delete_api_key(self, api_key: str):
"""Delete an API key from the database"""
with self.session_maker() as session:
session.query(APIKeyModel).filter(APIKeyModel.key == api_key).delete()
session.query(TokenModel).filter(TokenModel.token == api_key).delete()
session.commit()
@enforce_types
def get_api_key(self, api_key: str) -> Optional[APIKey]:
def get_api_key(self, api_key: str) -> Optional[Token]:
with self.session_maker() as session:
results = session.query(APIKeyModel).filter(APIKeyModel.key == api_key).all()
results = session.query(TokenModel).filter(TokenModel.token == api_key).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}" # should only be one result
return results[0].to_record()
@enforce_types
def get_all_api_keys_for_user(self, user_id: str) -> List[APIKey]:
def get_all_api_keys_for_user(self, user_id: uuid.UUID) -> List[Token]:
with self.session_maker() as session:
results = session.query(APIKeyModel).filter(APIKeyModel.user_id == user_id).all()
results = session.query(TokenModel).filter(TokenModel.user_id == user_id).all()
tokens = [r.to_record() for r in results]
return tokens
@ -441,20 +436,25 @@ class MetadataStore:
def create_agent(self, agent: AgentState):
# insert into agent table
# make sure agent.name does not already exist for user user_id
assert agent.state is not None, "Agent state must be provided"
assert len(list(agent.state.keys())) > 0, "Agent state must not be empty"
with self.session_maker() as session:
if session.query(AgentModel).filter(AgentModel.name == agent.name).filter(AgentModel.user_id == agent.user_id).count() > 0:
raise ValueError(f"Agent with name {agent.name} already exists")
fields = vars(agent)
fields["memory"] = agent.memory.to_dict()
session.add(AgentModel(**fields))
session.add(AgentModel(**vars(agent)))
session.commit()
@enforce_types
def create_source(self, source: Source):
def create_source(self, source: Source, exists_ok=False):
# make sure source.name does not already exist for user
with self.session_maker() as session:
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
session.add(SourceModel(**vars(source)))
if not exists_ok:
raise ValueError(f"Source with name {source.name} already exists for user {source.user_id}")
else:
session.update(SourceModel(**vars(source)))
else:
session.add(SourceModel(**vars(source)))
session.commit()
@enforce_types
@ -466,40 +466,67 @@ class MetadataStore:
session.commit()
@enforce_types
def create_block(self, block: Block):
def create_preset(self, preset: Preset):
with self.session_maker() as session:
# TODO: fix?
# we are only validating that more than one template block
# with a given name doesn't exist.
if (
session.query(BlockModel)
.filter(BlockModel.name == block.name)
.filter(BlockModel.user_id == block.user_id)
.filter(BlockModel.template == True)
.filter(BlockModel.label == block.label)
.count()
> 0
):
raise ValueError(f"Block with name {block.name} already exists")
session.add(BlockModel(**vars(block)))
if session.query(PresetModel).filter(PresetModel.id == preset.id).count() > 0:
raise ValueError(f"User with id {preset.id} already exists")
session.add(PresetModel(**vars(preset)))
session.commit()
@enforce_types
def create_tool(self, tool: Tool):
def get_preset(
self, preset_id: Optional[uuid.UUID] = None, name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
) -> Optional[Preset]:
with self.session_maker() as session:
if self.get_tool(tool_name=tool.name, user_id=tool.user_id) is not None:
raise ValueError(f"Tool with name {tool.name} already exists")
session.add(ToolModel(**vars(tool)))
if preset_id:
results = session.query(PresetModel).filter(PresetModel.id == preset_id).all()
elif name and user_id:
results = session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).all()
else:
raise ValueError("Must provide either preset_id or (preset_name and user_id)")
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
# @enforce_types
# def set_preset_functions(self, preset_id: uuid.UUID, functions: List[str]):
# preset = self.get_preset(preset_id)
# if preset is None:
# raise ValueError(f"Preset with id {preset_id} does not exist")
# user_id = preset.user_id
# with self.session_maker() as session:
# for function in functions:
# session.add(PresetFunctionMapping(user_id=user_id, preset_id=preset_id, function=function))
# session.commit()
@enforce_types
def set_preset_sources(self, preset_id: uuid.UUID, sources: List[uuid.UUID]):
preset = self.get_preset(preset_id)
if preset is None:
raise ValueError(f"Preset with id {preset_id} does not exist")
user_id = preset.user_id
with self.session_maker() as session:
for source_id in sources:
session.add(PresetSourceMapping(user_id=user_id, preset_id=preset_id, source_id=source_id))
session.commit()
# @enforce_types
# def get_preset_functions(self, preset_id: uuid.UUID) -> List[str]:
# with self.session_maker() as session:
# results = session.query(PresetFunctionMapping).filter(PresetFunctionMapping.preset_id == preset_id).all()
# return [r.function for r in results]
@enforce_types
def get_preset_sources(self, preset_id: uuid.UUID) -> List[uuid.UUID]:
with self.session_maker() as session:
results = session.query(PresetSourceMapping).filter(PresetSourceMapping.preset_id == preset_id).all()
return [r.source_id for r in results]
@enforce_types
def update_agent(self, agent: AgentState):
with self.session_maker() as session:
fields = vars(agent)
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
fields["memory"] = agent.memory.to_dict()
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.query(AgentModel).filter(AgentModel.id == agent.id).update(vars(agent))
session.commit()
@enforce_types
@ -515,41 +542,28 @@ class MetadataStore:
session.commit()
@enforce_types
def update_block(self, block: Block):
def update_human(self, human: HumanModel):
with self.session_maker() as session:
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
session.add(human)
session.commit()
session.refresh(human)
@enforce_types
def update_or_create_block(self, block: Block):
def update_persona(self, persona: PersonaModel):
with self.session_maker() as session:
existing_block = session.query(BlockModel).filter(BlockModel.id == block.id).first()
if existing_block:
session.query(BlockModel).filter(BlockModel.id == block.id).update(vars(block))
else:
session.add(BlockModel(**vars(block)))
session.add(persona)
session.commit()
session.refresh(persona)
@enforce_types
def update_tool(self, tool: Tool):
def update_tool(self, tool: ToolModel):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.id == tool.id).update(vars(tool))
session.add(tool)
session.commit()
session.refresh(tool)
@enforce_types
def delete_tool(self, tool_id: str):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.id == tool_id).delete()
session.commit()
@enforce_types
def delete_block(self, block_id: str):
with self.session_maker() as session:
session.query(BlockModel).filter(BlockModel.id == block_id).delete()
session.commit()
@enforce_types
def delete_agent(self, agent_id: str):
def delete_agent(self, agent_id: uuid.UUID):
with self.session_maker() as session:
# delete agents
@ -561,7 +575,7 @@ class MetadataStore:
session.commit()
@enforce_types
def delete_source(self, source_id: str):
def delete_source(self, source_id: uuid.UUID):
with self.session_maker() as session:
# delete from sources table
session.query(SourceModel).filter(SourceModel.id == source_id).delete()
@ -572,7 +586,7 @@ class MetadataStore:
session.commit()
@enforce_types
def delete_user(self, user_id: str):
def delete_user(self, user_id: uuid.UUID):
with self.session_maker() as session:
# delete from users table
session.query(UserModel).filter(UserModel.id == user_id).delete()
@ -589,30 +603,42 @@ class MetadataStore:
session.commit()
@enforce_types
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
def list_presets(self, user_id: uuid.UUID) -> List[Preset]:
with self.session_maker() as session:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return [r.to_record() for r in results]
@enforce_types
# def list_tools(self, user_id: uuid.UUID) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[uuid.UUID] = None) -> List[ToolModel]:
with self.session_maker() as session:
results = session.query(ToolModel).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.user_id == user_id).all()
res = [r.to_record() for r in results]
return res
return results
@enforce_types
def list_agents(self, user_id: str) -> List[AgentState]:
def list_agents(self, user_id: uuid.UUID) -> List[AgentState]:
with self.session_maker() as session:
results = session.query(AgentModel).filter(AgentModel.user_id == user_id).all()
return [r.to_record() for r in results]
@enforce_types
def list_sources(self, user_id: str) -> List[Source]:
def list_all_agents(self) -> List[AgentState]:
with self.session_maker() as session:
results = session.query(AgentModel).all()
return [r.to_record() for r in results]
@enforce_types
def list_sources(self, user_id: uuid.UUID) -> List[Source]:
with self.session_maker() as session:
results = session.query(SourceModel).filter(SourceModel.user_id == user_id).all()
return [r.to_record() for r in results]
@enforce_types
def get_agent(
self, agent_id: Optional[str] = None, agent_name: Optional[str] = None, user_id: Optional[str] = None
self, agent_id: Optional[uuid.UUID] = None, agent_name: Optional[str] = None, user_id: Optional[uuid.UUID] = None
) -> Optional[AgentState]:
with self.session_maker() as session:
if agent_id:
@ -627,7 +653,7 @@ class MetadataStore:
return results[0].to_record()
@enforce_types
def get_user(self, user_id: str) -> Optional[User]:
def get_user(self, user_id: uuid.UUID) -> Optional[User]:
with self.session_maker() as session:
results = session.query(UserModel).filter(UserModel.id == user_id).all()
if len(results) == 0:
@ -636,7 +662,7 @@ class MetadataStore:
return results[0].to_record()
@enforce_types
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
def get_all_users(self, cursor: Optional[uuid.UUID] = None, limit: Optional[int] = 50) -> (Optional[uuid.UUID], List[User]):
with self.session_maker() as session:
query = session.query(UserModel).order_by(desc(UserModel.id))
if cursor:
@ -646,13 +672,13 @@ class MetadataStore:
return None, []
user_records = [r.to_record() for r in results]
next_cursor = user_records[-1].id
assert isinstance(next_cursor, str)
assert isinstance(next_cursor, uuid.UUID)
return next_cursor, user_records
@enforce_types
def get_source(
self, source_id: Optional[str] = None, user_id: Optional[str] = None, source_name: Optional[str] = None
self, source_id: Optional[uuid.UUID] = None, user_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None
) -> Optional[Source]:
with self.session_maker() as session:
if source_id:
@ -666,89 +692,42 @@ class MetadataStore:
return results[0].to_record()
@enforce_types
def get_tool(
self, tool_name: Optional[str] = None, tool_id: Optional[str] = None, user_id: Optional[str] = None
) -> Optional[ToolModel]:
def get_tool(self, tool_name: str, user_id: Optional[uuid.UUID] = None) -> Optional[ToolModel]:
# TODO: add user_id when tools can eventually be added by users
with self.session_maker() as session:
if tool_id:
results = session.query(ToolModel).filter(ToolModel.id == tool_id).all()
else:
assert tool_name is not None
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
if user_id:
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_block(self, block_id: str) -> Optional[Block]:
with self.session_maker() as session:
results = session.query(BlockModel).filter(BlockModel.id == block_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_blocks(
self,
user_id: Optional[str],
label: Optional[str] = None,
template: bool = True,
name: Optional[str] = None,
id: Optional[str] = None,
) -> List[Block]:
"""List available blocks"""
with self.session_maker() as session:
query = session.query(BlockModel).filter(BlockModel.template == template)
results = session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == None).all()
if user_id:
query = query.filter(BlockModel.user_id == user_id)
if label:
query = query.filter(BlockModel.label == label)
if name:
query = query.filter(BlockModel.name == name)
if id:
query = query.filter(BlockModel.id == id)
results = query.all()
results += session.query(ToolModel).filter(ToolModel.name == tool_name).filter(ToolModel.user_id == user_id).all()
if len(results) == 0:
return None
return [r.to_record() for r in results]
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
# agent source metadata
@enforce_types
def attach_source(self, user_id: str, agent_id: str, source_id: str):
def attach_source(self, user_id: uuid.UUID, agent_id: uuid.UUID, source_id: uuid.UUID):
with self.session_maker() as session:
# TODO: remove this (is a hack)
mapping_id = f"{user_id}-{agent_id}-{source_id}"
session.add(AgentSourceMappingModel(id=mapping_id, user_id=user_id, agent_id=agent_id, source_id=source_id))
session.add(AgentSourceMappingModel(user_id=user_id, agent_id=agent_id, source_id=source_id))
session.commit()
@enforce_types
def list_attached_sources(self, agent_id: str) -> List[Source]:
def list_attached_sources(self, agent_id: uuid.UUID) -> List[uuid.UUID]:
with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.agent_id == agent_id).all()
sources = []
source_ids = []
# make sure source exists
for r in results:
source = self.get_source(source_id=r.source_id)
if source:
sources.append(source)
source_ids.append(r.source_id)
else:
printd(f"Warning: source {r.source_id} does not exist but exists in mapping database. This should never happen.")
return sources
return source_ids
@enforce_types
def list_attached_agents(self, source_id: str) -> List[str]:
def list_attached_agents(self, source_id: uuid.UUID) -> List[uuid.UUID]:
with self.session_maker() as session:
results = session.query(AgentSourceMappingModel).filter(AgentSourceMappingModel.source_id == source_id).all()
@ -763,7 +742,7 @@ class MetadataStore:
return agent_ids
@enforce_types
def detach_source(self, agent_id: str, source_id: str):
def detach_source(self, agent_id: uuid.UUID, source_id: uuid.UUID):
with self.session_maker() as session:
session.query(AgentSourceMappingModel).filter(
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
@ -771,38 +750,120 @@ class MetadataStore:
session.commit()
@enforce_types
def create_job(self, job: Job):
def add_human(self, human: HumanModel):
with self.session_maker() as session:
session.add(JobModel(**vars(job)))
if self.get_human(human.name, human.user_id):
raise ValueError(f"Human with name {human.name} already exists for user_id {human.user_id}")
session.add(human)
session.commit()
def delete_job(self, job_id: str):
@enforce_types
def add_persona(self, persona: PersonaModel):
with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).delete()
if self.get_persona(persona.name, persona.user_id):
raise ValueError(f"Persona with name {persona.name} already exists for user_id {persona.user_id}")
session.add(persona)
session.commit()
def get_job(self, job_id: str) -> Optional[Job]:
@enforce_types
def add_preset(self, preset: PresetModel): # TODO: remove
with self.session_maker() as session:
results = session.query(JobModel).filter(JobModel.id == job_id).all()
session.add(preset)
session.commit()
@enforce_types
def add_tool(self, tool: ToolModel):
with self.session_maker() as session:
if self.get_tool(tool.name, tool.user_id):
raise ValueError(f"Tool with name {tool.name} already exists for user_id {tool.user_id}")
session.add(tool)
session.commit()
@enforce_types
def get_human(self, name: str, user_id: uuid.UUID) -> Optional[HumanModel]:
with self.session_maker() as session:
results = session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
return results[0]
def list_jobs(self, user_id: str) -> List[Job]:
@enforce_types
def get_persona(self, name: str, user_id: uuid.UUID) -> Optional[PersonaModel]:
with self.session_maker() as session:
results = session.query(JobModel).filter(JobModel.user_id == user_id).all()
return [r.to_record() for r in results]
results = session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]
def update_job(self, job: Job) -> Job:
@enforce_types
def list_personas(self, user_id: uuid.UUID) -> List[PersonaModel]:
with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job.id).update(vars(job))
results = session.query(PersonaModel).filter(PersonaModel.user_id == user_id).all()
return results
@enforce_types
def list_humans(self, user_id: uuid.UUID) -> List[HumanModel]:
with self.session_maker() as session:
# if user_id matches provided user_id or if user_id is None
results = session.query(HumanModel).filter(HumanModel.user_id == user_id).all()
return results
@enforce_types
def list_presets(self, user_id: uuid.UUID) -> List[PresetModel]:
with self.session_maker() as session:
results = session.query(PresetModel).filter(PresetModel.user_id == user_id).all()
return results
@enforce_types
def delete_human(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(HumanModel).filter(HumanModel.name == name).filter(HumanModel.user_id == user_id).delete()
session.commit()
return Job
def update_job_status(self, job_id: str, status: JobStatus):
@enforce_types
def delete_persona(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PersonaModel).filter(PersonaModel.name == name).filter(PersonaModel.user_id == user_id).delete()
session.commit()
@enforce_types
def delete_preset(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(PresetModel).filter(PresetModel.name == name).filter(PresetModel.user_id == user_id).delete()
session.commit()
@enforce_types
def delete_tool(self, name: str, user_id: uuid.UUID):
with self.session_maker() as session:
session.query(ToolModel).filter(ToolModel.name == name).filter(ToolModel.user_id == user_id).delete()
session.commit()
# job related functions
def create_job(self, job: JobModel):
with self.session_maker() as session:
session.add(job)
session.commit()
session.expunge_all()
def update_job_status(self, job_id: uuid.UUID, status: JobStatus):
with self.session_maker() as session:
session.query(JobModel).filter(JobModel.id == job_id).update({"status": status})
if status == JobStatus.COMPLETED:
session.query(JobModel).filter(JobModel.id == job_id).update({"completed_at": get_utc_time()})
session.commit()
def update_job(self, job: JobModel):
with self.session_maker() as session:
session.add(job)
session.commit()
session.refresh(job)
def get_job(self, job_id: uuid.UUID) -> Optional[JobModel]:
with self.session_maker() as session:
results = session.query(JobModel).filter(JobModel.id == job_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0]

716
memgpt/migrate.py Normal file
View File

@ -0,0 +1,716 @@
import configparser
import glob
import json
import os
import pickle
import shutil
import sys
import traceback
import uuid
from datetime import datetime
from typing import List, Optional
import pytz
import questionary
import typer
from tqdm import tqdm
from memgpt.agent import Agent, save_agent
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.cli.cli_config import configure
from memgpt.config import MemGPTConfig
from memgpt.data_types import AgentState, Message, Passage, Source, User
from memgpt.metadata import MetadataStore
from memgpt.persistence_manager import LocalStateManager
from memgpt.utils import (
MEMGPT_DIR,
OpenAIBackcompatUnpickler,
annotate_message_json_list_with_tool_calls,
get_utc_time,
parse_formatted_time,
version_less_than,
)
# This is the version where the breaking change was made
VERSION_CUTOFF = "0.2.12"
# Migration backup dir (where we'll dump old agents that we successfully migrated)
MIGRATION_BACKUP_FOLDER = "migration_backups"
def wipe_config_and_reconfigure(data_dir: str = MEMGPT_DIR, run_configure=True, create_config=True):
"""Wipe (backup) the config file, and launch `memgpt configure`"""
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
# Get the current timestamp in a readable format (e.g., YYYYMMDD_HHMMSS)
timestamp = get_utc_time().strftime("%Y%m%d_%H%M%S")
# Construct the new backup directory name with the timestamp
backup_filename = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, f"config_backup_{timestamp}")
existing_filename = os.path.join(data_dir, "config")
# Check if the existing file exists before moving
if os.path.exists(existing_filename):
# shutil should work cross-platform
shutil.move(existing_filename, backup_filename)
typer.secho(f"Deleted config file ({existing_filename}) and saved as backup ({backup_filename})", fg=typer.colors.GREEN)
else:
typer.secho(f"Couldn't find an existing config file to delete", fg=typer.colors.RED)
if run_configure:
# Either run configure
configure()
elif create_config:
# Or create a new config with defaults
MemGPTConfig.load()
def config_is_compatible(data_dir: str = MEMGPT_DIR, allow_empty=False, echo=False) -> bool:
"""Check if the config is OK to use with 0.2.12, or if it needs to be deleted"""
# NOTE: don't use built-in load(), since that will apply defaults
# memgpt_config = MemGPTConfig.load()
memgpt_config_file = os.path.join(data_dir, "config")
if not os.path.exists(memgpt_config_file):
return True if allow_empty else False
parser = configparser.ConfigParser()
parser.read(memgpt_config_file)
if "version" in parser and "memgpt_version" in parser["version"]:
version = parser["version"]["memgpt_version"]
else:
version = None
if version is None:
# no version -- assume pre-determined config (does not need to be migrated)
return True
elif version_less_than(version, VERSION_CUTOFF):
if echo:
typer.secho(f"Current config version ({version}) is older than migration cutoff ({VERSION_CUTOFF})", fg=typer.colors.RED)
return False
else:
if echo:
typer.secho(f"Current config version {version} is compatible!", fg=typer.colors.GREEN)
return True
def agent_is_migrateable(agent_name: str, data_dir: str = MEMGPT_DIR) -> bool:
"""Determine whether or not the agent folder is a migration target"""
agent_folder = os.path.join(data_dir, "agents", agent_name)
if not os.path.exists(agent_folder):
raise ValueError(f"Folder {agent_folder} does not exist")
agent_config_file = os.path.join(agent_folder, "config.json")
if not os.path.exists(agent_config_file):
raise ValueError(f"Agent folder {agent_folder} does not have a config file")
try:
with open(agent_config_file, "r", encoding="utf-8") as fh:
agent_config = json.load(fh)
except Exception as e:
raise ValueError(f"Failed to load agent config file ({agent_config_file}), error = {e}")
if not hasattr(agent_config, "memgpt_version") or version_less_than(agent_config.memgpt_version, VERSION_CUTOFF):
return True
else:
return False
def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None):
"""
Migrate an old source folder (`~/.memgpt/sources/{source_name}`).
"""
# 1. Load the VectorIndex from ~/.memgpt/sources/{source_name}/index
# TODO
source_path = os.path.join(data_dir, "archival", source_name, "nodes.pkl")
assert os.path.exists(source_path), f"Source {source_name} does not exist at {source_path}"
# load state from old checkpoint file
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
if ms is None:
ms = MetadataStore(config)
# gets default user
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
# raise ValueError(
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
# )
# insert source into metadata store
source = Source(user_id=user.id, name=source_name)
ms.create_source(source)
try:
try:
nodes = pickle.load(open(source_path, "rb"))
except ModuleNotFoundError as e:
if "No module named 'llama_index.schema'" in str(e):
# cannot load source at all, so throw error
raise ValueError(
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
)
else:
raise e
passages = []
for node in nodes:
# print(len(node.embedding))
# TODO: make sure embedding config matches embedding size?
if len(node.embedding) != config.default_embedding_config.embedding_dim:
raise ValueError(
f"Cannot migrate source {source_name} due to incompatible embedding dimentions. Please re-load this source with `memgpt load`."
)
passages.append(
Passage(
user_id=user.id,
data_source=source_name,
text=node.text,
embedding=node.embedding,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
assert len(passages) > 0, f"Source {source_name} has no passages"
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config=config, user_id=user_id)
conn.insert_many(passages)
# print(f"Inserted {len(passages)} to {source_name}")
except Exception as e:
# delete from metadata store
ms.delete_source(source.id)
raise ValueError(f"Failed to migrate {source_name}: {str(e)}")
# basic checks
source = ms.get_source(user_id=user.id, source_name=source_name)
assert source is not None, f"Failed to load source {source_name} from database after migration"
def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[MetadataStore] = None) -> List[str]:
"""Migrate an old agent folder (`~/.memgpt/agents/{agent_name}`)
Steps:
1. Load the agent state JSON from the old folder
2. Create a new AgentState using the agent config + agent internal state
3. Instantiate a new Agent by passing AgentState to Agent.__init__
(This will automatically run into a new database)
If success, returns empty list
If warning, returns a list of strings (warning message)
If error, raises an Exception
"""
warnings = []
# 1. Load the agent state JSON from the old folder
# TODO
agent_folder = os.path.join(data_dir, "agents", agent_name)
# migration_file = os.path.join(agent_folder, MIGRATION_FILE_NAME)
# load state from old checkpoint file
agent_ckpt_directory = os.path.join(agent_folder, "agent_state")
json_files = glob.glob(os.path.join(agent_ckpt_directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}")
# NOTE this is a soft fail, just allow it to pass
# return
# return [f"Cannot load {agent_name} - no saved checkpoints found in {agent_ckpt_directory}"]
# Sort files based on modified timestamp, with the latest file being the first.
state_filename = max(json_files, key=os.path.getmtime)
state_dict = json.load(open(state_filename, "r"))
# print(state_dict.keys())
# print(state_dict["memory"])
# dict_keys(['model', 'system', 'functions', 'messages', 'messages_total', 'memory'])
# load old data from the persistence manager
persistence_filename = os.path.basename(state_filename).replace(".json", ".persistence.pickle")
persistence_filename = os.path.join(agent_folder, "persistence_manager", persistence_filename)
archival_filename = os.path.join(agent_folder, "persistence_manager", "index", "nodes.pkl")
if not os.path.exists(persistence_filename):
raise ValueError(f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}")
# return [f"Cannot load {agent_name} - no saved persistence pickle found at {persistence_filename}"]
try:
with open(persistence_filename, "rb") as f:
data = pickle.load(f)
except ModuleNotFoundError:
# Patch for stripped openai package
# ModuleNotFoundError: No module named 'openai.openai_object'
with open(persistence_filename, "rb") as f:
unpickler = OpenAIBackcompatUnpickler(f)
data = unpickler.load()
from memgpt.openai_backcompat.openai_object import OpenAIObject
def convert_openai_objects_to_dict(obj):
if isinstance(obj, OpenAIObject):
# Convert to dict or handle as needed
# print(f"detected OpenAIObject on {obj}")
return obj.to_dict_recursive()
elif isinstance(obj, dict):
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_openai_objects_to_dict(v) for v in obj]
else:
return obj
data = convert_openai_objects_to_dict(data)
# data will contain:
# print("data.keys()", data.keys())
# manager.all_messages = data["all_messages"]
# manager.messages = data["messages"]
# manager.recall_memory = data["recall_memory"]
agent_config_filename = os.path.join(agent_folder, "config.json")
with open(agent_config_filename, "r", encoding="utf-8") as fh:
agent_config = json.load(fh)
# 2. Create a new AgentState using the agent config + agent internal state
config = MemGPTConfig.load()
if ms is None:
ms = MetadataStore(config)
# gets default user
user_id = uuid.UUID(config.anon_clientid)
user = ms.get_user(user_id=user_id)
if user is None:
ms.create_user(User(id=user_id))
user = ms.get_user(user_id=user_id)
if user is None:
typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
sys.exit(1)
# raise ValueError(
# f"Failed to load user {str(user_id)} from database. Please make sure to migrate your config before migrating agents."
# )
# ms.create_user(User(id=user_id))
# user = ms.get_user(user_id=user_id)
# if user is None:
# typer.secho(f"Failed to create default user in database.", fg=typer.colors.RED)
# sys.exit(1)
# create an agent_id ahead of time
agent_id = uuid.uuid4()
# create all the Messages in the database
# message_objs = []
# for message_dict in annotate_message_json_list_with_tool_calls(state_dict["messages"]):
# message_obj = Message.dict_to_message(
# user_id=user.id,
# agent_id=agent_id,
# openai_message_dict=message_dict,
# model=state_dict["model"] if "model" in state_dict else None,
# # allow_functions_style=False,
# allow_functions_style=True,
# )
# message_objs.append(message_obj)
agent_state = AgentState(
id=agent_id,
name=agent_config["name"],
user_id=user.id,
# persona_name=agent_config["persona"], # eg 'sam_pov'
# human_name=agent_config["human"], # eg 'basic'
persona=state_dict["memory"]["persona"], # NOTE: hacky (not init, but latest)
human=state_dict["memory"]["human"], # NOTE: hacky (not init, but latest)
preset=agent_config["preset"], # eg 'memgpt_chat'
state=dict(
human=state_dict["memory"]["human"],
persona=state_dict["memory"]["persona"],
system=state_dict["system"],
functions=state_dict["functions"], # this shouldn't matter, since Agent.__init__ will re-link
# messages=[str(m.id) for m in message_objs], # this is a list of uuids, not message dicts
),
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
)
persistence_manager = LocalStateManager(agent_state=agent_state)
# First clean up the recall message history to add tool call ids
# allow_tool_roles in case some of the old messages were actually already in tool call format (for whatever reason)
full_message_history_buffer = annotate_message_json_list_with_tool_calls(
[d["message"] for d in data["all_messages"]], allow_tool_roles=True
)
for i in range(len(data["all_messages"])):
data["all_messages"][i]["message"] = full_message_history_buffer[i]
# Figure out what messages in recall are in-context, and which are out-of-context
agent_message_cache = state_dict["messages"]
recall_message_full = data["all_messages"]
def messages_are_equal(msg1, msg2):
if msg1["role"] != msg2["role"]:
return False
if msg1["content"] != msg2["content"]:
return False
if "function_call" in msg1 and "function_call" in msg2 and msg1["function_call"] != msg2["function_call"]:
return False
if "name" in msg1 and "name" in msg2 and msg1["name"] != msg2["name"]:
return False
# otherwise checks pass, ~= equal
return True
in_context_messages = []
out_of_context_messages = []
assert len(agent_message_cache) <= len(recall_message_full), (len(agent_message_cache), len(recall_message_full))
for i, d in enumerate(recall_message_full):
# unpack into "timestamp" and "message"
recall_message = d["message"]
recall_timestamp = str(d["timestamp"])
try:
recall_datetime = parse_formatted_time(recall_timestamp.strip()).astimezone(pytz.utc)
except ValueError:
recall_datetime = datetime.strptime(recall_timestamp.strip(), "%Y-%m-%d %I:%M:%S %p").astimezone(pytz.utc)
# message object
message_obj = Message.dict_to_message(
created_at=recall_datetime,
user_id=user.id,
agent_id=agent_id,
openai_message_dict=recall_message,
allow_functions_style=True,
)
# message is either in-context, or out-of-context
if i >= (len(recall_message_full) - len(agent_message_cache)):
# there are len(agent_message_cache) total messages on the agent
# this will correspond to the last N messages in the recall memory (though possibly out-of-order)
message_is_in_context = [messages_are_equal(recall_message, cache_message) for cache_message in agent_message_cache]
# assert sum(message_is_in_context) <= 1, message_is_in_context
# if any(message_is_in_context):
# in_context_messages.append(message_obj)
# else:
# out_of_context_messages.append(message_obj)
if not any(message_is_in_context):
# typer.secho(
# f"Warning: didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
# fg=typer.colors.RED,
# )
warnings.append(
f"Didn't find late buffer recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
)
out_of_context_messages.append(message_obj)
else:
if sum(message_is_in_context) > 1:
# typer.secho(
# f"Warning: found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}",
# fg=typer.colors.RED,
# )
warnings.append(
f"Found multiple occurences of recall message (i={i}/{len(recall_message_full)-1}) inside agent context\n{recall_message}"
)
in_context_messages.append(message_obj)
else:
# if we're not in the final portion of the recall memory buffer, then it's 100% out-of-context
out_of_context_messages.append(message_obj)
assert len(in_context_messages) > 0, f"Couldn't find any in-context messages (agent_cache = {len(agent_message_cache)})"
# assert len(in_context_messages) == len(agent_message_cache), (len(in_context_messages), len(agent_message_cache))
if len(in_context_messages) != len(agent_message_cache):
# typer.secho(
# f"Warning: uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})",
# fg=typer.colors.RED,
# )
warnings.append(
f"Uneven match of new in-context messages vs loaded cache ({len(in_context_messages)} != {len(agent_message_cache)})"
)
# assert (
# len(in_context_messages) + len(out_of_context_messages) == state_dict["messages_total"]
# ), f"{len(in_context_messages)} + {len(out_of_context_messages)} != {state_dict['messages_total']}"
# Now we can insert the messages into the actual recall database
# So when we construct the agent from the state, they will be available
persistence_manager.recall_memory.insert_many(out_of_context_messages)
persistence_manager.recall_memory.insert_many(in_context_messages)
# Overwrite the agent_state message object
agent_state.state["messages"] = [str(m.id) for m in in_context_messages] # this is a list of uuids, not message dicts
## 4. Insert into recall
# TODO should this be 'messages', or 'all_messages'?
# all_messages in recall will have fields "timestamp" and "message"
# full_message_history_buffer = annotate_message_json_list_with_tool_calls([d["message"] for d in data["all_messages"]])
# We want to keep the timestamp
# for i in range(len(data["all_messages"])):
# data["all_messages"][i]["message"] = full_message_history_buffer[i]
# messages_to_insert = [
# Message.dict_to_message(
# user_id=user.id,
# agent_id=agent_id,
# openai_message_dict=msg,
# allow_functions_style=True,
# )
# # for msg in data["all_messages"]
# for msg in full_message_history_buffer
# ]
# agent.persistence_manager.recall_memory.insert_many(messages_to_insert)
# print("Finished migrating recall memory")
# 3. Instantiate a new Agent by passing AgentState to Agent.__init__
# NOTE: the Agent.__init__ will trigger a save, which will write to the DB
try:
agent = Agent(
agent_state=agent_state,
# messages_total=state_dict["messages_total"], # TODO: do we need this?
messages_total=len(in_context_messages) + len(out_of_context_messages),
interface=None,
)
save_agent(agent, ms=ms)
except Exception:
# if "Agent with name" in str(e):
# print(e)
# return
# elif "was specified in agent.state.functions":
# print(e)
# return
# else:
# raise
raise
# Wrap the rest in a try-except so that we can cleanup by deleting the agent if we fail
try:
# TODO should we also assign data["messages"] to RecallMemory.messages?
# 5. Insert into archival
if os.path.exists(archival_filename):
try:
nodes = pickle.load(open(archival_filename, "rb"))
except ModuleNotFoundError as e:
if "No module named 'llama_index.schema'" in str(e):
print(
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
)
nodes = []
else:
raise e
passages = []
failed_inserts = []
for node in nodes:
if len(node.embedding) != config.default_embedding_config.embedding_dim:
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
# raise ValueError(f"Cannot migrate agent {agent_state.name} due to incompatible embedding dimentions.")
failed_inserts.append(
f"Cannot migrate passage due to incompatible embedding dimentions ({len(node.embedding)} != {config.default_embedding_config.embedding_dim}) - content = '{node.text}'."
)
passages.append(
Passage(
user_id=user.id,
agent_id=agent_state.id,
text=node.text,
embedding=node.embedding,
embedding_dim=agent_state.embedding_config.embedding_dim,
embedding_model=agent_state.embedding_config.embedding_model,
)
)
if len(passages) > 0:
agent.persistence_manager.archival_memory.storage.insert_many(passages)
# print(f"Inserted {len(passages)} passages into archival memory")
if len(failed_inserts) > 0:
warnings.append(
f"Failed to transfer {len(failed_inserts)}/{len(nodes)} passages from old archival memory: " + ", ".join(failed_inserts)
)
else:
warnings.append("No archival memory found at", archival_filename)
except:
ms.delete_agent(agent_state.id)
raise
try:
new_agent_folder = os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents", agent_name)
shutil.move(agent_folder, new_agent_folder)
except Exception:
print(f"Failed to move agent folder from {agent_folder} to {new_agent_folder}")
raise
return warnings
# def migrate_all_agents(stop_on_fail=True):
def migrate_all_agents(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
if not os.path.exists(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER)):
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER))
os.makedirs(os.path.join(data_dir, MIGRATION_BACKUP_FOLDER, "agents"))
if not config_is_compatible(data_dir, echo=True):
typer.secho(f"Your current config file is incompatible with MemGPT versions >= {VERSION_CUTOFF}", fg=typer.colors.RED)
if questionary.confirm(
"To migrate old MemGPT agents, you must delete your config file and run `memgpt configure`. Would you like to proceed?"
).ask():
try:
wipe_config_and_reconfigure(data_dir)
except Exception as e:
typer.secho(f"Fresh config generation failed - error:\n{e}", fg=typer.colors.RED)
raise
else:
typer.secho("Migration cancelled (to migrate old agents, run `memgpt migrate`)", fg=typer.colors.RED)
raise KeyboardInterrupt()
agents_dir = os.path.join(data_dir, "agents")
# Ensure the directory exists
if not os.path.exists(agents_dir):
raise ValueError(f"Directory {agents_dir} does not exist.")
# Get a list of all folders in agents_dir
agent_folders = [f for f in os.listdir(agents_dir) if os.path.isdir(os.path.join(agents_dir, f))]
# Iterate over each folder with a tqdm progress bar
count = 0
successes = [] # agents that migrated w/o warnings
warnings = [] # agents that migrated but had warnings
failures = [] # agents that failed to migrate (fatal error)
candidates = []
config = MemGPTConfig.load()
print(config)
ms = MetadataStore(config)
try:
for agent_name in tqdm(agent_folders, desc="Migrating agents"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
if agent_is_migrateable(agent_name=agent_name, data_dir=data_dir):
candidates.append(agent_name)
migration_warnings = migrate_agent(agent_name, data_dir=data_dir, ms=ms)
if len(migration_warnings) == 0:
successes.append(agent_name)
else:
warnings.append((agent_name, migration_warnings))
count += 1
else:
continue
except Exception as e:
failures.append({"name": agent_name, "reason": str(e)})
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
if debug:
traceback.print_exc()
if stop_on_fail:
raise
except KeyboardInterrupt:
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
if len(candidates) == 0:
typer.secho(f"No migration candidates found ({len(agent_folders)} agent folders total)", fg=typer.colors.GREEN)
else:
typer.secho(f"Inspected {len(agent_folders)} agent folders for migration")
if len(warnings) > 0:
typer.secho(f"Migration warnings:", fg=typer.colors.BRIGHT_YELLOW)
for warn in warnings:
typer.secho(f"{warn[0]}: {warn[1]}", fg=typer.colors.BRIGHT_YELLOW)
if len(failures) > 0:
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
for fail in failures:
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
if len(failures) > 0:
typer.secho(
f"🔴 {len(failures)}/{len(candidates)} agents failed to migrate (see reasons above)",
fg=typer.colors.RED,
)
typer.secho(f"{[d['name'] for d in failures]}", fg=typer.colors.RED)
if len(warnings) > 0:
typer.secho(
f"🟠 {len(warnings)}/{len(candidates)} agents successfully migrated with warnings (see reasons above)",
fg=typer.colors.BRIGHT_YELLOW,
)
typer.secho(f"{[t[0] for t in warnings]}", fg=typer.colors.BRIGHT_YELLOW)
if len(successes) > 0:
typer.secho(
f"🟢 {len(successes)}/{len(candidates)} agents successfully migrated with no warnings",
fg=typer.colors.GREEN,
)
typer.secho(f"{successes}", fg=typer.colors.GREEN)
del ms
return {
"agent_folders": len(agent_folders),
"migration_candidates": candidates,
"successful_migrations": len(successes) + len(warnings),
"failed_migrations": failures,
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
}
def migrate_all_sources(data_dir: str = MEMGPT_DIR, stop_on_fail: bool = False, debug: bool = False) -> dict:
"""Scan over all agent folders in data_dir and migrate each agent."""
sources_dir = os.path.join(data_dir, "archival")
# Ensure the directory exists
if not os.path.exists(sources_dir):
raise ValueError(f"Directory {sources_dir} does not exist.")
# Get a list of all folders in agents_dir
source_folders = [f for f in os.listdir(sources_dir) if os.path.isdir(os.path.join(sources_dir, f))]
# Iterate over each folder with a tqdm progress bar
count = 0
failures = []
candidates = []
config = MemGPTConfig.load()
ms = MetadataStore(config)
try:
for source_name in tqdm(source_folders, desc="Migrating data sources"):
# Assuming migrate_agent is a function that takes the agent name and performs migration
try:
candidates.append(source_name)
migrate_source(source_name, data_dir, ms=ms)
count += 1
except Exception as e:
failures.append({"name": source_name, "reason": str(e)})
if debug:
traceback.print_exc()
if stop_on_fail:
raise
# typer.secho(f"Migrating {agent_name} failed with: {str(e)}", fg=typer.colors.RED)
except KeyboardInterrupt:
typer.secho(f"User cancelled operation", fg=typer.colors.RED)
if len(candidates) == 0:
typer.secho(f"No migration candidates found ({len(source_folders)} source folders total)", fg=typer.colors.GREEN)
else:
typer.secho(f"Inspected {len(source_folders)} source folders")
if len(failures) > 0:
typer.secho(f"Failed migrations:", fg=typer.colors.RED)
for fail in failures:
typer.secho(f"{fail['name']}: {fail['reason']}", fg=typer.colors.RED)
typer.secho(f"{len(failures)}/{len(candidates)} migration targets failed (see reasons above)", fg=typer.colors.RED)
if count > 0:
typer.secho(
f"{count}/{len(candidates)} sources were successfully migrated to the new database format", fg=typer.colors.GREEN
)
del ms
return {
"source_folders": len(source_folders),
"migration_candidates": candidates,
"successful_migrations": count,
"failed_migrations": failures,
"user_id": uuid.UUID(MemGPTConfig.load().anon_clientid),
}

View File

@ -0,0 +1,197 @@
# tool imports
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import JSON, Column
from sqlalchemy_utils import ChoiceType
from sqlmodel import Field, SQLModel
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.utils import get_human_text, get_persona_text, get_utc_time
class OptionState(str, Enum):
"""Useful for kwargs that are bool + default option"""
YES = "yes"
NO = "no"
DEFAULT = "default"
class MemGPTUsageStatistics(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
step_count: int
class LLMConfigModel(BaseModel):
model: Optional[str] = "gpt-4"
model_endpoint_type: Optional[str] = "openai"
model_endpoint: Optional[str] = "https://api.openai.com/v1"
model_wrapper: Optional[str] = None
context_window: Optional[int] = None
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())
class EmbeddingConfigModel(BaseModel):
embedding_endpoint_type: Optional[str] = "openai"
embedding_endpoint: Optional[str] = "https://api.openai.com/v1"
embedding_model: Optional[str] = "text-embedding-ada-002"
embedding_dim: Optional[int] = 1536
embedding_chunk_size: Optional[int] = 300
class PresetModel(BaseModel):
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: str = Field(..., description="The system prompt of the preset.")
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
persona: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona of the preset.")
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
class ToolModel(SQLModel, table=True):
# TODO move into database
name: str = Field(..., description="The name of the function.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the function.", primary_key=True)
tags: List[str] = Field(sa_column=Column(JSON), description="Metadata tags.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
source_code: Optional[str] = Field(..., description="The source code of the function.")
module: Optional[str] = Field(None, description="The module of the function.")
json_schema: Dict = Field(default_factory=dict, sa_column=Column(JSON), description="The JSON schema of the function.")
# optional: user_id (user-specific tools)
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the function.", index=True)
# Needed for Column(JSON)
class Config:
arbitrary_types_allowed = True
class AgentToolMap(SQLModel, table=True):
# mapping between agents and tools
agent_id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the agent-tool map.", primary_key=True)
class PresetToolMap(SQLModel, table=True):
# mapping between presets and tools
preset_id: uuid.UUID = Field(..., description="The unique identifier of the preset.")
tool_id: uuid.UUID = Field(..., description="The unique identifier of the tool.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset-tool map.", primary_key=True)
class AgentStateModel(BaseModel):
id: uuid.UUID = Field(..., description="The unique identifier of the agent.")
name: str = Field(..., description="The name of the agent.")
description: Optional[str] = Field(None, description="The description of the agent.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the agent.")
# timestamps
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the agent was created.")
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
# preset information
tools: List[str] = Field(..., description="The tools used by the agent.")
system: str = Field(..., description="The system prompt used by the agent.")
# functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
# llm information
llm_config: LLMConfigModel = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the agent.")
# agent state
state: Optional[Dict] = Field(None, description="The state of the agent.")
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
class CoreMemory(BaseModel):
human: str = Field(..., description="Human element of the core memory.")
persona: str = Field(..., description="Persona element of the core memory.")
class HumanModel(SQLModel, table=True):
text: str = Field(default=get_human_text(DEFAULT_HUMAN), description="The human text.")
name: str = Field(..., description="The name of the human.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the human.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the human.", index=True)
class PersonaModel(SQLModel, table=True):
text: str = Field(default=get_persona_text(DEFAULT_PERSONA), description="The persona text.")
name: str = Field(..., description="The name of the persona.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the persona.", primary_key=True)
user_id: Optional[uuid.UUID] = Field(..., description="The unique identifier of the user associated with the persona.", index=True)
class SourceModel(SQLModel, table=True):
name: str = Field(..., description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the source was created.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
description: Optional[str] = Field(None, description="The description of the source.")
# embedding info
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
# NOTE: .metadata is a reserved attribute on SQLModel
metadata_: Optional[dict] = Field(None, sa_column=Column(JSON), description="Metadata associated with the source.")
class JobStatus(str, Enum):
created = "created"
running = "running"
completed = "completed"
failed = "failed"
class JobModel(SQLModel, table=True):
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the job.", primary_key=True)
# status: str = Field(default="created", description="The status of the job.")
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.", sa_column=Column(ChoiceType(JobStatus)))
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the job.", index=True)
metadata_: Optional[dict] = Field({}, sa_column=Column(JSON), description="The metadata of the job.")
class PassageModel(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user associated with the passage.")
agent_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the agent associated with the passage.")
text: str = Field(..., description="The text of the passage.")
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfigModel] = Field(
None, sa_column=Column(JSON), description="The embedding configuration used by the passage."
)
data_source: Optional[str] = Field(None, description="The data source of the passage.")
doc_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the document associated with the passage.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the passage.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the passage.")
class DocumentModel(BaseModel):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the document.")
text: str = Field(..., description="The text of the document.")
data_source: str = Field(..., description="The data source of the document.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")
class UserModel(BaseModel):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user.")

View File

@ -2,10 +2,8 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import List
from memgpt.data_types import AgentState, Message
from memgpt.memory import BaseRecallMemory, EmbeddingArchivalMemory
from memgpt.schemas.agent import AgentState
from memgpt.schemas.memory import Memory
from memgpt.schemas.message import Message
from memgpt.utils import printd
@ -47,7 +45,7 @@ class LocalStateManager(PersistenceManager):
def __init__(self, agent_state: AgentState):
# Memory held in-state useful for debugging stateful versions
self.memory = agent_state.memory
self.memory = None
# self.messages = [] # current in-context messages
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
self.archival_memory = EmbeddingArchivalMemory(agent_state)
@ -59,6 +57,15 @@ class LocalStateManager(PersistenceManager):
self.archival_memory.save()
self.recall_memory.save()
def init(self, agent):
"""Connect persistent state manager to agent"""
printd(f"Initializing {self.__class__.__name__} with agent object")
# self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
# self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
# printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
'''
def json_to_message(self, message_json) -> Message:
"""Convert agent message JSON into Message object"""
@ -121,6 +128,7 @@ class LocalStateManager(PersistenceManager):
# self.messages = [self.messages[0]] + added_messages + self.messages[1:]
# add to recall memory
self.recall_memory.insert_many([m for m in added_messages])
def append_to_messages(self, added_messages: List[Message]):
# first tag with timestamps
@ -142,7 +150,6 @@ class LocalStateManager(PersistenceManager):
# add to recall memory
self.recall_memory.insert(new_system_message)
def update_memory(self, new_memory: Memory):
def update_memory(self, new_memory):
printd(f"{self.__class__.__name__}.update_memory")
assert isinstance(new_memory, Memory), type(new_memory)
self.memory = new_memory

View File

@ -0,0 +1,10 @@
system_prompt: "memgpt_chat"
functions:
- "send_message"
- "pause_heartbeats"
- "core_memory_append"
- "core_memory_replace"
- "conversation_search"
- "conversation_search_date"
- "archival_memory_insert"
- "archival_memory_search"

View File

@ -0,0 +1,10 @@
system_prompt: "memgpt_doc"
functions:
- "send_message"
- "pause_heartbeats"
- "core_memory_append"
- "core_memory_replace"
- "conversation_search"
- "conversation_search_date"
- "archival_memory_insert"
- "archival_memory_search"

View File

@ -0,0 +1,15 @@
system_prompt: "memgpt_chat"
functions:
- "send_message"
- "pause_heartbeats"
- "core_memory_append"
- "core_memory_replace"
- "conversation_search"
- "conversation_search_date"
- "archival_memory_insert"
- "archival_memory_search"
# extras for read/write to files
- "read_from_text_file"
- "append_to_text_file"
# internet access
- "http_request"

91
memgpt/presets/presets.py Normal file
View File

@ -0,0 +1,91 @@
import importlib
import inspect
import os
import uuid
from memgpt.data_types import AgentState, Preset
from memgpt.functions.functions import load_function_set
from memgpt.interface import AgentInterface
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel
from memgpt.presets.utils import load_all_presets
from memgpt.utils import list_human_files, list_persona_files, printd
available_presets = load_all_presets()
preset_options = list(available_presets.keys())
def load_module_tools(module_name="base"):
# return List[ToolModel] from base.py tools
full_module_name = f"memgpt.functions.function_sets.{module_name}"
try:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
raise e
# function tags
try:
# Load the function set
functions_to_schema = load_function_set(module)
except ValueError as e:
err = f"Error loading function set '{module_name}': {e}"
printd(err)
# create tool in db
tools = []
for name, schema in functions_to_schema.items():
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]
if module_name == "base":
tags.append("memgpt-base")
tools.append(
ToolModel(
name=name,
tags=tags,
source_type="python",
module=schema["module"],
source_code=source_code,
json_schema=schema["json_schema"],
)
)
return tools
def add_default_tools(user_id: uuid.UUID, ms: MetadataStore):
module_name = "base"
for tool in load_module_tools(module_name=module_name):
existing_tool = ms.get_tool(tool.name)
if not existing_tool:
ms.add_tool(tool)
def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
for persona_file in list_persona_files():
text = open(persona_file, "r", encoding="utf-8").read()
name = os.path.basename(persona_file).replace(".txt", "")
if ms.get_persona(user_id=user_id, name=name) is not None:
printd(f"Persona '{name}' already exists for user '{user_id}'")
continue
persona = PersonaModel(name=name, text=text, user_id=user_id)
ms.add_persona(persona)
for human_file in list_human_files():
text = open(human_file, "r", encoding="utf-8").read()
name = os.path.basename(human_file).replace(".txt", "")
if ms.get_human(user_id=user_id, name=name) is not None:
printd(f"Human '{name}' already exists for user '{user_id}'")
continue
human = HumanModel(name=name, text=text, user_id=user_id)
print(human, user_id)
ms.add_human(human)
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
def create_agent_from_preset(
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
):
"""Initialize a new agent from a preset (combination of system + function)"""
raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead")

79
memgpt/presets/utils.py Normal file
View File

@ -0,0 +1,79 @@
import glob
import os
import yaml
from memgpt.constants import MEMGPT_DIR
def is_valid_yaml_format(yaml_data, function_set):
"""
Check if the given YAML data follows the specified format and if all functions in the yaml are part of the function_set.
Raises ValueError if any check fails.
:param yaml_data: The data loaded from a YAML file.
:param function_set: A set of valid function names.
"""
# Check for required keys
if not all(key in yaml_data for key in ["system_prompt", "functions"]):
raise ValueError("YAML data is missing one or more required keys: 'system_prompt', 'functions'.")
# Check if 'functions' is a list of strings
if not all(isinstance(item, str) for item in yaml_data.get("functions", [])):
raise ValueError("'functions' should be a list of strings.")
# Check if all functions in YAML are part of function_set
if not set(yaml_data["functions"]).issubset(function_set):
raise ValueError(
f"Some functions in YAML are not part of the provided function set: {set(yaml_data['functions']) - set(function_set)} "
)
# If all checks pass
return True
def load_yaml_file(file_path):
"""
Load a YAML file and return the data.
:param file_path: Path to the YAML file.
:return: Data from the YAML file.
"""
with open(file_path, "r", encoding="utf-8") as file:
return yaml.safe_load(file)
def load_all_presets():
"""Load all the preset configs in the examples directory"""
## Load the examples
# Get the directory in which the script is located
script_directory = os.path.dirname(os.path.abspath(__file__))
# Construct the path pattern
example_path_pattern = os.path.join(script_directory, "examples", "*.yaml")
# Listing all YAML files
example_yaml_files = glob.glob(example_path_pattern)
## Load the user-provided presets
# ~/.memgpt/presets/*.yaml
user_presets_dir = os.path.join(MEMGPT_DIR, "presets")
# Create directory if it doesn't exist
if not os.path.exists(user_presets_dir):
os.makedirs(user_presets_dir)
# Construct the path pattern
user_path_pattern = os.path.join(user_presets_dir, "*.yaml")
# Listing all YAML files
user_yaml_files = glob.glob(user_path_pattern)
# Pull from both examplesa and user-provided
all_yaml_files = example_yaml_files + user_yaml_files
# Loading and creating a mapping from file name to YAML data
all_yaml_data = {}
for file_path in all_yaml_files:
# Extracting the base file name without the '.yaml' extension
base_name = os.path.splitext(os.path.basename(file_path))[0]
data = load_yaml_file(file_path)
all_yaml_data[base_name] = data
return all_yaml_data

View File

@ -1,90 +0,0 @@
import uuid
from datetime import datetime
from typing import Dict, List, Optional
from pydantic import Field, field_validator
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.memory import Memory
class BaseAgent(MemGPTBase, validate_assignment=True):
__id_prefix__ = "agent"
description: Optional[str] = Field(None, description="The description of the agent.")
# metadata
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
user_id: Optional[str] = Field(None, description="The user id of the agent.")
class AgentState(BaseAgent):
"""Representation of an agent's state."""
id: str = BaseAgent.generate_id_field()
name: str = Field(..., description="The name of the agent.")
created_at: datetime = Field(..., description="The datetime the agent was created.", default_factory=datetime.now)
# in-context memory
message_ids: Optional[List[str]] = Field(default=None, description="The ids of the messages in the agent's in-context memory.")
memory: Memory = Field(default_factory=Memory, description="The in-context memory of the agent.")
# tools
tools: List[str] = Field(..., description="The tools used by the agent.")
# system prompt
system: str = Field(..., description="The system prompt used by the agent.")
# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
class CreateAgent(BaseAgent):
# all optional as server can generate defaults
name: Optional[str] = Field(None, description="The name of the agent.")
message_ids: Optional[List[uuid.UUID]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
@field_validator("name")
@classmethod
def validate_name(cls, name: str) -> str:
"""Validate the requested new agent name (prevent bad inputs)"""
import re
if not name:
# don't check if not provided
return name
# TODO: this check should also be added to other model (e.g. User.name)
# Length check
if not (1 <= len(name) <= 50):
raise ValueError("Name length must be between 1 and 50 characters.")
# Regex for allowed characters (alphanumeric, spaces, hyphens, underscores)
if not re.match("^[A-Za-z0-9 _-]+$", name):
raise ValueError("Name contains invalid characters.")
# Further checks can be added here...
# TODO
return name
class UpdateAgentState(BaseAgent):
id: str = Field(..., description="The id of the agent.")
name: Optional[str] = Field(None, description="The name of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
# TODO: determine if these should be editable via this schema?
message_ids: Optional[List[str]] = Field(None, description="The ids of the messages in the agent's in-context memory.")
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")

View File

@ -1,21 +0,0 @@
from typing import Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class BaseAPIKey(MemGPTBase):
__id_prefix__ = "sk" # secret key
class APIKey(BaseAPIKey):
id: str = BaseAPIKey.generate_id_field()
user_id: str = Field(..., description="The unique identifier of the user associated with the token.")
key: str = Field(..., description="The key value.")
name: str = Field(..., description="Name of the token.")
class APIKeyCreate(BaseAPIKey):
user_id: str = Field(..., description="The unique identifier of the user associated with the token.")
name: Optional[str] = Field(None, description="Name of the token.")

View File

@ -1,117 +0,0 @@
from typing import List, Optional, Union
from pydantic import Field, model_validator
from typing_extensions import Self
from memgpt.schemas.memgpt_base import MemGPTBase
# block of the LLM context
class BaseBlock(MemGPTBase, validate_assignment=True):
"""Base block of the LLM context"""
__id_prefix__ = "block"
# data value
value: Optional[Union[List[str], str]] = Field(None, description="Value of the block.")
limit: int = Field(2000, description="Character limit of the block.")
name: Optional[str] = Field(None, description="Name of the block.")
template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).")
label: Optional[str] = Field(None, description="Label of the block (e.g. 'human', 'persona').")
# metadat
description: Optional[str] = Field(None, description="Description of the block.")
metadata_: Optional[dict] = Field({}, description="Metadata of the block.")
# associated user/agent
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the block.")
@model_validator(mode="after")
def verify_char_limit(self) -> Self:
try:
assert len(self) <= self.limit
except AssertionError:
error_msg = f"Edit failed: Exceeds {self.limit} character limit (requested {len(self)})."
raise ValueError(error_msg)
except Exception as e:
raise e
return self
def __len__(self):
return len(str(self))
def __str__(self) -> str:
if isinstance(self.value, list):
return ",".join(self.value)
elif isinstance(self.value, str):
return self.value
else:
return ""
def __setattr__(self, name, value):
"""Run validation if self.value is updated"""
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True))
class Block(BaseBlock):
"""Block of the LLM context"""
id: str = BaseBlock.generate_id_field()
value: str = Field(..., description="Value of the block.")
class Human(Block):
"""Human block of the LLM context"""
label: str = "human"
class Persona(Block):
"""Persona block of the LLM context"""
label: str = "persona"
class CreateBlock(BaseBlock):
"""Create a block"""
template: bool = True
label: str = Field(..., description="Label of the block.")
class CreatePersona(BaseBlock):
"""Create a persona block"""
template: bool = True
label: str = "persona"
class CreateHuman(BaseBlock):
"""Create a human block"""
template: bool = True
label: str = "human"
class UpdateBlock(BaseBlock):
"""Update a block"""
id: str = Field(..., description="The unique identifier of the block.")
limit: Optional[int] = Field(2000, description="Character limit of the block.")
class UpdatePersona(UpdateBlock):
"""Update a persona block"""
label: str = "persona"
class UpdateHuman(UpdateBlock):
"""Update a human block"""
label: str = "human"

View File

@ -1,21 +0,0 @@
from typing import Dict, Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class DocumentBase(MemGPTBase):
"""Base class for document schemas"""
__id_prefix__ = "doc"
class Document(DocumentBase):
"""Representation of a single document (broken up into `Passage` objects)"""
id: str = DocumentBase.generate_id_field()
text: str = Field(..., description="The text of the document.")
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
user_id: str = Field(description="The unique identifier of the user associated with the document.")
metadata_: Optional[Dict] = Field({}, description="The metadata of the document.")

View File

@ -1,18 +0,0 @@
from typing import Optional
from pydantic import BaseModel, Field
class EmbeddingConfig(BaseModel):
"""Embedding model configuration"""
embedding_endpoint_type: str = Field(..., description="The endpoint type for the model.")
embedding_endpoint: Optional[str] = Field(None, description="The endpoint for the model (`None` if local).")
embedding_model: str = Field(..., description="The model for the embedding.")
embedding_dim: int = Field(..., description="The dimension of the embedding.")
embedding_chunk_size: Optional[int] = Field(300, description="The chunk size of the embedding.")
# azure only
azure_endpoint: Optional[str] = Field(None, description="The Azure endpoint for the model.")
azure_version: Optional[str] = Field(None, description="The Azure version for the model.")
azure_deployment: Optional[str] = Field(None, description="The Azure deployment for the model.")

View File

@ -1,30 +0,0 @@
from enum import Enum
class MessageRole(str, Enum):
assistant = "assistant"
user = "user"
tool = "tool"
system = "system"
class OptionState(str, Enum):
"""Useful for kwargs that are bool + default option"""
YES = "yes"
NO = "no"
DEFAULT = "default"
class JobStatus(str, Enum):
created = "created"
running = "running"
completed = "completed"
failed = "failed"
pending = "pending"
class MessageStreamStatus(str, Enum):
done_generation = "[DONE_GEN]"
done_step = "[DONE_STEP]"
done = "[DONE]"

View File

@ -1,28 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.utils import get_utc_time
class JobBase(MemGPTBase):
__id_prefix__ = "job"
metadata_: Optional[dict] = Field({}, description="The metadata of the job.")
class Job(JobBase):
"""Representation of offline jobs."""
id: str = JobBase.generate_id_field()
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
user_id: str = Field(..., description="The unique identifier of the user associated with the job.")
class JobUpdate(JobBase):
id: str = Field(..., description="The unique identifier of the job.")
status: Optional[JobStatus] = Field(..., description="The status of the job.")

View File

@ -1,15 +0,0 @@
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
class LLMConfig(BaseModel):
# TODO: 🤮 don't default to a vendor! bug city!
model: str = Field(..., description="LLM model name. ")
model_endpoint_type: str = Field(..., description="The endpoint type for the model.")
model_endpoint: str = Field(..., description="The endpoint for the model.")
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
context_window: int = Field(..., description="The context window size for the model.")
# FIXME hack to silence pydantic protected namespace warning
model_config = ConfigDict(protected_namespaces=())

View File

@ -1,80 +0,0 @@
import uuid
from logging import getLogger
from typing import Optional
from uuid import UUID
from pydantic import BaseModel, ConfigDict, Field, field_validator
# from: https://gist.github.com/norton120/22242eadb80bf2cf1dd54a961b151c61
logger = getLogger(__name__)
class MemGPTBase(BaseModel):
"""Base schema for MemGPT schemas (does not include model provider schemas, e.g. OpenAI)"""
model_config = ConfigDict(
# allows you to use the snake or camelcase names in your code (ie user_id or userId)
populate_by_name=True,
# allows you do dump a sqlalchemy object directly (ie PersistedAddress.model_validate(SQLAdress)
from_attributes=True,
# throw errors if attributes are given that don't belong
extra="forbid",
)
# def __id_prefix__(self):
# raise NotImplementedError("All schemas must have an __id_prefix__ attribute!")
@classmethod
def generate_id_field(cls, prefix: Optional[str] = None) -> "Field":
prefix = prefix or cls.__id_prefix__
# TODO: generate ID from regex pattern?
def _generate_id() -> str:
return f"{prefix}-{uuid.uuid4()}"
return Field(
...,
description=cls._id_description(prefix),
pattern=cls._id_regex_pattern(prefix),
examples=[cls._id_example(prefix)],
default_factory=_generate_id,
)
# def _generate_id(self) -> str:
# return f"{self.__id_prefix__}-{uuid.uuid4()}"
@classmethod
def _id_regex_pattern(cls, prefix: str):
"""generates the regex pattern for a given id"""
return (
r"^" + prefix + r"-" # prefix string
r"[a-fA-F0-9]{8}" # 8 hexadecimal characters
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
# r"[a-fA-F0-9]{12}$" # 12 hexadecimal characters
)
@classmethod
def _id_example(cls, prefix: str):
"""generates an example id for a given prefix"""
return [prefix + "-123e4567-e89b-12d3-a456-426614174000"]
@classmethod
def _id_description(cls, prefix: str):
"""generates a factory function for a given prefix"""
return f"The human-friendly ID of the {prefix.capitalize()}"
@field_validator("id", check_fields=False, mode="before")
@classmethod
def allow_bare_uuids(cls, v, values):
"""to ease the transition to stripe ids,
we allow bare uuids and convert them with a warning
"""
_ = values # for SCA
if isinstance(v, UUID):
logger.warning("Bare UUIDs are deprecated, please use the full prefixed id!")
return f"{cls.__id_prefix__}-{v}"
return v

View File

@ -1,78 +0,0 @@
from datetime import datetime, timezone
from typing import Literal, Union
from pydantic import BaseModel, field_serializer
# MemGPT API style responses (intended to be easier to use vs getting true Message types)
class BaseMemGPTMessage(BaseModel):
id: str
date: datetime
@field_serializer("date")
def serialize_datetime(self, dt: datetime, _info):
return dt.now(timezone.utc).isoformat()
class InternalMonologue(BaseMemGPTMessage):
"""
{
"internal_monologue": msg,
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
"id": str(msg_obj.id) if msg_obj is not None else None,
}
"""
internal_monologue: str
class FunctionCall(BaseModel):
name: str
arguments: str
class FunctionCallMessage(BaseMemGPTMessage):
"""
{
"function_call": {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
},
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""
function_call: FunctionCall
class FunctionReturn(BaseMemGPTMessage):
"""
{
"function_return": msg,
"status": "success" or "error",
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
"""
function_return: str
status: Literal["success", "error"]
MemGPTMessage = Union[InternalMonologue, FunctionCallMessage, FunctionReturn]
# Legacy MemGPT API had an additional type "assistant_message" and the "function_call" was a formatted string
class AssistantMessage(BaseMemGPTMessage):
assistant_message: str
class LegacyFunctionCallMessage(BaseMemGPTMessage):
function_call: str
LegacyMemGPTMessage = Union[InternalMonologue, AssistantMessage, LegacyFunctionCallMessage, FunctionReturn]

View File

@ -1,18 +0,0 @@
from typing import List
from pydantic import BaseModel, Field
from memgpt.schemas.message import MessageCreate
class MemGPTRequest(BaseModel):
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement
stream_steps: bool = Field(
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)

View File

@ -1,17 +0,0 @@
from typing import List, Union
from pydantic import BaseModel, Field
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics
# TODO: consider moving into own file
class MemGPTResponse(BaseModel):
# messages: List[Message] = Field(..., description="The messages returned by the agent.")
messages: Union[List[Message], List[MemGPTMessage], List[LegacyMemGPTMessage]] = Field(
..., description="The messages returned by the agent."
)
usage: MemGPTUsageStatistics = Field(..., description="The usage statistics of the agent.")

View File

@ -1,138 +0,0 @@
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field
from memgpt.schemas.block import Block
class Memory(BaseModel, validate_assignment=True):
"""Represents the in-context memory of the agent"""
# Private variable to avoid assignments with incorrect types
memory: Dict[str, Block] = Field(default_factory=dict, description="Mapping from memory block section to memory block.")
@classmethod
def load(cls, state: dict):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
obj.memory[key] = Block(**value)
return obj
def __str__(self) -> str:
"""Representation of the memory in-context"""
section_strs = []
for section, module in self.memory.items():
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
return "\n".join(section_strs)
def to_dict(self):
"""Convert to dictionary representation"""
return {key: value.dict() for key, value in self.memory.items()}
def to_flat_dict(self):
"""Convert to a dictionary that maps directly from block names to values"""
return {k: v.value for k, v in self.memory.items() if v is not None}
def list_block_names(self) -> List[str]:
"""Return a list of the block names held inside the memory object"""
return list(self.memory.keys())
def get_block(self, name: str) -> Block:
"""Correct way to index into the memory.memory field, returns a Block"""
if name not in self.memory:
return KeyError(f"Block field {name} does not exist (available sections = {', '.join(list(self.memory.keys()))})")
else:
return self.memory[name]
def link_block(self, name: str, block: Block, override: Optional[bool] = False):
"""Link a new block to the memory object"""
if not isinstance(block, Block):
raise ValueError(f"Param block must be type Block (not {type(block)})")
if not isinstance(name, str):
raise ValueError(f"Name must be str (not type {type(name)})")
if not override and name in self.memory:
raise ValueError(f"Block with name {name} already exists")
self.memory[name] = block
def update_block_value(self, name: str, value: Union[List[str], str]):
"""Update the value of a block"""
if name not in self.memory:
raise ValueError(f"Block with name {name} does not exist")
if not (isinstance(value, str) or (isinstance(value, list) and all(isinstance(v, str) for v in value))):
raise ValueError(f"Provided value must be a string or list of strings")
self.memory[name].value = value
# TODO: ideally this is refactored into ChatMemory and the subclasses are given more specific names.
class BaseChatMemory(Memory):
def core_memory_append(self, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
Args:
name (str): Section of the memory to be edited (persona or human).
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
current_value = str(self.memory.get_block(name).value)
new_value = current_value + "\n" + str(content)
self.memory.update_block_value(name=name, value=new_value)
return None
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
name (str): Section of the memory to be edited (persona or human).
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
current_value = str(self.memory.get_block(name).value)
new_value = current_value.replace(str(old_content), str(new_content))
self.memory.update_block_value(name=name, value=new_value)
return None
class ChatMemory(BaseChatMemory):
"""
ChatMemory initializes a BaseChatMemory with two default blocks
"""
def __init__(self, persona: str, human: str, limit: int = 2000):
super().__init__()
self.link_block(name="persona", block=Block(name="persona", value=persona, limit=limit, label="persona"))
self.link_block(name="human", block=Block(name="human", value=human, limit=limit, label="human"))
class BlockChatMemory(BaseChatMemory):
"""
BlockChatMemory is a subclass of BaseChatMemory which uses shared memory blocks specified at initialization-time.
"""
def __init__(self, blocks: List[Block] = []):
super().__init__()
for block in blocks:
# TODO: centralize these internal schema validations
assert block.name is not None and block.name != "", "each existing chat block must have a name"
self.link_block(name=block.name, block=block)
class UpdateMemory(BaseModel):
"""Update the memory of the agent"""
class ArchivalMemorySummary(BaseModel):
size: int = Field(..., description="Number of rows in archival memory")
class RecallMemorySummary(BaseModel):
size: int = Field(..., description="Number of rows in recall memory")

View File

@ -1,123 +0,0 @@
from typing import Any, Dict, List, Literal, Optional, Union
from pydantic import BaseModel, Field
class SystemMessage(BaseModel):
content: str
role: str = "system"
name: Optional[str] = None
class UserMessage(BaseModel):
content: Union[str, List[str]]
role: str = "user"
name: Optional[str] = None
class ToolCallFunction(BaseModel):
name: str = Field(..., description="The name of the function to call")
arguments: str = Field(..., description="The arguments to pass to the function (JSON dump)")
class ToolCall(BaseModel):
id: str = Field(..., description="The ID of the tool call")
type: str = "function"
function: ToolCallFunction = Field(..., description="The arguments and name for the function")
class AssistantMessage(BaseModel):
content: Optional[str] = None
role: str = "assistant"
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
class ToolMessage(BaseModel):
content: str
role: str = "tool"
tool_call_id: str
ChatMessage = Union[SystemMessage, UserMessage, AssistantMessage, ToolMessage]
# TODO: this might not be necessary with the validator
def cast_message_to_subtype(m_dict: dict) -> ChatMessage:
"""Cast a dictionary to one of the individual message types"""
role = m_dict.get("role")
if role == "system":
return SystemMessage(**m_dict)
elif role == "user":
return UserMessage(**m_dict)
elif role == "assistant":
return AssistantMessage(**m_dict)
elif role == "tool":
return ToolMessage(**m_dict)
else:
raise ValueError("Unknown message role")
class ResponseFormat(BaseModel):
type: str = Field(default="text", pattern="^(text|json_object)$")
## tool_choice ##
class FunctionCall(BaseModel):
name: str
class ToolFunctionChoice(BaseModel):
# The type of the tool. Currently, only function is supported
type: Literal["function"] = "function"
# type: str = Field(default="function", const=True)
function: FunctionCall
ToolChoice = Union[Literal["none", "auto"], ToolFunctionChoice]
## tools ##
class FunctionSchema(BaseModel):
name: str
description: Optional[str] = None
parameters: Optional[Dict[str, Any]] = None # JSON Schema for the parameters
class Tool(BaseModel):
# The type of the tool. Currently, only function is supported
type: Literal["function"] = "function"
# type: str = Field(default="function", const=True)
function: FunctionSchema
## function_call ##
FunctionCallChoice = Union[Literal["none", "auto"], FunctionCall]
class ChatCompletionRequest(BaseModel):
"""https://platform.openai.com/docs/api-reference/chat/create"""
model: str
messages: List[ChatMessage]
frequency_penalty: Optional[float] = 0
logit_bias: Optional[Dict[str, int]] = None
logprobs: Optional[bool] = False
top_logprobs: Optional[int] = None
max_tokens: Optional[int] = None
n: Optional[int] = 1
presence_penalty: Optional[float] = 0
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream: Optional[bool] = False
temperature: Optional[float] = 1
top_p: Optional[float] = 1
user: Optional[str] = None # unique ID of the end-user (for monitoring)
# function-calling related
tools: Optional[List[Tool]] = None
tool_choice: Optional[ToolChoice] = "none"
# deprecated scheme
functions: Optional[List[FunctionSchema]] = None
function_call: Optional[FunctionCallChoice] = None

View File

@ -1,66 +0,0 @@
from datetime import datetime
from typing import Dict, List, Optional
from pydantic import Field, field_validator
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.utils import get_utc_time
class PassageBase(MemGPTBase):
__id_prefix__ = "passage"
# associated user/agent
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the passage.")
agent_id: Optional[str] = Field(None, description="The unique identifier of the agent associated with the passage.")
# origin data source
source_id: Optional[str] = Field(None, description="The data source of the passage.")
# document association
doc_id: Optional[str] = Field(None, description="The unique identifier of the document associated with the passage.")
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
class Passage(PassageBase):
id: str = PassageBase.generate_id_field()
# passage text
text: str = Field(..., description="The text of the passage.")
# embeddings
embedding: Optional[List[float]] = Field(..., description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(..., description="The embedding configuration used by the passage.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the passage.")
@field_validator("embedding")
@classmethod
def pad_embeddings(cls, embedding: List[float]) -> List[float]:
"""Pad embeddings to MAX_EMBEDDING_SIZE. This is necessary to ensure all stored embeddings are the same size."""
import numpy as np
if embedding and len(embedding) != MAX_EMBEDDING_DIM:
np_embedding = np.array(embedding)
padded_embedding = np.pad(np_embedding, (0, MAX_EMBEDDING_DIM - np_embedding.shape[0]), mode="constant")
return padded_embedding.tolist()
return embedding
class PassageCreate(PassageBase):
text: str = Field(..., description="The text of the passage.")
# optionally provide embeddings
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
class PassageUpdate(PassageCreate):
id: str = Field(..., description="The unique identifier of the passage.")
text: Optional[str] = Field(None, description="The text of the passage.")
# optionally provide embeddings
embedding: Optional[List[float]] = Field(None, description="The embedding of the passage.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")

View File

@ -1,49 +0,0 @@
from datetime import datetime
from typing import Optional
from fastapi import UploadFile
from pydantic import BaseModel, Field
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.utils import get_utc_time
class BaseSource(MemGPTBase):
"""
Shared attributes accourss all source schemas.
"""
__id_prefix__ = "source"
description: Optional[str] = Field(None, description="The description of the source.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the passage.")
# NOTE: .metadata is a reserved attribute on SQLModel
metadata_: Optional[dict] = Field(None, description="Metadata associated with the source.")
class SourceCreate(BaseSource):
name: str = Field(..., description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
class Source(BaseSource):
id: str = BaseSource.generate_id_field()
name: str = Field(..., description="The name of the source.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of the source.")
user_id: str = Field(..., description="The ID of the user that created the source.")
class SourceUpdate(BaseSource):
id: str = Field(..., description="The ID of the source.")
name: Optional[str] = Field(None, description="The name of the source.")
class UploadFileToSourceRequest(BaseModel):
file: UploadFile = Field(..., description="The file to upload.")
class UploadFileToSourceResponse(BaseModel):
source: Source = Field(..., description="The source the file was uploaded to.")
added_passages: int = Field(..., description="The number of passages added to the source.")
added_documents: int = Field(..., description="The number of documents added to the source.")

View File

@ -1,86 +0,0 @@
from typing import Dict, List, Optional
from pydantic import Field
from memgpt.functions.schema_generator import (
generate_schema_from_args_schema,
generate_tool_wrapper,
)
from memgpt.schemas.memgpt_base import MemGPTBase
from memgpt.schemas.openai.chat_completions import ToolCall
class BaseTool(MemGPTBase):
__id_prefix__ = "tool"
# optional fields
description: Optional[str] = Field(None, description="The description of the tool.")
source_type: Optional[str] = Field(None, description="The type of the source code.")
module: Optional[str] = Field(None, description="The module of the function.")
# optional: user_id (user-specific tools)
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the function.")
class Tool(BaseTool):
id: str = BaseTool.generate_id_field()
name: str = Field(..., description="The name of the function.")
tags: List[str] = Field(..., description="Metadata tags.")
# code
source_code: str = Field(..., description="The source code of the function.")
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
def to_dict(self):
"""Convert into OpenAI representation"""
return vars(
ToolCall(
tool_id=self.id,
tool_call_type="function",
function=self.module,
)
)
@classmethod
def from_crewai(cls, crewai_tool) -> "Tool":
"""
Class method to create an instance of Tool from a crewAI BaseTool object.
Args:
crewai_tool (CrewAIBaseTool): An instance of a crewAI BaseTool (BaseTool from crewai)
Returns:
Tool: A memGPT Tool initialized with attributes derived from the provided crewAI BaseTool object.
"""
crewai_tool.name
description = crewai_tool.description
source_type = "python"
tags = ["crew-ai"]
wrapper_func_name, wrapper_function_str = generate_tool_wrapper(crewai_tool.__class__.__name__)
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
return cls(
name=wrapper_func_name,
description=description,
source_type=source_type,
tags=tags,
source_code=wrapper_function_str,
json_schema=json_schema,
)
class ToolCreate(BaseTool):
name: str = Field(..., description="The name of the function.")
tags: List[str] = Field(..., description="Metadata tags.")
source_code: str = Field(..., description="The source code of the function.")
json_schema: Dict = Field(default_factory=dict, description="The JSON schema of the function.")
class ToolUpdate(ToolCreate):
id: str = Field(..., description="The unique identifier of the tool.")
name: Optional[str] = Field(None, description="The name of the function.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
source_code: Optional[str] = Field(None, description="The source code of the function.")
json_schema: Optional[Dict] = Field(None, description="The JSON schema of the function.")

View File

@ -1,8 +0,0 @@
from pydantic import BaseModel, Field
class MemGPTUsageStatistics(BaseModel):
completion_tokens: int = Field(0, description="The number of tokens generated by the agent.")
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
step_count: int = Field(0, description="The number of steps taken by the agent.")

View File

@ -1,20 +0,0 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class UserBase(MemGPTBase):
__id_prefix__ = "user"
class User(UserBase):
id: str = UserBase.generate_id_field()
name: str = Field(..., description="The name of the user.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
class UserCreate(UserBase):
name: Optional[str] = Field(None, description="The name of the user.")

View File

@ -1,8 +1,6 @@
from typing import List
from fastapi import APIRouter
from memgpt.schemas.agent import AgentState
from memgpt.server.rest_api.agents.index import ListAgentsResponse
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -10,12 +8,14 @@ router = APIRouter()
def setup_agents_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
@router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
def get_all_agents():
"""
Get a list of all agents in the database
"""
interface.clear()
return server.list_agents()
agents_data = server.list_agents_legacy()
return ListAgentsResponse(**agents_data)
return router

View File

@ -3,7 +3,7 @@ from typing import List, Literal, Optional
from fastapi import APIRouter, Body, HTTPException
from pydantic import BaseModel, Field
from memgpt.schemas.tool import Tool as ToolModel # TODO: modify
from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer

View File

@ -1,46 +1,102 @@
import uuid
from typing import List, Optional
from fastapi import APIRouter, Body, HTTPException, Query
from pydantic import BaseModel, Field
from memgpt.schemas.api_key import APIKey, APIKeyCreate
from memgpt.schemas.user import User, UserCreate
from memgpt.data_types import User
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class GetAllUsersResponse(BaseModel):
cursor: Optional[uuid.UUID] = Field(None, description="Cursor for the next page in the response.")
user_list: List[dict] = Field(..., description="A list of users.")
class CreateUserRequest(BaseModel):
user_id: Optional[uuid.UUID] = Field(None, description="Identifier of the user (optional, generated automatically if null).")
api_key_name: Optional[str] = Field(None, description="Name for API key autogenerated on user creation (optional).")
class CreateUserResponse(BaseModel):
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
api_key: str = Field(..., description="New API key generated for user.")
class CreateAPIKeyRequest(BaseModel):
user_id: uuid.UUID = Field(..., description="Identifier of the user (UUID).")
name: Optional[str] = Field(None, description="Name for the API key (optional).")
class CreateAPIKeyResponse(BaseModel):
api_key: str = Field(..., description="New API key generated.")
class GetAPIKeysResponse(BaseModel):
api_key_list: List[str] = Field(..., description="Identifier of the user (UUID).")
class DeleteAPIKeyResponse(BaseModel):
message: str
api_key_deleted: str
class DeleteUserResponse(BaseModel):
message: str
user_id_deleted: uuid.UUID
def setup_admin_router(server: SyncServer, interface: QueuingInterface):
@router.get("/users", tags=["admin"], response_model=List[User])
def get_all_users(cursor: Optional[str] = Query(None), limit: Optional[int] = Query(50)):
@router.get("/users", tags=["admin"], response_model=GetAllUsersResponse)
def get_all_users(cursor: Optional[uuid.UUID] = Query(None), limit: Optional[int] = Query(50)):
"""
Get a list of all users in the database
"""
try:
# TODO: make this call a server function
_, users = server.ms.get_all_users(cursor=cursor, limit=limit)
next_cursor, users = server.ms.get_all_users(cursor, limit)
processed_users = [{"user_id": user.id} for user in users]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return users
return GetAllUsersResponse(cursor=next_cursor, user_list=processed_users)
@router.post("/users", tags=["admin"], response_model=User)
def create_user(request: UserCreate = Body(...)):
@router.post("/users", tags=["admin"], response_model=CreateUserResponse)
def create_user(request: Optional[CreateUserRequest] = Body(None)):
"""
Create a new user in the database
"""
if request is None:
request = CreateUserRequest()
new_user = User(
id=None if not request.user_id else request.user_id,
# TODO can add more fields (name? metadata?)
)
try:
user = server.create_user(request)
server.ms.create_user(new_user)
# make sure we can retrieve the user from the DB too
new_user_ret = server.ms.get_user(new_user.id)
if new_user_ret is None:
raise HTTPException(status_code=500, detail=f"Failed to verify user creation")
# create an API key for the user
token = server.ms.create_api_key(user_id=new_user.id, name=request.api_key_name)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return user
return CreateUserResponse(user_id=new_user_ret.id, api_key=token.token)
@router.delete("/users", tags=["admin"], response_model=User)
@router.delete("/users", tags=["admin"], response_model=DeleteUserResponse)
def delete_user(
user_id: str = Query(..., description="The user_id key to be deleted."),
user_id: uuid.UUID = Query(..., description="The user_id key to be deleted."),
):
# TODO make a soft deletion, instead of a hard deletion
try:
@ -52,24 +108,24 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return user
return DeleteUserResponse(message="User successfully deleted.", user_id_deleted=user_id)
@router.post("/users/keys", tags=["admin"], response_model=APIKey)
def create_new_api_key(request: APIKeyCreate = Body(...)):
@router.post("/users/keys", tags=["admin"], response_model=CreateAPIKeyResponse)
def create_new_api_key(request: CreateAPIKeyRequest = Body(...)):
"""
Create a new API key for a user
"""
try:
api_key = server.create_api_key(request)
token = server.ms.create_api_key(user_id=request.user_id, name=request.name)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return api_key
return CreateAPIKeyResponse(api_key=token.token)
@router.get("/users/keys", tags=["admin"], response_model=List[APIKey])
@router.get("/users/keys", tags=["admin"], response_model=GetAPIKeysResponse)
def get_api_keys(
user_id: str = Query(..., description="The unique identifier of the user."),
user_id: uuid.UUID = Query(..., description="The unique identifier of the user."),
):
"""
Get a list of all API keys for a user
@ -77,22 +133,28 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
try:
if server.ms.get_user(user_id=user_id) is None:
raise HTTPException(status_code=404, detail=f"User does not exist")
api_keys = server.ms.get_all_api_keys_for_user(user_id=user_id)
tokens = server.ms.get_all_api_keys_for_user(user_id=user_id)
processed_tokens = [t.token for t in tokens]
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return api_keys
print("TOKENS", processed_tokens)
return GetAPIKeysResponse(api_key_list=processed_tokens)
@router.delete("/users/keys", tags=["admin"], response_model=APIKey)
@router.delete("/users/keys", tags=["admin"], response_model=DeleteAPIKeyResponse)
def delete_api_key(
api_key: str = Query(..., description="The API key to be deleted."),
):
try:
return server.delete_api_key(api_key)
token = server.ms.get_api_key(api_key=api_key)
if token is None:
raise HTTPException(status_code=404, detail=f"API key does not exist")
server.ms.delete_api_key(api_key=api_key)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return DeleteAPIKeyResponse(message="API key successfully deleted.", api_key_deleted=api_key)
return router

View File

@ -0,0 +1,48 @@
import uuid
from functools import partial
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class CommandRequest(BaseModel):
command: str = Field(..., description="The command to be executed by the agent.")
class CommandResponse(BaseModel):
response: str = Field(..., description="The result of the executed command.")
def setup_agents_command_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.post("/agents/{agent_id}/command", tags=["agents"], response_model=CommandResponse)
def run_command(
agent_id: uuid.UUID,
request: CommandRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Execute a command on a specified agent.
This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately.
Raises an HTTPException for any processing errors.
"""
interface.clear()
try:
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
response = server.run_command(user_id=user_id, agent_id=agent_id, command=request.command)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return CommandResponse(response=response)
return router

View File

@ -0,0 +1,156 @@
import re
import uuid
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import (
AgentStateModel,
EmbeddingConfigModel,
LLMConfigModel,
)
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class AgentRenameRequest(BaseModel):
agent_name: str = Field(..., description="New name for the agent.")
class GetAgentResponse(BaseModel):
# config: dict = Field(..., description="The agent configuration object.")
agent_state: AgentStateModel = Field(..., description="The state of the agent.")
sources: List[str] = Field(..., description="The list of data sources associated with the agent.")
last_run_at: Optional[int] = Field(None, description="The unix timestamp of when the agent was last run.")
def validate_agent_name(name: str) -> str:
"""Validate the requested new agent name (prevent bad inputs)"""
# Length check
if not (1 <= len(name) <= 50):
raise HTTPException(status_code=400, detail="Name length must be between 1 and 50 characters.")
# Regex for allowed characters (alphanumeric, spaces, hyphens, underscores)
if not re.match("^[A-Za-z0-9 _-]+$", name):
raise HTTPException(status_code=400, detail="Name contains invalid characters.")
# Further checks can be added here...
# TODO
return name
def setup_agents_config_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/config", tags=["agents"], response_model=GetAgentResponse)
def get_agent_config(
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the configuration for a specific agent.
This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
"""
interface.clear()
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
agent_state = server.get_agent_config(user_id=user_id, agent_id=agent_id)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)
# configs
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=agent_state.tools,
system=agent_state.system,
metadata=agent_state._metadata,
),
last_run_at=None, # TODO
sources=attached_sources,
)
@router.patch("/agents/{agent_id}/rename", tags=["agents"], response_model=GetAgentResponse)
def update_agent_name(
agent_id: uuid.UUID,
request: AgentRenameRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Updates the name of a specific agent.
This changes the name of the agent in the database but does NOT edit the agent's persona.
"""
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
valid_name = validate_agent_name(request.agent_name)
interface.clear()
try:
agent_state = server.rename_agent(user_id=user_id, agent_id=agent_id, new_agent_name=valid_name)
# get sources
attached_sources = server.list_attached_sources(agent_id=agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return GetAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=agent_state.tools,
system=agent_state.system,
),
last_run_at=None, # TODO
sources=attached_sources,
)
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
# agent_id = uuid.UUID(agent_id)
interface.clear()
try:
server.delete_agent(user_id=user_id, agent_id=agent_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent agent_id={agent_id} successfully deleted"})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@ -1,22 +1,49 @@
import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from memgpt.constants import BASE_TOOLS
from memgpt.memory import ChatMemory
from memgpt.models.pydantic_models import (
AgentStateModel,
EmbeddingConfigModel,
LLMConfigModel,
PresetModel,
)
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.settings import settings
router = APIRouter()
class ListAgentsResponse(BaseModel):
num_agents: int = Field(..., description="The number of agents available to the user.")
# TODO make return type List[AgentStateModel]
# also return - presets: List[PresetModel]
agents: List[dict] = Field(..., description="List of agent configurations.")
class CreateAgentRequest(BaseModel):
# TODO: modify this (along with front end)
config: dict = Field(..., description="The agent configuration object.")
class CreateAgentResponse(BaseModel):
agent_state: AgentStateModel = Field(..., description="The state of the newly created agent.")
preset: PresetModel = Field(..., description="The preset that the agent was created from.")
def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents", tags=["agents"], response_model=List[AgentState])
@router.get("/agents", tags=["agents"], response_model=ListAgentsResponse)
def list_agents(
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
List all agents associated with a given user.
@ -24,71 +51,95 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
interface.clear()
agents_data = server.list_agents(user_id=user_id)
return agents_data
agents_data = server.list_agents_legacy(user_id=user_id)
return ListAgentsResponse(**agents_data)
@router.post("/agents", tags=["agents"], response_model=AgentState)
@router.post("/agents", tags=["agents"], response_model=CreateAgentResponse)
def create_agent(
request: CreateAgent = Body(...),
user_id: str = Depends(get_current_user_with_server),
request: CreateAgentRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Create a new agent with the specified configuration.
"""
interface.clear()
agent_state = server.create_agent(request, user_id=user_id)
return agent_state
# Parse request
# TODO: don't just use JSON in the future
human_name = request.config["human_name"] if "human_name" in request.config else None
human = request.config["human"] if "human" in request.config else None
persona_name = request.config["persona_name"] if "persona_name" in request.config else None
persona = request.config["persona"] if "persona" in request.config else None
request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
tool_names = request.config["function_names"] if ("function_names" in request.config and request.config["function_names"]) else None
metadata = request.config["metadata"] if "metadata" in request.config else {}
metadata["human"] = human_name
metadata["persona"] = persona_name
# TODO: remove this -- should be added based on create agent fields
if isinstance(tool_names, str): # TODO: fix this on clinet side?
tool_names = tool_names.split(",")
if tool_names is None or tool_names == "":
tool_names = []
for name in BASE_TOOLS: # TODO: remove this
if name not in tool_names:
tool_names.append(name)
assert isinstance(tool_names, list), "Tool names must be a list of strings."
# TODO: eventually remove this - should support general memory at the REST endpoint
# TODO: the REST server should add default memory tools at startup time
memory = ChatMemory(persona=persona, human=human)
@router.post("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
def update_agent(
agent_id: str,
request: UpdateAgentState = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""Update an exsiting agent"""
interface.clear()
try:
# TODO: should id be moved out of UpdateAgentState?
agent_state = server.update_agent(request, user_id=user_id)
agent_state = server.create_agent(
user_id=user_id,
# **request.config
# TODO turn into a pydantic model
name=request.config["name"],
memory=memory,
system=request.config.get("system", None),
# persona_name=persona_name,
# human_name=human_name,
# persona=persona,
# human=human,
# llm_config=LLMConfigModel(
# model=request.config['model'],
# )
# tools
tools=tool_names,
metadata=metadata,
# function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
)
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
return CreateAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=agent_state.tools,
system=agent_state.system,
metadata=agent_state._metadata,
),
preset=PresetModel( # TODO: remove (placeholder to avoid breaking frontend)
name="dummy_preset",
id=agent_state.id,
user_id=agent_state.user_id,
description="",
created_at=agent_state.created_at,
system=agent_state.system,
persona="",
human="",
functions_schema=[],
),
)
except Exception as e:
print(str(e))
raise HTTPException(status_code=500, detail=str(e))
return agent_state
@router.get("/agents/{agent_id}", tags=["agents"], response_model=AgentState)
def get_agent_state(
agent_id: str = None,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get the state of the agent.
"""
interface.clear()
if not server.ms.get_agent(user_id=user_id, agent_id=agent_id):
# agent does not exist
raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
return server.get_agent_state(user_id=user_id, agent_id=agent_id)
@router.delete("/agents/{agent_id}", tags=["agents"])
def delete_agent(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Delete an agent.
"""
# agent_id = str(agent_id)
interface.clear()
try:
server.delete_agent(user_id=user_id, agent_id=agent_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@ -1,12 +1,11 @@
import uuid
from functools import partial
from typing import Dict, List, Optional
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from memgpt.schemas.message import Message
from memgpt.schemas.passage import Passage
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -14,24 +13,60 @@ from memgpt.server.server import SyncServer
router = APIRouter()
class CoreMemory(BaseModel):
human: str | None = Field(None, description="Human element of the core memory.")
persona: str | None = Field(None, description="Persona element of the core memory.")
class GetAgentMemoryResponse(BaseModel):
core_memory: CoreMemory = Field(..., description="The state of the agent's core memory.")
recall_memory: int = Field(..., description="Size of the agent's recall memory.")
archival_memory: int = Field(..., description="Size of the agent's archival memory.")
# NOTE not subclassing CoreMemory since in the request both field are optional
class UpdateAgentMemoryRequest(BaseModel):
human: str = Field(None, description="Human element of the core memory.")
persona: str = Field(None, description="Persona element of the core memory.")
class UpdateAgentMemoryResponse(BaseModel):
old_core_memory: CoreMemory = Field(..., description="The previous state of the agent's core memory.")
new_core_memory: CoreMemory = Field(..., description="The updated state of the agent's core memory.")
class ArchivalMemoryObject(BaseModel):
# TODO move to models/pydantic_models, or inherent from data_types Record
id: uuid.UUID = Field(..., description="Unique identifier for the memory object inside the archival memory store.")
contents: str = Field(..., description="The memory contents.")
class GetAgentArchivalMemoryResponse(BaseModel):
# TODO: make this List[Passage] instead
archival_memory: List[ArchivalMemoryObject] = Field(..., description="A list of all memory objects in archival memory.")
class InsertAgentArchivalMemoryRequest(BaseModel):
content: str = Field(..., description="The memory contents to insert into archival memory.")
class InsertAgentArchivalMemoryResponse(BaseModel):
ids: List[str] = Field(
..., description="Unique identifier for the new archival memory object. May return multiple ids if insert contents are chunked."
)
class DeleteAgentArchivalMemoryRequest(BaseModel):
id: str = Field(..., description="Unique identifier for the new archival memory object.")
def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/memory/messages", tags=["agents"], response_model=List[Message])
def get_agent_in_context_messages(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the messages in the context of a specific agent.
"""
interface.clear()
return server.get_in_context_messages(agent_id=agent_id)
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
@router.get("/agents/{agent_id}/memory", tags=["agents"], response_model=GetAgentMemoryResponse)
def get_agent_memory(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the memory state of a specific agent.
@ -39,13 +74,14 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface,
This endpoint fetches the current memory state of the agent identified by the user ID and agent ID.
"""
interface.clear()
return server.get_agent_memory(agent_id=agent_id)
memory = server.get_agent_memory(user_id=user_id, agent_id=agent_id)
return GetAgentMemoryResponse(**memory)
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=Memory)
@router.post("/agents/{agent_id}/memory", tags=["agents"], response_model=UpdateAgentMemoryResponse)
def update_agent_memory(
agent_id: str,
request: Dict = Body(...),
user_id: str = Depends(get_current_user_with_server),
agent_id: uuid.UUID,
request: UpdateAgentMemoryRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Update the core memory of a specific agent.
@ -53,86 +89,75 @@ def setup_agents_memory_router(server: SyncServer, interface: QueuingInterface,
This endpoint accepts new memory contents (human and persona) and updates the core memory of the agent identified by the user ID and agent ID.
"""
interface.clear()
memory = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=request)
return memory
@router.get("/agents/{agent_id}/memory/recall", tags=["agents"], response_model=RecallMemorySummary)
def get_agent_recall_memory_summary(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
new_memory_contents = {"persona": request.persona, "human": request.human}
response = server.update_agent_core_memory(user_id=user_id, agent_id=agent_id, new_memory_contents=new_memory_contents)
return UpdateAgentMemoryResponse(**response)
@router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
def get_agent_archival_memory_all(
agent_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the summary of the recall memory of a specific agent.
Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
"""
interface.clear()
return server.get_recall_memory_summary(agent_id=agent_id)
archival_memories = server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
print("archival_memories:", archival_memories)
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories]
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
@router.get("/agents/{agent_id}/memory/archival", tags=["agents"], response_model=ArchivalMemorySummary)
def get_agent_archival_memory_summary(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Retrieve the summary of the archival memory of a specific agent.
"""
interface.clear()
return server.get_archival_memory_summary(agent_id=agent_id)
# @router.get("/agents/{agent_id}/archival/all", tags=["agents"], response_model=List[Passage])
# def get_agent_archival_memory_all(
# agent_id: str,
# user_id: str = Depends(get_current_user_with_server),
# ):
# """
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
# """
# interface.clear()
# return server.get_all_archival_memories(user_id=user_id, agent_id=agent_id)
@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=List[Passage])
@router.get("/agents/{agent_id}/archival", tags=["agents"], response_model=GetAgentArchivalMemoryResponse)
def get_agent_archival_memory(
agent_id: str,
agent_id: uuid.UUID,
after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."),
before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."),
limit: Optional[int] = Query(None, description="How many results to include in the response."),
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the memories in an agent's archival memory store (paginated query).
"""
interface.clear()
return server.get_agent_archival_cursor(
# TODO need to add support for non-postgres here
# chroma will throw:
# raise ValueError("Cannot run get_all_cursor with chroma")
_, archival_json_records = server.get_agent_archival_cursor(
user_id=user_id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_json_records]
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
@router.post("/agents/{agent_id}/archival/{memory}", tags=["agents"], response_model=List[Passage])
@router.post("/agents/{agent_id}/archival", tags=["agents"], response_model=InsertAgentArchivalMemoryResponse)
def insert_agent_archival_memory(
agent_id: str,
memory: str,
user_id: str = Depends(get_current_user_with_server),
agent_id: uuid.UUID,
request: InsertAgentArchivalMemoryRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Insert a memory into an agent's archival memory store.
"""
interface.clear()
return server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=memory)
memory_ids = server.insert_archival_memory(user_id=user_id, agent_id=agent_id, memory_contents=request.content)
return InsertAgentArchivalMemoryResponse(ids=memory_ids)
@router.delete("/agents/{agent_id}/archival/{memory_id}", tags=["agents"])
@router.delete("/agents/{agent_id}/archival", tags=["agents"])
def delete_agent_archival_memory(
agent_id: str,
memory_id: str,
user_id: str = Depends(get_current_user_with_server),
agent_id: uuid.UUID,
id: str = Query(..., description="Unique ID of the memory to be deleted."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete a memory from an agent's archival memory store.
"""
# TODO: should probably return a `Passage`
interface.clear()
try:
memory_id = uuid.UUID(id)
server.delete_archival_memory(user_id=user_id, agent_id=agent_id, memory_id=memory_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"})
except HTTPException:

View File

@ -1,48 +1,116 @@
import asyncio
import uuid
from datetime import datetime
from enum import Enum
from functools import partial
from typing import List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from memgpt.schemas.enums import MessageRole, MessageStreamStatus
from memgpt.schemas.memgpt_message import LegacyMemGPTMessage, MemGPTMessage
from memgpt.schemas.memgpt_request import MemGPTRequest
from memgpt.schemas.memgpt_response import MemGPTResponse
from memgpt.schemas.message import Message
from memgpt.models.pydantic_models import MemGPTUsageStatistics
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface, StreamingServerInterface
from memgpt.server.rest_api.utils import sse_async_generator
from memgpt.server.server import SyncServer
from memgpt.utils import deduplicate
router = APIRouter()
# TODO: cpacker should check this file
# TODO: move this into server.py?
class MessageRoleType(str, Enum):
user = "user"
system = "system"
class UserMessageRequest(BaseModel):
message: str = Field(..., description="The message content to be processed by the agent.")
name: Optional[str] = Field(default=None, description="Name of the message request sender")
role: MessageRoleType = Field(default=MessageRoleType.user, description="Role of the message sender (either 'user' or 'system')")
stream_steps: bool = Field(
default=False, description="Flag to determine if the response should be streamed. Set to True for streaming agent steps."
)
stream_tokens: bool = Field(
default=False,
description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).",
)
timestamp: Optional[datetime] = Field(
None,
description="Timestamp to tag the message with (in ISO format). If null, timestamp will be created server-side on receipt of message.",
)
stream: bool = Field(
default=False,
description="Legacy flag for old streaming API, will be deprecrated in the future.",
deprecated=True,
)
# @validator("timestamp", pre=True, always=True)
# def validate_timestamp(cls, value: Optional[datetime]) -> Optional[datetime]:
# if value is None:
# return value # If the timestamp is None, just return None, implying default handling to set server-side
# if not isinstance(value, datetime):
# raise TypeError("Timestamp must be a datetime object with timezone information.")
# if value.tzinfo is None or value.tzinfo.utcoffset(value) is None:
# raise ValueError("Timestamp must be timezone-aware.")
# # Convert timestamp to UTC if it's not already in UTC
# if value.tzinfo.utcoffset(value) != timezone.utc.utcoffset(value):
# value = value.astimezone(timezone.utc)
# return value
class UserMessageResponse(BaseModel):
messages: List[dict] = Field(..., description="List of messages generated by the agent in response to the received message.")
usage: MemGPTUsageStatistics = Field(..., description="Usage statistics for the completion.")
class GetAgentMessagesRequest(BaseModel):
start: int = Field(..., description="Message index to start on (reverse chronological).")
count: int = Field(..., description="How many messages to retrieve.")
class GetAgentMessagesCursorRequest(BaseModel):
before: Optional[uuid.UUID] = Field(..., description="Message before which to retrieve the returned messages.")
limit: int = Field(..., description="Maximum number of messages to retrieve.")
class GetAgentMessagesResponse(BaseModel):
messages: list = Field(..., description="List of message objects.")
async def send_message_to_agent(
server: SyncServer,
agent_id: str,
user_id: str,
role: MessageRole,
agent_id: uuid.UUID,
user_id: uuid.UUID,
role: str,
message: str,
stream_legacy: bool, # legacy
stream_steps: bool,
stream_tokens: bool,
chat_completion_mode: Optional[bool] = False,
timestamp: Optional[datetime] = None,
# related to whether or not we return `MemGPTMessage`s or `Message`s
return_message_object: bool = True, # Should be True for Python Client, False for REST API
) -> Union[StreamingResponse, MemGPTResponse]:
) -> Union[StreamingResponse, UserMessageResponse]:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# TODO: @charles is this the correct way to handle?
include_final_message = True
# determine role
if role == MessageRole.user:
# TODO this is a total hack but is required until we move streaming into the model config
if server.server_llm_config.model_endpoint != "https://api.openai.com/v1":
stream_tokens = False
# handle the legacy mode streaming
if stream_legacy:
# NOTE: override
stream_steps = True
stream_tokens = False
include_final_message = False
else:
include_final_message = True
if role == "user" or role is None:
message_func = server.user_message
elif role == MessageRole.system:
elif role == "system":
message_func = server.system_message
else:
raise HTTPException(status_code=500, detail=f"Bad role {role}")
@ -53,11 +121,9 @@ async def send_message_to_agent(
# For streaming response
try:
# TODO: move this logic into server.py
# Get the generator object off of the agent's streaming interface
# This will be attached to the POST SSE request used under-the-hood
memgpt_agent = server._get_or_load_agent(agent_id=agent_id)
memgpt_agent = server._get_or_load_agent(user_id=user_id, agent_id=agent_id)
streaming_interface = memgpt_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
@ -67,6 +133,8 @@ async def send_message_to_agent(
# "chatcompletion mode" does some remapping and ignores inner thoughts
streaming_interface.streaming_chat_completion_mode = chat_completion_mode
# NOTE: for legacy 'stream' flag
streaming_interface.nonstreaming_legacy_mode = stream_legacy
# streaming_interface.allow_assistant_message = stream
# streaming_interface.function_call_legacy_mode = stream
@ -77,44 +145,21 @@ async def send_message_to_agent(
)
if stream_steps:
if return_message_object:
# TODO implement returning `Message`s in a stream, not just `MemGPTMessage` format
raise NotImplementedError
# return a stream
return StreamingResponse(
sse_async_generator(streaming_interface.get_generator(), finish_message=include_final_message),
media_type="text/event-stream",
)
else:
# buffer the stream, then return the list
generated_stream = []
async for message in streaming_interface.get_generator():
assert (
isinstance(message, MemGPTMessage)
or isinstance(message, LegacyMemGPTMessage)
or isinstance(message, MessageStreamStatus)
), type(message)
generated_stream.append(message)
if message == MessageStreamStatus.done:
if "data" in message and message["data"] == "[DONE]":
break
# Get rid of the stream status messages
filtered_stream = [d for d in generated_stream if not isinstance(d, MessageStreamStatus)]
filtered_stream = [d for d in generated_stream if d not in ["[DONE_GEN]", "[DONE_STEP]", "[DONE]"]]
usage = await task
# By default the stream will be messages of type MemGPTMessage or MemGPTLegacyMessage
# If we want to convert these to Message, we can use the attached IDs
# NOTE: we will need to de-duplicate the Messsage IDs though (since Assistant->Inner+Func_Call)
# TODO: eventually update the interface to use `Message` and `MessageChunk` (new) inside the deque instead
if return_message_object:
message_ids = [m.id for m in filtered_stream]
message_ids = deduplicate(message_ids)
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
return MemGPTResponse(messages=message_objs, usage=usage)
else:
return MemGPTResponse(messages=filtered_stream, usage=usage)
return UserMessageResponse(messages=filtered_stream, usage=usage)
except HTTPException:
raise
@ -129,39 +174,55 @@ async def send_message_to_agent(
def setup_agents_message_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/agents/{agent_id}/messages/context/", tags=["agents"], response_model=List[Message])
def get_agent_messages_in_context(
agent_id: str,
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages(
agent_id: uuid.UUID,
start: int = Query(..., description="Message index to start on (reverse chronological)."),
count: int = Query(..., description="How many messages to retrieve."),
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
interface.clear()
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=start, count=count)
return messages
# Validate with the Pydantic model (optional)
request = GetAgentMessagesRequest(agent_id=agent_id, start=start, count=count)
# agent_id = uuid.UUID(request.agent_id) if request.agent_id else None
@router.get("/agents/{agent_id}/messages", tags=["agents"], response_model=List[Message])
def get_agent_messages(
agent_id: str,
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
interface.clear()
messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=request.start, count=request.count)
return GetAgentMessagesResponse(messages=messages)
@router.get("/agents/{agent_id}/messages-cursor", tags=["agents"], response_model=GetAgentMessagesResponse)
def get_agent_messages_cursor(
agent_id: uuid.UUID,
before: Optional[uuid.UUID] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(10, description="Maximum number of messages to retrieve."),
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Retrieve message history for an agent.
Retrieve the in-context messages of a specific agent. Paginated, provide start and count to iterate.
"""
interface.clear()
return server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, before=before, limit=limit, reverse=True)
# Validate with the Pydantic model (optional)
request = GetAgentMessagesCursorRequest(agent_id=agent_id, before=before, limit=limit)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=MemGPTResponse)
interface.clear()
[_, messages] = server.get_agent_recall_cursor(
user_id=user_id, agent_id=agent_id, before=request.before, limit=request.limit, reverse=True
)
# print("====> messages-cursor DEBUG")
# for i, msg in enumerate(messages):
# print(f"message {i+1}/{len(messages)}")
# print(f"UTC created-at: {msg.created_at.strftime('%Y-%m-%dT%H:%M:%S.%f')[:-3] + 'Z'}")
# print(f"ISO format string: {msg['created_at']}")
# print(msg)
return GetAgentMessagesResponse(messages=messages)
@router.post("/agents/{agent_id}/messages", tags=["agents"], response_model=UserMessageResponse)
async def send_message(
# background_tasks: BackgroundTasks,
agent_id: str,
request: MemGPTRequest = Body(...),
user_id: str = Depends(get_current_user_with_server),
agent_id: uuid.UUID,
request: UserMessageRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Process a user message and return the agent's response.
@ -169,21 +230,17 @@ def setup_agents_message_router(server: SyncServer, interface: QueuingInterface,
This endpoint accepts a message from a user and processes it through the agent.
It can optionally stream the response if 'stream' is set to True.
"""
# TODO: should this recieve multiple messages? @cpacker
# TODO: revise to `MemGPTRequest`
# TODO: support sending multiple messages
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
message = request.messages[0]
# TODO: what to do with message.name?
return await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=user_id,
role=message.role,
message=message.text,
role=request.role,
message=request.message,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
timestamp=request.timestamp,
# legacy
stream_legacy=request.stream,
)
return router

View File

@ -1,73 +0,0 @@
from functools import partial
from typing import List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from memgpt.schemas.block import Block, CreateBlock
from memgpt.schemas.block import Human as HumanModel # TODO: modify
from memgpt.schemas.block import UpdateBlock
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_block_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/blocks", tags=["block"], response_model=List[Block])
async def list_blocks(
user_id: str = Depends(get_current_user_with_server),
# query parameters
label: Optional[str] = Query(None, description="Labels to include (e.g. human, persona)"),
templates_only: bool = Query(True, description="Whether to include only templates"),
name: Optional[str] = Query(None, description="Name of the block"),
):
# Clear the interface
interface.clear()
blocks = server.get_blocks(user_id=user_id, label=label, template=templates_only, name=name)
if blocks is None:
return []
return blocks
@router.post("/blocks", tags=["block"], response_model=Block)
async def create_block(
request: CreateBlock = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
request.user_id = user_id # TODO: remove?
return server.create_block(user_id=user_id, request=request)
@router.post("/blocks/{block_id}", tags=["block"], response_model=Block)
async def update_block(
block_id: str,
request: UpdateBlock = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
# TODO: should this be in the param or the POST data?
assert block_id == request.id
return server.update_block(user_id=user_id, request=request)
@router.delete("/blocks/{block_id}", tags=["block"], response_model=Block)
async def delete_block(
block_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
return server.delete_block(block_id=block_id)
@router.get("/blocks/{block_id}", tags=["block"], response_model=Block)
async def get_block(
block_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
block = server.get_block(block_id=block_id)
if block is None:
raise HTTPException(status_code=404, detail="Block not found")
return block
return router

View File

@ -1,11 +1,9 @@
import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Depends
from pydantic import BaseModel, Field
from memgpt.schemas.embedding_config import EmbeddingConfig
from memgpt.schemas.llm_config import LLMConfig
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -21,20 +19,13 @@ class ConfigResponse(BaseModel):
def setup_config_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/config/llm", tags=["config"], response_model=List[LLMConfig])
def get_llm_configs(user_id: str = Depends(get_current_user_with_server)):
@router.get("/config", tags=["config"], response_model=ConfigResponse)
def get_server_config(user_id: uuid.UUID = Depends(get_current_user_with_server)):
"""
Retrieve the base configuration for the server.
"""
interface.clear()
return [server.server_llm_config]
@router.get("/config/embedding", tags=["config"], response_model=List[EmbeddingConfig])
def get_embedding_configs(user_id: str = Depends(get_current_user_with_server)):
"""
Retrieve the base configuration for the server.
"""
interface.clear()
return [server.server_embedding_config]
response = server.get_server_config(include_defaults=True)
return ConfigResponse(config=response["config"], defaults=response["defaults"])
return router

View File

@ -0,0 +1,69 @@
import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import HumanModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
class ListHumansResponse(BaseModel):
humans: List[HumanModel] = Field(..., description="List of human configurations.")
class CreateHumanRequest(BaseModel):
text: str = Field(..., description="The human text.")
name: str = Field(..., description="The name of the human.")
def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/humans", tags=["humans"], response_model=ListHumansResponse)
async def list_humans(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# Clear the interface
interface.clear()
humans = server.ms.list_humans(user_id=user_id)
return ListHumansResponse(humans=humans)
@router.post("/humans", tags=["humans"], response_model=HumanModel)
async def create_human(
request: CreateHumanRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
# TODO: disallow duplicate names for humans
interface.clear()
new_human = HumanModel(text=request.text, name=request.name, user_id=user_id)
human_id = new_human.id
server.ms.add_human(new_human)
return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id)
@router.delete("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
async def delete_human(
human_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
human = server.ms.delete_human(human_name, user_id=user_id)
return human
@router.get("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
async def get_human(
human_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
human = server.ms.get_human(human_name, user_id=user_id)
if human is None:
raise HTTPException(status_code=404, detail="Human not found")
return human
return router

View File

@ -2,24 +2,13 @@ import asyncio
import json
import queue
from collections import deque
from typing import AsyncGenerator, Literal, Optional, Union
from typing import AsyncGenerator, Optional
from memgpt.data_types import Message
from memgpt.interface import AgentInterface
from memgpt.schemas.enums import MessageStreamStatus
from memgpt.schemas.memgpt_message import (
AssistantMessage,
FunctionCall,
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
LegacyFunctionCallMessage,
LegacyMemGPTMessage,
MemGPTMessage,
)
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from memgpt.models.chat_completion_response import ChatCompletionChunkResponse
from memgpt.streaming_interface import AgentChunkStreamingInterface
from memgpt.utils import is_utc_datetime
from memgpt.utils import get_utc_time, is_utc_datetime
class QueuingInterface(AgentInterface):
@ -29,66 +18,12 @@ class QueuingInterface(AgentInterface):
self.buffer = queue.Queue()
self.debug = debug
def _queue_push(self, message_api: Union[str, dict], message_obj: Union[Message, None]):
"""Wrapper around self.buffer.queue.put() that ensures the types are safe
Data will be in the format: {
"message_obj": ...
"message_string": ...
}
"""
# Check the string first
if isinstance(message_api, str):
# check that it's the stop word
if message_api == "STOP":
assert message_obj is None
self.buffer.put(
{
"message_api": message_api,
"message_obj": None,
}
)
else:
raise ValueError(f"Unrecognized string pushed to buffer: {message_api}")
elif isinstance(message_api, dict):
# check if it's the error message style
if len(message_api.keys()) == 1 and "internal_error" in message_api:
assert message_obj is None
self.buffer.put(
{
"message_api": message_api,
"message_obj": None,
}
)
else:
assert message_obj is not None, message_api
self.buffer.put(
{
"message_api": message_api,
"message_obj": message_obj,
}
)
else:
raise ValueError(f"Unrecognized type pushed to buffer: {type(message_api)}")
def to_list(self, style: Literal["obj", "api"] = "obj"):
def to_list(self):
"""Convert queue to a list (empties it out at the same time)"""
items = []
while not self.buffer.empty():
try:
# items.append(self.buffer.get_nowait())
item_to_push = self.buffer.get_nowait()
if style == "obj":
if item_to_push["message_obj"] is not None:
items.append(item_to_push["message_obj"])
elif style == "api":
items.append(item_to_push["message_api"])
else:
raise ValueError(style)
items.append(self.buffer.get_nowait())
except queue.Empty:
break
if len(items) > 1 and items[-1] == "STOP":
@ -101,30 +36,20 @@ class QueuingInterface(AgentInterface):
# Empty the queue
self.buffer.queue.clear()
async def message_generator(self, style: Literal["obj", "api"] = "obj"):
async def message_generator(self):
while True:
if not self.buffer.empty():
message = self.buffer.get()
message_obj = message["message_obj"]
message_api = message["message_api"]
if message_api == "STOP":
if message == "STOP":
break
# yield message
if style == "obj":
yield message_obj
elif style == "api":
yield message_api
else:
raise ValueError(style)
# yield message | {"date": datetime.now(tz=pytz.utc).isoformat()}
yield message
else:
await asyncio.sleep(0.1) # Small sleep to prevent a busy loop
def step_yield(self):
"""Enqueue a special stop message"""
self._queue_push(message_api="STOP", message_obj=None)
self.buffer.put("STOP")
@staticmethod
def step_complete():
@ -132,8 +57,8 @@ class QueuingInterface(AgentInterface):
def error(self, error: str):
"""Enqueue a special stop message"""
self._queue_push(message_api={"internal_error": error}, message_obj=None)
self._queue_push(message_api="STOP", message_obj=None)
self.buffer.put({"internal_error": error})
self.buffer.put("STOP")
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Handle reception of a user message"""
@ -159,7 +84,7 @@ class QueuingInterface(AgentInterface):
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
self._queue_push(message_api=new_message, message_obj=msg_obj)
self.buffer.put(new_message)
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""Handle the agent sending a message"""
@ -183,13 +108,11 @@ class QueuingInterface(AgentInterface):
assert self.buffer.qsize() > 1, "Tried to reach back to grab function call data, but couldn't find a buffer message."
# TODO also should not be accessing protected member here
new_message["id"] = self.buffer.queue[-1]["message_api"]["id"]
new_message["id"] = self.buffer.queue[-1]["id"]
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = self.buffer.queue[-1]["message_api"]["date"]
new_message["date"] = self.buffer.queue[-1]["date"]
msg_obj = self.buffer.queue[-1]["message_obj"]
self._queue_push(message_api=new_message, message_obj=msg_obj)
self.buffer.put(new_message)
def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None:
"""Handle the agent calling a function"""
@ -229,7 +152,7 @@ class QueuingInterface(AgentInterface):
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
self._queue_push(message_api=new_message, message_obj=msg_obj)
self.buffer.put(new_message)
class FunctionArgumentsStreamHandler:
@ -316,21 +239,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# if multi_step = True, the stream ends when the agent yields
# if multi_step = False, the stream ends when the step ends
self.multi_step = multi_step
self.multi_step_indicator = MessageStreamStatus.done_step
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
self.multi_step_indicator = "[DONE_STEP]"
self.multi_step_gen_indicator = "[DONE_GEN]"
# extra prints
self.debug = False
self.timeout = 30
async def _create_generator(self) -> AsyncGenerator[Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus], None]:
async def _create_generator(self) -> AsyncGenerator:
"""An asynchronous generator that yields chunks as they become available."""
while self._active:
try:
# Wait until there is an item in the deque or the stream is deactivated
await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout
except asyncio.TimeoutError:
break # Exit the loop if we timeout
# Wait until there is an item in the deque or the stream is deactivated
await self._event.wait()
while self._chunks:
yield self._chunks.popleft()
@ -338,33 +254,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# Reset the event until a new item is pushed
self._event.clear()
# while self._active:
# # Wait until there is an item in the deque or the stream is deactivated
# await self._event.wait()
# while self._chunks:
# yield self._chunks.popleft()
# # Reset the event until a new item is pushed
# self._event.clear()
def get_generator(self) -> AsyncGenerator:
"""Get the generator that yields processed chunks."""
if not self._active:
# If the stream is not active, don't return a generator that would produce values
raise StopIteration("The stream has not been started or has been ended.")
return self._create_generator()
def _push_to_buffer(self, item: Union[MemGPTMessage, LegacyMemGPTMessage, MessageStreamStatus]):
"""Add an item to the deque"""
assert self._active, "Generator is inactive"
# assert isinstance(item, dict) or isinstance(item, MessageStreamStatus), f"Wrong type: {type(item)}"
assert (
isinstance(item, MemGPTMessage) or isinstance(item, LegacyMemGPTMessage) or isinstance(item, MessageStreamStatus)
), f"Wrong type: {type(item)}"
self._chunks.append(item)
self._event.set() # Signal that new data is available
def stream_start(self):
"""Initialize streaming by activating the generator and clearing any old chunks."""
self.streaming_chat_completion_mode_function_name = None
@ -379,10 +268,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_mode_function_name = None
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
self._push_to_buffer(self.multi_step_gen_indicator)
# self._active = False
# self._event.set() # Unblock the generator if it's waiting to allow it to complete
self._chunks.append(self.multi_step_gen_indicator)
self._event.set() # Signal that new data is available
# if not self.multi_step:
# # end the stream
@ -393,27 +280,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# self._chunks.append(self.multi_step_indicator)
# self._event.set() # Signal that new data is available
def step_complete(self):
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
if not self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# signal that a new step has started in the stream
self._push_to_buffer(self.multi_step_indicator)
def step_yield(self):
"""If multi_step, this is the true 'stream_end' function."""
# if self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
@staticmethod
def clear():
return
def _process_chunk_to_memgpt_style(self, chunk: ChatCompletionChunkResponse) -> Optional[dict]:
"""
Example data from non-streaming response looks like:
@ -539,7 +405,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if msg_obj:
processed_chunk["id"] = str(msg_obj.id)
self._push_to_buffer(processed_chunk)
self._chunks.append(processed_chunk)
self._event.set() # Signal that new data is available
def get_generator(self) -> AsyncGenerator:
"""Get the generator that yields processed chunks."""
if not self._active:
# If the stream is not active, don't return a generator that would produce values
raise StopIteration("The stream has not been started or has been ended.")
return self._create_generator()
def user_message(self, msg: str, msg_obj: Optional[Message] = None):
"""MemGPT receives a user message"""
@ -550,18 +424,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if not self.streaming_mode:
# create a fake "chunk" of a stream
# processed_chunk = {
# "internal_monologue": msg,
# "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
# "id": str(msg_obj.id) if msg_obj is not None else None,
# }
processed_chunk = InternalMonologue(
id=msg_obj.id,
date=msg_obj.created_at,
internal_monologue=msg,
)
processed_chunk = {
"internal_monologue": msg,
"date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
"id": str(msg_obj.id) if msg_obj is not None else None,
}
self._push_to_buffer(processed_chunk)
self._chunks.append(processed_chunk)
self._event.set() # Signal that new data is available
return
@ -603,56 +473,42 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# "date": "2024-06-22T23:04:32.141923+00:00"
# }
try:
func_args = json.loads(function_call.function.arguments)
func_args = json.loads(function_call.function["arguments"])
except:
func_args = function_call.function.arguments
# processed_chunk = {
# "function_call": f"{function_call.function.name}({func_args})",
# "id": str(msg_obj.id),
# "date": msg_obj.created_at.isoformat(),
# }
processed_chunk = LegacyFunctionCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=f"{function_call.function.name}({func_args})",
)
self._push_to_buffer(processed_chunk)
func_args = function_call.function["arguments"]
processed_chunk = {
"function_call": f"{function_call.function['name']}({func_args})",
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
self._chunks.append(processed_chunk)
self._event.set() # Signal that new data is available
if function_call.function.name == "send_message":
if function_call.function["name"] == "send_message":
try:
# processed_chunk = {
# "assistant_message": func_args["message"],
# "id": str(msg_obj.id),
# "date": msg_obj.created_at.isoformat(),
# }
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
assistant_message=func_args["message"],
)
self._push_to_buffer(processed_chunk)
processed_chunk = {
"assistant_message": func_args["message"],
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
self._chunks.append(processed_chunk)
self._event.set() # Signal that new data is available
except Exception as e:
print(f"Failed to parse function message: {e}")
else:
processed_chunk = FunctionCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
),
)
# processed_chunk = {
# "function_call": {
# "name": function_call.function.name,
# "arguments": function_call.function.arguments,
# },
# "id": str(msg_obj.id),
# "date": msg_obj.created_at.isoformat(),
# }
self._push_to_buffer(processed_chunk)
processed_chunk = {
"function_call": {
"id": function_call.id,
"name": function_call.function["name"],
"arguments": function_call.function["arguments"],
},
"id": str(msg_obj.id),
"date": msg_obj.created_at.isoformat(),
}
self._chunks.append(processed_chunk)
self._event.set() # Signal that new data is available
return
else:
@ -667,33 +523,43 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"}
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="success",
)
new_message = {"function_return": msg, "status": "success"}
elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
# new_message = {"function_return": msg, "status": "error"}
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="error",
)
new_message = {"function_return": msg, "status": "error"}
else:
# NOTE: generic, should not happen
raise ValueError(msg)
new_message = {"function_message": msg}
# add extra metadata
# if msg_obj is not None:
# new_message["id"] = str(msg_obj.id)
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
# new_message["date"] = msg_obj.created_at.isoformat()
if msg_obj is not None:
new_message["id"] = str(msg_obj.id)
assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
new_message["date"] = msg_obj.created_at.isoformat()
self._push_to_buffer(new_message)
self._chunks.append(new_message)
self._event.set() # Signal that new data is available
def step_complete(self):
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
if not self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
elif not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# signal that a new step has started in the stream
self._chunks.append(self.multi_step_indicator)
self._event.set() # Signal that new data is available
def step_yield(self):
"""If multi_step, this is the true 'stream_end' function."""
if self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
@staticmethod
def clear():
return

View File

@ -4,7 +4,7 @@ from typing import List
from fastapi import APIRouter
from pydantic import BaseModel, Field
from memgpt.schemas.llm_config import LLMConfig
from memgpt.models.pydantic_models import LLMConfigModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -13,7 +13,7 @@ router = APIRouter()
class ListModelsResponse(BaseModel):
models: List[LLMConfig] = Field(..., description="List of model configurations.")
models: List[LLMConfigModel] = Field(..., description="List of model configurations.")
def setup_models_index_router(server: SyncServer, interface: QueuingInterface, password: str):
@ -25,7 +25,7 @@ def setup_models_index_router(server: SyncServer, interface: QueuingInterface, p
interface.clear()
# currently, the server only supports one model, however this may change in the future
llm_config = LLMConfig(
llm_config = LLMConfigModel(
model=server.server_llm_config.model,
model_endpoint=server.server_llm_config.model_endpoint,
model_endpoint_type=server.server_llm_config.model_endpoint_type,

View File

@ -4,9 +4,10 @@ from typing import List, Optional
from fastapi import APIRouter, Body, HTTPException, Path, Query
from pydantic import BaseModel, Field
from memgpt.config import MemGPTConfig
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.message import Message
from memgpt.schemas.openai.openai import (
from memgpt.data_types import Message
from memgpt.models.openai import (
AssistantFile,
MessageFile,
MessageRoleType,
@ -138,6 +139,10 @@ class SubmitToolOutputsToRunRequest(BaseModel):
# TODO: implement mechanism for creating/authenticating users associated with a bearer token
def setup_openai_assistant_router(server: SyncServer, interface: QueuingInterface):
# TODO: remove this (when we have user auth)
user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
print(f"User ID: {user_id}")
# create assistant (MemGPT agent)
@router.post("/assistants", tags=["assistants"], response_model=OpenAIAssistant)
def create_assistant(request: CreateAssistantRequest = Body(...)):

View File

@ -4,9 +4,9 @@ from functools import partial
from fastapi import APIRouter, Body, Depends, HTTPException
# from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_request import ChatCompletionRequest
from memgpt.schemas.openai.chat_completion_response import (
# from memgpt.data_types import Message
from memgpt.models.chat_completion_request import ChatCompletionRequest
from memgpt.models.chat_completion_response import (
ChatCompletionResponse,
Choice,
Message,

View File

@ -5,7 +5,7 @@ from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.schemas.block import Persona as PersonaModel # TODO: modify
from memgpt.models.pydantic_models import PersonaModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -44,7 +44,7 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
interface.clear()
new_persona = PersonaModel(text=request.text, name=request.name, user_id=user_id)
persona_id = new_persona.id
server.ms.create_persona(new_persona)
server.ms.add_persona(new_persona)
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)
@router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)

View File

@ -0,0 +1,171 @@
import uuid
from functools import partial
from typing import Dict, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.data_types import Preset # TODO remove
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
from memgpt.prompts import gpt_system
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text, get_persona_text
router = APIRouter()
"""
Implement the following functions:
* List all available presets
* Create a new preset
* Delete a preset
* TODO update a preset
"""
class ListPresetsResponse(BaseModel):
presets: List[PresetModel] = Field(..., description="List of available presets.")
class CreatePresetsRequest(BaseModel):
# TODO is there a cleaner way to create the request from the PresetModel (need to drop fields though)?
name: str = Field(..., description="The name of the preset.")
id: Optional[str] = Field(None, description="The unique identifier of the preset.")
# user_id: uuid.UUID = Field(..., description="The unique identifier of the user who created the preset.")
description: Optional[str] = Field(None, description="The description of the preset.")
# created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the preset was created.")
system: Optional[str] = Field(None, description="The system prompt of the preset.") # TODO: make optional and allow defaults
persona: Optional[str] = Field(default=None, description="The persona of the preset.")
human: Optional[str] = Field(default=None, description="The human of the preset.")
functions_schema: List[Dict] = Field(..., description="The functions schema of the preset.")
# TODO
persona_name: Optional[str] = Field(None, description="The name of the persona of the preset.")
human_name: Optional[str] = Field(None, description="The name of the human of the preset.")
system_name: Optional[str] = Field(None, description="The name of the system prompt of the preset.")
class CreatePresetResponse(BaseModel):
preset: PresetModel = Field(..., description="The newly created preset.")
def setup_presets_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/presets/{preset_name}", tags=["presets"], response_model=PresetModel)
async def get_preset(
preset_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Get a preset."""
try:
preset = server.get_preset(user_id=user_id, preset_name=preset_name)
return preset
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.get("/presets", tags=["presets"], response_model=ListPresetsResponse)
async def list_presets(
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""List all presets created by a user."""
# Clear the interface
interface.clear()
try:
presets = server.list_presets(user_id=user_id)
return ListPresetsResponse(presets=presets)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/presets", tags=["presets"], response_model=CreatePresetResponse)
async def create_preset(
request: CreatePresetsRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Create a preset."""
try:
if isinstance(request.id, str):
request.id = uuid.UUID(request.id)
# check if preset already exists
# TODO: move this into a server function to create a preset
if server.ms.get_preset(name=request.name, user_id=user_id):
raise HTTPException(status_code=400, detail=f"Preset with name {request.name} already exists.")
# For system/human/persona - if {system/human-personal}_name is None but the text is provied, then create a new data entry
if not request.system_name and request.system:
# new system provided without name identity
system_name = f"system_{request.name}_{str(uuid.uuid4())}"
system = request.system
# TODO: insert into system table
else:
system_name = request.system_name if request.system_name else DEFAULT_PRESET
system = request.system if request.system else gpt_system.get_system_text(system_name)
if not request.human_name and request.human:
# new human provided without name identity
human_name = f"human_{request.name}_{str(uuid.uuid4())}"
human = request.human
server.ms.add_human(HumanModel(text=human, name=human_name, user_id=user_id))
else:
human_name = request.human_name if request.human_name else DEFAULT_HUMAN
human = request.human if request.human else get_human_text(human_name)
if not request.persona_name and request.persona:
# new persona provided without name identity
persona_name = f"persona_{request.name}_{str(uuid.uuid4())}"
persona = request.persona
server.ms.add_persona(PersonaModel(text=persona, name=persona_name, user_id=user_id))
else:
persona_name = request.persona_name if request.persona_name else DEFAULT_PERSONA
persona = request.persona if request.persona else get_persona_text(persona_name)
# create preset
new_preset = Preset(
user_id=user_id,
id=request.id if request.id else uuid.uuid4(),
name=request.name,
description=request.description,
system=system,
persona=persona,
persona_name=persona_name,
human=human,
human_name=human_name,
functions_schema=request.functions_schema,
)
preset = server.create_preset(preset=new_preset)
# TODO remove once we migrate from Preset to PresetModel
preset = PresetModel(**vars(preset))
return CreatePresetResponse(preset=preset)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.delete("/presets/{preset_id}", tags=["presets"])
async def delete_preset(
preset_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""Delete a preset."""
interface.clear()
try:
preset = server.delete_preset(user_id=user_id, preset_id=preset_id)
return JSONResponse(
status_code=status.HTTP_200_OK, content={"message": f"Preset preset_id={str(preset.id)} successfully deleted"}
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return router

View File

@ -15,12 +15,14 @@ from memgpt.server.constants import REST_DEFAULT_PORT
from memgpt.server.rest_api.admin.agents import setup_agents_admin_router
from memgpt.server.rest_api.admin.tools import setup_tools_index_router
from memgpt.server.rest_api.admin.users import setup_admin_router
from memgpt.server.rest_api.agents.command import setup_agents_command_router
from memgpt.server.rest_api.agents.config import setup_agents_config_router
from memgpt.server.rest_api.agents.index import setup_agents_index_router
from memgpt.server.rest_api.agents.memory import setup_agents_memory_router
from memgpt.server.rest_api.agents.message import setup_agents_message_router
from memgpt.server.rest_api.auth.index import setup_auth_router
from memgpt.server.rest_api.block.index import setup_block_index_router
from memgpt.server.rest_api.config.index import setup_config_index_router
from memgpt.server.rest_api.humans.index import setup_humans_index_router
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.models.index import setup_models_index_router
from memgpt.server.rest_api.openai_assistants.assistants import (
@ -29,6 +31,8 @@ from memgpt.server.rest_api.openai_assistants.assistants import (
from memgpt.server.rest_api.openai_chat_completions.chat_completions import (
setup_openai_chat_completions_router,
)
from memgpt.server.rest_api.personas.index import setup_personas_index_router
from memgpt.server.rest_api.presets.index import setup_presets_index_router
from memgpt.server.rest_api.sources.index import setup_sources_index_router
from memgpt.server.rest_api.static_files import mount_static_files
from memgpt.server.rest_api.tools.index import setup_user_tools_index_router
@ -91,13 +95,17 @@ app.include_router(setup_tools_index_router(server, interface), prefix=ADMIN_PRE
app.include_router(setup_agents_admin_router(server, interface), prefix=ADMIN_API_PREFIX, dependencies=[Depends(verify_password)])
# /api/agents endpoints
app.include_router(setup_agents_command_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_config_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_humans_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_personas_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_presets_index_router(server, interface, password), prefix=API_PREFIX)
# /api/config endpoints
app.include_router(setup_config_index_router(server, interface, password), prefix=API_PREFIX)
@ -145,8 +153,7 @@ def on_startup():
@app.on_event("shutdown")
def on_shutdown():
global server
if server:
server.save_agents()
server.save_agents()
server = None

View File

@ -1,7 +1,8 @@
import os
import tempfile
import uuid
from functools import partial
from typing import List
from typing import List, Optional
from fastapi import (
APIRouter,
@ -11,14 +12,20 @@ from fastapi import (
HTTPException,
Query,
UploadFile,
status,
)
from fastapi.responses import JSONResponse
from pydantic import BaseModel, Field
from memgpt.schemas.document import Document
from memgpt.schemas.job import Job
from memgpt.schemas.passage import Passage
# schemas
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate, UploadFile
from memgpt.data_sources.connectors import DirectoryConnector
from memgpt.data_types import Source
from memgpt.models.pydantic_models import (
DocumentModel,
JobModel,
JobStatus,
PassageModel,
SourceModel,
)
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -37,73 +44,77 @@ Implement the following functions:
"""
# class ListSourcesResponse(BaseModel):
# sources: List[SourceModel] = Field(..., description="List of available sources.")
#
#
# class CreateSourceRequest(BaseModel):
# name: str = Field(..., description="The name of the source.")
# description: Optional[str] = Field(None, description="The description of the source.")
#
#
# class UploadFileToSourceRequest(BaseModel):
# file: UploadFile = Field(..., description="The file to upload.")
#
#
# class UploadFileToSourceResponse(BaseModel):
# source: SourceModel = Field(..., description="The source the file was uploaded to.")
# added_passages: int = Field(..., description="The number of passages added to the source.")
# added_documents: int = Field(..., description="The number of documents added to the source.")
#
#
# class GetSourcePassagesResponse(BaseModel):
# passages: List[PassageModel] = Field(..., description="List of passages from the source.")
#
#
# class GetSourceDocumentsResponse(BaseModel):
# documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
class ListSourcesResponse(BaseModel):
sources: List[SourceModel] = Field(..., description="List of available sources.")
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(bytes)
class CreateSourceRequest(BaseModel):
name: str = Field(..., description="The name of the source.")
description: Optional[str] = Field(None, description="The description of the source.")
server.load_file_to_source(source_id, file_path, job_id)
class UploadFileToSourceRequest(BaseModel):
file: UploadFile = Field(..., description="The file to upload.")
class UploadFileToSourceResponse(BaseModel):
source: SourceModel = Field(..., description="The source the file was uploaded to.")
added_passages: int = Field(..., description="The number of passages added to the source.")
added_documents: int = Field(..., description="The number of documents added to the source.")
class GetSourcePassagesResponse(BaseModel):
passages: List[PassageModel] = Field(..., description="List of passages from the source.")
class GetSourceDocumentsResponse(BaseModel):
documents: List[DocumentModel] = Field(..., description="List of documents from the source.")
def load_file_to_source(server: SyncServer, user_id: uuid.UUID, source: Source, job_id: uuid.UUID, file: UploadFile, bytes: bytes):
# update job status
job = server.ms.get_job(job_id=job_id)
job.status = JobStatus.running
server.ms.update_job(job)
try:
# write the file to a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
file_path = os.path.join(tmpdirname, file.filename)
with open(file_path, "wb") as buffer:
buffer.write(bytes)
# read the file
connector = DirectoryConnector(input_files=[file_path])
# TODO: pre-compute total number of passages?
# load the data into the source via the connector
num_passages, num_documents = server.load_data(user_id=user_id, source_name=source.name, connector=connector)
except Exception as e:
# job failed with error
error = str(e)
print(error)
job.status = JobStatus.failed
job.metadata_["error"] = error
server.ms.update_job(job)
# TODO: delete any associated passages/documents?
return 0, 0
# update job status
job.status = JobStatus.completed
job.metadata_["num_passages"] = num_passages
job.metadata_["num_documents"] = num_documents
print("job completed", job.metadata_, job.id)
server.ms.update_job(job)
def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/sources/{source_id}", tags=["sources"], response_model=Source)
async def get_source(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get all sources
"""
interface.clear()
source = server.get_source(source_id=source_id, user_id=user_id)
return source
@router.get("/sources/name/{source_name}", tags=["sources"], response_model=str)
async def get_source_id_by_name(
source_name: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a source by name
"""
interface.clear()
source = server.get_source_id(source_name=source_name, user_id=user_id)
return source
@router.get("/sources", tags=["sources"], response_model=List[Source])
@router.get("/sources", tags=["sources"], response_model=ListSourcesResponse)
async def list_sources(
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
List all data sources created by a user.
@ -113,40 +124,58 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
try:
sources = server.list_all_sources(user_id=user_id)
return sources
return ListSourcesResponse(sources=sources)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources", tags=["sources"], response_model=Source)
@router.post("/sources", tags=["sources"], response_model=SourceModel)
async def create_source(
request: SourceCreate = Body(...),
user_id: str = Depends(get_current_user_with_server),
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Create a new data source.
"""
interface.clear()
try:
return server.create_source(request=request, user_id=user_id)
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.create_source(name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}", tags=["sources"], response_model=Source)
@router.post("/sources/{source_id}", tags=["sources"], response_model=SourceModel)
async def update_source(
source_id: str,
request: SourceUpdate = Body(...),
user_id: str = Depends(get_current_user_with_server),
source_id: uuid.UUID,
request: CreateSourceRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Update the name or documentation of an existing data source.
"""
interface.clear()
try:
return server.update_source(request=request, user_id=user_id)
# TODO: don't use Source and just use SourceModel once pydantic migration is complete
source = server.update_source(source_id=source_id, name=request.name, user_id=user_id, description=request.description)
return SourceModel(
name=source.name,
description=source.description,
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at.timestamp(),
)
except HTTPException:
raise
except Exception as e:
@ -154,8 +183,8 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
@router.delete("/sources/{source_id}", tags=["sources"])
async def delete_source(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete a data source.
@ -163,58 +192,66 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
interface.clear()
try:
server.delete_source(source_id=source_id, user_id=user_id)
return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Source source_id={source_id} successfully deleted"})
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=Source)
@router.post("/sources/{source_id}/attach", tags=["sources"], response_model=SourceModel)
async def attach_source_to_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to attach the source to."),
user_id: str = Depends(get_current_user_with_server),
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to attach the source to."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Attach a data source to an existing agent.
"""
interface.clear()
assert isinstance(agent_id, str), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, str), f"Expected user_id to be a UUID, got {user_id}"
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
source = server.ms.get_source(source_id=source_id, user_id=user_id)
source = server.attach_source_to_agent(source_id=source.id, agent_id=agent_id, user_id=user_id)
return source
source = server.attach_source_to_agent(source_name=source.name, agent_id=agent_id, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions
user_id=source.user_id,
id=source.id,
embedding_config=server.server_embedding_config,
created_at=source.created_at,
)
@router.post("/sources/{source_id}/detach", tags=["sources"])
@router.post("/sources/{source_id}/detach", tags=["sources"], response_model=SourceModel)
async def detach_source_from_agent(
source_id: str,
agent_id: str = Query(..., description="The unique identifier of the agent to detach the source from."),
user_id: str = Depends(get_current_user_with_server),
source_id: uuid.UUID,
agent_id: uuid.UUID = Query(..., description="The unique identifier of the agent to detach the source from."),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Detach a data source from an existing agent.
"""
server.detach_source_from_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=Job)
async def get_job(
job_id: str,
user_id: str = Depends(get_current_user_with_server),
@router.get("/sources/status/{job_id}", tags=["sources"], response_model=JobModel)
async def get_job_status(
job_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get the status of a job.
"""
job = server.get_job(job_id=job_id)
job = server.ms.get_job(job_id=job_id)
if job is None:
raise HTTPException(status_code=404, detail=f"Job with id={job_id} not found.")
return job
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=Job)
@router.post("/sources/{source_id}/upload", tags=["sources"], response_model=JobModel)
async def upload_file_to_source(
# file: UploadFile = UploadFile(..., description="The file to upload."),
file: UploadFile,
source_id: str,
source_id: uuid.UUID,
background_tasks: BackgroundTasks,
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Upload a file to a data source.
@ -224,39 +261,37 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
bytes = file.file.read()
# create job
# TODO: create server function
job = Job(user_id=user_id, metadata_={"type": "embedding", "filename": file.filename, "source_id": source_id})
job = JobModel(user_id=user_id, metadata={"type": "embedding", "filename": file.filename, "source_id": source_id})
job_id = job.id
server.ms.create_job(job)
# create background task
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
background_tasks.add_task(load_file_to_source, server, user_id, source, job_id, file, bytes)
# return job information
job = server.ms.get_job(job_id=job_id)
return job
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=List[Passage])
@router.get("/sources/{source_id}/passages ", tags=["sources"], response_model=GetSourcePassagesResponse)
async def list_passages(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
List all passages associated with a data source.
"""
# TODO: check if paginated?
passages = server.list_data_source_passages(user_id=user_id, source_id=source_id)
return passages
return GetSourcePassagesResponse(passages=passages)
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=List[Document])
@router.get("/sources/{source_id}/documents", tags=["sources"], response_model=GetSourceDocumentsResponse)
async def list_documents(
source_id: str,
user_id: str = Depends(get_current_user_with_server),
source_id: uuid.UUID,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
List all documents associated with a data source.
"""
documents = server.list_data_source_documents(user_id=user_id, source_id=source_id)
return documents
return GetSourceDocumentsResponse(documents=documents)
return router

View File

@ -1,9 +1,11 @@
import uuid
from functools import partial
from typing import List
from typing import List, Literal, Optional
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
from memgpt.models.pydantic_models import ToolModel
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@ -11,92 +13,121 @@ from memgpt.server.server import SyncServer
router = APIRouter()
class ListToolsResponse(BaseModel):
tools: List[ToolModel] = Field(..., description="List of tools (functions).")
class CreateToolRequest(BaseModel):
json_schema: dict = Field(..., description="JSON schema of the tool.") # NOT OpenAI - just has `name`
source_code: str = Field(..., description="The source code of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
update: Optional[bool] = Field(False, description="Update the tool if it already exists.")
class CreateToolResponse(BaseModel):
tool: ToolModel = Field(..., description="Information about the newly created tool.")
def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.delete("/tools/{tool_id}", tags=["tools"])
@router.delete("/tools/{tool_name}", tags=["tools"])
async def delete_tool(
tool_id: str,
user_id: str = Depends(get_current_user_with_server),
tool_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Delete a tool by name
"""
# Clear the interface
interface.clear()
server.delete_tool(id)
# tool = server.ms.delete_tool(user_id=user_id, tool_name=tool_name) TODO: add back when user-specific
server.ms.delete_tool(name=tool_name, user_id=user_id)
@router.get("/tools/{tool_id}", tags=["tools"], response_model=Tool)
@router.get("/tools/{tool_name}", tags=["tools"], response_model=ToolModel)
async def get_tool(
tool_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
tool = server.get_tool(tool_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with id {tool_id} not found.")
return tool
@router.get("/tools/name/{tool_name}", tags=["tools"], response_model=str)
async def get_tool_id(
tool_name: str,
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a tool by name
"""
# Clear the interface
interface.clear()
tool = server.get_tool_id(tool_name, user_id=user_id)
tool = server.ms.get_tool(tool_name=tool_name, user_id=user_id)
if tool is None:
# return 404 error
raise HTTPException(status_code=404, detail=f"Tool with name {tool_name} not found.")
return tool
@router.get("/tools", tags=["tools"], response_model=List[Tool])
@router.get("/tools", tags=["tools"], response_model=ListToolsResponse)
async def list_all_tools(
user_id: str = Depends(get_current_user_with_server),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Get a list of all tools available to agents created by a user
"""
# Clear the interface
interface.clear()
return server.list_tools(user_id)
tools = server.ms.list_tools(user_id=user_id)
return ListToolsResponse(tools=tools)
@router.post("/tools", tags=["tools"], response_model=Tool)
@router.post("/tools", tags=["tools"], response_model=ToolModel)
async def create_tool(
request: ToolCreate = Body(...),
user_id: str = Depends(get_current_user_with_server),
request: CreateToolRequest = Body(...),
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
"""
Create a new tool
"""
try:
return server.create_tool(request, user_id=user_id)
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
# NOTE: horrifying code, should be replaced when we migrate dev portal
from memgpt.agent import Agent # nasty: need agent to be defined
from memgpt.functions.schema_generator import generate_schema
name = request.json_schema["name"]
import ast
parsed_code = ast.parse(request.source_code)
function_names = []
# Function to find and print function names
def find_function_names(node):
for child in ast.iter_child_nodes(node):
if isinstance(child, ast.FunctionDef):
# Print the name of the function
function_names.append(child.name)
# Recurse into child nodes
find_function_names(child)
# Find and print function names
find_function_names(parsed_code)
assert len(function_names) == 1, f"Expected 1 function, found {len(function_names)}: {function_names}"
# generate JSON schema
env = {}
env.update(globals())
exec(request.source_code, env)
func = env.get(function_names[0])
json_schema = generate_schema(func, name=name)
from pprint import pprint
pprint(json_schema)
@router.post("/tools/{tool_id}", tags=["tools"], response_model=Tool)
async def update_tool(
tool_id: str,
request: ToolUpdate = Body(...),
user_id: str = Depends(get_current_user_with_server),
):
"""
Update an existing tool
"""
try:
# TODO: check that the user has access to this tool?
return server.update_tool(request)
return server.create_tool(
# json_schema=request.json_schema, # TODO: add back
json_schema=json_schema,
source_code=request.source_code,
source_type=request.source_type,
tags=request.tags,
user_id=user_id,
exists_ok=request.update,
)
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=f"Failed to update tool: {e}")
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}, exists_ok={request.update}")
return router

View File

@ -1,14 +1,10 @@
import asyncio
import json
import traceback
from enum import Enum
from typing import AsyncGenerator, Union
from pydantic import BaseModel
from typing import AsyncGenerator, Generator, Union
from memgpt.constants import JSON_ENSURE_ASCII
SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n"
SSE_FINISH_MSG = "[DONE]" # mimic openai
SSE_ARTIFICIAL_DELAY = 0.1
@ -17,7 +13,27 @@ def sse_formatter(data: Union[dict, str]) -> str:
"""Prefix with 'data: ', and always include double newlines"""
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data
return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}"
return f"data: {data_str}\n\n"
async def sse_generator(generator: Generator[dict, None, None]) -> Generator[str, None, None]:
"""Generator that returns 'data: dict' formatted items, e.g.:
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-9E0PdSZ2IBzAGlQ3SEWHJ5YwzucSP","object":"chat.completion.chunk","created":1713125205,"model":"gpt-4-0613","system_fingerprint":null,"choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
data: [DONE]
"""
try:
for msg in generator:
yield sse_formatter(msg)
if SSE_ARTIFICIAL_DELAY:
await asyncio.sleep(SSE_ARTIFICIAL_DELAY) # Sleep to prevent a tight loop, adjust time as needed
except Exception as e:
yield sse_formatter({"error": f"{str(e)}"})
yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete
async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
@ -33,20 +49,12 @@ async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
try:
async for chunk in generator:
# yield f"data: {json.dumps(chunk)}\n\n"
if isinstance(chunk, BaseModel):
chunk = chunk.model_dump()
elif isinstance(chunk, Enum):
chunk = str(chunk.value)
elif not isinstance(chunk, dict):
chunk = str(chunk)
yield sse_formatter(chunk)
except Exception as e:
print("stream decoder hit error:", e)
print(traceback.print_stack())
yield sse_formatter({"error": "stream decoder encountered an error"})
finally:
# yield "data: [DONE]\n\n"
if finish_message:
# Signal that the stream is complete
yield sse_formatter(SSE_FINISH_MSG)
yield sse_formatter(SSE_FINISH_MSG) # Signal that the stream is complete

File diff suppressed because it is too large Load Diff

View File

@ -43,12 +43,5 @@ class Settings(BaseSettings):
return None
class TestSettings(Settings):
model_config = SettingsConfigDict(env_prefix="memgpt_test_")
memgpt_dir: Optional[Path] = Field(Path.home() / ".memgpt/test", env="MEMGPT_TEST_DIR")
# singleton
settings = Settings()
test_settings = TestSettings()

View File

@ -7,9 +7,9 @@ from rich.console import Console
from rich.live import Live
from rich.markup import escape
from memgpt.data_types import Message
from memgpt.interface import CLIInterface
from memgpt.schemas.message import Message
from memgpt.schemas.openai.chat_completion_response import (
from memgpt.models.chat_completion_response import (
ChatCompletionChunkResponse,
ChatCompletionResponse,
)

View File

@ -33,8 +33,8 @@ from memgpt.constants import (
MEMGPT_DIR,
TOOL_CALL_ID_MAX_LEN,
)
from memgpt.models.chat_completion_response import ChatCompletionResponse
from memgpt.openai_backcompat.openai_object import OpenAIObject
from memgpt.schemas.openai.chat_completion_response import ChatCompletionResponse
DEBUG = False
if "LOG_LEVEL" in os.environ:
@ -468,17 +468,6 @@ NOUN_BANK = [
]
def deduplicate(target_list: list) -> list:
seen = set()
dedup_list = []
for i in target_list:
if i not in seen:
seen.add(i)
dedup_list.append(i)
return dedup_list
def smart_urljoin(base_url: str, relative_url: str) -> str:
"""urljoin is stupid and wants a trailing / at the end of the endpoint address, or it will chop the suffix off"""
if not base_url.endswith("/"):

4360
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -28,13 +28,14 @@ pgvector = { version = "^0.2.3", optional = true }
pre-commit = {version = "^3.5.0", optional = true }
pg8000 = {version = "^1.30.3", optional = true}
websockets = {version = "^12.0", optional = true}
docstring-parser = ">=0.16,<0.17"
docstring-parser = "^0.15"
lancedb = "^0.3.3"
httpx = "^0.25.2"
numpy = "^1.26.2"
demjson3 = "^3.0.6"
#tiktoken = ">=0.7.0,<0.8.0"
tiktoken = "^0.5.1"
pyyaml = "^6.0.1"
chromadb = ">=0.4.24,<0.5.0"
chromadb = "^0.5.0"
sqlalchemy-json = "^0.7.0"
fastapi = {version = "^0.104.1", optional = true}
uvicorn = {version = "^0.24.0.post1", optional = true}
@ -50,7 +51,7 @@ pymilvus = {version ="^2.4.3", optional = true}
python-box = "^7.1.1"
sqlmodel = "^0.0.16"
autoflake = {version = "^2.3.0", optional = true}
llama-index = "^0.10.65"
llama-index = "^0.10.27"
llama-index-embeddings-openai = "^0.1.1"
llama-index-embeddings-huggingface = {version = "^0.2.0", optional = true}
llama-index-embeddings-azure-openai = "^0.1.6"
@ -58,15 +59,12 @@ python-multipart = "^0.0.9"
sqlalchemy-utils = "^0.41.2"
pytest-order = {version = "^1.2.0", optional = true}
pytest-asyncio = {version = "^0.23.2", optional = true}
pytest = { version = "^7.4.4", optional = true }
pydantic-settings = "^2.2.1"
httpx-sse = "^0.4.0"
isort = { version = "^5.13.2", optional = true }
llama-index-embeddings-ollama = {version = "^0.1.2", optional = true}
crewai = {version = "^0.41.1", optional = true}
crewai-tools = {version = "^0.8.3", optional = true}
docker = {version = "^7.1.0", optional = true}
tiktoken = "^0.7.0"
nltk = "^3.8.1"
protobuf = "3.20.0"
[tool.poetry.extras]
local = ["llama-index-embeddings-huggingface"]
@ -77,7 +75,6 @@ server = ["websockets", "fastapi", "uvicorn"]
autogen = ["pyautogen"]
qdrant = ["qdrant-client"]
ollama = ["llama-index-embeddings-ollama"]
crewai-tools = ["crewai", "docker", "crewai-tools"]
[tool.poetry.group.dev.dependencies]
black = "^24.4.2"

View File

@ -1,3 +1,3 @@
# from tests.config import TestMGPTConfig
#
# TEST_MEMGPT_CONFIG = TestMGPTConfig()
from tests.config import TestMGPTConfig
TEST_MEMGPT_CONFIG = TestMGPTConfig()

View File

@ -3,4 +3,4 @@ pythonpath = /memgpt
testpaths = /tests
asyncio_mode = auto
filterwarnings =
ignore::pytest.PytestRemovedIn9Warning
ignore::pytest.PytestRemovedIn8Warning

View File

@ -1,9 +1,11 @@
import threading
import time
import uuid
import pytest
from memgpt import Admin
from tests.test_client import _reset_config, run_server
test_base_url = "http://localhost:8283"
@ -11,13 +13,6 @@ test_base_url = "http://localhost:8283"
test_server_token = "test_server_token"
def run_server():
from memgpt.server.rest_api.server import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(scope="session", autouse=True)
def start_uvicorn_server():
"""Starts Uvicorn server in a background thread."""
@ -39,85 +34,91 @@ def admin_client():
def test_admin_client(admin_client):
_reset_config()
# create a user
user_name = "test_user"
user1 = admin_client.create_user(user_name)
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
user_id = uuid.uuid4()
create_user1_response = admin_client.create_user(user_id)
assert user_id == create_user1_response.user_id, f"Expected {user_id}, got {create_user1_response.user_id}"
# create another user
user2 = admin_client.create_user()
create_user_2_response = admin_client.create_user()
# create keys
key1_name = "test_key1"
key2_name = "test_key2"
api_key1 = admin_client.create_key(user1.id, key1_name)
admin_client.create_key(user2.id, key2_name)
create_key1_response = admin_client.create_key(user_id, key1_name)
create_key2_response = admin_client.create_key(create_user_2_response.user_id, key2_name)
# list users
users = admin_client.get_users()
assert len(users) == 2
assert user1.id in [user.id for user in users]
assert user2.id in [user.id for user in users]
assert len(users.user_list) == 2
print(users.user_list)
assert user_id in [uuid.UUID(u["user_id"]) for u in users.user_list]
# list keys
user1_keys = admin_client.get_keys(user1.id)
assert len(user1_keys) == 1, f"Expected 1 keys, got {user1_keys}"
assert api_key1.key == user1_keys[0].key
user1_keys = admin_client.get_keys(user_id)
assert len(user1_keys) == 2, f"Expected 2 keys, got {user1_keys}"
assert create_key1_response.api_key in user1_keys, f"Expected {create_key1_response.api_key} in {user1_keys}"
assert create_user1_response.api_key in user1_keys, f"Expected {create_user1_response.api_key} in {user1_keys}"
# delete key
deleted_key1 = admin_client.delete_key(api_key1.key)
assert deleted_key1.key == api_key1.key
assert len(admin_client.get_keys(user1.id)) == 0
delete_key1_response = admin_client.delete_key(create_key1_response.api_key)
assert delete_key1_response.api_key_deleted == create_key1_response.api_key
assert len(admin_client.get_keys(user_id)) == 1
delete_key2_response = admin_client.delete_key(create_key2_response.api_key)
assert delete_key2_response.api_key_deleted == create_key2_response.api_key
assert len(admin_client.get_keys(create_user_2_response.user_id)) == 1
# delete users
deleted_user1 = admin_client.delete_user(user1.id)
assert deleted_user1.id == user1.id
deleted_user2 = admin_client.delete_user(user2.id)
assert deleted_user2.id == user2.id
delete_user1_response = admin_client.delete_user(user_id)
assert delete_user1_response.user_id_deleted == user_id
delete_user2_response = admin_client.delete_user(create_user_2_response.user_id)
assert delete_user2_response.user_id_deleted == create_user_2_response.user_id
# list users
users = admin_client.get_users()
assert len(users) == 0, f"Expected 0 users, got {users}"
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
# def test_get_users_pagination(admin_client):
#
# page_size = 5
# num_users = 7
# expected_users_remainder = num_users - page_size
#
# # create users
# all_user_ids = []
# for i in range(num_users):
#
# user_id = uuid.uuid4()
# all_user_ids.append(user_id)
# key_name = "test_key" + f"{i}"
#
# create_user_response = admin_client.create_user(user_id)
# admin_client.create_key(create_user_response.user_id, key_name)
#
# # list users in page 1
# get_all_users_response1 = admin_client.get_users(limit=page_size)
# cursor1 = get_all_users_response1.cursor
# user_list1 = get_all_users_response1.user_list
# assert len(user_list1) == page_size
#
# # list users in page 2 using cursor
# get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
# cursor2 = get_all_users_response2.cursor
# user_list2 = get_all_users_response2.user_list
#
# assert len(user_list2) == expected_users_remainder
# assert cursor1 != cursor2
#
# # delete users
# clean_up_users_and_keys(all_user_ids)
#
# # list users to check pagination with no users
# users = admin_client.get_users()
# assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
def test_get_users_pagination(admin_client):
_reset_config()
page_size = 5
num_users = 7
expected_users_remainder = num_users - page_size
# create users
all_user_ids = []
for i in range(num_users):
user_id = uuid.uuid4()
all_user_ids.append(user_id)
key_name = "test_key" + f"{i}"
create_user_response = admin_client.create_user(user_id)
admin_client.create_key(create_user_response.user_id, key_name)
# list users in page 1
get_all_users_response1 = admin_client.get_users(limit=page_size)
cursor1 = get_all_users_response1.cursor
user_list1 = get_all_users_response1.user_list
assert len(user_list1) == page_size
# list users in page 2 using cursor
get_all_users_response2 = admin_client.get_users(cursor1, limit=page_size)
cursor2 = get_all_users_response2.cursor
user_list2 = get_all_users_response2.user_list
assert len(user_list2) == expected_users_remainder
assert cursor1 != cursor2
# delete users
clean_up_users_and_keys(all_user_ids)
# list users to check pagination with no users
users = admin_client.get_users()
assert len(users.user_list) == 0, f"Expected 0 users, got {users}"
def clean_up_users_and_keys(user_id_list):

View File

@ -1,41 +1,38 @@
# TODO: add back
import os
import subprocess
# import os
# import subprocess
#
# import pytest
#
#
# @pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
# def test_agent_groupchat():
#
# # Define the path to the script you want to test
# script_path = "memgpt/autogen/examples/agent_groupchat.py"
#
# # Dynamically get the project's root directory (assuming this script is run from the root)
# # project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
# # print(project_root)
# # project_root = os.path.join(project_root, "MemGPT")
# # print(project_root)
# # sys.exit(1)
#
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
# project_root = os.path.join(project_root, "memgpt")
# print(f"Adding the following to PATH: {project_root}")
#
# # Prepare the environment, adding the project root to PYTHONPATH
# env = os.environ.copy()
# env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}"
#
# # Run the script using subprocess.run
# # Capture the output (stdout) and the exit code
# # result = subprocess.run(["python", script_path], capture_output=True, text=True)
# result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True)
#
# # Check the exit code (0 indicates success)
# assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}"
#
# # Optionally, check the output for expected content
# # For example, if you expect a specific line in the output, uncomment and adapt the following line:
# # assert "expected output" in result.stdout, "Expected output not found in script's output"
#
import pytest
@pytest.mark.skipif(not os.getenv("OPENAI_API_KEY"), reason="Missing OpenAI API key")
def test_agent_groupchat():
# Define the path to the script you want to test
script_path = "memgpt/autogen/examples/agent_groupchat.py"
# Dynamically get the project's root directory (assuming this script is run from the root)
# project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
# print(project_root)
# project_root = os.path.join(project_root, "MemGPT")
# print(project_root)
# sys.exit(1)
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
project_root = os.path.join(project_root, "memgpt")
print(f"Adding the following to PATH: {project_root}")
# Prepare the environment, adding the project root to PYTHONPATH
env = os.environ.copy()
env["PYTHONPATH"] = f"{project_root}:{env.get('PYTHONPATH', '')}"
# Run the script using subprocess.run
# Capture the output (stdout) and the exit code
# result = subprocess.run(["python", script_path], capture_output=True, text=True)
result = subprocess.run(["poetry", "run", "python", script_path], capture_output=True, text=True)
# Check the exit code (0 indicates success)
assert result.returncode == 0, f"Script exited with code {result.returncode}: {result.stderr}"
# Optionally, check the output for expected content
# For example, if you expect a specific line in the output, uncomment and adapt the following line:
# assert "expected output" in result.stdout, "Expected output not found in script's output"

View File

@ -26,7 +26,7 @@ def agent_obj():
agent_state = client.create_agent()
global agent_obj
agent_obj = client.server._get_or_load_agent(agent_id=agent_state.id)
agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id)
yield agent_obj
client.delete_agent(agent_obj.agent_state.id)

View File

@ -1,12 +1,17 @@
import os
import subprocess
import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"])
import pexpect
from prettytable.colortable import ColorTable
from memgpt.cli.cli_config import ListChoice, add, delete
from memgpt.cli.cli_config import list as list_command
from .constants import TIMEOUT
from .utils import create_config
# def test_configure_memgpt():
# configure_memgpt()
@ -42,3 +47,41 @@ def test_cli_config():
assert "test data" in row
# delete
delete(option=option, name="test")
def test_save_load():
# configure_memgpt() # rely on configure running first^
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")
child = pexpect.spawn("poetry run memgpt run --agent test_save_load --first --strip-ui")
child.expect("Enter your message:", timeout=TIMEOUT)
child.sendline()
child.expect("Empty input received. Try again!", timeout=TIMEOUT)
child.sendline("/save")
child.expect("Enter your message:", timeout=TIMEOUT)
child.sendline("/exit")
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
child.close()
assert child.isalive() is False, "CLI should have terminated."
assert child.exitstatus == 0, "CLI did not exit cleanly."
child = pexpect.spawn("poetry run memgpt run --agent test_save_load --first --strip-ui")
child.expect("Using existing agent test_save_load", timeout=TIMEOUT)
child.expect("Enter your message:", timeout=TIMEOUT)
child.sendline("/exit")
child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
child.close()
assert child.isalive() is False, "CLI should have terminated."
assert child.exitstatus == 0, "CLI did not exit cleanly."
if __name__ == "__main__":
# test_configure_memgpt()
test_save_load()

View File

@ -7,11 +7,12 @@ import pytest
from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.config import MemGPTConfig
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics
# from tests.utils import create_config
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import Preset # TODO move to PresetModel
from memgpt.settings import settings
from tests.utils import create_config
test_agent_name = f"test_client_{str(uuid.uuid4())}"
# test_preset_name = "test_preset"
@ -20,16 +21,44 @@ test_agent_state = None
client = None
test_agent_state_post_message = None
test_user_id = uuid.uuid4()
# admin credentials
test_server_token = "test_server_token"
def _reset_config():
# Use os.getenv with a fallback to os.environ.get
db_url = settings.memgpt_pg_uri
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
credentials = MemGPTCredentials(
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
create_config("memgpt_hosted")
credentials = MemGPTCredentials()
config = MemGPTConfig.load()
# set to use postgres
config.archival_storage_uri = db_url
config.recall_storage_uri = db_url
config.metadata_storage_uri = db_url
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
config.metadata_storage_type = "postgres"
config.save()
credentials.save()
print("_reset_config :: ", config.config_path)
def run_server():
load_dotenv()
# _reset_config()
_reset_config()
from memgpt.server.rest_api.server import start_server
@ -39,8 +68,7 @@ def run_server():
# Fixture to create clients with different configurations
@pytest.fixture(
# params=[{"server": True}, {"server": False}], # whether to use REST API server
params=[{"server": True}], # whether to use REST API server
params=[{"server": True}, {"server": False}], # whether to use REST API server
scope="module",
)
def client(request):
@ -58,20 +86,21 @@ def client(request):
print("Running client tests with server:", server_url)
# create user via admin client
admin = Admin(server_url, test_server_token)
user = admin.create_user() # Adjust as per your client's method
api_key = admin.create_key(user.id)
client = create_client(base_url=server_url, token=api_key.key) # This yields control back to the test function
response = admin.create_user(test_user_id) # Adjust as per your client's method
token = response.api_key
else:
# use local client (no server)
token = None
server_url = None
client = create_client()
client = create_client(base_url=server_url, token=token) # This yields control back to the test function
try:
yield client
finally:
# cleanup user
if server_url:
admin.delete_user(user.id)
admin.delete_user(test_user_id) # Adjust as per your client's method
# Fixture for test agent
@ -86,6 +115,7 @@ def agent(client):
def test_agent(client, agent):
_reset_config()
# test client.rename_agent
new_name = "RenamedTestAgent"
@ -101,84 +131,61 @@ def test_agent(client, agent):
def test_memory(client, agent):
# _reset_config()
_reset_config()
memory_response = client.get_in_context_memory(agent_id=agent.id)
memory_response = client.get_agent_memory(agent_id=agent.id)
print("MEMORY", memory_response)
updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"}
client.update_in_context_memory(agent_id=agent.id, section="human", value=updated_memory["human"])
client.update_in_context_memory(agent_id=agent.id, section="persona", value=updated_memory["persona"])
updated_memory_response = client.get_in_context_memory(agent_id=agent.id)
client.update_agent_core_memory(agent_id=agent.id, new_memory_contents=updated_memory)
updated_memory_response = client.get_agent_memory(agent_id=agent.id)
assert (
updated_memory_response.get_block("human").value == updated_memory["human"]
and updated_memory_response.get_block("persona").value == updated_memory["persona"]
updated_memory_response.core_memory.human == updated_memory["human"]
and updated_memory_response.core_memory.persona == updated_memory["persona"]
), "Memory update failed"
def test_agent_interactions(client, agent):
# _reset_config()
_reset_config()
message = "Hello, agent!"
print("Sending message", message)
response = client.user_message(agent_id=agent.id, message=message)
print("Response", response)
assert isinstance(response.usage, MemGPTUsageStatistics)
assert response.usage.step_count == 1
assert response.usage.total_tokens > 0
assert response.usage.completion_tokens > 0
assert isinstance(response.messages[0], Message)
print(response.messages)
message_response = client.user_message(agent_id=agent.id, message=message)
# TODO: add streaming tests
command = "/memory"
command_response = client.run_command(agent_id=agent.id, command=command)
print("command", command_response)
def test_archival_memory(client, agent):
# _reset_config()
_reset_config()
memory_content = "Archival memory content"
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)[0]
print("Inserted memory", insert_response.text, insert_response.id)
insert_response = client.insert_archival_memory(agent_id=agent.id, memory=memory_content)
assert insert_response, "Inserting archival memory failed"
archival_memory_response = client.get_archival_memory(agent_id=agent.id, limit=1)
archival_memories = [memory.text for memory in archival_memory_response]
archival_memory_response = client.get_agent_archival_memory(agent_id=agent.id, limit=1)
print("MEMORY")
archival_memories = [memory.contents for memory in archival_memory_response.archival_memory]
assert memory_content in archival_memories, f"Retrieving archival memory failed: {archival_memories}"
memory_id_to_delete = archival_memory_response[0].id
memory_id_to_delete = archival_memory_response.archival_memory[0].id
client.delete_archival_memory(agent_id=agent.id, memory_id=memory_id_to_delete)
# add archival memory
memory_str = "I love chats"
passage = client.insert_archival_memory(agent.id, memory=memory_str)[0]
# list archival memory
passages = client.get_archival_memory(agent.id)
assert passage.text in [p.text for p in passages], f"Missing passage {passage.text} in {passages}"
# get archival memory summary
archival_summary = client.get_archival_memory_summary(agent.id)
assert archival_summary.size == 1, f"Archival memory summary size is {archival_summary.size}"
# delete archival memory
client.delete_archival_memory(agent.id, passage.id)
# TODO: check deletion
client.get_archival_memory(agent.id)
def test_messages(client, agent):
# _reset_config()
_reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"
messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"
assert len(messages_response.messages) > 0, "Retrieving messages failed"
def test_humans_personas(client, agent):
# _reset_config()
_reset_config()
humans_response = client.list_humans()
print("HUMANS", humans_response)
@ -187,20 +194,18 @@ def test_humans_personas(client, agent):
print("PERSONAS", personas_response)
persona_name = "TestPersona"
persona_id = client.get_persona_id(persona_name)
if persona_id:
client.delete_persona(persona_id)
if client.get_persona(persona_name):
client.delete_persona(persona_name)
persona = client.create_persona(name=persona_name, text="Persona text")
assert persona.name == persona_name
assert persona.value == "Persona text", "Creating persona failed"
assert persona.text == "Persona text", "Creating persona failed"
human_name = "TestHuman"
human_id = client.get_human_id(human_name)
if human_id:
client.delete_human(human_id)
if client.get_human(human_name):
client.delete_human(human_name)
human = client.create_human(name=human_name, text="Human text")
assert human.name == human_name
assert human.value == "Human text", "Creating human failed"
assert human.text == "Human text", "Creating human failed"
# def test_tools(client, agent):
@ -213,14 +218,11 @@ def test_humans_personas(client, agent):
def test_config(client, agent):
# _reset_config()
_reset_config()
models_response = client.list_models()
print("MODELS", models_response)
embeddings_response = client.list_embedding_models()
print("EMBEDDINGS", embeddings_response)
# TODO: add back
# config_response = client.get_config()
# TODO: ensure config is the same as the one in the server
@ -228,7 +230,7 @@ def test_config(client, agent):
def test_sources(client, agent):
# _reset_config()
_reset_config()
if not hasattr(client, "base_url"):
pytest.skip("Skipping test_sources because base_url is None")
@ -236,7 +238,7 @@ def test_sources(client, agent):
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 0
assert len(sources.sources) == 0
# create a source
source = client.create_source(name="test_source")
@ -244,53 +246,36 @@ def test_sources(client, agent):
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 1
# TODO: add back?
assert sources[0].metadata_["num_passages"] == 0
assert sources[0].metadata_["num_documents"] == 0
# update the source
original_id = source.id
original_name = source.name
new_name = original_name + "_new"
client.update_source(source_id=source.id, name=new_name)
# get the source name (check that it's been updated)
source = client.get_source(source_id=source.id)
assert source.name == new_name
assert source.id == original_id
# get the source id (make sure that it's the same)
assert str(original_id) == client.get_source_id(source_name=new_name)
assert len(sources.sources) == 1
assert sources.sources[0].metadata_["num_passages"] == 0
assert sources.sources[0].metadata_["num_documents"] == 0
# check agent archival memory size
archival_memories = client.get_archival_memory(agent_id=agent.id)
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
print(archival_memories)
assert len(archival_memories) == 0
# load a file into a source
filename = "CONTRIBUTING.md"
upload_job = client.load_file_into_source(filename=filename, source_id=source.id)
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
print("Upload job", upload_job, upload_job.status, upload_job.metadata)
# TODO: make sure things run in the right order
archival_memories = client.get_archival_memory(agent_id=agent.id)
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
assert len(archival_memories) == 0
# attach a source
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
# list archival memory
archival_memories = client.get_archival_memory(agent_id=agent.id)
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
# print(archival_memories)
assert len(archival_memories) == 20 or len(archival_memories) == 21
# check number of passages
sources = client.list_sources()
# TODO: add back?
# assert sources.sources[0].metadata_["num_passages"] > 0
# assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
assert sources.sources[0].metadata_["num_passages"] > 0
assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
print(sources)
# detach the source
@ -299,3 +284,80 @@ def test_sources(client, agent):
# delete the source
client.delete_source(source.id)
# def test_presets(client, agent):
# _reset_config()
#
# # new_preset = Preset(
# # # user_id=client.user_id,
# # name="pytest_test_preset",
# # description="DUMMY_DESCRIPTION",
# # system="DUMMY_SYSTEM",
# # persona="DUMMY_PERSONA",
# # persona_name="DUMMY_PERSONA_NAME",
# # human="DUMMY_HUMAN",
# # human_name="DUMMY_HUMAN_NAME",
# # functions_schema=[
# # {
# # "name": "send_message",
# # "json_schema": {
# # "name": "send_message",
# # "description": "Sends a message to the human user.",
# # "parameters": {
# # "type": "object",
# # "properties": {
# # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
# # },
# # "required": ["message"],
# # },
# # },
# # "tags": ["memgpt-base"],
# # "source_type": "python",
# # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
# # }
# # ],
# # )
#
# ## List all presets and make sure the preset is NOT in the list
# # all_presets = client.list_presets()
# # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# # Create a preset
# new_preset = client.create_preset(name="pytest_test_preset")
#
# # List all presets and make sure the preset is in the list
# all_presets = client.list_presets()
# assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
#
# # Delete the preset
# client.delete_preset(preset_id=new_preset.id)
#
# # List all presets and make sure the preset is NOT in the list
# all_presets = client.list_presets()
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# def test_tools(client, agent):
#
# # load a function
# file_path = "tests/data/functions/dump_json.py"
# module_name = "dump_json"
#
# # list functions
# response = client.list_tools()
# orig_tools = response.tools
# print(orig_tools)
#
# # add the tool
# create_tool_response = client.create_tool(name=module_name, file_path=file_path)
# print(create_tool_response)
#
# # list functions
# response = client.list_tools()
# new_tools = response.tools
# assert module_name in [tool.name for tool in new_tools]
# # assert len(new_tools) == len(orig_tools) + 1
#
# # TODO: add a function to a preset
#
# # TODO: add a function to an agent

View File

@ -1,142 +1,142 @@
# TODO: add back when messaging works
import os
import threading
import time
import uuid
# import os
# import threading
# import time
# import uuid
#
# import pytest
# from dotenv import load_dotenv
#
# from memgpt import Admin, create_client
# from memgpt.config import MemGPTConfig
# from memgpt.credentials import MemGPTCredentials
# from memgpt.settings import settings
# from tests.utils import create_config
#
# test_agent_name = f"test_client_{str(uuid.uuid4())}"
## test_preset_name = "test_preset"
# test_agent_state = None
# client = None
#
# test_agent_state_post_message = None
# test_user_id = uuid.uuid4()
#
#
## admin credentials
# test_server_token = "test_server_token"
#
#
# def _reset_config():
#
# # Use os.getenv with a fallback to os.environ.get
# db_url = settings.memgpt_pg_uri
#
# if os.getenv("OPENAI_API_KEY"):
# create_config("openai")
# credentials = MemGPTCredentials(
# openai_key=os.getenv("OPENAI_API_KEY"),
# )
# else: # hosted
# create_config("memgpt_hosted")
# credentials = MemGPTCredentials()
#
# config = MemGPTConfig.load()
#
# # set to use postgres
# config.archival_storage_uri = db_url
# config.recall_storage_uri = db_url
# config.metadata_storage_uri = db_url
# config.archival_storage_type = "postgres"
# config.recall_storage_type = "postgres"
# config.metadata_storage_type = "postgres"
#
# config.save()
# credentials.save()
# print("_reset_config :: ", config.config_path)
#
#
# def run_server():
#
# load_dotenv()
#
# _reset_config()
#
# from memgpt.server.rest_api.server import start_server
#
# print("Starting server...")
# start_server(debug=True)
#
#
## Fixture to create clients with different configurations
# @pytest.fixture(
# params=[ # whether to use REST API server
# {"server": True},
# # {"server": False} # TODO: add when implemented
# ],
# scope="module",
# )
# def admin_client(request):
# if request.param["server"]:
# # get URL from enviornment
# server_url = os.getenv("MEMGPT_SERVER_URL")
# if server_url is None:
# # run server in thread
# # NOTE: must set MEMGPT_SERVER_PASS enviornment variable
# server_url = "http://localhost:8283"
# print("Starting server thread")
# thread = threading.Thread(target=run_server, daemon=True)
# thread.start()
# time.sleep(5)
# print("Running client tests with server:", server_url)
# # create user via admin client
# admin = Admin(server_url, test_server_token)
# response = admin.create_user(test_user_id) # Adjust as per your client's method
#
# yield admin
#
#
# def test_concurrent_messages(admin_client):
# # test concurrent messages
#
# # create three
#
# results = []
#
# def _send_message():
# try:
# print("START SEND MESSAGE")
# response = admin_client.create_user()
# token = response.api_key
# client = create_client(base_url=admin_client.base_url, token=token)
# agent = client.create_agent()
#
# print("Agent created", agent.id)
#
# st = time.time()
# message = "Hello, how are you?"
# response = client.send_message(agent_id=agent.id, message=message, role="user")
# et = time.time()
# print(f"Message sent from {st} to {et}")
# print(response.messages)
# results.append((st, et))
# except Exception as e:
# print("ERROR", e)
#
# threads = []
# print("Starting threads...")
# for i in range(5):
# thread = threading.Thread(target=_send_message)
# threads.append(thread)
# thread.start()
# print("CREATED THREAD")
#
# print("waiting for threads to finish...")
# for thread in threads:
# print(thread.join())
#
# # make sure runtime are overlapping
# assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
# results[1][0] < results[0][0] and results[1][1] > results[0][0]
# ), f"Threads should have overlapping runtimes {results}"
#
import pytest
from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.config import MemGPTConfig
from memgpt.constants import DEFAULT_PRESET
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import Preset # TODO move to PresetModel
from memgpt.settings import settings
from tests.utils import create_config
test_agent_name = f"test_client_{str(uuid.uuid4())}"
# test_preset_name = "test_preset"
test_preset_name = DEFAULT_PRESET
test_agent_state = None
client = None
test_agent_state_post_message = None
test_user_id = uuid.uuid4()
# admin credentials
test_server_token = "test_server_token"
def _reset_config():
# Use os.getenv with a fallback to os.environ.get
db_url = settings.memgpt_pg_uri
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
credentials = MemGPTCredentials(
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
create_config("memgpt_hosted")
credentials = MemGPTCredentials()
config = MemGPTConfig.load()
# set to use postgres
config.archival_storage_uri = db_url
config.recall_storage_uri = db_url
config.metadata_storage_uri = db_url
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
config.metadata_storage_type = "postgres"
config.save()
credentials.save()
print("_reset_config :: ", config.config_path)
def run_server():
load_dotenv()
_reset_config()
from memgpt.server.rest_api.server import start_server
print("Starting server...")
start_server(debug=True)
# Fixture to create clients with different configurations
@pytest.fixture(
params=[ # whether to use REST API server
{"server": True},
# {"server": False} # TODO: add when implemented
],
scope="module",
)
def admin_client(request):
if request.param["server"]:
# get URL from enviornment
server_url = os.getenv("MEMGPT_SERVER_URL")
if server_url is None:
# run server in thread
# NOTE: must set MEMGPT_SERVER_PASS enviornment variable
server_url = "http://localhost:8283"
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create user via admin client
admin = Admin(server_url, test_server_token)
response = admin.create_user(test_user_id) # Adjust as per your client's method
yield admin
def test_concurrent_messages(admin_client):
# test concurrent messages
# create three
results = []
def _send_message():
try:
print("START SEND MESSAGE")
response = admin_client.create_user()
token = response.api_key
client = create_client(base_url=admin_client.base_url, token=token)
agent = client.create_agent()
print("Agent created", agent.id)
st = time.time()
message = "Hello, how are you?"
response = client.send_message(agent_id=agent.id, message=message, role="user")
et = time.time()
print(f"Message sent from {st} to {et}")
print(response.messages)
results.append((st, et))
except Exception as e:
print("ERROR", e)
threads = []
print("Starting threads...")
for i in range(5):
thread = threading.Thread(target=_send_message)
threads.append(thread)
thread.start()
print("CREATED THREAD")
print("waiting for threads to finish...")
for thread in threads:
print(thread.join())
# make sure runtime are overlapping
assert (results[0][0] < results[1][0] and results[0][1] > results[1][0]) or (
results[1][0] < results[0][0] and results[1][1] > results[0][0]
), f"Threads should have overlapping runtimes {results}"

Some files were not shown because too many files have changed in this diff Show More