diff --git a/changes/10033.feature.md b/changes/10033.feature.md new file mode 100644 index 00000000000..f6094860920 --- /dev/null +++ b/changes/10033.feature.md @@ -0,0 +1 @@ +Apply RBAC permission validators to model deployment service actions diff --git a/src/ai/backend/manager/api/rest/service/handler.py b/src/ai/backend/manager/api/rest/service/handler.py index d60bb354a25..db45174dfb6 100644 --- a/src/ai/backend/manager/api/rest/service/handler.py +++ b/src/ai/backend/manager/api/rest/service/handler.py @@ -469,9 +469,9 @@ async def update_route( traffic_ratio=params.traffic_ratio, ) - result = await self._model_serving.update_route.wait_for_complete(action) + await self._model_serving.update_route.wait_for_complete(action) - resp = SuccessResponseModel(success=result.success) + resp = SuccessResponseModel(success=True) return APIResponse.build(HTTPStatus.OK, resp) # ------------------------------------------------------------------ @@ -494,9 +494,9 @@ async def delete_route( route_id=path_params.route_id, ) - result = await self._model_serving.delete_route.wait_for_complete(action) + await self._model_serving.delete_route.wait_for_complete(action) - resp = SuccessResponseModel(success=result.success) + resp = SuccessResponseModel(success=True) return APIResponse.build(HTTPStatus.OK, resp) # ------------------------------------------------------------------ diff --git a/src/ai/backend/manager/services/model_serving/actions/base.py b/src/ai/backend/manager/services/model_serving/actions/base.py index 46e25e78432..62da759b28d 100644 --- a/src/ai/backend/manager/services/model_serving/actions/base.py +++ b/src/ai/backend/manager/services/model_serving/actions/base.py @@ -2,10 +2,42 @@ from ai.backend.common.data.permission.types import EntityType from ai.backend.manager.actions.action import BaseAction +from ai.backend.manager.actions.action.scope import BaseScopeAction, BaseScopeActionResult +from ai.backend.manager.actions.action.single_entity import ( + BaseSingleEntityAction, + BaseSingleEntityActionResult, +) +from ai.backend.manager.actions.action.types import FieldData class ModelServiceAction(BaseAction): @override @classmethod def entity_type(cls) -> EntityType: - return EntityType.MODEL_SERVICE + return EntityType.MODEL_DEPLOYMENT + + +class ModelServiceScopeAction(BaseScopeAction): + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.MODEL_DEPLOYMENT + + +class ModelServiceScopeActionResult(BaseScopeActionResult): + pass + + +class ModelServiceSingleEntityAction(BaseSingleEntityAction): + @override + @classmethod + def entity_type(cls) -> EntityType: + return EntityType.MODEL_DEPLOYMENT + + @override + def field_data(self) -> FieldData | None: + return None + + +class ModelServiceSingleEntityActionResult(BaseSingleEntityActionResult): + pass diff --git a/src/ai/backend/manager/services/model_serving/actions/create_model_service.py b/src/ai/backend/manager/services/model_serving/actions/create_model_service.py index aa486356a72..a27391eda36 100644 --- a/src/ai/backend/manager/services/model_serving/actions/create_model_service.py +++ b/src/ai/backend/manager/services/model_serving/actions/create_model_service.py @@ -4,32 +4,50 @@ from dataclasses import dataclass from typing import override -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import RBACElementType, ScopeType from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.model_serving.creator import ModelServiceCreator from ai.backend.manager.data.model_serving.types import ServiceInfo -from ai.backend.manager.services.model_serving.actions.base import ModelServiceAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.model_serving.actions.base import ( + ModelServiceScopeAction, + ModelServiceScopeActionResult, +) @dataclass -class CreateModelServiceAction(ModelServiceAction): +class CreateModelServiceAction(ModelServiceScopeAction): request_user_id: uuid.UUID creator: ModelServiceCreator - - @override - def entity_id(self) -> str | None: - return None + _project_id: uuid.UUID @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.CREATE + @override + def scope_type(self) -> ScopeType: + return ScopeType.PROJECT + + @override + def scope_id(self) -> str: + return str(self._project_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.PROJECT, str(self._project_id)) + @dataclass -class CreateModelServiceActionResult(BaseActionResult): +class CreateModelServiceActionResult(ModelServiceScopeActionResult): data: ServiceInfo + _project_id: uuid.UUID + + @override + def scope_type(self) -> ScopeType: + return ScopeType.PROJECT @override - def entity_id(self) -> str | None: - return str(self.data.endpoint_id) + def scope_id(self) -> str: + return str(self._project_id) diff --git a/src/ai/backend/manager/services/model_serving/actions/delete_model_service.py b/src/ai/backend/manager/services/model_serving/actions/delete_model_service.py index 4d34171a432..ec14b6b5d34 100644 --- a/src/ai/backend/manager/services/model_serving/actions/delete_model_service.py +++ b/src/ai/backend/manager/services/model_serving/actions/delete_model_service.py @@ -2,29 +2,37 @@ from dataclasses import dataclass from typing import override -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.services.model_serving.actions.base import ModelServiceAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.model_serving.actions.base import ( + ModelServiceSingleEntityAction, + ModelServiceSingleEntityActionResult, +) @dataclass -class DeleteModelServiceAction(ModelServiceAction): +class DeleteModelServiceAction(ModelServiceSingleEntityAction): service_id: uuid.UUID - @override - def entity_id(self) -> str | None: - return None - @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.DELETE + @override + def target_entity_id(self) -> str: + return str(self.service_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.MODEL_DEPLOYMENT, str(self.service_id)) + @dataclass -class DeleteModelServiceActionResult(BaseActionResult): - success: bool +class DeleteModelServiceActionResult(ModelServiceSingleEntityActionResult): + service_id: uuid.UUID @override - def entity_id(self) -> str | None: - return None + def target_entity_id(self) -> str: + return str(self.service_id) diff --git a/src/ai/backend/manager/services/model_serving/actions/delete_route.py b/src/ai/backend/manager/services/model_serving/actions/delete_route.py index 65e84bcd076..0f3bc696d0b 100644 --- a/src/ai/backend/manager/services/model_serving/actions/delete_route.py +++ b/src/ai/backend/manager/services/model_serving/actions/delete_route.py @@ -2,36 +2,43 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import EntityType, RBACElementType from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.services.model_serving.actions.base import ModelServiceAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.model_serving.actions.base import ( + ModelServiceSingleEntityAction, + ModelServiceSingleEntityActionResult, +) @dataclass -class DeleteRouteAction(ModelServiceAction): +class DeleteRouteAction(ModelServiceSingleEntityAction): service_id: uuid.UUID route_id: uuid.UUID @override @classmethod def entity_type(cls) -> EntityType: - return EntityType.DEPLOYMENT_ROUTE - - @override - def entity_id(self) -> str | None: - return None + return EntityType.MODEL_DEPLOYMENT @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.DELETE + @override + def target_entity_id(self) -> str: + return str(self.route_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.MODEL_DEPLOYMENT, str(self.service_id)) + @dataclass -class DeleteRouteActionResult(BaseActionResult): - success: bool +class DeleteRouteActionResult(ModelServiceSingleEntityActionResult): + route_id: uuid.UUID @override - def entity_id(self) -> str | None: - return None + def target_entity_id(self) -> str: + return str(self.route_id) diff --git a/src/ai/backend/manager/services/model_serving/actions/get_model_service_info.py b/src/ai/backend/manager/services/model_serving/actions/get_model_service_info.py index 0ddca5393dd..b8c18b1d4ef 100644 --- a/src/ai/backend/manager/services/model_serving/actions/get_model_service_info.py +++ b/src/ai/backend/manager/services/model_serving/actions/get_model_service_info.py @@ -2,30 +2,38 @@ from dataclasses import dataclass from typing import override -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.model_serving.types import ServiceInfo -from ai.backend.manager.services.model_serving.actions.base import ModelServiceAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.model_serving.actions.base import ( + ModelServiceSingleEntityAction, + ModelServiceSingleEntityActionResult, +) @dataclass -class GetModelServiceInfoAction(ModelServiceAction): +class GetModelServiceInfoAction(ModelServiceSingleEntityAction): service_id: uuid.UUID - @override - def entity_id(self) -> str | None: - return None - @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.GET + @override + def target_entity_id(self) -> str: + return str(self.service_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.MODEL_DEPLOYMENT, str(self.service_id)) + @dataclass -class GetModelServiceInfoActionResult(BaseActionResult): +class GetModelServiceInfoActionResult(ModelServiceSingleEntityActionResult): data: ServiceInfo @override - def entity_id(self) -> str | None: + def target_entity_id(self) -> str: return str(self.data.endpoint_id) diff --git a/src/ai/backend/manager/services/model_serving/actions/modify_endpoint.py b/src/ai/backend/manager/services/model_serving/actions/modify_endpoint.py index 02669292fbd..cb66341a95a 100644 --- a/src/ai/backend/manager/services/model_serving/actions/modify_endpoint.py +++ b/src/ai/backend/manager/services/model_serving/actions/modify_endpoint.py @@ -2,34 +2,43 @@ from dataclasses import dataclass from typing import override -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import RBACElementType from ai.backend.manager.actions.types import ActionOperationType from ai.backend.manager.data.model_serving.types import EndpointData +from ai.backend.manager.data.permission.types import RBACElementRef from ai.backend.manager.models.endpoint import EndpointRow from ai.backend.manager.repositories.base.updater import Updater -from ai.backend.manager.services.model_serving.actions.base import ModelServiceAction +from ai.backend.manager.services.model_serving.actions.base import ( + ModelServiceSingleEntityAction, + ModelServiceSingleEntityActionResult, +) @dataclass -class ModifyEndpointAction(ModelServiceAction): +class ModifyEndpointAction(ModelServiceSingleEntityAction): endpoint_id: uuid.UUID updater: Updater[EndpointRow] - @override - def entity_id(self) -> str | None: - return None - @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.UPDATE + @override + def target_entity_id(self) -> str: + return str(self.endpoint_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.MODEL_DEPLOYMENT, str(self.endpoint_id)) + @dataclass -class ModifyEndpointActionResult(BaseActionResult): +class ModifyEndpointActionResult(ModelServiceSingleEntityActionResult): + endpoint_id: uuid.UUID success: bool data: EndpointData | None @override - def entity_id(self) -> str | None: - return str(self.data.id) if self.data is not None else None + def target_entity_id(self) -> str: + return str(self.endpoint_id) diff --git a/src/ai/backend/manager/services/model_serving/actions/update_route.py b/src/ai/backend/manager/services/model_serving/actions/update_route.py index de19dbc08f4..cde4d864304 100644 --- a/src/ai/backend/manager/services/model_serving/actions/update_route.py +++ b/src/ai/backend/manager/services/model_serving/actions/update_route.py @@ -2,14 +2,17 @@ from dataclasses import dataclass from typing import override -from ai.backend.common.data.permission.types import EntityType -from ai.backend.manager.actions.action import BaseActionResult +from ai.backend.common.data.permission.types import EntityType, RBACElementType from ai.backend.manager.actions.types import ActionOperationType -from ai.backend.manager.services.model_serving.actions.base import ModelServiceAction +from ai.backend.manager.data.permission.types import RBACElementRef +from ai.backend.manager.services.model_serving.actions.base import ( + ModelServiceSingleEntityAction, + ModelServiceSingleEntityActionResult, +) @dataclass -class UpdateRouteAction(ModelServiceAction): +class UpdateRouteAction(ModelServiceSingleEntityAction): service_id: uuid.UUID route_id: uuid.UUID traffic_ratio: float @@ -17,22 +20,26 @@ class UpdateRouteAction(ModelServiceAction): @override @classmethod def entity_type(cls) -> EntityType: - return EntityType.DEPLOYMENT_ROUTE - - @override - def entity_id(self) -> str | None: - return None + return EntityType.MODEL_DEPLOYMENT @override @classmethod def operation_type(cls) -> ActionOperationType: return ActionOperationType.UPDATE + @override + def target_entity_id(self) -> str: + return str(self.route_id) + + @override + def target_element(self) -> RBACElementRef: + return RBACElementRef(RBACElementType.MODEL_DEPLOYMENT, str(self.service_id)) + @dataclass -class UpdateRouteActionResult(BaseActionResult): - success: bool +class UpdateRouteActionResult(ModelServiceSingleEntityActionResult): + route_id: uuid.UUID @override - def entity_id(self) -> str | None: - return None + def target_entity_id(self) -> str: + return str(self.route_id) diff --git a/src/ai/backend/manager/services/model_serving/processors/model_serving.py b/src/ai/backend/manager/services/model_serving/processors/model_serving.py index 1a382e7ff8f..b312643e990 100644 --- a/src/ai/backend/manager/services/model_serving/processors/model_serving.py +++ b/src/ai/backend/manager/services/model_serving/processors/model_serving.py @@ -2,6 +2,8 @@ from ai.backend.manager.actions.monitors.monitor import ActionMonitor from ai.backend.manager.actions.processor import ActionProcessor +from ai.backend.manager.actions.processor.scope import ScopeActionProcessor +from ai.backend.manager.actions.processor.single_entity import SingleEntityActionProcessor from ai.backend.manager.actions.types import AbstractProcessorPackage, ActionSpec from ai.backend.manager.actions.validators import ActionValidators from ai.backend.manager.services.model_serving.actions.clear_error import ( @@ -66,21 +68,30 @@ class ModelServingProcessors(AbstractProcessorPackage): - create_model_service: ActionProcessor[CreateModelServiceAction, CreateModelServiceActionResult] + # Scope actions (with RBAC) + create_model_service: ScopeActionProcessor[ + CreateModelServiceAction, CreateModelServiceActionResult + ] list_model_service: ActionProcessor[ListModelServiceAction, ListModelServiceActionResult] - delete_model_service: ActionProcessor[DeleteModelServiceAction, DeleteModelServiceActionResult] - dry_run_model_service: ActionProcessor[DryRunModelServiceAction, DryRunModelServiceActionResult] - get_model_service_info: ActionProcessor[ + search_services: ActionProcessor[SearchServicesAction, SearchServicesActionResult] + + # Single entity actions (with RBAC) + get_model_service_info: SingleEntityActionProcessor[ GetModelServiceInfoAction, GetModelServiceInfoActionResult ] + delete_model_service: SingleEntityActionProcessor[ + DeleteModelServiceAction, DeleteModelServiceActionResult + ] + modify_endpoint: SingleEntityActionProcessor[ModifyEndpointAction, ModifyEndpointActionResult] + update_route: SingleEntityActionProcessor[UpdateRouteAction, UpdateRouteActionResult] + delete_route: SingleEntityActionProcessor[DeleteRouteAction, DeleteRouteActionResult] + + # Internal/system actions (no RBAC) + dry_run_model_service: ActionProcessor[DryRunModelServiceAction, DryRunModelServiceActionResult] list_errors: ActionProcessor[ListErrorsAction, ListErrorsActionResult] clear_error: ActionProcessor[ClearErrorAction, ClearErrorActionResult] force_sync: ActionProcessor[ForceSyncAction, ForceSyncActionResult] - update_route: ActionProcessor[UpdateRouteAction, UpdateRouteActionResult] - delete_route: ActionProcessor[DeleteRouteAction, DeleteRouteActionResult] generate_token: ActionProcessor[GenerateTokenAction, GenerateTokenActionResult] - modify_endpoint: ActionProcessor[ModifyEndpointAction, ModifyEndpointActionResult] - search_services: ActionProcessor[SearchServicesAction, SearchServicesActionResult] validate_model_service: ActionProcessor[ ValidateModelServiceAction, ValidateModelServiceActionResult ] @@ -91,21 +102,38 @@ def __init__( action_monitors: list[ActionMonitor], validators: ActionValidators, ) -> None: - self.create_model_service = ActionProcessor(service.create, action_monitors) + # Scope actions with RBAC validator + self.create_model_service = ScopeActionProcessor( + service.create, action_monitors, validators=[validators.rbac.scope] + ) self.list_model_service = ActionProcessor(service.list_serve, action_monitors) - self.delete_model_service = ActionProcessor(service.delete, action_monitors) - self.dry_run_model_service = ActionProcessor(service.dry_run, action_monitors) - self.get_model_service_info = ActionProcessor( - service.get_model_service_info, action_monitors + self.search_services = ActionProcessor(service.search_services, action_monitors) + + # Single entity actions with RBAC validator + self.get_model_service_info = SingleEntityActionProcessor( + service.get_model_service_info, + action_monitors, + validators=[validators.rbac.single_entity], + ) + self.delete_model_service = SingleEntityActionProcessor( + service.delete, action_monitors, validators=[validators.rbac.single_entity] + ) + self.modify_endpoint = SingleEntityActionProcessor( + service.modify_endpoint, action_monitors, validators=[validators.rbac.single_entity] + ) + self.update_route = SingleEntityActionProcessor( + service.update_route, action_monitors, validators=[validators.rbac.single_entity] ) + self.delete_route = SingleEntityActionProcessor( + service.delete_route, action_monitors, validators=[validators.rbac.single_entity] + ) + + # Internal/system actions without RBAC + self.dry_run_model_service = ActionProcessor(service.dry_run, action_monitors) self.list_errors = ActionProcessor(service.list_errors, action_monitors) self.clear_error = ActionProcessor(service.clear_error, action_monitors) self.force_sync = ActionProcessor(service.force_sync_with_app_proxy, action_monitors) - self.update_route = ActionProcessor(service.update_route, action_monitors) - self.delete_route = ActionProcessor(service.delete_route, action_monitors) self.generate_token = ActionProcessor(service.generate_token, action_monitors) - self.modify_endpoint = ActionProcessor(service.modify_endpoint, action_monitors) - self.search_services = ActionProcessor(service.search_services, action_monitors) self.validate_model_service = ActionProcessor( service.validate_model_service, action_monitors ) diff --git a/src/ai/backend/manager/services/model_serving/services/model_serving.py b/src/ai/backend/manager/services/model_serving/services/model_serving.py index 6806cae025d..d1c8e3be1d2 100644 --- a/src/ai/backend/manager/services/model_serving/services/model_serving.py +++ b/src/ai/backend/manager/services/model_serving/services/model_serving.py @@ -377,7 +377,7 @@ async def create(self, action: CreateModelServiceAction) -> CreateModelServiceAc endpoint_id = endpoint_data.id return CreateModelServiceActionResult( - ServiceInfo( + data=ServiceInfo( endpoint_id=endpoint_id, model_id=endpoint_spec.model, extra_mounts=[m.vfid.folder_id for m in endpoint_spec.extra_mounts], @@ -389,7 +389,8 @@ async def create(self, action: CreateModelServiceAction) -> CreateModelServiceAc service_endpoint=None, is_public=action.creator.open_to_public, runtime_variant=action.creator.runtime_variant, - ) + ), + _project_id=action._project_id, ) async def list_serve(self, action: ListModelServiceAction) -> ListModelServiceActionResult: @@ -413,7 +414,7 @@ async def list_serve(self, action: ListModelServiceAction) -> ListModelServiceAc is_public=endpoint.open_to_public, ) for endpoint in endpoints - ] + ], ) async def search_services(self, action: SearchServicesAction) -> SearchServicesActionResult: @@ -462,7 +463,7 @@ async def delete(self, action: DeleteModelServiceAction) -> DeleteModelServiceAc # Update endpoint lifecycle await self._repository.update_endpoint_lifecycle(service_id, lifecycle_stage, replicas) - return DeleteModelServiceActionResult(success=True) + return DeleteModelServiceActionResult(service_id=service_id) async def dry_run(self, action: DryRunModelServiceAction) -> DryRunModelServiceActionResult: # TODO: Seperate background task definition and trigger into different layer @@ -746,7 +747,7 @@ async def update_route(self, action: UpdateRouteAction) -> UpdateRouteActionResu updated_endpoint_data.id ) - return UpdateRouteActionResult(success=True) + return UpdateRouteActionResult(route_id=action.route_id) async def delete_route(self, action: DeleteRouteAction) -> DeleteRouteActionResult: # Validate access @@ -783,7 +784,7 @@ async def delete_route(self, action: DeleteRouteAction) -> DeleteRouteActionResu # Decrease endpoint replicas await self._repository.decrease_endpoint_replicas(action.service_id) - return DeleteRouteActionResult(success=True) + return DeleteRouteActionResult(route_id=action.route_id) async def generate_token(self, action: GenerateTokenAction) -> GenerateTokenActionResult: # Validate access @@ -886,7 +887,9 @@ async def modify_endpoint(self, action: ModifyEndpointAction) -> ModifyEndpointA await self._deployment_controller.mark_lifecycle_needed( DeploymentLifecycleType.CHECK_REPLICA, ) - return ModifyEndpointActionResult(success=result.success, data=result.data) + return ModifyEndpointActionResult( + endpoint_id=action.endpoint_id, success=result.success, data=result.data + ) async def validate_model_service( self, action: ValidateModelServiceAction diff --git a/tests/component/model_serving/conftest.py b/tests/component/model_serving/conftest.py index f1b54ad93cd..a04c5369058 100644 --- a/tests/component/model_serving/conftest.py +++ b/tests/component/model_serving/conftest.py @@ -7,6 +7,11 @@ from ai.backend.common.bgtask.bgtask import BackgroundTaskManager from ai.backend.common.events.hub.hub import EventHub from ai.backend.manager.actions.validators import ActionValidators +from ai.backend.manager.actions.validators.rbac import RBACValidators +from ai.backend.manager.actions.validators.rbac.scope import ScopeActionRBACValidator +from ai.backend.manager.actions.validators.rbac.single_entity import ( + SingleEntityActionRBACValidator, +) from ai.backend.manager.api.rest.routing import RouteRegistry from ai.backend.manager.api.rest.service.handler import ServiceHandler from ai.backend.manager.api.rest.service.registry import register_service_routes @@ -68,7 +73,14 @@ def model_serving_processors( revision_generator_registry=revision_gen, ) return ModelServingProcessors( - service=service, action_monitors=[], validators=MagicMock(spec=ActionValidators) + service=service, + action_monitors=[], + validators=ActionValidators( + rbac=RBACValidators( + scope=MagicMock(spec=ScopeActionRBACValidator), + single_entity=MagicMock(spec=SingleEntityActionRBACValidator), + ), + ), ) diff --git a/tests/unit/manager/services/model_serving/actions/BUILD b/tests/unit/manager/services/model_serving/actions/BUILD index 57341b1358b..9ae5c8877ec 100644 --- a/tests/unit/manager/services/model_serving/actions/BUILD +++ b/tests/unit/manager/services/model_serving/actions/BUILD @@ -1,3 +1,7 @@ python_tests( name="tests", ) + +python_test_utils( + name="testutils", +) diff --git a/tests/unit/manager/services/model_serving/actions/conftest.py b/tests/unit/manager/services/model_serving/actions/conftest.py new file mode 100644 index 00000000000..93c6a7d8b02 --- /dev/null +++ b/tests/unit/manager/services/model_serving/actions/conftest.py @@ -0,0 +1,20 @@ +from unittest.mock import MagicMock + +import pytest + +from ai.backend.manager.actions.validators import ActionValidators +from ai.backend.manager.actions.validators.rbac import RBACValidators +from ai.backend.manager.actions.validators.rbac.scope import ScopeActionRBACValidator +from ai.backend.manager.actions.validators.rbac.single_entity import ( + SingleEntityActionRBACValidator, +) + + +@pytest.fixture +def mock_action_validators() -> ActionValidators: + return ActionValidators( + rbac=RBACValidators( + scope=MagicMock(spec=ScopeActionRBACValidator), + single_entity=MagicMock(spec=SingleEntityActionRBACValidator), + ), + ) diff --git a/tests/unit/manager/services/model_serving/actions/test_create_auto_scaling_rule.py b/tests/unit/manager/services/model_serving/actions/test_create_auto_scaling_rule.py index a7a723a1edc..f2a50a86af7 100644 --- a/tests/unit/manager/services/model_serving/actions/test_create_auto_scaling_rule.py +++ b/tests/unit/manager/services/model_serving/actions/test_create_auto_scaling_rule.py @@ -74,11 +74,12 @@ def auto_scaling_processors( self, mock_action_monitor: MagicMock, auto_scaling_service: AutoScalingService, + mock_action_validators: ActionValidators, ) -> ModelServingAutoScalingProcessors: return ModelServingAutoScalingProcessors( service=auto_scaling_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_create_model_service.py b/tests/unit/manager/services/model_serving/actions/test_create_model_service.py index 6d05a109795..56dc9bbdb72 100644 --- a/tests/unit/manager/services/model_serving/actions/test_create_model_service.py +++ b/tests/unit/manager/services/model_serving/actions/test_create_model_service.py @@ -234,11 +234,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture @@ -379,6 +380,7 @@ def mock_create_endpoint_validated(self, mocker: Any, mock_repositories: Any) -> extra_mounts=[], ), ), + _project_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), ), CreateModelServiceActionResult( data=ServiceInfo( @@ -394,6 +396,7 @@ def mock_create_endpoint_validated(self, mocker: Any, mock_repositories: Any) -> is_public=False, runtime_variant=RuntimeVariant.CUSTOM, ), + _project_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), ), ), ScenarioBase.failure( @@ -436,6 +439,7 @@ def mock_create_endpoint_validated(self, mocker: Any, mock_repositories: Any) -> extra_mounts=[], ), ), + _project_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), ), Exception, # insufficient resources ), @@ -479,6 +483,7 @@ def mock_create_endpoint_validated(self, mocker: Any, mock_repositories: Any) -> extra_mounts=[], ), ), + _project_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), ), InvalidAPIParameters, ), @@ -522,6 +527,7 @@ def mock_create_endpoint_validated(self, mocker: Any, mock_repositories: Any) -> extra_mounts=[], ), ), + _project_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), ), CreateModelServiceActionResult( data=ServiceInfo( @@ -537,6 +543,7 @@ def mock_create_endpoint_validated(self, mocker: Any, mock_repositories: Any) -> is_public=True, runtime_variant=RuntimeVariant.CUSTOM, ), + _project_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), ), ), ], @@ -735,11 +742,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture @@ -841,6 +849,7 @@ def action_with_api_request_values(self) -> CreateModelServiceAction: extra_mounts=[], ), ), + _project_id=uuid.UUID("00000000-0000-0000-0000-000000000002"), ) async def test_service_definition_overrides_applied( diff --git a/tests/unit/manager/services/model_serving/actions/test_delete_auto_scaling_rule.py b/tests/unit/manager/services/model_serving/actions/test_delete_auto_scaling_rule.py index 07614bb8da0..26783cd3878 100644 --- a/tests/unit/manager/services/model_serving/actions/test_delete_auto_scaling_rule.py +++ b/tests/unit/manager/services/model_serving/actions/test_delete_auto_scaling_rule.py @@ -69,11 +69,12 @@ def auto_scaling_processors( self, mock_action_monitor: MagicMock, auto_scaling_service: AutoScalingService, + mock_action_validators: ActionValidators, ) -> ModelServingAutoScalingProcessors: return ModelServingAutoScalingProcessors( service=auto_scaling_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_delete_model_service.py b/tests/unit/manager/services/model_serving/actions/test_delete_model_service.py index e30d67ee20d..e3e3851f038 100644 --- a/tests/unit/manager/services/model_serving/actions/test_delete_model_service.py +++ b/tests/unit/manager/services/model_serving/actions/test_delete_model_service.py @@ -158,11 +158,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture @@ -222,7 +223,7 @@ def mock_check_user_access(self, mocker: Any, model_serving_service: Any) -> Asy service_id=uuid.UUID("cccccccc-dddd-eeee-ffff-111111111111"), ), DeleteModelServiceActionResult( - success=True, + service_id=uuid.UUID("cccccccc-dddd-eeee-ffff-111111111111"), ), ), ScenarioBase.failure( diff --git a/tests/unit/manager/services/model_serving/actions/test_dry_run_model_service.py b/tests/unit/manager/services/model_serving/actions/test_dry_run_model_service.py index aad82f76b9d..ae1741b094a 100644 --- a/tests/unit/manager/services/model_serving/actions/test_dry_run_model_service.py +++ b/tests/unit/manager/services/model_serving/actions/test_dry_run_model_service.py @@ -204,11 +204,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture @@ -694,11 +695,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture @@ -998,11 +1000,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_generate_token.py b/tests/unit/manager/services/model_serving/actions/test_generate_token.py index dbbe32fd8a6..68fa163a223 100644 --- a/tests/unit/manager/services/model_serving/actions/test_generate_token.py +++ b/tests/unit/manager/services/model_serving/actions/test_generate_token.py @@ -161,11 +161,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_get_model_service_info.py b/tests/unit/manager/services/model_serving/actions/test_get_model_service_info.py index 47be4444294..0b1f549a03a 100644 --- a/tests/unit/manager/services/model_serving/actions/test_get_model_service_info.py +++ b/tests/unit/manager/services/model_serving/actions/test_get_model_service_info.py @@ -160,11 +160,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_list_errors.py b/tests/unit/manager/services/model_serving/actions/test_list_errors.py index 0655a9f8e91..f4e4138e6f2 100644 --- a/tests/unit/manager/services/model_serving/actions/test_list_errors.py +++ b/tests/unit/manager/services/model_serving/actions/test_list_errors.py @@ -160,11 +160,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_list_model_service.py b/tests/unit/manager/services/model_serving/actions/test_list_model_service.py index b191f088c84..7c8532e834f 100644 --- a/tests/unit/manager/services/model_serving/actions/test_list_model_service.py +++ b/tests/unit/manager/services/model_serving/actions/test_list_model_service.py @@ -160,11 +160,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py b/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py index 4d8bd839243..5866acf4af5 100644 --- a/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py +++ b/tests/unit/manager/services/model_serving/actions/test_model_serving_crud_actions.py @@ -180,11 +180,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture @@ -488,7 +489,7 @@ async def test_healthy_route_deletion_success( action = DeleteRouteAction(service_id=service_id, route_id=route_id) result = await model_serving_processors.delete_route.wait_for_complete(action) - assert result.success is True + assert result.route_id == route_id mock_destroy_session.assert_called_once_with( mock_session_row, forced=False, @@ -539,7 +540,7 @@ async def test_sessionless_route_deletes_without_session_destruction( action = DeleteRouteAction(service_id=service_id, route_id=route_id) result = await model_serving_processors.delete_route.wait_for_complete(action) - assert result.success is True + assert result.route_id == route_id mock_destroy_session.assert_not_called() mock_decrease_endpoint_replicas.assert_called_once_with(service_id) diff --git a/tests/unit/manager/services/model_serving/actions/test_modify_auto_scaling_rule.py b/tests/unit/manager/services/model_serving/actions/test_modify_auto_scaling_rule.py index e202a71a859..b64d3eb7a67 100644 --- a/tests/unit/manager/services/model_serving/actions/test_modify_auto_scaling_rule.py +++ b/tests/unit/manager/services/model_serving/actions/test_modify_auto_scaling_rule.py @@ -80,11 +80,12 @@ def auto_scaling_processors( self, mock_action_monitor: MagicMock, auto_scaling_service: AutoScalingService, + mock_action_validators: ActionValidators, ) -> ModelServingAutoScalingProcessors: return ModelServingAutoScalingProcessors( service=auto_scaling_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_scale_service_replicas.py b/tests/unit/manager/services/model_serving/actions/test_scale_service_replicas.py index a0edf660a70..0e67a29c1b7 100644 --- a/tests/unit/manager/services/model_serving/actions/test_scale_service_replicas.py +++ b/tests/unit/manager/services/model_serving/actions/test_scale_service_replicas.py @@ -68,11 +68,12 @@ def auto_scaling_processors( self, mock_action_monitor: MagicMock, auto_scaling_service: AutoScalingService, + mock_action_validators: ActionValidators, ) -> ModelServingAutoScalingProcessors: return ModelServingAutoScalingProcessors( service=auto_scaling_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_search_services.py b/tests/unit/manager/services/model_serving/actions/test_search_services.py index 24b45700c88..dd7b5fb14d0 100644 --- a/tests/unit/manager/services/model_serving/actions/test_search_services.py +++ b/tests/unit/manager/services/model_serving/actions/test_search_services.py @@ -165,11 +165,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture diff --git a/tests/unit/manager/services/model_serving/actions/test_update_route.py b/tests/unit/manager/services/model_serving/actions/test_update_route.py index d50c3a903f3..73a277d68b0 100644 --- a/tests/unit/manager/services/model_serving/actions/test_update_route.py +++ b/tests/unit/manager/services/model_serving/actions/test_update_route.py @@ -159,11 +159,12 @@ def model_serving_processors( self, mock_action_monitor: MagicMock, model_serving_service: ModelServingService, + mock_action_validators: ActionValidators, ) -> ModelServingProcessors: return ModelServingProcessors( service=model_serving_service, action_monitors=[mock_action_monitor], - validators=MagicMock(spec=ActionValidators), + validators=mock_action_validators, ) @pytest.fixture @@ -243,7 +244,7 @@ def mock_notify_endpoint_route_update_to_appproxy( route_id=uuid.UUID("11111111-1111-1111-1111-111111111111"), traffic_ratio=0.7, ), - UpdateRouteActionResult(success=True), + UpdateRouteActionResult(route_id=uuid.UUID("11111111-1111-1111-1111-111111111111")), ), ], )