Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11647.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add `UserID` identifier type and `ResolveUserIDByAccessKey` auth action for resolving an access_key to its owning user UUID.
7 changes: 7 additions & 0 deletions src/ai/backend/common/identifier/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import NewType
from uuid import UUID

__all__ = ("UserID",)


UserID = NewType("UserID", UUID)
13 changes: 13 additions & 0 deletions src/ai/backend/manager/repositories/auth/db_source/db_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,18 @@
from sqlalchemy.orm import joinedload, selectinload

from ai.backend.common.exception import BackendAIError, UserNotFound
from ai.backend.common.identifier.user import UserID
from ai.backend.common.metrics.metric import DomainType, LayerType
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
from ai.backend.common.resilience.resilience import Resilience
from ai.backend.common.types import AccessKey
from ai.backend.manager.data.auth.login_session_types import LoginHistoryData, LoginSessionData
from ai.backend.manager.data.auth.types import GroupMembershipData, UserData
from ai.backend.manager.data.common.types import SearchResult
from ai.backend.manager.data.permission.types import EntityType, ScopeType
from ai.backend.manager.errors.auth import (
AccessKeyNotFound,
AuthorizationFailed,
GroupMembershipNotFoundError,
LoginSessionNotFoundError,
Expand Down Expand Up @@ -292,6 +295,16 @@ async def fetch_user_info_by_access_key(self, access_key: str) -> tuple[str, Use
raise ValueError("Unknown owner access key")
return row.domain_name, row.role

@auth_db_source_resilience.apply()
async def fetch_user_id_by_access_key(self, access_key: AccessKey) -> UserID:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a question, is there a reason this file uses the Core API and a direct connection instead of the ORM and db_session?

async with self._db.begin_readonly() as conn:
query = sa.select(keypairs.c.user).where(keypairs.c.access_key == access_key)
result = await conn.execute(query)
row = result.scalar()
if row is None:
raise AccessKeyNotFound("Unknown access key")
return UserID(cast(UUID, row))

@auth_db_source_resilience.apply()
async def fetch_user_info_by_email(self, email: str) -> tuple[UUID, UserRole, str]:
"""Fetch (uuid, role, domain_name) for a user identified by *email*.
Expand Down
6 changes: 6 additions & 0 deletions src/ai/backend/manager/repositories/auth/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

import sqlalchemy as sa

from ai.backend.common.identifier.user import UserID
from ai.backend.common.metrics.metric import DomainType, LayerType
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
from ai.backend.common.resilience.resilience import Resilience
from ai.backend.common.types import AccessKey
from ai.backend.manager.data.auth.login_session_types import LoginHistoryData, LoginSessionData
from ai.backend.manager.data.auth.types import GroupMembershipData, UserData
from ai.backend.manager.data.common.types import SearchResult
Expand Down Expand Up @@ -82,6 +84,10 @@ async def update_ssh_keypair(self, access_key: str, public_key: str, private_key
async def get_delegation_target_by_access_key(self, access_key: str) -> tuple[str, UserRole]:
return await self._db_source.fetch_user_info_by_access_key(access_key)

@auth_repository_resilience.apply()
async def get_user_id_by_access_key(self, access_key: AccessKey) -> UserID:
return await self._db_source.fetch_user_id_by_access_key(access_key)

@auth_repository_resilience.apply()
async def get_delegation_target_by_email(self, email: str) -> tuple[UUID, UserRole, str]:
return await self._db_source.fetch_user_info_by_email(email)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import override

from ai.backend.common.identifier.user import UserID
from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.services.auth.actions.base import AuthAction


@dataclass
class ResolveUserIDByAccessKeyAction(AuthAction):
access_key: AccessKey

@override
def entity_id(self) -> str | None:
return str(self.access_key)
Comment thread
fregataa marked this conversation as resolved.

@override
@classmethod
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.GET


@dataclass
class ResolveUserIDByAccessKeyResult(BaseActionResult):
user_id: UserID

@override
def entity_id(self) -> str | None:
return str(self.user_id)
11 changes: 11 additions & 0 deletions src/ai/backend/manager/services/auth/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
ResolveAccessKeyScopeAction,
ResolveAccessKeyScopeResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_id_by_access_key import (
ResolveUserIDByAccessKeyAction,
ResolveUserIDByAccessKeyResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_scope import (
ResolveUserScopeAction,
ResolveUserScopeResult,
Expand Down Expand Up @@ -88,6 +92,9 @@ class AuthProcessors(AbstractProcessorPackage):
ResolveAccessKeyScopeAction, ResolveAccessKeyScopeResult
]
resolve_user_scope: ActionProcessor[ResolveUserScopeAction, ResolveUserScopeResult]
resolve_user_id_by_access_key: ActionProcessor[
ResolveUserIDByAccessKeyAction, ResolveUserIDByAccessKeyResult
]
admin_search_login_sessions: ActionProcessor[
AdminSearchLoginSessionsAction, SearchLoginSessionsActionResult
]
Expand Down Expand Up @@ -135,6 +142,9 @@ def __init__(
service.resolve_access_key_scope, action_monitors
)
self.resolve_user_scope = ActionProcessor(service.resolve_user_scope, action_monitors)
self.resolve_user_id_by_access_key = ActionProcessor(
service.resolve_user_id_by_access_key, action_monitors
)
self.admin_search_login_sessions = ActionProcessor(
service.admin_search_login_sessions, action_monitors
)
Expand Down Expand Up @@ -167,6 +177,7 @@ def supported_actions(self) -> list[ActionSpec]:
UpdatePasswordNoAuthAction.spec(),
ResolveAccessKeyScopeAction.spec(),
ResolveUserScopeAction.spec(),
ResolveUserIDByAccessKeyAction.spec(),
AdminSearchLoginSessionsAction.spec(),
SearchLoginSessionsAction.spec(),
AdminSearchLoginHistoryAction.spec(),
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/manager/services/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
ResolveAccessKeyScopeAction,
ResolveAccessKeyScopeResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_id_by_access_key import (
ResolveUserIDByAccessKeyAction,
ResolveUserIDByAccessKeyResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_scope import (
ResolveUserScopeAction,
ResolveUserScopeResult,
Expand Down Expand Up @@ -758,6 +762,12 @@ async def resolve_access_key_scope(
owner_access_key=owner_ak,
)

async def resolve_user_id_by_access_key(
self, action: ResolveUserIDByAccessKeyAction
) -> ResolveUserIDByAccessKeyResult:
user_id = await self._auth_repository.get_user_id_by_access_key(action.access_key)
return ResolveUserIDByAccessKeyResult(user_id=user_id)

async def resolve_user_scope(self, action: ResolveUserScopeAction) -> ResolveUserScopeResult:
if action.owner_user_email is None:
return ResolveUserScopeResult(
Expand Down
21 changes: 19 additions & 2 deletions tests/unit/manager/repositories/auth/test_auth_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@

from ai.backend.common.data.permission.types import RelationType
from ai.backend.common.exception import UserNotFound
from ai.backend.common.types import ResourceSlot, VFolderHostPermissionMap
from ai.backend.common.types import AccessKey, ResourceSlot, VFolderHostPermissionMap
from ai.backend.manager.data.auth.hash import PasswordHashAlgorithm
from ai.backend.manager.data.auth.types import UserData
from ai.backend.manager.data.group.types import GroupData
from ai.backend.manager.data.permission.types import EntityType, ScopeType
from ai.backend.manager.errors.auth import GroupMembershipNotFoundError
from ai.backend.manager.errors.auth import AccessKeyNotFound, GroupMembershipNotFoundError
from ai.backend.manager.models.agent import AgentRow
from ai.backend.manager.models.deployment_auto_scaling_policy import DeploymentAutoScalingPolicyRow
from ai.backend.manager.models.deployment_policy import DeploymentPolicyRow
Expand Down Expand Up @@ -533,3 +533,20 @@ async def test_get_current_time(self, auth_repository: AuthRepository) -> None:
now_utc = datetime.now(UTC)
time_diff = abs((now_utc - result).total_seconds())
assert time_diff < 1.0

async def test_get_user_id_by_access_key_success(
self,
auth_repository: AuthRepository,
sample_user_data: UserTestData,
) -> None:
result = await auth_repository.get_user_id_by_access_key(
AccessKey(sample_user_data.access_key)
)

assert result == sample_user_data.uuid

async def test_get_user_id_by_access_key_not_found(
self, auth_repository: AuthRepository
) -> None:
with pytest.raises(AccessKeyNotFound):
await auth_repository.get_user_id_by_access_key(AccessKey("AKIANONEXISTENT"))
Loading