support generating embeddings on the fly

This commit is contained in:
Vivian Fang 2023-10-18 19:30:15 -07:00
parent 924be62eea
commit cf927b4e86
5 changed files with 138 additions and 31 deletions

View File

@ -107,8 +107,10 @@ python main.py --human me.txt
enables debugging output
--archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH>
load in document database (backed by FAISS index)
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB>"
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB_PATTERN>"
pre-load files into archival memory
--archival_storage_files_compute_embeddings="<ARCHIVAL_STORAGE_FILES_GLOB_PATTERN>"
pre-load files into archival memory and also compute embeddings for embedding search
--archival_storage_sqldb=<SQLDB_PATH>
load in SQL database
```
@ -181,6 +183,25 @@ To run our example where you can search over the SEC 10-K filings of Uber, Lyft,
```
If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes).
#### Enhance with embeddings search
In the root `MemGPT` directory, run
```bash
python3 main.py --archival_storage_files_compute_embeddings="<GLOB_PATTERN>" --persona=memgpt_doc --human=basic
```
This will generate embeddings, stick them into a FAISS index, and write the index to a directory, and then output:
```
To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings=<GLOB_PATTERN> with
--archival_storage_faiss_path=<DIRECTORY_WITH_EMBEDDINGS> (if your files haven't changed).
```
If you want to reuse these embeddings, run
```bash
python3 main.py --archival_storage_faiss_path="<DIRECTORY_WITH_EMBEDDINGS>" --persona=memgpt_doc --human=basic
```
</details>
<details>
<summary><h3>Talking to LlamaIndex API Docs</h3></summary>

View File

@ -10,6 +10,9 @@ init(autoreset=True)
# DEBUG = True # puts full message outputs in the terminal
DEBUG = False # only dumps important messages in the terminal
def important_message(msg):
print(f'{Fore.MAGENTA}{Style.BRIGHT}{msg}{Style.RESET_ALL}')
async def internal_monologue(msg):
# ANSI escape code for italic is '\x1B[3m'
print(f'\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}')

View File

@ -27,6 +27,7 @@ flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)")
flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)")
flags.DEFINE_string("archival_storage_files_compute_embeddings", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern), and compute embeddings over them")
flags.DEFINE_string("archival_storage_sqldb", default="", required=False, help="Specify SQL database to pre-load into archival memory")
@ -54,6 +55,11 @@ async def main():
archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files)
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database)
elif FLAGS.archival_storage_files_compute_embeddings:
faiss_save_dir = await utils.prepare_archival_index_from_files_compute_embeddings(FLAGS.archival_storage_files_compute_embeddings)
interface.important_message(f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed).")
index, archival_database = utils.prepare_archival_index(faiss_save_dir)
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
else:
persistence_manager = InMemoryStateManager()
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)

View File

@ -9,6 +9,8 @@ import faiss
import tiktoken
import glob
import sqlite3
from tqdm import tqdm
from memgpt.openai_tools import async_get_embedding_with_backoff
def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
@ -97,41 +99,54 @@ def read_in_chunks(file_object, chunk_size):
def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'):
encoding = tiktoken.encoding_for_model(model)
files = glob.glob(glob_pattern)
return chunk_files(files, tkns_per_chunk, model)
def total_bytes(pattern):
total = 0
for filename in glob.glob(pattern):
if os.path.isfile(filename): # ensure it's a file and not a directory
total += os.path.getsize(filename)
return total
def chunk_file(file, tkns_per_chunk=300, model='gpt-4'):
encoding = tiktoken.encoding_for_model(model)
with open(file, 'r') as f:
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
curr_chunk = []
curr_token_ct = 0
for i, line in enumerate(lines):
line = line.rstrip()
line = line.lstrip()
try:
line_token_ct = len(encoding.encode(line))
except Exception as e:
line_token_ct = len(line.split(' ')) / .75
print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens")
print(e)
if line_token_ct > tkns_per_chunk:
if len(curr_chunk) > 0:
yield ''.join(curr_chunk)
curr_chunk = []
curr_token_ct = 0
yield line[:3200]
continue
curr_token_ct += line_token_ct
curr_chunk.append(line)
if curr_token_ct > tkns_per_chunk:
yield ''.join(curr_chunk)
curr_chunk = []
curr_token_ct = 0
if len(curr_chunk) > 0:
yield ''.join(curr_chunk)
def chunk_files(files, tkns_per_chunk=300, model='gpt-4'):
archival_database = []
for file in files:
timestamp = os.path.getmtime(file)
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
with open(file, 'r') as f:
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
chunks = []
curr_chunk = []
curr_token_ct = 0
for line in lines:
line = line.rstrip()
line = line.lstrip()
try:
line_token_ct = len(encoding.encode(line))
except Exception as e:
line_token_ct = len(line.split(' ')) / .75
print(f"Could not encode line {line}, estimating it to be {line_token_ct} tokens")
if line_token_ct > tkns_per_chunk:
if len(curr_chunk) > 0:
chunks.append(''.join(curr_chunk))
curr_chunk = []
curr_token_ct = 0
chunks.append(line[:3200])
continue
curr_token_ct += line_token_ct
curr_chunk.append(line)
if curr_token_ct > tkns_per_chunk:
chunks.append(''.join(curr_chunk))
curr_chunk = []
curr_token_ct = 0
if len(curr_chunk) > 0:
chunks.append(''.join(curr_chunk))
file_stem = file.split('/')[-1]
chunks = [c for c in chunk_file(file, tkns_per_chunk, model)]
for i, chunk in enumerate(chunks):
archival_database.append({
'content': f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}",
@ -139,6 +154,67 @@ def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='g
})
return archival_database
def chunk_files_for_jsonl(files, tkns_per_chunk=300, model='gpt-4'):
ret = []
for file in files:
file_stem = file.split('/')[-1]
curr_file = []
for chunk in chunk_file(file, tkns_per_chunk, model):
curr_file.append({
'title': file_stem,
'text': chunk,
})
ret.append(curr_file)
return ret
async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkns_per_chunk=300, model='gpt-4', embeddings_model='text-embedding-ada-002'):
files = sorted(glob.glob(glob_pattern))
save_dir = "archival_index_from_files_" + get_local_time().replace(' ', '_').replace(':', '_')
os.makedirs(save_dir, exist_ok=True)
total_tokens = total_bytes(glob_pattern) / 3
price_estimate = total_tokens / 1000 * .0001
confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ")
if confirm != 'y':
raise Exception("embeddings were not computed")
# chunk the files, make embeddings
archival_database = chunk_files(files, tkns_per_chunk, model)
embedding_data = []
for chunk in tqdm(archival_database, desc="Processing file chunks", total=len(archival_database)):
# for chunk in tqdm(f, desc=f"Embedding file {i+1}/{len(chunks_by_file)}", total=len(f), leave=False):
try:
embedding = await async_get_embedding_with_backoff(chunk['content'], model=embeddings_model)
except Exception as e:
print(chunk)
raise e
embedding_data.append(embedding)
embeddings_file = os.path.join(save_dir, "embeddings.json")
with open(embeddings_file, 'w') as f:
print(f"Saving embeddings to {embeddings_file}")
json.dump(embedding_data, f)
# make all_text.json
archival_storage_file = os.path.join(save_dir, "all_docs.jsonl")
chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model)
with open(archival_storage_file, 'w') as f:
print(f"Saving archival storage with preloaded files to {archival_storage_file}")
for c in chunks_by_file:
json.dump(c, f)
f.write('\n')
# make the faiss index
index = faiss.IndexFlatL2(1536)
data = np.array(embedding_data).astype('float32')
try:
index.add(data)
except Exception as e:
print(data)
raise e
index_file = os.path.join(save_dir, "all_docs.index")
print(f"Saving faiss index {index_file}")
faiss.write_index(index, index_file)
return save_dir
def read_database_as_list(database_name):
result_list = []

View File

@ -11,3 +11,4 @@ pytz
rich
tiktoken
timezonefinder
tqdm