mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
140 lines
4.8 KiB
Python
140 lines
4.8 KiB
Python
import os
|
|
import threading
|
|
from contextlib import contextmanager
|
|
|
|
from rich.console import Console
|
|
from rich.panel import Panel
|
|
from rich.text import Text
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from letta.config import LettaConfig
|
|
from letta.log import get_logger
|
|
from letta.orm import Base
|
|
from letta.settings import settings
|
|
|
|
# Use globals for the lock and initialization flag
|
|
_engine_lock = threading.Lock()
|
|
_engine_initialized = False
|
|
|
|
# Create variables in global scope but don't initialize them yet
|
|
config = LettaConfig.load()
|
|
logger = get_logger(__name__)
|
|
engine = None
|
|
SessionLocal = None
|
|
|
|
|
|
def print_sqlite_schema_error():
|
|
"""Print a formatted error message for SQLite schema issues"""
|
|
console = Console()
|
|
error_text = Text()
|
|
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
|
|
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
|
|
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
|
|
error_text.append(") or use Postgres by setting ", style="white")
|
|
error_text.append("LETTA_PG_URI", style="yellow")
|
|
error_text.append(".\n\n", style="white")
|
|
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
|
|
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
|
|
error_text.append(" or downgrade to your previous version of Letta.", style="white")
|
|
|
|
console.print(Panel(error_text, border_style="red"))
|
|
|
|
|
|
@contextmanager
|
|
def db_error_handler():
|
|
"""Context manager for handling database errors"""
|
|
try:
|
|
yield
|
|
except Exception as e:
|
|
# Handle other SQLAlchemy errors
|
|
print(e)
|
|
print_sqlite_schema_error()
|
|
# raise ValueError(f"SQLite DB error: {str(e)}")
|
|
exit(1)
|
|
|
|
|
|
def initialize_engine():
|
|
"""Initialize the database engine only when needed."""
|
|
global engine, SessionLocal, _engine_initialized
|
|
|
|
with _engine_lock:
|
|
# Check again inside the lock to prevent race conditions
|
|
if _engine_initialized:
|
|
return
|
|
|
|
if settings.letta_pg_uri_no_default:
|
|
logger.info("Creating postgres engine")
|
|
config.recall_storage_type = "postgres"
|
|
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
|
config.archival_storage_type = "postgres"
|
|
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
|
|
|
# create engine
|
|
engine = create_engine(
|
|
settings.letta_pg_uri,
|
|
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
|
|
pool_size=settings.pg_pool_size,
|
|
max_overflow=settings.pg_max_overflow,
|
|
pool_timeout=settings.pg_pool_timeout,
|
|
pool_recycle=settings.pg_pool_recycle,
|
|
echo=settings.pg_echo,
|
|
# connect_args={"client_encoding": "utf8"},
|
|
)
|
|
else:
|
|
# TODO: don't rely on config storage
|
|
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
|
|
logger.info("Creating sqlite engine " + engine_path)
|
|
|
|
engine = create_engine(engine_path)
|
|
|
|
# Store the original connect method
|
|
original_connect = engine.connect
|
|
|
|
def wrapped_connect(*args, **kwargs):
|
|
with db_error_handler():
|
|
# Get the connection
|
|
connection = original_connect(*args, **kwargs)
|
|
|
|
# Store the original execution method
|
|
original_execute = connection.execute
|
|
|
|
# Wrap the execute method of the connection
|
|
def wrapped_execute(*args, **kwargs):
|
|
with db_error_handler():
|
|
return original_execute(*args, **kwargs)
|
|
|
|
# Replace the connection's execute method
|
|
connection.execute = wrapped_execute
|
|
|
|
return connection
|
|
|
|
# Replace the engine's connect method
|
|
engine.connect = wrapped_connect
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
# Create the session factory
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
_engine_initialized = True
|
|
|
|
|
|
def get_db():
|
|
"""Get a database session, initializing the engine if needed."""
|
|
global engine, SessionLocal
|
|
|
|
# Make sure engine is initialized
|
|
if not _engine_initialized:
|
|
initialize_engine()
|
|
|
|
# Now SessionLocal should be defined and callable
|
|
db = SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# Define db_context as a context manager that uses get_db
|
|
db_context = contextmanager(get_db)
|