Lancedb storage integration (#455)

This commit is contained in:
Prashant Dixit 2023-11-18 01:06:30 +05:30 committed by GitHub
parent 85d0e0dd17
commit f957209c35
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 421 additions and 6 deletions

View File

@ -18,5 +18,22 @@ pip install 'pymemgpt[postgres]'
You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation).
## LanceDB
In order to use the LanceDB backend.
You have to enable the LanceDB backend by running
```
memgpt configure
```
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`.
To enable the LanceDB backend, make sure to install the required dependencies with:
```
pip install 'pymemgpt[lancedb]'
```
for more checkout [lancedb docs](https://lancedb.github.io/lancedb/)
## Chroma
(Coming soon)

View File

@ -38,7 +38,8 @@
"outputs": [],
"source": [
"import openai\n",
"openai.api_key=\"YOUR_API_KEY\""
"\n",
"openai.api_key = \"YOUR_API_KEY\""
]
},
{

View File

@ -210,7 +210,7 @@ def configure_cli(config: MemGPTConfig):
def configure_archival_storage(config: MemGPTConfig):
# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_options = ["local", "lancedb", "postgres"]
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
@ -220,8 +220,17 @@ def configure_archival_storage(config: MemGPTConfig):
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()
if archival_storage_type == "lancedb":
archival_storage_uri = questionary.text(
"Enter lanncedb connection string (e.g. ./.lancedb",
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
).ask()
return archival_storage_type, archival_storage_uri
# TODO: allow configuring embedding model
@app.command()
def configure():

View File

@ -13,6 +13,7 @@ from tqdm import tqdm
from typing import Optional, List, Iterator
import numpy as np
from tqdm import tqdm
import pandas as pd
from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, Passage
@ -181,3 +182,139 @@ class PostgresStorageConnector(StorageConnector):
def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"
class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, name: Optional[str] = None):
config = MemGPTConfig.load()
# determine table name
if name:
self.table_name = self.generate_table_name(name)
else:
self.table_name = "lancedb_tbl"
printd(f"Using table name {self.table_name}")
# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
import lancedb
self.db = lancedb.connect(self.uri)
self.table = None
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
session = self.Session()
offset = 0
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = self.table.search().limit(page_size).to_list()
# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
break
# Yield a list of Passage objects converted from the chunk
yield [
Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk
]
# Increment the offset to get the next chunk in the next iteration
offset += page_size
def get_all(self, limit=10) -> List[Passage]:
db_passages = self.table.search().limit(limit).to_list()
return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages]
def get(self, id: str) -> Optional[Passage]:
db_passage = self.table.where(f"passage_id={id}").to_list()
if len(db_passage) == 0:
return None
return Passage(
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
)
def size(self) -> int:
# return size of table
if self.table:
return len(self.table.search().to_list())
else:
print(f"Table with name {self.table_name} not present")
return 0
def insert(self, passage: Passage):
data = [{"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}]
if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
def insert_many(self, passages: List[Passage], show_progress=True):
data = []
iterable = tqdm(passages) if show_progress else passages
for passage in iterable:
temp_dict = {"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}
data.append(temp_dict)
if self.table:
self.table.add(data)
else:
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
# Assuming query_vec is of same length as embeddings inside table
results = self.table.search(query_vec).limit(top_k)
# Convert the results into Passage objects
passages = [
Passage(text=result["text"], embedding=result["embedding"], doc_id=result["doc_id"], passage_id=result["passage_id"])
for result in results
]
return passages
def delete(self):
"""Drop the passage table from the database."""
# Drop the table specified by the PassageModel class
self.db.drop_table(self.table_name)
def save(self):
return
@staticmethod
def list_loaded_data():
config = MemGPTConfig.load()
import lancedb
db = lancedb.connect(config.archival_storage_uri)
tables = db.table_names()
tables = [table for table in tables if table.startswith("memgpt_")]
tables = [table.replace("memgpt_", "") for table in tables]
return tables
def sanitize_table_name(self, name: str) -> str:
# Remove leading and trailing whitespace
name = name.strip()
# Replace spaces and invalid characters with underscores
name = re.sub(r"\s+|\W+", "_", name)
# Truncate to the maximum identifier length
max_length = 63
if len(name) > max_length:
name = name[:max_length].rstrip("_")
# Convert to lowercase
name = name.lower()
return name
def generate_table_name(self, name: str):
return f"memgpt_{self.sanitize_table_name(name)}"

View File

@ -48,6 +48,11 @@ class StorageConnector:
return PostgresStorageConnector(name=name, agent_config=agent_config)
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector
return LanceDBConnector(name=name)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@ -62,6 +67,11 @@ class StorageConnector:
from memgpt.connectors.db import PostgresStorageConnector
return PostgresStorageConnector.list_loaded_data()
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector
return LanceDBConnector.list_loaded_data()
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")

147
poetry.lock generated
View File

@ -250,6 +250,17 @@ d = ["aiohttp (>=3.7.4)"]
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
uvloop = ["uvloop (>=0.15.2)"]
[[package]]
name = "cachetools"
version = "5.3.2"
description = "Extensible memoizing collections and decorators"
optional = false
python-versions = ">=3.7"
files = [
{file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"},
{file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"},
]
[[package]]
name = "certifi"
version = "2023.7.22"
@ -482,6 +493,17 @@ tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elast
torch = ["torch"]
vision = ["Pillow (>=6.2.1)"]
[[package]]
name = "decorator"
version = "5.1.1"
description = "Decorators for Humans"
optional = false
python-versions = ">=3.5"
files = [
{file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
]
[[package]]
name = "demjson3"
version = "3.0.6"
@ -509,6 +531,20 @@ wrapt = ">=1.10,<2"
[package.extras]
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
[[package]]
name = "deprecation"
version = "2.1.0"
description = "A library to handle automated deprecations"
optional = false
python-versions = "*"
files = [
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
]
[package.dependencies]
packaging = "*"
[[package]]
name = "dill"
version = "0.3.7"
@ -910,6 +946,39 @@ files = [
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
]
[[package]]
name = "lancedb"
version = "0.3.3"
description = "lancedb"
optional = false
python-versions = ">=3.8"
files = [
{file = "lancedb-0.3.3-py3-none-any.whl", hash = "sha256:67ccea22a6cb39c688041f7469be778a2e64b141db80866f6f0dec25a3122aff"},
{file = "lancedb-0.3.3.tar.gz", hash = "sha256:8d8a9c2b107154ee57f6f75957d215719a204cd64c9efbe7095eaf41b43c2a29"},
]
[package.dependencies]
aiohttp = "*"
attrs = ">=21.3.0"
cachetools = "*"
click = ">=8.1.7"
deprecation = "*"
pydantic = ">=1.10"
pylance = "0.8.10"
pyyaml = ">=6.0"
ratelimiter = ">=1.0,<2.0"
requests = ">=2.31.0"
retry = ">=0.9.2"
semver = ">=3.0"
tqdm = ">=4.1.0"
[package.extras]
clip = ["open-clip", "pillow", "torch"]
dev = ["black", "pre-commit", "ruff"]
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
embeddings = ["cohere", "open-clip-torch", "openai", "pillow", "sentence-transformers", "torch"]
tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock", "requests"]
[[package]]
name = "langchain"
version = "0.0.333"
@ -1936,11 +2005,22 @@ files = [
{file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"},
]
[[package]]
name = "py"
version = "1.11.0"
description = "library with cross-python path, ini-parsing, io, code, log facilities"
optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
files = [
{file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
]
[[package]]
name = "pyarrow"
version = "14.0.1"
description = "Python library for Apache Arrow"
optional = true
optional = false
python-versions = ">=3.8"
files = [
{file = "pyarrow-14.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:96d64e5ba7dceb519a955e5eeb5c9adcfd63f73a56aea4722e2cc81364fc567a"},
@ -2135,6 +2215,28 @@ files = [
[package.extras]
plugins = ["importlib-metadata"]
[[package]]
name = "pylance"
version = "0.8.10"
description = "python wrapper for Lance columnar format"
optional = false
python-versions = ">=3.8"
files = [
{file = "pylance-0.8.10-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:aecf053e12f13a1810a70c786c1e73bcf3ffe7287c0bfe2cc5df77a91f0a084c"},
{file = "pylance-0.8.10-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:b778fbcfae2e9186053292b7bd3fcd28efc92bd0471f733f8dbf4a1f840c9ce4"},
{file = "pylance-0.8.10-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ea617723593d4cc0d2faaaf4a861e31ae3c8657517b83e2fb99e5f68c0c1481"},
{file = "pylance-0.8.10-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b1bf9cc33d7095196931f96588d733e80d69a4c312b5352d9dab9a0d5a84c8f"},
{file = "pylance-0.8.10-cp38-abi3-win_amd64.whl", hash = "sha256:d700f874710c6f1a5567c6e4f98426c22aebf9937dcdbd7305573f519712b683"},
]
[package.dependencies]
numpy = ">=1.22"
pyarrow = ">=10"
[package.extras]
benchmarks = ["pytest-benchmark"]
tests = ["duckdb", "ml_dtypes", "pandas (>=1.4,<2.1)", "polars[pandas,pyarrow]", "pytest", "semver", "tensorflow", "tqdm"]
[[package]]
name = "pymupdf"
version = "1.23.6"
@ -2328,6 +2430,20 @@ files = [
[package.dependencies]
prompt_toolkit = ">=2.0,<=3.0.36"
[[package]]
name = "ratelimiter"
version = "1.2.0.post0"
description = "Simple python rate limiting object"
optional = false
python-versions = "*"
files = [
{file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"},
{file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"},
]
[package.extras]
test = ["pytest (>=3.0)", "pytest-asyncio"]
[[package]]
name = "regex"
version = "2023.10.3"
@ -2446,6 +2562,21 @@ urllib3 = ">=1.21.1,<3"
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
[[package]]
name = "retry"
version = "0.9.2"
description = "Easy to use retry decorator."
optional = false
python-versions = "*"
files = [
{file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"},
{file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"},
]
[package.dependencies]
decorator = ">=3.4.2"
py = ">=1.4.26,<2.0.0"
[[package]]
name = "rich"
version = "13.6.0"
@ -2597,6 +2728,17 @@ files = [
[package.dependencies]
asn1crypto = ">=1.5.1"
[[package]]
name = "semver"
version = "3.0.2"
description = "Python helper for Semantic Versioning (https://semver.org)"
optional = false
python-versions = ">=3.7"
files = [
{file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"},
{file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"},
]
[[package]]
name = "setuptools"
version = "68.2.2"
@ -3601,6 +3743,7 @@ multidict = ">=4.0"
[extras]
dev = ["black", "datasets", "pre-commit", "pytest"]
lancedb = []
legacy = ["faiss-cpu", "numpy"]
local = ["huggingface-hub", "torch", "transformers"]
postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"]
@ -3608,4 +3751,4 @@ postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"
[metadata]
lock-version = "2.0"
python-versions = "<3.12,>=3.9"
content-hash = "0fa0b65ce00550c139abcf5b4134e9e5b19b277930782ffe8421afec9d2743e2"
content-hash = "130c4da6c4b59aeb80aecf9549f75bed28123c275e30f159232e491d726034d5"

View File

@ -47,10 +47,12 @@ pg8000 = {version = "^1.30.3", optional = true}
torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
websockets = "^12.0"
docstring-parser = "^0.15"
lancedb = {version = "^0.3.3", optional = true}
[tool.poetry.extras]
legacy = ["faiss-cpu", "numpy"]
local = ["torch", "huggingface-hub", "transformers"]
lancedb = ["lancedb"]
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
dev = ["pytest", "black", "pre-commit", "datasets"]

View File

@ -46,6 +46,39 @@ def test_postgres():
)
def test_lancedb():
return
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
import lancedb # Try to import again after installing
# override config path with enviornment variable
# TODO: make into temporary file
os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg"
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
print(config)
config.save()
# loading dataset from hugging face
name = "tmp_hf_dataset"
dataset = load_dataset("MemGPT/example_short_stories")
cache_dir = os.getenv("HF_DATASETS_CACHE")
if cache_dir is None:
# Construct the default path if the environment variable is not set.
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb")
load_directory(
name=name,
input_dir=cache_dir,
recursive=True,
)
def test_chroma():
return

View File

@ -6,10 +6,12 @@ import pytest
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
import pgvector # Try to import again after installing
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.connectors.db import PostgresStorageConnector
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
from memgpt.embeddings import embedding_model
from memgpt.config import MemGPTConfig, AgentConfig
@ -57,6 +59,38 @@ def test_postgres_openai():
# print("...finished")
@pytest.mark.skipif(
not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key"
)
def test_lancedb_openai():
assert os.getenv("LANCEDB_TEST_URL") is not None
if os.getenv("OPENAI_API_KEY") is None:
return # soft pass
config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
print(config.config_path)
assert config.archival_storage_uri is not None
print(config)
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = LanceDBConnector(name="test-openai")
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
print(db.get_all())
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
def test_postgres_local():
if not os.getenv("PGVECTOR_TEST_DB_URL"):
@ -101,4 +135,33 @@ def test_postgres_local():
# print("...finished")
# test_postgres()
@pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI")
def test_lancedb_local():
assert os.getenv("LANCEDB_TEST_URL") is not None
config = MemGPTConfig(
archival_storage_type="lancedb",
archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
embedding_model="local",
embedding_dim=384, # use HF model
)
print(config.config_path)
assert config.archival_storage_uri is not None
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = LanceDBConnector(name="test-local")
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
print(db.get_all())
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"