mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: rename block.name
to block.template_name
for clarity and add shared block tests (#1951)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
parent
8d6ec808b6
commit
edebfc129f
@ -17,9 +17,11 @@ config = context.config
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
config.set_main_option("sqlalchemy.url", settings.letta_pg_uri)
|
||||
print(f"Using database: ", settings.letta_pg_uri)
|
||||
else:
|
||||
config.set_main_option("sqlalchemy.url", "sqlite:///" + os.path.join(letta_config.recall_storage_path, "sqlite.db"))
|
||||
|
||||
print(f"Using database: ", settings.letta_pg_uri, settings.letta_pg_uri_no_default)
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
|
@ -0,0 +1,36 @@
|
||||
"""Rename block.name to block.template_name
|
||||
|
||||
Revision ID: eff245f340f9
|
||||
Revises: 0c315956709d
|
||||
Create Date: 2024-10-31 18:09:08.819371
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "eff245f340f9"
|
||||
down_revision: Union[str, None] = "ee50a967e090"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("block", "name", new_column_name="template_name", existing_type=sa.String(), nullable=True)
|
||||
|
||||
# op.add_column('block', sa.Column('template_name', sa.String(), nullable=True))
|
||||
# op.drop_column('block', 'name')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column("block", "template_name", new_column_name="name", existing_type=sa.String(), nullable=True)
|
||||
# op.add_column('block', sa.Column('name', sa.VARCHAR(), autoincrement=False, nullable=True))
|
||||
# op.drop_column('block', 'template_name')
|
||||
# ### end Alembic commands ###
|
@ -15,7 +15,7 @@ class Swarm:
|
||||
self.client = create_client()
|
||||
|
||||
# shared memory block (shared section of context window accross agents)
|
||||
self.shared_memory = Block(name="human", label="human", value="")
|
||||
self.shared_memory = Block(label="human", value="")
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
@ -40,7 +40,7 @@ class Swarm:
|
||||
if len(instructions) > 0
|
||||
else f"You are agent with name {name}"
|
||||
)
|
||||
persona_block = Block(name="persona", label="persona", value=persona_value)
|
||||
persona_block = Block(label="persona", value=persona_value)
|
||||
memory = BasicBlockMemory(blocks=[persona_block, self.shared_memory])
|
||||
|
||||
agent = self.client.create_agent(
|
||||
|
@ -282,10 +282,10 @@ def run(
|
||||
system_prompt = system if system else None
|
||||
|
||||
memory = ChatMemory(human=human_obj.value, persona=persona_obj.value, limit=core_memory_limit)
|
||||
metadata = {"human": human_obj.name, "persona": persona_obj.name}
|
||||
metadata = {"human": human_obj.template_name, "persona": persona_obj.template_name}
|
||||
|
||||
typer.secho(f"-> {ASSISTANT_MESSAGE_CLI_SYMBOL} Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> {ASSISTANT_MESSAGE_CLI_SYMBOL} Using persona profile: '{persona_obj.template_name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{human_obj.template_name}'", fg=typer.colors.WHITE)
|
||||
|
||||
# add tools
|
||||
agent_state = client.create_agent(
|
||||
|
@ -59,13 +59,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
"""List all humans"""
|
||||
table.field_names = ["Name", "Text"]
|
||||
for human in client.list_humans():
|
||||
table.add_row([human.name, human.value.replace("\n", "")[:100]])
|
||||
table.add_row([human.template_name, human.value.replace("\n", "")[:100]])
|
||||
print(table)
|
||||
elif arg == ListChoice.personas:
|
||||
"""List all personas"""
|
||||
table.field_names = ["Name", "Text"]
|
||||
for persona in client.list_personas():
|
||||
table.add_row([persona.name, persona.value.replace("\n", "")[:100]])
|
||||
table.add_row([persona.template_name, persona.value.replace("\n", "")[:100]])
|
||||
print(table)
|
||||
elif arg == ListChoice.sources:
|
||||
"""List all data sources"""
|
||||
|
@ -859,8 +859,8 @@ class RESTClient(AbstractClient):
|
||||
else:
|
||||
return [Block(**block) for block in response.json()]
|
||||
|
||||
def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: #
|
||||
request = CreateBlock(label=label, value=text, template=template, name=name)
|
||||
def create_block(self, label: str, text: str, template_name: Optional[str] = None, template: bool = False) -> Block: #
|
||||
request = CreateBlock(label=label, value=text, template=template, template_name=template_name)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create block: {response.text}")
|
||||
@ -872,7 +872,7 @@ class RESTClient(AbstractClient):
|
||||
return Block(**response.json())
|
||||
|
||||
def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
|
||||
request = UpdateBlock(id=block_id, name=name, value=text)
|
||||
request = UpdateBlock(id=block_id, template_name=name, value=text)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{block_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update block: {response.text}")
|
||||
@ -926,7 +926,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
human (Human): Human block
|
||||
"""
|
||||
return self.create_block(label="human", name=name, text=text, template=True)
|
||||
return self.create_block(label="human", template_name=name, text=text, template=True)
|
||||
|
||||
def update_human(self, human_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Human:
|
||||
"""
|
||||
@ -939,7 +939,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
human (Human): Updated human block
|
||||
"""
|
||||
request = UpdateHuman(id=human_id, name=name, value=text)
|
||||
request = UpdateHuman(id=human_id, template_name=name, value=text)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{human_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update human: {response.text}")
|
||||
@ -966,7 +966,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
persona (Persona): Persona block
|
||||
"""
|
||||
return self.create_block(label="persona", name=name, text=text, template=True)
|
||||
return self.create_block(label="persona", template_name=name, text=text, template=True)
|
||||
|
||||
def update_persona(self, persona_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Persona:
|
||||
"""
|
||||
@ -979,7 +979,7 @@ class RESTClient(AbstractClient):
|
||||
Returns:
|
||||
persona (Persona): Updated persona block
|
||||
"""
|
||||
request = UpdatePersona(id=persona_id, name=name, value=text)
|
||||
request = UpdatePersona(id=persona_id, template_name=name, value=text)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/blocks/{persona_id}", json=request.model_dump(), headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to update persona: {response.text}")
|
||||
@ -2116,7 +2116,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
human (Human): Human block
|
||||
"""
|
||||
return self.server.create_block(CreateHuman(name=name, value=text, user_id=self.user_id), user_id=self.user_id)
|
||||
return self.server.create_block(CreateHuman(template_name=name, value=text, user_id=self.user_id), user_id=self.user_id)
|
||||
|
||||
def create_persona(self, name: str, text: str):
|
||||
"""
|
||||
@ -2129,7 +2129,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
persona (Persona): Persona block
|
||||
"""
|
||||
return self.server.create_block(CreatePersona(name=name, value=text, user_id=self.user_id), user_id=self.user_id)
|
||||
return self.server.create_block(CreatePersona(template_name=name, value=text, user_id=self.user_id), user_id=self.user_id)
|
||||
|
||||
def list_humans(self):
|
||||
"""
|
||||
@ -2635,7 +2635,7 @@ class LocalClient(AbstractClient):
|
||||
"""
|
||||
return self.server.get_blocks(label=label, template=templates_only)
|
||||
|
||||
def create_block(self, label: str, text: str, name: Optional[str] = None, template: bool = False) -> Block: #
|
||||
def create_block(self, label: str, text: str, template_name: Optional[str] = None, template: bool = False) -> Block: #
|
||||
"""
|
||||
Create a block
|
||||
|
||||
@ -2648,7 +2648,7 @@ class LocalClient(AbstractClient):
|
||||
block (Block): Created block
|
||||
"""
|
||||
return self.server.create_block(
|
||||
CreateBlock(label=label, name=name, value=text, user_id=self.user_id, template=template), user_id=self.user_id
|
||||
CreateBlock(label=label, template_name=template_name, value=text, user_id=self.user_id, template=template), user_id=self.user_id
|
||||
)
|
||||
|
||||
def update_block(self, block_id: str, name: Optional[str] = None, text: Optional[str] = None) -> Block:
|
||||
@ -2663,7 +2663,7 @@ class LocalClient(AbstractClient):
|
||||
Returns:
|
||||
block (Block): Updated block
|
||||
"""
|
||||
return self.server.update_block(UpdateBlock(id=block_id, name=name, value=text))
|
||||
return self.server.update_block(UpdateBlock(id=block_id, template_name=name, value=text))
|
||||
|
||||
def get_block(self, block_id: str) -> Block:
|
||||
"""
|
||||
|
@ -139,32 +139,73 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
return schema
|
||||
|
||||
|
||||
def generate_schema_from_args_schema(
|
||||
def generate_schema_from_args_schema_v1(
|
||||
args_schema: Type[V1BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
properties = {}
|
||||
required = []
|
||||
for field_name, field in args_schema.__fields__.items():
|
||||
if field.type_.__name__ == "str":
|
||||
if field.type_ == str:
|
||||
field_type = "string"
|
||||
elif field.type_.__name__ == "int":
|
||||
elif field.type_ == int:
|
||||
field_type = "integer"
|
||||
elif field.type_.__name__ == "bool":
|
||||
elif field.type_ == bool:
|
||||
field_type = "boolean"
|
||||
else:
|
||||
field_type = field.type_.__name__
|
||||
properties[field_name] = {"type": field_type, "description": field.field_info.description}
|
||||
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field.field_info.description,
|
||||
}
|
||||
if field.required:
|
||||
required.append(field_name)
|
||||
|
||||
# Construct the OpenAI function call JSON object
|
||||
function_call_json = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {"type": "object", "properties": properties, "required": required},
|
||||
}
|
||||
|
||||
# append heartbeat (necessary for triggering another reasoning step after this tool call)
|
||||
if append_heartbeat:
|
||||
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
function_call_json["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return function_call_json
|
||||
|
||||
|
||||
def generate_schema_from_args_schema_v2(
|
||||
args_schema: Type[BaseModel], name: Optional[str] = None, description: Optional[str] = None, append_heartbeat: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
properties = {}
|
||||
required = []
|
||||
for field_name, field in args_schema.model_fields.items():
|
||||
field_type_annotation = field.annotation
|
||||
if field_type_annotation == str:
|
||||
field_type = "string"
|
||||
elif field_type_annotation == int:
|
||||
field_type = "integer"
|
||||
elif field_type_annotation == bool:
|
||||
field_type = "boolean"
|
||||
else:
|
||||
field_type = field_type_annotation.__name__
|
||||
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field.description,
|
||||
}
|
||||
if field.is_required():
|
||||
required.append(field_name)
|
||||
|
||||
function_call_json = {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"parameters": {"type": "object", "properties": properties, "required": required},
|
||||
}
|
||||
|
||||
if append_heartbeat:
|
||||
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
|
@ -348,7 +348,7 @@ class BlockModel(Base):
|
||||
id = Column(String, primary_key=True, nullable=False)
|
||||
value = Column(String, nullable=False)
|
||||
limit = Column(BIGINT)
|
||||
name = Column(String)
|
||||
template_name = Column(String, nullable=True, default=None)
|
||||
template = Column(Boolean, default=False) # True: listed as possible human/persona
|
||||
label = Column(String, nullable=False)
|
||||
metadata_ = Column(JSON)
|
||||
@ -357,7 +357,7 @@ class BlockModel(Base):
|
||||
Index(__tablename__ + "_idx_user", user_id),
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Block(id='{self.id}', name='{self.name}', template='{self.template}', label='{self.label}', user_id='{self.user_id}')>"
|
||||
return f"<Block(id='{self.id}', template_name='{self.template_name}', template='{self.template_name}', label='{self.label}', user_id='{self.user_id}')>"
|
||||
|
||||
def to_record(self) -> Block:
|
||||
if self.label == "persona":
|
||||
@ -365,7 +365,7 @@ class BlockModel(Base):
|
||||
id=self.id,
|
||||
value=self.value,
|
||||
limit=self.limit,
|
||||
name=self.name,
|
||||
template_name=self.template_name,
|
||||
template=self.template,
|
||||
label=self.label,
|
||||
metadata_=self.metadata_,
|
||||
@ -377,7 +377,7 @@ class BlockModel(Base):
|
||||
id=self.id,
|
||||
value=self.value,
|
||||
limit=self.limit,
|
||||
name=self.name,
|
||||
template_name=self.template_name,
|
||||
template=self.template,
|
||||
label=self.label,
|
||||
metadata_=self.metadata_,
|
||||
@ -389,7 +389,7 @@ class BlockModel(Base):
|
||||
id=self.id,
|
||||
value=self.value,
|
||||
limit=self.limit,
|
||||
name=self.name,
|
||||
template_name=self.template_name,
|
||||
template=self.template,
|
||||
label=self.label,
|
||||
metadata_=self.metadata_,
|
||||
@ -512,7 +512,7 @@ class MetadataStore:
|
||||
# with a given name doesn't exist.
|
||||
if (
|
||||
session.query(BlockModel)
|
||||
.filter(BlockModel.name == block.name)
|
||||
.filter(BlockModel.template_name == block.template_name)
|
||||
.filter(BlockModel.user_id == block.user_id)
|
||||
.filter(BlockModel.template == True)
|
||||
.filter(BlockModel.label == block.label)
|
||||
@ -520,7 +520,7 @@ class MetadataStore:
|
||||
> 0
|
||||
):
|
||||
|
||||
raise ValueError(f"Block with name {block.name} already exists")
|
||||
raise ValueError(f"Block with name {block.template_name} already exists")
|
||||
session.add(BlockModel(**vars(block)))
|
||||
session.commit()
|
||||
|
||||
@ -658,7 +658,7 @@ class MetadataStore:
|
||||
user_id: Optional[str],
|
||||
label: Optional[str] = None,
|
||||
template: Optional[bool] = None,
|
||||
name: Optional[str] = None,
|
||||
template_name: Optional[str] = None,
|
||||
id: Optional[str] = None,
|
||||
) -> Optional[List[Block]]:
|
||||
"""List available blocks"""
|
||||
@ -671,8 +671,8 @@ class MetadataStore:
|
||||
if label:
|
||||
query = query.filter(BlockModel.label == label)
|
||||
|
||||
if name:
|
||||
query = query.filter(BlockModel.name == name)
|
||||
if template_name:
|
||||
query = query.filter(BlockModel.template_name == template_name)
|
||||
|
||||
if id:
|
||||
query = query.filter(BlockModel.id == id)
|
||||
|
@ -18,7 +18,7 @@ class BaseBlock(LettaBase, validate_assignment=True):
|
||||
limit: int = Field(2000, description="Character limit of the block.")
|
||||
|
||||
# template data (optional)
|
||||
name: Optional[str] = Field(None, description="Name of the block if it is a template.")
|
||||
template_name: Optional[str] = Field(None, description="Name of the block if it is a template.")
|
||||
template: bool = Field(False, description="Whether the block is a template (e.g. saved human/persona options).")
|
||||
|
||||
# context window label
|
||||
|
@ -8,7 +8,10 @@ from letta.functions.helpers import (
|
||||
generate_crewai_tool_wrapper,
|
||||
generate_langchain_tool_wrapper,
|
||||
)
|
||||
from letta.functions.schema_generator import generate_schema_from_args_schema
|
||||
from letta.functions.schema_generator import (
|
||||
generate_schema_from_args_schema_v1,
|
||||
generate_schema_from_args_schema_v2,
|
||||
)
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
|
||||
@ -97,7 +100,7 @@ class ToolCreate(LettaBase):
|
||||
source_type = "python"
|
||||
tags = ["composio"]
|
||||
wrapper_func_name, wrapper_function_str = generate_composio_tool_wrapper(action)
|
||||
json_schema = generate_schema_from_args_schema(composio_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
json_schema = generate_schema_from_args_schema_v2(composio_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
return cls(
|
||||
name=wrapper_func_name,
|
||||
@ -129,7 +132,7 @@ class ToolCreate(LettaBase):
|
||||
tags = ["langchain"]
|
||||
# NOTE: langchain tools may come from different packages
|
||||
wrapper_func_name, wrapper_function_str = generate_langchain_tool_wrapper(langchain_tool, additional_imports_module_attr_map)
|
||||
json_schema = generate_schema_from_args_schema(langchain_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
json_schema = generate_schema_from_args_schema_v1(langchain_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
return cls(
|
||||
name=wrapper_func_name,
|
||||
@ -159,7 +162,7 @@ class ToolCreate(LettaBase):
|
||||
source_type = "python"
|
||||
tags = ["crew-ai"]
|
||||
wrapper_func_name, wrapper_function_str = generate_crewai_tool_wrapper(crewai_tool, additional_imports_module_attr_map)
|
||||
json_schema = generate_schema_from_args_schema(crewai_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
json_schema = generate_schema_from_args_schema_v1(crewai_tool.args_schema, name=wrapper_func_name, description=description)
|
||||
|
||||
return cls(
|
||||
name=wrapper_func_name,
|
||||
|
@ -1084,7 +1084,7 @@ class SyncServer(Server):
|
||||
id: Optional[str] = None,
|
||||
) -> Optional[List[Block]]:
|
||||
|
||||
return self.ms.get_blocks(user_id=user_id, label=label, template=template, name=name, id=id)
|
||||
return self.ms.get_blocks(user_id=user_id, label=label, template=template, template_name=name, id=id)
|
||||
|
||||
def get_block(self, block_id: str):
|
||||
|
||||
@ -1096,14 +1096,18 @@ class SyncServer(Server):
|
||||
return blocks[0]
|
||||
|
||||
def create_block(self, request: CreateBlock, user_id: str, update: bool = False) -> Block:
|
||||
existing_blocks = self.ms.get_blocks(name=request.name, user_id=user_id, template=request.template, label=request.label)
|
||||
if existing_blocks is not None:
|
||||
existing_blocks = self.ms.get_blocks(
|
||||
template_name=request.template_name, user_id=user_id, template=request.template, label=request.label
|
||||
)
|
||||
|
||||
# for templates, update existing block template if exists
|
||||
if existing_blocks is not None and request.template:
|
||||
existing_block = existing_blocks[0]
|
||||
assert len(existing_blocks) == 1
|
||||
if update:
|
||||
return self.update_block(UpdateBlock(id=existing_block.id, **vars(request)))
|
||||
else:
|
||||
raise ValueError(f"Block with name {request.name} already exists")
|
||||
raise ValueError(f"Block with name {request.template_name} already exists")
|
||||
block = Block(**vars(request))
|
||||
self.ms.create_block(block)
|
||||
return block
|
||||
@ -1112,7 +1116,7 @@ class SyncServer(Server):
|
||||
block = self.get_block(request.id)
|
||||
block.limit = request.limit if request.limit is not None else block.limit
|
||||
block.value = request.value if request.value is not None else block.value
|
||||
block.name = request.name if request.name is not None else block.name
|
||||
block.template_name = request.template_name if request.template_name is not None else block.template_name
|
||||
self.ms.update_block(block=block)
|
||||
return self.ms.get_block(block_id=request.id)
|
||||
|
||||
@ -1773,12 +1777,12 @@ class SyncServer(Server):
|
||||
for persona_file in list_persona_files():
|
||||
text = open(persona_file, "r", encoding="utf-8").read()
|
||||
name = os.path.basename(persona_file).replace(".txt", "")
|
||||
self.create_block(CreatePersona(user_id=user_id, name=name, value=text, template=True), user_id=user_id, update=True)
|
||||
self.create_block(CreatePersona(user_id=user_id, template_name=name, value=text, template=True), user_id=user_id, update=True)
|
||||
|
||||
for human_file in list_human_files():
|
||||
text = open(human_file, "r", encoding="utf-8").read()
|
||||
name = os.path.basename(human_file).replace(".txt", "")
|
||||
self.create_block(CreateHuman(user_id=user_id, name=name, value=text, template=True), user_id=user_id, update=True)
|
||||
self.create_block(CreateHuman(user_id=user_id, template_name=name, value=text, template=True), user_id=user_id, update=True)
|
||||
|
||||
def get_agent_message(self, agent_id: str, message_id: str) -> Optional[Message]:
|
||||
"""Get a single message from the agent's memory"""
|
||||
|
1270
poetry.lock
generated
1270
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -83,7 +83,7 @@ psycopg2-binary = "^2.9.10"
|
||||
|
||||
[tool.poetry.extras]
|
||||
#local = ["llama-index-embeddings-huggingface"]
|
||||
postgres = ["pgvector", "pg8000", "psycopg2-binary"]
|
||||
postgres = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2"]
|
||||
milvus = ["pymilvus"]
|
||||
dev = ["pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "datasets", "pyright", "pytest-order", "autoflake", "isort"]
|
||||
server = ["websockets", "fastapi", "uvicorn"]
|
||||
|
@ -288,7 +288,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
|
||||
if persona_id:
|
||||
client.delete_persona(persona_id)
|
||||
persona = client.create_persona(name=persona_name, text="Persona text")
|
||||
assert persona.name == persona_name
|
||||
assert persona.template_name == persona_name
|
||||
assert persona.value == "Persona text", "Creating persona failed"
|
||||
|
||||
human_name = "TestHuman"
|
||||
@ -296,7 +296,7 @@ def test_humans_personas(client: Union[LocalClient, RESTClient], agent: AgentSta
|
||||
if human_id:
|
||||
client.delete_human(human_id)
|
||||
human = client.create_human(name=human_name, text="Human text")
|
||||
assert human.name == human_name
|
||||
assert human.template_name == human_name
|
||||
assert human.value == "Human text", "Creating human failed"
|
||||
|
||||
|
||||
@ -565,3 +565,36 @@ def test_list_llm_models(client: RESTClient):
|
||||
assert has_model_endpoint_type(models, "google_ai")
|
||||
if model_settings.anthropic_api_key:
|
||||
assert has_model_endpoint_type(models, "anthropic")
|
||||
|
||||
|
||||
def test_shared_blocks(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
# create a block
|
||||
block = client.create_block(label="human", text="username: sarah")
|
||||
|
||||
# create agents with shared block
|
||||
from letta.schemas.memory import BasicBlockMemory
|
||||
|
||||
persona1_block = client.create_block(label="persona", text="you are agent 1")
|
||||
persona2_block = client.create_block(label="persona", text="you are agent 2")
|
||||
|
||||
# create agnets
|
||||
agent_state1 = client.create_agent(name="agent1", memory=BasicBlockMemory(blocks=[block, persona1_block]))
|
||||
agent_state2 = client.create_agent(name="agent2", memory=BasicBlockMemory(blocks=[block, persona2_block]))
|
||||
|
||||
# update memory
|
||||
response = client.user_message(agent_id=agent_state1.id, message="my name is actually charles")
|
||||
|
||||
# check agent 2 memory
|
||||
assert "charles" in client.get_block(block.id).value.lower(), f"Shared block update failed {client.get_block(block.id).value}"
|
||||
|
||||
response = client.user_message(agent_id=agent_state2.id, message="whats my name?")
|
||||
assert (
|
||||
"charles" in client.get_core_memory(agent_state2.id).get_block("human").value.lower()
|
||||
), f"Shared block update failed {client.get_core_memory(agent_state2.id).get_block('human').value}"
|
||||
# assert "charles" in response.messages[1].text.lower(), f"Shared block update failed {response.messages[0].text}"
|
||||
|
||||
# cleanup
|
||||
client.delete_agent(agent_state1.id)
|
||||
client.delete_agent(agent_state2.id)
|
||||
|
@ -169,8 +169,8 @@ def test_agent_add_remove_tools(client: LocalClient, agent):
|
||||
|
||||
|
||||
def test_agent_with_shared_blocks(client: LocalClient):
|
||||
persona_block = Block(name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
|
||||
human_block = Block(name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
|
||||
persona_block = Block(template_name="persona", value="Here to test things!", label="persona", user_id=client.user_id)
|
||||
human_block = Block(template_name="human", value="Me Human, I swear. Beep boop.", label="human", user_id=client.user_id)
|
||||
existing_non_template_blocks = [persona_block, human_block]
|
||||
for block in existing_non_template_blocks:
|
||||
# ensure that previous chat blocks are persisted, as if another agent already produced them.
|
||||
|
Loading…
Reference in New Issue
Block a user