mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Lancedb storage integration (#455)
This commit is contained in:
parent
85d0e0dd17
commit
f957209c35
@ -18,5 +18,22 @@ pip install 'pymemgpt[postgres]'
|
||||
You will need to have a URI to a Postgres database which support [pgvector](https://github.com/pgvector/pgvector). You can either use a [hosted provider](https://github.com/pgvector/pgvector/issues/54) or [install pgvector](https://github.com/pgvector/pgvector#installation).
|
||||
|
||||
|
||||
## LanceDB
|
||||
In order to use the LanceDB backend.
|
||||
|
||||
You have to enable the LanceDB backend by running
|
||||
|
||||
```
|
||||
memgpt configure
|
||||
```
|
||||
and selecting `lancedb` for archival storage, and database URI (e.g. `./.lancedb`"), Empty archival uri is also handled and default uri is set at `./.lancedb`.
|
||||
|
||||
To enable the LanceDB backend, make sure to install the required dependencies with:
|
||||
```
|
||||
pip install 'pymemgpt[lancedb]'
|
||||
```
|
||||
for more checkout [lancedb docs](https://lancedb.github.io/lancedb/)
|
||||
|
||||
|
||||
## Chroma
|
||||
(Coming soon)
|
||||
|
@ -38,7 +38,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import openai\n",
|
||||
"openai.api_key=\"YOUR_API_KEY\""
|
||||
"\n",
|
||||
"openai.api_key = \"YOUR_API_KEY\""
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -210,7 +210,7 @@ def configure_cli(config: MemGPTConfig):
|
||||
|
||||
def configure_archival_storage(config: MemGPTConfig):
|
||||
# Configure archival storage backend
|
||||
archival_storage_options = ["local", "postgres"]
|
||||
archival_storage_options = ["local", "lancedb", "postgres"]
|
||||
archival_storage_type = questionary.select(
|
||||
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
|
||||
).ask()
|
||||
@ -220,8 +220,17 @@ def configure_archival_storage(config: MemGPTConfig):
|
||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
||||
).ask()
|
||||
|
||||
if archival_storage_type == "lancedb":
|
||||
archival_storage_uri = questionary.text(
|
||||
"Enter lanncedb connection string (e.g. ./.lancedb",
|
||||
default=config.archival_storage_uri if config.archival_storage_uri else "./.lancedb",
|
||||
).ask()
|
||||
|
||||
return archival_storage_type, archival_storage_uri
|
||||
|
||||
# TODO: allow configuring embedding model
|
||||
|
||||
|
||||
@app.command()
|
||||
def configure():
|
||||
|
@ -13,6 +13,7 @@ from tqdm import tqdm
|
||||
from typing import Optional, List, Iterator
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.connectors.storage import StorageConnector, Passage
|
||||
@ -181,3 +182,139 @@ class PostgresStorageConnector(StorageConnector):
|
||||
|
||||
def generate_table_name(self, name: str):
|
||||
return f"memgpt_{self.sanitize_table_name(name)}"
|
||||
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, name: Optional[str] = None):
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# determine table name
|
||||
if name:
|
||||
self.table_name = self.generate_table_name(name)
|
||||
else:
|
||||
self.table_name = "lancedb_tbl"
|
||||
|
||||
printd(f"Using table name {self.table_name}")
|
||||
|
||||
# create table
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||
import lancedb
|
||||
|
||||
self.db = lancedb.connect(self.uri)
|
||||
self.table = None
|
||||
|
||||
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
|
||||
session = self.Session()
|
||||
offset = 0
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
db_passages_chunk = self.table.search().limit(page_size).to_list()
|
||||
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
if not db_passages_chunk:
|
||||
break
|
||||
|
||||
# Yield a list of Passage objects converted from the chunk
|
||||
yield [
|
||||
Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk
|
||||
]
|
||||
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
|
||||
def get_all(self, limit=10) -> List[Passage]:
|
||||
db_passages = self.table.search().limit(limit).to_list()
|
||||
return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages]
|
||||
|
||||
def get(self, id: str) -> Optional[Passage]:
|
||||
db_passage = self.table.where(f"passage_id={id}").to_list()
|
||||
if len(db_passage) == 0:
|
||||
return None
|
||||
return Passage(
|
||||
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
|
||||
)
|
||||
|
||||
def size(self) -> int:
|
||||
# return size of table
|
||||
if self.table:
|
||||
return len(self.table.search().to_list())
|
||||
else:
|
||||
print(f"Table with name {self.table_name} not present")
|
||||
return 0
|
||||
|
||||
def insert(self, passage: Passage):
|
||||
data = [{"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}]
|
||||
|
||||
if self.table:
|
||||
self.table.add(data)
|
||||
else:
|
||||
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
|
||||
|
||||
def insert_many(self, passages: List[Passage], show_progress=True):
|
||||
data = []
|
||||
iterable = tqdm(passages) if show_progress else passages
|
||||
for passage in iterable:
|
||||
temp_dict = {"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}
|
||||
data.append(temp_dict)
|
||||
|
||||
if self.table:
|
||||
self.table.add(data)
|
||||
else:
|
||||
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
|
||||
# Assuming query_vec is of same length as embeddings inside table
|
||||
results = self.table.search(query_vec).limit(top_k)
|
||||
|
||||
# Convert the results into Passage objects
|
||||
passages = [
|
||||
Passage(text=result["text"], embedding=result["embedding"], doc_id=result["doc_id"], passage_id=result["passage_id"])
|
||||
for result in results
|
||||
]
|
||||
return passages
|
||||
|
||||
def delete(self):
|
||||
"""Drop the passage table from the database."""
|
||||
# Drop the table specified by the PassageModel class
|
||||
self.db.drop_table(self.table_name)
|
||||
|
||||
def save(self):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def list_loaded_data():
|
||||
config = MemGPTConfig.load()
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect(config.archival_storage_uri)
|
||||
|
||||
tables = db.table_names()
|
||||
tables = [table for table in tables if table.startswith("memgpt_")]
|
||||
tables = [table.replace("memgpt_", "") for table in tables]
|
||||
return tables
|
||||
|
||||
def sanitize_table_name(self, name: str) -> str:
|
||||
# Remove leading and trailing whitespace
|
||||
name = name.strip()
|
||||
|
||||
# Replace spaces and invalid characters with underscores
|
||||
name = re.sub(r"\s+|\W+", "_", name)
|
||||
|
||||
# Truncate to the maximum identifier length
|
||||
max_length = 63
|
||||
if len(name) > max_length:
|
||||
name = name[:max_length].rstrip("_")
|
||||
|
||||
# Convert to lowercase
|
||||
name = name.lower()
|
||||
|
||||
return name
|
||||
|
||||
def generate_table_name(self, name: str):
|
||||
return f"memgpt_{self.sanitize_table_name(name)}"
|
||||
|
@ -48,6 +48,11 @@ class StorageConnector:
|
||||
|
||||
return PostgresStorageConnector(name=name, agent_config=agent_config)
|
||||
|
||||
elif storage_type == "lancedb":
|
||||
from memgpt.connectors.db import LanceDBConnector
|
||||
|
||||
return LanceDBConnector(name=name)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
@ -62,6 +67,11 @@ class StorageConnector:
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
|
||||
return PostgresStorageConnector.list_loaded_data()
|
||||
|
||||
elif storage_type == "lancedb":
|
||||
from memgpt.connectors.db import LanceDBConnector
|
||||
|
||||
return LanceDBConnector.list_loaded_data()
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||
|
||||
|
147
poetry.lock
generated
147
poetry.lock
generated
@ -250,6 +250,17 @@ d = ["aiohttp (>=3.7.4)"]
|
||||
jupyter = ["ipython (>=7.8.0)", "tokenize-rt (>=3.2.0)"]
|
||||
uvloop = ["uvloop (>=0.15.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "cachetools"
|
||||
version = "5.3.2"
|
||||
description = "Extensible memoizing collections and decorators"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"},
|
||||
{file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2023.7.22"
|
||||
@ -482,6 +493,17 @@ tests = ["Pillow (>=6.2.1)", "absl-py", "apache-beam (>=2.26.0,<2.44.0)", "elast
|
||||
torch = ["torch"]
|
||||
vision = ["Pillow (>=6.2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "decorator"
|
||||
version = "5.1.1"
|
||||
description = "Decorators for Humans"
|
||||
optional = false
|
||||
python-versions = ">=3.5"
|
||||
files = [
|
||||
{file = "decorator-5.1.1-py3-none-any.whl", hash = "sha256:b8c3f85900b9dc423225913c5aace94729fe1fa9763b38939a95226f02d37186"},
|
||||
{file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "demjson3"
|
||||
version = "3.0.6"
|
||||
@ -509,6 +531,20 @@ wrapt = ">=1.10,<2"
|
||||
[package.extras]
|
||||
dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"]
|
||||
|
||||
[[package]]
|
||||
name = "deprecation"
|
||||
version = "2.1.0"
|
||||
description = "A library to handle automated deprecations"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"},
|
||||
{file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
packaging = "*"
|
||||
|
||||
[[package]]
|
||||
name = "dill"
|
||||
version = "0.3.7"
|
||||
@ -910,6 +946,39 @@ files = [
|
||||
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lancedb"
|
||||
version = "0.3.3"
|
||||
description = "lancedb"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "lancedb-0.3.3-py3-none-any.whl", hash = "sha256:67ccea22a6cb39c688041f7469be778a2e64b141db80866f6f0dec25a3122aff"},
|
||||
{file = "lancedb-0.3.3.tar.gz", hash = "sha256:8d8a9c2b107154ee57f6f75957d215719a204cd64c9efbe7095eaf41b43c2a29"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = "*"
|
||||
attrs = ">=21.3.0"
|
||||
cachetools = "*"
|
||||
click = ">=8.1.7"
|
||||
deprecation = "*"
|
||||
pydantic = ">=1.10"
|
||||
pylance = "0.8.10"
|
||||
pyyaml = ">=6.0"
|
||||
ratelimiter = ">=1.0,<2.0"
|
||||
requests = ">=2.31.0"
|
||||
retry = ">=0.9.2"
|
||||
semver = ">=3.0"
|
||||
tqdm = ">=4.1.0"
|
||||
|
||||
[package.extras]
|
||||
clip = ["open-clip", "pillow", "torch"]
|
||||
dev = ["black", "pre-commit", "ruff"]
|
||||
docs = ["mkdocs", "mkdocs-jupyter", "mkdocs-material", "mkdocstrings[python]"]
|
||||
embeddings = ["cohere", "open-clip-torch", "openai", "pillow", "sentence-transformers", "torch"]
|
||||
tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock", "requests"]
|
||||
|
||||
[[package]]
|
||||
name = "langchain"
|
||||
version = "0.0.333"
|
||||
@ -1936,11 +2005,22 @@ files = [
|
||||
{file = "psycopg2_binary-2.9.9-cp39-cp39-win_amd64.whl", hash = "sha256:f7ae5d65ccfbebdfa761585228eb4d0df3a8b15cfb53bd953e713e09fbb12957"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "py"
|
||||
version = "1.11.0"
|
||||
description = "library with cross-python path, ini-parsing, io, code, log facilities"
|
||||
optional = false
|
||||
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*"
|
||||
files = [
|
||||
{file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"},
|
||||
{file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pyarrow"
|
||||
version = "14.0.1"
|
||||
description = "Python library for Apache Arrow"
|
||||
optional = true
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pyarrow-14.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:96d64e5ba7dceb519a955e5eeb5c9adcfd63f73a56aea4722e2cc81364fc567a"},
|
||||
@ -2135,6 +2215,28 @@ files = [
|
||||
[package.extras]
|
||||
plugins = ["importlib-metadata"]
|
||||
|
||||
[[package]]
|
||||
name = "pylance"
|
||||
version = "0.8.10"
|
||||
description = "python wrapper for Lance columnar format"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "pylance-0.8.10-cp38-abi3-macosx_10_15_x86_64.whl", hash = "sha256:aecf053e12f13a1810a70c786c1e73bcf3ffe7287c0bfe2cc5df77a91f0a084c"},
|
||||
{file = "pylance-0.8.10-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:b778fbcfae2e9186053292b7bd3fcd28efc92bd0471f733f8dbf4a1f840c9ce4"},
|
||||
{file = "pylance-0.8.10-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ea617723593d4cc0d2faaaf4a861e31ae3c8657517b83e2fb99e5f68c0c1481"},
|
||||
{file = "pylance-0.8.10-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b1bf9cc33d7095196931f96588d733e80d69a4c312b5352d9dab9a0d5a84c8f"},
|
||||
{file = "pylance-0.8.10-cp38-abi3-win_amd64.whl", hash = "sha256:d700f874710c6f1a5567c6e4f98426c22aebf9937dcdbd7305573f519712b683"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.22"
|
||||
pyarrow = ">=10"
|
||||
|
||||
[package.extras]
|
||||
benchmarks = ["pytest-benchmark"]
|
||||
tests = ["duckdb", "ml_dtypes", "pandas (>=1.4,<2.1)", "polars[pandas,pyarrow]", "pytest", "semver", "tensorflow", "tqdm"]
|
||||
|
||||
[[package]]
|
||||
name = "pymupdf"
|
||||
version = "1.23.6"
|
||||
@ -2328,6 +2430,20 @@ files = [
|
||||
[package.dependencies]
|
||||
prompt_toolkit = ">=2.0,<=3.0.36"
|
||||
|
||||
[[package]]
|
||||
name = "ratelimiter"
|
||||
version = "1.2.0.post0"
|
||||
description = "Simple python rate limiting object"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "ratelimiter-1.2.0.post0-py3-none-any.whl", hash = "sha256:a52be07bc0bb0b3674b4b304550f10c769bbb00fead3072e035904474259809f"},
|
||||
{file = "ratelimiter-1.2.0.post0.tar.gz", hash = "sha256:5c395dcabdbbde2e5178ef3f89b568a3066454a6ddc223b76473dac22f89b4f7"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
test = ["pytest (>=3.0)", "pytest-asyncio"]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "2023.10.3"
|
||||
@ -2446,6 +2562,21 @@ urllib3 = ">=1.21.1,<3"
|
||||
socks = ["PySocks (>=1.5.6,!=1.5.7)"]
|
||||
use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"]
|
||||
|
||||
[[package]]
|
||||
name = "retry"
|
||||
version = "0.9.2"
|
||||
description = "Easy to use retry decorator."
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
files = [
|
||||
{file = "retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606"},
|
||||
{file = "retry-0.9.2.tar.gz", hash = "sha256:f8bfa8b99b69c4506d6f5bd3b0aabf77f98cdb17f3c9fc3f5ca820033336fba4"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
decorator = ">=3.4.2"
|
||||
py = ">=1.4.26,<2.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "rich"
|
||||
version = "13.6.0"
|
||||
@ -2597,6 +2728,17 @@ files = [
|
||||
[package.dependencies]
|
||||
asn1crypto = ">=1.5.1"
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "3.0.2"
|
||||
description = "Python helper for Semantic Versioning (https://semver.org)"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "semver-3.0.2-py3-none-any.whl", hash = "sha256:b1ea4686fe70b981f85359eda33199d60c53964284e0cfb4977d243e37cf4bf4"},
|
||||
{file = "semver-3.0.2.tar.gz", hash = "sha256:6253adb39c70f6e51afed2fa7152bcd414c411286088fb4b9effb133885ab4cc"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "setuptools"
|
||||
version = "68.2.2"
|
||||
@ -3601,6 +3743,7 @@ multidict = ">=4.0"
|
||||
|
||||
[extras]
|
||||
dev = ["black", "datasets", "pre-commit", "pytest"]
|
||||
lancedb = []
|
||||
legacy = ["faiss-cpu", "numpy"]
|
||||
local = ["huggingface-hub", "torch", "transformers"]
|
||||
postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"]
|
||||
@ -3608,4 +3751,4 @@ postgres = ["pg8000", "pgvector", "psycopg", "psycopg-binary", "psycopg2-binary"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.12,>=3.9"
|
||||
content-hash = "0fa0b65ce00550c139abcf5b4134e9e5b19b277930782ffe8421afec9d2743e2"
|
||||
content-hash = "130c4da6c4b59aeb80aecf9549f75bed28123c275e30f159232e491d726034d5"
|
||||
|
@ -47,10 +47,12 @@ pg8000 = {version = "^1.30.3", optional = true}
|
||||
torch = {version = ">=2.0.0, !=2.0.1, !=2.1.0", optional = true}
|
||||
websockets = "^12.0"
|
||||
docstring-parser = "^0.15"
|
||||
lancedb = {version = "^0.3.3", optional = true}
|
||||
|
||||
[tool.poetry.extras]
|
||||
legacy = ["faiss-cpu", "numpy"]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
lancedb = ["lancedb"]
|
||||
postgres = ["pgvector", "psycopg", "psycopg-binary", "psycopg2-binary", "pg8000"]
|
||||
dev = ["pytest", "black", "pre-commit", "datasets"]
|
||||
|
||||
|
@ -46,6 +46,39 @@ def test_postgres():
|
||||
)
|
||||
|
||||
|
||||
def test_lancedb():
|
||||
return
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
||||
import lancedb # Try to import again after installing
|
||||
|
||||
# override config path with enviornment variable
|
||||
# TODO: make into temporary file
|
||||
os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg"
|
||||
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
print(config)
|
||||
config.save()
|
||||
|
||||
# loading dataset from hugging face
|
||||
name = "tmp_hf_dataset"
|
||||
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
|
||||
cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
if cache_dir is None:
|
||||
# Construct the default path if the environment variable is not set.
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
|
||||
config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb")
|
||||
|
||||
load_directory(
|
||||
name=name,
|
||||
input_dir=cache_dir,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
|
||||
def test_chroma():
|
||||
return
|
||||
|
||||
|
@ -6,10 +6,12 @@ import pytest
|
||||
subprocess.check_call(
|
||||
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
||||
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
||||
import pgvector # Try to import again after installing
|
||||
|
||||
from memgpt.connectors.storage import StorageConnector, Passage
|
||||
from memgpt.connectors.db import PostgresStorageConnector
|
||||
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
|
||||
@ -57,6 +59,38 @@ def test_postgres_openai():
|
||||
# print("...finished")
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("LANCEDB_TEST_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing LANCEDB URI and/or OpenAI API key"
|
||||
)
|
||||
def test_lancedb_openai():
|
||||
assert os.getenv("LANCEDB_TEST_URL") is not None
|
||||
if os.getenv("OPENAI_API_KEY") is None:
|
||||
return # soft pass
|
||||
|
||||
config = MemGPTConfig(archival_storage_type="lancedb", archival_storage_uri=os.getenv("LANCEDB_TEST_URL"))
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
print(config)
|
||||
|
||||
embed_model = embedding_model()
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
|
||||
db = LanceDBConnector(name="test-openai")
|
||||
|
||||
for passage in passage:
|
||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
|
||||
print(db.get_all())
|
||||
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = db.query(None, query_vec, top_k=2)
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
|
||||
def test_postgres_local():
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
@ -101,4 +135,33 @@ def test_postgres_local():
|
||||
# print("...finished")
|
||||
|
||||
|
||||
# test_postgres()
|
||||
@pytest.mark.skipif(not os.getenv("LANCEDB_TEST_URL"), reason="Missing LanceDB URI")
|
||||
def test_lancedb_local():
|
||||
assert os.getenv("LANCEDB_TEST_URL") is not None
|
||||
|
||||
config = MemGPTConfig(
|
||||
archival_storage_type="lancedb",
|
||||
archival_storage_uri=os.getenv("LANCEDB_TEST_URL"),
|
||||
embedding_model="local",
|
||||
embedding_dim=384, # use HF model
|
||||
)
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
|
||||
embed_model = embedding_model()
|
||||
|
||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
|
||||
db = LanceDBConnector(name="test-local")
|
||||
|
||||
for passage in passage:
|
||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
||||
|
||||
print(db.get_all())
|
||||
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = db.query(None, query_vec, top_k=2)
|
||||
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
Loading…
Reference in New Issue
Block a user