Skip to content

Commit c952075

Browse files
fregataaclaude
andauthored
feat(BA-5035): Apply RBAC validator for Keypair actions (#10051)
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent bf6696b commit c952075

9 files changed

Lines changed: 149 additions & 45 deletions

File tree

changes/10051.feature.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Apply RBAC validator for Keypair actions to enforce permission checks on create, get, update, delete, and purge operations
Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,43 @@
1-
from dataclasses import dataclass
21
from typing import override
32

43
from ai.backend.common.data.permission.types import EntityType
54
from ai.backend.manager.actions.action import BaseAction
5+
from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult
6+
from ai.backend.manager.actions.action.single_entity import (
7+
BaseSingleEntityAction,
8+
BaseSingleEntityActionResult,
9+
)
10+
from ai.backend.manager.actions.action.types import FieldData
611

712

8-
@dataclass
913
class AuthAction(BaseAction):
10-
@classmethod
1114
@override
15+
@classmethod
1216
def entity_type(cls) -> EntityType:
1317
return EntityType.AUTH
18+
19+
20+
class KeypairScopeAction(BaseScopeAction):
21+
@override
22+
@classmethod
23+
def entity_type(cls) -> EntityType:
24+
return EntityType.KEYPAIR
25+
26+
27+
class KeypairScopeActionResult(BaseScopeActionResult):
28+
pass
29+
30+
31+
class KeypairSingleEntityAction(BaseSingleEntityAction):
32+
@override
33+
@classmethod
34+
def entity_type(cls) -> EntityType:
35+
return EntityType.KEYPAIR
36+
37+
@override
38+
def field_data(self) -> FieldData | None:
39+
return None
40+
41+
42+
class KeypairSingleEntityActionResult(BaseSingleEntityActionResult):
43+
pass

src/ai/backend/manager/services/auth/actions/generate_ssh_keypair.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,48 @@
22
from dataclasses import dataclass
33
from typing import override
44

5-
from ai.backend.manager.actions.action import BaseActionResult
5+
from ai.backend.common.data.permission.types import RBACElementType, ScopeType
66
from ai.backend.manager.actions.types import ActionOperationType
77
from ai.backend.manager.data.auth.types import SSHKeypair
8-
from ai.backend.manager.services.auth.actions.base import AuthAction
8+
from ai.backend.manager.data.permission.types import RBACElementRef
9+
from ai.backend.manager.services.auth.actions.base import (
10+
KeypairScopeAction,
11+
KeypairScopeActionResult,
12+
)
913

1014

1115
@dataclass
12-
class GenerateSSHKeypairAction(AuthAction):
16+
class GenerateSSHKeypairAction(KeypairScopeAction):
1317
user_id: uuid.UUID
1418
access_key: str
1519

16-
@override
17-
def entity_id(self) -> str | None:
18-
return str(self.user_id)
19-
2020
@override
2121
@classmethod
2222
def operation_type(cls) -> ActionOperationType:
2323
return ActionOperationType.CREATE
2424

25+
@override
26+
def scope_type(self) -> ScopeType:
27+
return ScopeType.USER
28+
29+
@override
30+
def scope_id(self) -> str:
31+
return str(self.user_id)
32+
33+
@override
34+
def target_element(self) -> RBACElementRef:
35+
return RBACElementRef(RBACElementType.USER, str(self.user_id))
36+
2537

2638
@dataclass
27-
class GenerateSSHKeypairActionResult(BaseActionResult):
39+
class GenerateSSHKeypairActionResult(KeypairScopeActionResult):
2840
ssh_keypair: SSHKeypair
41+
user_id: uuid.UUID
2942

3043
@override
31-
def entity_id(self) -> str | None:
32-
return None
44+
def scope_type(self) -> ScopeType:
45+
return ScopeType.USER
46+
47+
@override
48+
def scope_id(self) -> str:
49+
return str(self.user_id)

src/ai/backend/manager/services/auth/actions/get_ssh_keypair.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,39 @@
22
from dataclasses import dataclass
33
from typing import override
44

5-
from ai.backend.manager.actions.action import BaseActionResult
5+
from ai.backend.common.data.permission.types import RBACElementType
66
from ai.backend.manager.actions.types import ActionOperationType
7-
from ai.backend.manager.services.auth.actions.base import AuthAction
7+
from ai.backend.manager.data.permission.types import RBACElementRef
8+
from ai.backend.manager.services.auth.actions.base import (
9+
KeypairSingleEntityAction,
10+
KeypairSingleEntityActionResult,
11+
)
812

913

1014
@dataclass
11-
class GetSSHKeypairAction(AuthAction):
15+
class GetSSHKeypairAction(KeypairSingleEntityAction):
1216
user_id: uuid.UUID
1317
access_key: str
1418

15-
@override
16-
def entity_id(self) -> str | None:
17-
return str(self.user_id)
18-
1919
@override
2020
@classmethod
2121
def operation_type(cls) -> ActionOperationType:
2222
return ActionOperationType.GET
2323

24+
@override
25+
def target_entity_id(self) -> str:
26+
return self.access_key
27+
28+
@override
29+
def target_element(self) -> RBACElementRef:
30+
return RBACElementRef(RBACElementType.KEYPAIR, self.access_key)
31+
2432

2533
@dataclass
26-
class GetSSHKeypairActionResult(BaseActionResult):
34+
class GetSSHKeypairActionResult(KeypairSingleEntityActionResult):
2735
public_key: str
36+
access_key: str
2837

2938
@override
30-
def entity_id(self) -> str | None:
31-
return None
39+
def target_entity_id(self) -> str:
40+
return self.access_key

src/ai/backend/manager/services/auth/actions/upload_ssh_keypair.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,50 @@
22
from dataclasses import dataclass
33
from typing import override
44

5-
from ai.backend.manager.actions.action import BaseActionResult
5+
from ai.backend.common.data.permission.types import RBACElementType, ScopeType
66
from ai.backend.manager.actions.types import ActionOperationType
77
from ai.backend.manager.data.auth.types import SSHKeypair
8-
from ai.backend.manager.services.auth.actions.base import AuthAction
8+
from ai.backend.manager.data.permission.types import RBACElementRef
9+
from ai.backend.manager.services.auth.actions.base import (
10+
KeypairScopeAction,
11+
KeypairScopeActionResult,
12+
)
913

1014

1115
@dataclass
12-
class UploadSSHKeypairAction(AuthAction):
16+
class UploadSSHKeypairAction(KeypairScopeAction):
1317
user_id: uuid.UUID
1418
public_key: str
1519
private_key: str
1620
access_key: str
1721

18-
@override
19-
def entity_id(self) -> str | None:
20-
return str(self.user_id)
21-
2222
@override
2323
@classmethod
2424
def operation_type(cls) -> ActionOperationType:
2525
return ActionOperationType.CREATE
2626

27+
@override
28+
def scope_type(self) -> ScopeType:
29+
return ScopeType.USER
30+
31+
@override
32+
def scope_id(self) -> str:
33+
return str(self.user_id)
34+
35+
@override
36+
def target_element(self) -> RBACElementRef:
37+
return RBACElementRef(RBACElementType.USER, str(self.user_id))
38+
2739

2840
@dataclass
29-
class UploadSSHKeypairActionResult(BaseActionResult):
41+
class UploadSSHKeypairActionResult(KeypairScopeActionResult):
3042
ssh_keypair: SSHKeypair
43+
user_id: uuid.UUID
3144

3245
@override
33-
def entity_id(self) -> str | None:
34-
return None
46+
def scope_type(self) -> ScopeType:
47+
return ScopeType.USER
48+
49+
@override
50+
def scope_id(self) -> str:
51+
return str(self.user_id)

src/ai/backend/manager/services/auth/processors.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from ai.backend.manager.actions.monitors.monitor import ActionMonitor
44
from ai.backend.manager.actions.processor import ActionProcessor
5+
from ai.backend.manager.actions.processor.scope import ScopeActionProcessor
6+
from ai.backend.manager.actions.processor.single_entity import SingleEntityActionProcessor
57
from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec
68
from ai.backend.manager.actions.validators import ActionValidators
79
from ai.backend.manager.services.auth.actions.authorize import (
@@ -49,9 +51,11 @@
4951
class AuthProcessors(AbstractProcessorPackage):
5052
signout: ActionProcessor[SignoutAction, SignoutActionResult]
5153
update_full_name: ActionProcessor[UpdateFullNameAction, UpdateFullNameActionResult]
52-
get_ssh_keypair: ActionProcessor[GetSSHKeypairAction, GetSSHKeypairActionResult]
53-
generate_ssh_keypair: ActionProcessor[GenerateSSHKeypairAction, GenerateSSHKeypairActionResult]
54-
upload_ssh_keypair: ActionProcessor[UploadSSHKeypairAction, UploadSSHKeypairActionResult]
54+
get_ssh_keypair: SingleEntityActionProcessor[GetSSHKeypairAction, GetSSHKeypairActionResult]
55+
generate_ssh_keypair: ScopeActionProcessor[
56+
GenerateSSHKeypairAction, GenerateSSHKeypairActionResult
57+
]
58+
upload_ssh_keypair: ScopeActionProcessor[UploadSSHKeypairAction, UploadSSHKeypairActionResult]
5559
get_role: ActionProcessor[GetRoleAction, GetRoleActionResult]
5660
authorize: ActionProcessor[AuthorizeAction, AuthorizeActionResult]
5761
signup: ActionProcessor[SignupAction, SignupActionResult]
@@ -72,9 +76,15 @@ def __init__(
7276
) -> None:
7377
self.signout = ActionProcessor(service.signout, action_monitors)
7478
self.update_full_name = ActionProcessor(service.update_full_name, action_monitors)
75-
self.get_ssh_keypair = ActionProcessor(service.get_ssh_keypair, action_monitors)
76-
self.generate_ssh_keypair = ActionProcessor(service.generate_ssh_keypair, action_monitors)
77-
self.upload_ssh_keypair = ActionProcessor(service.upload_ssh_keypair, action_monitors)
79+
self.get_ssh_keypair = SingleEntityActionProcessor(
80+
service.get_ssh_keypair, action_monitors, validators=[validators.rbac.single_entity]
81+
)
82+
self.generate_ssh_keypair = ScopeActionProcessor(
83+
service.generate_ssh_keypair, action_monitors, validators=[validators.rbac.scope]
84+
)
85+
self.upload_ssh_keypair = ScopeActionProcessor(
86+
service.upload_ssh_keypair, action_monitors, validators=[validators.rbac.scope]
87+
)
7888
self.get_role = ActionProcessor(service.get_role, action_monitors)
7989
self.authorize = ActionProcessor(service.authorize, action_monitors)
8090
self.signup = ActionProcessor(service.signup, action_monitors)

src/ai/backend/manager/services/auth/service.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ async def update_password_no_auth(
407407

408408
async def get_ssh_keypair(self, action: GetSSHKeypairAction) -> GetSSHKeypairActionResult:
409409
pubkey = await self._auth_repository.get_ssh_public_key(action.access_key)
410-
return GetSSHKeypairActionResult(public_key=pubkey or "")
410+
return GetSSHKeypairActionResult(public_key=pubkey or "", access_key=action.access_key)
411411

412412
async def generate_ssh_keypair(
413413
self, action: GenerateSSHKeypairAction
@@ -419,7 +419,8 @@ async def generate_ssh_keypair(
419419
ssh_keypair=SSHKeypair(
420420
ssh_public_key=pubkey,
421421
ssh_private_key=privkey,
422-
)
422+
),
423+
user_id=action.user_id,
423424
)
424425

425426
async def upload_ssh_keypair(
@@ -438,6 +439,7 @@ async def upload_ssh_keypair(
438439
ssh_public_key=pubkey,
439440
ssh_private_key=privkey,
440441
),
442+
user_id=action.user_id,
441443
)
442444

443445
async def resolve_access_key_scope(

tests/component/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@
7474
from ai.backend.logging.config import ConsoleConfig, LogDriver, LoggingConfig
7575
from ai.backend.logging.types import LogFormat
7676
from ai.backend.manager.actions.validators import ActionValidators
77+
from ai.backend.manager.actions.validators.rbac import RBACValidators
78+
from ai.backend.manager.actions.validators.rbac.scope import ScopeActionRBACValidator
79+
from ai.backend.manager.actions.validators.rbac.single_entity import SingleEntityActionRBACValidator
7780
from ai.backend.manager.agent_cache import AgentRPCCache
7881
from ai.backend.manager.api import ManagerStatus
7982
from ai.backend.manager.api.rest.app import build_root_app, mount_registries
@@ -1187,7 +1190,14 @@ def auth_processors(
11871190
config_provider=config_provider,
11881191
)
11891192
return AuthProcessors(
1190-
service=service, action_monitors=[], validators=MagicMock(spec=ActionValidators)
1193+
service=service,
1194+
action_monitors=[],
1195+
validators=ActionValidators(
1196+
rbac=RBACValidators(
1197+
scope=MagicMock(spec=ScopeActionRBACValidator),
1198+
single_entity=MagicMock(spec=SingleEntityActionRBACValidator),
1199+
),
1200+
),
11911201
)
11921202

11931203

tests/unit/manager/api/auth/test_handlers.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,12 @@ async def test_calls_processor_and_returns_public_key(
538538
) -> None:
539539
"""Verify processor is called and public key is returned."""
540540
public_key = "ssh-rsa AAAAB3...\n"
541+
access_key = "AKIAIOSFODNN7EXAMPLE"
541542
mock_processors.auth.get_ssh_keypair.wait_for_complete = AsyncMock(
542-
return_value=GetSSHKeypairActionResult(public_key=public_key)
543+
return_value=GetSSHKeypairActionResult(
544+
public_key=public_key,
545+
access_key=access_key,
546+
)
543547
)
544548

545549
response = await handler.get_ssh_keypair(user_context)
@@ -564,12 +568,14 @@ async def test_calls_processor_and_returns_keypair(
564568
"""Verify processor is called and keypair is returned."""
565569
ssh_public_key = "ssh-rsa NEWPUB...\n"
566570
ssh_private_key = "-----BEGIN RSA PRIVATE KEY-----\n...\n"
571+
user_id = user_context.user_uuid
567572
mock_processors.auth.generate_ssh_keypair.wait_for_complete = AsyncMock(
568573
return_value=GenerateSSHKeypairActionResult(
569574
ssh_keypair=SSHKeypair(
570575
ssh_public_key=ssh_public_key,
571576
ssh_private_key=ssh_private_key,
572-
)
577+
),
578+
user_id=user_id,
573579
)
574580
)
575581

@@ -599,12 +605,14 @@ async def test_calls_processor_and_returns_keypair(
599605
"pubkey": "ssh-rsa AAAAB3...",
600606
"privkey": "-----BEGIN RSA PRIVATE KEY-----\n...",
601607
})
608+
user_id = user_context.user_uuid
602609
mock_processors.auth.upload_ssh_keypair.wait_for_complete = AsyncMock(
603610
return_value=UploadSSHKeypairActionResult(
604611
ssh_keypair=SSHKeypair(
605612
ssh_public_key="ssh-rsa AAAAB3...\n",
606613
ssh_private_key="-----BEGIN RSA PRIVATE KEY-----\n...\n",
607-
)
614+
),
615+
user_id=user_id,
608616
)
609617
)
610618

0 commit comments

Comments
 (0)