From cf927b4e863c4523e58c1cc59ccb7abe28c62516 Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Wed, 18 Oct 2023 19:30:15 -0700 Subject: [PATCH] support generating embeddings on the fly --- README.md | 23 +++++++- interface.py | 3 ++ main.py | 6 +++ memgpt/utils.py | 136 ++++++++++++++++++++++++++++++++++++----------- requirements.txt | 1 + 5 files changed, 138 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 130e9da72..52049bb7d 100644 --- a/README.md +++ b/README.md @@ -107,8 +107,10 @@ python main.py --human me.txt enables debugging output --archival_storage_faiss_path= load in document database (backed by FAISS index) ---archival_storage_files="" +--archival_storage_files="" pre-load files into archival memory +--archival_storage_files_compute_embeddings="" + pre-load files into archival memory and also compute embeddings for embedding search --archival_storage_sqldb= load in SQL database ``` @@ -181,6 +183,25 @@ To run our example where you can search over the SEC 10-K filings of Uber, Lyft, ``` If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes). + +#### Enhance with embeddings search +In the root `MemGPT` directory, run + ```bash + python3 main.py --archival_storage_files_compute_embeddings="" --persona=memgpt_doc --human=basic + ``` + +This will generate embeddings, stick them into a FAISS index, and write the index to a directory, and then output: +``` + To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings= with + --archival_storage_faiss_path= (if your files haven't changed). +``` + +If you want to reuse these embeddings, run +```bash +python3 main.py --archival_storage_faiss_path="" --persona=memgpt_doc --human=basic +``` + +

Talking to LlamaIndex API Docs

diff --git a/interface.py b/interface.py index b8729b1e9..af5cfc466 100644 --- a/interface.py +++ b/interface.py @@ -10,6 +10,9 @@ init(autoreset=True) # DEBUG = True # puts full message outputs in the terminal DEBUG = False # only dumps important messages in the terminal +def important_message(msg): + print(f'{Fore.MAGENTA}{Style.BRIGHT}{msg}{Style.RESET_ALL}') + async def internal_monologue(msg): # ANSI escape code for italic is '\x1B[3m' print(f'\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}') diff --git a/main.py b/main.py index 797cdadba..76cc6ed9d 100644 --- a/main.py +++ b/main.py @@ -27,6 +27,7 @@ flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output") flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)") flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)") +flags.DEFINE_string("archival_storage_files_compute_embeddings", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern), and compute embeddings over them") flags.DEFINE_string("archival_storage_sqldb", default="", required=False, help="Specify SQL database to pre-load into archival memory") @@ -54,6 +55,11 @@ async def main(): archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files) print(f"Preloaded {len(archival_database)} chunks into archival memory.") persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database) + elif FLAGS.archival_storage_files_compute_embeddings: + faiss_save_dir = await utils.prepare_archival_index_from_files_compute_embeddings(FLAGS.archival_storage_files_compute_embeddings) + interface.important_message(f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed).") + index, archival_database = utils.prepare_archival_index(faiss_save_dir) + persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database) else: persistence_manager = InMemoryStateManager() memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager) diff --git a/memgpt/utils.py b/memgpt/utils.py index 76905d398..37746a153 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -9,6 +9,8 @@ import faiss import tiktoken import glob import sqlite3 +from tqdm import tqdm +from memgpt.openai_tools import async_get_embedding_with_backoff def count_tokens(s: str, model: str = "gpt-4") -> int: encoding = tiktoken.encoding_for_model(model) @@ -97,41 +99,54 @@ def read_in_chunks(file_object, chunk_size): def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'): encoding = tiktoken.encoding_for_model(model) files = glob.glob(glob_pattern) + return chunk_files(files, tkns_per_chunk, model) + +def total_bytes(pattern): + total = 0 + for filename in glob.glob(pattern): + if os.path.isfile(filename): # ensure it's a file and not a directory + total += os.path.getsize(filename) + return total + +def chunk_file(file, tkns_per_chunk=300, model='gpt-4'): + encoding = tiktoken.encoding_for_model(model) + with open(file, 'r') as f: + lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)] + curr_chunk = [] + curr_token_ct = 0 + for i, line in enumerate(lines): + line = line.rstrip() + line = line.lstrip() + try: + line_token_ct = len(encoding.encode(line)) + except Exception as e: + line_token_ct = len(line.split(' ')) / .75 + print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens") + print(e) + if line_token_ct > tkns_per_chunk: + if len(curr_chunk) > 0: + yield ''.join(curr_chunk) + curr_chunk = [] + curr_token_ct = 0 + yield line[:3200] + continue + curr_token_ct += line_token_ct + curr_chunk.append(line) + if curr_token_ct > tkns_per_chunk: + yield ''.join(curr_chunk) + curr_chunk = [] + curr_token_ct = 0 + + if len(curr_chunk) > 0: + yield ''.join(curr_chunk) + +def chunk_files(files, tkns_per_chunk=300, model='gpt-4'): archival_database = [] for file in files: timestamp = os.path.getmtime(file) formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z") - with open(file, 'r') as f: - lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)] - chunks = [] - curr_chunk = [] - curr_token_ct = 0 - for line in lines: - line = line.rstrip() - line = line.lstrip() - try: - line_token_ct = len(encoding.encode(line)) - except Exception as e: - line_token_ct = len(line.split(' ')) / .75 - print(f"Could not encode line {line}, estimating it to be {line_token_ct} tokens") - if line_token_ct > tkns_per_chunk: - if len(curr_chunk) > 0: - chunks.append(''.join(curr_chunk)) - curr_chunk = [] - curr_token_ct = 0 - chunks.append(line[:3200]) - continue - curr_token_ct += line_token_ct - curr_chunk.append(line) - if curr_token_ct > tkns_per_chunk: - chunks.append(''.join(curr_chunk)) - curr_chunk = [] - curr_token_ct = 0 - - if len(curr_chunk) > 0: - chunks.append(''.join(curr_chunk)) - file_stem = file.split('/')[-1] + chunks = [c for c in chunk_file(file, tkns_per_chunk, model)] for i, chunk in enumerate(chunks): archival_database.append({ 'content': f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}", @@ -139,6 +154,67 @@ def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='g }) return archival_database +def chunk_files_for_jsonl(files, tkns_per_chunk=300, model='gpt-4'): + ret = [] + for file in files: + file_stem = file.split('/')[-1] + curr_file = [] + for chunk in chunk_file(file, tkns_per_chunk, model): + curr_file.append({ + 'title': file_stem, + 'text': chunk, + }) + ret.append(curr_file) + return ret + +async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkns_per_chunk=300, model='gpt-4', embeddings_model='text-embedding-ada-002'): + files = sorted(glob.glob(glob_pattern)) + save_dir = "archival_index_from_files_" + get_local_time().replace(' ', '_').replace(':', '_') + os.makedirs(save_dir, exist_ok=True) + total_tokens = total_bytes(glob_pattern) / 3 + price_estimate = total_tokens / 1000 * .0001 + confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ") + if confirm != 'y': + raise Exception("embeddings were not computed") + + # chunk the files, make embeddings + archival_database = chunk_files(files, tkns_per_chunk, model) + embedding_data = [] + for chunk in tqdm(archival_database, desc="Processing file chunks", total=len(archival_database)): + # for chunk in tqdm(f, desc=f"Embedding file {i+1}/{len(chunks_by_file)}", total=len(f), leave=False): + try: + embedding = await async_get_embedding_with_backoff(chunk['content'], model=embeddings_model) + except Exception as e: + print(chunk) + raise e + embedding_data.append(embedding) + embeddings_file = os.path.join(save_dir, "embeddings.json") + with open(embeddings_file, 'w') as f: + print(f"Saving embeddings to {embeddings_file}") + json.dump(embedding_data, f) + + # make all_text.json + archival_storage_file = os.path.join(save_dir, "all_docs.jsonl") + chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model) + with open(archival_storage_file, 'w') as f: + print(f"Saving archival storage with preloaded files to {archival_storage_file}") + for c in chunks_by_file: + json.dump(c, f) + f.write('\n') + + # make the faiss index + index = faiss.IndexFlatL2(1536) + data = np.array(embedding_data).astype('float32') + try: + index.add(data) + except Exception as e: + print(data) + raise e + index_file = os.path.join(save_dir, "all_docs.index") + print(f"Saving faiss index {index_file}") + faiss.write_index(index, index_file) + return save_dir + def read_database_as_list(database_name): result_list = [] diff --git a/requirements.txt b/requirements.txt index 6a18f0225..6258b8563 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ pytz rich tiktoken timezonefinder +tqdm