feat: return num_passages in Source.metadata_ from REST list sources endpoint (#1178)

This commit is contained in:
Sarah Wooders 2024-03-21 15:41:31 -07:00 committed by GitHub
parent 1bc88d9a3a
commit 464bda4589
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 46 additions and 26 deletions

View File

@ -5,7 +5,7 @@ import uuid
from typing import Dict, List, Union, Optional, Tuple
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel, SourceModel
from memgpt.cli.cli import QuickstartChoice
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
from memgpt.config import MemGPTConfig
@ -31,6 +31,7 @@ from memgpt.server.rest_api.personas.index import ListPersonasResponse
from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse
from memgpt.server.rest_api.models.index import ListModelsResponse
from memgpt.server.rest_api.presets.index import CreatePresetResponse, CreatePresetsRequest, ListPresetsResponse
from memgpt.server.rest_api.sources.index import ListSourcesResponse, UploadFileToSourceResponse
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
@ -437,7 +438,7 @@ class RESTClient(AbstractClient):
"""List loaded sources"""
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
response_json = response.json()
return response_json
return ListSourcesResponse(**response_json)
def delete_source(self, source_id: uuid.UUID):
"""Delete a source and associated data (including attached to agents)"""
@ -448,7 +449,7 @@ class RESTClient(AbstractClient):
"""Load {filename} and insert into source"""
files = {"file": open(filename, "rb")}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
return response.json()
return UploadFileToSourceResponse(**response.json())
def create_source(self, name: str) -> Source:
"""Create a new source"""
@ -456,13 +457,14 @@ class RESTClient(AbstractClient):
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
response_json = response.json()
print("CREATE SOURCE", response_json, response.text)
response_obj = SourceModel(**response_json)
return Source(
id=uuid.UUID(response_json["id"]),
name=response_json["name"],
user_id=uuid.UUID(response_json["user_id"]),
created_at=datetime.datetime.fromtimestamp(response_json["created_at"]),
embedding_dim=response_json["embedding_config"]["embedding_dim"],
embedding_model=response_json["embedding_config"]["embedding_model"],
id=uuid.UUID(response_obj.id),
name=response_obj.name,
user_id=uuid.UUID(response_obj.user_id),
created_at=response_obj.created_at,
embedding_dim=response_obj.embedding_config["embedding_dim"],
embedding_model=response_obj.embedding_config["embedding_model"],
)
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
@ -470,14 +472,12 @@ class RESTClient(AbstractClient):
params = {"agent_id": agent_id}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
return response.json()
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
"""Detach a source from an agent"""
params = {"agent_id": str(agent_id)}
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
return response.json()
# server configuration commands

View File

@ -39,10 +39,6 @@ class CreateSourceRequest(BaseModel):
description: Optional[str] = Field(None, description="The description of the source.")
class CreateSourceResponse(BaseModel):
source: SourceModel = Field(..., description="The newly created source.")
class UploadFileToSourceRequest(BaseModel):
file: UploadFile = Field(..., description="The file to upload.")
@ -128,7 +124,8 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
interface.clear()
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
source = server.ms.get_source(source_id=source_id, user_id=user_id)
source = server.attach_source_to_agent(source_name=source.name, agent_id=agent_id, user_id=user_id)
return SourceModel(
name=source.name,
description=None, # TODO: actually store descriptions

View File

@ -1360,8 +1360,21 @@ class SyncServer(LockingServer):
sources_with_metadata = []
for source in sources:
passages = self.list_data_source_passages(user_id=user_id, source_id=source.id)
documents = self.list_data_source_documents(user_id=user_id, source_id=source.id)
# count number of passages
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
num_passages = passage_conn.size({"data_source": source.name})
print(passage_conn.get_all())
print(
"NUMBER PASSAGES",
num_passages,
)
# TODO: add when documents table implemented
## count number of documents
# document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
# num_documents = document_conn.size({"data_source": source.name})
num_documents = 0
agent_ids = self.ms.list_attached_agents(source_id=source.id)
# add the agent name information
attached_agents = [
@ -1374,8 +1387,8 @@ class SyncServer(LockingServer):
# Overwrite metadata field, should be empty anyways
source.metadata_ = dict(
num_documents=len(passages),
num_passages=len(documents),
num_documents=num_documents,
num_passages=num_passages,
attached_agents=attached_agents,
)

View File

@ -253,6 +253,7 @@ def test_sources(client, agent):
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources.sources) == 0
# create a source
source = client.create_source(name="test_source")
@ -260,7 +261,9 @@ def test_sources(client, agent):
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 1
assert len(sources.sources) == 1
assert sources.sources[0].metadata_["num_passages"] == 0
assert sources.sources[0].metadata_["num_documents"] == 0
# check agent archival memory size
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
@ -269,18 +272,25 @@ def test_sources(client, agent):
# load a file into a source
filename = "CONTRIBUTING.md"
num_passages = 20
response = client.load_file_into_source(filename=filename, source_id=source.id)
print(response)
# TODO: make sure things run in the right order
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
assert len(archival_memories) == 0
# attach a source
# TODO: make sure things run in the right order
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
# list archival memory
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
print(archival_memories)
assert len(archival_memories) == num_passages
# print(archival_memories)
assert len(archival_memories) == 20 or len(archival_memories) == 21
# check number of passages
sources = client.list_sources()
assert sources.sources[0].metadata_["num_passages"] > 0
assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
print(sources)
# detach the source
# TODO: add when implemented