mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
Support recall and archival memory for postgres
working test
This commit is contained in:
parent
9f3806dfcb
commit
9b3d59e016
@ -7,7 +7,7 @@ import traceback
|
|||||||
from memgpt.persistence_manager import LocalStateManager
|
from memgpt.persistence_manager import LocalStateManager
|
||||||
from memgpt.config import AgentConfig, MemGPTConfig
|
from memgpt.config import AgentConfig, MemGPTConfig
|
||||||
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||||
from memgpt.memory import CoreMemory as Memory, summarize_messages
|
from memgpt.memory import CoreMemory as InContextMemory, summarize_messages
|
||||||
from memgpt.openai_tools import create, is_context_overflow_error
|
from memgpt.openai_tools import create, is_context_overflow_error
|
||||||
from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response
|
from memgpt.utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff, validate_function_response
|
||||||
from memgpt.constants import (
|
from memgpt.constants import (
|
||||||
@ -29,7 +29,7 @@ def initialize_memory(ai_notes, human_notes):
|
|||||||
raise ValueError(ai_notes)
|
raise ValueError(ai_notes)
|
||||||
if human_notes is None:
|
if human_notes is None:
|
||||||
raise ValueError(human_notes)
|
raise ValueError(human_notes)
|
||||||
memory = Memory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
|
memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
|
||||||
memory.edit_persona(ai_notes)
|
memory.edit_persona(ai_notes)
|
||||||
memory.edit_human(human_notes)
|
memory.edit_human(human_notes)
|
||||||
return memory
|
return memory
|
||||||
@ -240,6 +240,7 @@ class Agent(object):
|
|||||||
|
|
||||||
### Local state management
|
### Local state management
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
|
# TODO: select specific variables for the saves state (to eventually move to a DB) rather than checkpointing everything in the class
|
||||||
return {
|
return {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"system": self.system,
|
"system": self.system,
|
||||||
@ -249,32 +250,29 @@ class Agent(object):
|
|||||||
"memory": self.memory.to_dict(),
|
"memory": self.memory.to_dict(),
|
||||||
}
|
}
|
||||||
|
|
||||||
def save_to_json_file(self, filename):
|
def save_agent_state_json(self, filename):
|
||||||
|
"""Save agent state to JSON"""
|
||||||
with open(filename, "w") as file:
|
with open(filename, "w") as file:
|
||||||
json.dump(self.to_dict(), file)
|
json.dump(self.to_dict(), file)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Save agent state locally"""
|
"""Save agent state locally"""
|
||||||
|
|
||||||
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
|
|
||||||
agent_name = self.config.name # TODO: fix
|
|
||||||
|
|
||||||
# save config
|
# save config
|
||||||
self.config.save()
|
self.config.save()
|
||||||
|
|
||||||
# save agent state
|
# save agent state to timestamped file
|
||||||
|
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
|
||||||
filename = f"{timestamp}.json"
|
filename = f"{timestamp}.json"
|
||||||
os.makedirs(self.config.save_state_dir(), exist_ok=True)
|
os.makedirs(self.config.save_state_dir(), exist_ok=True)
|
||||||
self.save_to_json_file(os.path.join(self.config.save_state_dir(), filename))
|
self.save_agent_state_json(os.path.join(self.config.save_state_dir(), filename))
|
||||||
|
|
||||||
# save the persistence manager too
|
# save the persistence manager too (recall/archival memory)
|
||||||
filename = f"{timestamp}.persistence.pickle"
|
self.persistence_manager.save()
|
||||||
os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True)
|
|
||||||
self.persistence_manager.save(os.path.join(self.config.save_persistence_manager_dir(), filename))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_agent(cls, interface, agent_config: AgentConfig):
|
def load_agent(cls, interface, agent_config: AgentConfig):
|
||||||
"""Load saved agent state"""
|
"""Load saved agent state based on agent_config"""
|
||||||
# TODO: support loading from specific file
|
# TODO: support loading from specific file
|
||||||
agent_name = agent_config.name
|
agent_name = agent_config.name
|
||||||
|
|
||||||
@ -290,10 +288,7 @@ class Agent(object):
|
|||||||
state = json.load(open(filename, "r"))
|
state = json.load(open(filename, "r"))
|
||||||
|
|
||||||
# load persistence manager
|
# load persistence manager
|
||||||
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
|
persistence_manager = LocalStateManager.load(agent_config)
|
||||||
directory = agent_config.save_persistence_manager_dir()
|
|
||||||
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
|
|
||||||
persistence_manager = LocalStateManager.load(os.path.join(directory, filename), agent_config)
|
|
||||||
|
|
||||||
# need to dynamically link the functions
|
# need to dynamically link the functions
|
||||||
# the saved agent.functions will just have the schemas, but we need to
|
# the saved agent.functions will just have the schemas, but we need to
|
||||||
@ -354,70 +349,6 @@ class Agent(object):
|
|||||||
agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"])
|
agent.memory = initialize_memory(state["memory"]["persona"], state["memory"]["human"])
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, state, interface, persistence_manager):
|
|
||||||
model = state["model"]
|
|
||||||
system = state["system"]
|
|
||||||
functions = state["functions"]
|
|
||||||
messages = state["messages"]
|
|
||||||
try:
|
|
||||||
messages_total = state["messages_total"]
|
|
||||||
except KeyError:
|
|
||||||
messages_total = len(messages) - 1
|
|
||||||
# memory requires a nested load
|
|
||||||
memory_dict = state["memory"]
|
|
||||||
persona_notes = memory_dict["persona"]
|
|
||||||
human_notes = memory_dict["human"]
|
|
||||||
|
|
||||||
# Two-part load
|
|
||||||
new_agent = cls(
|
|
||||||
model=model,
|
|
||||||
system=system,
|
|
||||||
functions=functions,
|
|
||||||
interface=interface,
|
|
||||||
persistence_manager=persistence_manager,
|
|
||||||
persistence_manager_init=False,
|
|
||||||
persona_notes=persona_notes,
|
|
||||||
human_notes=human_notes,
|
|
||||||
messages_total=messages_total,
|
|
||||||
)
|
|
||||||
new_agent._messages = messages
|
|
||||||
return new_agent
|
|
||||||
|
|
||||||
def load_inplace(self, state):
|
|
||||||
self.model = state["model"]
|
|
||||||
self.system = state["system"]
|
|
||||||
self.functions = state["functions"]
|
|
||||||
# memory requires a nested load
|
|
||||||
memory_dict = state["memory"]
|
|
||||||
persona_notes = memory_dict["persona"]
|
|
||||||
human_notes = memory_dict["human"]
|
|
||||||
self.memory = initialize_memory(persona_notes, human_notes)
|
|
||||||
# messages also
|
|
||||||
self._messages = state["messages"]
|
|
||||||
try:
|
|
||||||
self.messages_total = state["messages_total"]
|
|
||||||
except KeyError:
|
|
||||||
self.messages_total = len(self.messages) - 1 # -system
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_from_json(cls, json_state, interface, persistence_manager):
|
|
||||||
state = json.loads(json_state)
|
|
||||||
return cls.load(state, interface, persistence_manager)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_from_json_file(cls, json_file, interface, persistence_manager):
|
|
||||||
with open(json_file, "r") as file:
|
|
||||||
state = json.load(file)
|
|
||||||
return cls.load(state, interface, persistence_manager)
|
|
||||||
|
|
||||||
def load_from_json_file_inplace(self, json_file):
|
|
||||||
# Load in-place
|
|
||||||
# No interface arg needed, we can use the current one
|
|
||||||
with open(json_file, "r") as file:
|
|
||||||
state = json.load(file)
|
|
||||||
self.load_inplace(state)
|
|
||||||
|
|
||||||
def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False):
|
def verify_first_message_correctness(self, response, require_send_message=True, require_monologue=False):
|
||||||
"""Can be used to enforce that the first message always uses send_message"""
|
"""Can be used to enforce that the first message always uses send_message"""
|
||||||
response_message = response.choices[0].message
|
response_message = response.choices[0].message
|
||||||
|
@ -418,6 +418,22 @@ def configure_archival_storage(config: MemGPTConfig):
|
|||||||
# TODO: allow configuring embedding model
|
# TODO: allow configuring embedding model
|
||||||
|
|
||||||
|
|
||||||
|
def configure_recall_storage(config: MemGPTConfig):
|
||||||
|
# Configure recall storage backend
|
||||||
|
recall_storage_options = ["local", "postgres"]
|
||||||
|
recall_storage_type = questionary.select(
|
||||||
|
"Select storage backend for recall data:", recall_storage_options, default=config.recall_storage_type
|
||||||
|
).ask()
|
||||||
|
recall_storage_uri, recall_storage_path = None, None
|
||||||
|
# configure postgres
|
||||||
|
if recall_storage_type == "postgres":
|
||||||
|
recall_storage_uri = questionary.text(
|
||||||
|
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||||
|
default=config.recall_storage_uri if config.recall_storage_uri else "",
|
||||||
|
).ask()
|
||||||
|
return recall_storage_type, recall_storage_uri, recall_storage_path
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
@app.command()
|
||||||
def configure():
|
def configure():
|
||||||
"""Updates default MemGPT configurations"""
|
"""Updates default MemGPT configurations"""
|
||||||
@ -430,17 +446,30 @@ def configure():
|
|||||||
|
|
||||||
# Will pre-populate with defaults, or what the user previously set
|
# Will pre-populate with defaults, or what the user previously set
|
||||||
config = MemGPTConfig.load()
|
config = MemGPTConfig.load()
|
||||||
try:
|
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
||||||
model_endpoint_type, model_endpoint = configure_llm_endpoint(config)
|
model, model_wrapper, context_window = configure_model(config, model_endpoint_type)
|
||||||
model, model_wrapper, context_window = configure_model(
|
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
|
||||||
config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint
|
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
||||||
)
|
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
|
||||||
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config)
|
recall_storage_type, recall_storage_uri, recall_storage_path = configure_recall_storage(config)
|
||||||
default_preset, default_persona, default_human, default_agent = configure_cli(config)
|
|
||||||
archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config)
|
# check credentials
|
||||||
except ValueError as e:
|
azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials()
|
||||||
typer.secho(str(e), fg=typer.colors.RED)
|
openai_key = get_openai_credentials()
|
||||||
return
|
if model_endpoint_type == "azure" or embedding_endpoint_type == "azure":
|
||||||
|
if all([azure_key, azure_endpoint, azure_version]):
|
||||||
|
print(f"Using Microsoft endpoint {azure_endpoint}.")
|
||||||
|
if all([azure_deployment, azure_embedding_deployment]):
|
||||||
|
print(f"Using deployment id {azure_deployment}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again."
|
||||||
|
)
|
||||||
|
if model_endpoint_type == "openai" or embedding_endpoint_type == "openai":
|
||||||
|
if not openai_key:
|
||||||
|
raise ValueError(
|
||||||
|
"Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again."
|
||||||
|
)
|
||||||
|
|
||||||
config = MemGPTConfig(
|
config = MemGPTConfig(
|
||||||
# model configs
|
# model configs
|
||||||
@ -470,6 +499,10 @@ def configure():
|
|||||||
archival_storage_type=archival_storage_type,
|
archival_storage_type=archival_storage_type,
|
||||||
archival_storage_uri=archival_storage_uri,
|
archival_storage_uri=archival_storage_uri,
|
||||||
archival_storage_path=archival_storage_path,
|
archival_storage_path=archival_storage_path,
|
||||||
|
# recall storage
|
||||||
|
recall_storage_type=recall_storage_type,
|
||||||
|
recall_storage_uri=recall_storage_uri,
|
||||||
|
recall_storage_path=recall_storage_path,
|
||||||
)
|
)
|
||||||
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
|
typer.secho(f"📖 Saving config to {config.config_path}", fg=typer.colors.GREEN)
|
||||||
config.save()
|
config.save()
|
||||||
|
@ -134,6 +134,9 @@ class MemGPTConfig:
|
|||||||
"archival_storage_type": get_field(config, "archival_storage", "type"),
|
"archival_storage_type": get_field(config, "archival_storage", "type"),
|
||||||
"archival_storage_path": get_field(config, "archival_storage", "path"),
|
"archival_storage_path": get_field(config, "archival_storage", "path"),
|
||||||
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
|
"archival_storage_uri": get_field(config, "archival_storage", "uri"),
|
||||||
|
"recall_storage_type": get_field(config, "recall_storage", "type"),
|
||||||
|
"recall_storage_path": get_field(config, "recall_storage", "path"),
|
||||||
|
"recall_storage_uri": get_field(config, "recall_storage", "uri"),
|
||||||
"anon_clientid": get_field(config, "client", "anon_clientid"),
|
"anon_clientid": get_field(config, "client", "anon_clientid"),
|
||||||
"config_path": config_path,
|
"config_path": config_path,
|
||||||
"memgpt_version": get_field(config, "version", "memgpt_version"),
|
"memgpt_version": get_field(config, "version", "memgpt_version"),
|
||||||
@ -187,6 +190,11 @@ class MemGPTConfig:
|
|||||||
set_field(config, "archival_storage", "path", self.archival_storage_path)
|
set_field(config, "archival_storage", "path", self.archival_storage_path)
|
||||||
set_field(config, "archival_storage", "uri", self.archival_storage_uri)
|
set_field(config, "archival_storage", "uri", self.archival_storage_uri)
|
||||||
|
|
||||||
|
# recall storage
|
||||||
|
set_field(config, "recall_storage", "type", self.recall_storage_type)
|
||||||
|
set_field(config, "recall_storage", "path", self.recall_storage_path)
|
||||||
|
set_field(config, "recall_storage", "uri", self.recall_storage_uri)
|
||||||
|
|
||||||
# set version
|
# set version
|
||||||
set_field(config, "version", "memgpt_version", memgpt.__version__)
|
set_field(config, "version", "memgpt_version", memgpt.__version__)
|
||||||
|
|
||||||
|
@ -2,22 +2,10 @@ import chromadb
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Optional, List, Iterator
|
from typing import Optional, List, Iterator
|
||||||
from memgpt.connectors.storage import StorageConnector, Passage
|
from memgpt.connectors.storage import StorageConnector, TableType
|
||||||
from memgpt.utils import printd
|
from memgpt.utils import printd
|
||||||
from memgpt.config import AgentConfig, MemGPTConfig
|
from memgpt.config import AgentConfig, MemGPTConfig
|
||||||
|
from memgpt.data_types import Record, Message, Passage
|
||||||
|
|
||||||
def create_chroma_client():
|
|
||||||
config = MemGPTConfig.load()
|
|
||||||
# create chroma client
|
|
||||||
if config.archival_storage_path:
|
|
||||||
client = chromadb.PersistentClient(config.archival_storage_path)
|
|
||||||
else:
|
|
||||||
# assume uri={ip}:{port}
|
|
||||||
ip = config.archival_storage_uri.split(":")[0]
|
|
||||||
port = config.archival_storage_uri.split(":")[1]
|
|
||||||
client = chromadb.HttpClient(host=ip, port=port)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaStorageConnector(StorageConnector):
|
class ChromaStorageConnector(StorageConnector):
|
||||||
@ -25,21 +13,24 @@ class ChromaStorageConnector(StorageConnector):
|
|||||||
|
|
||||||
# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.
|
# WARNING: This is not thread safe. Do NOT do concurrent access to the same collection.
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||||
# determine table name
|
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||||
if agent_config:
|
config = MemGPTConfig.load()
|
||||||
assert name is None, f"Cannot specify both agent config and name {name}"
|
|
||||||
self.table_name = self.generate_table_name_agent(agent_config)
|
# supported table types
|
||||||
elif name:
|
self.supported_types = [TableType.ARCHIVAL_MEMORY]
|
||||||
assert agent_config is None, f"Cannot specify both agent config and name {name}"
|
|
||||||
self.table_name = self.generate_table_name(name)
|
if table_type not in self.supported_types:
|
||||||
|
raise ValueError(f"Table type {table_type} not supported by Chroma")
|
||||||
|
|
||||||
|
# create chroma client
|
||||||
|
if config.archival_storage_path:
|
||||||
|
self.client = chromadb.PersistentClient(config.archival_storage_path)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Must specify either agent config or name")
|
# assume uri={ip}:{port}
|
||||||
|
ip = config.archival_storage_uri.split(":")[0]
|
||||||
printd(f"Using table name {self.table_name}")
|
port = config.archival_storage_uri.split(":")[1]
|
||||||
|
self.client = chromadb.HttpClient(host=ip, port=port)
|
||||||
# create client
|
|
||||||
self.client = create_chroma_client()
|
|
||||||
|
|
||||||
# get a collection or create if it doesn't exist already
|
# get a collection or create if it doesn't exist already
|
||||||
self.collection = self.client.get_or_create_collection(self.table_name)
|
self.collection = self.client.get_or_create_collection(self.table_name)
|
||||||
@ -98,28 +89,5 @@ class ChromaStorageConnector(StorageConnector):
|
|||||||
collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")]
|
collections = [c.name for c in collections if c.name.startswith("memgpt_") and not c.name.startswith("memgpt_agent_")]
|
||||||
return collections
|
return collections
|
||||||
|
|
||||||
def sanitize_table_name(self, name: str) -> str:
|
|
||||||
# Remove leading and trailing whitespace
|
|
||||||
name = name.strip()
|
|
||||||
|
|
||||||
# Replace spaces and invalid characters with underscores
|
|
||||||
name = re.sub(r"\s+|\W+", "_", name)
|
|
||||||
|
|
||||||
# Truncate to the maximum identifier length (e.g., 63 for PostgreSQL)
|
|
||||||
max_length = 63
|
|
||||||
if len(name) > max_length:
|
|
||||||
name = name[:max_length].rstrip("_")
|
|
||||||
|
|
||||||
# Convert to lowercase
|
|
||||||
name = name.lower()
|
|
||||||
|
|
||||||
return name
|
|
||||||
|
|
||||||
def generate_table_name_agent(self, agent_config: AgentConfig):
|
|
||||||
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
|
|
||||||
|
|
||||||
def generate_table_name(self, name: str):
|
|
||||||
return f"memgpt_{self.sanitize_table_name(name)}"
|
|
||||||
|
|
||||||
def size(self) -> int:
|
def size(self) -> int:
|
||||||
return self.collection.count()
|
return self.collection.count()
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from pgvector.psycopg import register_vector
|
from pgvector.psycopg import register_vector
|
||||||
from pgvector.sqlalchemy import Vector, JSON, Text
|
from pgvector.sqlalchemy import Vector
|
||||||
import psycopg
|
import psycopg
|
||||||
|
|
||||||
|
|
||||||
@ -8,6 +8,8 @@ from sqlalchemy.orm import sessionmaker, mapped_column
|
|||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from sqlalchemy import Column, BIGINT, String, DateTime
|
from sqlalchemy import Column, BIGINT, String, DateTime
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy_json import mutable_json_type
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -28,15 +30,10 @@ from datetime import datetime
|
|||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
def parse_formatted_time(formatted_time):
|
|
||||||
# parse times returned by memgpt.utils.get_formatted_time()
|
|
||||||
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_model(table_name: str, table_type: TableType):
|
def get_db_model(table_name: str, table_type: TableType):
|
||||||
config = MemGPTConfig.load()
|
config = MemGPTConfig.load()
|
||||||
|
|
||||||
if table_name == TableType.ARCHIVAL_MEMORY:
|
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||||
# create schema for archival memory
|
# create schema for archival memory
|
||||||
class PassageModel(Base):
|
class PassageModel(Base):
|
||||||
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
"""Defines data model for storing Passages (consisting of text, embedding)"""
|
||||||
@ -45,16 +42,29 @@ def get_db_model(table_name: str, table_type: TableType):
|
|||||||
|
|
||||||
# Assuming passage_id is the primary key
|
# Assuming passage_id is the primary key
|
||||||
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
id = Column(BIGINT, primary_key=True, nullable=False, autoincrement=True)
|
||||||
|
user_id = Column(String, nullable=False)
|
||||||
|
text = Column(String, nullable=False)
|
||||||
doc_id = Column(String)
|
doc_id = Column(String)
|
||||||
agent_id = Column(String)
|
agent_id = Column(String)
|
||||||
data_source = Column(String) # agent_name if agent, data_source name if from data source
|
data_source = Column(String) # agent_name if agent, data_source name if from data source
|
||||||
text = Column(String, nullable=False)
|
|
||||||
embedding = mapped_column(Vector(config.embedding_dim))
|
embedding = mapped_column(Vector(config.embedding_dim))
|
||||||
metadata_ = Column(JSON(astext_type=Text()))
|
metadata_ = Column(mutable_json_type(dbtype=JSONB, nested=True))
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
return f"<Passage(passage_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||||
|
|
||||||
|
def to_record(self):
|
||||||
|
return Passage(
|
||||||
|
text=self.text,
|
||||||
|
embedding=self.embedding,
|
||||||
|
doc_id=self.doc_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
id=self.id,
|
||||||
|
data_source=self.data_source,
|
||||||
|
agent_id=self.agent_id,
|
||||||
|
metadata=self.metadata_,
|
||||||
|
)
|
||||||
|
|
||||||
"""Create database model for table_name"""
|
"""Create database model for table_name"""
|
||||||
class_name = f"{table_name.capitalize()}Model"
|
class_name = f"{table_name.capitalize()}Model"
|
||||||
Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
|
Model = type(class_name, (PassageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}})
|
||||||
@ -71,7 +81,7 @@ def get_db_model(table_name: str, table_type: TableType):
|
|||||||
user_id = Column(String, nullable=False)
|
user_id = Column(String, nullable=False)
|
||||||
agent_id = Column(String, nullable=False)
|
agent_id = Column(String, nullable=False)
|
||||||
role = Column(String, nullable=False)
|
role = Column(String, nullable=False)
|
||||||
content = Column(String, nullable=False)
|
text = Column(String, nullable=False)
|
||||||
model = Column(String, nullable=False)
|
model = Column(String, nullable=False)
|
||||||
function_name = Column(String)
|
function_name = Column(String)
|
||||||
function_args = Column(String)
|
function_args = Column(String)
|
||||||
@ -82,7 +92,22 @@ def get_db_model(table_name: str, table_type: TableType):
|
|||||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Message(message_id='{self.id}', content='{self.content}', embedding='{self.embedding})>"
|
return f"<Message(message_id='{self.id}', text='{self.text}', embedding='{self.embedding})>"
|
||||||
|
|
||||||
|
def to_record(self):
|
||||||
|
return Message(
|
||||||
|
user_id=self.user_id,
|
||||||
|
agent_id=self.agent_id,
|
||||||
|
role=self.role,
|
||||||
|
text=self.text,
|
||||||
|
model=self.model,
|
||||||
|
function_name=self.function_name,
|
||||||
|
function_args=self.function_args,
|
||||||
|
function_response=self.function_response,
|
||||||
|
embedding=self.embedding,
|
||||||
|
created_at=self.created_at,
|
||||||
|
id=self.id,
|
||||||
|
)
|
||||||
|
|
||||||
"""Create database model for table_name"""
|
"""Create database model for table_name"""
|
||||||
class_name = f"{table_name.capitalize()}Model"
|
class_name = f"{table_name.capitalize()}Model"
|
||||||
@ -101,11 +126,20 @@ class PostgresStorageConnector(StorageConnector):
|
|||||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||||
config = MemGPTConfig.load()
|
config = MemGPTConfig.load()
|
||||||
|
|
||||||
|
# get storage URI
|
||||||
|
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||||
|
self.uri = config.archival_storage_uri
|
||||||
|
if config.archival_storage_uri is None:
|
||||||
|
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||||
|
elif table_type == TableType.RECALL_MEMORY:
|
||||||
|
self.uri = config.recall_storage_uri
|
||||||
|
if config.recall_storage_uri is None:
|
||||||
|
raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Table type {table_type} not implemented")
|
||||||
|
|
||||||
# create table
|
# create table
|
||||||
self.uri = config.archival_storage_uri
|
self.db_model = get_db_model(self.table_name, table_type)
|
||||||
if config.archival_storage_uri is None:
|
|
||||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
|
||||||
self.db_model = get_db_model(self.table_name)
|
|
||||||
self.engine = create_engine(self.uri)
|
self.engine = create_engine(self.uri)
|
||||||
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
||||||
self.Session = sessionmaker(bind=self.engine)
|
self.Session = sessionmaker(bind=self.engine)
|
||||||
@ -132,47 +166,47 @@ class PostgresStorageConnector(StorageConnector):
|
|||||||
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
|
def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
filters = self.get_filters(filters)
|
filters = self.get_filters(filters)
|
||||||
db_passages = session.query(self.db_model).filter(*filters).limit(limit).all()
|
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||||
return [self.type(**p.to_dict()) for p in db_passages]
|
return [record.to_record() for record in db_records]
|
||||||
|
|
||||||
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Passage]:
|
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Record]:
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
filters = self.get_filters(filters)
|
filters = self.get_filters(filters)
|
||||||
db_passage = session.query(self.db_model).filter(*filters).get(id)
|
db_record = session.query(self.db_model).filter(*filters).get(id)
|
||||||
if db_passage is None:
|
if db_record is None:
|
||||||
return None
|
return None
|
||||||
return Passage(text=db_passage.text, embedding=db_passage.embedding, doc_id=db_passage.doc_id, passage_id=db_passage.passage_id)
|
return db_record.to_record()
|
||||||
|
|
||||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||||
# return size of table
|
# return size of table
|
||||||
|
print("size")
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
filters = self.get_filters(filters)
|
filters = self.get_filters(filters)
|
||||||
return session.query(self.db_model).filter(*filters).count()
|
return session.query(self.db_model).filter(*filters).count()
|
||||||
|
|
||||||
def insert(self, passage: Passage):
|
def insert(self, record: Record):
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
|
db_record = self.db_model(**vars(record))
|
||||||
session.add(db_passage)
|
session.add(db_record)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def insert_many(self, records: List[Record], show_progress=True):
|
def insert_many(self, records: List[Record], show_progress=True):
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
iterable = tqdm(records) if show_progress else records
|
iterable = tqdm(records) if show_progress else records
|
||||||
for passage in iterable:
|
for record in iterable:
|
||||||
db_passage = self.db_model(doc_id=passage.doc_id, text=passage.text, embedding=passage.embedding)
|
db_record = self.db_model(**vars(record))
|
||||||
session.add(db_passage)
|
session.add(db_record)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||||
session = self.Session()
|
session = self.Session()
|
||||||
# Assuming PassageModel.embedding has the capability of computing l2_distance
|
|
||||||
filters = self.get_filters(filters)
|
filters = self.get_filters(filters)
|
||||||
results = session.scalars(
|
results = session.scalars(
|
||||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||||
).all()
|
).all()
|
||||||
|
|
||||||
# Convert the results into Passage objects
|
# Convert the results into Passage objects
|
||||||
records = [self.type(**vars(result)) for result in results]
|
records = [result.to_record() for result in results]
|
||||||
return records
|
return records
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
@ -195,6 +229,26 @@ class PostgresStorageConnector(StorageConnector):
|
|||||||
tables = [table[start_chars:] for table in tables]
|
tables = [table[start_chars:] for table in tables]
|
||||||
return tables
|
return tables
|
||||||
|
|
||||||
|
def query_date(self, start_date, end_date):
|
||||||
|
session = self.Session()
|
||||||
|
filters = self.get_filters({})
|
||||||
|
results = (
|
||||||
|
session.query(self.db_model)
|
||||||
|
.filter(*filters)
|
||||||
|
.filter(self.db_model.created_at >= start_date)
|
||||||
|
.filter(self.db_model.created_at <= end_date)
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
return [result.to_record() for result in results]
|
||||||
|
|
||||||
|
def query_text(self, query):
|
||||||
|
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
|
||||||
|
session = self.Session()
|
||||||
|
filters = self.get_filters({})
|
||||||
|
results = session.query(self.db_model).filter(*filters).filter(self.db_model.text.contains(query)).all()
|
||||||
|
print(results)
|
||||||
|
# return [self.type(**vars(result)) for result in results]
|
||||||
|
return [result.to_record() for result in results]
|
||||||
|
|
||||||
class LanceDBConnector(StorageConnector):
|
class LanceDBConnector(StorageConnector):
|
||||||
"""Storage via LanceDB"""
|
"""Storage via LanceDB"""
|
||||||
@ -251,7 +305,7 @@ class LanceDBConnector(StorageConnector):
|
|||||||
|
|
||||||
def get(self, id: str) -> Optional[Passage]:
|
def get(self, id: str) -> Optional[Passage]:
|
||||||
db_passage = self.table.where(f"passage_id={id}").to_list()
|
db_passage = self.table.where(f"passage_id={id}").to_list()
|
||||||
if len(db_passage) == 0:
|
if len(db_passage) == 0:
|
||||||
return None
|
return None
|
||||||
return Passage(
|
return Passage(
|
||||||
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
|
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
|
||||||
|
@ -6,7 +6,8 @@ import pickle
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Dict
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context
|
from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context
|
||||||
from llama_index.indices.empty.base import EmptyIndex
|
from llama_index.indices.empty.base import EmptyIndex
|
||||||
@ -145,11 +146,9 @@ class InMemoryStorageConnector(StorageConnector):
|
|||||||
|
|
||||||
# TODO: maybae replace this with sqllite?
|
# TODO: maybae replace this with sqllite?
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||||
from memgpt.embeddings import embedding_model
|
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||||
|
|
||||||
config = MemGPTConfig.load()
|
config = MemGPTConfig.load()
|
||||||
# TODO: figure out save location
|
|
||||||
|
|
||||||
self.rows = []
|
self.rows = []
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Optional, List, Iterator
|
|||||||
import re
|
import re
|
||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
|
from abc import abstractmethod
|
||||||
|
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -14,6 +14,7 @@ from tqdm import tqdm
|
|||||||
|
|
||||||
from memgpt.config import AgentConfig, MemGPTConfig
|
from memgpt.config import AgentConfig, MemGPTConfig
|
||||||
from memgpt.data_types import Record, Passage, Document, Message
|
from memgpt.data_types import Record, Passage, Document, Message
|
||||||
|
from memgpt.utils import printd
|
||||||
|
|
||||||
|
|
||||||
# ENUM representing table types in MemGPT
|
# ENUM representing table types in MemGPT
|
||||||
@ -28,9 +29,9 @@ class TableType:
|
|||||||
|
|
||||||
|
|
||||||
# table names used by MemGPT
|
# table names used by MemGPT
|
||||||
RECALL_TABLE_NAME = "memgpt_recall_memory"
|
RECALL_TABLE_NAME = "memgpt_recall_memory_agent" # agent memory
|
||||||
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory"
|
ARCHIVAL_TABLE_NAME = "memgpt_archival_memory_agent" # agent memory
|
||||||
PASSAGE_TABLE_NAME = "memgpt_passages"
|
PASSAGE_TABLE_NAME = "memgpt_passages" # loads data sources
|
||||||
DOCUMENT_TABLE_NAME = "memgpt_documents"
|
DOCUMENT_TABLE_NAME = "memgpt_documents"
|
||||||
|
|
||||||
|
|
||||||
@ -65,9 +66,10 @@ class StorageConnector:
|
|||||||
# get all filters for query
|
# get all filters for query
|
||||||
if filters is not None:
|
if filters is not None:
|
||||||
filter_conditions = {**self.filters, **filters}
|
filter_conditions = {**self.filters, **filters}
|
||||||
return self.filters + [self.db_model[key] == value for key, value in filter_conditions.items()]
|
|
||||||
else:
|
else:
|
||||||
return self.filters
|
filter_conditions = self.filters
|
||||||
|
print("FILTERS", filter_conditions)
|
||||||
|
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||||
|
|
||||||
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
|
def generate_table_name(self, agent_config: AgentConfig, table_type: TableType):
|
||||||
|
|
||||||
@ -102,18 +104,20 @@ class StorageConnector:
|
|||||||
if storage_type == "local":
|
if storage_type == "local":
|
||||||
from memgpt.connectors.local import VectorIndexStorageConnector
|
from memgpt.connectors.local import VectorIndexStorageConnector
|
||||||
|
|
||||||
return VectorIndexStorageConnector(agent_config=agent_config)
|
return VectorIndexStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||||
|
|
||||||
elif storage_type == "postgres":
|
elif storage_type == "postgres":
|
||||||
from memgpt.connectors.db import PostgresStorageConnector
|
from memgpt.connectors.db import PostgresStorageConnector
|
||||||
|
|
||||||
return PostgresStorageConnector(agent_config=agent_config)
|
return PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||||
|
elif storage_type == "chroma":
|
||||||
|
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||||
|
|
||||||
return ChromaStorageConnector(name=name, agent_config=agent_config)
|
return ChromaStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||||
elif storage_type == "lancedb":
|
elif storage_type == "lancedb":
|
||||||
from memgpt.connectors.db import LanceDBConnector
|
from memgpt.connectors.db import LanceDBConnector
|
||||||
|
|
||||||
return LanceDBConnector(agent_config=agent_config)
|
return LanceDBConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
raise NotImplementedError(f"Storage type {storage_type} not implemented")
|
||||||
@ -122,6 +126,8 @@ class StorageConnector:
|
|||||||
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
|
def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None):
|
||||||
storage_type = MemGPTConfig.load().recall_storage_type
|
storage_type = MemGPTConfig.load().recall_storage_type
|
||||||
|
|
||||||
|
print("Recall storage type", storage_type)
|
||||||
|
|
||||||
if storage_type == "local":
|
if storage_type == "local":
|
||||||
from memgpt.connectors.local import InMemoryStorageConnector
|
from memgpt.connectors.local import InMemoryStorageConnector
|
||||||
|
|
||||||
|
@ -20,11 +20,10 @@ class Record:
|
|||||||
self.text = text
|
self.text = text
|
||||||
self.id = id
|
self.id = id
|
||||||
# todo: generate unique uuid
|
# todo: generate unique uuid
|
||||||
# todo: timestamp
|
|
||||||
# todo: self.role = role (?)
|
# todo: self.role = role (?)
|
||||||
|
|
||||||
def __repr__(self):
|
# def __repr__(self):
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
|
|
||||||
class Message(Record):
|
class Message(Record):
|
||||||
@ -35,17 +34,19 @@ class Message(Record):
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
role: str,
|
role: str,
|
||||||
content: str,
|
text: str,
|
||||||
model: str, # model used to make function call
|
model: str, # model used to make function call
|
||||||
|
created_at: Optional[str] = None,
|
||||||
function_name: Optional[str] = None, # name of function called
|
function_name: Optional[str] = None, # name of function called
|
||||||
function_args: Optional[str] = None, # args of function called
|
function_args: Optional[str] = None, # args of function called
|
||||||
function_response: Optional[str] = None, # response of function called
|
function_response: Optional[str] = None, # response of function called
|
||||||
embedding: Optional[np.ndarray] = None,
|
embedding: Optional[np.ndarray] = None,
|
||||||
id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
super().__init__(user_id, agent_id, content, id)
|
super().__init__(user_id, agent_id, text, id)
|
||||||
self.role = role # role (agent/user/function)
|
self.role = role # role (agent/user/function)
|
||||||
self.model = model # model name (e.g. gpt-4)
|
self.model = model # model name (e.g. gpt-4)
|
||||||
|
self.created_at = created_at
|
||||||
|
|
||||||
# function call info (optional)
|
# function call info (optional)
|
||||||
self.function_name = function_name
|
self.function_name = function_name
|
||||||
@ -55,15 +56,15 @@ class Message(Record):
|
|||||||
# embedding (optional)
|
# embedding (optional)
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
|
|
||||||
def __repr__(self):
|
# def __repr__(self):
|
||||||
pass
|
# pass
|
||||||
|
|
||||||
|
|
||||||
class Document(Record):
|
class Document(Record):
|
||||||
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
|
"""A document represent a document loaded into MemGPT, which is broken down into passages."""
|
||||||
|
|
||||||
def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
|
def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None):
|
||||||
super().__init__(user_id)
|
super().__init__(user_id, agent_id, text, id)
|
||||||
self.text = text
|
self.text = text
|
||||||
self.document_id = document_id
|
self.document_id = document_id
|
||||||
self.data_source = data_source
|
self.data_source = data_source
|
||||||
@ -76,25 +77,26 @@ class Document(Record):
|
|||||||
class Passage(Record):
|
class Passage(Record):
|
||||||
"""A passage is a single unit of memory, and a standard format accross all storage backends.
|
"""A passage is a single unit of memory, and a standard format accross all storage backends.
|
||||||
|
|
||||||
It is a string of text with an associated embedding.
|
It is a string of text with an assoidciated embedding.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
text: str,
|
text: str,
|
||||||
data_source: str,
|
agent_id: Optional[str] = None, # set if contained in agent memory
|
||||||
embedding: np.ndarray,
|
embedding: Optional[np.ndarray] = None,
|
||||||
|
data_source: Optional[str] = None, # None if created by agent
|
||||||
doc_id: Optional[str] = None,
|
doc_id: Optional[str] = None,
|
||||||
passage_id: Optional[str] = None,
|
id: Optional[str] = None,
|
||||||
|
metadata: Optional[dict] = {},
|
||||||
):
|
):
|
||||||
super().__init__(user_id)
|
super().__init__(user_id, agent_id, text, id)
|
||||||
self.text = text
|
self.text = text
|
||||||
self.data_source = data_source
|
self.data_source = data_source
|
||||||
self.embedding = embedding
|
self.embedding = embedding
|
||||||
self.doc_id = doc_id
|
self.doc_id = doc_id
|
||||||
self.passage_id = passage_id
|
self.metadata = metadata
|
||||||
self.metadata = {}
|
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"Passage(text={self.text}, embedding={self.embedding})"
|
return f"Passage(text={self.text}, embedding={self.embedding})"
|
||||||
|
@ -284,7 +284,10 @@ class DummyRecallMemory(RecallMemory):
|
|||||||
return matches, len(matches)
|
return matches, len(matches)
|
||||||
|
|
||||||
|
|
||||||
class RecallMemorySQL(RecallMemory):
|
class BaseRecallMemory(RecallMemory):
|
||||||
|
|
||||||
|
"""Recall memory based on base functions implemented by storage connectors"""
|
||||||
|
|
||||||
def __init__(self, agent_config, restrict_search_to_summaries=False):
|
def __init__(self, agent_config, restrict_search_to_summaries=False):
|
||||||
|
|
||||||
# If true, the pool of messages that can be queried are the automated summaries only
|
# If true, the pool of messages that can be queried are the automated summaries only
|
||||||
@ -304,20 +307,39 @@ class RecallMemorySQL(RecallMemory):
|
|||||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def text_search(self, query_string, count=None, start=None):
|
def text_search(self, query_string, count=None, start=None):
|
||||||
pass
|
self.storage.query_text(query_string, count, start)
|
||||||
|
|
||||||
@abstractmethod
|
def date_search(self, start_date, end_date, count=None, start=None):
|
||||||
def date_search(self, query_string, count=None, start=None):
|
self.storage.query_date(start_date, end_date, count, start)
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
pass
|
total = self.storage.size()
|
||||||
|
system_count = self.storage.size(filters={"role": "system"})
|
||||||
|
user_count = self.storage.size(filters={"role": "user"})
|
||||||
|
assistant_count = self.storage.size(filters={"role": "assistant"})
|
||||||
|
function_count = self.storage.size(filters={"role": "function"})
|
||||||
|
other_count = total - (system_count + user_count + assistant_count + function_count)
|
||||||
|
|
||||||
|
memory_str = (
|
||||||
|
f"Statistics:"
|
||||||
|
+ f"\n{total} total messages"
|
||||||
|
+ f"\n{system_count} system"
|
||||||
|
+ f"\n{user_count} user"
|
||||||
|
+ f"\n{assistant_count} assistant"
|
||||||
|
+ f"\n{function_count} function"
|
||||||
|
+ f"\n{other_count} other"
|
||||||
|
)
|
||||||
|
return f"\n### RECALL MEMORY ###" + f"\n{memory_str}"
|
||||||
|
|
||||||
def insert(self, message: Message):
|
def insert(self, message: Message):
|
||||||
pass
|
self.storage.insert(message)
|
||||||
|
|
||||||
|
def insert_many(self, messages: List[Message]):
|
||||||
|
self.storage.insert_many(messages)
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
self.storage.save()
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingArchivalMemory(ArchivalMemory):
|
class EmbeddingArchivalMemory(ArchivalMemory):
|
||||||
@ -333,24 +355,31 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
|||||||
|
|
||||||
self.top_k = top_k
|
self.top_k = top_k
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
config = MemGPTConfig.load()
|
self.config = MemGPTConfig.load()
|
||||||
|
|
||||||
# create embedding model
|
# create embedding model
|
||||||
self.embed_model = embedding_model()
|
self.embed_model = embedding_model()
|
||||||
self.embedding_chunk_size = config.embedding_chunk_size
|
self.embedding_chunk_size = self.config.embedding_chunk_size
|
||||||
|
|
||||||
# create storage backend
|
# create storage backend
|
||||||
self.storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config)
|
self.storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config)
|
||||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
def create_passage(self, text, embedding):
|
||||||
|
return Passage(
|
||||||
|
user_id=self.config.anon_clientid,
|
||||||
|
agent_id=self.agent_config.name,
|
||||||
|
text=text,
|
||||||
|
embedding=embedding,
|
||||||
|
)
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
"""Save the index to disk"""
|
"""Save the index to disk"""
|
||||||
self.storage.save()
|
self.storage.save()
|
||||||
|
|
||||||
def insert(self, memory_string):
|
def insert(self, memory_string):
|
||||||
"""Embed and save memory string"""
|
"""Embed and save memory string"""
|
||||||
from memgpt.connectors.storage import Passage
|
|
||||||
|
|
||||||
if not isinstance(memory_string, str):
|
if not isinstance(memory_string, str):
|
||||||
return TypeError("memory must be a string")
|
return TypeError("memory must be a string")
|
||||||
@ -364,17 +393,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
|||||||
# breakup string into passages
|
# breakup string into passages
|
||||||
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
|
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
|
||||||
embedding = self.embed_model.get_text_embedding(node.text)
|
embedding = self.embed_model.get_text_embedding(node.text)
|
||||||
# fixing weird bug where type returned isn't a list, but instead is an object
|
passages.append(self.create_passage(node.text, embedding))
|
||||||
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
|
|
||||||
if isinstance(embedding, dict):
|
|
||||||
try:
|
|
||||||
embedding = embedding["data"][0]["embedding"]
|
|
||||||
except (KeyError, IndexError):
|
|
||||||
# TODO as a fallback, see if we can find any lists in the payload
|
|
||||||
raise TypeError(
|
|
||||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
|
||||||
)
|
|
||||||
passages.append(Passage(text=node.text, embedding=embedding, doc_id=f"agent_{self.agent_config.name}_memory"))
|
|
||||||
|
|
||||||
# insert passages
|
# insert passages
|
||||||
self.storage.insert_many(passages)
|
self.storage.insert_many(passages)
|
||||||
|
@ -3,9 +3,19 @@ import pickle
|
|||||||
from memgpt.config import AgentConfig
|
from memgpt.config import AgentConfig
|
||||||
from memgpt.memory import (
|
from memgpt.memory import (
|
||||||
DummyRecallMemory,
|
DummyRecallMemory,
|
||||||
|
BaseRecallMemory,
|
||||||
EmbeddingArchivalMemory,
|
EmbeddingArchivalMemory,
|
||||||
)
|
)
|
||||||
from memgpt.utils import get_local_time, printd, OpenAIBackcompatUnpickler
|
from memgpt.utils import get_local_time, printd
|
||||||
|
from memgpt.data_types import Message
|
||||||
|
from memgpt.config import MemGPTConfig
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def parse_formatted_time(formatted_time):
|
||||||
|
# parse times returned by memgpt.utils.get_formatted_time()
|
||||||
|
return datetime.strptime(formatted_time, "%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||||
|
|
||||||
|
|
||||||
class PersistenceManager(ABC):
|
class PersistenceManager(ABC):
|
||||||
@ -33,83 +43,60 @@ class PersistenceManager(ABC):
|
|||||||
class LocalStateManager(PersistenceManager):
|
class LocalStateManager(PersistenceManager):
|
||||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
||||||
|
|
||||||
recall_memory_cls = DummyRecallMemory
|
recall_memory_cls = BaseRecallMemory
|
||||||
archival_memory_cls = EmbeddingArchivalMemory
|
archival_memory_cls = EmbeddingArchivalMemory
|
||||||
|
|
||||||
def __init__(self, agent_config: AgentConfig):
|
def __init__(self, agent_config: AgentConfig):
|
||||||
# Memory held in-state useful for debugging stateful versions
|
# Memory held in-state useful for debugging stateful versions
|
||||||
self.memory = None
|
self.memory = None
|
||||||
self.messages = []
|
self.messages = [] # current in-context messages
|
||||||
self.all_messages = []
|
# self.all_messages = [] # all messages seen in current session (needed if lazily synchronizing state with DB)
|
||||||
self.archival_memory = EmbeddingArchivalMemory(agent_config)
|
self.archival_memory = EmbeddingArchivalMemory(agent_config)
|
||||||
self.recall_memory = None
|
self.recall_memory = BaseRecallMemory(agent_config)
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
|
self.config = MemGPTConfig.load()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, filename, agent_config: AgentConfig):
|
def load(cls, agent_config: AgentConfig):
|
||||||
""" Load a LocalStateManager from a file. """ ""
|
""" Load a LocalStateManager from a file. """ ""
|
||||||
try:
|
# TODO: remove this class and just init the class
|
||||||
with open(filename, "rb") as f:
|
|
||||||
data = pickle.load(f)
|
|
||||||
except ModuleNotFoundError as e:
|
|
||||||
# Patch for stripped openai package
|
|
||||||
# ModuleNotFoundError: No module named 'openai.openai_object'
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
unpickler = OpenAIBackcompatUnpickler(f)
|
|
||||||
data = unpickler.load()
|
|
||||||
# print(f"Unpickled data:\n{data.keys()}")
|
|
||||||
|
|
||||||
from memgpt.openai_backcompat.openai_object import OpenAIObject
|
|
||||||
|
|
||||||
def convert_openai_objects_to_dict(obj):
|
|
||||||
if isinstance(obj, OpenAIObject):
|
|
||||||
# Convert to dict or handle as needed
|
|
||||||
# print(f"detected OpenAIObject on {obj}")
|
|
||||||
return obj.to_dict_recursive()
|
|
||||||
elif isinstance(obj, dict):
|
|
||||||
return {k: convert_openai_objects_to_dict(v) for k, v in obj.items()}
|
|
||||||
elif isinstance(obj, list):
|
|
||||||
return [convert_openai_objects_to_dict(v) for v in obj]
|
|
||||||
else:
|
|
||||||
return obj
|
|
||||||
|
|
||||||
data = convert_openai_objects_to_dict(data)
|
|
||||||
# print(f"Converted data:\n{data.keys()}")
|
|
||||||
|
|
||||||
manager = cls(agent_config)
|
manager = cls(agent_config)
|
||||||
manager.all_messages = data["all_messages"]
|
|
||||||
manager.messages = data["messages"]
|
|
||||||
manager.recall_memory = data["recall_memory"]
|
|
||||||
manager.archival_memory = EmbeddingArchivalMemory(agent_config)
|
|
||||||
return manager
|
return manager
|
||||||
|
|
||||||
def save(self, filename):
|
def save(self):
|
||||||
with open(filename, "wb") as fh:
|
"""Ensure storage connectors save data"""
|
||||||
## TODO: fix this hacky solution to pickle the retriever
|
self.archival_memory.save()
|
||||||
self.archival_memory.save()
|
self.recall_memory.save()
|
||||||
pickle.dump(
|
|
||||||
{
|
|
||||||
"recall_memory": self.recall_memory,
|
|
||||||
"messages": self.messages,
|
|
||||||
"all_messages": self.all_messages,
|
|
||||||
},
|
|
||||||
fh,
|
|
||||||
protocol=pickle.HIGHEST_PROTOCOL,
|
|
||||||
)
|
|
||||||
printd(f"Saved state to {fh}")
|
|
||||||
|
|
||||||
def init(self, agent):
|
def init(self, agent):
|
||||||
|
"""Connect persistent state manager to agent"""
|
||||||
printd(f"Initializing {self.__class__.__name__} with agent object")
|
printd(f"Initializing {self.__class__.__name__} with agent object")
|
||||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
# self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||||
self.memory = agent.memory
|
self.memory = agent.memory
|
||||||
printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
# printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
||||||
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
||||||
|
|
||||||
# Persistence manager also handles DB-related state
|
# Persistence manager also handles DB-related state
|
||||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
# self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||||
|
|
||||||
# TODO: init archival memory here?
|
def json_to_message(self, message_json) -> Message:
|
||||||
|
"""Convert agent message JSON into Message object"""
|
||||||
|
timestamp = message_json["timestamp"]
|
||||||
|
message = message_json["message"]
|
||||||
|
|
||||||
|
return Message(
|
||||||
|
user_id=self.config.anon_clientid,
|
||||||
|
agent_id=self.agent_config.name,
|
||||||
|
role=message["role"],
|
||||||
|
text=message["content"],
|
||||||
|
model=self.agent_config.model,
|
||||||
|
created_at=parse_formatted_time(timestamp),
|
||||||
|
function_name=message["function_name"] if "function_name" in message else None,
|
||||||
|
function_args=message["function_args"] if "function_args" in message else None,
|
||||||
|
function_response=message["function_response"] if "function_response" in message else None,
|
||||||
|
id=message["id"] if "id" in message else None,
|
||||||
|
)
|
||||||
|
|
||||||
def trim_messages(self, num):
|
def trim_messages(self, num):
|
||||||
# printd(f"InMemoryStateManager.trim_messages")
|
# printd(f"InMemoryStateManager.trim_messages")
|
||||||
@ -121,7 +108,9 @@ class LocalStateManager(PersistenceManager):
|
|||||||
|
|
||||||
printd(f"{self.__class__.__name__}.prepend_to_message")
|
printd(f"{self.__class__.__name__}.prepend_to_message")
|
||||||
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||||
self.all_messages.extend(added_messages)
|
|
||||||
|
# add to recall memory
|
||||||
|
self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages])
|
||||||
|
|
||||||
def append_to_messages(self, added_messages):
|
def append_to_messages(self, added_messages):
|
||||||
# first tag with timestamps
|
# first tag with timestamps
|
||||||
@ -129,7 +118,9 @@ class LocalStateManager(PersistenceManager):
|
|||||||
|
|
||||||
printd(f"{self.__class__.__name__}.append_to_messages")
|
printd(f"{self.__class__.__name__}.append_to_messages")
|
||||||
self.messages = self.messages + added_messages
|
self.messages = self.messages + added_messages
|
||||||
self.all_messages.extend(added_messages)
|
|
||||||
|
# add to recall memory
|
||||||
|
self.recall_memory.insert_many([self.json_to_message(m) for m in added_messages])
|
||||||
|
|
||||||
def swap_system_message(self, new_system_message):
|
def swap_system_message(self, new_system_message):
|
||||||
# first tag with timestamps
|
# first tag with timestamps
|
||||||
@ -137,96 +128,9 @@ class LocalStateManager(PersistenceManager):
|
|||||||
|
|
||||||
printd(f"{self.__class__.__name__}.swap_system_message")
|
printd(f"{self.__class__.__name__}.swap_system_message")
|
||||||
self.messages[0] = new_system_message
|
self.messages[0] = new_system_message
|
||||||
self.all_messages.append(new_system_message)
|
|
||||||
|
# add to recall memory
|
||||||
def update_memory(self, new_memory):
|
self.recall_memory.insert(self.json_to_message(new_system_message))
|
||||||
printd(f"{self.__class__.__name__}.update_memory")
|
|
||||||
self.memory = new_memory
|
|
||||||
|
|
||||||
|
|
||||||
class StateManager(PersistenceManager):
|
|
||||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
|
||||||
|
|
||||||
recall_memory_cls = RecallMemory
|
|
||||||
archival_memory_cls = EmbeddingArchivalMemory
|
|
||||||
|
|
||||||
def __init__(self, agent_config: AgentConfig):
|
|
||||||
# Memory held in-state useful for debugging stateful versions
|
|
||||||
self.memory = None
|
|
||||||
self.messages = []
|
|
||||||
self.all_messages = []
|
|
||||||
self.archival_memory = EmbeddingArchivalMemory(agent_config)
|
|
||||||
self.recall_memory = None
|
|
||||||
self.agent_config = agent_config
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, filename, agent_config: AgentConfig):
|
|
||||||
""" Load a LocalStateManager from a file. """ ""
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
data = pickle.load(f)
|
|
||||||
|
|
||||||
manager = cls(agent_config)
|
|
||||||
manager.all_messages = data["all_messages"]
|
|
||||||
manager.messages = data["messages"]
|
|
||||||
manager.recall_memory = data["recall_memory"]
|
|
||||||
manager.archival_memory = EmbeddingArchivalMemory(agent_config)
|
|
||||||
return manager
|
|
||||||
|
|
||||||
def save(self, filename):
|
|
||||||
with open(filename, "wb") as fh:
|
|
||||||
## TODO: fix this hacky solution to pickle the retriever
|
|
||||||
self.archival_memory.save()
|
|
||||||
pickle.dump(
|
|
||||||
{
|
|
||||||
"recall_memory": self.recall_memory,
|
|
||||||
"messages": self.messages,
|
|
||||||
"all_messages": self.all_messages,
|
|
||||||
},
|
|
||||||
fh,
|
|
||||||
protocol=pickle.HIGHEST_PROTOCOL,
|
|
||||||
)
|
|
||||||
printd(f"Saved state to {fh}")
|
|
||||||
|
|
||||||
def init(self, agent):
|
|
||||||
printd(f"Initializing {self.__class__.__name__} with agent object")
|
|
||||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
||||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
||||||
self.memory = agent.memory
|
|
||||||
printd(f"{self.__class__.__name__}.all_messages.len = {len(self.all_messages)}")
|
|
||||||
printd(f"{self.__class__.__name__}.messages.len = {len(self.messages)}")
|
|
||||||
|
|
||||||
# Persistence manager also handles DB-related state
|
|
||||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
|
||||||
|
|
||||||
# TODO: init archival memory here?
|
|
||||||
|
|
||||||
def trim_messages(self, num):
|
|
||||||
# printd(f"InMemoryStateManager.trim_messages")
|
|
||||||
self.messages = [self.messages[0]] + self.messages[num:]
|
|
||||||
|
|
||||||
def prepend_to_messages(self, added_messages):
|
|
||||||
# first tag with timestamps
|
|
||||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
||||||
|
|
||||||
printd(f"{self.__class__.__name__}.prepend_to_message")
|
|
||||||
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
|
||||||
self.all_messages.extend(added_messages)
|
|
||||||
|
|
||||||
def append_to_messages(self, added_messages):
|
|
||||||
# first tag with timestamps
|
|
||||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
||||||
|
|
||||||
printd(f"{self.__class__.__name__}.append_to_messages")
|
|
||||||
self.messages = self.messages + added_messages
|
|
||||||
self.all_messages.extend(added_messages)
|
|
||||||
|
|
||||||
def swap_system_message(self, new_system_message):
|
|
||||||
# first tag with timestamps
|
|
||||||
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
|
||||||
|
|
||||||
printd(f"{self.__class__.__name__}.swap_system_message")
|
|
||||||
self.messages[0] = new_system_message
|
|
||||||
self.all_messages.append(new_system_message)
|
|
||||||
|
|
||||||
def update_memory(self, new_memory):
|
def update_memory(self, new_memory):
|
||||||
printd(f"{self.__class__.__name__}.update_memory")
|
printd(f"{self.__class__.__name__}.update_memory")
|
||||||
|
@ -117,10 +117,11 @@ def get_local_time(timezone=None):
|
|||||||
time_str = get_local_time_timezone(timezone)
|
time_str = get_local_time_timezone(timezone)
|
||||||
else:
|
else:
|
||||||
# Get the current time, which will be in the local timezone of the computer
|
# Get the current time, which will be in the local timezone of the computer
|
||||||
local_time = datetime.now()
|
local_time = datetime.now().astimezone()
|
||||||
|
|
||||||
# You may format it as you desire, including AM/PM
|
# You may format it as you desire, including AM/PM
|
||||||
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||||
|
print("formatted_time", formatted_time)
|
||||||
|
|
||||||
return time_str.strip()
|
return time_str.strip()
|
||||||
|
|
||||||
|
19
poetry.lock
generated
19
poetry.lock
generated
@ -3822,6 +3822,23 @@ postgresql-psycopgbinary = ["psycopg[binary] (>=3.0.7)"]
|
|||||||
pymysql = ["pymysql"]
|
pymysql = ["pymysql"]
|
||||||
sqlcipher = ["sqlcipher3-binary"]
|
sqlcipher = ["sqlcipher3-binary"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sqlalchemy-json"
|
||||||
|
version = "0.7.0"
|
||||||
|
description = "JSON type with nested change tracking for SQLAlchemy"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">= 3.6"
|
||||||
|
files = [
|
||||||
|
{file = "sqlalchemy-json-0.7.0.tar.gz", hash = "sha256:620d0b26f648f21a8fa9127df66f55f83a5ab4ae010e5397a5c6989a08238561"},
|
||||||
|
{file = "sqlalchemy_json-0.7.0-py3-none-any.whl", hash = "sha256:27881d662ca18363a4ac28175cc47ea2a6f2bef997ae1159c151026b741818e6"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
sqlalchemy = ">=0.7"
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["pytest"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "starlette"
|
name = "starlette"
|
||||||
version = "0.27.0"
|
version = "0.27.0"
|
||||||
@ -4913,4 +4930,4 @@ server = ["fastapi", "uvicorn", "websockets"]
|
|||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "<3.12,>=3.9"
|
python-versions = "<3.12,>=3.9"
|
||||||
content-hash = "12010863b2b9c1e26dceace00ea4e1ea7cc95932ab77b1ef37a5473c2e375575"
|
content-hash = "a6b8cdeda433007b6d33441b82417e435df336cea458af5dedce9a94f07f1225"
|
||||||
|
@ -49,14 +49,8 @@ tiktoken = "^0.5.1"
|
|||||||
python-box = "^7.1.1"
|
python-box = "^7.1.1"
|
||||||
pypdf = "^3.17.1"
|
pypdf = "^3.17.1"
|
||||||
pyyaml = "^6.0.1"
|
pyyaml = "^6.0.1"
|
||||||
fastapi = {version = "^0.104.1", optional = true}
|
chromadb = {version = "^0.4.18", optional = true}
|
||||||
uvicorn = {version = "^0.24.0.post1", optional = true}
|
sqlalchemy-json = "^0.7.0"
|
||||||
chromadb = "^0.4.18"
|
|
||||||
pytest-asyncio = {version = "^0.23.2", optional = true}
|
|
||||||
pydantic = "^2.5.2"
|
|
||||||
pyautogen = {version = "0.2.0", optional = true}
|
|
||||||
html2text = "^2020.1.16"
|
|
||||||
docx2txt = "^0.8"
|
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
local = ["torch", "huggingface-hub", "transformers"]
|
local = ["torch", "huggingface-hub", "transformers"]
|
||||||
|
@ -3,63 +3,90 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
subprocess.check_call(
|
# subprocess.check_call(
|
||||||
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
# [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
||||||
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
# ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
||||||
|
#
|
||||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
# subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
||||||
import pgvector # Try to import again after installing
|
import pgvector # Try to import again after installing
|
||||||
|
|
||||||
from memgpt.connectors.storage import StorageConnector, Passage
|
from memgpt.connectors.storage import StorageConnector, TableType
|
||||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||||
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
||||||
from memgpt.embeddings import embedding_model
|
from memgpt.embeddings import embedding_model
|
||||||
from memgpt.data_types import Message, Passage
|
from memgpt.data_types import Message, Passage
|
||||||
from memgpt.config import MemGPTConfig, AgentConfig
|
from memgpt.config import MemGPTConfig, AgentConfig
|
||||||
|
from memgpt.utils import get_local_time
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
|
||||||
def test_recall_db() -> None:
|
def test_recall_db():
|
||||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||||
|
|
||||||
storage_type = "postgres"
|
storage_type = "postgres"
|
||||||
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||||
config = MemGPTConfig(recall_storage_type=storage_type, recall_storage_uri=storage_uri)
|
config = MemGPTConfig(
|
||||||
|
recall_storage_type=storage_type,
|
||||||
|
recall_storage_uri=storage_uri,
|
||||||
|
model_endpoint_type="openai",
|
||||||
|
model_endpoint="https://api.openai.com/v1",
|
||||||
|
model="gpt4",
|
||||||
|
)
|
||||||
print(config.config_path)
|
print(config.config_path)
|
||||||
assert config.recall_storage_uri is not None
|
assert config.recall_storage_uri is not None
|
||||||
config.save()
|
config.save()
|
||||||
print(config)
|
print(config)
|
||||||
|
|
||||||
conn = StorageConnector.get_recall_storage_connector()
|
agent_config = AgentConfig(
|
||||||
|
persona=config.persona,
|
||||||
|
human=config.human,
|
||||||
|
model=config.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
conn = StorageConnector.get_recall_storage_connector(agent_config)
|
||||||
|
|
||||||
# construct recall memory messages
|
# construct recall memory messages
|
||||||
message1 = Message(
|
message1 = Message(
|
||||||
agent_id="test_agent1",
|
agent_id=agent_config.name,
|
||||||
role="agent",
|
role="agent",
|
||||||
content="This is a test message",
|
text="This is a test message",
|
||||||
id="test_id1",
|
user_id=config.anon_clientid,
|
||||||
|
model=agent_config.model,
|
||||||
|
created_at=datetime.now(),
|
||||||
)
|
)
|
||||||
message2 = Message(
|
message2 = Message(
|
||||||
agent_id="test_agent2",
|
agent_id=agent_config.name,
|
||||||
role="user",
|
role="user",
|
||||||
content="This is a test message",
|
text="This is a test message",
|
||||||
id="test_id2",
|
user_id=config.anon_clientid,
|
||||||
|
model=agent_config.model,
|
||||||
|
created_at=datetime.now(),
|
||||||
)
|
)
|
||||||
|
print(vars(message1))
|
||||||
|
|
||||||
# test insert
|
# test insert
|
||||||
conn.insert(message1)
|
conn.insert(message1)
|
||||||
conn.insert_many([message2])
|
conn.insert_many([message2])
|
||||||
|
|
||||||
# test size
|
# test size
|
||||||
assert conn.size() == 2, f"Expected 2 messages, got {conn.size()}"
|
assert conn.size() >= 2, f"Expected 2 messages, got {conn.size()}"
|
||||||
assert conn.size(filters={"agent_id": "test_agent2"}) == 1, f"Expected 2 messages, got {conn.size()}"
|
assert conn.size(filters={"role": "user"}) >= 1, f'Expected 2 messages, got {conn.size(filters={"role": "user"})}'
|
||||||
|
|
||||||
# test get
|
# test text query
|
||||||
assert conn.get("test_id1") == message1, f"Expected {message1}, got {conn.get('test_id1')}"
|
res = conn.query_text("test")
|
||||||
assert (
|
print(res)
|
||||||
len(conn.get_all(limit=10, filters={"agent_id": "test_agent2"})) == 1
|
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||||
), f"Expected 1 message, got {len(conn.get_all(limit=10, filters={'agent_id': 'test_agent2'}))}"
|
|
||||||
|
# test date query
|
||||||
|
current_time = datetime.now()
|
||||||
|
ten_weeks_ago = current_time - timedelta(weeks=1)
|
||||||
|
res = conn.query_date(start_date=ten_weeks_ago, end_date=current_time)
|
||||||
|
print(res)
|
||||||
|
assert len(res) >= 2, f"Expected 2 messages, got {len(res)}"
|
||||||
|
|
||||||
|
print(conn.get_all())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
||||||
@ -83,10 +110,28 @@ def test_postgres_openai():
|
|||||||
|
|
||||||
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||||
|
|
||||||
db = PostgresStorageConnector(name="test-openai")
|
agent_config = AgentConfig(
|
||||||
|
name="test_agent",
|
||||||
|
persona=config.persona,
|
||||||
|
human=config.human,
|
||||||
|
model=config.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
db = PostgresStorageConnector(agent_config=agent_config, table_type=TableType.ARCHIVAL_MEMORY)
|
||||||
|
|
||||||
|
# db.delete()
|
||||||
|
# return
|
||||||
for passage in passage:
|
for passage in passage:
|
||||||
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
db.insert(
|
||||||
|
Passage(
|
||||||
|
text=passage,
|
||||||
|
embedding=embed_model.get_text_embedding(passage),
|
||||||
|
user_id=config.anon_clientid,
|
||||||
|
agent_id="test_agent",
|
||||||
|
data_source="test",
|
||||||
|
metadata={"test_metadata_key": "test_metadata_value"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
print(db.get_all())
|
print(db.get_all())
|
||||||
|
|
||||||
@ -246,3 +291,6 @@ def test_lancedb_local():
|
|||||||
|
|
||||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||||
|
|
||||||
|
|
||||||
|
test_recall_db()
|
||||||
|
Loading…
Reference in New Issue
Block a user