feat: add agent types (#1831)

This commit is contained in:
Vivek Verma 2024-10-08 11:18:36 -07:00 committed by GitHub
parent 4a01ca3a55
commit 90a1e3b438
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 32 additions and 5 deletions

View File

@ -30,6 +30,6 @@ services:
ports:
- "8000:8000"
command: >
--model ${LETTA_LLM_MODEL} --max_model_len=8000
--model ${LETTA_LLM_MODEL} --max_model_len=8000
# Replace with your model
ipc: host
ipc: host

View File

@ -9,7 +9,7 @@ from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
from letta.memory import get_memory_functions
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.block import (
Block,
CreateBlock,
@ -68,6 +68,7 @@ class AbstractClient(object):
def create_agent(
self,
name: Optional[str] = None,
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
@ -319,6 +320,8 @@ class RESTClient(AbstractClient):
def create_agent(
self,
name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs
embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None,
@ -381,6 +384,7 @@ class RESTClient(AbstractClient):
memory=memory,
tools=tool_names,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
)
@ -1462,6 +1466,8 @@ class LocalClient(AbstractClient):
def create_agent(
self,
name: Optional[str] = None,
# agent config
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
# model configs
embedding_config: EmbeddingConfig = None,
llm_config: LLMConfig = None,
@ -1524,6 +1530,7 @@ class LocalClient(AbstractClient):
memory=memory,
tools=tool_names,
system=system,
agent_type=agent_type,
llm_config=llm_config if llm_config else self._default_llm_config,
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
),

View File

@ -218,6 +218,7 @@ class AgentModel(Base):
tools = Column(JSON)
# configs
agent_type = Column(String)
llm_config = Column(LLMConfigColumn)
embedding_config = Column(EmbeddingConfigColumn)
@ -243,6 +244,7 @@ class AgentModel(Base):
memory=Memory.load(self.memory), # load dictionary
system=self.system,
tools=self.tools,
agent_type=self.agent_type,
llm_config=self.llm_config,
embedding_config=self.embedding_config,
metadata_=self.metadata_,

View File

@ -1,5 +1,6 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Union
from pydantic import BaseModel, Field, field_validator
@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True):
user_id: Optional[str] = Field(None, description="The user id of the agent.")
class AgentType(str, Enum):
"""
Enum to represent the type of agent.
"""
memgpt_agent = "memgpt_agent"
split_thread_agent = "split_thread_agent"
class AgentState(BaseAgent):
"""
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
@ -52,6 +62,9 @@ class AgentState(BaseAgent):
# system prompt
system: str = Field(..., description="The system prompt used by the agent.")
# agent configuration
agent_type: AgentType = Field(..., description="The type of agent.")
# llm information
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
@ -64,6 +77,7 @@ class CreateAgent(BaseAgent):
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")

View File

@ -51,7 +51,7 @@ from letta.providers import (
OpenAIProvider,
VLLMProvider,
)
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
from letta.schemas.api_key import APIKey, APIKeyCreate
from letta.schemas.block import (
Block,
@ -335,7 +335,10 @@ class SyncServer(Server):
# Make sure the memory is a memory object
assert isinstance(agent_state.memory, Memory)
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
if agent_state.agent_type == AgentType.memgpt_agent:
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
else:
raise NotImplementedError("Only base agents are supported as of right now!")
# Add the agent to the in-memory store and return its reference
logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
@ -787,6 +790,7 @@ class SyncServer(Server):
name=request.name,
user_id=user_id,
tools=request.tools if request.tools else [],
agent_type=request.agent_type or AgentType.memgpt_agent,
llm_config=llm_config,
embedding_config=embedding_config,
system=request.system,