diff --git a/changes/11041.enhance.md b/changes/11041.enhance.md new file mode 100644 index 00000000000..705fac3a399 --- /dev/null +++ b/changes/11041.enhance.md @@ -0,0 +1 @@ +Add `UserRepository.get_main_access_key_by_id` helper and rewrite `SessionConditions.by_access_key_*` filters to resolve the owner's `main_access_key` via a subquery. Groundwork for dropping the redundant `sessions.access_key` / `kernels.access_key` columns. diff --git a/src/ai/backend/manager/models/session/conditions.py b/src/ai/backend/manager/models/session/conditions.py index 0298889188f..913a818c10d 100644 --- a/src/ai/backend/manager/models/session/conditions.py +++ b/src/ai/backend/manager/models/session/conditions.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from ai.backend.common.data.filter_specs import ( + StringInMatchSpec, StringMatchSpec, UUIDEqualMatchSpec, UUIDInMatchSpec, @@ -20,6 +21,7 @@ from ai.backend.manager.data.session.types import KernelMatchType, SessionStatus from ai.backend.manager.models.condition_utils import make_string_in_factory from ai.backend.manager.models.kernel import KernelRow +from ai.backend.manager.models.user import UserRow from ai.backend.manager.repositories.base import QueryCondition from .row import SessionRow @@ -28,6 +30,19 @@ class SessionConditions: """Query conditions for sessions.""" + @staticmethod + def _owners_where_main_access_key( + condition: sa.sql.expression.ColumnElement[bool], + ) -> sa.sql.expression.ColumnElement[bool]: + """Return a predicate matching ``SessionRow.user_uuid`` against users whose ``main_access_key`` satisfies ``condition``. + + The subquery selects ``users.uuid`` (non-null PK) so ``NOT IN`` is + well-defined. NULL ``main_access_key`` fails ``condition`` (evaluates + to NULL, not TRUE), so such users are excluded from the subquery + without needing an explicit ``IS NOT NULL`` guard. + """ + return SessionRow.user_uuid.in_(sa.select(UserRow.uuid).where(condition)) + @staticmethod def by_ids(session_ids: Collection[SessionId]) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: @@ -107,9 +122,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_contains(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}%") + match = UserRow.main_access_key.ilike(f"%{spec.value}%") else: - condition = SessionRow.access_key.like(f"%{spec.value}%") + match = UserRow.main_access_key.like(f"%{spec.value}%") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -120,9 +136,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_equals(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = sa.func.lower(SessionRow.access_key) == spec.value.lower() + match = sa.func.lower(UserRow.main_access_key) == spec.value.lower() else: - condition = SessionRow.access_key == spec.value + match = UserRow.main_access_key == spec.value + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -133,9 +150,10 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_starts_with(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"{spec.value}%") + match = UserRow.main_access_key.ilike(f"{spec.value}%") else: - condition = SessionRow.access_key.like(f"{spec.value}%") + match = UserRow.main_access_key.like(f"{spec.value}%") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition @@ -146,16 +164,29 @@ def inner() -> sa.sql.expression.ColumnElement[bool]: def by_access_key_ends_with(spec: StringMatchSpec) -> QueryCondition: def inner() -> sa.sql.expression.ColumnElement[bool]: if spec.case_insensitive: - condition = SessionRow.access_key.ilike(f"%{spec.value}") + match = UserRow.main_access_key.ilike(f"%{spec.value}") else: - condition = SessionRow.access_key.like(f"%{spec.value}") + match = UserRow.main_access_key.like(f"%{spec.value}") + condition = SessionConditions._owners_where_main_access_key(match) if spec.negated: condition = sa.not_(condition) return condition return inner - by_access_key_in = staticmethod(make_string_in_factory(SessionRow.access_key)) + @staticmethod + def by_access_key_in(spec: StringInMatchSpec) -> QueryCondition: + def inner() -> sa.sql.expression.ColumnElement[bool]: + if spec.case_insensitive: + match = sa.func.lower(UserRow.main_access_key).in_([v.lower() for v in spec.values]) + else: + match = UserRow.main_access_key.in_(spec.values) + condition = SessionConditions._owners_where_main_access_key(match) + if spec.negated: + condition = sa.not_(condition) + return condition + + return inner @staticmethod def by_domain_name_contains(spec: StringMatchSpec) -> QueryCondition: diff --git a/src/ai/backend/manager/repositories/user/db_source/db_source.py b/src/ai/backend/manager/repositories/user/db_source/db_source.py index f3ff8eef41b..bc7621fe6d6 100644 --- a/src/ai/backend/manager/repositories/user/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/user/db_source/db_source.py @@ -137,6 +137,13 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData: user_row = await self._get_user_by_uuid(db_session, user_uuid) return user_row.to_data() + async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None: + """Return the user's ``main_access_key`` or ``None`` if unset/missing.""" + async with self._db.begin_readonly_session() as db_session: + return await db_session.scalar( + sa.select(UserRow.main_access_key).where(UserRow.uuid == user_uuid) + ) + async def get_by_email_validated( self, email: str, diff --git a/src/ai/backend/manager/repositories/user/repository.py b/src/ai/backend/manager/repositories/user/repository.py index 9df169da1f9..9907b828595 100644 --- a/src/ai/backend/manager/repositories/user/repository.py +++ b/src/ai/backend/manager/repositories/user/repository.py @@ -77,6 +77,11 @@ async def get_user_by_uuid(self, user_uuid: UUID) -> UserData: """ return await self._db_source.get_user_by_uuid(user_uuid) + @user_repository_resilience.apply() + async def get_main_access_key_by_id(self, user_uuid: UUID) -> str | None: + """Return the user's ``main_access_key`` or ``None`` if unset/missing.""" + return await self._db_source.get_main_access_key_by_id(user_uuid) + @user_repository_resilience.apply() async def get_by_email_validated( self,