Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/ai/backend/common/identifier/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from typing import NewType
from uuid import UUID

__all__ = ("UserID",)


UserID = NewType("UserID", UUID)
12 changes: 12 additions & 0 deletions src/ai/backend/manager/repositories/auth/db_source/db_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy.orm import joinedload, selectinload

from ai.backend.common.exception import BackendAIError, UserNotFound
from ai.backend.common.identifier.user import UserID
from ai.backend.common.metrics.metric import DomainType, LayerType
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
from ai.backend.common.resilience.policies.retry import BackoffStrategy, RetryArgs, RetryPolicy
Expand All @@ -21,6 +22,7 @@
from ai.backend.manager.data.common.types import SearchResult
from ai.backend.manager.data.permission.types import EntityType, ScopeType
from ai.backend.manager.errors.auth import (
AccessKeyNotFound,
AuthorizationFailed,
GroupMembershipNotFoundError,
LoginSessionNotFoundError,
Expand Down Expand Up @@ -292,6 +294,16 @@ async def fetch_user_info_by_access_key(self, access_key: str) -> tuple[str, Use
raise ValueError("Unknown owner access key")
return row.domain_name, row.role

@auth_db_source_resilience.apply()
async def fetch_user_id_by_access_key(self, access_key: str) -> UserID:
Comment thread
fregataa marked this conversation as resolved.
Outdated
Comment thread
jopemachine marked this conversation as resolved.
Outdated
async with self._db.begin_readonly() as conn:
query = sa.select(keypairs.c.user).where(keypairs.c.access_key == access_key)
result = await conn.execute(query)
row = result.scalar()
if row is None:
raise AccessKeyNotFound("Unknown access key")
return UserID(UUID(str(row)))
Comment thread
fregataa marked this conversation as resolved.
Outdated
Comment thread
fregataa marked this conversation as resolved.
Outdated

@auth_db_source_resilience.apply()
async def fetch_user_info_by_email(self, email: str) -> tuple[UUID, UserRole, str]:
"""Fetch (uuid, role, domain_name) for a user identified by *email*.
Expand Down
5 changes: 5 additions & 0 deletions src/ai/backend/manager/repositories/auth/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import sqlalchemy as sa

from ai.backend.common.identifier.user import UserID
from ai.backend.common.metrics.metric import DomainType, LayerType
from ai.backend.common.resilience.policies.metrics import MetricArgs, MetricPolicy
from ai.backend.common.resilience.resilience import Resilience
Expand Down Expand Up @@ -82,6 +83,10 @@ async def update_ssh_keypair(self, access_key: str, public_key: str, private_key
async def get_delegation_target_by_access_key(self, access_key: str) -> tuple[str, UserRole]:
return await self._db_source.fetch_user_info_by_access_key(access_key)

@auth_repository_resilience.apply()
async def get_user_id_by_access_key(self, access_key: str) -> UserID:
return await self._db_source.fetch_user_id_by_access_key(access_key)

@auth_repository_resilience.apply()
async def get_delegation_target_by_email(self, email: str) -> tuple[UUID, UserRole, str]:
return await self._db_source.fetch_user_info_by_email(email)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import override

from ai.backend.common.identifier.user import UserID
from ai.backend.common.types import AccessKey
from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.services.auth.actions.base import AuthAction


@dataclass
class ResolveUserIDByAccessKeyAction(AuthAction):
access_key: AccessKey

@override
def entity_id(self) -> str | None:
return str(self.access_key)
Comment thread
fregataa marked this conversation as resolved.

@override
@classmethod
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.GET


@dataclass
class ResolveUserIDByAccessKeyResult(BaseActionResult):
user_id: UserID

@override
def entity_id(self) -> str | None:
return str(self.user_id)
11 changes: 11 additions & 0 deletions src/ai/backend/manager/services/auth/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@
ResolveAccessKeyScopeAction,
ResolveAccessKeyScopeResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_id_by_access_key import (
ResolveUserIDByAccessKeyAction,
ResolveUserIDByAccessKeyResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_scope import (
ResolveUserScopeAction,
ResolveUserScopeResult,
Expand Down Expand Up @@ -88,6 +92,9 @@ class AuthProcessors(AbstractProcessorPackage):
ResolveAccessKeyScopeAction, ResolveAccessKeyScopeResult
]
resolve_user_scope: ActionProcessor[ResolveUserScopeAction, ResolveUserScopeResult]
resolve_user_id_by_access_key: ActionProcessor[
ResolveUserIDByAccessKeyAction, ResolveUserIDByAccessKeyResult
]
admin_search_login_sessions: ActionProcessor[
AdminSearchLoginSessionsAction, SearchLoginSessionsActionResult
]
Expand Down Expand Up @@ -135,6 +142,9 @@ def __init__(
service.resolve_access_key_scope, action_monitors
)
self.resolve_user_scope = ActionProcessor(service.resolve_user_scope, action_monitors)
self.resolve_user_id_by_access_key = ActionProcessor(
service.resolve_user_id_by_access_key, action_monitors
)
self.admin_search_login_sessions = ActionProcessor(
service.admin_search_login_sessions, action_monitors
)
Expand Down Expand Up @@ -167,6 +177,7 @@ def supported_actions(self) -> list[ActionSpec]:
UpdatePasswordNoAuthAction.spec(),
ResolveAccessKeyScopeAction.spec(),
ResolveUserScopeAction.spec(),
ResolveUserIDByAccessKeyAction.spec(),
AdminSearchLoginSessionsAction.spec(),
SearchLoginSessionsAction.spec(),
AdminSearchLoginHistoryAction.spec(),
Expand Down
10 changes: 10 additions & 0 deletions src/ai/backend/manager/services/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
ResolveAccessKeyScopeAction,
ResolveAccessKeyScopeResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_id_by_access_key import (
ResolveUserIDByAccessKeyAction,
ResolveUserIDByAccessKeyResult,
)
from ai.backend.manager.services.auth.actions.resolve_user_scope import (
ResolveUserScopeAction,
ResolveUserScopeResult,
Expand Down Expand Up @@ -758,6 +762,12 @@ async def resolve_access_key_scope(
owner_access_key=owner_ak,
)

async def resolve_user_id_by_access_key(
self, action: ResolveUserIDByAccessKeyAction
) -> ResolveUserIDByAccessKeyResult:
user_id = await self._auth_repository.get_user_id_by_access_key(str(action.access_key))
return ResolveUserIDByAccessKeyResult(user_id=user_id)

async def resolve_user_scope(self, action: ResolveUserScopeAction) -> ResolveUserScopeResult:
if action.owner_user_email is None:
return ResolveUserScopeResult(
Expand Down
Loading