mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Add SQLite integration for recall memory
This commit is contained in:
parent
1caf07db14
commit
65b99a7a59
@ -70,12 +70,12 @@ class MemGPTConfig:
|
||||
|
||||
# database configs: archival
|
||||
archival_storage_type: str = "local" # local, db
|
||||
archival_storage_path: str = None # TODO: set to memgpt dir
|
||||
archival_storage_path: str = MEMGPT_DIR # TODO: set to memgpt dir
|
||||
archival_storage_uri: str = None # TODO: eventually allow external vector DB
|
||||
|
||||
# database configs: recall
|
||||
recall_storage_type: str = "local" # local, db
|
||||
recall_storage_path: str = None # TODO: set to memgpt dir
|
||||
recall_storage_path: str = MEMGPT_DIR
|
||||
recall_storage_uri: str = None # TODO: eventually allow external vector DB
|
||||
|
||||
# database configs: agent state
|
||||
|
@ -1,9 +1,10 @@
|
||||
from pgvector.psycopg import register_vector
|
||||
import os
|
||||
from pgvector.sqlalchemy import Vector
|
||||
import psycopg
|
||||
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import sessionmaker, mapped_column
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
@ -11,6 +12,7 @@ from sqlalchemy.sql import func
|
||||
from sqlalchemy import Column, BIGINT, String, DateTime
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy_json import mutable_json_type
|
||||
from sqlalchemy import TypeDecorator, CHAR
|
||||
import uuid
|
||||
|
||||
import re
|
||||
@ -29,6 +31,31 @@ from memgpt.data_types import Record, Message, Passage
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# Custom UUID type
|
||||
class CommonUUID(TypeDecorator):
|
||||
|
||||
impl = CHAR
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == "postgresql":
|
||||
return dialect.type_descriptor(UUID(as_uuid=True))
|
||||
else:
|
||||
return dialect.type_descriptor(CHAR())
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if dialect.name == "postgresql" or value is None:
|
||||
return value
|
||||
else:
|
||||
return str(value) # Convert UUID to string for SQLite
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if dialect.name == "postgresql" or value is None:
|
||||
return value
|
||||
else:
|
||||
return uuid.UUID(value)
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
@ -43,7 +70,9 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming passage_id is the primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(String, nullable=False)
|
||||
text = Column(String, nullable=False)
|
||||
doc_id = Column(String)
|
||||
@ -79,7 +108,9 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
__abstract__ = True # this line is necessary
|
||||
|
||||
# Assuming message_id is the primary key
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
# id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
# id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
user_id = Column(String, nullable=False)
|
||||
agent_id = Column(String, nullable=False)
|
||||
|
||||
@ -127,34 +158,9 @@ def get_db_model(table_name: str, table_type: TableType):
|
||||
|
||||
|
||||
class SQLStorageConnector(StorageConnector):
|
||||
"""Storage via Postgres"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# TODO: only support recall memory (need postgres for archival)
|
||||
|
||||
# get storage URI
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = config.recall_storage_uri
|
||||
if config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}")
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
# create table
|
||||
self.db_model = get_db_model(self.table_name, table_type)
|
||||
self.engine = create_engine(self.uri)
|
||||
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
self.config = MemGPTConfig.load()
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}):
|
||||
if filters is not None:
|
||||
@ -279,6 +285,23 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
|
||||
# get storage URI
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
self.uri = self.config.archival_storage_uri
|
||||
if self.config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
self.uri = self.config.recall_storage_uri
|
||||
if self.config.recall_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}")
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
# create table
|
||||
self.db_model = get_db_model(self.table_name, table_type)
|
||||
self.engine = create_engine(self.uri)
|
||||
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
@ -320,6 +343,36 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
return records
|
||||
|
||||
|
||||
class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
|
||||
# get storage URI
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
# TODO: eventually implement URI option
|
||||
self.path = self.config.recall_storage_path
|
||||
if self.path is None:
|
||||
raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}")
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
self.path = os.path.join(self.path, f"{self.table_name}.db")
|
||||
self.db_model = get_db_model(self.table_name, table_type)
|
||||
|
||||
# Create the SQLAlchemy engine
|
||||
self.db_model = get_db_model(self.table_name, table_type)
|
||||
self.engine = create_engine(f"sqlite:///{self.path}")
|
||||
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
||||
self.Session = sessionmaker(bind=self.engine)
|
||||
|
||||
import sqlite3
|
||||
|
||||
sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
|
||||
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
|
||||
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
|
||||
|
@ -122,6 +122,11 @@ class StorageConnector:
|
||||
|
||||
return InMemoryStorageConnector(agent_config=agent_config, table_type=table_type)
|
||||
|
||||
elif storage_type == "sqllite":
|
||||
from memgpt.connectors.db import SQLLiteStorageConnector
|
||||
|
||||
return SQLLiteStorageConnector(agent_config=agent_config, table_type=table_type)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@ -144,6 +149,7 @@ class StorageConnector:
|
||||
if storage_type == "local":
|
||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||
|
||||
# TODO: remove
|
||||
return VectorIndexStorageConnector.list_loaded_data()
|
||||
elif storage_type == "postgres":
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
|
@ -23,6 +23,8 @@ class Record:
|
||||
self.id = uuid.uuid4()
|
||||
else:
|
||||
self.id = id
|
||||
|
||||
assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type"
|
||||
# todo: generate unique uuid
|
||||
# todo: self.role = role (?)
|
||||
|
||||
@ -78,8 +80,8 @@ class Document(Record):
|
||||
self.data_source = data_source
|
||||
# TODO: add optional embedding?
|
||||
|
||||
def __repr__(self) -> str:
|
||||
pass
|
||||
# def __repr__(self) -> str:
|
||||
# pass
|
||||
|
||||
|
||||
class Passage(Record):
|
||||
@ -106,5 +108,5 @@ class Passage(Record):
|
||||
self.doc_id = doc_id
|
||||
self.metadata = metadata
|
||||
|
||||
def __repr__(self):
|
||||
return str(vars(self))
|
||||
# def __repr__(self):
|
||||
# pass
|
||||
|
@ -56,8 +56,8 @@ def generate_messages():
|
||||
return messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"])
|
||||
# @pytest.mark.parametrize("storage_connector", ["postgres"])
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqllite", "lancedb"])
|
||||
# @pytest.mark.parametrize("storage_connector", ["sqllite"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
|
||||
def test_storage(storage_connector, table_type):
|
||||
|
||||
@ -86,9 +86,9 @@ def test_storage(storage_connector, table_type):
|
||||
return
|
||||
config.archival_storage_type = "chroma"
|
||||
config.archival_storage_path = "./test_chroma"
|
||||
if storage_connector == "local":
|
||||
if storage_connector == "sqllite":
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
print("Skipping test, local only supported for recall memory")
|
||||
print("Skipping test, sqllite only supported for recall memory")
|
||||
return
|
||||
config.recall_storage_type = "local"
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user