mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: add agent types (#1831)
This commit is contained in:
parent
4a01ca3a55
commit
90a1e3b438
@ -30,6 +30,6 @@ services:
|
||||
ports:
|
||||
- "8000:8000"
|
||||
command: >
|
||||
--model ${LETTA_LLM_MODEL} --max_model_len=8000
|
||||
--model ${LETTA_LLM_MODEL} --max_model_len=8000
|
||||
# Replace with your model
|
||||
ipc: host
|
||||
ipc: host
|
||||
|
@ -9,7 +9,7 @@ from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.memory import get_memory_functions
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
||||
from letta.schemas.block import (
|
||||
Block,
|
||||
CreateBlock,
|
||||
@ -68,6 +68,7 @@ class AbstractClient(object):
|
||||
def create_agent(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
memory: Memory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_persona_text(DEFAULT_PERSONA)),
|
||||
@ -319,6 +320,8 @@ class RESTClient(AbstractClient):
|
||||
def create_agent(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
# agent config
|
||||
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
|
||||
# model configs
|
||||
embedding_config: EmbeddingConfig = None,
|
||||
llm_config: LLMConfig = None,
|
||||
@ -381,6 +384,7 @@ class RESTClient(AbstractClient):
|
||||
memory=memory,
|
||||
tools=tool_names,
|
||||
system=system,
|
||||
agent_type=agent_type,
|
||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||
)
|
||||
@ -1462,6 +1466,8 @@ class LocalClient(AbstractClient):
|
||||
def create_agent(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
# agent config
|
||||
agent_type: Optional[AgentType] = AgentType.memgpt_agent,
|
||||
# model configs
|
||||
embedding_config: EmbeddingConfig = None,
|
||||
llm_config: LLMConfig = None,
|
||||
@ -1524,6 +1530,7 @@ class LocalClient(AbstractClient):
|
||||
memory=memory,
|
||||
tools=tool_names,
|
||||
system=system,
|
||||
agent_type=agent_type,
|
||||
llm_config=llm_config if llm_config else self._default_llm_config,
|
||||
embedding_config=embedding_config if embedding_config else self._default_embedding_config,
|
||||
),
|
||||
|
@ -218,6 +218,7 @@ class AgentModel(Base):
|
||||
tools = Column(JSON)
|
||||
|
||||
# configs
|
||||
agent_type = Column(String)
|
||||
llm_config = Column(LLMConfigColumn)
|
||||
embedding_config = Column(EmbeddingConfigColumn)
|
||||
|
||||
@ -243,6 +244,7 @@ class AgentModel(Base):
|
||||
memory=Memory.load(self.memory), # load dictionary
|
||||
system=self.system,
|
||||
tools=self.tools,
|
||||
agent_type=self.agent_type,
|
||||
llm_config=self.llm_config,
|
||||
embedding_config=self.embedding_config,
|
||||
metadata_=self.metadata_,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
@ -21,6 +22,15 @@ class BaseAgent(LettaBase, validate_assignment=True):
|
||||
user_id: Optional[str] = Field(None, description="The user id of the agent.")
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
"""
|
||||
Enum to represent the type of agent.
|
||||
"""
|
||||
|
||||
memgpt_agent = "memgpt_agent"
|
||||
split_thread_agent = "split_thread_agent"
|
||||
|
||||
|
||||
class AgentState(BaseAgent):
|
||||
"""
|
||||
Representation of an agent's state. This is the state of the agent at a given time, and is persisted in the DB backend. The state has all the information needed to recreate a persisted agent.
|
||||
@ -52,6 +62,9 @@ class AgentState(BaseAgent):
|
||||
# system prompt
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
|
||||
# agent configuration
|
||||
agent_type: AgentType = Field(..., description="The type of agent.")
|
||||
|
||||
# llm information
|
||||
llm_config: LLMConfig = Field(..., description="The LLM configuration used by the agent.")
|
||||
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the agent.")
|
||||
@ -64,6 +77,7 @@ class CreateAgent(BaseAgent):
|
||||
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
|
||||
|
@ -51,7 +51,7 @@ from letta.providers import (
|
||||
OpenAIProvider,
|
||||
VLLMProvider,
|
||||
)
|
||||
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgentState
|
||||
from letta.schemas.api_key import APIKey, APIKeyCreate
|
||||
from letta.schemas.block import (
|
||||
Block,
|
||||
@ -335,7 +335,10 @@ class SyncServer(Server):
|
||||
# Make sure the memory is a memory object
|
||||
assert isinstance(agent_state.memory, Memory)
|
||||
|
||||
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
letta_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
||||
else:
|
||||
raise NotImplementedError("Only base agents are supported as of right now!")
|
||||
|
||||
# Add the agent to the in-memory store and return its reference
|
||||
logger.debug(f"Adding agent to the agent cache: user_id={user_id}, agent_id={agent_id}")
|
||||
@ -787,6 +790,7 @@ class SyncServer(Server):
|
||||
name=request.name,
|
||||
user_id=user_id,
|
||||
tools=request.tools if request.tools else [],
|
||||
agent_type=request.agent_type or AgentType.memgpt_agent,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
system=request.system,
|
||||
|
Loading…
Reference in New Issue
Block a user