mirror of
https://github.com/cpacker/MemGPT.git
synced 2025-06-03 04:30:22 +00:00
feat: Cache model handle (#1568)
This commit is contained in:
parent
e6908a412e
commit
b56ced7336
@ -94,7 +94,7 @@ from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings, settings, tool_settings
|
||||
from letta.sleeptime_agent import SleeptimeAgent
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import get_friendly_error_msg
|
||||
from letta.utils import get_friendly_error_msg, make_key
|
||||
|
||||
config = LettaConfig.load()
|
||||
logger = get_logger(__name__)
|
||||
@ -346,6 +346,10 @@ class SyncServer(Server):
|
||||
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
|
||||
|
||||
# TODO: Remove these in memory caches
|
||||
self._llm_config_cache = {}
|
||||
self._embedding_config_cache = {}
|
||||
|
||||
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@ -696,6 +700,18 @@ class SyncServer(Server):
|
||||
command = command[1:] # strip the prefix
|
||||
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
||||
|
||||
def get_cached_llm_config(self, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
def get_cached_embedding_config(self, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
@ -706,7 +722,7 @@ class SyncServer(Server):
|
||||
if request.llm_config is None:
|
||||
if request.model is None:
|
||||
raise ValueError("Must specify either model or llm_config in request")
|
||||
request.llm_config = self.get_llm_config_from_handle(
|
||||
request.llm_config = self.get_cached_llm_config(
|
||||
handle=request.model,
|
||||
context_window_limit=request.context_window_limit,
|
||||
max_tokens=request.max_tokens,
|
||||
@ -717,8 +733,9 @@ class SyncServer(Server):
|
||||
if request.embedding_config is None:
|
||||
if request.embedding is None:
|
||||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||||
request.embedding_config = self.get_embedding_config_from_handle(
|
||||
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
request.embedding_config = self.get_cached_embedding_config(
|
||||
handle=request.embedding,
|
||||
embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
)
|
||||
|
||||
main_agent = self.agent_manager.create_agent(
|
||||
|
@ -1070,3 +1070,7 @@ def log_telemetry(logger: Logger, event: str, **kwargs):
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S,%f UTC") # More readable timestamp
|
||||
extra_data = " | ".join(f"{key}={value}" for key, value in kwargs.items() if value is not None)
|
||||
logger.info(f"[{timestamp}] EVENT: {event} | {extra_data}")
|
||||
|
||||
|
||||
def make_key(*args, **kwargs):
|
||||
return str((args, tuple(sorted(kwargs.items()))))
|
||||
|
@ -465,3 +465,41 @@ def test_anthropic_streaming(client: Letta):
|
||||
)
|
||||
|
||||
print(list(response))
|
||||
|
||||
|
||||
import time
|
||||
|
||||
|
||||
def test_create_agents_telemetry(client: Letta):
|
||||
start_total = time.perf_counter()
|
||||
|
||||
# delete any existing worker agents
|
||||
start_delete = time.perf_counter()
|
||||
workers = client.agents.list(tags=["worker"])
|
||||
for worker in workers:
|
||||
client.agents.delete(agent_id=worker.id)
|
||||
end_delete = time.perf_counter()
|
||||
print(f"[telemetry] Deleted {len(workers)} existing worker agents in {end_delete - start_delete:.2f}s")
|
||||
|
||||
# create worker agents
|
||||
num_workers = 100
|
||||
agent_times = []
|
||||
for idx in range(num_workers):
|
||||
start = time.perf_counter()
|
||||
client.agents.create(
|
||||
name=f"worker_{idx}",
|
||||
include_base_tools=True,
|
||||
model="anthropic/claude-3-5-sonnet-20241022",
|
||||
embedding="letta/letta-free",
|
||||
)
|
||||
end = time.perf_counter()
|
||||
duration = end - start
|
||||
agent_times.append(duration)
|
||||
print(f"[telemetry] Created worker_{idx} in {duration:.2f}s")
|
||||
|
||||
total_duration = time.perf_counter() - start_total
|
||||
avg_duration = sum(agent_times) / len(agent_times)
|
||||
|
||||
print(f"[telemetry] Total time to create {num_workers} agents: {total_duration:.2f}s")
|
||||
print(f"[telemetry] Average agent creation time: {avg_duration:.2f}s")
|
||||
print(f"[telemetry] Fastest agent: {min(agent_times):.2f}s, Slowest agent: {max(agent_times):.2f}s")
|
||||
|
Loading…
Reference in New Issue
Block a user