mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
support generating embeddings on the fly
This commit is contained in:
parent
924be62eea
commit
cf927b4e86
23
README.md
23
README.md
@ -107,8 +107,10 @@ python main.py --human me.txt
|
|||||||
enables debugging output
|
enables debugging output
|
||||||
--archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH>
|
--archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH>
|
||||||
load in document database (backed by FAISS index)
|
load in document database (backed by FAISS index)
|
||||||
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB>"
|
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB_PATTERN>"
|
||||||
pre-load files into archival memory
|
pre-load files into archival memory
|
||||||
|
--archival_storage_files_compute_embeddings="<ARCHIVAL_STORAGE_FILES_GLOB_PATTERN>"
|
||||||
|
pre-load files into archival memory and also compute embeddings for embedding search
|
||||||
--archival_storage_sqldb=<SQLDB_PATH>
|
--archival_storage_sqldb=<SQLDB_PATH>
|
||||||
load in SQL database
|
load in SQL database
|
||||||
```
|
```
|
||||||
@ -181,6 +183,25 @@ To run our example where you can search over the SEC 10-K filings of Uber, Lyft,
|
|||||||
```
|
```
|
||||||
|
|
||||||
If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes).
|
If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes).
|
||||||
|
|
||||||
|
#### Enhance with embeddings search
|
||||||
|
In the root `MemGPT` directory, run
|
||||||
|
```bash
|
||||||
|
python3 main.py --archival_storage_files_compute_embeddings="<GLOB_PATTERN>" --persona=memgpt_doc --human=basic
|
||||||
|
```
|
||||||
|
|
||||||
|
This will generate embeddings, stick them into a FAISS index, and write the index to a directory, and then output:
|
||||||
|
```
|
||||||
|
To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings=<GLOB_PATTERN> with
|
||||||
|
--archival_storage_faiss_path=<DIRECTORY_WITH_EMBEDDINGS> (if your files haven't changed).
|
||||||
|
```
|
||||||
|
|
||||||
|
If you want to reuse these embeddings, run
|
||||||
|
```bash
|
||||||
|
python3 main.py --archival_storage_faiss_path="<DIRECTORY_WITH_EMBEDDINGS>" --persona=memgpt_doc --human=basic
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
<details>
|
<details>
|
||||||
<summary><h3>Talking to LlamaIndex API Docs</h3></summary>
|
<summary><h3>Talking to LlamaIndex API Docs</h3></summary>
|
||||||
|
@ -10,6 +10,9 @@ init(autoreset=True)
|
|||||||
# DEBUG = True # puts full message outputs in the terminal
|
# DEBUG = True # puts full message outputs in the terminal
|
||||||
DEBUG = False # only dumps important messages in the terminal
|
DEBUG = False # only dumps important messages in the terminal
|
||||||
|
|
||||||
|
def important_message(msg):
|
||||||
|
print(f'{Fore.MAGENTA}{Style.BRIGHT}{msg}{Style.RESET_ALL}')
|
||||||
|
|
||||||
async def internal_monologue(msg):
|
async def internal_monologue(msg):
|
||||||
# ANSI escape code for italic is '\x1B[3m'
|
# ANSI escape code for italic is '\x1B[3m'
|
||||||
print(f'\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}')
|
print(f'\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}')
|
||||||
|
6
main.py
6
main.py
@ -27,6 +27,7 @@ flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to
|
|||||||
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
|
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
|
||||||
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)")
|
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)")
|
||||||
flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)")
|
flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)")
|
||||||
|
flags.DEFINE_string("archival_storage_files_compute_embeddings", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern), and compute embeddings over them")
|
||||||
flags.DEFINE_string("archival_storage_sqldb", default="", required=False, help="Specify SQL database to pre-load into archival memory")
|
flags.DEFINE_string("archival_storage_sqldb", default="", required=False, help="Specify SQL database to pre-load into archival memory")
|
||||||
|
|
||||||
|
|
||||||
@ -54,6 +55,11 @@ async def main():
|
|||||||
archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files)
|
archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files)
|
||||||
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
|
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
|
||||||
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database)
|
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database)
|
||||||
|
elif FLAGS.archival_storage_files_compute_embeddings:
|
||||||
|
faiss_save_dir = await utils.prepare_archival_index_from_files_compute_embeddings(FLAGS.archival_storage_files_compute_embeddings)
|
||||||
|
interface.important_message(f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed).")
|
||||||
|
index, archival_database = utils.prepare_archival_index(faiss_save_dir)
|
||||||
|
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
|
||||||
else:
|
else:
|
||||||
persistence_manager = InMemoryStateManager()
|
persistence_manager = InMemoryStateManager()
|
||||||
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)
|
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)
|
||||||
|
136
memgpt/utils.py
136
memgpt/utils.py
@ -9,6 +9,8 @@ import faiss
|
|||||||
import tiktoken
|
import tiktoken
|
||||||
import glob
|
import glob
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from tqdm import tqdm
|
||||||
|
from memgpt.openai_tools import async_get_embedding_with_backoff
|
||||||
|
|
||||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
@ -97,41 +99,54 @@ def read_in_chunks(file_object, chunk_size):
|
|||||||
def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'):
|
def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'):
|
||||||
encoding = tiktoken.encoding_for_model(model)
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
files = glob.glob(glob_pattern)
|
files = glob.glob(glob_pattern)
|
||||||
|
return chunk_files(files, tkns_per_chunk, model)
|
||||||
|
|
||||||
|
def total_bytes(pattern):
|
||||||
|
total = 0
|
||||||
|
for filename in glob.glob(pattern):
|
||||||
|
if os.path.isfile(filename): # ensure it's a file and not a directory
|
||||||
|
total += os.path.getsize(filename)
|
||||||
|
return total
|
||||||
|
|
||||||
|
def chunk_file(file, tkns_per_chunk=300, model='gpt-4'):
|
||||||
|
encoding = tiktoken.encoding_for_model(model)
|
||||||
|
with open(file, 'r') as f:
|
||||||
|
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
|
||||||
|
curr_chunk = []
|
||||||
|
curr_token_ct = 0
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
line = line.rstrip()
|
||||||
|
line = line.lstrip()
|
||||||
|
try:
|
||||||
|
line_token_ct = len(encoding.encode(line))
|
||||||
|
except Exception as e:
|
||||||
|
line_token_ct = len(line.split(' ')) / .75
|
||||||
|
print(f"Could not encode line {i}, estimating it to be {line_token_ct} tokens")
|
||||||
|
print(e)
|
||||||
|
if line_token_ct > tkns_per_chunk:
|
||||||
|
if len(curr_chunk) > 0:
|
||||||
|
yield ''.join(curr_chunk)
|
||||||
|
curr_chunk = []
|
||||||
|
curr_token_ct = 0
|
||||||
|
yield line[:3200]
|
||||||
|
continue
|
||||||
|
curr_token_ct += line_token_ct
|
||||||
|
curr_chunk.append(line)
|
||||||
|
if curr_token_ct > tkns_per_chunk:
|
||||||
|
yield ''.join(curr_chunk)
|
||||||
|
curr_chunk = []
|
||||||
|
curr_token_ct = 0
|
||||||
|
|
||||||
|
if len(curr_chunk) > 0:
|
||||||
|
yield ''.join(curr_chunk)
|
||||||
|
|
||||||
|
def chunk_files(files, tkns_per_chunk=300, model='gpt-4'):
|
||||||
archival_database = []
|
archival_database = []
|
||||||
for file in files:
|
for file in files:
|
||||||
timestamp = os.path.getmtime(file)
|
timestamp = os.path.getmtime(file)
|
||||||
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||||
with open(file, 'r') as f:
|
|
||||||
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
|
|
||||||
chunks = []
|
|
||||||
curr_chunk = []
|
|
||||||
curr_token_ct = 0
|
|
||||||
for line in lines:
|
|
||||||
line = line.rstrip()
|
|
||||||
line = line.lstrip()
|
|
||||||
try:
|
|
||||||
line_token_ct = len(encoding.encode(line))
|
|
||||||
except Exception as e:
|
|
||||||
line_token_ct = len(line.split(' ')) / .75
|
|
||||||
print(f"Could not encode line {line}, estimating it to be {line_token_ct} tokens")
|
|
||||||
if line_token_ct > tkns_per_chunk:
|
|
||||||
if len(curr_chunk) > 0:
|
|
||||||
chunks.append(''.join(curr_chunk))
|
|
||||||
curr_chunk = []
|
|
||||||
curr_token_ct = 0
|
|
||||||
chunks.append(line[:3200])
|
|
||||||
continue
|
|
||||||
curr_token_ct += line_token_ct
|
|
||||||
curr_chunk.append(line)
|
|
||||||
if curr_token_ct > tkns_per_chunk:
|
|
||||||
chunks.append(''.join(curr_chunk))
|
|
||||||
curr_chunk = []
|
|
||||||
curr_token_ct = 0
|
|
||||||
|
|
||||||
if len(curr_chunk) > 0:
|
|
||||||
chunks.append(''.join(curr_chunk))
|
|
||||||
|
|
||||||
file_stem = file.split('/')[-1]
|
file_stem = file.split('/')[-1]
|
||||||
|
chunks = [c for c in chunk_file(file, tkns_per_chunk, model)]
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
archival_database.append({
|
archival_database.append({
|
||||||
'content': f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}",
|
'content': f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}",
|
||||||
@ -139,6 +154,67 @@ def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='g
|
|||||||
})
|
})
|
||||||
return archival_database
|
return archival_database
|
||||||
|
|
||||||
|
def chunk_files_for_jsonl(files, tkns_per_chunk=300, model='gpt-4'):
|
||||||
|
ret = []
|
||||||
|
for file in files:
|
||||||
|
file_stem = file.split('/')[-1]
|
||||||
|
curr_file = []
|
||||||
|
for chunk in chunk_file(file, tkns_per_chunk, model):
|
||||||
|
curr_file.append({
|
||||||
|
'title': file_stem,
|
||||||
|
'text': chunk,
|
||||||
|
})
|
||||||
|
ret.append(curr_file)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkns_per_chunk=300, model='gpt-4', embeddings_model='text-embedding-ada-002'):
|
||||||
|
files = sorted(glob.glob(glob_pattern))
|
||||||
|
save_dir = "archival_index_from_files_" + get_local_time().replace(' ', '_').replace(':', '_')
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
total_tokens = total_bytes(glob_pattern) / 3
|
||||||
|
price_estimate = total_tokens / 1000 * .0001
|
||||||
|
confirm = input(f"Computing embeddings over {len(files)} files. This will cost ~${price_estimate:.2f}. Continue? [y/n] ")
|
||||||
|
if confirm != 'y':
|
||||||
|
raise Exception("embeddings were not computed")
|
||||||
|
|
||||||
|
# chunk the files, make embeddings
|
||||||
|
archival_database = chunk_files(files, tkns_per_chunk, model)
|
||||||
|
embedding_data = []
|
||||||
|
for chunk in tqdm(archival_database, desc="Processing file chunks", total=len(archival_database)):
|
||||||
|
# for chunk in tqdm(f, desc=f"Embedding file {i+1}/{len(chunks_by_file)}", total=len(f), leave=False):
|
||||||
|
try:
|
||||||
|
embedding = await async_get_embedding_with_backoff(chunk['content'], model=embeddings_model)
|
||||||
|
except Exception as e:
|
||||||
|
print(chunk)
|
||||||
|
raise e
|
||||||
|
embedding_data.append(embedding)
|
||||||
|
embeddings_file = os.path.join(save_dir, "embeddings.json")
|
||||||
|
with open(embeddings_file, 'w') as f:
|
||||||
|
print(f"Saving embeddings to {embeddings_file}")
|
||||||
|
json.dump(embedding_data, f)
|
||||||
|
|
||||||
|
# make all_text.json
|
||||||
|
archival_storage_file = os.path.join(save_dir, "all_docs.jsonl")
|
||||||
|
chunks_by_file = chunk_files_for_jsonl(files, tkns_per_chunk, model)
|
||||||
|
with open(archival_storage_file, 'w') as f:
|
||||||
|
print(f"Saving archival storage with preloaded files to {archival_storage_file}")
|
||||||
|
for c in chunks_by_file:
|
||||||
|
json.dump(c, f)
|
||||||
|
f.write('\n')
|
||||||
|
|
||||||
|
# make the faiss index
|
||||||
|
index = faiss.IndexFlatL2(1536)
|
||||||
|
data = np.array(embedding_data).astype('float32')
|
||||||
|
try:
|
||||||
|
index.add(data)
|
||||||
|
except Exception as e:
|
||||||
|
print(data)
|
||||||
|
raise e
|
||||||
|
index_file = os.path.join(save_dir, "all_docs.index")
|
||||||
|
print(f"Saving faiss index {index_file}")
|
||||||
|
faiss.write_index(index, index_file)
|
||||||
|
return save_dir
|
||||||
|
|
||||||
def read_database_as_list(database_name):
|
def read_database_as_list(database_name):
|
||||||
result_list = []
|
result_list = []
|
||||||
|
|
||||||
|
@ -11,3 +11,4 @@ pytz
|
|||||||
rich
|
rich
|
||||||
tiktoken
|
tiktoken
|
||||||
timezonefinder
|
timezonefinder
|
||||||
|
tqdm
|
||||||
|
Loading…
Reference in New Issue
Block a user