Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11041.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `UserRepository.get_main_access_key_by_id` helper and rewrite `SessionConditions.by_access_key_*` filters to resolve the owner's `main_access_key` via a subquery. Groundwork for dropping the redundant `sessions.access_key` / `kernels.access_key` columns.
49 changes: 40 additions & 9 deletions src/ai/backend/manager/models/session/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

if TYPE_CHECKING:
from ai.backend.common.data.filter_specs import (
StringInMatchSpec,
StringMatchSpec,
UUIDEqualMatchSpec,
UUIDInMatchSpec,
Comment thread
jopemachine marked this conversation as resolved.
Expand All @@ -20,6 +21,7 @@
from ai.backend.manager.data.session.types import KernelMatchType, SessionStatus
from ai.backend.manager.models.condition_utils import make_string_in_factory
from ai.backend.manager.models.kernel import KernelRow
from ai.backend.manager.models.user import UserRow
from ai.backend.manager.repositories.base import QueryCondition

from .row import SessionRow
Expand All @@ -28,6 +30,19 @@
class SessionConditions:
"""Query conditions for sessions."""

@staticmethod
def _owners_where_main_access_key(
condition: sa.sql.expression.ColumnElement[bool],
) -> sa.sql.expression.ColumnElement[bool]:
"""Return a predicate matching ``SessionRow.user_uuid`` against users whose ``main_access_key`` satisfies ``condition``.

The subquery selects ``users.uuid`` (non-null PK) so ``NOT IN`` is
well-defined. NULL ``main_access_key`` fails ``condition`` (evaluates
to NULL, not TRUE), so such users are excluded from the subquery
without needing an explicit ``IS NOT NULL`` guard.
"""
return SessionRow.user_uuid.in_(sa.select(UserRow.uuid).where(condition))

@staticmethod
def by_ids(session_ids: Collection[SessionId]) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
Expand Down Expand Up @@ -107,9 +122,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"%{spec.value}%")
match = UserRow.main_access_key.ilike(f"%{spec.value}%")
else:
condition = SessionRow.access_key.like(f"%{spec.value}%")
match = UserRow.main_access_key.like(f"%{spec.value}%")
condition = SessionConditions._owners_where_main_access_key(match)
if spec.negated:
condition = sa.not_(condition)
return condition
Expand All @@ -120,9 +136,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = sa.func.lower(SessionRow.access_key) == spec.value.lower()
match = sa.func.lower(UserRow.main_access_key) == spec.value.lower()
else:
condition = SessionRow.access_key == spec.value
match = UserRow.main_access_key == spec.value
condition = SessionConditions._owners_where_main_access_key(match)
if spec.negated:
condition = sa.not_(condition)
return condition
Expand All @@ -133,9 +150,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"{spec.value}%")
match = UserRow.main_access_key.ilike(f"{spec.value}%")
else:
condition = SessionRow.access_key.like(f"{spec.value}%")
match = UserRow.main_access_key.like(f"{spec.value}%")
condition = SessionConditions._owners_where_main_access_key(match)
if spec.negated:
condition = sa.not_(condition)
return condition
Expand All @@ -146,16 +164,29 @@ def inner() -> sa.sql.expression.ColumnElement[bool]:
def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
condition = SessionRow.access_key.ilike(f"%{spec.value}")
match = UserRow.main_access_key.ilike(f"%{spec.value}")
else:
condition = SessionRow.access_key.like(f"%{spec.value}")
match = UserRow.main_access_key.like(f"%{spec.value}")
condition = SessionConditions._owners_where_main_access_key(match)
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

by_access_key_in = staticmethod(make_string_in_factory(SessionRow.access_key))
@staticmethod
def by_access_key_in(spec: StringInMatchSpec) -> QueryCondition:
def inner() -> sa.sql.expression.ColumnElement[bool]:
if spec.case_insensitive:
match = sa.func.lower(UserRow.main_access_key).in_([v.lower() for v in spec.values])
else:
match = UserRow.main_access_key.in_(spec.values)
condition = SessionConditions._owners_where_main_access_key(match)
if spec.negated:
condition = sa.not_(condition)
return condition

return inner

@staticmethod
def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,13 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData:
user_row = await self._get_user_by_uuid(db_session, user_uuid)
return user_row.to_data()

async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None:
"""Return the user's ``main_access_key`` or ``None`` if unset/missing."""
async with self._db.begin_readonly_session() as db_session:
return await db_session.scalar(
sa.select(UserRow.main_access_key).where(UserRow.uuid == user_uuid)
)

async def get_by_email_validated(
self,
email: str,
Expand Down
5 changes: 5 additions & 0 deletions src/ai/backend/manager/repositories/user/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData:
"""
return await self._db_source.get_user_by_uuid(user_uuid)

@user_repository_resilience.apply()
async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None:
"""Return the user's ``main_access_key`` or ``None`` if unset/missing."""
return await self._db_source.get_main_access_key_by_id(user_uuid)
Comment thread
jopemachine marked this conversation as resolved.

@user_repository_resilience.apply()
async def get_by_email_validated(
self,
Expand Down
Loading