feat: background multi-agent group for sleeptime agent (#1508)

This commit is contained in:
cthomas 2025-04-01 15:00:45 -07:00 committed by GitHub
parent d1b4f7c669
commit ab710c5073
16 changed files with 663 additions and 52 deletions

View File

@ -0,0 +1,44 @@
"""add background group support
Revision ID: 74f2ede29317
Revises: bff040379479
Create Date: 2025-04-01 07:45:31.735977
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "74f2ede29317"
down_revision: Union[str, None] = "bff040379479"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("groups", sa.Column("background_agents_interval", sa.Integer(), nullable=True))
op.add_column("groups", sa.Column("turns_counter", sa.Integer(), nullable=True))
op.add_column("groups", sa.Column("last_processed_message_id", sa.String(), nullable=True))
op.create_table(
"groups_blocks",
sa.Column("group_id", sa.String(), nullable=False),
sa.Column("block_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("group_id", "block_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("groups_blocks")
op.drop_column("groups", "last_processed_message_id")
op.drop_column("groups", "turns_counter")
op.drop_column("groups", "background_agents_interval")
# ### end Alembic commands ###

View File

@ -0,0 +1,254 @@
import asyncio
import threading
from datetime import datetime
from typing import List, Optional
from letta.agent import Agent, AgentState
from letta.groups.helpers import stringify_message
from letta.interface import AgentInterface
from letta.orm import User
from letta.schemas.enums import JobStatus
from letta.schemas.job import JobUpdate
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.run import Run
from letta.schemas.usage import LettaUsageStatistics
from letta.services.group_manager import GroupManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
class BackgroundMultiAgent(Agent):
def __init__(
self,
interface: AgentInterface,
agent_state: AgentState,
user: User,
# custom
group_id: str = "",
agent_ids: List[str] = [],
description: str = "",
background_agents_interval: Optional[int] = None,
):
super().__init__(interface, agent_state, user)
self.group_id = group_id
self.agent_ids = agent_ids
self.description = description
self.background_agents_interval = background_agents_interval
self.group_manager = GroupManager()
self.message_manager = MessageManager()
self.job_manager = JobManager()
def _run_async_in_new_thread(self, coro):
"""Run an async coroutine in a new thread with its own event loop"""
result = None
def run_async():
nonlocal result
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
result = loop.run_until_complete(coro)
finally:
loop.close()
asyncio.set_event_loop(None)
thread = threading.Thread(target=run_async)
thread.start()
thread.join()
return result
async def _issue_background_task(
self,
participant_agent_id: str,
messages: List[Message],
chaining: bool,
max_chaining_steps: Optional[int],
token_streaming: bool,
metadata: Optional[dict],
put_inner_thoughts_first: bool,
last_processed_message_id: str,
) -> str:
run = Run(
user_id=self.user.id,
status=JobStatus.created,
metadata={
"job_type": "background_agent_send_message_async",
"agent_id": participant_agent_id,
},
)
run = self.job_manager.create_job(pydantic_job=run, actor=self.user)
asyncio.create_task(
self._perform_background_agent_step(
participant_agent_id=participant_agent_id,
messages=messages,
chaining=chaining,
max_chaining_steps=max_chaining_steps,
token_streaming=token_streaming,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
last_processed_message_id=last_processed_message_id,
run_id=run.id,
)
)
return run.id
async def _perform_background_agent_step(
self,
participant_agent_id: str,
messages: List[Message],
chaining: bool,
max_chaining_steps: Optional[int],
token_streaming: bool,
metadata: Optional[dict],
put_inner_thoughts_first: bool,
last_processed_message_id: str,
run_id: str,
) -> LettaUsageStatistics:
try:
participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user)
participant_agent = Agent(
agent_state=participant_agent_state,
interface=self.interface,
user=self.user,
)
prior_messages = []
if self.background_agents_interval:
try:
prior_messages = self.message_manager.list_messages_for_agent(
agent_id=self.agent_state.id,
actor=self.user,
after=last_processed_message_id,
before=messages[0].id,
)
except Exception as e:
print(f"Error fetching prior messages: {str(e)}")
# continue with just latest messages
transcript_summary = [stringify_message(message) for message in prior_messages + messages]
transcript_summary = [summary for summary in transcript_summary if summary is not None]
message_text = "\n".join(transcript_summary)
participant_agent_messages = [
Message(
id=Message.generate_id(),
agent_id=participant_agent.agent_state.id,
role="user",
content=[TextContent(text=message_text)],
group_id=self.group_id,
)
]
result = participant_agent.step(
messages=participant_agent_messages,
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
skip_verify=True,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
job_update = JobUpdate(
status=JobStatus.completed,
completed_at=datetime.utcnow(),
metadata={"result": result.model_dump(mode="json")}, # Store the result in metadata
)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
return result
except Exception as e:
job_update = JobUpdate(
status=JobStatus.failed,
completed_at=datetime.utcnow(),
metadata={"error": str(e)},
)
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
raise
def step(
self,
messages: List[MessageCreate],
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
put_inner_thoughts_first: bool = True,
**kwargs,
) -> LettaUsageStatistics:
run_ids = []
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
messages = [
Message(
id=Message.generate_id(),
agent_id=self.agent_state.id,
role=message.role,
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
name=message.name,
model=None,
tool_calls=None,
tool_call_id=None,
group_id=self.group_id,
)
for message in messages
]
try:
main_agent = Agent(
agent_state=self.agent_state,
interface=self.interface,
user=self.user,
)
usage_stats = main_agent.step(
messages=messages,
chaining=chaining,
max_chaining_steps=max_chaining_steps,
stream=token_streaming,
skip_verify=True,
metadata=metadata,
put_inner_thoughts_first=put_inner_thoughts_first,
)
turns_counter = None
if self.background_agents_interval is not None and self.background_agents_interval > 0:
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user)
if self.background_agents_interval is None or (
turns_counter is not None and turns_counter % self.background_agents_interval == 0
):
last_response_messages = [message for sublist in usage_stats.steps_messages for message in sublist]
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
group_id=self.group_id, last_processed_message_id=last_response_messages[-1].id, actor=self.user
)
for participant_agent_id in self.agent_ids:
try:
run_id = self._run_async_in_new_thread(
self._issue_background_task(
participant_agent_id,
last_response_messages,
chaining,
max_chaining_steps,
token_streaming,
metadata,
put_inner_thoughts_first,
last_processed_message_id,
)
)
run_ids.append(run_id)
except Exception as e:
# Handle individual task failures
print(f"Agent processing failed: {str(e)}")
raise e
except Exception as e:
raise e
finally:
self.interface.step_yield()
self.interface.step_complete()
usage_stats.run_ids = run_ids
return LettaUsageStatistics(**usage_stats.model_dump())

104
letta/groups/helpers.py Normal file
View File

@ -0,0 +1,104 @@
import json
from typing import Optional, Union
from letta.agent import Agent
from letta.interface import AgentInterface
from letta.orm.group import Group
from letta.orm.user import User
from letta.schemas.agent import AgentState
from letta.schemas.group import ManagerType
from letta.schemas.message import Message
def load_multi_agent(
group: Group,
agent_state: Optional[AgentState],
actor: User,
interface: Union[AgentInterface, None] = None,
) -> Agent:
if len(group.agent_ids) == 0:
raise ValueError("Empty group: group must have at least one agent")
if not agent_state:
raise ValueError("Empty manager agent state: manager agent state must be provided")
match group.manager_type:
case ManagerType.round_robin:
from letta.groups.round_robin_multi_agent import RoundRobinMultiAgent
return RoundRobinMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
max_turns=group.max_turns,
)
case ManagerType.dynamic:
from letta.groups.dynamic_multi_agent import DynamicMultiAgent
return DynamicMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
max_turns=group.max_turns,
termination_token=group.termination_token,
)
case ManagerType.supervisor:
from letta.groups.supervisor_multi_agent import SupervisorMultiAgent
return SupervisorMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
)
case ManagerType.background:
from letta.groups.background_multi_agent import BackgroundMultiAgent
return BackgroundMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
background_agents_interval=group.background_agents_interval,
)
case _:
raise ValueError(f"Type {group.manager_type} is not supported.")
def stringify_message(message: Message) -> str | None:
if message.role == "user":
content = json.loads(message.content[0].text)
if content["type"] == "user_message":
return f"{message.name or 'user'}: {content['message']}"
else:
return None
elif message.role == "assistant":
messages = []
if message.content:
messages.append(f"{message.name or 'assistant'}: *thinking* {message.content[0].text}")
if message.tool_calls:
if message.tool_calls[0].function.name == "send_message":
messages.append(f"{message.name or 'assistant'}: {json.loads(message.tool_calls[0].function.arguments)['message']}")
else:
messages.append(f"{message.name or 'assistant'}: Calling tool {message.tool_calls[0].function.name}")
return "\n".join(messages)
elif message.role == "tool":
if message.content:
content = json.loads(message.content[0].text)
if content["message"] != "None" and content["message"] != None:
return f"{message.name or 'assistant'}: Tool call returned {content['message']}"
return None
elif message.role == "system":
return None
return f"{message.name or 'user'}: {message.content[0].text}"

View File

@ -7,6 +7,7 @@ from letta.orm.blocks_agents import BlocksAgents
from letta.orm.file import FileMetadata
from letta.orm.group import Group
from letta.orm.groups_agents import GroupsAgents
from letta.orm.groups_blocks import GroupsBlocks
from letta.orm.identities_agents import IdentitiesAgents
from letta.orm.identities_blocks import IdentitiesBlocks
from letta.orm.identity import Identity

View File

@ -67,6 +67,13 @@ class Block(OrganizationMixin, SqlalchemyBase):
back_populates="blocks",
passive_deletes=True,
)
groups: Mapped[List["Group"]] = relationship(
"Group",
secondary="groups_blocks",
lazy="selectin",
back_populates="shared_blocks",
passive_deletes=True,
)
def to_pydantic(self) -> Type:
match self.label:

View File

@ -20,6 +20,9 @@ class Group(SqlalchemyBase, OrganizationMixin):
manager_agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id", ondelete="RESTRICT"), nullable=True, doc="")
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
background_agents_interval: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
@ -27,4 +30,7 @@ class Group(SqlalchemyBase, OrganizationMixin):
agents: Mapped[List["Agent"]] = relationship(
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
)
shared_blocks: Mapped[List["Block"]] = relationship(
"Block", secondary="groups_blocks", lazy="selectin", passive_deletes=True, back_populates="groups"
)
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")

View File

@ -0,0 +1,13 @@
from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.base import Base
class GroupsBlocks(Base):
"""Groups may have one or many shared blocks associated with them."""
__tablename__ = "groups_blocks"
group_id: Mapped[str] = mapped_column(String, ForeignKey("groups.id", ondelete="CASCADE"), primary_key=True)
block_id: Mapped[str] = mapped_column(String, ForeignKey("block.id", ondelete="CASCADE"), primary_key=True)

View File

@ -174,6 +174,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
False,
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
)
enable_sleeptime: Optional[bool] = Field(False, description="If set to True, memory management will move to a background agent thread.")
@field_validator("name")
@classmethod
@ -252,6 +253,7 @@ class UpdateAgent(BaseModel):
embedding: Optional[str] = Field(
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
)
enable_sleeptime: Optional[bool] = Field(False, description="If set to True, memory management will move to a background agent thread.")
class Config:
extra = "ignore" # Ignores extra fields

View File

@ -10,6 +10,7 @@ class ManagerType(str, Enum):
round_robin = "round_robin"
supervisor = "supervisor"
dynamic = "dynamic"
background = "background"
swarm = "swarm"
@ -22,10 +23,14 @@ class Group(GroupBase):
manager_type: ManagerType = Field(..., description="")
agent_ids: List[str] = Field(..., description="")
description: str = Field(..., description="")
shared_block_ids: List[str] = Field([], description="")
# Pattern fields
manager_agent_id: Optional[str] = Field(None, description="")
termination_token: Optional[str] = Field(None, description="")
max_turns: Optional[int] = Field(None, description="")
background_agents_interval: Optional[int] = Field(None, description="")
turns_counter: Optional[int] = Field(None, description="")
last_processed_message_id: Optional[str] = Field(None, description="")
class ManagerConfig(BaseModel):
@ -49,12 +54,18 @@ class DynamicManager(ManagerConfig):
max_turns: Optional[int] = Field(None, description="")
class BackgroundManager(ManagerConfig):
manager_type: Literal[ManagerType.background] = Field(ManagerType.background, description="")
manager_agent_id: str = Field(..., description="")
background_agents_interval: Optional[int] = Field(None, description="")
# class SwarmGroup(ManagerConfig):
# manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="")
ManagerConfigUnion = Annotated[
Union[RoundRobinManager, SupervisorManager, DynamicManager],
Union[RoundRobinManager, SupervisorManager, DynamicManager, BackgroundManager],
Field(discriminator="manager_type"),
]
@ -63,9 +74,11 @@ class GroupCreate(BaseModel):
agent_ids: List[str] = Field(..., description="")
description: str = Field(..., description="")
manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="")
shared_block_ids: List[str] = Field([], description="")
class GroupUpdate(BaseModel):
agent_ids: Optional[List[str]] = Field(None, description="")
description: Optional[str] = Field(None, description="")
manager_config: Optional[ManagerConfigUnion] = Field(None, description="")
shared_block_ids: Optional[List[str]] = Field(None, description="")

View File

@ -23,3 +23,4 @@ class LettaUsageStatistics(BaseModel):
step_count: int = Field(0, description="The number of steps taken by the agent.")
# TODO: Optional for now. This field makes everyone's lives easier
steps_messages: Optional[List[List[Message]]] = Field(None, description="The messages generated per step")
run_ids: Optional[List[str]] = Field(None, description="The background task run IDs associated with the agent interaction")

View File

@ -19,11 +19,11 @@ import letta.system as system
from letta.agent import Agent, save_agent
from letta.config import LettaConfig
from letta.data_sources.connectors import DataConnector, load_data
from letta.dynamic_multi_agent import DynamicMultiAgent
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.functions.mcp_client.sse_client import MCP_CONFIG_TOPLEVEL_KEY, SSEMCPClient
from letta.functions.mcp_client.stdio_client import StdioMCPClient
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
from letta.groups.helpers import load_multi_agent
from letta.helpers.datetime_helpers import get_utc_time
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.helpers.message_helper import prepare_input_message_create
@ -34,7 +34,6 @@ from letta.interface import CLIInterface # for printing to terminal
from letta.log import get_logger
from letta.offline_memory_agent import OfflineMemoryAgent
from letta.orm.errors import NoResultFound
from letta.round_robin_multi_agent import RoundRobinMultiAgent
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
@ -42,7 +41,6 @@ from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.group import Group, ManagerType
from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage
from letta.schemas.letta_message_content import TextContent
@ -94,7 +92,6 @@ from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSan
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings, tool_settings
from letta.supervisor_multi_agent import SupervisorMultiAgent
from letta.tracing import trace_method
from letta.utils import get_friendly_error_msg
@ -352,7 +349,7 @@ class SyncServer(Server):
"""Updated method to load agents from persisted storage"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
if agent_state.multi_agent_group:
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
return load_multi_agent(group=agent_state.multi_agent_group, agent_state=agent_state, actor=actor, interface=interface)
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
@ -364,49 +361,6 @@ class SyncServer(Server):
return agent
def load_multi_agent(
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
) -> Agent:
if len(group.agent_ids) == 0:
raise ValueError("Empty group: group must have at least one agent")
match group.manager_type:
case ManagerType.round_robin:
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor)
return RoundRobinMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
max_turns=group.max_turns,
)
case ManagerType.dynamic:
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
return DynamicMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
max_turns=group.max_turns,
termination_token=group.termination_token,
)
case ManagerType.supervisor:
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
return SupervisorMultiAgent(
agent_state=agent_state,
interface=interface,
user=actor,
group_id=group.id,
agent_ids=group.agent_ids,
description=group.description,
)
case _:
raise ValueError(f"Type {group.manager_type} is not supported.")
def _step(
self,
actor: User,
@ -1599,7 +1553,9 @@ class SyncServer(Server):
raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'")
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
agent_state_id = group.manager_agent_id or (group.agent_ids[0] if len(group.agent_ids) > 0 else None)
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_state_id, actor=actor) if agent_state_id else None
letta_multi_agent = load_multi_agent(group=group, agent_state=agent_state, actor=actor)
llm_config = letta_multi_agent.agent_state.llm_config
supports_token_streaming = ["openai", "anthropic", "deepseek"]

View File

@ -71,11 +71,20 @@ class GroupManager:
case ManagerType.supervisor:
new_group.manager_type = ManagerType.supervisor
new_group.manager_agent_id = group.manager_config.manager_agent_id
case ManagerType.background:
new_group.manager_type = ManagerType.background
new_group.manager_agent_id = group.manager_config.manager_agent_id
new_group.background_agents_interval = group.manager_config.background_agents_interval
if new_group.background_agents_interval:
new_group.turns_counter = 0
case _:
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False)
if group.shared_block_ids:
self._process_shared_block_relationship(session=session, group=new_group, block_ids=group.shared_block_ids)
new_group.create(session, actor=actor)
return new_group.to_pydantic()
@ -84,6 +93,7 @@ class GroupManager:
with self.session_maker() as session:
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
background_agents_interval = None
max_turns = None
termination_token = None
manager_agent_id = None
@ -99,9 +109,16 @@ class GroupManager:
termination_token = group_update.manager_config.termination_token
case ManagerType.supervisor:
manager_agent_id = group_update.manager_config.manager_agent_id
case ManagerType.background:
manager_agent_id = group_update.manager_config.manager_agent_id
background_agents_interval = group_update.manager_config.background_agents_interval
if background_agents_interval and group.turns_counter is None:
group.turns_counter = 0
case _:
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
if background_agents_interval:
group.background_agents_interval = background_agents_interval
if max_turns:
group.max_turns = max_turns
if termination_token:
@ -174,6 +191,30 @@ class GroupManager:
session.commit()
@enforce_types
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
with self.session_maker() as session:
# Ensure group is loadable by user
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
# Update turns counter
group.turns_counter = (group.turns_counter + 1) % group.background_agents_interval
group.update(session, actor=actor)
return group.turns_counter
@enforce_types
def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str:
with self.session_maker() as session:
# Ensure group is loadable by user
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
# Update last processed message id
prev_last_processed_message_id = group.last_processed_message_id
group.last_processed_message_id = last_processed_message_id
group.update(session, actor=actor)
return prev_last_processed_message_id
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
if not agent_ids:
if replace:
@ -203,3 +244,30 @@ class GroupManager:
setattr(group, "agent_ids", agent_ids)
else:
raise ValueError("Extend relationship is not supported for groups.")
def _process_shared_block_relationship(
self,
session: Session,
group: GroupModel,
block_ids: List[str],
):
"""Process shared block relationships for a group and its agents."""
from letta.orm import Agent, Block, BlocksAgents
# Add blocks to group
blocks = session.query(Block).filter(Block.id.in_(block_ids)).all()
group.shared_blocks = blocks
# Add blocks to all agents
if group.agent_ids:
agents = session.query(Agent).filter(Agent.id.in_(group.agent_ids)).all()
for agent in agents:
for block in blocks:
session.add(BlocksAgents(agent_id=agent.id, block_id=block.id, block_label=block.label))
# Add blocks to manager agent if exists
if group.manager_agent_id:
manager_agent = session.query(Agent).filter(Agent.id == group.manager_agent_id).first()
if manager_agent:
for block in blocks:
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))

View File

@ -1,12 +1,28 @@
import time
import pytest
from sqlalchemy import delete
from letta.config import LettaConfig
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.orm import Provider, Step
from letta.schemas.agent import CreateAgent
from letta.schemas.block import CreateBlock
from letta.schemas.group import DynamicManager, GroupCreate, GroupUpdate, ManagerType, RoundRobinManager, SupervisorManager
from letta.schemas.block import Block, CreateBlock
from letta.schemas.enums import JobStatus
from letta.schemas.group import (
BackgroundManager,
DynamicManager,
GroupCreate,
GroupUpdate,
ManagerType,
RoundRobinManager,
SupervisorManager,
)
from letta.schemas.message import MessageCreate
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.server.server import SyncServer
@ -426,3 +442,129 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen
finally:
server.group_manager.delete_group(group_id=group.id, actor=actor)
@pytest.mark.asyncio
async def test_background_group_chat(server, actor):
# 1. create shared block
shared_memory_block = server.block_manager.create_or_update_block(
Block(
label="human",
value="",
limit=1000,
),
actor=actor,
)
# 2. create main agent
main_agent = server.create_agent(
request=CreateAgent(
name="main_agent",
memory_blocks=[
CreateBlock(
label="persona",
value="You are a personal assistant that helps users with requests.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
include_base_tools=False,
tools=BASE_TOOLS,
),
actor=actor,
)
# 3. create background memory agent
def skip_memory_update():
"""
Perform no memory updates. This function is used when the transcript
does not require any changes to the memory.
"""
skip_memory_update = Tool(
name=skip_memory_update.__name__,
description="",
source_type="python",
tags=[],
source_code=parse_source_code(skip_memory_update),
json_schema=generate_schema(skip_memory_update, None),
)
skip_memory_update = server.tool_manager.create_or_update_tool(
pydantic_tool=skip_memory_update,
actor=actor,
)
background_memory_agent = server.create_agent(
request=CreateAgent(
name="memory_agent",
memory_blocks=[
CreateBlock(
label="persona",
value="You manage memory for the main agent. You are a background agent and you are not expected to respond to messages. When you receive a conversation snippet from the main thread, perform memory updates only if there are meaningful changes, and otherwise call the skip_memory_update tool.",
),
],
model="openai/gpt-4o-mini",
embedding="openai/text-embedding-ada-002",
include_base_tools=False,
include_base_tool_rules=False,
tools=BASE_MEMORY_TOOLS + [skip_memory_update.name],
tool_rules=[
TerminalToolRule(tool_name="core_memory_append"),
TerminalToolRule(tool_name="core_memory_replace"),
TerminalToolRule(tool_name="skip_memory_update"),
],
),
actor=actor,
)
# 4. create group
group = server.group_manager.create_group(
group=GroupCreate(
description="",
agent_ids=[background_memory_agent.id],
manager_config=BackgroundManager(
manager_agent_id=main_agent.id,
background_agents_interval=2,
),
shared_block_ids=[shared_memory_block.id],
),
actor=actor,
)
agents = server.block_manager.get_agents_for_block(block_id=shared_memory_block.id, actor=actor)
assert len(agents) == 2
message_text = [
"my favorite color is orange",
"not particularly. today is a good day",
"actually my favorite color is coral",
"sorry gotta run",
]
run_ids = []
for i, text in enumerate(message_text):
response = await server.send_message_to_agent(
agent_id=main_agent.id,
actor=actor,
messages=[
MessageCreate(
role="user",
content=text,
),
],
stream_steps=False,
stream_tokens=False,
)
assert len(response.messages) > 0
assert len(response.usage.run_ids) == i % 2
run_ids.extend(response.usage.run_ids)
time.sleep(5)
for run_id in run_ids:
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
assert job.status == JobStatus.completed
server.group_manager.delete_group(group_id=group.id, actor=actor)
server.agent_manager.delete_agent(agent_id=background_memory_agent.id, actor=actor)
server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)