mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: background multi-agent group for sleeptime agent (#1508)
This commit is contained in:
parent
d1b4f7c669
commit
ab710c5073
@ -0,0 +1,44 @@
|
||||
"""add background group support
|
||||
|
||||
Revision ID: 74f2ede29317
|
||||
Revises: bff040379479
|
||||
Create Date: 2025-04-01 07:45:31.735977
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "74f2ede29317"
|
||||
down_revision: Union[str, None] = "bff040379479"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("groups", sa.Column("background_agents_interval", sa.Integer(), nullable=True))
|
||||
op.add_column("groups", sa.Column("turns_counter", sa.Integer(), nullable=True))
|
||||
op.add_column("groups", sa.Column("last_processed_message_id", sa.String(), nullable=True))
|
||||
op.create_table(
|
||||
"groups_blocks",
|
||||
sa.Column("group_id", sa.String(), nullable=False),
|
||||
sa.Column("block_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(["block_id"], ["block.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["group_id"], ["groups.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("group_id", "block_id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("groups_blocks")
|
||||
op.drop_column("groups", "last_processed_message_id")
|
||||
op.drop_column("groups", "turns_counter")
|
||||
op.drop_column("groups", "background_agents_interval")
|
||||
# ### end Alembic commands ###
|
254
letta/groups/background_multi_agent.py
Normal file
254
letta/groups/background_multi_agent.py
Normal file
@ -0,0 +1,254 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.groups.helpers import stringify_message
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import JobUpdate
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
|
||||
|
||||
class BackgroundMultiAgent(Agent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
description: str = "",
|
||||
background_agents_interval: Optional[int] = None,
|
||||
):
|
||||
super().__init__(interface, agent_state, user)
|
||||
self.group_id = group_id
|
||||
self.agent_ids = agent_ids
|
||||
self.description = description
|
||||
self.background_agents_interval = background_agents_interval
|
||||
self.group_manager = GroupManager()
|
||||
self.message_manager = MessageManager()
|
||||
self.job_manager = JobManager()
|
||||
|
||||
def _run_async_in_new_thread(self, coro):
|
||||
"""Run an async coroutine in a new thread with its own event loop"""
|
||||
result = None
|
||||
|
||||
def run_async():
|
||||
nonlocal result
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
result = loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
thread = threading.Thread(target=run_async)
|
||||
thread.start()
|
||||
thread.join()
|
||||
return result
|
||||
|
||||
async def _issue_background_task(
|
||||
self,
|
||||
participant_agent_id: str,
|
||||
messages: List[Message],
|
||||
chaining: bool,
|
||||
max_chaining_steps: Optional[int],
|
||||
token_streaming: bool,
|
||||
metadata: Optional[dict],
|
||||
put_inner_thoughts_first: bool,
|
||||
last_processed_message_id: str,
|
||||
) -> str:
|
||||
run = Run(
|
||||
user_id=self.user.id,
|
||||
status=JobStatus.created,
|
||||
metadata={
|
||||
"job_type": "background_agent_send_message_async",
|
||||
"agent_id": participant_agent_id,
|
||||
},
|
||||
)
|
||||
run = self.job_manager.create_job(pydantic_job=run, actor=self.user)
|
||||
|
||||
asyncio.create_task(
|
||||
self._perform_background_agent_step(
|
||||
participant_agent_id=participant_agent_id,
|
||||
messages=messages,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
token_streaming=token_streaming,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
last_processed_message_id=last_processed_message_id,
|
||||
run_id=run.id,
|
||||
)
|
||||
)
|
||||
|
||||
return run.id
|
||||
|
||||
async def _perform_background_agent_step(
|
||||
self,
|
||||
participant_agent_id: str,
|
||||
messages: List[Message],
|
||||
chaining: bool,
|
||||
max_chaining_steps: Optional[int],
|
||||
token_streaming: bool,
|
||||
metadata: Optional[dict],
|
||||
put_inner_thoughts_first: bool,
|
||||
last_processed_message_id: str,
|
||||
run_id: str,
|
||||
) -> LettaUsageStatistics:
|
||||
try:
|
||||
participant_agent_state = self.agent_manager.get_agent_by_id(participant_agent_id, actor=self.user)
|
||||
participant_agent = Agent(
|
||||
agent_state=participant_agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
prior_messages = []
|
||||
if self.background_agents_interval:
|
||||
try:
|
||||
prior_messages = self.message_manager.list_messages_for_agent(
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
after=last_processed_message_id,
|
||||
before=messages[0].id,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error fetching prior messages: {str(e)}")
|
||||
# continue with just latest messages
|
||||
|
||||
transcript_summary = [stringify_message(message) for message in prior_messages + messages]
|
||||
transcript_summary = [summary for summary in transcript_summary if summary is not None]
|
||||
message_text = "\n".join(transcript_summary)
|
||||
|
||||
participant_agent_messages = [
|
||||
Message(
|
||||
id=Message.generate_id(),
|
||||
agent_id=participant_agent.agent_state.id,
|
||||
role="user",
|
||||
content=[TextContent(text=message_text)],
|
||||
group_id=self.group_id,
|
||||
)
|
||||
]
|
||||
result = participant_agent.step(
|
||||
messages=participant_agent_messages,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
job_update = JobUpdate(
|
||||
status=JobStatus.completed,
|
||||
completed_at=datetime.utcnow(),
|
||||
metadata={"result": result.model_dump(mode="json")}, # Store the result in metadata
|
||||
)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
|
||||
return result
|
||||
except Exception as e:
|
||||
job_update = JobUpdate(
|
||||
status=JobStatus.failed,
|
||||
completed_at=datetime.utcnow(),
|
||||
metadata={"error": str(e)},
|
||||
)
|
||||
self.job_manager.update_job_by_id(job_id=run_id, job_update=job_update, actor=self.user)
|
||||
raise
|
||||
|
||||
def step(
|
||||
self,
|
||||
messages: List[MessageCreate],
|
||||
chaining: bool = True,
|
||||
max_chaining_steps: Optional[int] = None,
|
||||
put_inner_thoughts_first: bool = True,
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
run_ids = []
|
||||
|
||||
token_streaming = self.interface.streaming_mode if hasattr(self.interface, "streaming_mode") else False
|
||||
metadata = self.interface.metadata if hasattr(self.interface, "metadata") else None
|
||||
|
||||
messages = [
|
||||
Message(
|
||||
id=Message.generate_id(),
|
||||
agent_id=self.agent_state.id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message.content)] if isinstance(message.content, str) else message.content,
|
||||
name=message.name,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
group_id=self.group_id,
|
||||
)
|
||||
for message in messages
|
||||
]
|
||||
|
||||
try:
|
||||
main_agent = Agent(
|
||||
agent_state=self.agent_state,
|
||||
interface=self.interface,
|
||||
user=self.user,
|
||||
)
|
||||
usage_stats = main_agent.step(
|
||||
messages=messages,
|
||||
chaining=chaining,
|
||||
max_chaining_steps=max_chaining_steps,
|
||||
stream=token_streaming,
|
||||
skip_verify=True,
|
||||
metadata=metadata,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
)
|
||||
|
||||
turns_counter = None
|
||||
if self.background_agents_interval is not None and self.background_agents_interval > 0:
|
||||
turns_counter = self.group_manager.bump_turns_counter(group_id=self.group_id, actor=self.user)
|
||||
|
||||
if self.background_agents_interval is None or (
|
||||
turns_counter is not None and turns_counter % self.background_agents_interval == 0
|
||||
):
|
||||
last_response_messages = [message for sublist in usage_stats.steps_messages for message in sublist]
|
||||
last_processed_message_id = self.group_manager.get_last_processed_message_id_and_update(
|
||||
group_id=self.group_id, last_processed_message_id=last_response_messages[-1].id, actor=self.user
|
||||
)
|
||||
for participant_agent_id in self.agent_ids:
|
||||
try:
|
||||
run_id = self._run_async_in_new_thread(
|
||||
self._issue_background_task(
|
||||
participant_agent_id,
|
||||
last_response_messages,
|
||||
chaining,
|
||||
max_chaining_steps,
|
||||
token_streaming,
|
||||
metadata,
|
||||
put_inner_thoughts_first,
|
||||
last_processed_message_id,
|
||||
)
|
||||
)
|
||||
run_ids.append(run_id)
|
||||
|
||||
except Exception as e:
|
||||
# Handle individual task failures
|
||||
print(f"Agent processing failed: {str(e)}")
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
finally:
|
||||
self.interface.step_yield()
|
||||
|
||||
self.interface.step_complete()
|
||||
|
||||
usage_stats.run_ids = run_ids
|
||||
return LettaUsageStatistics(**usage_stats.model_dump())
|
104
letta/groups/helpers.py
Normal file
104
letta/groups/helpers.py
Normal file
@ -0,0 +1,104 @@
|
||||
import json
|
||||
from typing import Optional, Union
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.user import User
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.group import ManagerType
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
def load_multi_agent(
|
||||
group: Group,
|
||||
agent_state: Optional[AgentState],
|
||||
actor: User,
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
) -> Agent:
|
||||
if len(group.agent_ids) == 0:
|
||||
raise ValueError("Empty group: group must have at least one agent")
|
||||
|
||||
if not agent_state:
|
||||
raise ValueError("Empty manager agent state: manager agent state must be provided")
|
||||
|
||||
match group.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
from letta.groups.round_robin_multi_agent import RoundRobinMultiAgent
|
||||
|
||||
return RoundRobinMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
)
|
||||
case ManagerType.dynamic:
|
||||
from letta.groups.dynamic_multi_agent import DynamicMultiAgent
|
||||
|
||||
return DynamicMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
termination_token=group.termination_token,
|
||||
)
|
||||
case ManagerType.supervisor:
|
||||
from letta.groups.supervisor_multi_agent import SupervisorMultiAgent
|
||||
|
||||
return SupervisorMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
)
|
||||
case ManagerType.background:
|
||||
from letta.groups.background_multi_agent import BackgroundMultiAgent
|
||||
|
||||
return BackgroundMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
background_agents_interval=group.background_agents_interval,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Type {group.manager_type} is not supported.")
|
||||
|
||||
|
||||
def stringify_message(message: Message) -> str | None:
|
||||
if message.role == "user":
|
||||
content = json.loads(message.content[0].text)
|
||||
if content["type"] == "user_message":
|
||||
return f"{message.name or 'user'}: {content['message']}"
|
||||
else:
|
||||
return None
|
||||
elif message.role == "assistant":
|
||||
messages = []
|
||||
if message.content:
|
||||
messages.append(f"{message.name or 'assistant'}: *thinking* {message.content[0].text}")
|
||||
if message.tool_calls:
|
||||
if message.tool_calls[0].function.name == "send_message":
|
||||
messages.append(f"{message.name or 'assistant'}: {json.loads(message.tool_calls[0].function.arguments)['message']}")
|
||||
else:
|
||||
messages.append(f"{message.name or 'assistant'}: Calling tool {message.tool_calls[0].function.name}")
|
||||
return "\n".join(messages)
|
||||
elif message.role == "tool":
|
||||
if message.content:
|
||||
content = json.loads(message.content[0].text)
|
||||
if content["message"] != "None" and content["message"] != None:
|
||||
return f"{message.name or 'assistant'}: Tool call returned {content['message']}"
|
||||
return None
|
||||
elif message.role == "system":
|
||||
return None
|
||||
|
||||
return f"{message.name or 'user'}: {message.content[0].text}"
|
@ -7,6 +7,7 @@ from letta.orm.blocks_agents import BlocksAgents
|
||||
from letta.orm.file import FileMetadata
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.groups_agents import GroupsAgents
|
||||
from letta.orm.groups_blocks import GroupsBlocks
|
||||
from letta.orm.identities_agents import IdentitiesAgents
|
||||
from letta.orm.identities_blocks import IdentitiesBlocks
|
||||
from letta.orm.identity import Identity
|
||||
|
@ -67,6 +67,13 @@ class Block(OrganizationMixin, SqlalchemyBase):
|
||||
back_populates="blocks",
|
||||
passive_deletes=True,
|
||||
)
|
||||
groups: Mapped[List["Group"]] = relationship(
|
||||
"Group",
|
||||
secondary="groups_blocks",
|
||||
lazy="selectin",
|
||||
back_populates="shared_blocks",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
def to_pydantic(self) -> Type:
|
||||
match self.label:
|
||||
|
@ -20,6 +20,9 @@ class Group(SqlalchemyBase, OrganizationMixin):
|
||||
manager_agent_id: Mapped[Optional[str]] = mapped_column(String, ForeignKey("agents.id", ondelete="RESTRICT"), nullable=True, doc="")
|
||||
termination_token: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
max_turns: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
background_agents_interval: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
turns_counter: Mapped[Optional[int]] = mapped_column(nullable=True, doc="")
|
||||
last_processed_message_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="groups")
|
||||
@ -27,4 +30,7 @@ class Group(SqlalchemyBase, OrganizationMixin):
|
||||
agents: Mapped[List["Agent"]] = relationship(
|
||||
"Agent", secondary="groups_agents", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||
)
|
||||
shared_blocks: Mapped[List["Block"]] = relationship(
|
||||
"Block", secondary="groups_blocks", lazy="selectin", passive_deletes=True, back_populates="groups"
|
||||
)
|
||||
manager_agent: Mapped["Agent"] = relationship("Agent", lazy="joined", back_populates="multi_agent_group")
|
||||
|
13
letta/orm/groups_blocks.py
Normal file
13
letta/orm/groups_blocks.py
Normal file
@ -0,0 +1,13 @@
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.base import Base
|
||||
|
||||
|
||||
class GroupsBlocks(Base):
|
||||
"""Groups may have one or many shared blocks associated with them."""
|
||||
|
||||
__tablename__ = "groups_blocks"
|
||||
|
||||
group_id: Mapped[str] = mapped_column(String, ForeignKey("groups.id", ondelete="CASCADE"), primary_key=True)
|
||||
block_id: Mapped[str] = mapped_column(String, ForeignKey("block.id", ondelete="CASCADE"), primary_key=True)
|
@ -174,6 +174,7 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
False,
|
||||
description="If set to True, the agent will not remember previous messages (though the agent will still retain state via core memory blocks and archival/recall memory). Not recommended unless you have an advanced use case.",
|
||||
)
|
||||
enable_sleeptime: Optional[bool] = Field(False, description="If set to True, memory management will move to a background agent thread.")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@ -252,6 +253,7 @@ class UpdateAgent(BaseModel):
|
||||
embedding: Optional[str] = Field(
|
||||
None, description="The embedding configuration handle used by the agent, specified in the format provider/model-name."
|
||||
)
|
||||
enable_sleeptime: Optional[bool] = Field(False, description="If set to True, memory management will move to a background agent thread.")
|
||||
|
||||
class Config:
|
||||
extra = "ignore" # Ignores extra fields
|
||||
|
@ -10,6 +10,7 @@ class ManagerType(str, Enum):
|
||||
round_robin = "round_robin"
|
||||
supervisor = "supervisor"
|
||||
dynamic = "dynamic"
|
||||
background = "background"
|
||||
swarm = "swarm"
|
||||
|
||||
|
||||
@ -22,10 +23,14 @@ class Group(GroupBase):
|
||||
manager_type: ManagerType = Field(..., description="")
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
shared_block_ids: List[str] = Field([], description="")
|
||||
# Pattern fields
|
||||
manager_agent_id: Optional[str] = Field(None, description="")
|
||||
termination_token: Optional[str] = Field(None, description="")
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
background_agents_interval: Optional[int] = Field(None, description="")
|
||||
turns_counter: Optional[int] = Field(None, description="")
|
||||
last_processed_message_id: Optional[str] = Field(None, description="")
|
||||
|
||||
|
||||
class ManagerConfig(BaseModel):
|
||||
@ -49,12 +54,18 @@ class DynamicManager(ManagerConfig):
|
||||
max_turns: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
class BackgroundManager(ManagerConfig):
|
||||
manager_type: Literal[ManagerType.background] = Field(ManagerType.background, description="")
|
||||
manager_agent_id: str = Field(..., description="")
|
||||
background_agents_interval: Optional[int] = Field(None, description="")
|
||||
|
||||
|
||||
# class SwarmGroup(ManagerConfig):
|
||||
# manager_type: Literal[ManagerType.swarm] = Field(ManagerType.swarm, description="")
|
||||
|
||||
|
||||
ManagerConfigUnion = Annotated[
|
||||
Union[RoundRobinManager, SupervisorManager, DynamicManager],
|
||||
Union[RoundRobinManager, SupervisorManager, DynamicManager, BackgroundManager],
|
||||
Field(discriminator="manager_type"),
|
||||
]
|
||||
|
||||
@ -63,9 +74,11 @@ class GroupCreate(BaseModel):
|
||||
agent_ids: List[str] = Field(..., description="")
|
||||
description: str = Field(..., description="")
|
||||
manager_config: ManagerConfigUnion = Field(RoundRobinManager(), description="")
|
||||
shared_block_ids: List[str] = Field([], description="")
|
||||
|
||||
|
||||
class GroupUpdate(BaseModel):
|
||||
agent_ids: Optional[List[str]] = Field(None, description="")
|
||||
description: Optional[str] = Field(None, description="")
|
||||
manager_config: Optional[ManagerConfigUnion] = Field(None, description="")
|
||||
shared_block_ids: Optional[List[str]] = Field(None, description="")
|
||||
|
@ -23,3 +23,4 @@ class LettaUsageStatistics(BaseModel):
|
||||
step_count: int = Field(0, description="The number of steps taken by the agent.")
|
||||
# TODO: Optional for now. This field makes everyone's lives easier
|
||||
steps_messages: Optional[List[List[Message]]] = Field(None, description="The messages generated per step")
|
||||
run_ids: Optional[List[str]] = Field(None, description="The background task run IDs associated with the agent interaction")
|
||||
|
@ -19,11 +19,11 @@ import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.config import LettaConfig
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.dynamic_multi_agent import DynamicMultiAgent
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.functions.mcp_client.sse_client import MCP_CONFIG_TOPLEVEL_KEY, SSEMCPClient
|
||||
from letta.functions.mcp_client.stdio_client import StdioMCPClient
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.groups.helpers import load_multi_agent
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
@ -34,7 +34,6 @@ from letta.interface import CLIInterface # for printing to terminal
|
||||
from letta.log import get_logger
|
||||
from letta.offline_memory_agent import OfflineMemoryAgent
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.round_robin_multi_agent import RoundRobinMultiAgent
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@ -42,7 +41,6 @@ from letta.schemas.embedding_config import EmbeddingConfig
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.group import Group, ManagerType
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import LegacyLettaMessage, LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
@ -94,7 +92,6 @@ from letta.services.tool_executor.tool_execution_sandbox import ToolExecutionSan
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
from letta.supervisor_multi_agent import SupervisorMultiAgent
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@ -352,7 +349,7 @@ class SyncServer(Server):
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.multi_agent_group:
|
||||
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
|
||||
return load_multi_agent(group=agent_state.multi_agent_group, agent_state=agent_state, actor=actor, interface=interface)
|
||||
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
@ -364,49 +361,6 @@ class SyncServer(Server):
|
||||
|
||||
return agent
|
||||
|
||||
def load_multi_agent(
|
||||
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
|
||||
) -> Agent:
|
||||
if len(group.agent_ids) == 0:
|
||||
raise ValueError("Empty group: group must have at least one agent")
|
||||
|
||||
match group.manager_type:
|
||||
case ManagerType.round_robin:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.agent_ids[0], actor=actor)
|
||||
return RoundRobinMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
)
|
||||
case ManagerType.dynamic:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
|
||||
return DynamicMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
max_turns=group.max_turns,
|
||||
termination_token=group.termination_token,
|
||||
)
|
||||
case ManagerType.supervisor:
|
||||
agent_state = agent_state or self.agent_manager.get_agent_by_id(agent_id=group.manager_agent_id, actor=actor)
|
||||
return SupervisorMultiAgent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
)
|
||||
case _:
|
||||
raise ValueError(f"Type {group.manager_type} is not supported.")
|
||||
|
||||
def _step(
|
||||
self,
|
||||
actor: User,
|
||||
@ -1599,7 +1553,9 @@ class SyncServer(Server):
|
||||
raise ValueError("stream_steps must be 'true' if stream_tokens is 'true'")
|
||||
|
||||
group = self.group_manager.retrieve_group(group_id=group_id, actor=actor)
|
||||
letta_multi_agent = self.load_multi_agent(group=group, actor=actor)
|
||||
agent_state_id = group.manager_agent_id or (group.agent_ids[0] if len(group.agent_ids) > 0 else None)
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_state_id, actor=actor) if agent_state_id else None
|
||||
letta_multi_agent = load_multi_agent(group=group, agent_state=agent_state, actor=actor)
|
||||
|
||||
llm_config = letta_multi_agent.agent_state.llm_config
|
||||
supports_token_streaming = ["openai", "anthropic", "deepseek"]
|
||||
|
@ -71,11 +71,20 @@ class GroupManager:
|
||||
case ManagerType.supervisor:
|
||||
new_group.manager_type = ManagerType.supervisor
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
case ManagerType.background:
|
||||
new_group.manager_type = ManagerType.background
|
||||
new_group.manager_agent_id = group.manager_config.manager_agent_id
|
||||
new_group.background_agents_interval = group.manager_config.background_agents_interval
|
||||
if new_group.background_agents_interval:
|
||||
new_group.turns_counter = 0
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group.manager_config.manager_type}")
|
||||
|
||||
self._process_agent_relationship(session=session, group=new_group, agent_ids=group.agent_ids, allow_partial=False)
|
||||
|
||||
if group.shared_block_ids:
|
||||
self._process_shared_block_relationship(session=session, group=new_group, block_ids=group.shared_block_ids)
|
||||
|
||||
new_group.create(session, actor=actor)
|
||||
return new_group.to_pydantic()
|
||||
|
||||
@ -84,6 +93,7 @@ class GroupManager:
|
||||
with self.session_maker() as session:
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
background_agents_interval = None
|
||||
max_turns = None
|
||||
termination_token = None
|
||||
manager_agent_id = None
|
||||
@ -99,9 +109,16 @@ class GroupManager:
|
||||
termination_token = group_update.manager_config.termination_token
|
||||
case ManagerType.supervisor:
|
||||
manager_agent_id = group_update.manager_config.manager_agent_id
|
||||
case ManagerType.background:
|
||||
manager_agent_id = group_update.manager_config.manager_agent_id
|
||||
background_agents_interval = group_update.manager_config.background_agents_interval
|
||||
if background_agents_interval and group.turns_counter is None:
|
||||
group.turns_counter = 0
|
||||
case _:
|
||||
raise ValueError(f"Unsupported manager type: {group_update.manager_config.manager_type}")
|
||||
|
||||
if background_agents_interval:
|
||||
group.background_agents_interval = background_agents_interval
|
||||
if max_turns:
|
||||
group.max_turns = max_turns
|
||||
if termination_token:
|
||||
@ -174,6 +191,30 @@ class GroupManager:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def bump_turns_counter(self, group_id: str, actor: PydanticUser) -> int:
|
||||
with self.session_maker() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
# Update turns counter
|
||||
group.turns_counter = (group.turns_counter + 1) % group.background_agents_interval
|
||||
group.update(session, actor=actor)
|
||||
return group.turns_counter
|
||||
|
||||
@enforce_types
|
||||
def get_last_processed_message_id_and_update(self, group_id: str, last_processed_message_id: str, actor: PydanticUser) -> str:
|
||||
with self.session_maker() as session:
|
||||
# Ensure group is loadable by user
|
||||
group = GroupModel.read(db_session=session, identifier=group_id, actor=actor)
|
||||
|
||||
# Update last processed message id
|
||||
prev_last_processed_message_id = group.last_processed_message_id
|
||||
group.last_processed_message_id = last_processed_message_id
|
||||
group.update(session, actor=actor)
|
||||
|
||||
return prev_last_processed_message_id
|
||||
|
||||
def _process_agent_relationship(self, session: Session, group: GroupModel, agent_ids: List[str], allow_partial=False, replace=True):
|
||||
if not agent_ids:
|
||||
if replace:
|
||||
@ -203,3 +244,30 @@ class GroupManager:
|
||||
setattr(group, "agent_ids", agent_ids)
|
||||
else:
|
||||
raise ValueError("Extend relationship is not supported for groups.")
|
||||
|
||||
def _process_shared_block_relationship(
|
||||
self,
|
||||
session: Session,
|
||||
group: GroupModel,
|
||||
block_ids: List[str],
|
||||
):
|
||||
"""Process shared block relationships for a group and its agents."""
|
||||
from letta.orm import Agent, Block, BlocksAgents
|
||||
|
||||
# Add blocks to group
|
||||
blocks = session.query(Block).filter(Block.id.in_(block_ids)).all()
|
||||
group.shared_blocks = blocks
|
||||
|
||||
# Add blocks to all agents
|
||||
if group.agent_ids:
|
||||
agents = session.query(Agent).filter(Agent.id.in_(group.agent_ids)).all()
|
||||
for agent in agents:
|
||||
for block in blocks:
|
||||
session.add(BlocksAgents(agent_id=agent.id, block_id=block.id, block_label=block.label))
|
||||
|
||||
# Add blocks to manager agent if exists
|
||||
if group.manager_agent_id:
|
||||
manager_agent = session.query(Agent).filter(Agent.id == group.manager_agent_id).first()
|
||||
if manager_agent:
|
||||
for block in blocks:
|
||||
session.add(BlocksAgents(agent_id=manager_agent.id, block_id=block.id, block_label=block.label))
|
||||
|
@ -1,12 +1,28 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
from letta.orm import Provider, Step
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.block import CreateBlock
|
||||
from letta.schemas.group import DynamicManager, GroupCreate, GroupUpdate, ManagerType, RoundRobinManager, SupervisorManager
|
||||
from letta.schemas.block import Block, CreateBlock
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.group import (
|
||||
BackgroundManager,
|
||||
DynamicManager,
|
||||
GroupCreate,
|
||||
GroupUpdate,
|
||||
ManagerType,
|
||||
RoundRobinManager,
|
||||
SupervisorManager,
|
||||
)
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
|
||||
@ -426,3 +442,129 @@ async def test_dynamic_group_chat(server, actor, manager_agent, participant_agen
|
||||
|
||||
finally:
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_background_group_chat(server, actor):
|
||||
# 1. create shared block
|
||||
shared_memory_block = server.block_manager.create_or_update_block(
|
||||
Block(
|
||||
label="human",
|
||||
value="",
|
||||
limit=1000,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# 2. create main agent
|
||||
main_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="main_agent",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You are a personal assistant that helps users with requests.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
include_base_tools=False,
|
||||
tools=BASE_TOOLS,
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# 3. create background memory agent
|
||||
def skip_memory_update():
|
||||
"""
|
||||
Perform no memory updates. This function is used when the transcript
|
||||
does not require any changes to the memory.
|
||||
"""
|
||||
|
||||
skip_memory_update = Tool(
|
||||
name=skip_memory_update.__name__,
|
||||
description="",
|
||||
source_type="python",
|
||||
tags=[],
|
||||
source_code=parse_source_code(skip_memory_update),
|
||||
json_schema=generate_schema(skip_memory_update, None),
|
||||
)
|
||||
skip_memory_update = server.tool_manager.create_or_update_tool(
|
||||
pydantic_tool=skip_memory_update,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
background_memory_agent = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="memory_agent",
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="persona",
|
||||
value="You manage memory for the main agent. You are a background agent and you are not expected to respond to messages. When you receive a conversation snippet from the main thread, perform memory updates only if there are meaningful changes, and otherwise call the skip_memory_update tool.",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
include_base_tools=False,
|
||||
include_base_tool_rules=False,
|
||||
tools=BASE_MEMORY_TOOLS + [skip_memory_update.name],
|
||||
tool_rules=[
|
||||
TerminalToolRule(tool_name="core_memory_append"),
|
||||
TerminalToolRule(tool_name="core_memory_replace"),
|
||||
TerminalToolRule(tool_name="skip_memory_update"),
|
||||
],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# 4. create group
|
||||
group = server.group_manager.create_group(
|
||||
group=GroupCreate(
|
||||
description="",
|
||||
agent_ids=[background_memory_agent.id],
|
||||
manager_config=BackgroundManager(
|
||||
manager_agent_id=main_agent.id,
|
||||
background_agents_interval=2,
|
||||
),
|
||||
shared_block_ids=[shared_memory_block.id],
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
agents = server.block_manager.get_agents_for_block(block_id=shared_memory_block.id, actor=actor)
|
||||
assert len(agents) == 2
|
||||
|
||||
message_text = [
|
||||
"my favorite color is orange",
|
||||
"not particularly. today is a good day",
|
||||
"actually my favorite color is coral",
|
||||
"sorry gotta run",
|
||||
]
|
||||
run_ids = []
|
||||
for i, text in enumerate(message_text):
|
||||
response = await server.send_message_to_agent(
|
||||
agent_id=main_agent.id,
|
||||
actor=actor,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=text,
|
||||
),
|
||||
],
|
||||
stream_steps=False,
|
||||
stream_tokens=False,
|
||||
)
|
||||
|
||||
assert len(response.messages) > 0
|
||||
assert len(response.usage.run_ids) == i % 2
|
||||
run_ids.extend(response.usage.run_ids)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
for run_id in run_ids:
|
||||
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
|
||||
assert job.status == JobStatus.completed
|
||||
|
||||
server.group_manager.delete_group(group_id=group.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_id=background_memory_agent.id, actor=actor)
|
||||
server.agent_manager.delete_agent(agent_id=main_agent.id, actor=actor)
|
||||
|
Loading…
Reference in New Issue
Block a user