From 89b51a12cc0e72079b06dc29cfdad77f7cdc2589 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 3 Jan 2024 18:32:13 -0800 Subject: [PATCH] Fix bug with supporting paginated search for recall memory --- memgpt/connectors/db.py | 21 +++++---------------- memgpt/memory.py | 6 ++++-- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index a8d98b3c2..7ae77eed2 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -280,23 +280,12 @@ class SQLStorageConnector(StorageConnector): # todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204 session = self.Session() filters = self.get_filters({}) + query = ( + session.query(self.db_model).filter(*filters).filter(func.lower(self.db_model.text).contains(func.lower(query))).offset(offset) + ) if limit: - results = ( - session.query(self.db_model) - .filter(*filters) - .filter(func.lower(self.db_model.text).contains(func.lower(query))) - .offset(offset) - .all() - ) - else: - results = ( - session.query(self.db_model) - .filter(*filters) - .filter(func.lower(self.db_model.text).contains(func.lower(query))) - .offset(offset) - .limit(limit) - .all() - ) + query = query.limit(limit) + results = query.all() # return [self.type(**vars(result)) for result in results] return [result.to_record() for result in results] diff --git a/memgpt/memory.py b/memgpt/memory.py index 81b939723..59febe4a0 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -310,10 +310,12 @@ class BaseRecallMemory(RecallMemory): self.cache = {} def text_search(self, query_string, count=None, start=None): - self.storage.query_text(query_string, count, start) + results = self.storage.query_text(query_string, count, start) + return results, len(results) def date_search(self, start_date, end_date, count=None, start=None): - self.storage.query_date(start_date, end_date, count, start) + results = self.storage.query_date(start_date, end_date, count, start) + return results, len(results) def __repr__(self) -> str: total = self.storage.size()