diff --git a/changes/11647.feature.md b/changes/11647.feature.md new file mode 100644 index 00000000000..173459ba510 --- /dev/null +++ b/changes/11647.feature.md @@ -0,0 +1 @@ +Add `UserID` identifier type and `ResolveUserIDByAccessKey` auth action for resolving an access_key to its owning user UUID. diff --git a/src/ai/backend/common/identifier/user.py b/src/ai/backend/common/identifier/user.py new file mode 100644 index 00000000000..809300951d9 --- /dev/null +++ b/src/ai/backend/common/identifier/user.py @@ -0,0 +1,7 @@ +from typing import NewType +from uuid import UUID + +__all__ = ("UserID",) + + +UserID = NewType("UserID", UUID) diff --git a/src/ai/backend/manager/repositories/auth/db_source/db_source.py b/src/ai/backend/manager/repositories/auth/db_source/db_source.py index 56ebf53856c..75fd3d5c245 100644 --- a/src/ai/backend/manager/repositories/auth/db_source/db_source.py +++ b/src/ai/backend/manager/repositories/auth/db_source/db_source.py @@ -12,15 +12,18 @@ from sqlalchemy.orm import joinedload, selectinload from ai.backend.common.exception import BackendAIError, UserNotFound +from ai.backend.common.identifier.user import UserID from ai.backend.common.metrics.metric import DomainType, LayerType from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy from ai.backend.common.resilience.resilience import Resilience +from ai.backend.common.types import AccessKey from ai.backend.manager.data.auth.login_session_types import LoginHistoryData, LoginSessionData from ai.backend.manager.data.auth.types import GroupMembershipData, UserData from ai.backend.manager.data.common.types import SearchResult from ai.backend.manager.data.permission.types import EntityType, ScopeType from ai.backend.manager.errors.auth import ( + AccessKeyNotFound, AuthorizationFailed, GroupMembershipNotFoundError, LoginSessionNotFoundError, @@ -292,6 +295,16 @@ async def fetch_user_info_by_access_key(self, access_key: str) -> tuple[str, Use raise ValueError("Unknown owner access key") return row.domain_name, row.role + @auth_db_source_resilience.apply() + async def fetch_user_id_by_access_key(self, access_key: AccessKey) -> UserID: + async with self._db.begin_readonly() as conn: + query = sa.select(keypairs.c.user).where(keypairs.c.access_key == access_key) + result = await conn.execute(query) + row = result.scalar() + if row is None: + raise AccessKeyNotFound("Unknown access key") + return UserID(cast(UUID, row)) + @auth_db_source_resilience.apply() async def fetch_user_info_by_email(self, email: str) -> tuple[UUID, UserRole, str]: """Fetch (uuid, role, domain_name) for a user identified by *email*. diff --git a/src/ai/backend/manager/repositories/auth/repository.py b/src/ai/backend/manager/repositories/auth/repository.py index 8c3e79c42b7..20da8f1ff0b 100644 --- a/src/ai/backend/manager/repositories/auth/repository.py +++ b/src/ai/backend/manager/repositories/auth/repository.py @@ -4,9 +4,11 @@ import sqlalchemy as sa +from ai.backend.common.identifier.user import UserID from ai.backend.common.metrics.metric import DomainType, LayerType from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy from ai.backend.common.resilience.resilience import Resilience +from ai.backend.common.types import AccessKey from ai.backend.manager.data.auth.login_session_types import LoginHistoryData, LoginSessionData from ai.backend.manager.data.auth.types import GroupMembershipData, UserData from ai.backend.manager.data.common.types import SearchResult @@ -82,6 +84,10 @@ async def update_ssh_keypair(self, access_key: str, public_key: str, private_key async def get_delegation_target_by_access_key(self, access_key: str) -> tuple[str, UserRole]: return await self._db_source.fetch_user_info_by_access_key(access_key) + @auth_repository_resilience.apply() + async def get_user_id_by_access_key(self, access_key: AccessKey) -> UserID: + return await self._db_source.fetch_user_id_by_access_key(access_key) + @auth_repository_resilience.apply() async def get_delegation_target_by_email(self, email: str) -> tuple[UUID, UserRole, str]: return await self._db_source.fetch_user_info_by_email(email) diff --git a/src/ai/backend/manager/services/auth/actions/resolve_user_id_by_access_key.py b/src/ai/backend/manager/services/auth/actions/resolve_user_id_by_access_key.py new file mode 100644 index 00000000000..0c8fe23038d --- /dev/null +++ b/src/ai/backend/manager/services/auth/actions/resolve_user_id_by_access_key.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import override + +from ai.backend.common.identifier.user import UserID +from ai.backend.common.types import AccessKey +from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.manager.actions.types import ActionOperationType +from ai.backend.manager.services.auth.actions.base import AuthAction + + +@dataclass +class ResolveUserIDByAccessKeyAction(AuthAction): + access_key: AccessKey + + @override + def entity_id(self) -> str | None: + return str(self.access_key) + + @override + @classmethod + def operation_type(cls) -> ActionOperationType: + return ActionOperationType.GET + + +@dataclass +class ResolveUserIDByAccessKeyResult(BaseActionResult): + user_id: UserID + + @override + def entity_id(self) -> str | None: + return str(self.user_id) diff --git a/src/ai/backend/manager/services/auth/processors.py b/src/ai/backend/manager/services/auth/processors.py index 608f1a6b27c..5da695a4cb3 100644 --- a/src/ai/backend/manager/services/auth/processors.py +++ b/src/ai/backend/manager/services/auth/processors.py @@ -24,6 +24,10 @@ ResolveAccessKeyScopeAction, ResolveAccessKeyScopeResult, ) +from ai.backend.manager.services.auth.actions.resolve_user_id_by_access_key import ( + ResolveUserIDByAccessKeyAction, + ResolveUserIDByAccessKeyResult, +) from ai.backend.manager.services.auth.actions.resolve_user_scope import ( ResolveUserScopeAction, ResolveUserScopeResult, @@ -88,6 +92,9 @@ class AuthProcessors(AbstractProcessorPackage): ResolveAccessKeyScopeAction, ResolveAccessKeyScopeResult ] resolve_user_scope: ActionProcessor[ResolveUserScopeAction, ResolveUserScopeResult] + resolve_user_id_by_access_key: ActionProcessor[ + ResolveUserIDByAccessKeyAction, ResolveUserIDByAccessKeyResult + ] admin_search_login_sessions: ActionProcessor[ AdminSearchLoginSessionsAction, SearchLoginSessionsActionResult ] @@ -135,6 +142,9 @@ def __init__( service.resolve_access_key_scope, action_monitors ) self.resolve_user_scope = ActionProcessor(service.resolve_user_scope, action_monitors) + self.resolve_user_id_by_access_key = ActionProcessor( + service.resolve_user_id_by_access_key, action_monitors + ) self.admin_search_login_sessions = ActionProcessor( service.admin_search_login_sessions, action_monitors ) @@ -167,6 +177,7 @@ def supported_actions(self) -> list[ActionSpec]: UpdatePasswordNoAuthAction.spec(), ResolveAccessKeyScopeAction.spec(), ResolveUserScopeAction.spec(), + ResolveUserIDByAccessKeyAction.spec(), AdminSearchLoginSessionsAction.spec(), SearchLoginSessionsAction.spec(), AdminSearchLoginHistoryAction.spec(), diff --git a/src/ai/backend/manager/services/auth/service.py b/src/ai/backend/manager/services/auth/service.py index df31d73b94d..54cb241e525 100644 --- a/src/ai/backend/manager/services/auth/service.py +++ b/src/ai/backend/manager/services/auth/service.py @@ -77,6 +77,10 @@ ResolveAccessKeyScopeAction, ResolveAccessKeyScopeResult, ) +from ai.backend.manager.services.auth.actions.resolve_user_id_by_access_key import ( + ResolveUserIDByAccessKeyAction, + ResolveUserIDByAccessKeyResult, +) from ai.backend.manager.services.auth.actions.resolve_user_scope import ( ResolveUserScopeAction, ResolveUserScopeResult, @@ -758,6 +762,12 @@ async def resolve_access_key_scope( owner_access_key=owner_ak, ) + async def resolve_user_id_by_access_key( + self, action: ResolveUserIDByAccessKeyAction + ) -> ResolveUserIDByAccessKeyResult: + user_id = await self._auth_repository.get_user_id_by_access_key(action.access_key) + return ResolveUserIDByAccessKeyResult(user_id=user_id) + async def resolve_user_scope(self, action: ResolveUserScopeAction) -> ResolveUserScopeResult: if action.owner_user_email is None: return ResolveUserScopeResult( diff --git a/tests/unit/manager/repositories/auth/test_auth_repository.py b/tests/unit/manager/repositories/auth/test_auth_repository.py index 23937f65f57..1c30985af2b 100644 --- a/tests/unit/manager/repositories/auth/test_auth_repository.py +++ b/tests/unit/manager/repositories/auth/test_auth_repository.py @@ -15,12 +15,12 @@ from ai.backend.common.data.permission.types import RelationType from ai.backend.common.exception import UserNotFound -from ai.backend.common.types import ResourceSlot, VFolderHostPermissionMap +from ai.backend.common.types import AccessKey, ResourceSlot, VFolderHostPermissionMap from ai.backend.manager.data.auth.hash import PasswordHashAlgorithm from ai.backend.manager.data.auth.types import UserData from ai.backend.manager.data.group.types import GroupData from ai.backend.manager.data.permission.types import EntityType, ScopeType -from ai.backend.manager.errors.auth import GroupMembershipNotFoundError +from ai.backend.manager.errors.auth import AccessKeyNotFound, GroupMembershipNotFoundError from ai.backend.manager.models.agent import AgentRow from ai.backend.manager.models.deployment_auto_scaling_policy import DeploymentAutoScalingPolicyRow from ai.backend.manager.models.deployment_policy import DeploymentPolicyRow @@ -533,3 +533,20 @@ async def test_get_current_time(self, auth_repository: AuthRepository) -> None: now_utc = datetime.now(UTC) time_diff = abs((now_utc - result).total_seconds()) assert time_diff < 1.0 + + async def test_get_user_id_by_access_key_success( + self, + auth_repository: AuthRepository, + sample_user_data: UserTestData, + ) -> None: + result = await auth_repository.get_user_id_by_access_key( + AccessKey(sample_user_data.access_key) + ) + + assert result == sample_user_data.uuid + + async def test_get_user_id_by_access_key_not_found( + self, auth_repository: AuthRepository + ) -> None: + with pytest.raises(AccessKeyNotFound): + await auth_repository.get_user_id_by_access_key(AccessKey("AKIANONEXISTENT"))