Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/9992.enhance.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor artifact registry action classes to RBAC-aware base classes (`ArtifactRegistrySingleEntityAction`) to enable RBAC validator wiring
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from dataclasses import dataclass
from typing import override

from ai.backend.common.data.permission.types import EntityType
Expand All @@ -11,15 +10,13 @@
from ai.backend.manager.actions.action.types import FieldData


@dataclass
class ArtifactRegistryAction(BaseAction):
@override
@classmethod
def entity_type(cls) -> EntityType:
return EntityType.ARTIFACT_REGISTRY


@dataclass
class ArtifactBatchRegistryAction(BaseBatchAction):
@override
@classmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,44 @@
from dataclasses import dataclass
from typing import override

from ai.backend.manager.actions.action import BaseActionResult
from ai.backend.common.data.permission.types import RBACElementType
from ai.backend.manager.actions.types import ActionOperationType
from ai.backend.manager.data.artifact_registries.types import ArtifactRegistryData
from ai.backend.manager.services.artifact_registry.actions.base import ArtifactRegistryAction
from ai.backend.manager.data.permission.types import RBACElementRef
from ai.backend.manager.errors.api import InvalidAPIParameters
from ai.backend.manager.services.artifact_registry.actions.base import (
ArtifactRegistrySingleEntityAction,
ArtifactRegistrySingleEntityActionResult,
)


@dataclass
class GetArtifactRegistryMetaAction(ArtifactRegistryAction):
class GetArtifactRegistryMetaAction(ArtifactRegistrySingleEntityAction):
registry_id: uuid.UUID | None = None
registry_name: str | None = None

@override
def entity_id(self) -> str | None:
return str(self.registry_id)

@override
@classmethod
def operation_type(cls) -> ActionOperationType:
return ActionOperationType.GET

@override
def target_entity_id(self) -> str:
if self.registry_id:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think registry_name should not be included here

return str(self.registry_id)
if self.registry_name:
return self.registry_name
raise InvalidAPIParameters("Either registry_id or registry_name must be provided.")

@override
def target_element(self) -> RBACElementRef:
return RBACElementRef(RBACElementType.ARTIFACT_REGISTRY, self.target_entity_id())


@dataclass
class GetArtifactRegistryMetaActionResult(BaseActionResult):
class GetArtifactRegistryMetaActionResult(ArtifactRegistrySingleEntityActionResult):
result: ArtifactRegistryData

@override
def entity_id(self) -> str | None:
def target_entity_id(self) -> str:
return str(self.result.id)
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class ArtifactRegistryProcessors(AbstractProcessorPackage):
search_reservoir_registries: ActionProcessor[
SearchReservoirRegistriesAction, SearchReservoirRegistriesActionResult
]
get_registry_meta: ActionProcessor[
get_registry_meta: SingleEntityActionProcessor[
GetArtifactRegistryMetaAction, GetArtifactRegistryMetaActionResult
]
get_registry_metas: ActionProcessor[
Expand Down Expand Up @@ -192,14 +192,19 @@ def __init__(
validators=[validators.rbac.single_entity],
)

self.get_registry_meta = SingleEntityActionProcessor(
service.get_registry_meta,
action_monitors,
validators=[validators.rbac.single_entity],
)

# Internal/batch actions without RBAC
self.get_huggingface_registries = ActionProcessor(
service.get_huggingface_registries, action_monitors
)
self.get_reservoir_registries = ActionProcessor(
service.get_reservoir_registries, action_monitors
)
self.get_registry_meta = ActionProcessor(service.get_registry_meta, action_monitors)
self.get_registry_metas = ActionProcessor(service.get_registry_metas, action_monitors)

@override
Expand Down
Loading