feat: refactor loading and attaching data sources, and upgrade to llama-index==0.10.6 (#1016)

This commit is contained in:
Sarah Wooders 2024-02-18 16:57:01 -08:00 committed by GitHub
parent 508679e2cb
commit 38c184caf8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 1586 additions and 1131 deletions

View File

@ -5,9 +5,11 @@ import json
from pathlib import Path
import traceback
from typing import List, Tuple, Optional, cast, Union
from tqdm import tqdm
from memgpt.data_types import AgentState, Message, EmbeddingConfig
from memgpt.metadata import MetadataStore
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.data_types import AgentState, Message, EmbeddingConfig, Passage
from memgpt.models import chat_completion_response
from memgpt.interface import AgentInterface
from memgpt.persistence_manager import LocalStateManager
@ -24,6 +26,7 @@ from memgpt.utils import (
get_schema_diff,
validate_function_response,
verify_first_message_correctness,
create_uuid_from_string,
)
from memgpt.constants import (
FIRST_MESSAGE_ATTEMPTS,
@ -951,3 +954,57 @@ class Agent(object):
# TODO: recall memory
raise NotImplementedError()
def attach_source(self, source_name, source_connector: StorageConnector, ms: MetadataStore):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
filters = {"user_id": self.agent_state.user_id, "data_source": source_name}
size = source_connector.size(filters)
# typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN)
page_size = 100
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
all_passages = []
for i in tqdm(range(0, size, page_size)):
passages = next(generator)
# need to associated passage with agent (for filtering)
for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = self.agent_state.id
# regenerate passage ID (avoid duplicates)
passage.id = create_uuid_from_string(f"{source_name}_{str(passage.agent_id)}_{passage.text}")
# insert into agent archival memory
self.persistence_manager.archival_memory.storage.insert_many(passages)
all_passages += passages
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
# save destination storage
self.persistence_manager.archival_memory.storage.save()
# attach to agent
source = ms.get_source(source_name=source_name, user_id=self.agent_state.user_id)
assert source is not None, f"source does not exist for source_name={source_name}, user_id={self.agent_state.user_id}"
source_id = source.id
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
total_agent_passages = self.persistence_manager.archival_memory.storage.size()
printd(
f"Attached data source {source_name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
)
def save_agent(agent: Agent, ms: MetadataStore):
"""Save agent to metadata store"""
agent.update_state()
agent_state = agent.agent_state
if ms.get_agent(agent_id=agent_state.id):
ms.update_agent(agent_state)
else:
ms.create_agent(agent_state)

View File

@ -37,7 +37,7 @@ class ChromaStorageConnector(StorageConnector):
self.include: Include = ["documents", "embeddings", "metadatas"]
# need to be converted to strings
self.uuid_fields = ["id", "user_id", "agent_id", "source_id"]
self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "doc_id"]
def get_filters(self, filters: Optional[Dict] = {}) -> Tuple[list, dict]:
# get all filters for query

View File

@ -4,6 +4,7 @@ from typing import Callable, Optional, List, Dict, Union, Any, Tuple
from autogen.agentchat import Agent, ConversableAgent, UserProxyAgent, GroupChat, GroupChatManager
from memgpt.agent import Agent as MemGPTAgent
from memgpt.agent import save_agent
from memgpt.autogen.interface import AutoGenInterface
import memgpt.system as system
import memgpt.constants as constants
@ -14,7 +15,6 @@ from memgpt.credentials import MemGPTCredentials
from memgpt.cli.cli import attach
from memgpt.cli.cli_load import load_directory, load_webpage, load_index, load_database, load_vector_database
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore, save_agent
from memgpt.data_types import AgentState, User, LLMConfig, EmbeddingConfig

View File

@ -13,8 +13,6 @@ from typing import Annotated, Optional
import typer
import questionary
from llama_index import set_global_service_context
from llama_index import ServiceContext
from memgpt.log import logger
from memgpt.interface import CLIInterface as interface # for printing to terminal
@ -25,11 +23,11 @@ from memgpt.utils import printd, open_folder_in_explorer, suppress_stdout
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX, JSON_ENSURE_ASCII
from memgpt.agent import Agent
from memgpt.agent import Agent, save_agent
from memgpt.embeddings import embedding_model
from memgpt.server.constants import WS_DEFAULT_PORT, REST_DEFAULT_PORT
from memgpt.data_types import AgentState, LLMConfig, EmbeddingConfig, User, Passage
from memgpt.metadata import MetadataStore, save_agent
from memgpt.metadata import MetadataStore
from memgpt.migrate import migrate_all_agents, migrate_all_sources
@ -648,16 +646,6 @@ def run(
# printd(json.dumps(vars(agent_config), indent=4, sort_keys=True, ensure_ascii=JSON_ENSURE_ASCII))
# printd(json.dumps(agent_init_state), indent=4, sort_keys=True, ensure_ascii=JSON_ENSURE_ASCII))
# configure llama index
original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index
sys.stdout = io.StringIO()
embed_model = embedding_model(config=agent_state.embedding_config, user_id=user.id)
service_context = ServiceContext.from_defaults(
llm=None, embed_model=embed_model, chunk_size=agent_state.embedding_config.embedding_chunk_size
)
set_global_service_context(service_context)
sys.stdout = original_stdout
# start event loop
from memgpt.main import run_agent_loop
@ -703,69 +691,6 @@ def delete_agent(
sys.exit(1)
def attach(
agent_name: Annotated[str, typer.Option(help="Specify agent to attach data to")],
data_source: Annotated[str, typer.Option(help="Data source to attach to agent")],
user_id: uuid.UUID = None,
):
# use client ID is no user_id provided
config = MemGPTConfig.load()
if user_id is None:
user_id = uuid.UUID(config.anon_clientid)
try:
# loads the data contained in data source into the agent's memory
from memgpt.agent_store.storage import StorageConnector, TableType
from tqdm import tqdm
ms = MetadataStore(config)
agent = ms.get_agent(agent_name=agent_name, user_id=user_id)
assert agent is not None, f"No agent found under agent_name={agent_name}, user_id={user_id}"
source = ms.get_source(source_name=data_source, user_id=user_id)
assert source is not None, f"Source {data_source} does not exist for user {user_id}"
# get storage connectors
with suppress_stdout():
source_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id=user_id)
dest_storage = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id=user_id, agent_id=agent.id)
size = source_storage.size({"data_source": data_source})
typer.secho(f"Ingesting {size} passages into {agent.name}", fg=typer.colors.GREEN)
page_size = 100
generator = source_storage.get_all_paginated(filters={"data_source": data_source}, page_size=page_size) # yields List[Passage]
all_passages = []
for i in tqdm(range(0, size, page_size)):
passages = next(generator)
# need to associated passage with agent (for filtering)
for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = agent.id
# insert into agent archival memory
dest_storage.insert_many(passages)
all_passages += passages
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
# save destination storage
dest_storage.save()
# attach to agent
source = ms.get_source(source_name=data_source, user_id=user_id)
assert source is not None, f"source does not exist for source_name={data_source}, user_id={user_id}"
source_id = source.id
ms.attach_source(agent_id=agent.id, source_id=source_id, user_id=user_id)
total_agent_passages = dest_storage.size()
typer.secho(
f"Attached data source {data_source} to agent {agent_name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
fg=typer.colors.GREEN,
)
except KeyboardInterrupt:
typer.secho("Operation interrupted by KeyboardInterrupt.", fg=typer.colors.YELLOW)
def version():
import memgpt

View File

@ -14,6 +14,7 @@ import numpy as np
import typer
import uuid
from memgpt.data_sources.connectors import load_data, DirectoryConnector, VectorDBConnector
from memgpt.embeddings import embedding_model, check_and_split_text
from memgpt.agent_store.storage import StorageConnector
from memgpt.config import MemGPTConfig
@ -24,180 +25,57 @@ from memgpt.agent_store.storage import StorageConnector, TableType
from datetime import datetime
from llama_index import (
VectorStoreIndex,
ServiceContext,
StorageContext,
load_index_from_storage,
)
app = typer.Typer()
def insert_passages_into_source(passages: List[Passage], source_name: str, user_id: uuid.UUID, config: MemGPTConfig):
"""Insert a list of passages into a source by updating storage connectors and metadata store"""
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
orig_size = storage.size()
# insert metadata store
ms = MetadataStore(config)
source = ms.get_source(user_id=user_id, source_name=source_name)
if not source:
# create new
source = Source(user_id=user_id, name=source_name)
ms.create_source(source)
# make sure user_id is set for passages
passage_chunk = []
insert_chunk_size = 1000
for passage in passages:
# TODO: attach source IDs
# passage.source_id = source.id
passage.user_id = user_id
passage.data_source = source_name
# add and save all passages
passage_chunk.append(passage)
if len(passage_chunk) >= insert_chunk_size:
storage.insert_many(passage_chunk)
storage.save()
passage_chunk = []
if len(passage_chunk) > 0:
storage.insert_many(passage_chunk)
storage.save()
# print info
num_new_passages = storage.size() - orig_size
print(f"Updated {len(passages)}, inserted {num_new_passages} new passages into {source_name}")
print("Total passages in source:", storage.size())
def store_docs(name, docs, user_id=None, show_progress=True):
"""Common function for embedding and storing documents"""
config = MemGPTConfig.load()
if user_id is None: # assume running local with single user
user_id = uuid.UUID(config.anon_clientid)
# record data source metadata
ms = MetadataStore(config)
user = ms.get_user(user_id)
if user is None:
raise ValueError(f"Cannot find user {user_id} in metadata store. Please run 'memgpt configure'.")
# create data source record
data_source = Source(
user_id=user.id,
name=name,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
)
existing_source = ms.get_source(user_id=user.id, source_name=name)
if not existing_source:
ms.create_source(data_source)
else:
print(f"Source {name} for user {user.id} already exists.")
if existing_source.embedding_model != data_source.embedding_model:
print(
f"Warning: embedding model for existing source {existing_source.embedding_model} does not match default {data_source.embedding_model}"
)
print("Cannot import data into this source without a compatible embedding endpoint.")
print("Please run 'memgpt configure' to update the default embedding settings.")
return False
if existing_source.embedding_dim != data_source.embedding_dim:
print(
f"Warning: embedding dimension for existing source {existing_source.embedding_dim} does not match default {data_source.embedding_dim}"
)
print("Cannot import data into this source without a compatible embedding endpoint.")
print("Please run 'memgpt configure' to update the default embedding settings.")
return False
# compute and record passages
embed_model = embedding_model(config.default_embedding_config)
# use llama index to run embeddings code
with suppress_stdout():
service_context = ServiceContext.from_defaults(
llm=None, embed_model=embed_model, chunk_size=config.default_embedding_config.embedding_chunk_size
)
index = VectorStoreIndex.from_documents(docs, service_context=service_context, show_progress=True)
embed_dict = index._vector_store._data.embedding_dict
node_dict = index._docstore.docs
# TODO: add document store
# gather passages
passages = []
for node_id, node in tqdm(node_dict.items()):
vector = embed_dict[node_id]
node.embedding = vector
text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters
assert (
len(node.embedding) == config.default_embedding_config.embedding_dim
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(node.embedding)}: {node.embedding}"
passages.append(
Passage(
user_id=user.id,
text=text,
data_source=name,
embedding=node.embedding,
metadata_=None,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
insert_passages_into_source(passages, name, user_id, config)
@app.command("index")
def load_index(
name: Annotated[str, typer.Option(help="Name of dataset to load.")],
dir: Annotated[Optional[str], typer.Option(help="Path to directory containing index.")] = None,
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
):
"""Load a LlamaIndex saved VectorIndex into MemGPT"""
if user_id is None:
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
try:
# load index data
storage_context = StorageContext.from_defaults(persist_dir=dir)
loaded_index = load_index_from_storage(storage_context)
# hacky code to extract out passages/embeddings (thanks a lot, llama index)
embed_dict = loaded_index._vector_store._data.embedding_dict
node_dict = loaded_index._docstore.docs
# create storage connector
config = MemGPTConfig.load()
if user_id is None:
user_id = uuid.UUID(config.anon_clientid)
passages = []
for node_id, node in node_dict.items():
vector = embed_dict[node_id]
node.embedding = vector
# assume embedding are the same as config
passages.append(
Passage(
text=node.text,
embedding=np.array(vector),
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
assert config.default_embedding_config.embedding_dim == len(
vector
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}"
if len(passages) == 0:
raise ValueError(f"No passages found in index {dir}")
insert_passages_into_source(passages, name, user_id, config)
except ValueError as e:
typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED)
# NOTE: not supported due to llama-index breaking things (please reach out if you still need it)
# @app.command("index")
# def load_index(
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
# dir: Annotated[Optional[str], typer.Option(help="Path to directory containing index.")] = None,
# user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
# ):
# """Load a LlamaIndex saved VectorIndex into MemGPT"""
# if user_id is None:
# config = MemGPTConfig.load()
# user_id = uuid.UUID(config.anon_clientid)
#
# try:
# # load index data
# storage_context = StorageContext.from_defaults(persist_dir=dir)
# loaded_index = load_index_from_storage(storage_context)
#
# # hacky code to extract out passages/embeddings (thanks a lot, llama index)
# embed_dict = loaded_index._vector_store._data.embedding_dict
# node_dict = loaded_index._docstore.docs
#
# # create storage connector
# config = MemGPTConfig.load()
# if user_id is None:
# user_id = uuid.UUID(config.anon_clientid)
#
# passages = []
# for node_id, node in node_dict.items():
# vector = embed_dict[node_id]
# node.embedding = vector
# # assume embedding are the same as config
# passages.append(
# Passage(
# text=node.text,
# embedding=np.array(vector),
# embedding_dim=config.default_embedding_config.embedding_dim,
# embedding_model=config.default_embedding_config.embedding_model,
# )
# )
# assert config.default_embedding_config.embedding_dim == len(
# vector
# ), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(vector)}"
#
# if len(passages) == 0:
# raise ValueError(f"No passages found in index {dir}")
#
# insert_passages_into_source(passages, name, user_id, config)
# except ValueError as e:
# typer.secho(f"Failed to load index from provided information.\n{e}", fg=typer.colors.RED)
default_extensions = ".txt,.md,.pdf"
@ -210,95 +88,99 @@ def load_directory(
input_files: Annotated[List[str], typer.Option(help="List of paths to files containing dataset.")] = [],
recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False,
extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions,
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
):
try:
from llama_index import SimpleDirectoryReader
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
config = MemGPTConfig.load()
if not user_id:
user_id = uuid.UUID(config.anon_clientid)
if recursive == True:
assert input_dir is not None, "Must provide input directory if recursive is True."
ms = MetadataStore(config)
source = Source(name=name, user_id=user_id)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
if input_dir is not None:
reader = SimpleDirectoryReader(
input_dir=str(input_dir),
recursive=recursive,
required_exts=[ext.strip() for ext in str(extensions).split(",")],
# ingest data into passage/document store
num_passages, num_documents = load_data(
connector=connector,
source=source,
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
chunk_size=1000,
)
else:
assert input_files is not None, "Must provide input files if input_dir is None"
reader = SimpleDirectoryReader(input_files=[str(f) for f in input_files])
# load docs
docs = reader.load_data()
store_docs(str(name), docs, user_id)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except ValueError as e:
typer.secho(f"Failed to load directory from provided information.\n{e}", fg=typer.colors.RED)
raise
@app.command("webpage")
def load_webpage(
name: Annotated[str, typer.Option(help="Name of dataset to load.")],
urls: Annotated[List[str], typer.Option(help="List of urls to load.")],
):
try:
from llama_index.readers.web import SimpleWebPageReader
docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
store_docs(name, docs)
except ValueError as e:
typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)
@app.command("database")
def load_database(
name: Annotated[str, typer.Option(help="Name of dataset to load.")],
query: Annotated[str, typer.Option(help="Database query.")],
dump_path: Annotated[Optional[str], typer.Option(help="Path to dump file.")] = None,
scheme: Annotated[Optional[str], typer.Option(help="Database scheme.")] = None,
host: Annotated[Optional[str], typer.Option(help="Database host.")] = None,
port: Annotated[Optional[int], typer.Option(help="Database port.")] = None,
user: Annotated[Optional[str], typer.Option(help="Database user.")] = None,
password: Annotated[Optional[str], typer.Option(help="Database password.")] = None,
dbname: Annotated[Optional[str], typer.Option(help="Database name.")] = None,
):
try:
from llama_index.readers.database import DatabaseReader
print(dump_path, scheme)
if dump_path is not None:
# read from database dump file
from sqlalchemy import create_engine
engine = create_engine(f"sqlite:///{dump_path}")
db = DatabaseReader(engine=engine)
else:
assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
assert scheme is not None, "Must provide database scheme."
assert host is not None, "Must provide database host."
assert port is not None, "Must provide database port."
assert user is not None, "Must provide database user."
assert password is not None, "Must provide database password."
assert dbname is not None, "Must provide database name."
db = DatabaseReader(
scheme=scheme, # Database Scheme
host=host, # Database Host
port=str(port), # Database Port
user=user, # Database User
password=password, # Database Password
dbname=dbname, # Database Name
)
# load data
docs = db.load_data(query=query)
store_docs(name, docs)
except ValueError as e:
typer.secho(f"Failed to load database from provided information.\n{e}", fg=typer.colors.RED)
# @app.command("webpage")
# def load_webpage(
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
# urls: Annotated[List[str], typer.Option(help="List of urls to load.")],
# ):
# try:
# from llama_index.readers.web import SimpleWebPageReader
#
# docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
# store_docs(name, docs)
#
# except ValueError as e:
# typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED)
#
#
# @app.command("database")
# def load_database(
# name: Annotated[str, typer.Option(help="Name of dataset to load.")],
# query: Annotated[str, typer.Option(help="Database query.")],
# dump_path: Annotated[Optional[str], typer.Option(help="Path to dump file.")] = None,
# scheme: Annotated[Optional[str], typer.Option(help="Database scheme.")] = None,
# host: Annotated[Optional[str], typer.Option(help="Database host.")] = None,
# port: Annotated[Optional[int], typer.Option(help="Database port.")] = None,
# user: Annotated[Optional[str], typer.Option(help="Database user.")] = None,
# password: Annotated[Optional[str], typer.Option(help="Database password.")] = None,
# dbname: Annotated[Optional[str], typer.Option(help="Database name.")] = None,
# ):
# try:
# from llama_index.readers.database import DatabaseReader
#
# print(dump_path, scheme)
#
# if dump_path is not None:
# # read from database dump file
# from sqlalchemy import create_engine
#
# engine = create_engine(f"sqlite:///{dump_path}")
#
# db = DatabaseReader(engine=engine)
# else:
# assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
# assert scheme is not None, "Must provide database scheme."
# assert host is not None, "Must provide database host."
# assert port is not None, "Must provide database port."
# assert user is not None, "Must provide database user."
# assert password is not None, "Must provide database password."
# assert dbname is not None, "Must provide database name."
#
# db = DatabaseReader(
# scheme=scheme, # Database Scheme
# host=host, # Database Host
# port=str(port), # Database Port
# user=user, # Database User
# password=password, # Database Password
# dbname=dbname, # Database Name
# )
#
# # load data
# docs = db.load_data(query=query)
# store_docs(name, docs)
# except ValueError as e:
# typer.secho(f"Failed to load database from provided information.\n{e}", fg=typer.colors.RED)
#
@app.command("vector-database")
@ -311,58 +193,35 @@ def load_vector_database(
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None,
):
"""Load pre-computed embeddings into MemGPT from a database."""
if user_id is None:
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
try:
from sqlalchemy import create_engine, select, MetaData, Table, Inspector
from pgvector.sqlalchemy import Vector
# connect to db table
engine = create_engine(uri)
metadata = MetaData()
# Create an inspector to inspect the database
inspector = Inspector.from_engine(engine)
table_names = inspector.get_table_names()
assert table_name in table_names, f"Table {table_name} not found in database: tables that exist {table_names}."
table = Table(table_name, metadata, autoload_with=engine)
config = MemGPTConfig.load()
# Prepare a select statement
select_statement = select(
table.c[text_column], table.c[embedding_column].cast(Vector(config.default_embedding_config.embedding_dim))
)
# Execute the query and fetch the results
with engine.connect() as connection:
result = connection.execute(select_statement).fetchall()
# Convert to a list of tuples (text, embedding)
passages = []
for text, embedding in result:
# assume that embeddings are the same model as in config
passages.append(
Passage(
text=text,
embedding=embedding,
user_id=user_id,
connector = VectorDBConnector(
uri=uri,
table_name=table_name,
text_column=text_column,
embedding_column=embedding_column,
embedding_dim=config.default_embedding_config.embedding_dim,
embedding_model=config.default_embedding_config.embedding_model,
)
)
assert config.default_embedding_config.embedding_dim == len(
embedding
), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(embedding)}"
# create storage connector
config = MemGPTConfig.load()
if user_id is None:
if not user_id:
user_id = uuid.UUID(config.anon_clientid)
insert_passages_into_source(passages, name, user_id, config)
ms = MetadataStore(config)
source = Source(name=name, user_id=user_id)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
# TODO: also get document store
# ingest data into passage/document store
num_passages, num_documents = load_data(
connector=connector,
source=source,
embedding_config=config.default_embedding_config,
document_store=None,
passage_store=passage_storage,
chunk_size=1000,
)
print(f"Loaded {num_passages} passages and {num_documents} documents from {name}")
except ValueError as e:
typer.secho(f"Failed to load vector database from provided information.\n{e}", fg=typer.colors.RED)
typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED)
raise

View File

@ -0,0 +1,194 @@
from memgpt.data_types import Passage, Document, EmbeddingConfig, Source
from memgpt.utils import create_uuid_from_string
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.embeddings import embedding_model
from memgpt.data_types import Document, Passage
import uuid
from typing import List, Iterator, Dict, Tuple, Optional
from llama_index.core import Document as LlamaIndexDocument
class DataConnector:
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
pass
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
pass
def load_data(
connector: DataConnector,
source: Source,
embedding_config: EmbeddingConfig,
passage_store: StorageConnector,
document_store: Optional[StorageConnector] = None,
chunk_size: int = 1000,
):
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
# embedding model
embed_model = embedding_model(embedding_config)
# insert passages/documents
passages = []
passage_count = 0
document_count = 0
for document_text, document_metadata in connector.generate_documents():
# insert document into storage
document = Document(
id=create_uuid_from_string(f"{str(source.id)}_{document_text}"),
text=document_text,
metadata=document_metadata,
data_source=source.name,
user_id=source.user_id,
)
document_count += 1
if document_store:
document_store.insert(document)
# generate passages
for passage_text, passage_metadata in connector.generate_passages([document]):
print("passage", passage_text, passage_metadata)
embedding = embed_model.get_text_embedding(passage_text)
passage = Passage(
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
text=passage_text,
doc_id=document.id,
metadata_=passage_metadata,
user_id=source.user_id,
data_source=source.name,
embedding_dim=embedding_config.embedding_dim,
embedding_model=embedding_config.embedding_model,
embedding=embedding,
)
passages.append(passage)
if len(passages) >= chunk_size:
# insert passages into passage store
passage_store.insert_many(passages)
passage_count += len(passages)
passages = []
if len(passages) > 0:
# insert passages into passage store
passage_store.insert_many(passages)
passage_count += len(passages)
return passage_count, document_count
class DirectoryConnector(DataConnector):
def __init__(self, input_files: List[str] = None, input_directory: str = None, recursive: bool = False, extensions: List[str] = None):
self.connector_type = "directory"
self.input_files = input_files
self.input_directory = input_directory
self.recursive = recursive
self.extensions = extensions
if self.recursive == True:
assert self.input_dir is not None, "Must provide input directory if recursive is True."
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
from llama_index.core import SimpleDirectoryReader
if self.input_directory is not None:
reader = SimpleDirectoryReader(
input_dir=self.input_directory,
recursive=self.recursive,
required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
)
else:
assert self.input_files is not None, "Must provide input files if input_dir is None"
reader = SimpleDirectoryReader(input_files=[str(f) for f in self.input_files])
llama_index_docs = reader.load_data()
docs = []
for llama_index_doc in llama_index_docs:
# TODO: add additional metadata?
# doc = Document(text=llama_index_doc.text, metadata=llama_index_doc.metadata)
# docs.append(doc)
yield llama_index_doc.text, llama_index_doc.metadata
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
# use llama index to run embeddings code
from llama_index.core.node_parser import SentenceSplitter
parser = SentenceSplitter(chunk_size=chunk_size)
for document in documents:
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata)]
nodes = parser.get_nodes_from_documents(llama_index_docs)
for node in nodes:
# passage = Passage(
# text=node.text,
# doc_id=document.id,
# )
yield node.text, None
class WebConnector(DataConnector):
# TODO
def __init__(self):
pass
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
pass
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
pass
class VectorDBConnector(DataConnector):
# NOTE: this class has not been properly tested, so is unlikely to work
# TODO: allow loading multiple tables (1:1 mapping between Document and Table)
def __init__(
self,
name: str,
uri: str,
table_name: str,
text_column: str,
embedding_column: str,
embedding_dim: int,
):
self.name = name
self.uri = uri
self.table_name = table_name
self.text_column = text_column
self.embedding_column = embedding_column
self.embedding_dim = embedding_dim
# connect to db table
from sqlalchemy import create_engine
self.engine = create_engine(uri)
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
yield self.table_name, None
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
from sqlalchemy import select, MetaData, Table, Inspector
from pgvector.sqlalchemy import Vector
metadata = MetaData()
# Create an inspector to inspect the database
inspector = Inspector.from_engine(self.engine)
table_names = inspector.get_table_names()
assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
table = Table(self.table_name, metadata, autoload_with=self.engine)
# Prepare a select statement
select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
# Execute the query and fetch the results
# TODO: paginate results
with self.engine.connect() as connection:
result = connection.execute(select_statement).fetchall()
for text, embedding in result:
# assume that embeddings are the same model as in config
# TODO: don't re-compute embedding
yield text, {"embedding": embedding}

View File

@ -11,6 +11,9 @@ from memgpt.utils import get_local_time, format_datetime, get_utc_time, create_u
from memgpt.models import chat_completion_response
from memgpt.utils import get_human_text, get_persona_text, printd
from pydantic import BaseModel, Field, Json
from memgpt.utils import get_human_text, get_persona_text, printd
from pydantic import BaseModel, Field, Json
@ -274,7 +277,7 @@ class Message(Record):
class Document(Record):
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None):
def __init__(self, user_id: uuid.UUID, text: str, data_source: str, id: Optional[uuid.UUID] = None, metadata: Optional[Dict] = {}):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
self.id = create_uuid_from_string("".join([text, str(user_id)]))
@ -284,6 +287,7 @@ class Document(Record):
self.user_id = user_id
self.text = text
self.data_source = data_source
self.metadata = metadata
# TODO: add optional embedding?
@ -295,8 +299,8 @@ class Passage(Record):
def __init__(
self,
user_id: uuid.UUID,
text: str,
user_id: Optional[uuid.UUID] = None,
agent_id: Optional[uuid.UUID] = None, # set if contained in agent memory
embedding: Optional[np.ndarray] = None,
embedding_dim: Optional[int] = None,
@ -308,7 +312,11 @@ class Passage(Record):
):
if id is None:
# by default, generate ID as a hash of the text (avoid duplicates)
# TODO: use source-id instead?
if agent_id:
self.id = create_uuid_from_string("".join([text, str(agent_id), str(user_id)]))
else:
self.id = create_uuid_from_string("".join([text, str(user_id)]))
else:
self.id = id
super().__init__(self.id)
@ -334,6 +342,7 @@ class Passage(Record):
assert len(self.embedding) == MAX_EMBEDDING_DIM, f"Embedding must be of length {MAX_EMBEDDING_DIM}"
assert isinstance(self.user_id, uuid.UUID), f"UUID {self.user_id} must be a UUID type"
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type"
assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type"

View File

@ -1,6 +1,6 @@
import typer
import uuid
from typing import Optional, List
from typing import Optional, List, Any
import os
import numpy as np
@ -9,13 +9,26 @@ from memgpt.data_types import EmbeddingConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.constants import MAX_EMBEDDING_DIM, EMBEDDING_TO_TOKENIZER_MAP, EMBEDDING_TO_TOKENIZER_DEFAULT
from llama_index.embeddings import OpenAIEmbedding, AzureOpenAIEmbedding
from llama_index.bridge.pydantic import PrivateAttr
from llama_index.embeddings.base import BaseEmbedding
from llama_index.embeddings.huggingface_utils import format_text
# from llama_index.core.base.embeddings import BaseEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import Document as LlamaIndexDocument
# from llama_index.core.base.embeddings import BaseEmbedding
# from llama_index.core.embeddings import BaseEmbedding
# from llama_index.core.base.embeddings.base import BaseEmbedding
# from llama_index.bridge.pydantic import PrivateAttr
# from llama_index.embeddings.base import BaseEmbedding
# from llama_index.embeddings.huggingface_utils import format_text
import tiktoken
def parse_and_chunk_text(text: str, chunk_size: int) -> List[str]:
parser = SentenceSplitter(chunk_size=chunk_size)
llama_index_docs = [LlamaIndexDocument(text=text)]
nodes = parser.get_nodes_from_documents(llama_index_docs)
return [n.text for n in nodes]
def truncate_text(text: str, max_length: int, encoding) -> str:
# truncate the text based on max_length and encoding
encoded_text = encoding.encode(text)[:max_length]
@ -53,15 +66,15 @@ def check_and_split_text(text: str, embedding_model: str) -> List[str]:
return [text]
class EmbeddingEndpoint(BaseEmbedding):
class EmbeddingEndpoint:
"""Implementation for OpenAI compatible endpoint"""
""" Based off llama index https://github.com/run-llama/llama_index/blob/a98bdb8ecee513dc2e880f56674e7fd157d1dc3a/llama_index/embeddings/text_embeddings_inference.py """
# """ Based off llama index https://github.com/run-llama/llama_index/blob/a98bdb8ecee513dc2e880f56674e7fd157d1dc3a/llama_index/embeddings/text_embeddings_inference.py """
_user: str = PrivateAttr()
_timeout: float = PrivateAttr()
_base_url: str = PrivateAttr()
# _user: str = PrivateAttr()
# _timeout: float = PrivateAttr()
# _base_url: str = PrivateAttr()
def __init__(
self,
@ -69,21 +82,16 @@ class EmbeddingEndpoint(BaseEmbedding):
base_url: str,
user: str,
timeout: float = 60.0,
**kwargs: Any,
):
if not is_valid_url(base_url):
raise ValueError(
f"Embeddings endpoint was provided an invalid URL (set to: '{base_url}'). Make sure embedding_endpoint is set correctly in your MemGPT config."
)
self.model_name = model
self._user = user
self._base_url = base_url
self._timeout = timeout
super().__init__(
model_name=model,
)
@classmethod
def class_name(cls) -> str:
return "EmbeddingEndpoint"
def _call_api(self, text: str) -> List[float]:
if not is_valid_url(self._base_url):
@ -120,59 +128,8 @@ class EmbeddingEndpoint(BaseEmbedding):
return embedding
async def _acall_api(self, text: str) -> List[float]:
if not is_valid_url(self._base_url):
raise ValueError(
f"Embeddings endpoint does not have a valid URL (set to: '{self._base_url}'). Make sure embedding_endpoint is set correctly in your MemGPT config."
)
import httpx
headers = {"Content-Type": "application/json"}
json_data = {"input": text, "model": self.model_name, "user": self._user}
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self._base_url}/embeddings",
headers=headers,
json=json_data,
timeout=self._timeout,
)
response_json = response.json()
if isinstance(response_json, list):
# embedding directly in response
embedding = response_json
elif isinstance(response_json, dict):
# TEI embedding packaged inside openai-style response
try:
embedding = response_json["data"][0]["embedding"]
except (KeyError, IndexError):
raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}")
else:
# unknown response, can't parse
raise TypeError(f"Got back an unexpected payload from text embedding function, response=\n{response_json}")
return embedding
def _get_query_embedding(self, query: str) -> list[float]:
"""get query embedding."""
embedding = self._call_api(query)
return embedding
def _get_text_embedding(self, text: str) -> list[float]:
"""get text embedding."""
embedding = self._call_api(text)
return embedding
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
embeddings = [self._get_text_embedding(text) for text in texts]
return embeddings
async def _aget_query_embedding(self, query: str) -> List[float]:
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
return self._get_text_embedding(text)
def get_text_embedding(self, text: str) -> List[float]:
return self._call_api(text)
def default_embedding_model():
@ -202,10 +159,14 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
credentials = MemGPTCredentials.load()
if endpoint_type == "openai":
from llama_index.embeddings.openai import OpenAIEmbedding
additional_kwargs = {"user_id": user_id} if user_id else {}
model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=credentials.openai_key, additional_kwargs=additional_kwargs)
return model
elif endpoint_type == "azure":
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
# https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings
model = "text-embedding-ada-002"
deployment = credentials.azure_embedding_deployment if credentials.azure_embedding_deployment is not None else model

View File

@ -11,15 +11,16 @@ from memgpt.constants import FUNC_FAILED_HEARTBEAT_MESSAGE, JSON_ENSURE_ASCII, J
console = Console()
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.config import MemGPTConfig
import memgpt.agent as agent
import memgpt.system as system
import memgpt.errors as errors
from memgpt.cli.cli import run, attach, version, server, open_folder, quickstart, migrate, delete_agent
from memgpt.cli.cli import run, version, server, open_folder, quickstart, migrate, delete_agent
from memgpt.cli.cli_config import configure, list, add, delete
from memgpt.cli.cli_load import app as load_app
from memgpt.metadata import MetadataStore, save_agent
from memgpt.metadata import MetadataStore
# import benchmark
from memgpt.benchmark.benchmark import bench
@ -27,7 +28,6 @@ from memgpt.benchmark.benchmark import bench
app = typer.Typer(pretty_exceptions_enable=False)
app.command(name="run")(run)
app.command(name="version")(version)
app.command(name="attach")(attach)
app.command(name="configure")(configure)
app.command(name="list")(list)
app.command(name="add")(add)
@ -100,11 +100,11 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
# updated agent save functions
if user_input.lower() == "/exit":
# memgpt_agent.save()
save_agent(memgpt_agent, ms)
agent.save_agent(memgpt_agent, ms)
break
elif user_input.lower() == "/save" or user_input.lower() == "/savechat":
# memgpt_agent.save()
save_agent(memgpt_agent, ms)
agent.save_agent(memgpt_agent, ms)
continue
elif user_input.lower() == "/attach":
# TODO: check if agent already has it
@ -143,7 +143,11 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore,
data_source = questionary.select("Select data source", choices=valid_options).ask()
# attach new data
attach(memgpt_agent.agent_state.name, data_source)
# attach(memgpt_agent.agent_state.name, data_source)
source_connector = StorageConnector.get_storage_connector(
TableType.PASSAGES, config, user_id=memgpt_agent.agent_state.user_id
)
memgpt_agent.attach_source(data_source, source_connector, ms)
continue

View File

@ -7,9 +7,10 @@ from memgpt.utils import get_local_time, printd, count_tokens, validate_date_for
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from memgpt.llm_api_tools import create
from memgpt.data_types import Message, Passage, AgentState
from memgpt.embeddings import embedding_model, query_embedding
from llama_index import Document
from llama_index.node_parser import SimpleNodeParser
from memgpt.embeddings import embedding_model, query_embedding, parse_and_chunk_text
# from llama_index import Document
# from llama_index.node_parser import SimpleNodeParser
class CoreMemory(object):
@ -402,12 +403,9 @@ class EmbeddingArchivalMemory(ArchivalMemory):
try:
passages = []
# create parser
parser = SimpleNodeParser.from_defaults(chunk_size=self.embedding_chunk_size)
# breakup string into passages
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
embedding = self.embed_model.get_text_embedding(node.text)
for text in parse_and_chunk_text(memory_string, self.embedding_chunk_size):
embedding = self.embed_model.get_text_embedding(text)
# fixing weird bug where type returned isn't a list, but instead is an object
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
if isinstance(embedding, dict):
@ -418,7 +416,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
passages.append(self.create_passage(node.text, embedding))
passages.append(self.create_passage(text, embedding))
# insert passages
self.storage.insert_many(passages)

View File

@ -10,7 +10,6 @@ from memgpt.constants import DEFAULT_HUMAN, DEFAULT_MEMGPT_MODEL, DEFAULT_PERSON
from memgpt.utils import get_local_time, enforce_types
from memgpt.data_types import AgentState, Source, User, LLMConfig, EmbeddingConfig, Token, Preset
from memgpt.config import MemGPTConfig
from memgpt.agent import Agent
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, Boolean
from sqlalchemy import func
@ -379,11 +378,15 @@ class MetadataStore:
session.commit()
@enforce_types
def create_source(self, source: Source):
def create_source(self, source: Source, exists_ok=False):
# make sure source.name does not already exist for user
with self.session_maker() as session:
if session.query(SourceModel).filter(SourceModel.name == source.name).filter(SourceModel.user_id == source.user_id).count() > 0:
if not exists_ok:
raise ValueError(f"Source with name {source.name} already exists")
else:
session.update(SourceModel(**vars(source)))
else:
session.add(SourceModel(**vars(source)))
session.commit()
@ -596,15 +599,3 @@ class MetadataStore:
AgentSourceMappingModel.agent_id == agent_id, AgentSourceMappingModel.source_id == source_id
).delete()
session.commit()
def save_agent(agent: Agent, ms: MetadataStore):
"""Save agent to metadata store"""
agent.update_state()
agent_state = agent.agent_state
if ms.get_agent(agent_id=agent_state.id):
ms.update_agent(agent_state)
else:
ms.create_agent(agent_state)

View File

@ -15,14 +15,10 @@ import typer
from tqdm import tqdm
import questionary
from llama_index import (
StorageContext,
load_index_from_storage,
)
from memgpt.agent import Agent
from memgpt.agent import Agent, save_agent
from memgpt.data_types import AgentState, User, Passage, Source, Message
from memgpt.metadata import MetadataStore, save_agent
from memgpt.metadata import MetadataStore
from memgpt.utils import (
MEMGPT_DIR,
version_less_than,
@ -159,8 +155,18 @@ def migrate_source(source_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Me
source = Source(user_id=user.id, name=source_name)
ms.create_source(source)
try:
try:
nodes = pickle.load(open(source_path, "rb"))
except ModuleNotFoundError as e:
if "No module named 'llama_index.schema'" in str(e):
# cannot load source at all, so throw error
raise ValueError(
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
)
else:
raise e
passages = []
for node in nodes:
# print(len(node.embedding))
@ -485,7 +491,17 @@ def migrate_agent(agent_name: str, data_dir: str = MEMGPT_DIR, ms: Optional[Meta
# 5. Insert into archival
if os.path.exists(archival_filename):
try:
nodes = pickle.load(open(archival_filename, "rb"))
except ModuleNotFoundError as e:
if "No module named 'llama_index.schema'" in str(e):
print(
"Failed to load archival memory due thanks to llama_index's breaking changes. Please downgrade to MemGPT version 0.3.3 or earlier to migrate this agent."
)
nodes = []
else:
raise e
passages = []
failed_inserts = []
for node in nodes:

View File

@ -23,6 +23,7 @@ def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)
print("PRESET", preset_name, user_id)
if ms.get_preset(user_id=user_id, preset_name=preset_name) is not None:
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
continue

View File

@ -11,21 +11,21 @@ from fastapi import HTTPException
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.constants import JSON_LOADS_STRICT, JSON_ENSURE_ASCII
from memgpt.agent import Agent
from memgpt.agent import Agent, save_agent
import memgpt.system as system
import memgpt.constants as constants
from memgpt.cli.cli import attach
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options
# from memgpt.agent_store.storage import StorageConnector
from memgpt.metadata import MetadataStore, save_agent
from memgpt.metadata import MetadataStore
import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.server.utils as server_utils
from memgpt.data_types import (
User,
Passage,
AgentState,
LLMConfig,
EmbeddingConfig,
@ -394,7 +394,9 @@ class SyncServer(LockingServer):
except:
raise ValueError(command)
attach(agent_name=memgpt_agent.agent_state.name, data_source=data_source, user_id=user_id)
# attach data to agent from source
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
memgpt_agent.attach_source(data_source, source_connector, self.ms)
elif command.lower() == "dump" or command.lower().startswith("dump "):
# Check if there's an additional argument that's an integer
@ -642,6 +644,10 @@ class SyncServer(LockingServer):
if agent is not None:
self.ms.delete_agent(agent_id=agent_id)
def initialize_default_presets(self, user_id: uuid.UUID):
"""Add default preset options into the metadata store"""
presets.add_default_presets(user_id, self.ms)
def create_preset(self, preset: Preset):
"""Create a new preset using a config"""
if self.ms.get_user(user_id=preset.user_id) is None:
@ -1009,3 +1015,16 @@ class SyncServer(LockingServer):
"""Create a new API key for a user"""
token = self.ms.create_api_key(user_id=user_id)
return token
def create_source(self, name: str): # TODO: add other fields
# craete a data source
pass
def load_passages(self, source_id: uuid.UUID, passages: List[Passage]):
# load a list of passages into a data source
pass
def attach_source_to_agent(self, agent_id: uuid.UUID, source_id: uuid.UUID):
# attach a data source to an agent
# TODO: insert passages into agent archival memory
pass

1616
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,6 @@ pytz = "^2023.3.post1"
tqdm = "^4.66.1"
black = { version = "^23.10.1", optional = true }
pytest = { version = "^7.4.4", optional = true }
llama-index = "^0.9.13"
setuptools = "^68.2.2"
datasets = { version = "^2.14.6", optional = true}
prettytable = "^3.9.0"
@ -53,6 +52,9 @@ docx2txt = "^0.8"
sqlalchemy = "^2.0.25"
pexpect = {version = "^4.9.0", optional = true}
pyright = {version = "^1.1.347", optional = true}
llama-index = "^0.10.6"
llama-index-embeddings-openai = "^0.1.1"
python-box = "^7.1.1"
[tool.poetry.extras]
local = ["torch", "huggingface-hub", "transformers"]

View File

@ -6,8 +6,9 @@ from sqlalchemy.ext.declarative import declarative_base
# import memgpt
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
from memgpt.cli.cli import attach
from memgpt.cli.cli_load import load_directory
# from memgpt.data_sources.connectors import DirectoryConnector, load_data
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.metadata import MetadataStore
@ -153,28 +154,29 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
sources = ms.list_sources(user_id=user_id)
print("All sources", [s.name for s in sources])
# test loading into an agent
# create agent
agent_id = agent.id
# create storage connector
print("Creating agent archival storage connector...")
conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
print("Deleting agent archival table...")
conn.delete_table()
conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
# TODO: add back once agent attachment fully supported from server
## test loading into an agent
## create agent
# agent_id = agent.id
## create storage connector
# print("Creating agent archival storage connector...")
# conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
# print("Deleting agent archival table...")
# conn.delete_table()
# conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config=config, user_id=user_id, agent_id=agent_id)
# assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
# attach data
print("Attaching data...")
attach(agent_name=agent.name, data_source=name, user_id=user_id)
## attach data
# print("Attaching data...")
# attach(agent_name=agent.name, data_source=name, user_id=user_id)
# test to see if contained in storage
assert len(passages) == conn.size()
assert len(passages) == len(conn.get_all({"data_source": name}))
## test to see if contained in storage
# assert len(passages) == conn.size()
# assert len(passages) == len(conn.get_all({"data_source": name}))
# test: delete source
passages_conn.delete({"data_source": name})
assert len(passages_conn.get_all({"data_source": name})) == 0
## test: delete source
# passages_conn.delete({"data_source": name})
# assert len(passages_conn.get_all({"data_source": name})) == 0
# cleanup
ms.delete_user(user.id)

View File

@ -18,10 +18,12 @@ def test_migrate_0211():
# os.environ["MEMGPT_CONFIG_PATH"] = os.path.join(data_dir, "config")
# print(f"MEMGPT_CONFIG_PATH={os.environ['MEMGPT_CONFIG_PATH']}")
try:
agent_res = migrate_all_agents(tmp_dir)
agent_res = migrate_all_agents(tmp_dir, debug=True)
assert len(agent_res["failed_migrations"]) == 0, f"Failed migrations: {agent_res}"
source_res = migrate_all_sources(tmp_dir)
assert len(source_res["failed_migrations"]) == 0, f"Failed migrations: {source_res}"
# NOTE: source tests had to be removed since it is no longer possible to migrate llama index vector indices
# source_res = migrate_all_sources(tmp_dir)
# assert len(source_res["failed_migrations"]) == 0, f"Failed migrations: {source_res}"
# TODO: assert everything is in the DB
@ -40,11 +42,11 @@ def test_migrate_0211():
messages = server.get_agent_messages(user_id=agent_state.user_id, agent_id=agent_state.id, start=0, count=1000)
assert len(messages) > 0
for source_name in source_res["migration_candidates"]:
if source_name not in source_res["failed_migrations"]:
# assert source data exists
source = server.ms.get_source(source_name=source_name, user_id=source_res["user_id"])
assert source is not None
# for source_name in source_res["migration_candidates"]:
# if source_name not in source_res["failed_migrations"]:
# # assert source data exists
# source = server.ms.get_source(source_name=source_name, user_id=source_res["user_id"])
# assert source is not None
except Exception as e:
raise e
finally:

View File

@ -5,17 +5,22 @@ import uuid
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.server import app
from memgpt.constants import DEFAULT_PRESET
from memgpt.config import MemGPTConfig
def test_list_messages():
client = TestClient(app)
test_user_id = uuid.uuid4()
test_user_id = uuid.UUID(MemGPTConfig.load().anon_clientid)
# create user
server = SyncServer()
if not server.get_user(test_user_id):
server.create_user({"id": test_user_id})
# write default presets to DB
server.initialize_default_presets(test_user_id)
# test: create agent
request_body = {
"user_id": str(test_user_id),