diff --git a/changes/10051.feature.md b/changes/10051.feature.md new file mode 100644 index 00000000000..9c423590006 --- /dev/null +++ b/changes/10051.feature.md @@ -0,0 +1 @@ +Apply RBAC validator for Keypair actions to enforce permission checks on create, get, update, delete, and purge operations diff --git a/src/ai/backend/manager/services/auth/actions/base.py b/src/ai/backend/manager/services/auth/actions/base.py index 4fb65575b42..8711d0976f3 100644 --- a/src/ai/backend/manager/services/auth/actions/base.py +++ b/src/ai/backend/manager/services/auth/actions/base.py @@ -1,13 +1,43 @@ -from dataclasses import dataclass from typing import override from ai.backend.common.data.permission.types import EntityType from ai.backend.manager.actions.action import BaseAction +from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult +from ai.backend.manager.actions.action.single_entity import ( + BaseSingleEntityAction, + BaseSingleEntityActionResult, +) +from ai.backend.manager.actions.action.types import FieldData -@dataclass class AuthAction(BaseAction): - @classmethod @override + @classmethod def entity_type(cls) -> EntityType: return EntityType.AUTH + + +class KeypairScopeAction(BaseScopeAction): + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.KEYPAIR + + +class KeypairScopeActionResult(BaseScopeActionResult): + pass + + +class KeypairSingleEntityAction(BaseSingleEntityAction): + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.KEYPAIR + + @override + def field_data(self) -> FieldData | None: + return None + + +class KeypairSingleEntityActionResult(BaseSingleEntityActionResult): + pass diff --git a/src/ai/backend/manager/services/auth/actions/generate_ssh_keypair.py b/src/ai/backend/manager/services/auth/actions/generate_ssh_keypair.py index c165f83ebc9..76c1728c2a3 100644 --- a/src/ai/backend/manager/services/auth/actions/generate_ssh_keypair.py +++ b/src/ai/backend/manager/services/auth/actions/generate_ssh_keypair.py @@ -2,31 +2,48 @@ from dataclasses import dataclass from typing import override -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.auth.types import SSHKeypair -from ai.backend.manager.services.auth.actions.base import AuthAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.auth.actions.base import ( + KeypairScopeAction, + KeypairScopeActionResult, +) @dataclass -class GenerateSSHKeypairAction(AuthAction): +class GenerateSSHKeypairAction(KeypairScopeAction): user_id: uuid.UUID access_key: str - @override - def entity_id(self) -> str | None: - return str(self.user_id) - @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.USER, str(self.user_id)) + @dataclass -class GenerateSSHKeypairActionResult(BaseActionResult): +class GenerateSSHKeypairActionResult(KeypairScopeActionResult): ssh_keypair: SSHKeypair + user_id: uuid.UUID @override - def entity_id(self) -> str | None: - return None + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) diff --git a/src/ai/backend/manager/services/auth/actions/get_ssh_keypair.py b/src/ai/backend/manager/services/auth/actions/get_ssh_keypair.py index 1b6dcea0dc7..c97293bfc08 100644 --- a/src/ai/backend/manager/services/auth/actions/get_ssh_keypair.py +++ b/src/ai/backend/manager/services/auth/actions/get_ssh_keypair.py @@ -2,30 +2,39 @@ from dataclasses import dataclass from typing import override -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.services.auth.actions.base import AuthAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.auth.actions.base import ( + KeypairSingleEntityAction, + KeypairSingleEntityActionResult, +) @dataclass -class GetSSHKeypairAction(AuthAction): +class GetSSHKeypairAction(KeypairSingleEntityAction): user_id: uuid.UUID access_key: str - @override - def entity_id(self) -> str | None: - return str(self.user_id) - @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.GET + @override + def target_entity_id(self) -> str: + return self.access_key + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.KEYPAIR, self.access_key) + @dataclass -class GetSSHKeypairActionResult(BaseActionResult): +class GetSSHKeypairActionResult(KeypairSingleEntityActionResult): public_key: str + access_key: str @override - def entity_id(self) -> str | None: - return None + def target_entity_id(self) -> str: + return self.access_key diff --git a/src/ai/backend/manager/services/auth/actions/upload_ssh_keypair.py b/src/ai/backend/manager/services/auth/actions/upload_ssh_keypair.py index eeeefe5a035..c32c703c98e 100644 --- a/src/ai/backend/manager/services/auth/actions/upload_ssh_keypair.py +++ b/src/ai/backend/manager/services/auth/actions/upload_ssh_keypair.py @@ -2,33 +2,50 @@ from dataclasses import dataclass from typing import override -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.auth.types import SSHKeypair -from ai.backend.manager.services.auth.actions.base import AuthAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.auth.actions.base import ( + KeypairScopeAction, + KeypairScopeActionResult, +) @dataclass -class UploadSSHKeypairAction(AuthAction): +class UploadSSHKeypairAction(KeypairScopeAction): user_id: uuid.UUID public_key: str private_key: str access_key: str - @override - def entity_id(self) -> str | None: - return str(self.user_id) - @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.USER, str(self.user_id)) + @dataclass -class UploadSSHKeypairActionResult(BaseActionResult): +class UploadSSHKeypairActionResult(KeypairScopeActionResult): ssh_keypair: SSHKeypair + user_id: uuid.UUID @override - def entity_id(self) -> str | None: - return None + def scope_type(self) -> ScopeType: + return ScopeType.USER + + @override + def scope_id(self) -> str: + return str(self.user_id) diff --git a/src/ai/backend/manager/services/auth/processors.py b/src/ai/backend/manager/services/auth/processors.py index b2425e649df..785b95c3720 100644 --- a/src/ai/backend/manager/services/auth/processors.py +++ b/src/ai/backend/manager/services/auth/processors.py @@ -2,6 +2,8 @@ from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.processor import ActionProcessor +from ai.backend.manager.actions.processor.scope import ScopeActionProcessor +from ai.backend.manager.actions.processor.single_entity import SingleEntityActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec from ai.backend.manager.actions.validators import ActionValidators from ai.backend.manager.services.auth.actions.authorize import ( @@ -49,9 +51,11 @@ class AuthProcessors(AbstractProcessorPackage): signout: ActionProcessor[SignoutAction, SignoutActionResult] update_full_name: ActionProcessor[UpdateFullNameAction, UpdateFullNameActionResult] - get_ssh_keypair: ActionProcessor[GetSSHKeypairAction, GetSSHKeypairActionResult] - generate_ssh_keypair: ActionProcessor[GenerateSSHKeypairAction, GenerateSSHKeypairActionResult] - upload_ssh_keypair: ActionProcessor[UploadSSHKeypairAction, UploadSSHKeypairActionResult] + get_ssh_keypair: SingleEntityActionProcessor[GetSSHKeypairAction, GetSSHKeypairActionResult] + generate_ssh_keypair: ScopeActionProcessor[ + GenerateSSHKeypairAction, GenerateSSHKeypairActionResult + ] + upload_ssh_keypair: ScopeActionProcessor[UploadSSHKeypairAction, UploadSSHKeypairActionResult] get_role: ActionProcessor[GetRoleAction, GetRoleActionResult] authorize: ActionProcessor[AuthorizeAction, AuthorizeActionResult] signup: ActionProcessor[SignupAction, SignupActionResult] @@ -72,9 +76,15 @@ def __init__( ) -> None: self.signout = ActionProcessor(service.signout, action_monitors) self.update_full_name = ActionProcessor(service.update_full_name, action_monitors) - self.get_ssh_keypair = ActionProcessor(service.get_ssh_keypair, action_monitors) - self.generate_ssh_keypair = ActionProcessor(service.generate_ssh_keypair, action_monitors) - self.upload_ssh_keypair = ActionProcessor(service.upload_ssh_keypair, action_monitors) + self.get_ssh_keypair = SingleEntityActionProcessor( + service.get_ssh_keypair, action_monitors, validators=[validators.rbac.single_entity] + ) + self.generate_ssh_keypair = ScopeActionProcessor( + service.generate_ssh_keypair, action_monitors, validators=[validators.rbac.scope] + ) + self.upload_ssh_keypair = ScopeActionProcessor( + service.upload_ssh_keypair, action_monitors, validators=[validators.rbac.scope] + ) self.get_role = ActionProcessor(service.get_role, action_monitors) self.authorize = ActionProcessor(service.authorize, action_monitors) self.signup = ActionProcessor(service.signup, action_monitors) diff --git a/src/ai/backend/manager/services/auth/service.py b/src/ai/backend/manager/services/auth/service.py index 53625ed5296..c675ef4ab90 100644 --- a/src/ai/backend/manager/services/auth/service.py +++ b/src/ai/backend/manager/services/auth/service.py @@ -407,7 +407,7 @@ async def update_password_no_auth( async def get_ssh_keypair(self, action: GetSSHKeypairAction) -> GetSSHKeypairActionResult: pubkey = await self._auth_repository.get_ssh_public_key(action.access_key) - return GetSSHKeypairActionResult(public_key=pubkey or "") + return GetSSHKeypairActionResult(public_key=pubkey or "", access_key=action.access_key) async def generate_ssh_keypair( self, action: GenerateSSHKeypairAction @@ -419,7 +419,8 @@ async def generate_ssh_keypair( ssh_keypair=SSHKeypair( ssh_public_key=pubkey, ssh_private_key=privkey, - ) + ), + user_id=action.user_id, ) async def upload_ssh_keypair( @@ -438,6 +439,7 @@ async def upload_ssh_keypair( ssh_public_key=pubkey, ssh_private_key=privkey, ), + user_id=action.user_id, ) async def resolve_access_key_scope( diff --git a/tests/component/conftest.py b/tests/component/conftest.py index 20f9a8ad01e..90e552c4f74 100644 --- a/tests/component/conftest.py +++ b/tests/component/conftest.py @@ -74,6 +74,9 @@ from ai.backend.logging.config import ConsoleConfig, LogDriver, LoggingConfig from ai.backend.logging.types import LogFormat from ai.backend.manager.actions.validators import ActionValidators +from ai.backend.manager.actions.validators.rbac import RBACValidators +from ai.backend.manager.actions.validators.rbac.scope import ScopeActionRBACValidator +from ai.backend.manager.actions.validators.rbac.single_entity import SingleEntityActionRBACValidator from ai.backend.manager.agent_cache import AgentRPCCache from ai.backend.manager.api import ManagerStatus from ai.backend.manager.api.rest.app import build_root_app, mount_registries @@ -1187,7 +1190,14 @@ def auth_processors( config_provider=config_provider, ) return AuthProcessors( - service=service, action_monitors=[], validators=MagicMock(spec=ActionValidators) + service=service, + action_monitors=[], + validators=ActionValidators( + rbac=RBACValidators( + scope=MagicMock(spec=ScopeActionRBACValidator), + single_entity=MagicMock(spec=SingleEntityActionRBACValidator), + ), + ), ) diff --git a/tests/unit/manager/api/auth/test_handlers.py b/tests/unit/manager/api/auth/test_handlers.py index c1cbb00dc3e..2d0e8ebd337 100644 --- a/tests/unit/manager/api/auth/test_handlers.py +++ b/tests/unit/manager/api/auth/test_handlers.py @@ -538,8 +538,12 @@ async def test_calls_processor_and_returns_public_key( ) -> None: """Verify processor is called and public key is returned.""" public_key = "ssh-rsa AAAAB3...\n" + access_key = "AKIAIOSFODNN7EXAMPLE" mock_processors.auth.get_ssh_keypair.wait_for_complete = AsyncMock( - return_value=GetSSHKeypairActionResult(public_key=public_key) + return_value=GetSSHKeypairActionResult( + public_key=public_key, + access_key=access_key, + ) ) response = await handler.get_ssh_keypair(user_context) @@ -564,12 +568,14 @@ async def test_calls_processor_and_returns_keypair( """Verify processor is called and keypair is returned.""" ssh_public_key = "ssh-rsa NEWPUB...\n" ssh_private_key = "-----BEGIN RSA PRIVATE KEY-----\n...\n" + user_id = user_context.user_uuid mock_processors.auth.generate_ssh_keypair.wait_for_complete = AsyncMock( return_value=GenerateSSHKeypairActionResult( ssh_keypair=SSHKeypair( ssh_public_key=ssh_public_key, ssh_private_key=ssh_private_key, - ) + ), + user_id=user_id, ) ) @@ -599,12 +605,14 @@ async def test_calls_processor_and_returns_keypair( "pubkey": "ssh-rsa AAAAB3...", "privkey": "-----BEGIN RSA PRIVATE KEY-----\n...", }) + user_id = user_context.user_uuid mock_processors.auth.upload_ssh_keypair.wait_for_complete = AsyncMock( return_value=UploadSSHKeypairActionResult( ssh_keypair=SSHKeypair( ssh_public_key="ssh-rsa AAAAB3...\n", ssh_private_key="-----BEGIN RSA PRIVATE KEY-----\n...\n", - ) + ), + user_id=user_id, ) )