mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
support generating embeddings on the fly
This commit is contained in:
parent
924be62eea
commit
cf927b4e86
23
README.md
23
README.md
@ -107,8 +107,10 @@ python main.py --human me.txt
|
||||
enables debugging output
|
||||
--archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH>
|
||||
load in document database (backed by FAISS index)
|
||||
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB>"
|
||||
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB_PATTERN>"
|
||||
pre-load files into archival memory
|
||||
--archival_storage_files_compute_embeddings="<ARCHIVAL_STORAGE_FILES_GLOB_PATTERN>"
|
||||
pre-load files into archival memory and also compute embeddings for embedding search
|
||||
--archival_storage_sqldb=<SQLDB_PATH>
|
||||
load in SQL database
|
||||
```
|
||||
@ -181,6 +183,25 @@ To run our example where you can search over the SEC 10-K filings of Uber, Lyft,
|
||||
```
|
||||
|
||||
If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes).
|
||||
|
||||
#### Enhance with embeddings search
|
||||
In the root `MemGPT` directory, run
|
||||
```bash
|
||||
python3 main.py --archival_storage_files_compute_embeddings="<GLOB_PATTERN>" --persona=memgpt_doc --human=basic
|
||||
```
|
||||
|
||||
This will generate embeddings, stick them into a FAISS index, and write the index to a directory, and then output:
|
||||
```
|
||||
To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings=<GLOB_PATTERN> with
|
||||
--archival_storage_faiss_path=<DIRECTORY_WITH_EMBEDDINGS> (if your files haven't changed).
|
||||
```
|
||||
|
||||
If you want to reuse these embeddings, run
|
||||
```bash
|
||||
python3 main.py --archival_storage_faiss_path="<DIRECTORY_WITH_EMBEDDINGS>" --persona=memgpt_doc --human=basic
|
||||
```
|
||||
|
||||
|
||||
</details>
|
||||
<details>
|
||||
<summary><h3>Talking to LlamaIndex API Docs</h3></summary>
|
||||
|
@ -10,6 +10,9 @@ init(autoreset=True)
|
||||
# DEBUG = True # puts full message outputs in the terminal
|
||||
DEBUG = False # only dumps important messages in the terminal
|
||||
|
||||
def important_message(msg):
|
||||
print(f'{Fore.MAGENTA}{Style.BRIGHT}{msg}{Style.RESET_ALL}')
|
||||
|
||||
async def internal_monologue(msg):
|
||||
# ANSI escape code for italic is '\x1B[3m'
|
||||
print(f'\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}')
|
||||
|
6
main.py
6
main.py
@ -27,6 +27,7 @@ flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to
|
||||
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
|
||||
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)")
|
||||
flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)")
|
||||
flags.DEFINE_string("archival_storage_files_compute_embeddings", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern), and compute embeddings over them")
|
||||
flags.DEFINE_string("archival_storage_sqldb", default="", required=False, help="Specify SQL database to pre-load into archival memory")
|
||||
|
||||
|
||||
@ -54,6 +55,11 @@ async def main():
|
||||
archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files)
|
||||
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
|
||||
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database)
|
||||
elif FLAGS.archival_storage_files_compute_embeddings:
|
||||
faiss_save_dir = await utils.prepare_archival_index_from_files_compute_embeddings(FLAGS.archival_storage_files_compute_embeddings)
|
||||
interface.important_message(f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed).")
|
||||
index, archival_database = utils.prepare_archival_index(faiss_save_dir)
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
|
||||
else:
|
||||
persistence_manager = InMemoryStateManager()
|
||||
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)
|
||||
|
136
memgpt/utils.py
136
memgpt/utils.py
@ -9,6 +9,8 @@ import faiss
|
||||
import tiktoken
|
||||
import glob
|
||||
import sqlite3
|
||||
from tqdm import tqdm
|
||||
from memgpt.openai_tools import async_get_embedding_with_backoff
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@ -97,41 +99,54 @@ def read_in_chunks(file_object, chunk_size):
|
||||
def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
files = glob.glob(glob_pattern)
|
||||
return chunk_files(files, tkns_per_chunk, model)
|
||||
|
||||
def total_bytes(pattern):
|
||||
total = 0
|
||||
for filename in glob.glob(pattern):
|
||||
if os.path.isfile(filename): # ensure it's a file and not a directory
|
||||
total += os.path.getsize(filename)
|
||||
return total
|
||||
|
||||
def chunk_file(file, tkns_per_chunk=300, model='gpt-4'):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
with open(file, 'r') as f:
|
||||
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
for i, line in enumerate(lines):
|
||||
line = line.rstrip()
|
||||
line = line.lstrip()
|
||||
try:
|
||||
line_token_ct = len(encoding.encode(line))
|
||||
except Exception as e:
|
||||
line_token_ct = len(line.split(' ')) / .75
|
||||
print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens")
|
||||
print(e)
|
||||
if line_token_ct > tkns_per_chunk:
|
||||
if len(curr_chunk) > 0:
|
||||
yield ''.join(curr_chunk)
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
yield line[:3200]
|
||||
continue
|
||||
curr_token_ct += line_token_ct
|
||||
curr_chunk.append(line)
|
||||
if curr_token_ct > tkns_per_chunk:
|
||||
yield ''.join(curr_chunk)
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
|
||||
if len(curr_chunk) > 0:
|
||||
yield ''.join(curr_chunk)
|
||||
|
||||
def chunk_files(files, tkns_per_chunk=300, model='gpt-4'):
|
||||
archival_database = []
|
||||
for file in files:
|
||||
timestamp = os.path.getmtime(file)
|
||||
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
with open(file, 'r') as f:
|
||||
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
|
||||
chunks = []
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
for line in lines:
|
||||
line = line.rstrip()
|
||||
line = line.lstrip()
|
||||
try:
|
||||
line_token_ct = len(encoding.encode(line))
|
||||
except Exception as e:
|
||||
line_token_ct = len(line.split(' ')) / .75
|
||||
print(f"Could not encode line {line}, estimating it to be {line_token_ct} tokens")
|
||||
if line_token_ct > tkns_per_chunk:
|
||||
if len(curr_chunk) > 0:
|
||||
chunks.append(''.join(curr_chunk))
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
chunks.append(line[:3200])
|
||||
continue
|
||||
curr_token_ct += line_token_ct
|
||||
curr_chunk.append(line)
|
||||
if curr_token_ct > tkns_per_chunk:
|
||||
chunks.append(''.join(curr_chunk))
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
|
||||
if len(curr_chunk) > 0:
|
||||
chunks.append(''.join(curr_chunk))
|
||||
|
||||
file_stem = file.split('/')[-1]
|
||||
chunks = [c for c in chunk_file(file, tkns_per_chunk, model)]
|
||||
for i, chunk in enumerate(chunks):
|
||||
archival_database.append({
|
||||
'content': f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}",
|
||||
@ -139,6 +154,67 @@ def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='g
|
||||
})
|
||||
return archival_database
|
||||
|
||||
def chunk_files_for_jsonl(files, tkns_per_chunk=300, model='gpt-4'):
|
||||
ret = []
|
||||
for file in files:
|
||||
file_stem = file.split('/')[-1]
|
||||
curr_file = []
|
||||
for chunk in chunk_file(file, tkns_per_chunk, model):
|
||||
curr_file.append({
|
||||
'title': file_stem,
|
||||
'text': chunk,
|
||||
})
|
||||
ret.append(curr_file)
|
||||
return ret
|
||||
|
||||
async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkns_per_chunk=300, model='gpt-4', embeddings_model='text-embedding-ada-002'):
|
||||
files = sorted(glob.glob(glob_pattern))
|
||||
save_dir = "archival_index_from_files_" + get_local_time().replace(' ', '_').replace(':', '_')
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
total_tokens = total_bytes(glob_pattern) / 3
|
||||
price_estimate = total_tokens / 1000 * .0001
|
||||
confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ")
|
||||
if confirm != 'y':
|
||||
raise Exception("embeddings were not computed")
|
||||
|
||||
# chunk the files, make embeddings
|
||||
archival_database = chunk_files(files, tkns_per_chunk, model)
|
||||
embedding_data = []
|
||||
for chunk in tqdm(archival_database, desc="Processing file chunks", total=len(archival_database)):
|
||||
# for chunk in tqdm(f, desc=f"Embedding file {i+1}/{len(chunks_by_file)}", total=len(f), leave=False):
|
||||
try:
|
||||
embedding = await async_get_embedding_with_backoff(chunk['content'], model=embeddings_model)
|
||||
except Exception as e:
|
||||
print(chunk)
|
||||
raise e
|
||||
embedding_data.append(embedding)
|
||||
embeddings_file = os.path.join(save_dir, "embeddings.json")
|
||||
with open(embeddings_file, 'w') as f:
|
||||
print(f"Saving embeddings to {embeddings_file}")
|
||||
json.dump(embedding_data, f)
|
||||
|
||||
# make all_text.json
|
||||
archival_storage_file = os.path.join(save_dir, "all_docs.jsonl")
|
||||
chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model)
|
||||
with open(archival_storage_file, 'w') as f:
|
||||
print(f"Saving archival storage with preloaded files to {archival_storage_file}")
|
||||
for c in chunks_by_file:
|
||||
json.dump(c, f)
|
||||
f.write('\n')
|
||||
|
||||
# make the faiss index
|
||||
index = faiss.IndexFlatL2(1536)
|
||||
data = np.array(embedding_data).astype('float32')
|
||||
try:
|
||||
index.add(data)
|
||||
except Exception as e:
|
||||
print(data)
|
||||
raise e
|
||||
index_file = os.path.join(save_dir, "all_docs.index")
|
||||
print(f"Saving faiss index {index_file}")
|
||||
faiss.write_index(index, index_file)
|
||||
return save_dir
|
||||
|
||||
def read_database_as_list(database_name):
|
||||
result_list = []
|
||||
|
||||
|
@ -11,3 +11,4 @@ pytz
|
||||
rich
|
||||
tiktoken
|
||||
timezonefinder
|
||||
tqdm
|
||||
|
Loading…
Reference in New Issue
Block a user