feat: rename block.name to block.template_name for clarity and add shared block tests (#1951)

Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
Sarah Wooders 2024-11-04 11:49:16 -08:00 committed by GitHub
parent 8d6ec808b6
commit edebfc129f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 821 additions and 674 deletions

View File

@ -17,9 +17,11 @@ config = context.config
if settings.letta_pg_uri_no_default:
config.set_main_option("sqlalchemy.url", settings.letta_pg_uri)
print(f"Using database: ", settings.letta_pg_uri)
else:
config.set_main_option("sqlalchemy.url", "sqlite:///" + os.path.join(letta_config.recall_storage_path, "sqlite.db"))
print(f"Using database: ", settings.letta_pg_uri, settings.letta_pg_uri_no_default)
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:

View File

@ -0,0 +1,36 @@
"""Rename block.name to block.template_name
Revision ID: eff245f340f9
Revises: 0c315956709d
Create Date: 2024-10-31 18:09:08.819371
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "eff245f340f9"
down_revision: Union[str, None] = "ee50a967e090"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("block", "name", new_column_name="template_name", existing_type=sa.String(), nullable=True)
# op.add_column('block', sa.Column('template_name', sa.String(), nullable=True))
# op.drop_column('block', 'name')
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.alter_column("block", "template_name", new_column_name="name", existing_type=sa.String(), nullable=True)
# op.add_column('block', sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True))
# op.drop_column('block', 'template_name')
# ### end Alembic commands ###

View File

@ -15,7 +15,7 @@ class Swarm:
self.client = create_client()
# shared memory block (shared section of context window accross agents)
self.shared_memory = Block(name="human", label="human", value="")
self.shared_memory = Block(label="human", value="")
def create_agent(
self,
@ -40,7 +40,7 @@ class Swarm:
if len(instructions) > 0
else f"You are agent with name {name}"
)
persona_block = Block(name="persona", label="persona", value=persona_value)
persona_block = Block(label="persona", value=persona_value)
memory = BasicBlockMemory(blocks=[persona_block, self.shared_memory])
agent = self.client.create_agent(

View File

@ -282,10 +282,10 @@ def run(
system_prompt = system if system else None
memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
metadata = {"human": human_obj.template_name, "persona": persona_obj.template_name}
typer.secho(f"-> {ASSISTANT_MESSAGE_CLI_SYMBOL} Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
typer.secho(f"-> {ASSISTANT_MESSAGE_CLI_SYMBOL} Using persona profile: '{persona_obj.template_name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.template_name}'", fg=typer.colors.WHITE)
# add tools
agent_state = client.create_agent(

View File

@ -59,13 +59,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
"""List all humans"""
table.field_names = ["Name", "Text"]
for human in client.list_humans():
table.add_row([human.name, human.value.replace("\n", "")[:100]])
table.add_row([human.template_name, human.value.replace("\n", "")[:100]])
print(table)
elif arg == ListChoice.personas:
"""List all personas"""
table.field_names = ["Name", "Text"]
for persona in client.list_personas():
table.add_row([persona.name, persona.value.replace("\n", "")[:100]])
table.add_row([persona.template_name, persona.value.replace("\n", "")[:100]])
print(table)
elif arg == ListChoice.sources:
"""List all data sources"""

View File

@ -859,8 +859,8 @@ class RESTClient(AbstractClient):
else:
return [Block(**block) for block in response.json()]
def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: #
request = CreateBlock(label=label, value=text, template=template, name=name)
def create_block(self, label: str, text: str, template_name: Optional[str] = None, template: bool = False) -> Block: #
request = CreateBlock(label=label, value=text, template=template, template_name=template_name)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create block: {response.text}")
@ -872,7 +872,7 @@ class RESTClient(AbstractClient):
return Block(**response.json())
def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
request = UpdateBlock(id=block_id, name=name, value=text)
request = UpdateBlock(id=block_id, template_name=name, value=text)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update block: {response.text}")
@ -926,7 +926,7 @@ class RESTClient(AbstractClient):
Returns:
human (Human): Human block
"""
return self.create_block(label="human", name=name, text=text, template=True)
return self.create_block(label="human", template_name=name, text=text, template=True)
def update_human(self, human_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Human:
"""
@ -939,7 +939,7 @@ class RESTClient(AbstractClient):
Returns:
human (Human): Updated human block
"""
request = UpdateHuman(id=human_id, name=name, value=text)
request = UpdateHuman(id=human_id, template_name=name, value=text)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{human_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update human: {response.text}")
@ -966,7 +966,7 @@ class RESTClient(AbstractClient):
Returns:
persona (Persona): Persona block
"""
return self.create_block(label="persona", name=name, text=text, template=True)
return self.create_block(label="persona", template_name=name, text=text, template=True)
def update_persona(self, persona_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Persona:
"""
@ -979,7 +979,7 @@ class RESTClient(AbstractClient):
Returns:
persona (Persona): Updated persona block
"""
request = UpdatePersona(id=persona_id, name=name, value=text)
request = UpdatePersona(id=persona_id, template_name=name, value=text)
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{persona_id}", json=request.model_dump(), headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to update persona: {response.text}")
@ -2116,7 +2116,7 @@ class LocalClient(AbstractClient):
Returns:
human (Human): Human block
"""
return self.server.create_block(CreateHuman(name=name, value=text, user_id=self.user_id), user_id=self.user_id)
return self.server.create_block(CreateHuman(template_name=name, value=text, user_id=self.user_id), user_id=self.user_id)
def create_persona(self, name: str, text: str):
"""
@ -2129,7 +2129,7 @@ class LocalClient(AbstractClient):
Returns:
persona (Persona): Persona block
"""
return self.server.create_block(CreatePersona(name=name, value=text, user_id=self.user_id), user_id=self.user_id)
return self.server.create_block(CreatePersona(template_name=name, value=text, user_id=self.user_id), user_id=self.user_id)
def list_humans(self):
"""
@ -2635,7 +2635,7 @@ class LocalClient(AbstractClient):
"""
return self.server.get_blocks(label=label, template=templates_only)
def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: #
def create_block(self, label: str, text: str, template_name: Optional[str] = None, template: bool = False) -> Block: #
"""
Create a block
@ -2648,7 +2648,7 @@ class LocalClient(AbstractClient):
block (Block): Created block
"""
return self.server.create_block(
CreateBlock(label=label, name=name, value=text, user_id=self.user_id, template=template), user_id=self.user_id
CreateBlock(label=label, template_name=template_name, value=text, user_id=self.user_id, template=template), user_id=self.user_id
)
def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
@ -2663,7 +2663,7 @@ class LocalClient(AbstractClient):
Returns:
block (Block): Updated block
"""
return self.server.update_block(UpdateBlock(id=block_id, name=name, value=text))
return self.server.update_block(UpdateBlock(id=block_id, template_name=name, value=text))
def get_block(self, block_id: str) -> Block:
"""

View File

@ -139,32 +139,73 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
return schema
def generate_schema_from_args_schema(
def generate_schema_from_args_schema_v1(
args_schema: Type[V1BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
) -> Dict[str, Any]:
properties = {}
required = []
for field_name, field in args_schema.__fields__.items():
if field.type_.__name__ == "str":
if field.type_ == str:
field_type = "string"
elif field.type_.__name__ == "int":
elif field.type_ == int:
field_type = "integer"
elif field.type_.__name__ == "bool":
elif field.type_ == bool:
field_type = "boolean"
else:
field_type = field.type_.__name__
properties[field_name] = {"type": field_type, "description": field.field_info.description}
properties[field_name] = {
"type": field_type,
"description": field.field_info.description,
}
if field.required:
required.append(field_name)
# Construct the OpenAI function call JSON object
function_call_json = {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": properties, "required": required},
}
# append heartbeat (necessary for triggering another reasoning step after this tool call)
if append_heartbeat:
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
function_call_json["parameters"]["required"].append("request_heartbeat")
return function_call_json
def generate_schema_from_args_schema_v2(
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
) -> Dict[str, Any]:
properties = {}
required = []
for field_name, field in args_schema.model_fields.items():
field_type_annotation = field.annotation
if field_type_annotation == str:
field_type = "string"
elif field_type_annotation == int:
field_type = "integer"
elif field_type_annotation == bool:
field_type = "boolean"
else:
field_type = field_type_annotation.__name__
properties[field_name] = {
"type": field_type,
"description": field.description,
}
if field.is_required():
required.append(field_name)
function_call_json = {
"name": name,
"description": description,
"parameters": {"type": "object", "properties": properties, "required": required},
}
if append_heartbeat:
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",

View File

@ -348,7 +348,7 @@ class BlockModel(Base):
id = Column(String, primary_key=True, nullable=False)
value = Column(String, nullable=False)
limit = Column(BIGINT)
name = Column(String)
template_name = Column(String, nullable=True, default=None)
template = Column(Boolean, default=False) # True: listed as possible human/persona
label = Column(String, nullable=False)
metadata_ = Column(JSON)
@ -357,7 +357,7 @@ class BlockModel(Base):
Index(__tablename__ + "_idx_user", user_id),
def __repr__(self) -> str:
return f"<Block(id='{self.id}', name='{self.name}', template='{self.template}', label='{self.label}', user_id='{self.user_id}')>"
return f"<Block(id='{self.id}', template_name='{self.template_name}', template='{self.template_name}', label='{self.label}', user_id='{self.user_id}')>"
def to_record(self) -> Block:
if self.label == "persona":
@ -365,7 +365,7 @@ class BlockModel(Base):
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template_name=self.template_name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
@ -377,7 +377,7 @@ class BlockModel(Base):
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template_name=self.template_name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
@ -389,7 +389,7 @@ class BlockModel(Base):
id=self.id,
value=self.value,
limit=self.limit,
name=self.name,
template_name=self.template_name,
template=self.template,
label=self.label,
metadata_=self.metadata_,
@ -512,7 +512,7 @@ class MetadataStore:
# with a given name doesn't exist.
if (
session.query(BlockModel)
.filter(BlockModel.name == block.name)
.filter(BlockModel.template_name == block.template_name)
.filter(BlockModel.user_id == block.user_id)
.filter(BlockModel.template == True)
.filter(BlockModel.label == block.label)
@ -520,7 +520,7 @@ class MetadataStore:
> 0
):
raise ValueError(f"Block with name {block.name} already exists")
raise ValueError(f"Block with name {block.template_name} already exists")
session.add(BlockModel(**vars(block)))
session.commit()
@ -658,7 +658,7 @@ class MetadataStore:
user_id: Optional[str],
label: Optional[str] = None,
template: Optional[bool] = None,
name: Optional[str] = None,
template_name: Optional[str] = None,
id: Optional[str] = None,
) -> Optional[List[Block]]:
"""List available blocks"""
@ -671,8 +671,8 @@ class MetadataStore:
if label:
query = query.filter(BlockModel.label == label)
if name:
query = query.filter(BlockModel.name == name)
if template_name:
query = query.filter(BlockModel.template_name == template_name)
if id:
query = query.filter(BlockModel.id == id)

View File

@ -18,7 +18,7 @@ class BaseBlock(LettaBase, validate_assignment=True):
limit: int = Field(2000, description="Character limit of the block.")
# template data (optional)
name: Optional[str] = Field(None, description="Name of the block if it is a template.")
template_name: Optional[str] = Field(None, description="Name of the block if it is a template.")
template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).")
# context window label

View File

@ -8,7 +8,10 @@ from letta.functions.helpers import (
generate_crewai_tool_wrapper,
generate_langchain_tool_wrapper,
)
from letta.functions.schema_generator import generate_schema_from_args_schema
from letta.functions.schema_generator import (
generate_schema_from_args_schema_v1,
generate_schema_from_args_schema_v2,
)
from letta.schemas.letta_base import LettaBase
from letta.schemas.openai.chat_completions import ToolCall
@ -97,7 +100,7 @@ class ToolCreate(LettaBase):
source_type = "python"
tags = ["composio"]
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action)
json_schema = generate_schema_from_args_schema(composio_tool.args_schema, name=wrapper_func_name, description=description)
json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description)
return cls(
name=wrapper_func_name,
@ -129,7 +132,7 @@ class ToolCreate(LettaBase):
tags = ["langchain"]
# NOTE: langchain tools may come from different packages
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map)
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)
json_schema = generate_schema_from_args_schema_v1(langchain_tool.args_schema, name=wrapper_func_name, description=description)
return cls(
name=wrapper_func_name,
@ -159,7 +162,7 @@ class ToolCreate(LettaBase):
source_type = "python"
tags = ["crew-ai"]
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool, additional_imports_module_attr_map)
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
json_schema = generate_schema_from_args_schema_v1(crewai_tool.args_schema, name=wrapper_func_name, description=description)
return cls(
name=wrapper_func_name,

View File

@ -1084,7 +1084,7 @@ class SyncServer(Server):
id: Optional[str] = None,
) -> Optional[List[Block]]:
return self.ms.get_blocks(user_id=user_id, label=label, template=template, name=name, id=id)
return self.ms.get_blocks(user_id=user_id, label=label, template=template, template_name=name, id=id)
def get_block(self, block_id: str):
@ -1096,14 +1096,18 @@ class SyncServer(Server):
return blocks[0]
def create_block(self, request: CreateBlock, user_id: str, update: bool = False) -> Block:
existing_blocks = self.ms.get_blocks(name=request.name, user_id=user_id, template=request.template, label=request.label)
if existing_blocks is not None:
existing_blocks = self.ms.get_blocks(
template_name=request.template_name, user_id=user_id, template=request.template, label=request.label
)
# for templates, update existing block template if exists
if existing_blocks is not None and request.template:
existing_block = existing_blocks[0]
assert len(existing_blocks) == 1
if update:
return self.update_block(UpdateBlock(id=existing_block.id, **vars(request)))
else:
raise ValueError(f"Block with name {request.name} already exists")
raise ValueError(f"Block with name {request.template_name} already exists")
block = Block(**vars(request))
self.ms.create_block(block)
return block
@ -1112,7 +1116,7 @@ class SyncServer(Server):
block = self.get_block(request.id)
block.limit = request.limit if request.limit is not None else block.limit
block.value = request.value if request.value is not None else block.value
block.name = request.name if request.name is not None else block.name
block.template_name = request.template_name if request.template_name is not None else block.template_name
self.ms.update_block(block=block)
return self.ms.get_block(block_id=request.id)
@ -1773,12 +1777,12 @@ class SyncServer(Server):
for persona_file in list_persona_files():
text = open(persona_file, "r", encoding="utf-8").read()
name = os.path.basename(persona_file).replace(".txt", "")
self.create_block(CreatePersona(user_id=user_id, name=name, value=text, template=True), user_id=user_id, update=True)
self.create_block(CreatePersona(user_id=user_id, template_name=name, value=text, template=True), user_id=user_id, update=True)
for human_file in list_human_files():
text = open(human_file, "r", encoding="utf-8").read()
name = os.path.basename(human_file).replace(".txt", "")
self.create_block(CreateHuman(user_id=user_id, name=name, value=text, template=True), user_id=user_id, update=True)
self.create_block(CreateHuman(user_id=user_id, template_name=name, value=text, template=True), user_id=user_id, update=True)
def get_agent_message(self, agent_id: str, message_id: str) -> Optional[Message]:
"""Get a single message from the agent's memory"""

1270
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -83,7 +83,7 @@ psycopg2-binary = "^2.9.10"
[tool.poetry.extras]
#local = ["llama-index-embeddings-huggingface"]
postgres = ["pgvector", "pg8000", "psycopg2-binary"]
postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2"]
milvus = ["pymilvus"]
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort"]
server = ["websockets", "fastapi", "uvicorn"]

View File

@ -288,7 +288,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
if persona_id:
client.delete_persona(persona_id)
persona = client.create_persona(name=persona_name, text="Persona text")
assert persona.name == persona_name
assert persona.template_name == persona_name
assert persona.value == "Persona text", "Creating persona failed"
human_name = "TestHuman"
@ -296,7 +296,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
if human_id:
client.delete_human(human_id)
human = client.create_human(name=human_name, text="Human text")
assert human.name == human_name
assert human.template_name == human_name
assert human.value == "Human text", "Creating human failed"
@ -565,3 +565,36 @@ def test_list_llm_models(client: RESTClient):
assert has_model_endpoint_type(models, "google_ai")
if model_settings.anthropic_api_key:
assert has_model_endpoint_type(models, "anthropic")
def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
# create a block
block = client.create_block(label="human", text="username: sarah")
# create agents with shared block
from letta.schemas.memory import BasicBlockMemory
persona1_block = client.create_block(label="persona", text="you are agent 1")
persona2_block = client.create_block(label="persona", text="you are agent 2")
# create agnets
agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory(blocks=[block, persona1_block]))
agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory(blocks=[block, persona2_block]))
# update memory
response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
# check agent 2 memory
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
response = client.user_message(agent_id=agent_state2.id, message="whats my name?")
assert (
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
# cleanup
client.delete_agent(agent_state1.id)
client.delete_agent(agent_state2.id)

View File

@ -169,8 +169,8 @@ def test_agent_add_remove_tools(client: LocalClient, agent):
def test_agent_with_shared_blocks(client: LocalClient):
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
persona_block = Block(template_name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
human_block = Block(template_name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
existing_non_template_blocks = [persona_block, human_block]
for block in existing_non_template_blocks:
# ensure that previous chat blocks are persisted, as if another agent already produced them.