MemGPT/letta/server/db.py
2025-04-15 12:01:46 -07:00

140 lines
4.8 KiB
Python

import os
import threading
from contextlib import contextmanager
from rich.console import Console
from rich.panel import Panel
from rich.text import Text
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
from letta.log import get_logger
from letta.orm import Base
from letta.settings import settings
# Use globals for the lock and initialization flag
_engine_lock = threading.Lock()
_engine_initialized = False
# Create variables in global scope but don't initialize them yet
config = LettaConfig.load()
logger = get_logger(__name__)
engine = None
SessionLocal = None
def print_sqlite_schema_error():
"""Print a formatted error message for SQLite schema issues"""
console = Console()
error_text = Text()
error_text.append("Existing SQLite DB schema is invalid, and schema migrations are not supported for SQLite. ", style="bold red")
error_text.append("To have migrations supported between Letta versions, please run Letta with Docker (", style="white")
error_text.append("https://docs.letta.com/server/docker", style="blue underline")
error_text.append(") or use Postgres by setting ", style="white")
error_text.append("LETTA_PG_URI", style="yellow")
error_text.append(".\n\n", style="white")
error_text.append("If you wish to keep using SQLite, you can reset your database by removing the DB file with ", style="white")
error_text.append("rm ~/.letta/sqlite.db", style="yellow")
error_text.append(" or downgrade to your previous version of Letta.", style="white")
console.print(Panel(error_text, border_style="red"))
@contextmanager
def db_error_handler():
"""Context manager for handling database errors"""
try:
yield
except Exception as e:
# Handle other SQLAlchemy errors
print(e)
print_sqlite_schema_error()
# raise ValueError(f"SQLite DB error: {str(e)}")
exit(1)
def initialize_engine():
"""Initialize the database engine only when needed."""
global engine, SessionLocal, _engine_initialized
with _engine_lock:
# Check again inside the lock to prevent race conditions
if _engine_initialized:
return
if settings.letta_pg_uri_no_default:
logger.info("Creating postgres engine")
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
config.archival_storage_type = "postgres"
config.archival_storage_uri = settings.letta_pg_uri_no_default
# create engine
engine = create_engine(
settings.letta_pg_uri,
# f"{settings.letta_pg_uri}?options=-c%20client_encoding=UTF8",
pool_size=settings.pg_pool_size,
max_overflow=settings.pg_max_overflow,
pool_timeout=settings.pg_pool_timeout,
pool_recycle=settings.pg_pool_recycle,
echo=settings.pg_echo,
# connect_args={"client_encoding": "utf8"},
)
else:
# TODO: don't rely on config storage
engine_path = "sqlite:///" + os.path.join(config.recall_storage_path, "sqlite.db")
logger.info("Creating sqlite engine " + engine_path)
engine = create_engine(engine_path)
# Store the original connect method
original_connect = engine.connect
def wrapped_connect(*args, **kwargs):
with db_error_handler():
# Get the connection
connection = original_connect(*args, **kwargs)
# Store the original execution method
original_execute = connection.execute
# Wrap the execute method of the connection
def wrapped_execute(*args, **kwargs):
with db_error_handler():
return original_execute(*args, **kwargs)
# Replace the connection's execute method
connection.execute = wrapped_execute
return connection
# Replace the engine's connect method
engine.connect = wrapped_connect
Base.metadata.create_all(bind=engine)
# Create the session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
_engine_initialized = True
def get_db():
"""Get a database session, initializing the engine if needed."""
global engine, SessionLocal
# Make sure engine is initialized
if not _engine_initialized:
initialize_engine()
# Now SessionLocal should be defined and callable
db = SessionLocal()
try:
yield db
finally:
db.close()
# Define db_context as a context manager that uses get_db
db_context = contextmanager(get_db)