MemGPT/paper_experiments/utils.py
Sarah Wooders 85faf5f474
chore: migrate package name to letta (#1775)
Co-authored-by: Charles Packer <packercharles@gmail.com>
Co-authored-by: Shubham Naik <shubham.naik10@gmail.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
2024-09-23 09:15:18 -07:00

70 lines
2.4 KiB
Python

import gzip
import json
from typing import List
from letta.config import LettaConfig
from letta.constants import LLM_MAX_TOKENS
from letta.data_types import EmbeddingConfig, LLMConfig
def load_gzipped_file(file_path):
with gzip.open(file_path, "rt", encoding="utf-8") as f:
for line in f:
yield json.loads(line)
def read_jsonl(filename) -> List[dict]:
lines = []
with open(filename, "r") as file:
for line in file:
lines.append(json.loads(line.strip()))
return lines
def get_experiment_config(postgres_uri, endpoint_type="openai", model="gpt-4"):
config = LettaConfig.load()
config.archival_storage_type = "postgres"
config.archival_storage_uri = postgres_uri
if endpoint_type == "openai":
llm_config = LLMConfig(
model=model, model_endpoint_type="openai", model_endpoint="https://api.openai.com/v1", context_window=LLM_MAX_TOKENS[model]
)
embedding_config = EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_model="text-embedding-ada-002",
embedding_chunk_size=300, # TODO: fix this
)
else:
assert model == "ehartford/dolphin-2.5-mixtral-8x7b", "Only model supported is ehartford/dolphin-2.5-mixtral-8x7b"
llm_config = LLMConfig(
model="ehartford/dolphin-2.5-mixtral-8x7b",
model_endpoint_type="vllm",
model_endpoint="https://api.letta.ai",
model_wrapper="chatml",
context_window=16384,
)
embedding_config = EmbeddingConfig(
embedding_endpoint_type="hugging-face",
embedding_endpoint="https://embeddings.letta.ai",
embedding_dim=1024,
embedding_model="BAAI/bge-large-en-v1.5",
embedding_chunk_size=300,
)
config = LettaConfig(
anon_clientid=config.anon_clientid,
archival_storage_type="postgres",
archival_storage_uri=postgres_uri,
recall_storage_type="postgres",
recall_storage_uri=postgres_uri,
metadata_storage_type="postgres",
metadata_storage_uri=postgres_uri,
default_llm_config=llm_config,
default_embedding_config=embedding_config,
)
print("Config model", config.default_llm_config.model)
return config