feat: Cache model handle (#1568)

This commit is contained in:
Matthew Zhou 2025-04-04 12:11:20 -07:00 committed by GitHub
parent e6908a412e
commit b56ced7336
3 changed files with 63 additions and 4 deletions

View File

@ -94,7 +94,7 @@ from letta.services.user_manager import UserManager
from letta.settings import model_settings, settings, tool_settings
from letta.sleeptime_agent import SleeptimeAgent
from letta.tracing import trace_method
from letta.utils import get_friendly_error_msg
from letta.utils import get_friendly_error_msg, make_key
config = LettaConfig.load()
logger = get_logger(__name__)
@ -346,6 +346,10 @@ class SyncServer(Server):
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
# TODO: Remove these in memory caches
self._llm_config_cache = {}
self._embedding_config_cache = {}
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
"""Updated method to load agents from persisted storage"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
@ -696,6 +700,18 @@ class SyncServer(Server):
command = command[1:] # strip the prefix
return self._command(user_id=user_id, agent_id=agent_id, command=command)
def get_cached_llm_config(self, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
return self._llm_config_cache[key]
def get_cached_embedding_config(self, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
return self._embedding_config_cache[key]
def create_agent(
self,
request: CreateAgent,
@ -706,7 +722,7 @@ class SyncServer(Server):
if request.llm_config is None:
if request.model is None:
raise ValueError("Must specify either model or llm_config in request")
request.llm_config = self.get_llm_config_from_handle(
request.llm_config = self.get_cached_llm_config(
handle=request.model,
context_window_limit=request.context_window_limit,
max_tokens=request.max_tokens,
@ -717,8 +733,9 @@ class SyncServer(Server):
if request.embedding_config is None:
if request.embedding is None:
raise ValueError("Must specify either embedding or embedding_config in request")
request.embedding_config = self.get_embedding_config_from_handle(
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
request.embedding_config = self.get_cached_embedding_config(
handle=request.embedding,
embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
)
main_agent = self.agent_manager.create_agent(

View File

@ -1070,3 +1070,7 @@ def log_telemetry(logger: Logger, event: str, **kwargs):
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S,%f UTC") # More readable timestamp
extra_data = " | ".join(f"{key}={value}" for key, value in kwargs.items() if value is not None)
logger.info(f"[{timestamp}] EVENT: {event} | {extra_data}")
def make_key(*args, **kwargs):
return str((args, tuple(sorted(kwargs.items()))))

View File

@ -465,3 +465,41 @@ def test_anthropic_streaming(client: Letta):
)
print(list(response))
import time
def test_create_agents_telemetry(client: Letta):
start_total = time.perf_counter()
# delete any existing worker agents
start_delete = time.perf_counter()
workers = client.agents.list(tags=["worker"])
for worker in workers:
client.agents.delete(agent_id=worker.id)
end_delete = time.perf_counter()
print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s")
# create worker agents
num_workers = 100
agent_times = []
for idx in range(num_workers):
start = time.perf_counter()
client.agents.create(
name=f"worker_{idx}",
include_base_tools=True,
model="anthropic/claude-3-5-sonnet-20241022",
embedding="letta/letta-free",
)
end = time.perf_counter()
duration = end - start
agent_times.append(duration)
print(f"[telemetry] Created worker_{idx} in {duration:.2f}s")
total_duration = time.perf_counter() - start_total
avg_duration = sum(agent_times) / len(agent_times)
print(f"[telemetry] Total time to create {num_workers} agents: {total_duration:.2f}s")
print(f"[telemetry] Average agent creation time: {avg_duration:.2f}s")
print(f"[telemetry] Fastest agent: {min(agent_times):.2f}s, Slowest agent: {max(agent_times):.2f}s")