diff --git a/changes/10300.feature.md b/changes/10300.feature.md new file mode 100644 index 00000000000..4efc92a5657 --- /dev/null +++ b/changes/10300.feature.md @@ -0,0 +1 @@ +Add `update_deployment_policy` GQL mutation diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 88989077722..1e67a3758b6 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -7718,6 +7718,13 @@ type Mutation """Added in 25.16.0""" addModelRevision(input: AddRevisionInput!): AddRevisionPayload! @join__field(graph: STRAWBERRY) + """ + Added in 26.4.0. + Create or update the deployment policy for a given deployment (upsert semantics). + If the deployment already has a policy, it is replaced entirely with the new configuration. + """ + updateDeploymentPolicy(input: UpdateDeploymentPolicyInput!): UpdateDeploymentPolicyPayload! @join__field(graph: STRAWBERRY) + """Create a new notification channel (admin only)""" adminCreateNotificationChannel(input: CreateNotificationChannelInput!): CreateNotificationChannelPayload! @join__field(graph: STRAWBERRY) @@ -12948,6 +12955,35 @@ type UpdateDeploymentPayload deployment: ModelDeployment! } +""" +Added in 26.4.0. +Input for creating or updating a deployment policy (upsert semantics). +Specify the target deployment_id and the desired strategy type. +Exactly one of rolling_update or blue_green must be provided, +matching the chosen strategy type. +If a policy already exists for the deployment, it is replaced entirely. +""" +input UpdateDeploymentPolicyInput + @join__type(graph: STRAWBERRY) +{ + deploymentId: ID! + strategy: DeploymentStrategyType! + rollbackOnFailure: Boolean! = false + rollingUpdate: RollingUpdateConfigInput = null + blueGreen: BlueGreenConfigInput = null +} + +""" +Added in 26.4.0. +Result payload returned after creating or updating a deployment policy. +Contains the full deployment_policy object reflecting the applied configuration. +""" +type UpdateDeploymentPolicyPayload + @join__type(graph: STRAWBERRY) +{ + deploymentPolicy: DeploymentPolicy! +} + """Added in 25.14.0""" input UpdateHuggingFaceRegistryInput @join__type(graph: STRAWBERRY) diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index f334b31750e..4612d288f81 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -4074,6 +4074,13 @@ type Mutation { """Added in 25.16.0""" addModelRevision(input: AddRevisionInput!): AddRevisionPayload! + """ + Added in 26.4.0. + Create or update the deployment policy for a given deployment (upsert semantics). + If the deployment already has a policy, it is replaced entirely with the new configuration. + """ + updateDeploymentPolicy(input: UpdateDeploymentPolicyInput!): UpdateDeploymentPolicyPayload! + """Create a new notification channel (admin only)""" adminCreateNotificationChannel(input: CreateNotificationChannelInput!): CreateNotificationChannelPayload! @@ -7891,6 +7898,31 @@ type UpdateDeploymentPayload { deployment: ModelDeployment! } +""" +Added in 26.4.0. +Input for creating or updating a deployment policy (upsert semantics). +Specify the target deployment_id and the desired strategy type. +Exactly one of rolling_update or blue_green must be provided, +matching the chosen strategy type. +If a policy already exists for the deployment, it is replaced entirely. +""" +input UpdateDeploymentPolicyInput { + deploymentId: ID! + strategy: DeploymentStrategyType! + rollbackOnFailure: Boolean! = false + rollingUpdate: RollingUpdateConfigInput = null + blueGreen: BlueGreenConfigInput = null +} + +""" +Added in 26.4.0. +Result payload returned after creating or updating a deployment policy. +Contains the full deployment_policy object reflecting the applied configuration. +""" +type UpdateDeploymentPolicyPayload { + deploymentPolicy: DeploymentPolicy! +} + """Added in 25.14.0""" input UpdateHuggingFaceRegistryInput { id: ID! diff --git a/src/ai/backend/manager/api/gql/deployment/__init__.py b/src/ai/backend/manager/api/gql/deployment/__init__.py index e45322e0a81..31e5fe6c464 100644 --- a/src/ai/backend/manager/api/gql/deployment/__init__.py +++ b/src/ai/backend/manager/api/gql/deployment/__init__.py @@ -40,6 +40,8 @@ routes, sync_replicas, update_auto_scaling_rule, + # Policy + update_deployment_policy, update_model_deployment, update_route_traffic_status, ) @@ -144,6 +146,9 @@ UpdateAutoScalingRulePayload, UpdateDeploymentInput, UpdateDeploymentPayload, + # Policy (mutation types) + UpdateDeploymentPolicyInputGQL, + UpdateDeploymentPolicyPayloadGQL, UpdateRouteTrafficStatusInputGQL, UpdateRouteTrafficStatusPayloadGQL, get_route_pagination_spec, @@ -204,6 +209,8 @@ "DeploymentStrategyTypeGQL", "RollingUpdateConfigInputGQL", "RollingUpdateStrategySpecGQL", + "UpdateDeploymentPolicyInputGQL", + "UpdateDeploymentPolicyPayloadGQL", # Replica Types "ActivenessStatus", "LivenessStatus", @@ -267,6 +274,8 @@ "deployments", "sync_replicas", "update_model_deployment", + # Resolvers - Policy + "update_deployment_policy", # Resolvers - Replica "replica", "replica_status_changed", diff --git a/src/ai/backend/manager/api/gql/deployment/resolver/__init__.py b/src/ai/backend/manager/api/gql/deployment/resolver/__init__.py index 61b2de6c007..56201b26565 100644 --- a/src/ai/backend/manager/api/gql/deployment/resolver/__init__.py +++ b/src/ai/backend/manager/api/gql/deployment/resolver/__init__.py @@ -21,6 +21,9 @@ sync_replicas, update_model_deployment, ) +from .policy import ( + update_deployment_policy, +) from .replica import ( replica, replica_status_changed, @@ -55,6 +58,8 @@ "delete_model_deployment", "sync_replicas", "deployment_status_changed", + # Policy + "update_deployment_policy", # Replica "replicas", "replica", diff --git a/src/ai/backend/manager/api/gql/deployment/resolver/policy.py b/src/ai/backend/manager/api/gql/deployment/resolver/policy.py new file mode 100644 index 00000000000..1eb9e9b8b99 --- /dev/null +++ b/src/ai/backend/manager/api/gql/deployment/resolver/policy.py @@ -0,0 +1,42 @@ +"""Deployment policy resolver functions.""" + +from __future__ import annotations + +import strawberry +from strawberry import Info + +from ai.backend.manager.api.gql.deployment.types.policy import ( + DeploymentPolicyGQL, + UpdateDeploymentPolicyInputGQL, + UpdateDeploymentPolicyPayloadGQL, +) +from ai.backend.manager.api.gql.types import StrawberryGQLContext +from ai.backend.manager.api.gql.utils import check_admin_only, dedent_strip +from ai.backend.manager.services.deployment.actions.deployment_policy.upsert_deployment_policy import ( + UpsertDeploymentPolicyAction, +) + + +@strawberry.mutation( # type: ignore[misc] + description=dedent_strip(""" + Added in 26.4.0. + Create or update the deployment policy for a given deployment (upsert semantics). + If the deployment already has a policy, it is replaced entirely with the new configuration. + """), +) +async def update_deployment_policy( + input: UpdateDeploymentPolicyInputGQL, + info: Info[StrawberryGQLContext], +) -> UpdateDeploymentPolicyPayloadGQL: + """Update (upsert) a deployment policy for a deployment.""" + check_admin_only() + upserter = input.to_upserter() + + processor = info.context.processors.deployment + result = await processor.upsert_deployment_policy.wait_for_complete( + UpsertDeploymentPolicyAction(upserter=upserter) + ) + + return UpdateDeploymentPolicyPayloadGQL( + deployment_policy=DeploymentPolicyGQL.from_data(result.data), + ) diff --git a/src/ai/backend/manager/api/gql/deployment/types/__init__.py b/src/ai/backend/manager/api/gql/deployment/types/__init__.py index 3fb76bc9ab6..4296bc0309a 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/__init__.py +++ b/src/ai/backend/manager/api/gql/deployment/types/__init__.py @@ -57,6 +57,8 @@ DeploymentStrategyTypeGQL, RollingUpdateConfigInputGQL, RollingUpdateStrategySpecGQL, + UpdateDeploymentPolicyInputGQL, + UpdateDeploymentPolicyPayloadGQL, ) from .replica import ( ActivenessStatus, @@ -166,6 +168,8 @@ "DeploymentStrategyTypeGQL", "RollingUpdateConfigInputGQL", "RollingUpdateStrategySpecGQL", + "UpdateDeploymentPolicyInputGQL", + "UpdateDeploymentPolicyPayloadGQL", # Replica "ActivenessStatus", "LivenessStatus", diff --git a/src/ai/backend/manager/api/gql/deployment/types/policy.py b/src/ai/backend/manager/api/gql/deployment/types/policy.py index 4823dd8ecd3..1ea597c00bf 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/policy.py +++ b/src/ai/backend/manager/api/gql/deployment/types/policy.py @@ -4,13 +4,17 @@ from datetime import datetime from typing import Self +from uuid import UUID import strawberry from strawberry import ID from strawberry.relay import Node, NodeID from ai.backend.common.data.model_deployment.types import DeploymentStrategy +from ai.backend.manager.api.gql.utils import dedent_strip from ai.backend.manager.data.deployment.types import DeploymentPolicyData +from ai.backend.manager.data.deployment.upserter import DeploymentPolicyUpserter +from ai.backend.manager.errors.api import InvalidAPIParameters from ai.backend.manager.errors.deployment import InvalidDeploymentStrategySpec from ai.backend.manager.models.deployment_policy import BlueGreenSpec, RollingUpdateSpec @@ -129,3 +133,63 @@ def to_spec(self) -> BlueGreenSpec: auto_promote=self.auto_promote, promote_delay_seconds=self.promote_delay_seconds, ) + + +# ========== Mutation Input/Payload Types ========== + + +@strawberry.input( + name="UpdateDeploymentPolicyInput", + description=dedent_strip(""" + Added in 26.4.0. + Input for creating or updating a deployment policy (upsert semantics). + Specify the target deployment_id and the desired strategy type. + Exactly one of rolling_update or blue_green must be provided, + matching the chosen strategy type. + If a policy already exists for the deployment, it is replaced entirely. + """), +) +class UpdateDeploymentPolicyInputGQL: + deployment_id: ID + strategy: DeploymentStrategyTypeGQL + rollback_on_failure: bool = False + rolling_update: RollingUpdateConfigInputGQL | None = None + blue_green: BlueGreenConfigInputGQL | None = None + + def to_upserter(self) -> DeploymentPolicyUpserter: + """Convert to DeploymentPolicyUpserter for the service layer.""" + + strategy = DeploymentStrategy(self.strategy.value) + strategy_spec: RollingUpdateSpec | BlueGreenSpec + match strategy: + case DeploymentStrategy.ROLLING: + if self.rolling_update is None: + raise InvalidAPIParameters( + "rolling_update config required for ROLLING strategy" + ) + strategy_spec = self.rolling_update.to_spec() + case DeploymentStrategy.BLUE_GREEN: + if self.blue_green is None: + raise InvalidAPIParameters("blue_green config required for BLUE_GREEN strategy") + strategy_spec = self.blue_green.to_spec() + case _: + raise InvalidAPIParameters(f"Unsupported deployment strategy: {strategy}") + + return DeploymentPolicyUpserter( + deployment_id=UUID(str(self.deployment_id)), + strategy=strategy, + strategy_spec=strategy_spec, + rollback_on_failure=self.rollback_on_failure, + ) + + +@strawberry.type( + name="UpdateDeploymentPolicyPayload", + description=dedent_strip(""" + Added in 26.4.0. + Result payload returned after creating or updating a deployment policy. + Contains the full deployment_policy object reflecting the applied configuration. + """), +) +class UpdateDeploymentPolicyPayloadGQL: + deployment_policy: DeploymentPolicyGQL diff --git a/src/ai/backend/manager/api/gql/schema.py b/src/ai/backend/manager/api/gql/schema.py index b07814bdf3c..559c0bbb705 100644 --- a/src/ai/backend/manager/api/gql/schema.py +++ b/src/ai/backend/manager/api/gql/schema.py @@ -76,6 +76,7 @@ routes, sync_replicas, update_auto_scaling_rule, + update_deployment_policy, update_model_deployment, update_route_traffic_status, ) @@ -430,6 +431,7 @@ class Mutation: delete_model_deployment = delete_model_deployment sync_replicas = sync_replicas add_model_revision = add_model_revision + update_deployment_policy = update_deployment_policy # Notification - Admin APIs admin_create_notification_channel = admin_create_notification_channel admin_update_notification_channel = admin_update_notification_channel diff --git a/tests/unit/manager/api/gql/deployment/test_update_deployment_policy.py b/tests/unit/manager/api/gql/deployment/test_update_deployment_policy.py new file mode 100644 index 00000000000..7aa61cb910f --- /dev/null +++ b/tests/unit/manager/api/gql/deployment/test_update_deployment_policy.py @@ -0,0 +1,276 @@ +"""Tests for update_deployment_policy GQL mutation.""" + +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from datetime import UTC, datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest +from aiohttp import web +from strawberry import ID + +from ai.backend.common.data.model_deployment.types import DeploymentStrategy +from ai.backend.manager.api.gql import utils as gql_utils +from ai.backend.manager.api.gql.deployment.resolver import policy as policy_resolver +from ai.backend.manager.api.gql.deployment.types.policy import ( + BlueGreenConfigInputGQL, + RollingUpdateConfigInputGQL, + UpdateDeploymentPolicyInputGQL, + UpdateDeploymentPolicyPayloadGQL, +) +from ai.backend.manager.data.deployment.types import DeploymentPolicyData +from ai.backend.manager.errors.api import InvalidAPIParameters +from ai.backend.manager.models.deployment_policy import BlueGreenSpec, RollingUpdateSpec +from ai.backend.manager.services.deployment.actions.deployment_policy.upsert_deployment_policy import ( + UpsertDeploymentPolicyActionResult, +) + +SAMPLE_DEPLOYMENT_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + + +# --- Test scenarios --- + + +@dataclass(frozen=True) +class StrategyConversionScenario: + """Input → expected upserter output for a valid strategy conversion.""" + + input: UpdateDeploymentPolicyInputGQL + expected_spec: RollingUpdateSpec | BlueGreenSpec + expected_rollback_on_failure: bool + + +@dataclass(frozen=True) +class MissingConfigScenario: + """Input that should raise due to missing strategy config.""" + + input: UpdateDeploymentPolicyInputGQL + expected_error_match: str + + +# --- Fixtures --- + + +@pytest.fixture +def mock_superadmin_user() -> MagicMock: + """Create mock superadmin user.""" + user = MagicMock() + user.is_superadmin = True + return user + + +@pytest.fixture +def mock_regular_user() -> MagicMock: + """Create mock regular (non-superadmin) user.""" + user = MagicMock() + user.is_superadmin = False + return user + + +@pytest.fixture +def mock_upsert_processor() -> AsyncMock: + """Create mock upsert_deployment_policy processor.""" + return AsyncMock() + + +@pytest.fixture +def mock_info(mock_upsert_processor: AsyncMock) -> MagicMock: + """Create mock strawberry.Info with deployment processors.""" + info = MagicMock() + info.context.processors.deployment.upsert_deployment_policy = mock_upsert_processor + return info + + +@pytest.fixture +def rolling_update_input() -> UpdateDeploymentPolicyInputGQL: + """Input for ROLLING strategy with custom surge/unavailable.""" + return UpdateDeploymentPolicyInputGQL( + deployment_id=ID(SAMPLE_DEPLOYMENT_ID), + strategy=DeploymentStrategy.ROLLING, + rollback_on_failure=True, + rolling_update=RollingUpdateConfigInputGQL(max_surge=2, max_unavailable=1), + ) + + +def _make_policy_data( + *, + strategy: DeploymentStrategy = DeploymentStrategy.ROLLING, + strategy_spec: RollingUpdateSpec | BlueGreenSpec | None = None, +) -> DeploymentPolicyData: + """Create a DeploymentPolicyData for mock results.""" + if strategy_spec is None: + strategy_spec = RollingUpdateSpec(max_surge=1, max_unavailable=0) + return DeploymentPolicyData( + id=uuid.uuid4(), + endpoint=uuid.uuid4(), + strategy=strategy, + strategy_spec=strategy_spec, + rollback_on_failure=False, + created_at=datetime(2026, 1, 1, tzinfo=UTC), + updated_at=datetime(2026, 1, 1, tzinfo=UTC), + ) + + +# --- Input type conversion tests --- + + +class TestToUpserterConversion: + """Tests for UpdateDeploymentPolicyInputGQL.to_upserter().""" + + @pytest.mark.parametrize( + "scenario", + [ + pytest.param( + StrategyConversionScenario( + input=UpdateDeploymentPolicyInputGQL( + deployment_id=ID(SAMPLE_DEPLOYMENT_ID), + strategy=DeploymentStrategy.ROLLING, + rollback_on_failure=True, + rolling_update=RollingUpdateConfigInputGQL(max_surge=2, max_unavailable=1), + ), + expected_spec=RollingUpdateSpec(max_surge=2, max_unavailable=1), + expected_rollback_on_failure=True, + ), + id="rolling", + ), + pytest.param( + StrategyConversionScenario( + input=UpdateDeploymentPolicyInputGQL( + deployment_id=ID(SAMPLE_DEPLOYMENT_ID), + strategy=DeploymentStrategy.BLUE_GREEN, + blue_green=BlueGreenConfigInputGQL( + auto_promote=True, promote_delay_seconds=30 + ), + ), + expected_spec=BlueGreenSpec(auto_promote=True, promote_delay_seconds=30), + expected_rollback_on_failure=False, + ), + id="blue_green", + ), + ], + ) + def test_converts_gql_input_to_upserter(self, scenario: StrategyConversionScenario) -> None: + """Test that GQL input is correctly converted to DeploymentPolicyUpserter.""" + upserter = scenario.input.to_upserter() + + assert upserter.strategy == scenario.input.strategy + assert upserter.strategy_spec == scenario.expected_spec + assert upserter.rollback_on_failure is scenario.expected_rollback_on_failure + + @pytest.mark.parametrize( + "scenario", + [ + pytest.param( + MissingConfigScenario( + input=UpdateDeploymentPolicyInputGQL( + deployment_id=ID(SAMPLE_DEPLOYMENT_ID), + strategy=DeploymentStrategy.ROLLING, + ), + expected_error_match="rolling_update", + ), + id="rolling", + ), + pytest.param( + MissingConfigScenario( + input=UpdateDeploymentPolicyInputGQL( + deployment_id=ID(SAMPLE_DEPLOYMENT_ID), + strategy=DeploymentStrategy.BLUE_GREEN, + ), + expected_error_match="blue_green", + ), + id="blue_green", + ), + ], + ) + def test_raises_when_strategy_config_is_missing(self, scenario: MissingConfigScenario) -> None: + """Test that to_upserter() raises when matching strategy config is not provided.""" + with pytest.raises(InvalidAPIParameters, match=scenario.expected_error_match): + scenario.input.to_upserter() + + def test_converts_deployment_id_to_uuid(self) -> None: + """Test that string deployment_id is correctly parsed into UUID.""" + input_gql = UpdateDeploymentPolicyInputGQL( + deployment_id=ID(SAMPLE_DEPLOYMENT_ID), + strategy=DeploymentStrategy.ROLLING, + rolling_update=RollingUpdateConfigInputGQL(), + ) + upserter = input_gql.to_upserter() + + assert str(upserter.deployment_id) == SAMPLE_DEPLOYMENT_ID + + +# --- Resolver tests --- + + +class TestAdminUpdateDeploymentPolicyResolver: + """Tests for update_deployment_policy resolver.""" + + async def test_delegates_upsert_action_to_processor( + self, + mock_superadmin_user: MagicMock, + mock_upsert_processor: AsyncMock, + mock_info: MagicMock, + rolling_update_input: UpdateDeploymentPolicyInputGQL, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test that resolver delegates to processor and returns payload.""" + # Given + policy_data = _make_policy_data( + strategy=DeploymentStrategy.ROLLING, + strategy_spec=RollingUpdateSpec(max_surge=2, max_unavailable=1), + ) + mock_upsert_processor.wait_for_complete.return_value = UpsertDeploymentPolicyActionResult( + data=policy_data, + created=True, + ) + + monkeypatch.setattr( + gql_utils, + "current_user", + lambda: mock_superadmin_user, + ) + + # When + resolver_fn = policy_resolver.update_deployment_policy.base_resolver + result = await resolver_fn(rolling_update_input, mock_info) + + # Then + mock_upsert_processor.wait_for_complete.assert_called_once() + call_args = mock_upsert_processor.wait_for_complete.call_args + action = call_args[0][0] + + assert str(action.upserter.deployment_id) == SAMPLE_DEPLOYMENT_ID + assert action.upserter.strategy == DeploymentStrategy.ROLLING + assert action.upserter.rollback_on_failure is True + + assert isinstance(result, UpdateDeploymentPolicyPayloadGQL) + + async def test_rejects_non_superadmin( + self, + mock_regular_user: MagicMock, + mock_upsert_processor: AsyncMock, + mock_info: MagicMock, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Test that non-superadmin user is rejected with HTTPForbidden.""" + # Given + monkeypatch.setattr( + gql_utils, + "current_user", + lambda: mock_regular_user, + ) + + input_data = UpdateDeploymentPolicyInputGQL( + deployment_id=ID(SAMPLE_DEPLOYMENT_ID), + strategy=DeploymentStrategy.ROLLING, + rolling_update=RollingUpdateConfigInputGQL(), + ) + + # When / Then + resolver_fn = policy_resolver.update_deployment_policy.base_resolver + with pytest.raises(web.HTTPForbidden): + await resolver_fn(input_data, mock_info) + + mock_upsert_processor.wait_for_complete.assert_not_called()