mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: refactor loading and attaching data sources, and upgrade to llama-index==0.10.6
(#1016)
This commit is contained in:
parent
508679e2cb
commit
38c184caf8
@ -5,9 +5,11 @@ import json
|
||||
from pathlib import Path
|
||||
import traceback
|
||||
from typing import List, Tuple, Optional, cast, Union
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
from memgpt.data_types import AgentState, Message, EmbeddingConfig
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.data_types import AgentState, Message, EmbeddingConfig, Passage
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
@ -24,6 +26,7 @@ from memgpt.utils import (
|
||||
get_schema_diff,
|
||||
validate_function_response,
|
||||
verify_first_message_correctness,
|
||||
create_uuid_from_string,
|
||||
)
|
||||
from memgpt.constants import (
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
@ -951,3 +954,57 @@ class Agent(object):
|
||||
|
||||
# TODO: recall memory
|
||||
raise NotImplementedError()
|
||||
|
||||
def attach_source(self, source_name, source_connector: StorageConnector, ms: MetadataStore):
|
||||
"""Attach data with name `source_name` to the agent from source_connector."""
|
||||
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
|
||||
|
||||
filters = {"user_id": self.agent_state.user_id, "data_source": source_name}
|
||||
size = source_connector.size(filters)
|
||||
# typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN)
|
||||
page_size = 100
|
||||
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
|
||||
all_passages = []
|
||||
for i in tqdm(range(0, size, page_size)):
|
||||
passages = next(generator)
|
||||
|
||||
# need to associated passage with agent (for filtering)
|
||||
for passage in passages:
|
||||
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
|
||||
passage.agent_id = self.agent_state.id
|
||||
|
||||
# regenerate passage ID (avoid duplicates)
|
||||
passage.id = create_uuid_from_string(f"{source_name}_{str(passage.agent_id)}_{passage.text}")
|
||||
|
||||
# insert into agent archival memory
|
||||
self.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
all_passages += passages
|
||||
|
||||
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
|
||||
|
||||
# save destination storage
|
||||
self.persistence_manager.archival_memory.storage.save()
|
||||
|
||||
# attach to agent
|
||||
source = ms.get_source(source_name=source_name, user_id=self.agent_state.user_id)
|
||||
assert source is not None, f"source does not exist for source_name={source_name}, user_id={self.agent_state.user_id}"
|
||||
source_id = source.id
|
||||
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
|
||||
|
||||
total_agent_passages = self.persistence_manager.archival_memory.storage.size()
|
||||
|
||||
printd(
|
||||
f"Attached data source {source_name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
)
|
||||
|
||||
|
||||
def save_agent(agent: Agent, ms: MetadataStore):
|
||||
"""Save agent to metadata store"""
|
||||
|
||||
agent.update_state()
|
||||
agent_state = agent.agent_state
|
||||
|
||||
if ms.get_agent(agent_id=agent_state.id):
|
||||
ms.update_agent(agent_state)
|
||||
else:
|
||||
ms.create_agent(agent_state)
|
||||
|
@ -37,7 +37,7 @@ class ChromaStorageConnector(StorageConnector):
|
||||
self.include: Include = ["documents", "embeddings", "metadatas"]
|
||||
|
||||
# need to be converted to strings
|
||||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id"]
|
||||
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
|
||||
# get all filters for query
|
||||
|
@ -4,6 +4,7 @@ from typing import Callable, Optional, List, Dict, Union, Any, Tuple
|
||||
from autogen.agentchat import Agent, ConversableAgent, UserProxyAgent, GroupChat, GroupChatManager
|
||||
|
||||
from memgpt.agent import Agent as MemGPTAgent
|
||||
from memgpt.agent import save_agent
|
||||
from memgpt.autogen.interface import AutoGenInterface
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
@ -14,7 +15,6 @@ from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.cli.cli import attach
|
||||
from memgpt.cli.cli_load import load_directory, load_webpage, load_index, load_database, load_vector_database
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.metadata import MetadataStore, save_agent
|
||||
from memgpt.data_types import AgentState, User, LLMConfig, EmbeddingConfig
|
||||
|
||||
|
||||
|
@ -13,8 +13,6 @@ from typing import Annotated, Optional
|
||||
|
||||
import typer
|
||||
import questionary
|
||||
from llama_index import set_global_service_context
|
||||
from llama_index import ServiceContext
|
||||
|
||||
from memgpt.log import logger
|
||||
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
@ -25,11 +23,11 @@ from memgpt.utils import printd, open_folder_in_explorer, suppress_stdout
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT
|
||||
from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig, User, Passage
|
||||
from memgpt.metadata import MetadataStore, save_agent
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
|
||||
|
||||
@ -648,16 +646,6 @@ def run(
|
||||
# printd(json.dumps(vars(agent_config), indent=4, sort_keys=True, ensure_ascii=JSON_ENSURE_ASCII))
|
||||
# printd(json.dumps(agent_init_state), indent=4, sort_keys=True, ensure_ascii=JSON_ENSURE_ASCII))
|
||||
|
||||
# configure llama index
|
||||
original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index
|
||||
sys.stdout = io.StringIO()
|
||||
embed_model = embedding_model(config=agent_state.embedding_config, user_id=user.id)
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm=None, embed_model=embed_model, chunk_size=agent_state.embedding_config.embedding_chunk_size
|
||||
)
|
||||
set_global_service_context(service_context)
|
||||
sys.stdout = original_stdout
|
||||
|
||||
# start event loop
|
||||
from memgpt.main import run_agent_loop
|
||||
|
||||
@ -703,69 +691,6 @@ def delete_agent(
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def attach(
|
||||
agent_name: Annotated[str, typer.Option(help="Specify agent to attach data to")],
|
||||
data_source: Annotated[str, typer.Option(help="Data source to attach to agent")],
|
||||
user_id: uuid.UUID = None,
|
||||
):
|
||||
# use client ID is no user_id provided
|
||||
config = MemGPTConfig.load()
|
||||
if user_id is None:
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
try:
|
||||
# loads the data contained in data source into the agent's memory
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from tqdm import tqdm
|
||||
|
||||
ms = MetadataStore(config)
|
||||
agent = ms.get_agent(agent_name=agent_name, user_id=user_id)
|
||||
assert agent is not None, f"No agent found under agent_name={agent_name}, user_id={user_id}"
|
||||
source = ms.get_source(source_name=data_source, user_id=user_id)
|
||||
assert source is not None, f"Source {data_source} does not exist for user {user_id}"
|
||||
|
||||
# get storage connectors
|
||||
with suppress_stdout():
|
||||
source_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id=user_id)
|
||||
dest_storage = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id=user_id, agent_id=agent.id)
|
||||
|
||||
size = source_storage.size({"data_source": data_source})
|
||||
typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN)
|
||||
page_size = 100
|
||||
generator = source_storage.get_all_paginated(filters={"data_source": data_source}, page_size=page_size) # yields List[Passage]
|
||||
all_passages = []
|
||||
for i in tqdm(range(0, size, page_size)):
|
||||
passages = next(generator)
|
||||
|
||||
# need to associated passage with agent (for filtering)
|
||||
for passage in passages:
|
||||
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
|
||||
passage.agent_id = agent.id
|
||||
|
||||
# insert into agent archival memory
|
||||
dest_storage.insert_many(passages)
|
||||
all_passages += passages
|
||||
|
||||
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
|
||||
|
||||
# save destination storage
|
||||
dest_storage.save()
|
||||
|
||||
# attach to agent
|
||||
source = ms.get_source(source_name=data_source, user_id=user_id)
|
||||
assert source is not None, f"source does not exist for source_name={data_source}, user_id={user_id}"
|
||||
source_id = source.id
|
||||
ms.attach_source(agent_id=agent.id, source_id=source_id, user_id=user_id)
|
||||
|
||||
total_agent_passages = dest_storage.size()
|
||||
|
||||
typer.secho(
|
||||
f"Attached data source {data_source} to agent {agent_name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
fg=typer.colors.GREEN,
|
||||
)
|
||||
except KeyboardInterrupt:
|
||||
typer.secho("Operation interrupted by KeyboardInterrupt.", fg=typer.colors.YELLOW)
|
||||
|
||||
|
||||
def version():
|
||||
import memgpt
|
||||
|
||||
|
@ -14,6 +14,7 @@ import numpy as np
|
||||
import typer
|
||||
import uuid
|
||||
|
||||
from memgpt.data_sources.connectors import load_data, DirectoryConnector, VectorDBConnector
|
||||
from memgpt.embeddings import embedding_model, check_and_split_text
|
||||
from memgpt.agent_store.storage import StorageConnector
|
||||
from memgpt.config import MemGPTConfig
|
||||
@ -24,180 +25,57 @@ from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from llama_index import (
|
||||
VectorStoreIndex,
|
||||
ServiceContext,
|
||||
StorageContext,
|
||||
load_index_from_storage,
|
||||
)
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
|
||||
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
|
||||
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
orig_size = storage.size()
|
||||
|
||||
# insert metadata store
|
||||
ms = MetadataStore(config)
|
||||
source = ms.get_source(user_id=user_id, source_name=source_name)
|
||||
if not source:
|
||||
# create new
|
||||
source = Source(user_id=user_id, name=source_name)
|
||||
ms.create_source(source)
|
||||
|
||||
# make sure user_id is set for passages
|
||||
passage_chunk = []
|
||||
insert_chunk_size = 1000
|
||||
for passage in passages:
|
||||
# TODO: attach source IDs
|
||||
# passage.source_id = source.id
|
||||
passage.user_id = user_id
|
||||
passage.data_source = source_name
|
||||
|
||||
# add and save all passages
|
||||
passage_chunk.append(passage)
|
||||
if len(passage_chunk) >= insert_chunk_size:
|
||||
storage.insert_many(passage_chunk)
|
||||
storage.save()
|
||||
passage_chunk = []
|
||||
|
||||
if len(passage_chunk) > 0:
|
||||
storage.insert_many(passage_chunk)
|
||||
storage.save()
|
||||
|
||||
# print info
|
||||
num_new_passages = storage.size() - orig_size
|
||||
print(f"Updated {len(passages)}, inserted {num_new_passages} new passages into {source_name}")
|
||||
print("Total passages in source:", storage.size())
|
||||
|
||||
|
||||
def store_docs(name, docs, user_id=None, show_progress=True):
|
||||
"""Common function for embedding and storing documents"""
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
if user_id is None: # assume running local with single user
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
# record data source metadata
|
||||
ms = MetadataStore(config)
|
||||
user = ms.get_user(user_id)
|
||||
if user is None:
|
||||
raise ValueError(f"Cannot find user {user_id} in metadata store. Please run 'memgpt configure'.")
|
||||
# create data source record
|
||||
data_source = Source(
|
||||
user_id=user.id,
|
||||
name=name,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
)
|
||||
existing_source = ms.get_source(user_id=user.id, source_name=name)
|
||||
if not existing_source:
|
||||
ms.create_source(data_source)
|
||||
else:
|
||||
print(f"Source {name} for user {user.id} already exists.")
|
||||
if existing_source.embedding_model != data_source.embedding_model:
|
||||
print(
|
||||
f"Warning: embedding model for existing source {existing_source.embedding_model} does not match default {data_source.embedding_model}"
|
||||
)
|
||||
print("Cannot import data into this source without a compatible embedding endpoint.")
|
||||
print("Please run 'memgpt configure' to update the default embedding settings.")
|
||||
return False
|
||||
if existing_source.embedding_dim != data_source.embedding_dim:
|
||||
print(
|
||||
f"Warning: embedding dimension for existing source {existing_source.embedding_dim} does not match default {data_source.embedding_dim}"
|
||||
)
|
||||
print("Cannot import data into this source without a compatible embedding endpoint.")
|
||||
print("Please run 'memgpt configure' to update the default embedding settings.")
|
||||
return False
|
||||
|
||||
# compute and record passages
|
||||
embed_model = embedding_model(config.default_embedding_config)
|
||||
|
||||
# use llama index to run embeddings code
|
||||
with suppress_stdout():
|
||||
service_context = ServiceContext.from_defaults(
|
||||
llm=None, embed_model=embed_model, chunk_size=config.default_embedding_config.embedding_chunk_size
|
||||
)
|
||||
index = VectorStoreIndex.from_documents(docs, service_context=service_context, show_progress=True)
|
||||
embed_dict = index._vector_store._data.embedding_dict
|
||||
node_dict = index._docstore.docs
|
||||
|
||||
# TODO: add document store
|
||||
|
||||
# gather passages
|
||||
passages = []
|
||||
for node_id, node in tqdm(node_dict.items()):
|
||||
vector = embed_dict[node_id]
|
||||
node.embedding = vector
|
||||
text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters
|
||||
assert (
|
||||
len(node.embedding) == config.default_embedding_config.embedding_dim
|
||||
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(node.embedding)}: {node.embedding}"
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user.id,
|
||||
text=text,
|
||||
data_source=name,
|
||||
embedding=node.embedding,
|
||||
metadata_=None,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
|
||||
insert_passages_into_source(passages, name, user_id, config)
|
||||
|
||||
|
||||
@app.command("index")
|
||||
def load_index(
|
||||
name: Annotated[str, typer.Option(help="Name of dataset to load.")],
|
||||
dir: Annotated[Optional[str], typer.Option(help="Path to directory containing index.")] = None,
|
||||
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
|
||||
):
|
||||
"""Load a LlamaIndex saved VectorIndex into MemGPT"""
|
||||
if user_id is None:
|
||||
config = MemGPTConfig.load()
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
try:
|
||||
# load index data
|
||||
storage_context = StorageContext.from_defaults(persist_dir=dir)
|
||||
loaded_index = load_index_from_storage(storage_context)
|
||||
|
||||
# hacky code to extract out passages/embeddings (thanks a lot, llama index)
|
||||
embed_dict = loaded_index._vector_store._data.embedding_dict
|
||||
node_dict = loaded_index._docstore.docs
|
||||
|
||||
# create storage connector
|
||||
config = MemGPTConfig.load()
|
||||
if user_id is None:
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
passages = []
|
||||
for node_id, node in node_dict.items():
|
||||
vector = embed_dict[node_id]
|
||||
node.embedding = vector
|
||||
# assume embedding are the same as config
|
||||
passages.append(
|
||||
Passage(
|
||||
text=node.text,
|
||||
embedding=np.array(vector),
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
assert config.default_embedding_config.embedding_dim == len(
|
||||
vector
|
||||
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}"
|
||||
|
||||
if len(passages) == 0:
|
||||
raise ValueError(f"No passages found in index {dir}")
|
||||
|
||||
insert_passages_into_source(passages, name, user_id, config)
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED)
|
||||
# NOTE: not supported due to llama-index breaking things (please reach out if you still need it)
|
||||
# @app.command("index")
|
||||
# def load_index(
|
||||
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
|
||||
# dir: Annotated[Optional[str], typer.Option(help="Path to directory containing index.")] = None,
|
||||
# user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
|
||||
# ):
|
||||
# """Load a LlamaIndex saved VectorIndex into MemGPT"""
|
||||
# if user_id is None:
|
||||
# config = MemGPTConfig.load()
|
||||
# user_id = uuid.UUID(config.anon_clientid)
|
||||
#
|
||||
# try:
|
||||
# # load index data
|
||||
# storage_context = StorageContext.from_defaults(persist_dir=dir)
|
||||
# loaded_index = load_index_from_storage(storage_context)
|
||||
#
|
||||
# # hacky code to extract out passages/embeddings (thanks a lot, llama index)
|
||||
# embed_dict = loaded_index._vector_store._data.embedding_dict
|
||||
# node_dict = loaded_index._docstore.docs
|
||||
#
|
||||
# # create storage connector
|
||||
# config = MemGPTConfig.load()
|
||||
# if user_id is None:
|
||||
# user_id = uuid.UUID(config.anon_clientid)
|
||||
#
|
||||
# passages = []
|
||||
# for node_id, node in node_dict.items():
|
||||
# vector = embed_dict[node_id]
|
||||
# node.embedding = vector
|
||||
# # assume embedding are the same as config
|
||||
# passages.append(
|
||||
# Passage(
|
||||
# text=node.text,
|
||||
# embedding=np.array(vector),
|
||||
# embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
# embedding_model=config.default_embedding_config.embedding_model,
|
||||
# )
|
||||
# )
|
||||
# assert config.default_embedding_config.embedding_dim == len(
|
||||
# vector
|
||||
# ), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}"
|
||||
#
|
||||
# if len(passages) == 0:
|
||||
# raise ValueError(f"No passages found in index {dir}")
|
||||
#
|
||||
# insert_passages_into_source(passages, name, user_id, config)
|
||||
# except ValueError as e:
|
||||
# typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED)
|
||||
|
||||
|
||||
default_extensions = ".txt,.md,.pdf"
|
||||
@ -210,95 +88,99 @@ def load_directory(
|
||||
input_files: Annotated[List[str], typer.Option(help="List of paths to files containing dataset.")] = [],
|
||||
recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False,
|
||||
extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions,
|
||||
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
|
||||
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
|
||||
):
|
||||
try:
|
||||
from llama_index import SimpleDirectoryReader
|
||||
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
|
||||
config = MemGPTConfig.load()
|
||||
if not user_id:
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
if recursive == True:
|
||||
assert input_dir is not None, "Must provide input directory if recursive is True."
|
||||
ms = MetadataStore(config)
|
||||
source = Source(name=name, user_id=user_id)
|
||||
ms.create_source(source)
|
||||
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
# TODO: also get document store
|
||||
|
||||
if input_dir is not None:
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=str(input_dir),
|
||||
recursive=recursive,
|
||||
required_exts=[ext.strip() for ext in str(extensions).split(",")],
|
||||
)
|
||||
else:
|
||||
assert input_files is not None, "Must provide input files if input_dir is None"
|
||||
reader = SimpleDirectoryReader(input_files=[str(f) for f in input_files])
|
||||
|
||||
# load docs
|
||||
docs = reader.load_data()
|
||||
store_docs(str(name), docs, user_id)
|
||||
# ingest data into passage/document store
|
||||
num_passages, num_documents = load_data(
|
||||
connector=connector,
|
||||
source=source,
|
||||
embedding_config=config.default_embedding_config,
|
||||
document_store=None,
|
||||
passage_store=passage_storage,
|
||||
chunk_size=1000,
|
||||
)
|
||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
|
||||
|
||||
@app.command("webpage")
|
||||
def load_webpage(
|
||||
name: Annotated[str, typer.Option(help="Name of dataset to load.")],
|
||||
urls: Annotated[List[str], typer.Option(help="List of urls to load.")],
|
||||
):
|
||||
try:
|
||||
from llama_index.readers.web import SimpleWebPageReader
|
||||
|
||||
docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
|
||||
store_docs(name, docs)
|
||||
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)
|
||||
|
||||
|
||||
@app.command("database")
|
||||
def load_database(
|
||||
name: Annotated[str, typer.Option(help="Name of dataset to load.")],
|
||||
query: Annotated[str, typer.Option(help="Database query.")],
|
||||
dump_path: Annotated[Optional[str], typer.Option(help="Path to dump file.")] = None,
|
||||
scheme: Annotated[Optional[str], typer.Option(help="Database scheme.")] = None,
|
||||
host: Annotated[Optional[str], typer.Option(help="Database host.")] = None,
|
||||
port: Annotated[Optional[int], typer.Option(help="Database port.")] = None,
|
||||
user: Annotated[Optional[str], typer.Option(help="Database user.")] = None,
|
||||
password: Annotated[Optional[str], typer.Option(help="Database password.")] = None,
|
||||
dbname: Annotated[Optional[str], typer.Option(help="Database name.")] = None,
|
||||
):
|
||||
try:
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
|
||||
print(dump_path, scheme)
|
||||
|
||||
if dump_path is not None:
|
||||
# read from database dump file
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
engine = create_engine(f"sqlite:///{dump_path}")
|
||||
|
||||
db = DatabaseReader(engine=engine)
|
||||
else:
|
||||
assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
|
||||
assert scheme is not None, "Must provide database scheme."
|
||||
assert host is not None, "Must provide database host."
|
||||
assert port is not None, "Must provide database port."
|
||||
assert user is not None, "Must provide database user."
|
||||
assert password is not None, "Must provide database password."
|
||||
assert dbname is not None, "Must provide database name."
|
||||
|
||||
db = DatabaseReader(
|
||||
scheme=scheme, # Database Scheme
|
||||
host=host, # Database Host
|
||||
port=str(port), # Database Port
|
||||
user=user, # Database User
|
||||
password=password, # Database Password
|
||||
dbname=dbname, # Database Name
|
||||
)
|
||||
|
||||
# load data
|
||||
docs = db.load_data(query=query)
|
||||
store_docs(name, docs)
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load database from provided information.\n{e}", fg=typer.colors.RED)
|
||||
# @app.command("webpage")
|
||||
# def load_webpage(
|
||||
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
|
||||
# urls: Annotated[List[str], typer.Option(help="List of urls to load.")],
|
||||
# ):
|
||||
# try:
|
||||
# from llama_index.readers.web import SimpleWebPageReader
|
||||
#
|
||||
# docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
|
||||
# store_docs(name, docs)
|
||||
#
|
||||
# except ValueError as e:
|
||||
# typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)
|
||||
#
|
||||
#
|
||||
# @app.command("database")
|
||||
# def load_database(
|
||||
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
|
||||
# query: Annotated[str, typer.Option(help="Database query.")],
|
||||
# dump_path: Annotated[Optional[str], typer.Option(help="Path to dump file.")] = None,
|
||||
# scheme: Annotated[Optional[str], typer.Option(help="Database scheme.")] = None,
|
||||
# host: Annotated[Optional[str], typer.Option(help="Database host.")] = None,
|
||||
# port: Annotated[Optional[int], typer.Option(help="Database port.")] = None,
|
||||
# user: Annotated[Optional[str], typer.Option(help="Database user.")] = None,
|
||||
# password: Annotated[Optional[str], typer.Option(help="Database password.")] = None,
|
||||
# dbname: Annotated[Optional[str], typer.Option(help="Database name.")] = None,
|
||||
# ):
|
||||
# try:
|
||||
# from llama_index.readers.database import DatabaseReader
|
||||
#
|
||||
# print(dump_path, scheme)
|
||||
#
|
||||
# if dump_path is not None:
|
||||
# # read from database dump file
|
||||
# from sqlalchemy import create_engine
|
||||
#
|
||||
# engine = create_engine(f"sqlite:///{dump_path}")
|
||||
#
|
||||
# db = DatabaseReader(engine=engine)
|
||||
# else:
|
||||
# assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
|
||||
# assert scheme is not None, "Must provide database scheme."
|
||||
# assert host is not None, "Must provide database host."
|
||||
# assert port is not None, "Must provide database port."
|
||||
# assert user is not None, "Must provide database user."
|
||||
# assert password is not None, "Must provide database password."
|
||||
# assert dbname is not None, "Must provide database name."
|
||||
#
|
||||
# db = DatabaseReader(
|
||||
# scheme=scheme, # Database Scheme
|
||||
# host=host, # Database Host
|
||||
# port=str(port), # Database Port
|
||||
# user=user, # Database User
|
||||
# password=password, # Database Password
|
||||
# dbname=dbname, # Database Name
|
||||
# )
|
||||
#
|
||||
# # load data
|
||||
# docs = db.load_data(query=query)
|
||||
# store_docs(name, docs)
|
||||
# except ValueError as e:
|
||||
# typer.secho(f"Failed to load database from provided information.\n{e}", fg=typer.colors.RED)
|
||||
#
|
||||
|
||||
|
||||
@app.command("vector-database")
|
||||
@ -311,58 +193,35 @@ def load_vector_database(
|
||||
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
|
||||
):
|
||||
"""Load pre-computed embeddings into MemGPT from a database."""
|
||||
if user_id is None:
|
||||
config = MemGPTConfig.load()
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
try:
|
||||
from sqlalchemy import create_engine, select, MetaData, Table, Inspector
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
# connect to db table
|
||||
engine = create_engine(uri)
|
||||
metadata = MetaData()
|
||||
# Create an inspector to inspect the database
|
||||
inspector = Inspector.from_engine(engine)
|
||||
table_names = inspector.get_table_names()
|
||||
assert table_name in table_names, f"Table {table_name} not found in database: tables that exist {table_names}."
|
||||
|
||||
table = Table(table_name, metadata, autoload_with=engine)
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# Prepare a select statement
|
||||
select_statement = select(
|
||||
table.c[text_column], table.c[embedding_column].cast(Vector(config.default_embedding_config.embedding_dim))
|
||||
connector = VectorDBConnector(
|
||||
uri=uri,
|
||||
table_name=table_name,
|
||||
text_column=text_column,
|
||||
embedding_column=embedding_column,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
)
|
||||
|
||||
# Execute the query and fetch the results
|
||||
with engine.connect() as connection:
|
||||
result = connection.execute(select_statement).fetchall()
|
||||
|
||||
# Convert to a list of tuples (text, embedding)
|
||||
passages = []
|
||||
for text, embedding in result:
|
||||
# assume that embeddings are the same model as in config
|
||||
passages.append(
|
||||
Passage(
|
||||
text=text,
|
||||
embedding=embedding,
|
||||
user_id=user_id,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
)
|
||||
)
|
||||
assert config.default_embedding_config.embedding_dim == len(
|
||||
embedding
|
||||
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(embedding)}"
|
||||
|
||||
# create storage connector
|
||||
config = MemGPTConfig.load()
|
||||
if user_id is None:
|
||||
if not user_id:
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
insert_passages_into_source(passages, name, user_id, config)
|
||||
ms = MetadataStore(config)
|
||||
source = Source(name=name, user_id=user_id)
|
||||
ms.create_source(source)
|
||||
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
# TODO: also get document store
|
||||
|
||||
# ingest data into passage/document store
|
||||
num_passages, num_documents = load_data(
|
||||
connector=connector,
|
||||
source=source,
|
||||
embedding_config=config.default_embedding_config,
|
||||
document_store=None,
|
||||
passage_store=passage_storage,
|
||||
chunk_size=1000,
|
||||
)
|
||||
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
|
||||
|
||||
except ValueError as e:
|
||||
typer.secho(f"Failed to load vector database from provided information.\n{e}", fg=typer.colors.RED)
|
||||
typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
|
||||
raise
|
||||
|
194
memgpt/data_sources/connectors.py
Normal file
194
memgpt/data_sources/connectors.py
Normal file
@ -0,0 +1,194 @@
|
||||
from memgpt.data_types import Passage, Document, EmbeddingConfig, Source
|
||||
from memgpt.utils import create_uuid_from_string
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Document, Passage
|
||||
|
||||
import uuid
|
||||
from typing import List, Iterator, Dict, Tuple, Optional
|
||||
from llama_index.core import Document as LlamaIndexDocument
|
||||
|
||||
|
||||
class DataConnector:
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
pass
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
pass
|
||||
|
||||
|
||||
def load_data(
|
||||
connector: DataConnector,
|
||||
source: Source,
|
||||
embedding_config: EmbeddingConfig,
|
||||
passage_store: StorageConnector,
|
||||
document_store: Optional[StorageConnector] = None,
|
||||
chunk_size: int = 1000,
|
||||
):
|
||||
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
|
||||
|
||||
# embedding model
|
||||
embed_model = embedding_model(embedding_config)
|
||||
|
||||
# insert passages/documents
|
||||
passages = []
|
||||
passage_count = 0
|
||||
document_count = 0
|
||||
for document_text, document_metadata in connector.generate_documents():
|
||||
# insert document into storage
|
||||
document = Document(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{document_text}"),
|
||||
text=document_text,
|
||||
metadata=document_metadata,
|
||||
data_source=source.name,
|
||||
user_id=source.user_id,
|
||||
)
|
||||
document_count += 1
|
||||
if document_store:
|
||||
document_store.insert(document)
|
||||
|
||||
# generate passages
|
||||
for passage_text, passage_metadata in connector.generate_passages([document]):
|
||||
print("passage", passage_text, passage_metadata)
|
||||
embedding = embed_model.get_text_embedding(passage_text)
|
||||
passage = Passage(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||
text=passage_text,
|
||||
doc_id=document.id,
|
||||
metadata_=passage_metadata,
|
||||
user_id=source.user_id,
|
||||
data_source=source.name,
|
||||
embedding_dim=embedding_config.embedding_dim,
|
||||
embedding_model=embedding_config.embedding_model,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
passages.append(passage)
|
||||
if len(passages) >= chunk_size:
|
||||
# insert passages into passage store
|
||||
passage_store.insert_many(passages)
|
||||
|
||||
passage_count += len(passages)
|
||||
passages = []
|
||||
|
||||
if len(passages) > 0:
|
||||
# insert passages into passage store
|
||||
passage_store.insert_many(passages)
|
||||
passage_count += len(passages)
|
||||
|
||||
return passage_count, document_count
|
||||
|
||||
|
||||
class DirectoryConnector(DataConnector):
|
||||
def __init__(self, input_files: List[str] = None, input_directory: str = None, recursive: bool = False, extensions: List[str] = None):
|
||||
self.connector_type = "directory"
|
||||
self.input_files = input_files
|
||||
self.input_directory = input_directory
|
||||
self.recursive = recursive
|
||||
self.extensions = extensions
|
||||
|
||||
if self.recursive == True:
|
||||
assert self.input_dir is not None, "Must provide input directory if recursive is True."
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
if self.input_directory is not None:
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=self.input_directory,
|
||||
recursive=self.recursive,
|
||||
required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
|
||||
)
|
||||
else:
|
||||
assert self.input_files is not None, "Must provide input files if input_dir is None"
|
||||
reader = SimpleDirectoryReader(input_files=[str(f) for f in self.input_files])
|
||||
|
||||
llama_index_docs = reader.load_data()
|
||||
docs = []
|
||||
for llama_index_doc in llama_index_docs:
|
||||
# TODO: add additional metadata?
|
||||
# doc = Document(text=llama_index_doc.text, metadata=llama_index_doc.metadata)
|
||||
# docs.append(doc)
|
||||
yield llama_index_doc.text, llama_index_doc.metadata
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
# use llama index to run embeddings code
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
|
||||
parser = SentenceSplitter(chunk_size=chunk_size)
|
||||
for document in documents:
|
||||
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata)]
|
||||
nodes = parser.get_nodes_from_documents(llama_index_docs)
|
||||
for node in nodes:
|
||||
# passage = Passage(
|
||||
# text=node.text,
|
||||
# doc_id=document.id,
|
||||
# )
|
||||
yield node.text, None
|
||||
|
||||
|
||||
class WebConnector(DataConnector):
|
||||
# TODO
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
pass
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
pass
|
||||
|
||||
|
||||
class VectorDBConnector(DataConnector):
|
||||
# NOTE: this class has not been properly tested, so is unlikely to work
|
||||
# TODO: allow loading multiple tables (1:1 mapping between Document and Table)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
uri: str,
|
||||
table_name: str,
|
||||
text_column: str,
|
||||
embedding_column: str,
|
||||
embedding_dim: int,
|
||||
):
|
||||
self.name = name
|
||||
self.uri = uri
|
||||
self.table_name = table_name
|
||||
self.text_column = text_column
|
||||
self.embedding_column = embedding_column
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# connect to db table
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
self.engine = create_engine(uri)
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
yield self.table_name, None
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
from sqlalchemy import select, MetaData, Table, Inspector
|
||||
from pgvector.sqlalchemy import Vector
|
||||
|
||||
metadata = MetaData()
|
||||
# Create an inspector to inspect the database
|
||||
inspector = Inspector.from_engine(self.engine)
|
||||
table_names = inspector.get_table_names()
|
||||
assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
|
||||
|
||||
table = Table(self.table_name, metadata, autoload_with=self.engine)
|
||||
|
||||
# Prepare a select statement
|
||||
select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
|
||||
|
||||
# Execute the query and fetch the results
|
||||
# TODO: paginate results
|
||||
with self.engine.connect() as connection:
|
||||
result = connection.execute(select_statement).fetchall()
|
||||
|
||||
for text, embedding in result:
|
||||
# assume that embeddings are the same model as in config
|
||||
# TODO: don't re-compute embedding
|
||||
yield text, {"embedding": embedding}
|
@ -11,6 +11,9 @@ from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_u
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
|
||||
from pydantic import BaseModel, Field, Json
|
||||
from memgpt.utils import get_human_text, get_persona_text, printd
|
||||
|
||||
from pydantic import BaseModel, Field, Json
|
||||
|
||||
|
||||
@ -274,7 +277,7 @@ class Message(Record):
|
||||
class Document(Record):
|
||||
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
|
||||
|
||||
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None):
|
||||
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None, metadata: Optional[Dict] = {}):
|
||||
if id is None:
|
||||
# by default, generate ID as a hash of the text (avoid duplicates)
|
||||
self.id = create_uuid_from_string("".join([text, str(user_id)]))
|
||||
@ -284,6 +287,7 @@ class Document(Record):
|
||||
self.user_id = user_id
|
||||
self.text = text
|
||||
self.data_source = data_source
|
||||
self.metadata = metadata
|
||||
# TODO: add optional embedding?
|
||||
|
||||
|
||||
@ -295,8 +299,8 @@ class Passage(Record):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
text: str,
|
||||
user_id: Optional[uuid.UUID] = None,
|
||||
agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory
|
||||
embedding: Optional[np.ndarray] = None,
|
||||
embedding_dim: Optional[int] = None,
|
||||
@ -308,7 +312,11 @@ class Passage(Record):
|
||||
):
|
||||
if id is None:
|
||||
# by default, generate ID as a hash of the text (avoid duplicates)
|
||||
self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)]))
|
||||
# TODO: use source-id instead?
|
||||
if agent_id:
|
||||
self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)]))
|
||||
else:
|
||||
self.id = create_uuid_from_string("".join([text, str(user_id)]))
|
||||
else:
|
||||
self.id = id
|
||||
super().__init__(self.id)
|
||||
@ -334,6 +342,7 @@ class Passage(Record):
|
||||
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
|
||||
|
||||
assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type"
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type"
|
||||
assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type"
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
import typer
|
||||
import uuid
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Any
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
@ -9,13 +9,26 @@ from memgpt.data_types import EmbeddingConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM, EMBEDDING_TO_TOKENIZER_MAP, EMBEDDING_TO_TOKENIZER_DEFAULT
|
||||
|
||||
from llama_index.embeddings import OpenAIEmbedding, AzureOpenAIEmbedding
|
||||
from llama_index.bridge.pydantic import PrivateAttr
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.embeddings.huggingface_utils import format_text
|
||||
# from llama_index.core.base.embeddings import BaseEmbedding
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import Document as LlamaIndexDocument
|
||||
|
||||
# from llama_index.core.base.embeddings import BaseEmbedding
|
||||
# from llama_index.core.embeddings import BaseEmbedding
|
||||
# from llama_index.core.base.embeddings.base import BaseEmbedding
|
||||
# from llama_index.bridge.pydantic import PrivateAttr
|
||||
# from llama_index.embeddings.base import BaseEmbedding
|
||||
# from llama_index.embeddings.huggingface_utils import format_text
|
||||
import tiktoken
|
||||
|
||||
|
||||
def parse_and_chunk_text(text: str, chunk_size: int) -> List[str]:
|
||||
parser = SentenceSplitter(chunk_size=chunk_size)
|
||||
llama_index_docs = [LlamaIndexDocument(text=text)]
|
||||
nodes = parser.get_nodes_from_documents(llama_index_docs)
|
||||
return [n.text for n in nodes]
|
||||
|
||||
|
||||
def truncate_text(text: str, max_length: int, encoding) -> str:
|
||||
# truncate the text based on max_length and encoding
|
||||
encoded_text = encoding.encode(text)[:max_length]
|
||||
@ -53,15 +66,15 @@ def check_and_split_text(text: str, embedding_model: str) -> List[str]:
|
||||
return [text]
|
||||
|
||||
|
||||
class EmbeddingEndpoint(BaseEmbedding):
|
||||
class EmbeddingEndpoint:
|
||||
|
||||
"""Implementation for OpenAI compatible endpoint"""
|
||||
|
||||
""" Based off llama index https://github.com/run-llama/llama_index/blob/a98bdb8ecee513dc2e880f56674e7fd157d1dc3a/llama_index/embeddings/text_embeddings_inference.py """
|
||||
# """ Based off llama index https://github.com/run-llama/llama_index/blob/a98bdb8ecee513dc2e880f56674e7fd157d1dc3a/llama_index/embeddings/text_embeddings_inference.py """
|
||||
|
||||
_user: str = PrivateAttr()
|
||||
_timeout: float = PrivateAttr()
|
||||
_base_url: str = PrivateAttr()
|
||||
# _user: str = PrivateAttr()
|
||||
# _timeout: float = PrivateAttr()
|
||||
# _base_url: str = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -69,21 +82,16 @@ class EmbeddingEndpoint(BaseEmbedding):
|
||||
base_url: str,
|
||||
user: str,
|
||||
timeout: float = 60.0,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if not is_valid_url(base_url):
|
||||
raise ValueError(
|
||||
f"Embeddings endpoint was provided an invalid URL (set to: '{base_url}'). Make sure embedding_endpoint is set correctly in your MemGPT config."
|
||||
)
|
||||
self.model_name = model
|
||||
self._user = user
|
||||
self._base_url = base_url
|
||||
self._timeout = timeout
|
||||
super().__init__(
|
||||
model_name=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "EmbeddingEndpoint"
|
||||
|
||||
def _call_api(self, text: str) -> List[float]:
|
||||
if not is_valid_url(self._base_url):
|
||||
@ -120,59 +128,8 @@ class EmbeddingEndpoint(BaseEmbedding):
|
||||
|
||||
return embedding
|
||||
|
||||
async def _acall_api(self, text: str) -> List[float]:
|
||||
if not is_valid_url(self._base_url):
|
||||
raise ValueError(
|
||||
f"Embeddings endpoint does not have a valid URL (set to: '{self._base_url}'). Make sure embedding_endpoint is set correctly in your MemGPT config."
|
||||
)
|
||||
import httpx
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
json_data = {"input": text, "model": self.model_name, "user": self._user}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self._base_url}/embeddings",
|
||||
headers=headers,
|
||||
json=json_data,
|
||||
timeout=self._timeout,
|
||||
)
|
||||
response_json = response.json()
|
||||
|
||||
if isinstance(response_json, list):
|
||||
# embedding directly in response
|
||||
embedding = response_json
|
||||
elif isinstance(response_json, dict):
|
||||
# TEI embedding packaged inside openai-style response
|
||||
try:
|
||||
embedding = response_json["data"][0]["embedding"]
|
||||
except (KeyError, IndexError):
|
||||
raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}")
|
||||
else:
|
||||
# unknown response, can't parse
|
||||
raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}")
|
||||
|
||||
return embedding
|
||||
|
||||
def _get_query_embedding(self, query: str) -> list[float]:
|
||||
"""get query embedding."""
|
||||
embedding = self._call_api(query)
|
||||
return embedding
|
||||
|
||||
def _get_text_embedding(self, text: str) -> list[float]:
|
||||
"""get text embedding."""
|
||||
embedding = self._call_api(text)
|
||||
return embedding
|
||||
|
||||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
embeddings = [self._get_text_embedding(text) for text in texts]
|
||||
return embeddings
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> List[float]:
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> List[float]:
|
||||
return self._get_text_embedding(text)
|
||||
def get_text_embedding(self, text: str) -> List[float]:
|
||||
return self._call_api(text)
|
||||
|
||||
|
||||
def default_embedding_model():
|
||||
@ -202,10 +159,14 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
|
||||
credentials = MemGPTCredentials.load()
|
||||
|
||||
if endpoint_type == "openai":
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
additional_kwargs = {"user_id": user_id} if user_id else {}
|
||||
model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=credentials.openai_key, additional_kwargs=additional_kwargs)
|
||||
return model
|
||||
elif endpoint_type == "azure":
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
|
||||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
|
||||
model = "text-embedding-ada-002"
|
||||
deployment = credentials.azure_embedding_deployment if credentials.azure_embedding_deployment is not None else model
|
||||
|
@ -11,15 +11,16 @@ from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, JSON_ENSURE_ASCII, J
|
||||
|
||||
console = Console()
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
from memgpt.config import MemGPTConfig
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.errors as errors
|
||||
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart, migrate, delete_agent
|
||||
from memgpt.cli.cli import run, version, server, open_folder, quickstart, migrate, delete_agent
|
||||
from memgpt.cli.cli_config import configure, list, add, delete
|
||||
from memgpt.cli.cli_load import app as load_app
|
||||
from memgpt.metadata import MetadataStore, save_agent
|
||||
from memgpt.metadata import MetadataStore
|
||||
|
||||
# import benchmark
|
||||
from memgpt.benchmark.benchmark import bench
|
||||
@ -27,7 +28,6 @@ from memgpt.benchmark.benchmark import bench
|
||||
app = typer.Typer(pretty_exceptions_enable=False)
|
||||
app.command(name="run")(run)
|
||||
app.command(name="version")(version)
|
||||
app.command(name="attach")(attach)
|
||||
app.command(name="configure")(configure)
|
||||
app.command(name="list")(list)
|
||||
app.command(name="add")(add)
|
||||
@ -100,11 +100,11 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
# updated agent save functions
|
||||
if user_input.lower() == "/exit":
|
||||
# memgpt_agent.save()
|
||||
save_agent(memgpt_agent, ms)
|
||||
agent.save_agent(memgpt_agent, ms)
|
||||
break
|
||||
elif user_input.lower() == "/save" or user_input.lower() == "/savechat":
|
||||
# memgpt_agent.save()
|
||||
save_agent(memgpt_agent, ms)
|
||||
agent.save_agent(memgpt_agent, ms)
|
||||
continue
|
||||
elif user_input.lower() == "/attach":
|
||||
# TODO: check if agent already has it
|
||||
@ -143,7 +143,11 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
|
||||
data_source = questionary.select("Select data source", choices=valid_options).ask()
|
||||
|
||||
# attach new data
|
||||
attach(memgpt_agent.agent_state.name, data_source)
|
||||
# attach(memgpt_agent.agent_state.name, data_source)
|
||||
source_connector = StorageConnector.get_storage_connector(
|
||||
TableType.PASSAGES, config, user_id=memgpt_agent.agent_state.user_id
|
||||
)
|
||||
memgpt_agent.attach_source(data_source, source_connector, ms)
|
||||
|
||||
continue
|
||||
|
||||
|
@ -7,9 +7,10 @@ from memgpt.utils import get_local_time, printd, count_tokens, validate_date_for
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from memgpt.llm_api_tools import create
|
||||
from memgpt.data_types import Message, Passage, AgentState
|
||||
from memgpt.embeddings import embedding_model, query_embedding
|
||||
from llama_index import Document
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
from memgpt.embeddings import embedding_model, query_embedding, parse_and_chunk_text
|
||||
|
||||
# from llama_index import Document
|
||||
# from llama_index.node_parser import SimpleNodeParser
|
||||
|
||||
|
||||
class CoreMemory(object):
|
||||
@ -402,12 +403,9 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
try:
|
||||
passages = []
|
||||
|
||||
# create parser
|
||||
parser = SimpleNodeParser.from_defaults(chunk_size=self.embedding_chunk_size)
|
||||
|
||||
# breakup string into passages
|
||||
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
|
||||
embedding = self.embed_model.get_text_embedding(node.text)
|
||||
for text in parse_and_chunk_text(memory_string, self.embedding_chunk_size):
|
||||
embedding = self.embed_model.get_text_embedding(text)
|
||||
# fixing weird bug where type returned isn't a list, but instead is an object
|
||||
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
|
||||
if isinstance(embedding, dict):
|
||||
@ -418,7 +416,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
passages.append(self.create_passage(node.text, embedding))
|
||||
passages.append(self.create_passage(text, embedding))
|
||||
|
||||
# insert passages
|
||||
self.storage.insert_many(passages)
|
||||
|
@ -10,7 +10,6 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON
|
||||
from memgpt.utils import get_local_time, enforce_types
|
||||
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.agent import Agent
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
|
||||
from sqlalchemy import func
|
||||
@ -379,12 +378,16 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_source(self, source: Source):
|
||||
def create_source(self, source: Source, exists_ok=False):
|
||||
# make sure source.name does not already exist for user
|
||||
with self.session_maker() as session:
|
||||
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
|
||||
raise ValueError(f"Source with name {source.name} already exists")
|
||||
session.add(SourceModel(**vars(source)))
|
||||
if not exists_ok:
|
||||
raise ValueError(f"Source with name {source.name} already exists")
|
||||
else:
|
||||
session.update(SourceModel(**vars(source)))
|
||||
else:
|
||||
session.add(SourceModel(**vars(source)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
@ -596,15 +599,3 @@ class MetadataStore:
|
||||
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
|
||||
).delete()
|
||||
session.commit()
|
||||
|
||||
|
||||
def save_agent(agent: Agent, ms: MetadataStore):
|
||||
"""Save agent to metadata store"""
|
||||
|
||||
agent.update_state()
|
||||
agent_state = agent.agent_state
|
||||
|
||||
if ms.get_agent(agent_id=agent_state.id):
|
||||
ms.update_agent(agent_state)
|
||||
else:
|
||||
ms.create_agent(agent_state)
|
||||
|
@ -15,14 +15,10 @@ import typer
|
||||
from tqdm import tqdm
|
||||
import questionary
|
||||
|
||||
from llama_index import (
|
||||
StorageContext,
|
||||
load_index_from_storage,
|
||||
)
|
||||
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.data_types import AgentState, User, Passage, Source, Message
|
||||
from memgpt.metadata import MetadataStore, save_agent
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.utils import (
|
||||
MEMGPT_DIR,
|
||||
version_less_than,
|
||||
@ -160,7 +156,17 @@ def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Me
|
||||
ms.create_source(source)
|
||||
|
||||
try:
|
||||
nodes = pickle.load(open(source_path, "rb"))
|
||||
try:
|
||||
nodes = pickle.load(open(source_path, "rb"))
|
||||
except ModuleNotFoundError as e:
|
||||
if "No module named 'llama_index.schema'" in str(e):
|
||||
# cannot load source at all, so throw error
|
||||
raise ValueError(
|
||||
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
|
||||
passages = []
|
||||
for node in nodes:
|
||||
# print(len(node.embedding))
|
||||
@ -485,7 +491,17 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
|
||||
|
||||
# 5. Insert into archival
|
||||
if os.path.exists(archival_filename):
|
||||
nodes = pickle.load(open(archival_filename, "rb"))
|
||||
try:
|
||||
nodes = pickle.load(open(archival_filename, "rb"))
|
||||
except ModuleNotFoundError as e:
|
||||
if "No module named 'llama_index.schema'" in str(e):
|
||||
print(
|
||||
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
|
||||
)
|
||||
nodes = []
|
||||
else:
|
||||
raise e
|
||||
|
||||
passages = []
|
||||
failed_inserts = []
|
||||
for node in nodes:
|
||||
|
@ -23,6 +23,7 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
print("PRESET", preset_name, user_id)
|
||||
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
|
||||
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
|
@ -11,21 +11,21 @@ from fastapi import HTTPException
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.constants import JSON_LOADS_STRICT, JSON_ENSURE_ASCII
|
||||
from memgpt.agent import Agent
|
||||
from memgpt.agent import Agent, save_agent
|
||||
import memgpt.system as system
|
||||
import memgpt.constants as constants
|
||||
from memgpt.cli.cli import attach
|
||||
|
||||
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
from memgpt.cli.cli_config import get_model_options
|
||||
|
||||
# from memgpt.agent_store.storage import StorageConnector
|
||||
from memgpt.metadata import MetadataStore, save_agent
|
||||
from memgpt.metadata import MetadataStore
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
import memgpt.server.utils as server_utils
|
||||
from memgpt.data_types import (
|
||||
User,
|
||||
Passage,
|
||||
AgentState,
|
||||
LLMConfig,
|
||||
EmbeddingConfig,
|
||||
@ -394,7 +394,9 @@ class SyncServer(LockingServer):
|
||||
except:
|
||||
raise ValueError(command)
|
||||
|
||||
attach(agent_name=memgpt_agent.agent_state.name, data_source=data_source, user_id=user_id)
|
||||
# attach data to agent from source
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
memgpt_agent.attach_source(data_source, source_connector, self.ms)
|
||||
|
||||
elif command.lower() == "dump" or command.lower().startswith("dump "):
|
||||
# Check if there's an additional argument that's an integer
|
||||
@ -642,6 +644,10 @@ class SyncServer(LockingServer):
|
||||
if agent is not None:
|
||||
self.ms.delete_agent(agent_id=agent_id)
|
||||
|
||||
def initialize_default_presets(self, user_id: uuid.UUID):
|
||||
"""Add default preset options into the metadata store"""
|
||||
presets.add_default_presets(user_id, self.ms)
|
||||
|
||||
def create_preset(self, preset: Preset):
|
||||
"""Create a new preset using a config"""
|
||||
if self.ms.get_user(user_id=preset.user_id) is None:
|
||||
@ -1009,3 +1015,16 @@ class SyncServer(LockingServer):
|
||||
"""Create a new API key for a user"""
|
||||
token = self.ms.create_api_key(user_id=user_id)
|
||||
return token
|
||||
|
||||
def create_source(self, name: str): # TODO: add other fields
|
||||
# craete a data source
|
||||
pass
|
||||
|
||||
def load_passages(self, source_id: uuid.UUID, passages: List[Passage]):
|
||||
# load a list of passages into a data source
|
||||
pass
|
||||
|
||||
def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID):
|
||||
# attach a data source to an agent
|
||||
# TODO: insert passages into agent archival memory
|
||||
pass
|
||||
|
1616
poetry.lock
generated
1616
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -22,7 +22,6 @@ pytz = "^2023.3.post1"
|
||||
tqdm = "^4.66.1"
|
||||
black = { version = "^23.10.1", optional = true }
|
||||
pytest = { version = "^7.4.4", optional = true }
|
||||
llama-index = "^0.9.13"
|
||||
setuptools = "^68.2.2"
|
||||
datasets = { version = "^2.14.6", optional = true}
|
||||
prettytable = "^3.9.0"
|
||||
@ -53,6 +52,9 @@ docx2txt = "^0.8"
|
||||
sqlalchemy = "^2.0.25"
|
||||
pexpect = {version = "^4.9.0", optional = true}
|
||||
pyright = {version = "^1.1.347", optional = true}
|
||||
llama-index = "^0.10.6"
|
||||
llama-index-embeddings-openai = "^0.1.1"
|
||||
python-box = "^7.1.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
|
@ -6,8 +6,9 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
# import memgpt
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
from memgpt.cli.cli import attach
|
||||
from memgpt.cli.cli_load import load_directory
|
||||
|
||||
# from memgpt.data_sources.connectors import DirectoryConnector, load_data
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.metadata import MetadataStore
|
||||
@ -153,28 +154,29 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
|
||||
sources = ms.list_sources(user_id=user_id)
|
||||
print("All sources", [s.name for s in sources])
|
||||
|
||||
# test loading into an agent
|
||||
# create agent
|
||||
agent_id = agent.id
|
||||
# create storage connector
|
||||
print("Creating agent archival storage connector...")
|
||||
conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
|
||||
print("Deleting agent archival table...")
|
||||
conn.delete_table()
|
||||
conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
|
||||
assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
|
||||
# TODO: add back once agent attachment fully supported from server
|
||||
## test loading into an agent
|
||||
## create agent
|
||||
# agent_id = agent.id
|
||||
## create storage connector
|
||||
# print("Creating agent archival storage connector...")
|
||||
# conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
|
||||
# print("Deleting agent archival table...")
|
||||
# conn.delete_table()
|
||||
# conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
|
||||
# assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
|
||||
|
||||
# attach data
|
||||
print("Attaching data...")
|
||||
attach(agent_name=agent.name, data_source=name, user_id=user_id)
|
||||
## attach data
|
||||
# print("Attaching data...")
|
||||
# attach(agent_name=agent.name, data_source=name, user_id=user_id)
|
||||
|
||||
# test to see if contained in storage
|
||||
assert len(passages) == conn.size()
|
||||
assert len(passages) == len(conn.get_all({"data_source": name}))
|
||||
## test to see if contained in storage
|
||||
# assert len(passages) == conn.size()
|
||||
# assert len(passages) == len(conn.get_all({"data_source": name}))
|
||||
|
||||
# test: delete source
|
||||
passages_conn.delete({"data_source": name})
|
||||
assert len(passages_conn.get_all({"data_source": name})) == 0
|
||||
## test: delete source
|
||||
# passages_conn.delete({"data_source": name})
|
||||
# assert len(passages_conn.get_all({"data_source": name})) == 0
|
||||
|
||||
# cleanup
|
||||
ms.delete_user(user.id)
|
||||
|
@ -18,10 +18,12 @@ def test_migrate_0211():
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = os.path.join(data_dir, "config")
|
||||
# print(f"MEMGPT_CONFIG_PATH={os.environ['MEMGPT_CONFIG_PATH']}")
|
||||
try:
|
||||
agent_res = migrate_all_agents(tmp_dir)
|
||||
agent_res = migrate_all_agents(tmp_dir, debug=True)
|
||||
assert len(agent_res["failed_migrations"]) == 0, f"Failed migrations: {agent_res}"
|
||||
source_res = migrate_all_sources(tmp_dir)
|
||||
assert len(source_res["failed_migrations"]) == 0, f"Failed migrations: {source_res}"
|
||||
|
||||
# NOTE: source tests had to be removed since it is no longer possible to migrate llama index vector indices
|
||||
# source_res = migrate_all_sources(tmp_dir)
|
||||
# assert len(source_res["failed_migrations"]) == 0, f"Failed migrations: {source_res}"
|
||||
|
||||
# TODO: assert everything is in the DB
|
||||
|
||||
@ -40,11 +42,11 @@ def test_migrate_0211():
|
||||
messages = server.get_agent_messages(user_id=agent_state.user_id, agent_id=agent_state.id, start=0, count=1000)
|
||||
assert len(messages) > 0
|
||||
|
||||
for source_name in source_res["migration_candidates"]:
|
||||
if source_name not in source_res["failed_migrations"]:
|
||||
# assert source data exists
|
||||
source = server.ms.get_source(source_name=source_name, user_id=source_res["user_id"])
|
||||
assert source is not None
|
||||
# for source_name in source_res["migration_candidates"]:
|
||||
# if source_name not in source_res["failed_migrations"]:
|
||||
# # assert source data exists
|
||||
# source = server.ms.get_source(source_name=source_name, user_id=source_res["user_id"])
|
||||
# assert source is not None
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
|
@ -5,16 +5,21 @@ import uuid
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.server import app
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
|
||||
def test_list_messages():
|
||||
client = TestClient(app)
|
||||
|
||||
test_user_id = uuid.uuid4()
|
||||
test_user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
|
||||
|
||||
# create user
|
||||
server = SyncServer()
|
||||
server.create_user({"id": test_user_id})
|
||||
if not server.get_user(test_user_id):
|
||||
server.create_user({"id": test_user_id})
|
||||
|
||||
# write default presets to DB
|
||||
server.initialize_default_presets(test_user_id)
|
||||
|
||||
# test: create agent
|
||||
request_body = {
|
||||
|
Loading…
Reference in New Issue
Block a user