Add performance improvements

This commit is contained in:
Matt Zhou 2025-05-28 14:44:14 -07:00
parent 2716d422da
commit 73d628a124
4 changed files with 7 additions and 44 deletions

View File

@ -480,5 +480,6 @@ class GoogleVertexClient(LLMClientBase):
"required": tool["parameters"]["required"],
},
},
"propertyOrdering": ["name", "args"],
"required": ["name", "args"],
}

View File

@ -120,7 +120,7 @@ class Agent(SqlalchemyBase, OrganizationMixin, AsyncAttrs):
)
multi_agent_group: Mapped["Group"] = relationship(
"Group",
lazy="joined",
lazy="selectin",
viewonly=True,
back_populates="manager_agent",
)

View File

@ -61,8 +61,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
@ -86,8 +84,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Text to search for
query_embedding: Vector to search for similar embeddings
ascending: Sort direction
tags: List of tags to filter by
match_all_tags: If True, return items matching all tags. If False, match any tag.
**kwargs: Additional filters to apply
"""
if start_date and end_date and start_date > end_date:
@ -123,8 +119,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text=query_text,
query_embedding=query_embedding,
ascending=ascending,
tags=tags,
match_all_tags=match_all_tags,
actor=actor,
access=access,
access_type=access_type,
@ -162,8 +156,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
@ -189,8 +181,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Text to search for
query_embedding: Vector to search for similar embeddings
ascending: Sort direction
tags: List of tags to filter by
match_all_tags: If True, return items matching all tags. If False, match any tag.
**kwargs: Additional filters to apply
"""
if start_date and end_date and start_date > end_date:
@ -226,8 +216,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text=query_text,
query_embedding=query_embedding,
ascending=ascending,
tags=tags,
match_all_tags=match_all_tags,
actor=actor,
access=access,
access_type=access_type,
@ -263,8 +251,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
query_text: Optional[str] = None,
query_embedding: Optional[List[float]] = None,
ascending: bool = True,
tags: Optional[List[str]] = None,
match_all_tags: bool = False,
actor: Optional["User"] = None,
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
access_type: AccessType = AccessType.ORGANIZATION,
@ -286,28 +272,6 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
if actor:
query = cls.apply_access_predicate(query, actor, access, access_type)
# Handle tag filtering if the model has tags
if tags and hasattr(cls, "tags"):
query = select(cls)
if match_all_tags:
# Match ALL tags - use subqueries
subquery = (
select(cls.tags.property.mapper.class_.agent_id)
.where(cls.tags.property.mapper.class_.tag.in_(tags))
.group_by(cls.tags.property.mapper.class_.agent_id)
.having(func.count() == len(tags))
)
query = query.filter(cls.id.in_(subquery))
else:
# Match ANY tag - use join and filter
query = (
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
) # Deduplicate results
# select distinct primary key
query = query.distinct(cls.id).order_by(cls.id)
if identifier_keys and hasattr(cls, "identities"):
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))

View File

@ -1,7 +1,7 @@
import datetime
from typing import List, Literal, Optional
from sqlalchemy import and_, asc, desc, func, literal, or_, select
from sqlalchemy import and_, asc, desc, exists, or_, select
from letta import system
from letta.constants import IN_CONTEXT_MEMORY_KEYWORD, STRUCTURED_OUTPUT_MODELS
@ -438,13 +438,11 @@ def _apply_tag_filter(query, tags: Optional[List[str]], match_all_tags: bool):
The modified query with tag filters applied.
"""
if tags:
# Build a subquery to select agent IDs that have the specified tags.
subquery = select(AgentsTags.agent_id).where(AgentsTags.tag.in_(tags)).group_by(AgentsTags.agent_id)
# If all tags must match, add a HAVING clause to ensure the count of tags equals the number provided.
if match_all_tags:
subquery = subquery.having(func.count(AgentsTags.tag) == literal(len(tags)))
# Filter the main query to include only agents present in the subquery.
query = query.where(AgentModel.id.in_(subquery))
for tag in tags:
query = query.filter(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag == tag)))
else:
query = query.where(exists().where((AgentsTags.agent_id == AgentModel.id) & (AgentsTags.tag.in_(tags))))
return query