Skip to content
1 change: 1 addition & 0 deletions changes/10051.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Apply RBAC validator for Keypair actions to enforce permission checks on create, get, update, delete, and purge operations
36 changes: 33 additions & 3 deletions src/ai/backend/manager/services/auth/actions/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,43 @@
from dataclasses import dataclass
from typing import override

from ai.backend.common.data.permission.types import EntityType
from ai.backend.manager.actions.action import BaseAction
from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult
from ai.backend.manager.actions.action.single_entity import (
BaseSingleEntityAction,
BaseSingleEntityActionResult,
)
from ai.backend.manager.actions.action.types import FieldData


@dataclass
class AuthAction(BaseAction):
@classmethod
@override
@classmethod
def entity_type(cls) -> EntityType:
return EntityType.AUTH


class KeypairScopeAction(BaseScopeAction):
@override
@classmethod
def entity_type(cls) -> EntityType:
return EntityType.KEYPAIR


class KeypairScopeActionResult(BaseScopeActionResult):
pass


class KeypairSingleEntityAction(BaseSingleEntityAction):
@override
@classmethod
def entity_type(cls) -> EntityType:
return EntityType.KEYPAIR

@override
def field_data(self) -> FieldData | None:
return None


class KeypairSingleEntityActionResult(BaseSingleEntityActionResult):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,48 @@
from dataclasses import dataclass
from typing import override

from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.common.data.permission.types import RBACElementType, ScopeType
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.auth.types import SSHKeypair
from ai.backend.manager.services.auth.actions.base import AuthAction
from ai.backend.manager.data.permission.types import RBACElementRef
Comment on lines +5 to +8
from ai.backend.manager.services.auth.actions.base import (
KeypairScopeAction,
KeypairScopeActionResult,
)


@dataclass
class GenerateSSHKeypairAction(AuthAction):
class GenerateSSHKeypairAction(KeypairScopeAction):
user_id: uuid.UUID
access_key: str

@override
def entity_id(self) -> str | None:
return str(self.user_id)

@override
@classmethod
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.CREATE

@override
def scope_type(self) -> ScopeType:
return ScopeType.USER

@override
def scope_id(self) -> str:
return str(self.user_id)

@override
def target_element(self) -> RBACElementRef:
return RBACElementRef(RBACElementType.USER, str(self.user_id))
Comment on lines +34 to +35


@dataclass
class GenerateSSHKeypairActionResult(BaseActionResult):
class GenerateSSHKeypairActionResult(KeypairScopeActionResult):
ssh_keypair: SSHKeypair
user_id: uuid.UUID
Comment on lines +39 to +41

@override
def entity_id(self) -> str | None:
return None
def scope_type(self) -> ScopeType:
return ScopeType.USER

@override
def scope_id(self) -> str:
return str(self.user_id)
29 changes: 19 additions & 10 deletions src/ai/backend/manager/services/auth/actions/get_ssh_keypair.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,39 @@
from dataclasses import dataclass
from typing import override

from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.common.data.permission.types import RBACElementType
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.services.auth.actions.base import AuthAction
from ai.backend.manager.data.permission.types import RBACElementRef
Comment on lines +5 to +7
from ai.backend.manager.services.auth.actions.base import (
KeypairSingleEntityAction,
KeypairSingleEntityActionResult,
)


@dataclass
class GetSSHKeypairAction(AuthAction):
class GetSSHKeypairAction(KeypairSingleEntityAction):
user_id: uuid.UUID
access_key: str

@override
def entity_id(self) -> str | None:
return str(self.user_id)

@override
@classmethod
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.GET

@override
def target_entity_id(self) -> str:
return self.access_key

@override
def target_element(self) -> RBACElementRef:
return RBACElementRef(RBACElementType.KEYPAIR, self.access_key)
Comment on lines +29 to +30


@dataclass
class GetSSHKeypairActionResult(BaseActionResult):
class GetSSHKeypairActionResult(KeypairSingleEntityActionResult):
public_key: str
access_key: str
Comment on lines +34 to +36

@override
def entity_id(self) -> str | None:
return None
def target_entity_id(self) -> str:
return self.access_key
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,50 @@
from dataclasses import dataclass
from typing import override

from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.common.data.permission.types import RBACElementType, ScopeType
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.auth.types import SSHKeypair
from ai.backend.manager.services.auth.actions.base import AuthAction
from ai.backend.manager.data.permission.types import RBACElementRef
Comment on lines +5 to +8
from ai.backend.manager.services.auth.actions.base import (
KeypairScopeAction,
KeypairScopeActionResult,
)


@dataclass
class UploadSSHKeypairAction(AuthAction):
class UploadSSHKeypairAction(KeypairScopeAction):
user_id: uuid.UUID
public_key: str
private_key: str
access_key: str

@override
def entity_id(self) -> str | None:
return str(self.user_id)

@override
@classmethod
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.CREATE

@override
def scope_type(self) -> ScopeType:
return ScopeType.USER

@override
def scope_id(self) -> str:
return str(self.user_id)

@override
def target_element(self) -> RBACElementRef:
return RBACElementRef(RBACElementType.USER, str(self.user_id))
Comment on lines +36 to +37


@dataclass
class UploadSSHKeypairActionResult(BaseActionResult):
class UploadSSHKeypairActionResult(KeypairScopeActionResult):
ssh_keypair: SSHKeypair
user_id: uuid.UUID
Comment on lines +41 to +43

@override
def entity_id(self) -> str | None:
return None
def scope_type(self) -> ScopeType:
return ScopeType.USER

@override
def scope_id(self) -> str:
return str(self.user_id)
22 changes: 16 additions & 6 deletions src/ai/backend/manager/services/auth/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from ai.backend.manager.actions.monitors.monitor import ActionMonitor
from ai.backend.manager.actions.processor import ActionProcessor
from ai.backend.manager.actions.processor.scope import ScopeActionProcessor
from ai.backend.manager.actions.processor.single_entity import SingleEntityActionProcessor
from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec
from ai.backend.manager.actions.validators import ActionValidators
from ai.backend.manager.services.auth.actions.authorize import (
Expand Down Expand Up @@ -49,9 +51,11 @@
class AuthProcessors(AbstractProcessorPackage):
signout: ActionProcessor[SignoutAction, SignoutActionResult]
update_full_name: ActionProcessor[UpdateFullNameAction, UpdateFullNameActionResult]
get_ssh_keypair: ActionProcessor[GetSSHKeypairAction, GetSSHKeypairActionResult]
generate_ssh_keypair: ActionProcessor[GenerateSSHKeypairAction, GenerateSSHKeypairActionResult]
upload_ssh_keypair: ActionProcessor[UploadSSHKeypairAction, UploadSSHKeypairActionResult]
get_ssh_keypair: SingleEntityActionProcessor[GetSSHKeypairAction, GetSSHKeypairActionResult]
generate_ssh_keypair: ScopeActionProcessor[
GenerateSSHKeypairAction, GenerateSSHKeypairActionResult
]
upload_ssh_keypair: ScopeActionProcessor[UploadSSHKeypairAction, UploadSSHKeypairActionResult]
get_role: ActionProcessor[GetRoleAction, GetRoleActionResult]
authorize: ActionProcessor[AuthorizeAction, AuthorizeActionResult]
signup: ActionProcessor[SignupAction, SignupActionResult]
Expand All @@ -72,9 +76,15 @@ def __init__(
) -> None:
self.signout = ActionProcessor(service.signout, action_monitors)
self.update_full_name = ActionProcessor(service.update_full_name, action_monitors)
self.get_ssh_keypair = ActionProcessor(service.get_ssh_keypair, action_monitors)
self.generate_ssh_keypair = ActionProcessor(service.generate_ssh_keypair, action_monitors)
self.upload_ssh_keypair = ActionProcessor(service.upload_ssh_keypair, action_monitors)
self.get_ssh_keypair = SingleEntityActionProcessor(
service.get_ssh_keypair, action_monitors, validators=[validators.rbac.single_entity]
)
self.generate_ssh_keypair = ScopeActionProcessor(
service.generate_ssh_keypair, action_monitors, validators=[validators.rbac.scope]
)
self.upload_ssh_keypair = ScopeActionProcessor(
service.upload_ssh_keypair, action_monitors, validators=[validators.rbac.scope]
)
self.get_role = ActionProcessor(service.get_role, action_monitors)
self.authorize = ActionProcessor(service.authorize, action_monitors)
self.signup = ActionProcessor(service.signup, action_monitors)
Expand Down
6 changes: 4 additions & 2 deletions src/ai/backend/manager/services/auth/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ async def update_password_no_auth(

async def get_ssh_keypair(self, action: GetSSHKeypairAction) -> GetSSHKeypairActionResult:
pubkey = await self._auth_repository.get_ssh_public_key(action.access_key)
return GetSSHKeypairActionResult(public_key=pubkey or "")
return GetSSHKeypairActionResult(public_key=pubkey or "", access_key=action.access_key)

async def generate_ssh_keypair(
self, action: GenerateSSHKeypairAction
Expand All @@ -419,7 +419,8 @@ async def generate_ssh_keypair(
ssh_keypair=SSHKeypair(
ssh_public_key=pubkey,
ssh_private_key=privkey,
)
),
user_id=action.user_id,
)

async def upload_ssh_keypair(
Expand All @@ -438,6 +439,7 @@ async def upload_ssh_keypair(
ssh_public_key=pubkey,
ssh_private_key=privkey,
),
user_id=action.user_id,
)

async def resolve_access_key_scope(
Expand Down
12 changes: 11 additions & 1 deletion tests/component/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@
from ai.backend.logging.config import ConsoleConfig, LogDriver, LoggingConfig
from ai.backend.logging.types import LogFormat
from ai.backend.manager.actions.validators import ActionValidators
from ai.backend.manager.actions.validators.rbac import RBACValidators
from ai.backend.manager.actions.validators.rbac.scope import ScopeActionRBACValidator
from ai.backend.manager.actions.validators.rbac.single_entity import SingleEntityActionRBACValidator
from ai.backend.manager.agent_cache import AgentRPCCache
from ai.backend.manager.api import ManagerStatus
from ai.backend.manager.api.rest.app import build_root_app, mount_registries
Expand Down Expand Up @@ -1187,7 +1190,14 @@ def auth_processors(
config_provider=config_provider,
)
return AuthProcessors(
service=service, action_monitors=[], validators=MagicMock(spec=ActionValidators)
service=service,
action_monitors=[],
validators=ActionValidators(
rbac=RBACValidators(
scope=MagicMock(spec=ScopeActionRBACValidator),
single_entity=MagicMock(spec=SingleEntityActionRBACValidator),
),
),
)


Expand Down
14 changes: 11 additions & 3 deletions tests/unit/manager/api/auth/test_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,12 @@ async def test_calls_processor_and_returns_public_key(
) -> None:
"""Verify processor is called and public key is returned."""
public_key = "ssh-rsa AAAAB3...\n"
access_key = "AKIAIOSFODNN7EXAMPLE"
mock_processors.auth.get_ssh_keypair.wait_for_complete = AsyncMock(
return_value=GetSSHKeypairActionResult(public_key=public_key)
return_value=GetSSHKeypairActionResult(
public_key=public_key,
access_key=access_key,
)
)

response = await handler.get_ssh_keypair(user_context)
Expand All @@ -564,12 +568,14 @@ async def test_calls_processor_and_returns_keypair(
"""Verify processor is called and keypair is returned."""
ssh_public_key = "ssh-rsa NEWPUB...\n"
ssh_private_key = "-----BEGIN RSA PRIVATE KEY-----\n...\n"
user_id = user_context.user_uuid
mock_processors.auth.generate_ssh_keypair.wait_for_complete = AsyncMock(
return_value=GenerateSSHKeypairActionResult(
ssh_keypair=SSHKeypair(
ssh_public_key=ssh_public_key,
ssh_private_key=ssh_private_key,
)
),
user_id=user_id,
)
)

Expand Down Expand Up @@ -599,12 +605,14 @@ async def test_calls_processor_and_returns_keypair(
"pubkey": "ssh-rsa AAAAB3...",
"privkey": "-----BEGIN RSA PRIVATE KEY-----\n...",
})
user_id = user_context.user_uuid
mock_processors.auth.upload_ssh_keypair.wait_for_complete = AsyncMock(
return_value=UploadSSHKeypairActionResult(
ssh_keypair=SSHKeypair(
ssh_public_key="ssh-rsa AAAAB3...\n",
ssh_private_key="-----BEGIN RSA PRIVATE KEY-----\n...\n",
)
),
user_id=user_id,
)
)

Expand Down
Loading