Add SQLite integration for recall memory

This commit is contained in:
Sarah Wooders 2023-12-22 11:33:37 +04:00
parent 1caf07db14
commit 65b99a7a59
5 changed files with 100 additions and 39 deletions

View File

@ -70,12 +70,12 @@ class MemGPTConfig:
# database configs: archival
archival_storage_type: str = "local" # local, db
archival_storage_path: str = None # TODO: set to memgpt dir
archival_storage_path: str = MEMGPT_DIR # TODO: set to memgpt dir
archival_storage_uri: str = None # TODO: eventually allow external vector DB
# database configs: recall
recall_storage_type: str = "local" # local, db
recall_storage_path: str = None # TODO: set to memgpt dir
recall_storage_path: str = MEMGPT_DIR
recall_storage_uri: str = None # TODO: eventually allow external vector DB
# database configs: agent state

View File

@ -1,9 +1,10 @@
from pgvector.psycopg import register_vector
import os
from pgvector.sqlalchemy import Vector
import psycopg
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker, mapped_column
from sqlalchemy.ext.declarative import declarative_base
@ -11,6 +12,7 @@ from sqlalchemy.sql import func
from sqlalchemy import Column, BIGINT, String, DateTime
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy_json import mutable_json_type
from sqlalchemy import TypeDecorator, CHAR
import uuid
import re
@ -29,6 +31,31 @@ from memgpt.data_types import Record, Message, Passage
from datetime import datetime
# Custom UUID type
class CommonUUID(TypeDecorator):
impl = CHAR
def load_dialect_impl(self, dialect):
if dialect.name == "postgresql":
return dialect.type_descriptor(UUID(as_uuid=True))
else:
return dialect.type_descriptor(CHAR())
def process_bind_param(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return str(value) # Convert UUID to string for SQLite
def process_result_value(self, value, dialect):
if dialect.name == "postgresql" or value is None:
return value
else:
return uuid.UUID(value)
Base = declarative_base()
@ -43,7 +70,9 @@ def get_db_model(table_name: str, table_type: TableType):
__abstract__ = True # this line is necessary
# Assuming passage_id is the primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String, nullable=False)
text = Column(String, nullable=False)
doc_id = Column(String)
@ -79,7 +108,9 @@ def get_db_model(table_name: str, table_type: TableType):
__abstract__ = True # this line is necessary
# Assuming message_id is the primary key
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
@ -127,34 +158,9 @@ def get_db_model(table_name: str, table_type: TableType):
class SQLStorageConnector(StorageConnector):
"""Storage via Postgres"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
# TODO: only support recall memory (need postgres for archival)
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY:
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = config.recall_storage_uri
if config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
# create table
self.db_model = get_db_model(self.table_name, table_type)
self.engine = create_engine(self.uri)
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
self.Session = sessionmaker(bind=self.engine)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
self.config = MemGPTConfig.load()
def get_filters(self, filters: Optional[Dict] = {}):
if filters is not None:
@ -279,6 +285,23 @@ class PostgresStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY:
self.uri = self.config.archival_storage_uri
if self.config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
elif table_type == TableType.RECALL_MEMORY:
self.uri = self.config.recall_storage_uri
if self.config.recall_storage_uri is None:
raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
# create table
self.db_model = get_db_model(self.table_name, table_type)
self.engine = create_engine(self.uri)
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
self.Session = sessionmaker(bind=self.engine)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
@ -320,6 +343,36 @@ class PostgresStorageConnector(SQLStorageConnector):
return records
class SQLLiteStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY:
raise ValueError(f"Table type {table_type} not implemented")
elif table_type == TableType.RECALL_MEMORY:
# TODO: eventually implement URI option
self.path = self.config.recall_storage_path
if self.path is None:
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
else:
raise ValueError(f"Table type {table_type} not implemented")
self.path = os.path.join(self.path, f"{self.table_name}.db")
self.db_model = get_db_model(self.table_name, table_type)
# Create the SQLAlchemy engine
self.db_model = get_db_model(self.table_name, table_type)
self.engine = create_engine(f"sqlite:///{self.path}")
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
self.Session = sessionmaker(bind=self.engine)
import sqlite3
sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""

View File

@ -122,6 +122,11 @@ class StorageConnector:
return InMemoryStorageConnector(agent_config=agent_config, table_type=table_type)
elif storage_type == "sqllite":
from memgpt.connectors.db import SQLLiteStorageConnector
return SQLLiteStorageConnector(agent_config=agent_config, table_type=table_type)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@ -144,6 +149,7 @@ class StorageConnector:
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector
# TODO: remove
return VectorIndexStorageConnector.list_loaded_data()
elif storage_type == "postgres":
from memgpt.connectors.db import PostgresStorageConnector

View File

@ -23,6 +23,8 @@ class Record:
self.id = uuid.uuid4()
else:
self.id = id
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
# todo: generate unique uuid
# todo: self.role = role (?)
@ -78,8 +80,8 @@ class Document(Record):
self.data_source = data_source
# TODO: add optional embedding?
def __repr__(self) -> str:
pass
# def __repr__(self) -> str:
# pass
class Passage(Record):
@ -106,5 +108,5 @@ class Passage(Record):
self.doc_id = doc_id
self.metadata = metadata
def __repr__(self):
return str(vars(self))
# def __repr__(self):
# pass

View File

@ -56,8 +56,8 @@ def generate_messages():
return messages
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"])
# @pytest.mark.parametrize("storage_connector", ["postgres"])
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqllite", "lancedb"])
# @pytest.mark.parametrize("storage_connector", ["sqllite"])
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
def test_storage(storage_connector, table_type):
@ -86,9 +86,9 @@ def test_storage(storage_connector, table_type):
return
config.archival_storage_type = "chroma"
config.archival_storage_path = "./test_chroma"
if storage_connector == "local":
if storage_connector == "sqllite":
if table_type == TableType.ARCHIVAL_MEMORY:
print("Skipping test, local only supported for recall memory")
print("Skipping test, sqllite only supported for recall memory")
return
config.recall_storage_type = "local"