mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: return num_passages
in Source.metadata_
from REST list sources endpoint (#1178)
This commit is contained in:
parent
1bc88d9a3a
commit
464bda4589
@ -5,7 +5,7 @@ import uuid
|
|||||||
from typing import Dict, List, Union, Optional, Tuple
|
from typing import Dict, List, Union, Optional, Tuple
|
||||||
|
|
||||||
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
|
from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source
|
||||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel
|
from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel, SourceModel
|
||||||
from memgpt.cli.cli import QuickstartChoice
|
from memgpt.cli.cli import QuickstartChoice
|
||||||
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
|
from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice
|
||||||
from memgpt.config import MemGPTConfig
|
from memgpt.config import MemGPTConfig
|
||||||
@ -31,6 +31,7 @@ from memgpt.server.rest_api.personas.index import ListPersonasResponse
|
|||||||
from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse
|
from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse
|
||||||
from memgpt.server.rest_api.models.index import ListModelsResponse
|
from memgpt.server.rest_api.models.index import ListModelsResponse
|
||||||
from memgpt.server.rest_api.presets.index import CreatePresetResponse, CreatePresetsRequest, ListPresetsResponse
|
from memgpt.server.rest_api.presets.index import CreatePresetResponse, CreatePresetsRequest, ListPresetsResponse
|
||||||
|
from memgpt.server.rest_api.sources.index import ListSourcesResponse, UploadFileToSourceResponse
|
||||||
|
|
||||||
|
|
||||||
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
|
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
|
||||||
@ -437,7 +438,7 @@ class RESTClient(AbstractClient):
|
|||||||
"""List loaded sources"""
|
"""List loaded sources"""
|
||||||
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
|
response = requests.get(f"{self.base_url}/api/sources", headers=self.headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
return response_json
|
return ListSourcesResponse(**response_json)
|
||||||
|
|
||||||
def delete_source(self, source_id: uuid.UUID):
|
def delete_source(self, source_id: uuid.UUID):
|
||||||
"""Delete a source and associated data (including attached to agents)"""
|
"""Delete a source and associated data (including attached to agents)"""
|
||||||
@ -448,7 +449,7 @@ class RESTClient(AbstractClient):
|
|||||||
"""Load {filename} and insert into source"""
|
"""Load {filename} and insert into source"""
|
||||||
files = {"file": open(filename, "rb")}
|
files = {"file": open(filename, "rb")}
|
||||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
|
response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers)
|
||||||
return response.json()
|
return UploadFileToSourceResponse(**response.json())
|
||||||
|
|
||||||
def create_source(self, name: str) -> Source:
|
def create_source(self, name: str) -> Source:
|
||||||
"""Create a new source"""
|
"""Create a new source"""
|
||||||
@ -456,13 +457,14 @@ class RESTClient(AbstractClient):
|
|||||||
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
|
response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers)
|
||||||
response_json = response.json()
|
response_json = response.json()
|
||||||
print("CREATE SOURCE", response_json, response.text)
|
print("CREATE SOURCE", response_json, response.text)
|
||||||
|
response_obj = SourceModel(**response_json)
|
||||||
return Source(
|
return Source(
|
||||||
id=uuid.UUID(response_json["id"]),
|
id=uuid.UUID(response_obj.id),
|
||||||
name=response_json["name"],
|
name=response_obj.name,
|
||||||
user_id=uuid.UUID(response_json["user_id"]),
|
user_id=uuid.UUID(response_obj.user_id),
|
||||||
created_at=datetime.datetime.fromtimestamp(response_json["created_at"]),
|
created_at=response_obj.created_at,
|
||||||
embedding_dim=response_json["embedding_config"]["embedding_dim"],
|
embedding_dim=response_obj.embedding_config["embedding_dim"],
|
||||||
embedding_model=response_json["embedding_config"]["embedding_model"],
|
embedding_model=response_obj.embedding_config["embedding_model"],
|
||||||
)
|
)
|
||||||
|
|
||||||
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||||
@ -470,14 +472,12 @@ class RESTClient(AbstractClient):
|
|||||||
params = {"agent_id": agent_id}
|
params = {"agent_id": agent_id}
|
||||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
|
response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers)
|
||||||
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
assert response.status_code == 200, f"Failed to attach source to agent: {response.text}"
|
||||||
return response.json()
|
|
||||||
|
|
||||||
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||||
"""Detach a source from an agent"""
|
"""Detach a source from an agent"""
|
||||||
params = {"agent_id": str(agent_id)}
|
params = {"agent_id": str(agent_id)}
|
||||||
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
|
response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers)
|
||||||
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
assert response.status_code == 200, f"Failed to detach source from agent: {response.text}"
|
||||||
return response.json()
|
|
||||||
|
|
||||||
# server configuration commands
|
# server configuration commands
|
||||||
|
|
||||||
|
@ -39,10 +39,6 @@ class CreateSourceRequest(BaseModel):
|
|||||||
description: Optional[str] = Field(None, description="The description of the source.")
|
description: Optional[str] = Field(None, description="The description of the source.")
|
||||||
|
|
||||||
|
|
||||||
class CreateSourceResponse(BaseModel):
|
|
||||||
source: SourceModel = Field(..., description="The newly created source.")
|
|
||||||
|
|
||||||
|
|
||||||
class UploadFileToSourceRequest(BaseModel):
|
class UploadFileToSourceRequest(BaseModel):
|
||||||
file: UploadFile = Field(..., description="The file to upload.")
|
file: UploadFile = Field(..., description="The file to upload.")
|
||||||
|
|
||||||
@ -128,7 +124,8 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface,
|
|||||||
interface.clear()
|
interface.clear()
|
||||||
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
|
assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}"
|
||||||
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
|
assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}"
|
||||||
source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id)
|
source = server.ms.get_source(source_id=source_id, user_id=user_id)
|
||||||
|
source = server.attach_source_to_agent(source_name=source.name, agent_id=agent_id, user_id=user_id)
|
||||||
return SourceModel(
|
return SourceModel(
|
||||||
name=source.name,
|
name=source.name,
|
||||||
description=None, # TODO: actually store descriptions
|
description=None, # TODO: actually store descriptions
|
||||||
|
@ -1360,8 +1360,21 @@ class SyncServer(LockingServer):
|
|||||||
sources_with_metadata = []
|
sources_with_metadata = []
|
||||||
for source in sources:
|
for source in sources:
|
||||||
|
|
||||||
passages = self.list_data_source_passages(user_id=user_id, source_id=source.id)
|
# count number of passages
|
||||||
documents = self.list_data_source_documents(user_id=user_id, source_id=source.id)
|
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||||
|
num_passages = passage_conn.size({"data_source": source.name})
|
||||||
|
print(passage_conn.get_all())
|
||||||
|
print(
|
||||||
|
"NUMBER PASSAGES",
|
||||||
|
num_passages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: add when documents table implemented
|
||||||
|
## count number of documents
|
||||||
|
# document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
||||||
|
# num_documents = document_conn.size({"data_source": source.name})
|
||||||
|
num_documents = 0
|
||||||
|
|
||||||
agent_ids = self.ms.list_attached_agents(source_id=source.id)
|
agent_ids = self.ms.list_attached_agents(source_id=source.id)
|
||||||
# add the agent name information
|
# add the agent name information
|
||||||
attached_agents = [
|
attached_agents = [
|
||||||
@ -1374,8 +1387,8 @@ class SyncServer(LockingServer):
|
|||||||
|
|
||||||
# Overwrite metadata field, should be empty anyways
|
# Overwrite metadata field, should be empty anyways
|
||||||
source.metadata_ = dict(
|
source.metadata_ = dict(
|
||||||
num_documents=len(passages),
|
num_documents=num_documents,
|
||||||
num_passages=len(documents),
|
num_passages=num_passages,
|
||||||
attached_agents=attached_agents,
|
attached_agents=attached_agents,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -253,6 +253,7 @@ def test_sources(client, agent):
|
|||||||
# list sources
|
# list sources
|
||||||
sources = client.list_sources()
|
sources = client.list_sources()
|
||||||
print("listed sources", sources)
|
print("listed sources", sources)
|
||||||
|
assert len(sources.sources) == 0
|
||||||
|
|
||||||
# create a source
|
# create a source
|
||||||
source = client.create_source(name="test_source")
|
source = client.create_source(name="test_source")
|
||||||
@ -260,7 +261,9 @@ def test_sources(client, agent):
|
|||||||
# list sources
|
# list sources
|
||||||
sources = client.list_sources()
|
sources = client.list_sources()
|
||||||
print("listed sources", sources)
|
print("listed sources", sources)
|
||||||
assert len(sources) == 1
|
assert len(sources.sources) == 1
|
||||||
|
assert sources.sources[0].metadata_["num_passages"] == 0
|
||||||
|
assert sources.sources[0].metadata_["num_documents"] == 0
|
||||||
|
|
||||||
# check agent archival memory size
|
# check agent archival memory size
|
||||||
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||||
@ -269,18 +272,25 @@ def test_sources(client, agent):
|
|||||||
|
|
||||||
# load a file into a source
|
# load a file into a source
|
||||||
filename = "CONTRIBUTING.md"
|
filename = "CONTRIBUTING.md"
|
||||||
num_passages = 20
|
|
||||||
response = client.load_file_into_source(filename=filename, source_id=source.id)
|
response = client.load_file_into_source(filename=filename, source_id=source.id)
|
||||||
print(response)
|
|
||||||
|
# TODO: make sure things run in the right order
|
||||||
|
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||||
|
assert len(archival_memories) == 0
|
||||||
|
|
||||||
# attach a source
|
# attach a source
|
||||||
# TODO: make sure things run in the right order
|
|
||||||
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
|
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
|
||||||
|
|
||||||
# list archival memory
|
# list archival memory
|
||||||
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory
|
||||||
print(archival_memories)
|
# print(archival_memories)
|
||||||
assert len(archival_memories) == num_passages
|
assert len(archival_memories) == 20 or len(archival_memories) == 21
|
||||||
|
|
||||||
|
# check number of passages
|
||||||
|
sources = client.list_sources()
|
||||||
|
assert sources.sources[0].metadata_["num_passages"] > 0
|
||||||
|
assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added
|
||||||
|
print(sources)
|
||||||
|
|
||||||
# detach the source
|
# detach the source
|
||||||
# TODO: add when implemented
|
# TODO: add when implemented
|
||||||
|
Loading…
Reference in New Issue
Block a user