From 030b77ecb5c4cce47a71020b8ac99cf319f3230f Mon Sep 17 00:00:00 2001 From: Gyubong Date: Fri, 8 May 2026 19:34:21 +0900 Subject: [PATCH 01/18] fix(BA-5983): make ModelConfig GQL input fields optional MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The addModelRevision mutation rejected requests that omitted ModelConfigInput.name (and similarly model_path, ModelServiceConfigInput.port, ModelHealthCheckInput.path), even though those fields are routinely supplied by other layers in the revision merge chain — the runtime variant's default_model_definition, the model vfolder's model-definition.yaml, a revision preset, or the model_mount_destination default. Bind the input types (ModelConfigInputGQL, ModelDefinitionInputGQL, ModelServiceConfigInputGQL, ModelHealthCheckInputGQL) to the *Draft DTO variants from common/config so every field that the merge chain can supply is nullable at the GQL boundary. Required-field validation stays where it belongs — in {ModelConfig,ModelHealthCheck,ModelServiceConfig}Draft .to_resolved() after the full merge. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../api/gql/deployment/types/revision.py | 91 +++++++++++++------ 1 file changed, 62 insertions(+), 29 deletions(-) diff --git a/src/ai/backend/manager/api/gql/deployment/types/revision.py b/src/ai/backend/manager/api/gql/deployment/types/revision.py index a08c6114708..f4b90b1aa19 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/revision.py +++ b/src/ai/backend/manager/api/gql/deployment/types/revision.py @@ -14,19 +14,19 @@ from strawberry.scalars import JSON from ai.backend.common.config import ( - ModelConfig as ModelConfigDTO, + ModelConfigDraft as ModelConfigDraftDTO, ) from ai.backend.common.config import ( - ModelDefinition as ModelDefinitionDTO, + ModelDefinitionDraft as ModelDefinitionDraftDTO, ) from ai.backend.common.config import ( - ModelHealthCheck as ModelHealthCheckDTO, + ModelHealthCheckDraft as ModelHealthCheckDraftDTO, ) from ai.backend.common.config import ( ModelMetadata as ModelMetadataDTO, ) from ai.backend.common.config import ( - ModelServiceConfig as ModelServiceConfigDTO, + ModelServiceConfigDraft as ModelServiceConfigDraftDTO, ) from ai.backend.common.config import ( PreStartAction as PreStartActionDTO, @@ -808,22 +808,30 @@ class PreStartActionInputGQL(PydanticInputMixin[PreStartActionDTO]): ), name="ModelHealthCheckInput", ) -class ModelHealthCheckInputGQL(PydanticInputMixin[ModelHealthCheckDTO]): - interval: float = gql_field( - description="Interval in seconds between health checks.", default=10.0 +class ModelHealthCheckInputGQL(PydanticInputMixin[ModelHealthCheckDraftDTO]): + interval: float | None = gql_field( + description="Interval in seconds between health checks. Falls back to the runtime variant or built-in default (10.0) when omitted.", + default=None, ) - path: str = gql_field(description="Path to check for health status.") - max_retries: int = gql_field( - description="Maximum number of retries for health check.", default=10 + path: str | None = gql_field( + description="Path to check for health status. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted.", + default=None, ) - max_wait_time: float = gql_field( - description="Maximum time in seconds to wait for a health check response.", default=15.0 + max_retries: int | None = gql_field( + description="Maximum number of retries for health check. Falls back to the runtime variant or built-in default (10) when omitted.", + default=None, ) - expected_status_code: int = gql_field( - description="Expected HTTP status code for a healthy response.", default=200 + max_wait_time: float | None = gql_field( + description="Maximum time in seconds to wait for a health check response. Falls back to the runtime variant or built-in default (15.0) when omitted.", + default=None, ) - initial_delay: float = gql_field( - description="Initial delay in seconds before the first health check.", default=60.0 + expected_status_code: int | None = gql_field( + description="Expected HTTP status code for a healthy response. Falls back to the runtime variant or built-in default (200) when omitted.", + default=None, + ) + initial_delay: float | None = gql_field( + description="Initial delay in seconds before the first health check. Falls back to the runtime variant or built-in default (60.0) when omitted.", + default=None, ) @@ -834,19 +842,25 @@ class ModelHealthCheckInputGQL(PydanticInputMixin[ModelHealthCheckDTO]): ), name="ModelServiceConfigInput", ) -class ModelServiceConfigInputGQL(PydanticInputMixin[ModelServiceConfigDTO]): - pre_start_actions: list[PreStartActionInputGQL] = gql_field( +class ModelServiceConfigInputGQL(PydanticInputMixin[ModelServiceConfigDraftDTO]): + pre_start_actions: list[PreStartActionInputGQL] | None = gql_field( description="List of pre-start actions to execute before starting the model service.", - default=strawberry.UNSET, + default=None, ) start_command: list[str] | None = gql_field( description="Command to start the model service.", default=None ) - shell: str = gql_field( - description="Shell configured for the model service.", - default="/bin/bash", + shell: str | None = gql_field( + description="Shell configured for the model service. Falls back to the runtime variant or built-in default (/bin/bash) when omitted.", + default=None, + ) + port: int | None = gql_field( + description=( + "Port number for the model service. May be supplied by the runtime variant" + " default model definition or the model vfolder's model-definition.yaml when omitted." + ), + default=None, ) - port: int = gql_field(description="Port number for the model service. Must be greater than 1.") health_check: ModelHealthCheckInputGQL | None = gql_field( description="Health check configuration for the model service.", default=None ) @@ -888,9 +902,23 @@ class ModelMetadataInputGQL(PydanticInputMixin[ModelMetadataDTO]): ), name="ModelConfigInput", ) -class ModelConfigInputGQL(PydanticInputMixin[ModelConfigDTO]): - name: str = gql_field(description="Name of the model.") - model_path: str = gql_field(description="Path to the model file.") +class ModelConfigInputGQL(PydanticInputMixin[ModelConfigDraftDTO]): + name: str | None = gql_field( + description=( + "Name of the model. May be supplied by the runtime variant default model" + " definition, a revision preset, or the model vfolder's model-definition.yaml" + " when omitted; the merge chain produces the final value." + ), + default=None, + ) + model_path: str | None = gql_field( + description=( + "Path to the model file. Defaults to the model mount destination when not" + " overridden by the runtime variant, a revision preset, the vfolder's" + " model-definition.yaml, or this request." + ), + default=None, + ) service: ModelServiceConfigInputGQL | None = gql_field( description="Configuration for the model service.", default=None ) @@ -906,9 +934,14 @@ class ModelConfigInputGQL(PydanticInputMixin[ModelConfigDTO]): ), name="ModelDefinitionInput", ) -class ModelDefinitionInputGQL(PydanticInputMixin[ModelDefinitionDTO]): - models: list[ModelConfigInputGQL] = gql_field( - description="List of models in the model definition." +class ModelDefinitionInputGQL(PydanticInputMixin[ModelDefinitionDraftDTO]): + models: list[ModelConfigInputGQL] | None = gql_field( + description=( + "List of model entries in the model definition. Omit to inherit the entire" + " definition from the runtime variant default, the revision preset, or the" + " model vfolder's model-definition.yaml; provide entries to override per-index." + ), + default=None, ) From a1fd8c2c91de1902f80d315eaf6179eef1297d62 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Fri, 8 May 2026 19:34:52 +0900 Subject: [PATCH 02/18] chore: add news fragment for PR #11531 Co-Authored-By: Claude Opus 4.7 (1M context) --- changes/11531.fix.md | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/11531.fix.md diff --git a/changes/11531.fix.md b/changes/11531.fix.md new file mode 100644 index 00000000000..ad8cd7320d4 --- /dev/null +++ b/changes/11531.fix.md @@ -0,0 +1 @@ +Make ModelConfig / ModelDefinition / ModelServiceConfig / ModelHealthCheck GraphQL input fields optional so addModelRevision can inherit values from the runtime variant, model-definition.yaml, or revision preset. From b84c950f8fe073f1623ffb12c3b1056beb11ddda Mon Sep 17 00:00:00 2001 From: Gyubong Date: Fri, 8 May 2026 10:38:09 +0000 Subject: [PATCH 03/18] chore: update api schema dump Co-authored-by: octodog --- .../graphql-reference/supergraph.graphql | 68 ++++++++++++------- .../graphql-reference/v2-schema.graphql | 68 ++++++++++++------- 2 files changed, 90 insertions(+), 46 deletions(-) diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 6e75b72e73b..248f1a4ef67 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -9234,11 +9234,15 @@ Added in 26.4.0. Configuration for a single model within a model definition. input ModelConfigInput @join__type(graph: STRAWBERRY) { - """Name of the model.""" - name: String! + """ + Name of the model. May be supplied by the runtime variant default model definition, a revision preset, or the model vfolder's model-definition.yaml when omitted; the merge chain produces the final value. + """ + name: String = null - """Path to the model file.""" - modelPath: String! + """ + Path to the model file. Defaults to the model mount destination when not overridden by the runtime variant, a revision preset, the vfolder's model-definition.yaml, or this request. + """ + modelPath: String = null """Configuration for the model service.""" service: ModelServiceConfigInput = null @@ -9263,8 +9267,10 @@ Added in 26.4.0. Model definition containing a list of model configurations. input ModelDefinitionInput @join__type(graph: STRAWBERRY) { - """List of models in the model definition.""" - models: [ModelConfigInput!]! + """ + List of model entries in the model definition. Omit to inherit the entire definition from the runtime variant default, the revision preset, or the model vfolder's model-definition.yaml; provide entries to override per-index. + """ + models: [ModelConfigInput!] = null } """ @@ -9430,23 +9436,35 @@ type ModelHealthCheck input ModelHealthCheckInput @join__type(graph: STRAWBERRY) { - """Interval in seconds between health checks.""" - interval: Float! = 10 + """ + Interval in seconds between health checks. Falls back to the runtime variant or built-in default (10.0) when omitted. + """ + interval: Float = null - """Path to check for health status.""" - path: String! + """ + Path to check for health status. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. + """ + path: String = null - """Maximum number of retries for health check.""" - maxRetries: Int! = 10 + """ + Maximum number of retries for health check. Falls back to the runtime variant or built-in default (10) when omitted. + """ + maxRetries: Int = null - """Maximum time in seconds to wait for a health check response.""" - maxWaitTime: Float! = 15 + """ + Maximum time in seconds to wait for a health check response. Falls back to the runtime variant or built-in default (15.0) when omitted. + """ + maxWaitTime: Float = null - """Expected HTTP status code for a healthy response.""" - expectedStatusCode: Int! = 200 + """ + Expected HTTP status code for a healthy response. Falls back to the runtime variant or built-in default (200) when omitted. + """ + expectedStatusCode: Int = null - """Initial delay in seconds before the first health check.""" - initialDelay: Float! = 60 + """ + Initial delay in seconds before the first health check. Falls back to the runtime variant or built-in default (60.0) when omitted. + """ + initialDelay: Float = null } """Added in 26.4.2. Metadata describing a model entry.""" @@ -9803,16 +9821,20 @@ input ModelServiceConfigInput """ List of pre-start actions to execute before starting the model service. """ - preStartActions: [PreStartActionInput!]! + preStartActions: [PreStartActionInput!] = null """Command to start the model service.""" startCommand: [String!] = null - """Shell configured for the model service.""" - shell: String! = "/bin/bash" + """ + Shell configured for the model service. Falls back to the runtime variant or built-in default (/bin/bash) when omitted. + """ + shell: String = null - """Port number for the model service. Must be greater than 1.""" - port: Int! + """ + Port number for the model service. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. + """ + port: Int = null """Health check configuration for the model service.""" healthCheck: ModelHealthCheckInput = null diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index fc690ad83fb..0d9a0633b50 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -6044,11 +6044,15 @@ type ModelConfig { Added in 26.4.0. Configuration for a single model within a model definition. """ input ModelConfigInput { - """Name of the model.""" - name: String! + """ + Name of the model. May be supplied by the runtime variant default model definition, a revision preset, or the model vfolder's model-definition.yaml when omitted; the merge chain produces the final value. + """ + name: String = null - """Path to the model file.""" - modelPath: String! + """ + Path to the model file. Defaults to the model mount destination when not overridden by the runtime variant, a revision preset, the vfolder's model-definition.yaml, or this request. + """ + modelPath: String = null """Configuration for the model service.""" service: ModelServiceConfigInput = null @@ -6069,8 +6073,10 @@ type ModelDefinition { Added in 26.4.0. Model definition containing a list of model configurations. """ input ModelDefinitionInput { - """List of models in the model definition.""" - models: [ModelConfigInput!]! + """ + List of model entries in the model definition. Omit to inherit the entire definition from the runtime variant default, the revision preset, or the model vfolder's model-definition.yaml; provide entries to override per-index. + """ + models: [ModelConfigInput!] = null } """ @@ -6217,23 +6223,35 @@ type ModelHealthCheck { """Added in 26.4.0. Health check configuration for a model service.""" input ModelHealthCheckInput { - """Interval in seconds between health checks.""" - interval: Float! = 10 + """ + Interval in seconds between health checks. Falls back to the runtime variant or built-in default (10.0) when omitted. + """ + interval: Float = null - """Path to check for health status.""" - path: String! + """ + Path to check for health status. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. + """ + path: String = null - """Maximum number of retries for health check.""" - maxRetries: Int! = 10 + """ + Maximum number of retries for health check. Falls back to the runtime variant or built-in default (10) when omitted. + """ + maxRetries: Int = null - """Maximum time in seconds to wait for a health check response.""" - maxWaitTime: Float! = 15 + """ + Maximum time in seconds to wait for a health check response. Falls back to the runtime variant or built-in default (15.0) when omitted. + """ + maxWaitTime: Float = null - """Expected HTTP status code for a healthy response.""" - expectedStatusCode: Int! = 200 + """ + Expected HTTP status code for a healthy response. Falls back to the runtime variant or built-in default (200) when omitted. + """ + expectedStatusCode: Int = null - """Initial delay in seconds before the first health check.""" - initialDelay: Float! = 60 + """ + Initial delay in seconds before the first health check. Falls back to the runtime variant or built-in default (60.0) when omitted. + """ + initialDelay: Float = null } """Added in 26.4.2. Metadata describing a model entry.""" @@ -6554,16 +6572,20 @@ input ModelServiceConfigInput { """ List of pre-start actions to execute before starting the model service. """ - preStartActions: [PreStartActionInput!]! + preStartActions: [PreStartActionInput!] = null """Command to start the model service.""" startCommand: [String!] = null - """Shell configured for the model service.""" - shell: String! = "/bin/bash" + """ + Shell configured for the model service. Falls back to the runtime variant or built-in default (/bin/bash) when omitted. + """ + shell: String = null - """Port number for the model service. Must be greater than 1.""" - port: Int! + """ + Port number for the model service. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. + """ + port: Int = null """Health check configuration for the model service.""" healthCheck: ModelHealthCheckInput = null From 91472f48000b44f7dc452695a1f3967d3b24da95 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Fri, 8 May 2026 19:45:13 +0900 Subject: [PATCH 04/18] refactor(BA-5983): introduce v2 ModelConfig/ModelDefinition Input DTOs Re-route the GraphQL/REST v2 input boundary through dedicated ``Model{HealthCheck,Metadata,ServiceConfig,Config,Definition}Input`` DTOs in ``common/dto/manager/v2/deployment/request.py`` instead of binding directly to the merge-chain ``*Draft`` domain models from ``common/config``. The ``*Draft`` types remain the internal merge-chain representation; the boundary types are owned by the v2 DTO package. - New v2 Input DTOs mirror the structure of the corresponding ``*Draft`` types (every field optional) so the request layer stays permissive and lower-priority sources (runtime variant default, revision preset, vfolder ``model-definition.yaml``) can supply whatever the request omits. - ``CreateRevisionInputDTO``/``AddRevisionGQLInputDTO``/v2 ``RevisionInput`` now type ``model_definition`` as ``ModelDefinitionInput`` instead of leaking ``ModelDefinitionDraft`` from ``common/config`` into the v2 DTO package. - Add ``to_model_definition_draft`` converter alongside the DTOs and call it at the GQL adapter boundary (``manager/api/adapters/ deployment/adapter.py``) before constructing ``ModelRevisionCreator``; the legacy REST path is unchanged (still uses the deprecated ``common/dto/manager/deployment`` ``RevisionInput``). - Re-bind ``Model*InputGQL`` types in ``manager/api/gql/deployment/types/revision.py`` to the new v2 Input DTOs and drop the temporary ``*DraftDTO`` aliases introduced in the previous commit. Required-field validation still happens in ``ModelConfigDraft.to_resolved`` / ``ModelHealthCheckDraft.to_resolved`` / ``ModelServiceConfigDraft.to_resolved`` after the merge, which is called in ``DeploymentController.add_revision``. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../dto/manager/v2/deployment/request.py | 136 +++++++++++++++++- .../api/adapters/deployment/adapter.py | 5 +- .../api/gql/deployment/types/revision.py | 40 +++--- 3 files changed, 155 insertions(+), 26 deletions(-) diff --git a/src/ai/backend/common/dto/manager/v2/deployment/request.py b/src/ai/backend/common/dto/manager/v2/deployment/request.py index e84677eb734..2ffbff9875f 100644 --- a/src/ai/backend/common/dto/manager/v2/deployment/request.py +++ b/src/ai/backend/common/dto/manager/v2/deployment/request.py @@ -13,7 +13,10 @@ from pydantic import Field, field_validator from ai.backend.common.api_handlers import SENTINEL, BaseRequestModel, Sentinel -from ai.backend.common.config import ModelDefinitionDraft +from ai.backend.common.config import ( + ModelDefinitionDraft, + PreStartAction, +) from ai.backend.common.data.model_deployment.types import ( DeploymentStrategy, RouteHealthStatus, @@ -80,10 +83,15 @@ "EnvironmentVariablesInput", "ExtraVFolderMountInput", "ImageInput", + "ModelConfigInput", + "ModelDefinitionInput", "ModelDeploymentMetadataInput", "ModelDeploymentNetworkAccessInput", + "ModelHealthCheckInput", + "ModelMetadataInput", "ModelMountConfigInput", "ModelRuntimeConfigInput", + "ModelServiceConfigInput", "ReplicaFilter", "ReplicaOrder", "ReplicaStatusFilter", @@ -116,6 +124,126 @@ ) +class ModelHealthCheckInput(BaseRequestModel): + """Input for the model service health check. + + Every field is optional. Lower-priority sources in the revision merge + chain (runtime variant default, revision preset, vfolder + ``model-definition.yaml``) supply values for any field the request + omits; the strict :class:`~ai.backend.common.config.ModelHealthCheck` + is materialized only at the persistence boundary by + :meth:`~ai.backend.common.config.ModelHealthCheckDraft.to_resolved`. + """ + + interval: float | None = Field( + default=None, description="Interval in seconds between health checks." + ) + path: str | None = Field(default=None, description="Path to check for health status.") + max_retries: int | None = Field( + default=None, description="Maximum number of retries for health check." + ) + max_wait_time: float | None = Field( + default=None, + description="Maximum time in seconds to wait for a health check response.", + ) + expected_status_code: int | None = Field( + default=None, description="Expected HTTP status code for a healthy response." + ) + initial_delay: float | None = Field( + default=None, description="Initial delay in seconds before the first health check." + ) + + +class ModelMetadataInput(BaseRequestModel): + """Input for model metadata. Every field is optional.""" + + author: str | None = Field(default=None, description="Author of the model.") + title: str | None = Field(default=None, description="Title of the model.") + version: str | None = Field(default=None, description="Version of the model.") + created: str | None = Field(default=None, description="Creation date of the model.") + last_modified: str | None = Field(default=None, description="Last modified date of the model.") + description: str | None = Field(default=None, description="Description of the model.") + task: str | None = Field(default=None, description="Task type of the model.") + category: str | None = Field(default=None, description="Category of the model.") + architecture: str | None = Field(default=None, description="Architecture of the model.") + framework: list[str] | None = Field(default=None, description="Frameworks used by the model.") + label: list[str] | None = Field(default=None, description="Labels for the model.") + license: str | None = Field(default=None, description="License of the model.") + min_resource: dict[str, Any] | None = Field( + default=None, description="Minimum resource requirements for the model." + ) + + +class ModelServiceConfigInput(BaseRequestModel): + """Input for the model service configuration. Every field is optional. + + Lower-priority sources in the revision merge chain supply any field + the request omits. + """ + + pre_start_actions: list[PreStartAction] | None = Field( + default=None, + description="List of pre-start actions to execute before starting the model service.", + ) + start_command: list[str] | None = Field( + default=None, description="Command to start the model service." + ) + shell: str | None = Field(default=None, description="Shell configured for the model service.") + port: int | None = Field(default=None, description="Port number for the model service.") + health_check: ModelHealthCheckInput | None = Field( + default=None, description="Health check configuration for the model service." + ) + + +class ModelConfigInput(BaseRequestModel): + """Input for a single model entry within a model definition. + + Every field is optional so the revision merge chain (runtime variant + default, revision preset, vfolder ``model-definition.yaml``, and the + ``model_mount_destination`` default) can fill in whatever the request + leaves unset. + """ + + name: str | None = Field(default=None, description="Name of the model.") + model_path: str | None = Field(default=None, description="Path to the model file.") + service: ModelServiceConfigInput | None = Field( + default=None, description="Configuration for the model service." + ) + metadata: ModelMetadataInput | None = Field( + default=None, description="Metadata about the model." + ) + + +class ModelDefinitionInput(BaseRequestModel): + """Input for the full model definition. + + A request may omit ``models`` entirely to inherit the definition from + a lower-priority source, or supply an entry to override per-index. + """ + + models: list[ModelConfigInput] | None = Field( + default=None, description="List of model entries in the model definition." + ) + + +def to_model_definition_draft( + input: ModelDefinitionInput | None, +) -> ModelDefinitionDraft | None: + """Project a v2 :class:`ModelDefinitionInput` onto the internal merge-chain + :class:`~ai.backend.common.config.ModelDefinitionDraft`. + + The two types are structurally isomorphic by design (both all-optional + super-sets of the strict :class:`~ai.backend.common.config.ModelDefinition`), + so a dump-and-validate round-trip is sufficient. Required-field + enforcement still happens in + :meth:`~ai.backend.common.config.ModelDefinitionDraft.to_resolved` after + the full merge. + """ + if input is None: + return None + return ModelDefinitionDraft.model_validate(input.model_dump()) + + class ClusterConfigInput(BaseRequestModel): """Cluster configuration input for a revision.""" @@ -233,7 +361,7 @@ class CreateRevisionInputDTO(BaseRequestModel): image: ImageInput = Field(description="Container image") model_runtime_config: ModelRuntimeConfigInput = Field(description="Runtime configuration") model_mount_config: ModelMountConfigInput = Field(description="Model mount configuration") - model_definition: ModelDefinitionDraft | None = Field( + model_definition: ModelDefinitionInput | None = Field( default=None, description="Model definition to override the default values generated by the server", ) @@ -269,7 +397,7 @@ class AddRevisionGQLInputDTO(BaseRequestModel): image: ImageInput = Field(description="Container image") model_runtime_config: ModelRuntimeConfigInput = Field(description="Runtime configuration") model_mount_config: ModelMountConfigInput = Field(description="Model mount configuration") - model_definition: ModelDefinitionDraft | None = Field( + model_definition: ModelDefinitionInput | None = Field( default=None, description="Model definition to override the default values generated by the server", ) @@ -396,7 +524,7 @@ class RevisionInput(BaseRequestModel): default="/models", description="Mount destination for model vfolder" ) model_definition_path: str = Field(description="Path to model definition file") - model_definition: ModelDefinitionDraft | None = Field( + model_definition: ModelDefinitionInput | None = Field( default=None, description="Model definition to override the default values generated by the server", ) diff --git a/src/ai/backend/manager/api/adapters/deployment/adapter.py b/src/ai/backend/manager/api/adapters/deployment/adapter.py index 865f09af4f8..a6cdef4e500 100644 --- a/src/ai/backend/manager/api/adapters/deployment/adapter.py +++ b/src/ai/backend/manager/api/adapters/deployment/adapter.py @@ -62,6 +62,7 @@ SyncReplicaInput, UpdateDeploymentInput, UpsertDeploymentPolicyInput, + to_model_definition_draft, ) from ai.backend.common.dto.manager.v2.deployment.response import ( AccessTokenNode, @@ -505,7 +506,7 @@ async def create( else None, ), mounts=mounts_creator, - model_definition=initial_revision.model_definition, + model_definition=to_model_definition_draft(initial_revision.model_definition), revision_preset_id=initial_revision.revision_preset_id, execution=ExecutionSpec( runtime_variant_id=initial_revision.model_runtime_config.runtime_variant_id, @@ -1110,7 +1111,7 @@ async def add_revision( else None, inference_runtime_config=input.model_runtime_config.inference_runtime_config, ), - model_definition=input.model_definition, + model_definition=to_model_definition_draft(input.model_definition), revision_preset_id=input.revision_preset_id, ) action_result = await self._processors.deployment.add_model_revision.wait_for_complete( diff --git a/src/ai/backend/manager/api/gql/deployment/types/revision.py b/src/ai/backend/manager/api/gql/deployment/types/revision.py index f4b90b1aa19..33811d3e8a3 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/revision.py +++ b/src/ai/backend/manager/api/gql/deployment/types/revision.py @@ -13,21 +13,6 @@ from strawberry.relay import Connection, Edge, NodeID from strawberry.scalars import JSON -from ai.backend.common.config import ( - ModelConfigDraft as ModelConfigDraftDTO, -) -from ai.backend.common.config import ( - ModelDefinitionDraft as ModelDefinitionDraftDTO, -) -from ai.backend.common.config import ( - ModelHealthCheckDraft as ModelHealthCheckDraftDTO, -) -from ai.backend.common.config import ( - ModelMetadata as ModelMetadataDTO, -) -from ai.backend.common.config import ( - ModelServiceConfigDraft as ModelServiceConfigDraftDTO, -) from ai.backend.common.config import ( PreStartAction as PreStartActionDTO, ) @@ -56,12 +41,27 @@ from ai.backend.common.dto.manager.v2.deployment.request import ( ImageInput as ImageInputDTO, ) +from ai.backend.common.dto.manager.v2.deployment.request import ( + ModelConfigInput as ModelConfigInputDTO, +) +from ai.backend.common.dto.manager.v2.deployment.request import ( + ModelDefinitionInput as ModelDefinitionInputDTO, +) +from ai.backend.common.dto.manager.v2.deployment.request import ( + ModelHealthCheckInput as ModelHealthCheckInputDTO, +) +from ai.backend.common.dto.manager.v2.deployment.request import ( + ModelMetadataInput as ModelMetadataInputDTO, +) from ai.backend.common.dto.manager.v2.deployment.request import ( ModelMountConfigInput as ModelMountConfigInputDTO, ) from ai.backend.common.dto.manager.v2.deployment.request import ( ModelRuntimeConfigInput as ModelRuntimeConfigInputDTO, ) +from ai.backend.common.dto.manager.v2.deployment.request import ( + ModelServiceConfigInput as ModelServiceConfigInputDTO, +) from ai.backend.common.dto.manager.v2.deployment.request import ( ResourceConfigInput as ResourceConfigInputDTO, ) @@ -808,7 +808,7 @@ class PreStartActionInputGQL(PydanticInputMixin[PreStartActionDTO]): ), name="ModelHealthCheckInput", ) -class ModelHealthCheckInputGQL(PydanticInputMixin[ModelHealthCheckDraftDTO]): +class ModelHealthCheckInputGQL(PydanticInputMixin[ModelHealthCheckInputDTO]): interval: float | None = gql_field( description="Interval in seconds between health checks. Falls back to the runtime variant or built-in default (10.0) when omitted.", default=None, @@ -842,7 +842,7 @@ class ModelHealthCheckInputGQL(PydanticInputMixin[ModelHealthCheckDraftDTO]): ), name="ModelServiceConfigInput", ) -class ModelServiceConfigInputGQL(PydanticInputMixin[ModelServiceConfigDraftDTO]): +class ModelServiceConfigInputGQL(PydanticInputMixin[ModelServiceConfigInputDTO]): pre_start_actions: list[PreStartActionInputGQL] | None = gql_field( description="List of pre-start actions to execute before starting the model service.", default=None, @@ -873,7 +873,7 @@ class ModelServiceConfigInputGQL(PydanticInputMixin[ModelServiceConfigDraftDTO]) ), name="ModelMetadataInput", ) -class ModelMetadataInputGQL(PydanticInputMixin[ModelMetadataDTO]): +class ModelMetadataInputGQL(PydanticInputMixin[ModelMetadataInputDTO]): author: str | None = gql_field(description="Author of the model.", default=None) title: str | None = gql_field(description="Title of the model.", default=None) version: str | None = gql_field(description="Version of the model.", default=None) @@ -902,7 +902,7 @@ class ModelMetadataInputGQL(PydanticInputMixin[ModelMetadataDTO]): ), name="ModelConfigInput", ) -class ModelConfigInputGQL(PydanticInputMixin[ModelConfigDraftDTO]): +class ModelConfigInputGQL(PydanticInputMixin[ModelConfigInputDTO]): name: str | None = gql_field( description=( "Name of the model. May be supplied by the runtime variant default model" @@ -934,7 +934,7 @@ class ModelConfigInputGQL(PydanticInputMixin[ModelConfigDraftDTO]): ), name="ModelDefinitionInput", ) -class ModelDefinitionInputGQL(PydanticInputMixin[ModelDefinitionDraftDTO]): +class ModelDefinitionInputGQL(PydanticInputMixin[ModelDefinitionInputDTO]): models: list[ModelConfigInputGQL] | None = gql_field( description=( "List of model entries in the model definition. Omit to inherit the entire" From 47012b533fead961bb307d75f0146f8e6e31ae2f Mon Sep 17 00:00:00 2001 From: Gyubong Date: Fri, 8 May 2026 20:02:16 +0900 Subject: [PATCH 05/18] chore(BA-5983): trim repetitive descriptions on Model* input types Drop verbose per-field descriptions on the new v2 ModelConfig / ModelDefinition / ModelServiceConfig / ModelHealthCheck Input DTOs and their GQL counterparts where the field name is already self-describing or the description merely repeated the merge-chain note. Keep one class-level note on ModelDefinitionInput as the single canonical explanation. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../dto/manager/v2/deployment/request.py | 129 +++++------------- .../api/gql/deployment/types/revision.py | 56 ++------ 2 files changed, 46 insertions(+), 139 deletions(-) diff --git a/src/ai/backend/common/dto/manager/v2/deployment/request.py b/src/ai/backend/common/dto/manager/v2/deployment/request.py index 2ffbff9875f..a6fad82fdef 100644 --- a/src/ai/backend/common/dto/manager/v2/deployment/request.py +++ b/src/ai/backend/common/dto/manager/v2/deployment/request.py @@ -125,120 +125,61 @@ class ModelHealthCheckInput(BaseRequestModel): - """Input for the model service health check. - - Every field is optional. Lower-priority sources in the revision merge - chain (runtime variant default, revision preset, vfolder - ``model-definition.yaml``) supply values for any field the request - omits; the strict :class:`~ai.backend.common.config.ModelHealthCheck` - is materialized only at the persistence boundary by - :meth:`~ai.backend.common.config.ModelHealthCheckDraft.to_resolved`. - """ - - interval: float | None = Field( - default=None, description="Interval in seconds between health checks." - ) - path: str | None = Field(default=None, description="Path to check for health status.") - max_retries: int | None = Field( - default=None, description="Maximum number of retries for health check." - ) - max_wait_time: float | None = Field( - default=None, - description="Maximum time in seconds to wait for a health check response.", - ) - expected_status_code: int | None = Field( - default=None, description="Expected HTTP status code for a healthy response." - ) - initial_delay: float | None = Field( - default=None, description="Initial delay in seconds before the first health check." - ) + interval: float | None = None + path: str | None = None + max_retries: int | None = None + max_wait_time: float | None = None + expected_status_code: int | None = None + initial_delay: float | None = None class ModelMetadataInput(BaseRequestModel): - """Input for model metadata. Every field is optional.""" - - author: str | None = Field(default=None, description="Author of the model.") - title: str | None = Field(default=None, description="Title of the model.") - version: str | None = Field(default=None, description="Version of the model.") - created: str | None = Field(default=None, description="Creation date of the model.") - last_modified: str | None = Field(default=None, description="Last modified date of the model.") - description: str | None = Field(default=None, description="Description of the model.") - task: str | None = Field(default=None, description="Task type of the model.") - category: str | None = Field(default=None, description="Category of the model.") - architecture: str | None = Field(default=None, description="Architecture of the model.") - framework: list[str] | None = Field(default=None, description="Frameworks used by the model.") - label: list[str] | None = Field(default=None, description="Labels for the model.") - license: str | None = Field(default=None, description="License of the model.") - min_resource: dict[str, Any] | None = Field( - default=None, description="Minimum resource requirements for the model." - ) + author: str | None = None + title: str | None = None + version: str | None = None + created: str | None = None + last_modified: str | None = None + description: str | None = None + task: str | None = None + category: str | None = None + architecture: str | None = None + framework: list[str] | None = None + label: list[str] | None = None + license: str | None = None + min_resource: dict[str, Any] | None = None class ModelServiceConfigInput(BaseRequestModel): - """Input for the model service configuration. Every field is optional. - - Lower-priority sources in the revision merge chain supply any field - the request omits. - """ - - pre_start_actions: list[PreStartAction] | None = Field( - default=None, - description="List of pre-start actions to execute before starting the model service.", - ) - start_command: list[str] | None = Field( - default=None, description="Command to start the model service." - ) - shell: str | None = Field(default=None, description="Shell configured for the model service.") - port: int | None = Field(default=None, description="Port number for the model service.") - health_check: ModelHealthCheckInput | None = Field( - default=None, description="Health check configuration for the model service." - ) + pre_start_actions: list[PreStartAction] | None = None + start_command: list[str] | None = None + shell: str | None = None + port: int | None = None + health_check: ModelHealthCheckInput | None = None class ModelConfigInput(BaseRequestModel): - """Input for a single model entry within a model definition. - - Every field is optional so the revision merge chain (runtime variant - default, revision preset, vfolder ``model-definition.yaml``, and the - ``model_mount_destination`` default) can fill in whatever the request - leaves unset. - """ - - name: str | None = Field(default=None, description="Name of the model.") - model_path: str | None = Field(default=None, description="Path to the model file.") - service: ModelServiceConfigInput | None = Field( - default=None, description="Configuration for the model service." - ) - metadata: ModelMetadataInput | None = Field( - default=None, description="Metadata about the model." - ) + name: str | None = None + model_path: str | None = None + service: ModelServiceConfigInput | None = None + metadata: ModelMetadataInput | None = None class ModelDefinitionInput(BaseRequestModel): - """Input for the full model definition. + """All-optional v2 input mirror of :class:`ModelDefinitionDraft`. - A request may omit ``models`` entirely to inherit the definition from - a lower-priority source, or supply an entry to override per-index. + Fields a request omits are filled by lower-priority sources in the + revision merge chain (runtime variant baseline, revision preset, + vfolder ``model-definition.yaml``, ``model_mount_destination`` + default). Required-field enforcement happens later in + ``ModelDefinitionDraft.to_resolved`` after the merge. """ - models: list[ModelConfigInput] | None = Field( - default=None, description="List of model entries in the model definition." - ) + models: list[ModelConfigInput] | None = None def to_model_definition_draft( input: ModelDefinitionInput | None, ) -> ModelDefinitionDraft | None: - """Project a v2 :class:`ModelDefinitionInput` onto the internal merge-chain - :class:`~ai.backend.common.config.ModelDefinitionDraft`. - - The two types are structurally isomorphic by design (both all-optional - super-sets of the strict :class:`~ai.backend.common.config.ModelDefinition`), - so a dump-and-validate round-trip is sufficient. Required-field - enforcement still happens in - :meth:`~ai.backend.common.config.ModelDefinitionDraft.to_resolved` after - the full merge. - """ if input is None: return None return ModelDefinitionDraft.model_validate(input.model_dump()) diff --git a/src/ai/backend/manager/api/gql/deployment/types/revision.py b/src/ai/backend/manager/api/gql/deployment/types/revision.py index 33811d3e8a3..11e309995fe 100644 --- a/src/ai/backend/manager/api/gql/deployment/types/revision.py +++ b/src/ai/backend/manager/api/gql/deployment/types/revision.py @@ -810,28 +810,20 @@ class PreStartActionInputGQL(PydanticInputMixin[PreStartActionDTO]): ) class ModelHealthCheckInputGQL(PydanticInputMixin[ModelHealthCheckInputDTO]): interval: float | None = gql_field( - description="Interval in seconds between health checks. Falls back to the runtime variant or built-in default (10.0) when omitted.", - default=None, - ) - path: str | None = gql_field( - description="Path to check for health status. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted.", - default=None, + description="Interval in seconds between health checks.", default=None ) + path: str | None = gql_field(description="Path to check for health status.", default=None) max_retries: int | None = gql_field( - description="Maximum number of retries for health check. Falls back to the runtime variant or built-in default (10) when omitted.", - default=None, + description="Maximum number of retries for health check.", default=None ) max_wait_time: float | None = gql_field( - description="Maximum time in seconds to wait for a health check response. Falls back to the runtime variant or built-in default (15.0) when omitted.", - default=None, + description="Maximum time in seconds to wait for a health check response.", default=None ) expected_status_code: int | None = gql_field( - description="Expected HTTP status code for a healthy response. Falls back to the runtime variant or built-in default (200) when omitted.", - default=None, + description="Expected HTTP status code for a healthy response.", default=None ) initial_delay: float | None = gql_field( - description="Initial delay in seconds before the first health check. Falls back to the runtime variant or built-in default (60.0) when omitted.", - default=None, + description="Initial delay in seconds before the first health check.", default=None ) @@ -851,16 +843,9 @@ class ModelServiceConfigInputGQL(PydanticInputMixin[ModelServiceConfigInputDTO]) description="Command to start the model service.", default=None ) shell: str | None = gql_field( - description="Shell configured for the model service. Falls back to the runtime variant or built-in default (/bin/bash) when omitted.", - default=None, - ) - port: int | None = gql_field( - description=( - "Port number for the model service. May be supplied by the runtime variant" - " default model definition or the model vfolder's model-definition.yaml when omitted." - ), - default=None, + description="Shell configured for the model service.", default=None ) + port: int | None = gql_field(description="Port number for the model service.", default=None) health_check: ModelHealthCheckInputGQL | None = gql_field( description="Health check configuration for the model service.", default=None ) @@ -903,22 +888,8 @@ class ModelMetadataInputGQL(PydanticInputMixin[ModelMetadataInputDTO]): name="ModelConfigInput", ) class ModelConfigInputGQL(PydanticInputMixin[ModelConfigInputDTO]): - name: str | None = gql_field( - description=( - "Name of the model. May be supplied by the runtime variant default model" - " definition, a revision preset, or the model vfolder's model-definition.yaml" - " when omitted; the merge chain produces the final value." - ), - default=None, - ) - model_path: str | None = gql_field( - description=( - "Path to the model file. Defaults to the model mount destination when not" - " overridden by the runtime variant, a revision preset, the vfolder's" - " model-definition.yaml, or this request." - ), - default=None, - ) + name: str | None = gql_field(description="Name of the model.", default=None) + model_path: str | None = gql_field(description="Path to the model file.", default=None) service: ModelServiceConfigInputGQL | None = gql_field( description="Configuration for the model service.", default=None ) @@ -936,12 +907,7 @@ class ModelConfigInputGQL(PydanticInputMixin[ModelConfigInputDTO]): ) class ModelDefinitionInputGQL(PydanticInputMixin[ModelDefinitionInputDTO]): models: list[ModelConfigInputGQL] | None = gql_field( - description=( - "List of model entries in the model definition. Omit to inherit the entire" - " definition from the runtime variant default, the revision preset, or the" - " model vfolder's model-definition.yaml; provide entries to override per-index." - ), - default=None, + description="List of models in the model definition.", default=None ) From 9c1e0039ad0cbb35a260c787aff7bd31c04c162b Mon Sep 17 00:00:00 2001 From: Gyubong Date: Fri, 8 May 2026 20:03:51 +0900 Subject: [PATCH 06/18] chore: update api schema dump Co-Authored-By: Claude Opus 4.7 (1M context) --- .../graphql-reference/supergraph.graphql | 44 +++++-------------- .../graphql-reference/v2-schema.graphql | 44 +++++-------------- 2 files changed, 22 insertions(+), 66 deletions(-) diff --git a/docs/manager/graphql-reference/supergraph.graphql b/docs/manager/graphql-reference/supergraph.graphql index 248f1a4ef67..e14e0a9edc6 100644 --- a/docs/manager/graphql-reference/supergraph.graphql +++ b/docs/manager/graphql-reference/supergraph.graphql @@ -9234,14 +9234,10 @@ Added in 26.4.0. Configuration for a single model within a model definition. input ModelConfigInput @join__type(graph: STRAWBERRY) { - """ - Name of the model. May be supplied by the runtime variant default model definition, a revision preset, or the model vfolder's model-definition.yaml when omitted; the merge chain produces the final value. - """ + """Name of the model.""" name: String = null - """ - Path to the model file. Defaults to the model mount destination when not overridden by the runtime variant, a revision preset, the vfolder's model-definition.yaml, or this request. - """ + """Path to the model file.""" modelPath: String = null """Configuration for the model service.""" @@ -9267,9 +9263,7 @@ Added in 26.4.0. Model definition containing a list of model configurations. input ModelDefinitionInput @join__type(graph: STRAWBERRY) { - """ - List of model entries in the model definition. Omit to inherit the entire definition from the runtime variant default, the revision preset, or the model vfolder's model-definition.yaml; provide entries to override per-index. - """ + """List of models in the model definition.""" models: [ModelConfigInput!] = null } @@ -9436,34 +9430,22 @@ type ModelHealthCheck input ModelHealthCheckInput @join__type(graph: STRAWBERRY) { - """ - Interval in seconds between health checks. Falls back to the runtime variant or built-in default (10.0) when omitted. - """ + """Interval in seconds between health checks.""" interval: Float = null - """ - Path to check for health status. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. - """ + """Path to check for health status.""" path: String = null - """ - Maximum number of retries for health check. Falls back to the runtime variant or built-in default (10) when omitted. - """ + """Maximum number of retries for health check.""" maxRetries: Int = null - """ - Maximum time in seconds to wait for a health check response. Falls back to the runtime variant or built-in default (15.0) when omitted. - """ + """Maximum time in seconds to wait for a health check response.""" maxWaitTime: Float = null - """ - Expected HTTP status code for a healthy response. Falls back to the runtime variant or built-in default (200) when omitted. - """ + """Expected HTTP status code for a healthy response.""" expectedStatusCode: Int = null - """ - Initial delay in seconds before the first health check. Falls back to the runtime variant or built-in default (60.0) when omitted. - """ + """Initial delay in seconds before the first health check.""" initialDelay: Float = null } @@ -9826,14 +9808,10 @@ input ModelServiceConfigInput """Command to start the model service.""" startCommand: [String!] = null - """ - Shell configured for the model service. Falls back to the runtime variant or built-in default (/bin/bash) when omitted. - """ + """Shell configured for the model service.""" shell: String = null - """ - Port number for the model service. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. - """ + """Port number for the model service.""" port: Int = null """Health check configuration for the model service.""" diff --git a/docs/manager/graphql-reference/v2-schema.graphql b/docs/manager/graphql-reference/v2-schema.graphql index 0d9a0633b50..116e48f8b81 100644 --- a/docs/manager/graphql-reference/v2-schema.graphql +++ b/docs/manager/graphql-reference/v2-schema.graphql @@ -6044,14 +6044,10 @@ type ModelConfig { Added in 26.4.0. Configuration for a single model within a model definition. """ input ModelConfigInput { - """ - Name of the model. May be supplied by the runtime variant default model definition, a revision preset, or the model vfolder's model-definition.yaml when omitted; the merge chain produces the final value. - """ + """Name of the model.""" name: String = null - """ - Path to the model file. Defaults to the model mount destination when not overridden by the runtime variant, a revision preset, the vfolder's model-definition.yaml, or this request. - """ + """Path to the model file.""" modelPath: String = null """Configuration for the model service.""" @@ -6073,9 +6069,7 @@ type ModelDefinition { Added in 26.4.0. Model definition containing a list of model configurations. """ input ModelDefinitionInput { - """ - List of model entries in the model definition. Omit to inherit the entire definition from the runtime variant default, the revision preset, or the model vfolder's model-definition.yaml; provide entries to override per-index. - """ + """List of models in the model definition.""" models: [ModelConfigInput!] = null } @@ -6223,34 +6217,22 @@ type ModelHealthCheck { """Added in 26.4.0. Health check configuration for a model service.""" input ModelHealthCheckInput { - """ - Interval in seconds between health checks. Falls back to the runtime variant or built-in default (10.0) when omitted. - """ + """Interval in seconds between health checks.""" interval: Float = null - """ - Path to check for health status. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. - """ + """Path to check for health status.""" path: String = null - """ - Maximum number of retries for health check. Falls back to the runtime variant or built-in default (10) when omitted. - """ + """Maximum number of retries for health check.""" maxRetries: Int = null - """ - Maximum time in seconds to wait for a health check response. Falls back to the runtime variant or built-in default (15.0) when omitted. - """ + """Maximum time in seconds to wait for a health check response.""" maxWaitTime: Float = null - """ - Expected HTTP status code for a healthy response. Falls back to the runtime variant or built-in default (200) when omitted. - """ + """Expected HTTP status code for a healthy response.""" expectedStatusCode: Int = null - """ - Initial delay in seconds before the first health check. Falls back to the runtime variant or built-in default (60.0) when omitted. - """ + """Initial delay in seconds before the first health check.""" initialDelay: Float = null } @@ -6577,14 +6559,10 @@ input ModelServiceConfigInput { """Command to start the model service.""" startCommand: [String!] = null - """ - Shell configured for the model service. Falls back to the runtime variant or built-in default (/bin/bash) when omitted. - """ + """Shell configured for the model service.""" shell: String = null - """ - Port number for the model service. May be supplied by the runtime variant default model definition or the model vfolder's model-definition.yaml when omitted. - """ + """Port number for the model service.""" port: Int = null """Health check configuration for the model service.""" From e82cd8f58e0d511eb715e0c15a22616b745bbebf Mon Sep 17 00:00:00 2001 From: Gyubong Date: Sun, 10 May 2026 18:10:48 +0900 Subject: [PATCH 07/18] refactor(BA-5983): make to_draft a method of ModelDefinitionInput Address review feedback by moving the standalone ``to_model_definition_draft`` helper into ``ModelDefinitionInput.to_draft`` and update the test fixtures to use ``ModelDefinitionInput`` (matching the field type) instead of ``ModelDefinitionDraft``. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../backend/common/dto/manager/v2/deployment/request.py | 9 ++------- .../backend/manager/api/adapters/deployment/adapter.py | 9 ++++++--- .../common/dto/manager/v2/deployment/test_request.py | 8 ++++---- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/ai/backend/common/dto/manager/v2/deployment/request.py b/src/ai/backend/common/dto/manager/v2/deployment/request.py index a6fad82fdef..bf697ed0cfc 100644 --- a/src/ai/backend/common/dto/manager/v2/deployment/request.py +++ b/src/ai/backend/common/dto/manager/v2/deployment/request.py @@ -176,13 +176,8 @@ class ModelDefinitionInput(BaseRequestModel): models: list[ModelConfigInput] | None = None - -def to_model_definition_draft( - input: ModelDefinitionInput | None, -) -> ModelDefinitionDraft | None: - if input is None: - return None - return ModelDefinitionDraft.model_validate(input.model_dump()) + def to_draft(self) -> ModelDefinitionDraft: + return ModelDefinitionDraft.model_validate(self.model_dump()) class ClusterConfigInput(BaseRequestModel): diff --git a/src/ai/backend/manager/api/adapters/deployment/adapter.py b/src/ai/backend/manager/api/adapters/deployment/adapter.py index a6cdef4e500..e091f5db073 100644 --- a/src/ai/backend/manager/api/adapters/deployment/adapter.py +++ b/src/ai/backend/manager/api/adapters/deployment/adapter.py @@ -62,7 +62,6 @@ SyncReplicaInput, UpdateDeploymentInput, UpsertDeploymentPolicyInput, - to_model_definition_draft, ) from ai.backend.common.dto.manager.v2.deployment.response import ( AccessTokenNode, @@ -506,7 +505,9 @@ async def create( else None, ), mounts=mounts_creator, - model_definition=to_model_definition_draft(initial_revision.model_definition), + model_definition=initial_revision.model_definition.to_draft() + if initial_revision.model_definition is not None + else None, revision_preset_id=initial_revision.revision_preset_id, execution=ExecutionSpec( runtime_variant_id=initial_revision.model_runtime_config.runtime_variant_id, @@ -1111,7 +1112,9 @@ async def add_revision( else None, inference_runtime_config=input.model_runtime_config.inference_runtime_config, ), - model_definition=to_model_definition_draft(input.model_definition), + model_definition=input.model_definition.to_draft() + if input.model_definition is not None + else None, revision_preset_id=input.revision_preset_id, ) action_result = await self._processors.deployment.add_model_revision.wait_for_complete( diff --git a/tests/unit/common/dto/manager/v2/deployment/test_request.py b/tests/unit/common/dto/manager/v2/deployment/test_request.py index 922490cf579..b5756ea086d 100644 --- a/tests/unit/common/dto/manager/v2/deployment/test_request.py +++ b/tests/unit/common/dto/manager/v2/deployment/test_request.py @@ -11,7 +11,6 @@ from pydantic import ValidationError from ai.backend.common.api_handlers import SENTINEL, Sentinel -from ai.backend.common.config import ModelDefinitionDraft from ai.backend.common.data.model_deployment.types import DeploymentStrategy from ai.backend.common.dto.manager.v2.deployment.request import ( ActivateDeploymentInput, @@ -25,6 +24,7 @@ DeploymentStrategyInput, ExtraVFolderMountInput, ImageInput, + ModelDefinitionInput, ModelDeploymentMetadataInput, ModelDeploymentNetworkAccessInput, ModelMountConfigInput, @@ -55,7 +55,7 @@ def _make_revision_input(**kwargs: object) -> RevisionInput: "runtime_variant_id": RuntimeVariantID(uuid.uuid4()), "model_vfolder_id": VFolderUUID(uuid.uuid4()), "model_definition_path": "/models/model.yaml", - "model_definition": ModelDefinitionDraft(), + "model_definition": ModelDefinitionInput(), } defaults.update(kwargs) return RevisionInput(**defaults) @@ -82,7 +82,7 @@ def _make_create_revision_input_dto(**kwargs: object) -> CreateRevisionInputDTO: mount_destination="/models", definition_path="/models/model.yaml", ), - "model_definition": ModelDefinitionDraft(), + "model_definition": ModelDefinitionInput(), } defaults.update(kwargs) return CreateRevisionInputDTO(**defaults) @@ -103,7 +103,7 @@ def test_valid_creation_with_required_fields(self) -> None: runtime_variant_id=runtime_variant_id, model_vfolder_id=model_id, model_definition_path="/models/def.yaml", - model_definition=ModelDefinitionDraft(), + model_definition=ModelDefinitionInput(), ) assert rev.image_id == image_id assert rev.cluster_mode == ClusterMode.SINGLE_NODE From 709dc9e1ff1822d616baa1d13cbc900bef963b6b Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 10:08:16 +0900 Subject: [PATCH 08/18] test(BA-5983): cover ModelDefinitionInput merge + to_resolved behavior Pin the BA-5983 contract: GraphQL/REST inputs accept all-optional fields, but ``to_resolved()`` after the revision merge chain still raises when no source supplies a required field. Three groups: - ``ModelDefinitionInput.to_draft()`` produces a valid empty/partial draft without raising - Empty request merges with variant baseline / preset and resolves to the baseline values (partial overrides also combine correctly) - Missing ``name`` / ``model_path`` / ``port`` / health-check ``path`` with no baseline raises ``ValueError`` at ``to_resolved()`` Co-Authored-By: Claude Opus 4.7 (1M context) --- .../deployment/test_model_definition_merge.py | 231 ++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tests/unit/manager/sokovan/deployment/test_model_definition_merge.py diff --git a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py new file mode 100644 index 00000000000..e72b480e9bd --- /dev/null +++ b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py @@ -0,0 +1,231 @@ +"""Verify that nullable v2 ``ModelDefinitionInput`` fields still result in +correct required-field enforcement after the revision merge chain. + +This pins the BA-5983 behavior: the GraphQL/REST boundary accepts +all-optional fields, but ``to_resolved()`` at the persistence boundary +must still raise when no merge layer (request, preset, variant baseline) +supplies a required value. +""" + +from __future__ import annotations + +import functools + +import pytest + +from ai.backend.common.config import ( + ModelConfigDraft, + ModelDefinitionDraft, + ModelHealthCheckDraft, + ModelServiceConfigDraft, +) +from ai.backend.common.dto.manager.v2.deployment.request import ( + ModelConfigInput, + ModelDefinitionInput, + ModelHealthCheckInput, + ModelServiceConfigInput, +) +from ai.backend.manager.data.deployment.types import RevisionDraft + + +def _merge(*drafts: RevisionDraft) -> RevisionDraft: + return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) + + +class TestModelDefinitionInputToDraft: + """``ModelDefinitionInput.to_draft`` is the bridge between the + all-optional DTO and the merge-chain draft. The conversion itself + must never raise — required-field enforcement is deferred to + ``to_resolved()`` after the merge.""" + + def test_empty_input_yields_empty_draft(self) -> None: + draft = ModelDefinitionInput().to_draft() + assert isinstance(draft, ModelDefinitionDraft) + assert draft.models is None + + def test_partial_input_preserves_nones(self) -> None: + draft = ModelDefinitionInput( + models=[ModelConfigInput(name="only-name")], + ).to_draft() + assert draft.models is not None + assert draft.models[0].name == "only-name" + assert draft.models[0].model_path is None + + def test_nested_service_input_round_trips(self) -> None: + draft = ModelDefinitionInput( + models=[ + ModelConfigInput( + name="m", + service=ModelServiceConfigInput( + port=8080, + health_check=ModelHealthCheckInput(path="/healthz"), + ), + ) + ] + ).to_draft() + assert draft.models is not None + svc = draft.models[0].service + assert svc is not None + assert svc.port == 8080 + assert svc.health_check is not None + assert svc.health_check.path == "/healthz" + + +class TestEmptyInputMergesWithBaseline: + """Empty (all-null) request input must let lower-priority sources + (variant baseline, preset) fill the required fields, and the merged + draft must resolve cleanly.""" + + def test_baseline_fills_required_fields_when_request_is_empty(self) -> None: + variant_baseline = RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ + ModelConfigDraft(name="llama", model_path="/models/llama"), + ] + ), + ) + request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) + + merged = _merge(variant_baseline, request) + + assert merged.model_definition is not None + resolved = merged.model_definition.to_resolved() + assert resolved.models[0].name == "llama" + assert resolved.models[0].model_path == "/models/llama" + + def test_preset_fills_required_fields_when_request_is_empty(self) -> None: + preset = RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="from-preset", + model_path="/preset/path", + service=ModelServiceConfigDraft( + port=9000, + health_check=ModelHealthCheckDraft(path="/ready"), + ), + ) + ] + ), + ) + request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) + + merged = _merge(preset, request) + + assert merged.model_definition is not None + resolved = merged.model_definition.to_resolved() + assert resolved.models[0].name == "from-preset" + assert resolved.models[0].model_path == "/preset/path" + assert resolved.models[0].service is not None + assert resolved.models[0].service.port == 9000 + assert resolved.models[0].service.health_check is not None + assert resolved.models[0].service.health_check.path == "/ready" + + def test_request_partial_override_combines_with_baseline(self) -> None: + variant_baseline = RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ + ModelConfigDraft(name="baseline-name", model_path="/baseline/path"), + ] + ), + ) + request = RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ModelConfigInput(name="user-name")], + ).to_draft(), + ) + + merged = _merge(variant_baseline, request) + + assert merged.model_definition is not None + resolved = merged.model_definition.to_resolved() + assert resolved.models[0].name == "user-name" + assert resolved.models[0].model_path == "/baseline/path" + + +class TestMergeRaisesWhenAllSourcesAreEmpty: + """When neither the request nor any baseline source supplies a + required field, ``to_resolved()`` must raise at the persistence + boundary — preserving the pre-BA-5983 contract.""" + + def test_missing_model_name_raises(self) -> None: + request = RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ModelConfigInput(model_path="/p")], + ).to_draft(), + ) + + merged = _merge(request) + + assert merged.model_definition is not None + with pytest.raises(ValueError, match=r"ModelConfig\.name is required"): + merged.model_definition.to_resolved() + + def test_missing_model_path_raises(self) -> None: + request = RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ModelConfigInput(name="n")], + ).to_draft(), + ) + + merged = _merge(request) + + assert merged.model_definition is not None + with pytest.raises(ValueError, match=r"ModelConfig\.model_path is required"): + merged.model_definition.to_resolved() + + def test_missing_service_port_raises(self) -> None: + request = RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ + ModelConfigInput( + name="n", + model_path="/p", + service=ModelServiceConfigInput(), + ) + ], + ).to_draft(), + ) + + merged = _merge(request) + + assert merged.model_definition is not None + with pytest.raises(ValueError, match=r"ModelServiceConfig\.port is required"): + merged.model_definition.to_resolved() + + def test_missing_health_check_path_raises(self) -> None: + request = RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ + ModelConfigInput( + name="n", + model_path="/p", + service=ModelServiceConfigInput( + port=8080, + health_check=ModelHealthCheckInput(), + ), + ) + ], + ).to_draft(), + ) + + merged = _merge(request) + + assert merged.model_definition is not None + with pytest.raises(ValueError, match=r"ModelHealthCheck\.path is required"): + merged.model_definition.to_resolved() + + def test_empty_request_with_no_baseline_yields_empty_resolved(self) -> None: + """A completely empty merge chain resolves to an empty ModelDefinition. + + The ``add_revision`` controller guards against this case separately + (``model_definition.models must contain at least one entry``); the + resolved type itself permits an empty models list. + """ + request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) + + merged = _merge(request) + + assert merged.model_definition is not None + resolved = merged.model_definition.to_resolved() + assert resolved.models == [] From c52e6f06f3fca973c8ec05c611a04f88f67325db Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:09:15 +0900 Subject: [PATCH 09/18] docs: refer to RevisionDraft.merge instead of removed helper The ``merge_revision_drafts`` helper was removed in #11250 in favor of the ``RevisionDraft.merge(self, other)`` instance method, but ``RevisionDraftReader`` docstring still pointed at the old name. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../manager/sokovan/deployment/revision_draft/reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ai/backend/manager/sokovan/deployment/revision_draft/reader.py b/src/ai/backend/manager/sokovan/deployment/revision_draft/reader.py index f70bcfa5c77..0493b52e452 100644 --- a/src/ai/backend/manager/sokovan/deployment/revision_draft/reader.py +++ b/src/ai/backend/manager/sokovan/deployment/revision_draft/reader.py @@ -45,8 +45,8 @@ class RevisionDraftReader: """Fan out the DB + storage reads that feed the revision merge chain. One public method per API path (legacy create, legacy modify, v2 add). - Each returns the ordered list of drafts the controller feeds into - ``merge_revision_drafts`` — lowest priority first. The model mount + Each returns the ordered list of drafts the controller layers via + ``RevisionDraft.merge`` — lowest priority first. The model mount destination is added as the lowest-priority ``model_path`` default. """ From aa14f1fd26340c2420ffb3def60c90888d3a117e Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:10:53 +0900 Subject: [PATCH 10/18] test(BA-5983): parametrize missing-required-field merge tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Collapse the four ``test_missing_*_raises`` cases into a single parametrized test. Each scenario carries the ``ModelDefinitionInput`` shape and the expected ``ValueError`` pattern; new cases can be added as ``pytest.param`` entries. The remaining tests (DTO→draft conversions, baseline/preset merge shape) are intentionally left non-parametrized — their assertions differ in depth and the per-source distinction is meaningful. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../deployment/test_model_definition_merge.py | 110 ++++++++---------- 1 file changed, 49 insertions(+), 61 deletions(-) diff --git a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py index e72b480e9bd..ff7ab02e480 100644 --- a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py +++ b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py @@ -148,71 +148,59 @@ class TestMergeRaisesWhenAllSourcesAreEmpty: required field, ``to_resolved()`` must raise at the persistence boundary — preserving the pre-BA-5983 contract.""" - def test_missing_model_name_raises(self) -> None: - request = RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ModelConfigInput(model_path="/p")], - ).to_draft(), - ) - - merged = _merge(request) - - assert merged.model_definition is not None - with pytest.raises(ValueError, match=r"ModelConfig\.name is required"): - merged.model_definition.to_resolved() - - def test_missing_model_path_raises(self) -> None: - request = RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ModelConfigInput(name="n")], - ).to_draft(), - ) - - merged = _merge(request) - - assert merged.model_definition is not None - with pytest.raises(ValueError, match=r"ModelConfig\.model_path is required"): - merged.model_definition.to_resolved() - - def test_missing_service_port_raises(self) -> None: - request = RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ - ModelConfigInput( - name="n", - model_path="/p", - service=ModelServiceConfigInput(), - ) - ], - ).to_draft(), - ) - - merged = _merge(request) - - assert merged.model_definition is not None - with pytest.raises(ValueError, match=r"ModelServiceConfig\.port is required"): - merged.model_definition.to_resolved() - - def test_missing_health_check_path_raises(self) -> None: - request = RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ - ModelConfigInput( - name="n", - model_path="/p", - service=ModelServiceConfigInput( - port=8080, - health_check=ModelHealthCheckInput(), - ), - ) - ], - ).to_draft(), - ) + @pytest.mark.parametrize( + ("request_input", "error_pattern"), + [ + pytest.param( + ModelDefinitionInput(models=[ModelConfigInput(model_path="/p")]), + r"ModelConfig\.name is required", + id="missing_name", + ), + pytest.param( + ModelDefinitionInput(models=[ModelConfigInput(name="n")]), + r"ModelConfig\.model_path is required", + id="missing_model_path", + ), + pytest.param( + ModelDefinitionInput( + models=[ + ModelConfigInput( + name="n", + model_path="/p", + service=ModelServiceConfigInput(), + ) + ], + ), + r"ModelServiceConfig\.port is required", + id="missing_service_port", + ), + pytest.param( + ModelDefinitionInput( + models=[ + ModelConfigInput( + name="n", + model_path="/p", + service=ModelServiceConfigInput( + port=8080, + health_check=ModelHealthCheckInput(), + ), + ) + ], + ), + r"ModelHealthCheck\.path is required", + id="missing_health_check_path", + ), + ], + ) + def test_missing_required_field_raises( + self, request_input: ModelDefinitionInput, error_pattern: str + ) -> None: + request = RevisionDraft(model_definition=request_input.to_draft()) merged = _merge(request) assert merged.model_definition is not None - with pytest.raises(ValueError, match=r"ModelHealthCheck\.path is required"): + with pytest.raises(ValueError, match=error_pattern): merged.model_definition.to_resolved() def test_empty_request_with_no_baseline_yields_empty_resolved(self) -> None: From d3cf04c70e9f1264709d1e5b3f9ef7dd3616c20f Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:14:54 +0900 Subject: [PATCH 11/18] test(BA-5983): parametrize remaining merge-behavior test groups MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply parametrize to the two test groups that previously held one test method per scenario: - ``TestModelDefinitionInputToDraft`` — collapse three round-trip cases into a single ``test_to_draft_preserves_input_shape`` that asserts ``draft.model_dump() == input.model_dump()``. The invariant is the same for every scenario (empty / partial / nested); the inputs now live as ``pytest.param`` entries. - ``TestEmptyInputMergesWithBaseline`` — extract a ``ResolvedExpectation`` dataclass so the three "merge produces correct resolved value" scenarios share one test body. New cases (additional sources, deeper overrides) only need a new ``pytest.param``. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../deployment/test_model_definition_merge.py | 205 ++++++++++-------- 1 file changed, 116 insertions(+), 89 deletions(-) diff --git a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py index ff7ab02e480..d4709a574df 100644 --- a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py +++ b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py @@ -10,6 +10,7 @@ from __future__ import annotations import functools +from dataclasses import dataclass import pytest @@ -32,43 +33,56 @@ def _merge(*drafts: RevisionDraft) -> RevisionDraft: return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) +@dataclass(frozen=True) +class ResolvedExpectation: + """Expected attributes on the resolved ``ModelConfig`` at ``models[0]``. + + Only the named-string fields participate; ``None`` means the + corresponding nested object should not be asserted (the scenario + does not exercise it). + """ + + name: str + model_path: str + service_port: int | None = None + health_check_path: str | None = None + + class TestModelDefinitionInputToDraft: """``ModelDefinitionInput.to_draft`` is the bridge between the all-optional DTO and the merge-chain draft. The conversion itself must never raise — required-field enforcement is deferred to - ``to_resolved()`` after the merge.""" + ``to_resolved()`` after the merge — and must preserve every field + the input carries (including ``None`` placeholders).""" - def test_empty_input_yields_empty_draft(self) -> None: - draft = ModelDefinitionInput().to_draft() + @pytest.mark.parametrize( + "input_dto", + [ + pytest.param(ModelDefinitionInput(), id="empty"), + pytest.param( + ModelDefinitionInput(models=[ModelConfigInput(name="only-name")]), + id="partial_name_only", + ), + pytest.param( + ModelDefinitionInput( + models=[ + ModelConfigInput( + name="m", + service=ModelServiceConfigInput( + port=8080, + health_check=ModelHealthCheckInput(path="/healthz"), + ), + ) + ] + ), + id="nested_service_and_health_check", + ), + ], + ) + def test_to_draft_preserves_input_shape(self, input_dto: ModelDefinitionInput) -> None: + draft = input_dto.to_draft() assert isinstance(draft, ModelDefinitionDraft) - assert draft.models is None - - def test_partial_input_preserves_nones(self) -> None: - draft = ModelDefinitionInput( - models=[ModelConfigInput(name="only-name")], - ).to_draft() - assert draft.models is not None - assert draft.models[0].name == "only-name" - assert draft.models[0].model_path is None - - def test_nested_service_input_round_trips(self) -> None: - draft = ModelDefinitionInput( - models=[ - ModelConfigInput( - name="m", - service=ModelServiceConfigInput( - port=8080, - health_check=ModelHealthCheckInput(path="/healthz"), - ), - ) - ] - ).to_draft() - assert draft.models is not None - svc = draft.models[0].service - assert svc is not None - assert svc.port == 8080 - assert svc.health_check is not None - assert svc.health_check.path == "/healthz" + assert draft.model_dump() == input_dto.model_dump() class TestEmptyInputMergesWithBaseline: @@ -76,71 +90,84 @@ class TestEmptyInputMergesWithBaseline: (variant baseline, preset) fill the required fields, and the merged draft must resolve cleanly.""" - def test_baseline_fills_required_fields_when_request_is_empty(self) -> None: - variant_baseline = RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ - ModelConfigDraft(name="llama", model_path="/models/llama"), - ] + @pytest.mark.parametrize( + ("drafts", "expected"), + [ + pytest.param( + [ + RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ModelConfigDraft(name="llama", model_path="/models/llama")], + ), + ), + RevisionDraft(model_definition=ModelDefinitionInput().to_draft()), + ], + ResolvedExpectation(name="llama", model_path="/models/llama"), + id="variant_baseline_fills_required", ), - ) - request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) - - merged = _merge(variant_baseline, request) - - assert merged.model_definition is not None - resolved = merged.model_definition.to_resolved() - assert resolved.models[0].name == "llama" - assert resolved.models[0].model_path == "/models/llama" - - def test_preset_fills_required_fields_when_request_is_empty(self) -> None: - preset = RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ - ModelConfigDraft( - name="from-preset", - model_path="/preset/path", - service=ModelServiceConfigDraft( - port=9000, - health_check=ModelHealthCheckDraft(path="/ready"), + pytest.param( + [ + RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="from-preset", + model_path="/preset/path", + service=ModelServiceConfigDraft( + port=9000, + health_check=ModelHealthCheckDraft(path="/ready"), + ), + ) + ], ), - ) - ] + ), + RevisionDraft(model_definition=ModelDefinitionInput().to_draft()), + ], + ResolvedExpectation( + name="from-preset", + model_path="/preset/path", + service_port=9000, + health_check_path="/ready", + ), + id="preset_fills_nested_required", ), - ) - request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) - - merged = _merge(preset, request) - - assert merged.model_definition is not None - resolved = merged.model_definition.to_resolved() - assert resolved.models[0].name == "from-preset" - assert resolved.models[0].model_path == "/preset/path" - assert resolved.models[0].service is not None - assert resolved.models[0].service.port == 9000 - assert resolved.models[0].service.health_check is not None - assert resolved.models[0].service.health_check.path == "/ready" - - def test_request_partial_override_combines_with_baseline(self) -> None: - variant_baseline = RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ - ModelConfigDraft(name="baseline-name", model_path="/baseline/path"), - ] + pytest.param( + [ + RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ + ModelConfigDraft(name="baseline-name", model_path="/baseline/path"), + ], + ), + ), + RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ModelConfigInput(name="user-name")], + ).to_draft(), + ), + ], + ResolvedExpectation(name="user-name", model_path="/baseline/path"), + id="request_partial_overrides_baseline", ), - ) - request = RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ModelConfigInput(name="user-name")], - ).to_draft(), - ) - - merged = _merge(variant_baseline, request) + ], + ) + def test_merge_resolves_to_expected_values( + self, drafts: list[RevisionDraft], expected: ResolvedExpectation + ) -> None: + merged = _merge(*drafts) assert merged.model_definition is not None resolved = merged.model_definition.to_resolved() - assert resolved.models[0].name == "user-name" - assert resolved.models[0].model_path == "/baseline/path" + model = resolved.models[0] + assert model.name == expected.name + assert model.model_path == expected.model_path + if expected.service_port is not None: + assert model.service is not None + assert model.service.port == expected.service_port + if expected.health_check_path is not None: + assert model.service is not None + assert model.service.health_check is not None + assert model.service.health_check.path == expected.health_check_path class TestMergeRaisesWhenAllSourcesAreEmpty: From 83b3ac49d8be19ef26db3db75b8013805f665d67 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:19:25 +0900 Subject: [PATCH 12/18] fix(BA-5983): preserve unset semantics in ModelDefinitionInput.to_draft MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``ModelDefinitionInput.to_draft()`` used ``model_dump()`` with default arguments, which dumps every field — including the unset ones at their ``None`` default. Round-tripping that through ``model_validate`` left the resulting draft with ``model_fields_set`` containing every field, so every ``None`` looked "explicitly set" and clobbered lower-priority baselines during the revision merge chain. Switch to ``model_dump(exclude_unset=True)`` so the resulting draft's ``model_fields_set`` reflects only what the caller actually provided. This is what makes the BA-5983 scenario actually work end-to-end: a request that omits ``name`` / ``model_path`` / ``service.port`` / ``health_check.path`` lets the variant baseline (or preset) fill them in instead of nulling them out. Extend the merge test to cover this directly — every "missing required field" scenario now layers a baseline draft together with a request draft so the merge actually combines fields across sources. Without the fix, the model_path / service_port / health_check_path cases would raise the wrong error first (e.g. ``ModelConfig.name is required`` fires before ``model_path``) because every request-side ``None`` would clobber the baseline's preserved value. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../dto/manager/v2/deployment/request.py | 7 +- .../deployment/test_model_definition_merge.py | 114 ++++++++++++------ 2 files changed, 86 insertions(+), 35 deletions(-) diff --git a/src/ai/backend/common/dto/manager/v2/deployment/request.py b/src/ai/backend/common/dto/manager/v2/deployment/request.py index bf697ed0cfc..f819ec72d0d 100644 --- a/src/ai/backend/common/dto/manager/v2/deployment/request.py +++ b/src/ai/backend/common/dto/manager/v2/deployment/request.py @@ -177,7 +177,12 @@ class ModelDefinitionInput(BaseRequestModel): models: list[ModelConfigInput] | None = None def to_draft(self) -> ModelDefinitionDraft: - return ModelDefinitionDraft.model_validate(self.model_dump()) + # ``exclude_unset=True`` keeps the resulting draft's + # ``model_fields_set`` aligned with what the caller actually + # provided. Without it, every field would appear "explicitly + # set" (to ``None``) and clobber lower-priority sources during + # the revision merge. + return ModelDefinitionDraft.model_validate(self.model_dump(exclude_unset=True)) class ClusterConfigInput(BaseRequestModel): diff --git a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py index d4709a574df..2ccb29f8ed5 100644 --- a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py +++ b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py @@ -173,58 +173,104 @@ def test_merge_resolves_to_expected_values( class TestMergeRaisesWhenAllSourcesAreEmpty: """When neither the request nor any baseline source supplies a required field, ``to_resolved()`` must raise at the persistence - boundary — preserving the pre-BA-5983 contract.""" + boundary — preserving the pre-BA-5983 contract. + + Each scenario layers a baseline draft (variant-style) together with + a request draft so the merge actually combines fields across + sources; the target required field remains unfilled in every layer + and the resolved-time check fires on it specifically.""" @pytest.mark.parametrize( - ("request_input", "error_pattern"), + ("drafts", "error_pattern"), [ pytest.param( - ModelDefinitionInput(models=[ModelConfigInput(model_path="/p")]), + [ + # baseline supplies model_path; request adds nothing. + RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ModelConfigDraft(model_path="/baseline/path")], + ), + ), + RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ModelConfigInput()], + ).to_draft(), + ), + ], r"ModelConfig\.name is required", - id="missing_name", + id="name_unfilled_across_baseline_and_request", ), pytest.param( - ModelDefinitionInput(models=[ModelConfigInput(name="n")]), + [ + # baseline supplies name; request adds nothing. + RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ModelConfigDraft(name="baseline-name")], + ), + ), + RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ModelConfigInput()], + ).to_draft(), + ), + ], r"ModelConfig\.model_path is required", - id="missing_model_path", + id="model_path_unfilled_across_baseline_and_request", ), pytest.param( - ModelDefinitionInput( - models=[ - ModelConfigInput( - name="n", - model_path="/p", - service=ModelServiceConfigInput(), - ) - ], - ), + [ + # baseline supplies the outer ModelConfig fields; + # request adds an empty service (no port anywhere). + RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ModelConfigDraft(name="n", model_path="/p")], + ), + ), + RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ModelConfigInput(service=ModelServiceConfigInput())], + ).to_draft(), + ), + ], r"ModelServiceConfig\.port is required", - id="missing_service_port", + id="service_port_unfilled_across_baseline_and_request", ), pytest.param( - ModelDefinitionInput( - models=[ - ModelConfigInput( - name="n", - model_path="/p", - service=ModelServiceConfigInput( - port=8080, - health_check=ModelHealthCheckInput(), - ), - ) - ], - ), + [ + # baseline supplies a service with port; request adds + # an empty health_check (no path anywhere). + RevisionDraft( + model_definition=ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="n", + model_path="/p", + service=ModelServiceConfigDraft(port=8080), + ) + ], + ), + ), + RevisionDraft( + model_definition=ModelDefinitionInput( + models=[ + ModelConfigInput( + service=ModelServiceConfigInput( + health_check=ModelHealthCheckInput(), + ), + ) + ], + ).to_draft(), + ), + ], r"ModelHealthCheck\.path is required", - id="missing_health_check_path", + id="health_check_path_unfilled_across_baseline_and_request", ), ], ) - def test_missing_required_field_raises( - self, request_input: ModelDefinitionInput, error_pattern: str + def test_required_field_unfilled_after_merge_raises( + self, drafts: list[RevisionDraft], error_pattern: str ) -> None: - request = RevisionDraft(model_definition=request_input.to_draft()) - - merged = _merge(request) + merged = _merge(*drafts) assert merged.model_definition is not None with pytest.raises(ValueError, match=error_pattern): From 8f74efcd550d4d0b834b29ffe854938d4b06c45e Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:30:34 +0900 Subject: [PATCH 13/18] test(BA-5983): add DB-backed revision merge test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Insert a real ``RuntimeVariantRow`` with a baseline ``default_model_definition`` (round-trips through ``PydanticColumn`` serialization), then exercise the production ``RevisionDraftReader`` + ``RevisionDraft.merge`` pipeline against a request draft built from ``ModelDefinitionInput.to_draft()``. Scenarios: - Empty input + baseline supplying full required tree → resolved ``ModelConfig`` carries baseline values verbatim. - Partial request (name only) + baseline (name + model_path) → request wins on ``name``; baseline's ``model_path`` survives. - Baseline missing ``name`` → ``to_resolved()`` raises ``ModelConfig.name is required``. - Baseline supplying empty ``service`` → ``ModelServiceConfig.port is required``. - Baseline supplying empty ``health_check`` → ``ModelHealthCheck.path is required``. The synthetic merge tests still cover the pure merge functions in-process; this file pins the DB → reader → merge → resolve loop that the ``add_model_revision`` action actually runs in production. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../deployment/test_revision_merge_db.py | 260 ++++++++++++++++++ 1 file changed, 260 insertions(+) create mode 100644 tests/unit/manager/repositories/deployment/test_revision_merge_db.py diff --git a/tests/unit/manager/repositories/deployment/test_revision_merge_db.py b/tests/unit/manager/repositories/deployment/test_revision_merge_db.py new file mode 100644 index 00000000000..cf4b4de86f5 --- /dev/null +++ b/tests/unit/manager/repositories/deployment/test_revision_merge_db.py @@ -0,0 +1,260 @@ +"""DB-backed verification of the BA-5983 revision merge contract. + +Inserts real ``RuntimeVariantRow`` records (so the variant's +``default_model_definition`` round-trips through ``PydanticColumn`` +serialization) and runs the production ``RevisionDraftReader`` + +``RevisionDraft.merge`` pipeline against a request draft built from +``ModelDefinitionInput.to_draft()``. The resolved output is then +inspected to confirm: + +- An empty request inherits every required field from the variant + baseline; the resolved ``ModelDefinition`` carries the baseline + values verbatim. +- A request that supplies a subset of fields overrides only those + fields; baseline-supplied fields survive. +- When no source supplies a required field, ``to_resolved()`` raises + ``ValueError`` with the field-specific message. + +This exercises the full read path (DB → ``PydanticColumn`` → +``RuntimeVariantData`` → ``RevisionDraft``) plus the merge and +resolve phases that the ``add_model_revision`` action ultimately runs. +""" + +from __future__ import annotations + +import functools +import uuid +from collections.abc import AsyncGenerator +from dataclasses import dataclass +from unittest.mock import MagicMock + +import pytest + +from ai.backend.common.config import ( + ModelConfigDraft, + ModelDefinitionDraft, + ModelHealthCheckDraft, + ModelServiceConfigDraft, +) +from ai.backend.common.dto.manager.v2.deployment.request import ( + ModelConfigInput, + ModelDefinitionInput, +) +from ai.backend.common.identifier.runtime_variant import RuntimeVariantID +from ai.backend.common.identifier.vfolder import VFolderUUID +from ai.backend.manager.data.deployment.types import MountMetadata, RevisionDraft +from ai.backend.manager.models.runtime_variant import RuntimeVariantRow +from ai.backend.manager.models.utils import ExtendedAsyncSAEngine +from ai.backend.manager.repositories.deployment.repository import DeploymentRepository +from ai.backend.manager.sokovan.deployment.revision_draft import RevisionDraftReader +from ai.backend.testutils.db import with_tables + + +@dataclass(frozen=True) +class ResolvedExpectation: + """Expected attributes on the resolved ``ModelConfig`` at ``models[0]``.""" + + name: str + model_path: str + service_port: int | None = None + health_check_path: str | None = None + + +class TestRevisionMergeWithRealVariantBaseline: + @pytest.fixture + async def db_with_variant_table( + self, + database_connection: ExtendedAsyncSAEngine, + ) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: + async with with_tables(database_connection, [RuntimeVariantRow]): + yield database_connection + + @pytest.fixture + def reader( + self, + db_with_variant_table: ExtendedAsyncSAEngine, + ) -> RevisionDraftReader: + # ``load_deployment_revision_read_bundle`` only touches the + # runtime_variants table when ``preset_id`` is ``None``; the + # other repository dependencies are not exercised and can be + # stubbed out. + repo = DeploymentRepository( + db=db_with_variant_table, + storage_manager=MagicMock(), + valkey_stat=MagicMock(), + valkey_live=MagicMock(), + valkey_schedule=MagicMock(), + ) + return RevisionDraftReader(deployment_repository=repo) + + @pytest.fixture + def mounts(self) -> MountMetadata: + return MountMetadata( + model_vfolder_id=VFolderUUID(uuid.uuid4()), + model_definition_path=None, + model_mount_destination="/models", + extra_mounts=[], + ) + + @staticmethod + async def _seed_variant_baseline( + db: ExtendedAsyncSAEngine, + baseline: ModelDefinitionDraft, + ) -> RuntimeVariantID: + variant_id = RuntimeVariantID(uuid.uuid4()) + async with db.begin_session() as sess: + sess.add( + RuntimeVariantRow( + id=variant_id, + name=f"test-variant-{variant_id.hex[:8]}", + description="BA-5983 merge-test variant baseline", + reads_vfolder_config_files=False, + default_model_definition=baseline, + ) + ) + await sess.commit() + return variant_id + + @staticmethod + async def _merge_via_reader( + reader: RevisionDraftReader, + variant_id: RuntimeVariantID, + request: RevisionDraft, + mounts: MountMetadata, + ) -> RevisionDraft: + drafts = await reader.read_for_deployment_revision( + runtime_variant_id=variant_id, + request_draft=request, + mounts=mounts, + preset_id=None, + ) + return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) + + @pytest.mark.parametrize( + ("baseline", "request_input", "expected"), + [ + pytest.param( + ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="baseline-llama", + model_path="/models/baseline", + service=ModelServiceConfigDraft( + port=9000, + health_check=ModelHealthCheckDraft(path="/healthz"), + ), + ), + ], + ), + ModelDefinitionInput(), + ResolvedExpectation( + name="baseline-llama", + model_path="/models/baseline", + service_port=9000, + health_check_path="/healthz", + ), + id="empty_request_inherits_full_baseline", + ), + pytest.param( + ModelDefinitionDraft( + models=[ + ModelConfigDraft(name="baseline-name", model_path="/baseline/path"), + ], + ), + ModelDefinitionInput(models=[ModelConfigInput(name="user-name")]), + ResolvedExpectation(name="user-name", model_path="/baseline/path"), + id="request_overrides_name_baseline_keeps_model_path", + ), + ], + ) + async def test_merge_resolves_to_expected_values( + self, + db_with_variant_table: ExtendedAsyncSAEngine, + reader: RevisionDraftReader, + mounts: MountMetadata, + baseline: ModelDefinitionDraft, + request_input: ModelDefinitionInput, + expected: ResolvedExpectation, + ) -> None: + variant_id = await self._seed_variant_baseline(db_with_variant_table, baseline) + request = RevisionDraft(model_definition=request_input.to_draft()) + + merged = await self._merge_via_reader(reader, variant_id, request, mounts) + + assert merged.model_definition is not None + resolved = merged.model_definition.to_resolved() + model = resolved.models[0] + assert model.name == expected.name + assert model.model_path == expected.model_path + if expected.service_port is not None: + assert model.service is not None + assert model.service.port == expected.service_port + if expected.health_check_path is not None: + assert model.service is not None + assert model.service.health_check is not None + assert model.service.health_check.path == expected.health_check_path + + @pytest.mark.parametrize( + ("baseline", "error_pattern"), + [ + pytest.param( + # baseline supplies model_path only; reader's mount-destination + # default would also fill model_path → only ``name`` remains unfilled. + ModelDefinitionDraft(models=[ModelConfigDraft(model_path="/p")]), + r"ModelConfig\.name is required", + id="name_unfilled_across_baseline_and_request", + ), + pytest.param( + # baseline supplies name + model_path + an empty service → + # service.port has no default and no override. + ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="n", + model_path="/p", + service=ModelServiceConfigDraft(), + ), + ], + ), + r"ModelServiceConfig\.port is required", + id="service_port_unfilled_across_baseline_and_request", + ), + pytest.param( + # baseline supplies service.port but an empty health_check → + # health_check.path has no default. + ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="n", + model_path="/p", + service=ModelServiceConfigDraft( + port=8080, + health_check=ModelHealthCheckDraft(), + ), + ), + ], + ), + r"ModelHealthCheck\.path is required", + id="health_check_path_unfilled_across_baseline_and_request", + ), + ], + ) + async def test_required_field_unfilled_after_merge_raises( + self, + db_with_variant_table: ExtendedAsyncSAEngine, + reader: RevisionDraftReader, + mounts: MountMetadata, + baseline: ModelDefinitionDraft, + error_pattern: str, + ) -> None: + # Request is always an all-empty ``ModelDefinitionInput`` for these + # scenarios — the merge result depends entirely on whether the + # baseline (or reader-supplied defaults) cover every required field. + variant_id = await self._seed_variant_baseline(db_with_variant_table, baseline) + request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) + + merged = await self._merge_via_reader(reader, variant_id, request, mounts) + + assert merged.model_definition is not None + with pytest.raises(ValueError, match=error_pattern): + merged.model_definition.to_resolved() From 33b5ab2448fa7f240d08baccd97230c989022001 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:33:19 +0900 Subject: [PATCH 14/18] test(BA-5983): seed DB baseline via fixture; parametrize on input/result MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restructure the DB-backed merge test so the baseline ``ModelDefinitionDraft`` lives in a fixture (the DB-side value) and the parametrize tables only carry the actual request input + expected outcome. Two scenario groups, each pinned to its own baseline fixture: - ``TestMergeWithFullBaseline`` — baseline ships every required field; parametrize cases probe how different requests (empty, partial override) combine with that baseline. - ``TestMergeRaisesWithIncompleteBaseline`` — baseline ships an incomplete definition so the resolve-time check fires. The request is always all-empty here; the parametrize table only varies the baseline shape + expected error pattern. Shared helpers (``_seed_variant``, ``_merge_via_reader``) move to module scope so both classes consume the same DB and reader setup. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../deployment/test_revision_merge_db.py | 247 +++++++++--------- 1 file changed, 129 insertions(+), 118 deletions(-) diff --git a/tests/unit/manager/repositories/deployment/test_revision_merge_db.py b/tests/unit/manager/repositories/deployment/test_revision_merge_db.py index cf4b4de86f5..fe8db8e8faf 100644 --- a/tests/unit/manager/repositories/deployment/test_revision_merge_db.py +++ b/tests/unit/manager/repositories/deployment/test_revision_merge_db.py @@ -1,23 +1,22 @@ """DB-backed verification of the BA-5983 revision merge contract. -Inserts real ``RuntimeVariantRow`` records (so the variant's +A real ``RuntimeVariantRow`` is seeded via a fixture (so the variant's ``default_model_definition`` round-trips through ``PydanticColumn`` -serialization) and runs the production ``RevisionDraftReader`` + -``RevisionDraft.merge`` pipeline against a request draft built from -``ModelDefinitionInput.to_draft()``. The resolved output is then -inspected to confirm: +serialization). The production ``RevisionDraftReader`` + +``RevisionDraft.merge`` pipeline is then run against request drafts +built from various ``ModelDefinitionInput`` shapes; the parametrized +table only carries the request input and the expected outcome. -- An empty request inherits every required field from the variant - baseline; the resolved ``ModelDefinition`` carries the baseline - values verbatim. -- A request that supplies a subset of fields overrides only those - fields; baseline-supplied fields survive. -- When no source supplies a required field, ``to_resolved()`` raises - ``ValueError`` with the field-specific message. +Two scenario groups, each pinned to its own DB baseline fixture: -This exercises the full read path (DB → ``PydanticColumn`` → -``RuntimeVariantData`` → ``RevisionDraft``) plus the merge and -resolve phases that the ``add_model_revision`` action ultimately runs. +- ``TestMergeWithFullBaseline`` — variant ships every required field; + the parametrized inputs probe how different requests combine with + it (inherit-all, partial override). +- ``TestMergeRaisesWithIncompleteBaseline`` — variant ships an + incomplete definition where ``to_resolved()`` is expected to raise + because no source supplies a required nested field. Each parametrize + entry pairs an incomplete baseline shape with the expected error + pattern; the request is always all-empty. """ from __future__ import annotations @@ -60,92 +59,105 @@ class ResolvedExpectation: health_check_path: str | None = None -class TestRevisionMergeWithRealVariantBaseline: - @pytest.fixture - async def db_with_variant_table( - self, - database_connection: ExtendedAsyncSAEngine, - ) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: - async with with_tables(database_connection, [RuntimeVariantRow]): - yield database_connection +@pytest.fixture +async def db_with_variant_table( + database_connection: ExtendedAsyncSAEngine, +) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: + async with with_tables(database_connection, [RuntimeVariantRow]): + yield database_connection - @pytest.fixture - def reader( - self, - db_with_variant_table: ExtendedAsyncSAEngine, - ) -> RevisionDraftReader: - # ``load_deployment_revision_read_bundle`` only touches the - # runtime_variants table when ``preset_id`` is ``None``; the - # other repository dependencies are not exercised and can be - # stubbed out. - repo = DeploymentRepository( - db=db_with_variant_table, - storage_manager=MagicMock(), - valkey_stat=MagicMock(), - valkey_live=MagicMock(), - valkey_schedule=MagicMock(), - ) - return RevisionDraftReader(deployment_repository=repo) - @pytest.fixture - def mounts(self) -> MountMetadata: - return MountMetadata( - model_vfolder_id=VFolderUUID(uuid.uuid4()), - model_definition_path=None, - model_mount_destination="/models", - extra_mounts=[], - ) +@pytest.fixture +def reader( + db_with_variant_table: ExtendedAsyncSAEngine, +) -> RevisionDraftReader: + # ``load_deployment_revision_read_bundle`` only touches the + # runtime_variants table when ``preset_id`` is ``None``; the other + # repository dependencies are not exercised here and can be stubbed. + repo = DeploymentRepository( + db=db_with_variant_table, + storage_manager=MagicMock(), + valkey_stat=MagicMock(), + valkey_live=MagicMock(), + valkey_schedule=MagicMock(), + ) + return RevisionDraftReader(deployment_repository=repo) - @staticmethod - async def _seed_variant_baseline( - db: ExtendedAsyncSAEngine, - baseline: ModelDefinitionDraft, - ) -> RuntimeVariantID: - variant_id = RuntimeVariantID(uuid.uuid4()) - async with db.begin_session() as sess: - sess.add( - RuntimeVariantRow( - id=variant_id, - name=f"test-variant-{variant_id.hex[:8]}", - description="BA-5983 merge-test variant baseline", - reads_vfolder_config_files=False, - default_model_definition=baseline, - ) + +@pytest.fixture +def mounts() -> MountMetadata: + return MountMetadata( + model_vfolder_id=VFolderUUID(uuid.uuid4()), + model_definition_path=None, + model_mount_destination="/models", + extra_mounts=[], + ) + + +async def _seed_variant( + db: ExtendedAsyncSAEngine, + baseline: ModelDefinitionDraft, +) -> RuntimeVariantID: + variant_id = RuntimeVariantID(uuid.uuid4()) + async with db.begin_session() as sess: + sess.add( + RuntimeVariantRow( + id=variant_id, + name=f"test-variant-{variant_id.hex[:8]}", + description="BA-5983 merge-test variant baseline", + reads_vfolder_config_files=False, + default_model_definition=baseline, ) - await sess.commit() - return variant_id + ) + await sess.commit() + return variant_id - @staticmethod - async def _merge_via_reader( - reader: RevisionDraftReader, - variant_id: RuntimeVariantID, - request: RevisionDraft, - mounts: MountMetadata, - ) -> RevisionDraft: - drafts = await reader.read_for_deployment_revision( - runtime_variant_id=variant_id, - request_draft=request, - mounts=mounts, - preset_id=None, + +async def _merge_via_reader( + reader: RevisionDraftReader, + variant_id: RuntimeVariantID, + request: RevisionDraft, + mounts: MountMetadata, +) -> RevisionDraft: + drafts = await reader.read_for_deployment_revision( + runtime_variant_id=variant_id, + request_draft=request, + mounts=mounts, + preset_id=None, + ) + return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) + + +class TestMergeWithFullBaseline: + """Baseline supplies every required field. The parametrize table + pairs each ``ModelDefinitionInput`` shape with the resolved values + we expect after merging it against this baseline.""" + + @pytest.fixture + async def variant_id( + self, + db_with_variant_table: ExtendedAsyncSAEngine, + ) -> RuntimeVariantID: + return await _seed_variant( + db_with_variant_table, + ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="baseline-llama", + model_path="/models/baseline", + service=ModelServiceConfigDraft( + port=9000, + health_check=ModelHealthCheckDraft(path="/healthz"), + ), + ), + ], + ), ) - return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) @pytest.mark.parametrize( - ("baseline", "request_input", "expected"), + ("request_input", "expected"), [ pytest.param( - ModelDefinitionDraft( - models=[ - ModelConfigDraft( - name="baseline-llama", - model_path="/models/baseline", - service=ModelServiceConfigDraft( - port=9000, - health_check=ModelHealthCheckDraft(path="/healthz"), - ), - ), - ], - ), ModelDefinitionInput(), ResolvedExpectation( name="baseline-llama", @@ -156,30 +168,28 @@ async def _merge_via_reader( id="empty_request_inherits_full_baseline", ), pytest.param( - ModelDefinitionDraft( - models=[ - ModelConfigDraft(name="baseline-name", model_path="/baseline/path"), - ], - ), ModelDefinitionInput(models=[ModelConfigInput(name="user-name")]), - ResolvedExpectation(name="user-name", model_path="/baseline/path"), - id="request_overrides_name_baseline_keeps_model_path", + ResolvedExpectation( + name="user-name", + model_path="/models/baseline", + service_port=9000, + health_check_path="/healthz", + ), + id="request_overrides_name_only", ), ], ) async def test_merge_resolves_to_expected_values( self, - db_with_variant_table: ExtendedAsyncSAEngine, reader: RevisionDraftReader, mounts: MountMetadata, - baseline: ModelDefinitionDraft, + variant_id: RuntimeVariantID, request_input: ModelDefinitionInput, expected: ResolvedExpectation, ) -> None: - variant_id = await self._seed_variant_baseline(db_with_variant_table, baseline) request = RevisionDraft(model_definition=request_input.to_draft()) - merged = await self._merge_via_reader(reader, variant_id, request, mounts) + merged = await _merge_via_reader(reader, variant_id, request, mounts) assert merged.model_definition is not None resolved = merged.model_definition.to_resolved() @@ -194,19 +204,25 @@ async def test_merge_resolves_to_expected_values( assert model.service.health_check is not None assert model.service.health_check.path == expected.health_check_path + +class TestMergeRaisesWithIncompleteBaseline: + """Each parametrize entry seeds its own incomplete baseline (via the + ``baseline_factory``) and expects ``to_resolved()`` to raise because + no source supplies a required field. The request is always + all-empty so the failure mode comes entirely from the baseline.""" + @pytest.mark.parametrize( - ("baseline", "error_pattern"), + ("incomplete_baseline", "error_pattern"), [ pytest.param( - # baseline supplies model_path only; reader's mount-destination - # default would also fill model_path → only ``name`` remains unfilled. + # Reader's mount-destination default also fills model_path, + # so the only required ``ModelConfig`` field that ends up + # unfilled is ``name``. ModelDefinitionDraft(models=[ModelConfigDraft(model_path="/p")]), r"ModelConfig\.name is required", - id="name_unfilled_across_baseline_and_request", + id="name_unfilled", ), pytest.param( - # baseline supplies name + model_path + an empty service → - # service.port has no default and no override. ModelDefinitionDraft( models=[ ModelConfigDraft( @@ -217,11 +233,9 @@ async def test_merge_resolves_to_expected_values( ], ), r"ModelServiceConfig\.port is required", - id="service_port_unfilled_across_baseline_and_request", + id="service_port_unfilled", ), pytest.param( - # baseline supplies service.port but an empty health_check → - # health_check.path has no default. ModelDefinitionDraft( models=[ ModelConfigDraft( @@ -235,7 +249,7 @@ async def test_merge_resolves_to_expected_values( ], ), r"ModelHealthCheck\.path is required", - id="health_check_path_unfilled_across_baseline_and_request", + id="health_check_path_unfilled", ), ], ) @@ -244,16 +258,13 @@ async def test_required_field_unfilled_after_merge_raises( db_with_variant_table: ExtendedAsyncSAEngine, reader: RevisionDraftReader, mounts: MountMetadata, - baseline: ModelDefinitionDraft, + incomplete_baseline: ModelDefinitionDraft, error_pattern: str, ) -> None: - # Request is always an all-empty ``ModelDefinitionInput`` for these - # scenarios — the merge result depends entirely on whether the - # baseline (or reader-supplied defaults) cover every required field. - variant_id = await self._seed_variant_baseline(db_with_variant_table, baseline) + variant_id = await _seed_variant(db_with_variant_table, incomplete_baseline) request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) - merged = await self._merge_via_reader(reader, variant_id, request, mounts) + merged = await _merge_via_reader(reader, variant_id, request, mounts) assert merged.model_definition is not None with pytest.raises(ValueError, match=error_pattern): From 58eb89b3539850878de2ba14242a6c64eeed8f99 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:42:11 +0900 Subject: [PATCH 15/18] test(BA-5983): one fixture per DB baseline shape; parametrize only inputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Restructure so each baseline shape lives in its own test class with its own ``variant_id`` fixture seeding the DB. The parametrize tables inside each class only carry the request ``ModelDefinitionInput`` and the expected outcome — there is no ``baseline``/``incomplete_baseline`` parameter, since "what's in the DB" is the fixture's responsibility. Four classes, one per baseline shape: - ``TestMergeWithCompleteBaseline`` — variant ships every required field; any request resolves successfully. - ``TestMergeWhenBaselineLacksName`` — request must supply ``name``, otherwise ``to_resolved()`` raises. - ``TestMergeWhenBaselineLacksServicePort`` — request must supply ``service.port``. - ``TestMergeWhenBaselineLacksHealthCheckPath`` — request must supply ``service.health_check.path``. Each "lacks-X" class pairs a parametrized success test (request supplies the missing field) with a dedicated failure test (empty request → ``to_resolved()`` raises). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../deployment/test_revision_merge_db.py | 289 +++++++++++++----- 1 file changed, 218 insertions(+), 71 deletions(-) diff --git a/tests/unit/manager/repositories/deployment/test_revision_merge_db.py b/tests/unit/manager/repositories/deployment/test_revision_merge_db.py index fe8db8e8faf..767f37ccd05 100644 --- a/tests/unit/manager/repositories/deployment/test_revision_merge_db.py +++ b/tests/unit/manager/repositories/deployment/test_revision_merge_db.py @@ -1,22 +1,16 @@ """DB-backed verification of the BA-5983 revision merge contract. -A real ``RuntimeVariantRow`` is seeded via a fixture (so the variant's -``default_model_definition`` round-trips through ``PydanticColumn`` -serialization). The production ``RevisionDraftReader`` + -``RevisionDraft.merge`` pipeline is then run against request drafts -built from various ``ModelDefinitionInput`` shapes; the parametrized -table only carries the request input and the expected outcome. - -Two scenario groups, each pinned to its own DB baseline fixture: - -- ``TestMergeWithFullBaseline`` — variant ships every required field; - the parametrized inputs probe how different requests combine with - it (inherit-all, partial override). -- ``TestMergeRaisesWithIncompleteBaseline`` — variant ships an - incomplete definition where ``to_resolved()`` is expected to raise - because no source supplies a required nested field. Each parametrize - entry pairs an incomplete baseline shape with the expected error - pattern; the request is always all-empty. +Each test class seeds one specific ``RuntimeVariantRow.default_model_definition`` +shape into the DB (so it round-trips through ``PydanticColumn`` +serialization) and runs the production ``RevisionDraftReader`` + +``RevisionDraft.merge`` pipeline against various request inputs. The +parametrize tables only carry the request ``ModelDefinitionInput`` and +the expected resolved values — the DB baseline is fixed per class via +its ``variant_id`` fixture. + +Scenarios are partitioned by baseline shape so each class makes the +"what's in the DB" / "what the user sends" / "what should come out" +relationship obvious at a glance. """ from __future__ import annotations @@ -38,6 +32,8 @@ from ai.backend.common.dto.manager.v2.deployment.request import ( ModelConfigInput, ModelDefinitionInput, + ModelHealthCheckInput, + ModelServiceConfigInput, ) from ai.backend.common.identifier.runtime_variant import RuntimeVariantID from ai.backend.common.identifier.vfolder import VFolderUUID @@ -128,10 +124,27 @@ async def _merge_via_reader( return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) -class TestMergeWithFullBaseline: - """Baseline supplies every required field. The parametrize table - pairs each ``ModelDefinitionInput`` shape with the resolved values - we expect after merging it against this baseline.""" +def _assert_resolved_matches(merged: RevisionDraft, expected: ResolvedExpectation) -> None: + assert merged.model_definition is not None + resolved = merged.model_definition.to_resolved() + model = resolved.models[0] + assert model.name == expected.name + assert model.model_path == expected.model_path + if expected.service_port is not None: + assert model.service is not None + assert model.service.port == expected.service_port + if expected.health_check_path is not None: + assert model.service is not None + assert model.service.health_check is not None + assert model.service.health_check.path == expected.health_check_path + + +class TestMergeWithCompleteBaseline: + """The variant ships a fully-populated ``default_model_definition``. + + Any request — including all-empty — resolves successfully because + the DB-side baseline already covers every required field. + """ @pytest.fixture async def variant_id( @@ -165,7 +178,7 @@ async def variant_id( service_port=9000, health_check_path="/healthz", ), - id="empty_request_inherits_full_baseline", + id="empty_request_inherits_baseline", ), pytest.param( ModelDefinitionInput(models=[ModelConfigInput(name="user-name")]), @@ -179,7 +192,7 @@ async def variant_id( ), ], ) - async def test_merge_resolves_to_expected_values( + async def test_resolves_to_expected_values( self, reader: RevisionDraftReader, mounts: MountMetadata, @@ -191,81 +204,215 @@ async def test_merge_resolves_to_expected_values( merged = await _merge_via_reader(reader, variant_id, request, mounts) - assert merged.model_definition is not None - resolved = merged.model_definition.to_resolved() - model = resolved.models[0] - assert model.name == expected.name - assert model.model_path == expected.model_path - if expected.service_port is not None: - assert model.service is not None - assert model.service.port == expected.service_port - if expected.health_check_path is not None: - assert model.service is not None - assert model.service.health_check is not None - assert model.service.health_check.path == expected.health_check_path - - -class TestMergeRaisesWithIncompleteBaseline: - """Each parametrize entry seeds its own incomplete baseline (via the - ``baseline_factory``) and expects ``to_resolved()`` to raise because - no source supplies a required field. The request is always - all-empty so the failure mode comes entirely from the baseline.""" + _assert_resolved_matches(merged, expected) + + +class TestMergeWhenBaselineLacksName: + """The variant baseline omits ``name``. The merge succeeds only + when the request supplies one — otherwise ``to_resolved()`` raises. + """ + + @pytest.fixture + async def variant_id( + self, + db_with_variant_table: ExtendedAsyncSAEngine, + ) -> RuntimeVariantID: + return await _seed_variant( + db_with_variant_table, + ModelDefinitionDraft( + models=[ModelConfigDraft(model_path="/baseline/path")], + ), + ) @pytest.mark.parametrize( - ("incomplete_baseline", "error_pattern"), + ("request_input", "expected"), [ pytest.param( - # Reader's mount-destination default also fills model_path, - # so the only required ``ModelConfig`` field that ends up - # unfilled is ``name``. - ModelDefinitionDraft(models=[ModelConfigDraft(model_path="/p")]), - r"ModelConfig\.name is required", - id="name_unfilled", + ModelDefinitionInput(models=[ModelConfigInput(name="from-request")]), + ResolvedExpectation(name="from-request", model_path="/baseline/path"), + id="request_supplies_missing_name", + ), + ], + ) + async def test_request_supplying_name_resolves( + self, + reader: RevisionDraftReader, + mounts: MountMetadata, + variant_id: RuntimeVariantID, + request_input: ModelDefinitionInput, + expected: ResolvedExpectation, + ) -> None: + request = RevisionDraft(model_definition=request_input.to_draft()) + + merged = await _merge_via_reader(reader, variant_id, request, mounts) + + _assert_resolved_matches(merged, expected) + + async def test_empty_request_raises_name_required( + self, + reader: RevisionDraftReader, + mounts: MountMetadata, + variant_id: RuntimeVariantID, + ) -> None: + request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) + + merged = await _merge_via_reader(reader, variant_id, request, mounts) + + assert merged.model_definition is not None + with pytest.raises(ValueError, match=r"ModelConfig\.name is required"): + merged.model_definition.to_resolved() + + +class TestMergeWhenBaselineLacksServicePort: + """The variant baseline supplies a ``service`` block without + ``port``. The merge succeeds only when the request supplies the + port — otherwise ``to_resolved()`` raises. + """ + + @pytest.fixture + async def variant_id( + self, + db_with_variant_table: ExtendedAsyncSAEngine, + ) -> RuntimeVariantID: + return await _seed_variant( + db_with_variant_table, + ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="baseline", + model_path="/baseline/path", + service=ModelServiceConfigDraft( + health_check=ModelHealthCheckDraft(path="/healthz"), + ), + ), + ], ), + ) + + @pytest.mark.parametrize( + ("request_input", "expected"), + [ pytest.param( - ModelDefinitionDraft( + ModelDefinitionInput( models=[ - ModelConfigDraft( - name="n", - model_path="/p", - service=ModelServiceConfigDraft(), + ModelConfigInput( + service=ModelServiceConfigInput(port=8080), ), ], ), - r"ModelServiceConfig\.port is required", - id="service_port_unfilled", + ResolvedExpectation( + name="baseline", + model_path="/baseline/path", + service_port=8080, + health_check_path="/healthz", + ), + id="request_supplies_service_port", ), + ], + ) + async def test_request_supplying_port_resolves( + self, + reader: RevisionDraftReader, + mounts: MountMetadata, + variant_id: RuntimeVariantID, + request_input: ModelDefinitionInput, + expected: ResolvedExpectation, + ) -> None: + request = RevisionDraft(model_definition=request_input.to_draft()) + + merged = await _merge_via_reader(reader, variant_id, request, mounts) + + _assert_resolved_matches(merged, expected) + + async def test_empty_request_raises_port_required( + self, + reader: RevisionDraftReader, + mounts: MountMetadata, + variant_id: RuntimeVariantID, + ) -> None: + request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) + + merged = await _merge_via_reader(reader, variant_id, request, mounts) + + assert merged.model_definition is not None + with pytest.raises(ValueError, match=r"ModelServiceConfig\.port is required"): + merged.model_definition.to_resolved() + + +class TestMergeWhenBaselineLacksHealthCheckPath: + """The variant baseline supplies ``service.health_check`` without + ``path``. The merge succeeds only when the request supplies the + path — otherwise ``to_resolved()`` raises. + """ + + @pytest.fixture + async def variant_id( + self, + db_with_variant_table: ExtendedAsyncSAEngine, + ) -> RuntimeVariantID: + return await _seed_variant( + db_with_variant_table, + ModelDefinitionDraft( + models=[ + ModelConfigDraft( + name="baseline", + model_path="/baseline/path", + service=ModelServiceConfigDraft( + port=8080, + health_check=ModelHealthCheckDraft(), + ), + ), + ], + ), + ) + + @pytest.mark.parametrize( + ("request_input", "expected"), + [ pytest.param( - ModelDefinitionDraft( + ModelDefinitionInput( models=[ - ModelConfigDraft( - name="n", - model_path="/p", - service=ModelServiceConfigDraft( - port=8080, - health_check=ModelHealthCheckDraft(), + ModelConfigInput( + service=ModelServiceConfigInput( + health_check=ModelHealthCheckInput(path="/ready"), ), ), ], ), - r"ModelHealthCheck\.path is required", - id="health_check_path_unfilled", + ResolvedExpectation( + name="baseline", + model_path="/baseline/path", + service_port=8080, + health_check_path="/ready", + ), + id="request_supplies_health_check_path", ), ], ) - async def test_required_field_unfilled_after_merge_raises( + async def test_request_supplying_path_resolves( + self, + reader: RevisionDraftReader, + mounts: MountMetadata, + variant_id: RuntimeVariantID, + request_input: ModelDefinitionInput, + expected: ResolvedExpectation, + ) -> None: + request = RevisionDraft(model_definition=request_input.to_draft()) + + merged = await _merge_via_reader(reader, variant_id, request, mounts) + + _assert_resolved_matches(merged, expected) + + async def test_empty_request_raises_health_check_path_required( self, - db_with_variant_table: ExtendedAsyncSAEngine, reader: RevisionDraftReader, mounts: MountMetadata, - incomplete_baseline: ModelDefinitionDraft, - error_pattern: str, + variant_id: RuntimeVariantID, ) -> None: - variant_id = await _seed_variant(db_with_variant_table, incomplete_baseline) request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) merged = await _merge_via_reader(reader, variant_id, request, mounts) assert merged.model_definition is not None - with pytest.raises(ValueError, match=error_pattern): + with pytest.raises(ValueError, match=r"ModelHealthCheck\.path is required"): merged.model_definition.to_resolved() From 90d5645d43faef15741bbfc5841670d85b5d6b20 Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:47:46 +0900 Subject: [PATCH 16/18] test(BA-5983): consolidate merge tests into the DB-backed file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The synthetic merge test file (test_model_definition_merge.py) duplicated every scenario the DB-backed test already covers — same "empty input + baseline → resolved values" and "missing field → raise" shapes, but with hand-built drafts instead of the real ``RevisionDraftReader`` path. Keep the realistic version (test_revision_merge_db.py) as the single source of truth for the merge contract. The one piece worth preserving from the synthetic file is the pure ``to_draft`` conversion check. Move that into the existing DTO test module as ``TestModelDefinitionInputToDraft``, tightened to assert ``draft.model_fields_set == input.model_fields_set`` — the precise invariant the merge logic relies on (BA-5983 broke because ``model_dump()`` was used without ``exclude_unset=True``, leaving every field "set" on the draft). Co-Authored-By: Claude Opus 4.7 (1M context) --- .../dto/manager/v2/deployment/test_request.py | 47 +++ .../deployment/test_model_definition_merge.py | 292 ------------------ 2 files changed, 47 insertions(+), 292 deletions(-) delete mode 100644 tests/unit/manager/sokovan/deployment/test_model_definition_merge.py diff --git a/tests/unit/common/dto/manager/v2/deployment/test_request.py b/tests/unit/common/dto/manager/v2/deployment/test_request.py index b5756ea086d..1dc0cea59ae 100644 --- a/tests/unit/common/dto/manager/v2/deployment/test_request.py +++ b/tests/unit/common/dto/manager/v2/deployment/test_request.py @@ -11,6 +11,7 @@ from pydantic import ValidationError from ai.backend.common.api_handlers import SENTINEL, Sentinel +from ai.backend.common.config import ModelDefinitionDraft from ai.backend.common.data.model_deployment.types import DeploymentStrategy from ai.backend.common.dto.manager.v2.deployment.request import ( ActivateDeploymentInput, @@ -24,11 +25,14 @@ DeploymentStrategyInput, ExtraVFolderMountInput, ImageInput, + ModelConfigInput, ModelDefinitionInput, ModelDeploymentMetadataInput, ModelDeploymentNetworkAccessInput, + ModelHealthCheckInput, ModelMountConfigInput, ModelRuntimeConfigInput, + ModelServiceConfigInput, ResourceConfigInput, ResourceGroupInput, ResourceSlotEntryInput, @@ -150,6 +154,49 @@ def test_with_extra_mounts(self) -> None: assert rev.extra_mounts[0].mount_destination == "/data" +class TestModelDefinitionInputToDraft: + """``ModelDefinitionInput.to_draft`` converts the all-optional v2 + input DTO into the ``ModelDefinitionDraft`` consumed by the + revision merge chain. The conversion must preserve unset semantics + so omitted fields stay unset on the resulting draft — otherwise + every ``None`` would clobber lower-priority sources during merge + (BA-5983). + """ + + @pytest.mark.parametrize( + "input_dto", + [ + pytest.param(ModelDefinitionInput(), id="empty"), + pytest.param( + ModelDefinitionInput(models=[ModelConfigInput(name="only-name")]), + id="partial_name_only", + ), + pytest.param( + ModelDefinitionInput( + models=[ + ModelConfigInput( + name="m", + service=ModelServiceConfigInput( + port=8080, + health_check=ModelHealthCheckInput(path="/healthz"), + ), + ) + ] + ), + id="nested_service_and_health_check", + ), + ], + ) + def test_to_draft_preserves_set_fields(self, input_dto: ModelDefinitionInput) -> None: + draft = input_dto.to_draft() + assert isinstance(draft, ModelDefinitionDraft) + # ``model_fields_set`` on the draft must match what the caller + # explicitly set on the input — that is what the merge logic + # uses to distinguish "unset → defer to baseline" from + # "explicitly None → clobber baseline". + assert draft.model_fields_set == input_dto.model_fields_set + + class TestExtraVFolderMountInput: """Tests for ExtraVFolderMountInput model.""" diff --git a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py b/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py deleted file mode 100644 index 2ccb29f8ed5..00000000000 --- a/tests/unit/manager/sokovan/deployment/test_model_definition_merge.py +++ /dev/null @@ -1,292 +0,0 @@ -"""Verify that nullable v2 ``ModelDefinitionInput`` fields still result in -correct required-field enforcement after the revision merge chain. - -This pins the BA-5983 behavior: the GraphQL/REST boundary accepts -all-optional fields, but ``to_resolved()`` at the persistence boundary -must still raise when no merge layer (request, preset, variant baseline) -supplies a required value. -""" - -from __future__ import annotations - -import functools -from dataclasses import dataclass - -import pytest - -from ai.backend.common.config import ( - ModelConfigDraft, - ModelDefinitionDraft, - ModelHealthCheckDraft, - ModelServiceConfigDraft, -) -from ai.backend.common.dto.manager.v2.deployment.request import ( - ModelConfigInput, - ModelDefinitionInput, - ModelHealthCheckInput, - ModelServiceConfigInput, -) -from ai.backend.manager.data.deployment.types import RevisionDraft - - -def _merge(*drafts: RevisionDraft) -> RevisionDraft: - return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) - - -@dataclass(frozen=True) -class ResolvedExpectation: - """Expected attributes on the resolved ``ModelConfig`` at ``models[0]``. - - Only the named-string fields participate; ``None`` means the - corresponding nested object should not be asserted (the scenario - does not exercise it). - """ - - name: str - model_path: str - service_port: int | None = None - health_check_path: str | None = None - - -class TestModelDefinitionInputToDraft: - """``ModelDefinitionInput.to_draft`` is the bridge between the - all-optional DTO and the merge-chain draft. The conversion itself - must never raise — required-field enforcement is deferred to - ``to_resolved()`` after the merge — and must preserve every field - the input carries (including ``None`` placeholders).""" - - @pytest.mark.parametrize( - "input_dto", - [ - pytest.param(ModelDefinitionInput(), id="empty"), - pytest.param( - ModelDefinitionInput(models=[ModelConfigInput(name="only-name")]), - id="partial_name_only", - ), - pytest.param( - ModelDefinitionInput( - models=[ - ModelConfigInput( - name="m", - service=ModelServiceConfigInput( - port=8080, - health_check=ModelHealthCheckInput(path="/healthz"), - ), - ) - ] - ), - id="nested_service_and_health_check", - ), - ], - ) - def test_to_draft_preserves_input_shape(self, input_dto: ModelDefinitionInput) -> None: - draft = input_dto.to_draft() - assert isinstance(draft, ModelDefinitionDraft) - assert draft.model_dump() == input_dto.model_dump() - - -class TestEmptyInputMergesWithBaseline: - """Empty (all-null) request input must let lower-priority sources - (variant baseline, preset) fill the required fields, and the merged - draft must resolve cleanly.""" - - @pytest.mark.parametrize( - ("drafts", "expected"), - [ - pytest.param( - [ - RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ModelConfigDraft(name="llama", model_path="/models/llama")], - ), - ), - RevisionDraft(model_definition=ModelDefinitionInput().to_draft()), - ], - ResolvedExpectation(name="llama", model_path="/models/llama"), - id="variant_baseline_fills_required", - ), - pytest.param( - [ - RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ - ModelConfigDraft( - name="from-preset", - model_path="/preset/path", - service=ModelServiceConfigDraft( - port=9000, - health_check=ModelHealthCheckDraft(path="/ready"), - ), - ) - ], - ), - ), - RevisionDraft(model_definition=ModelDefinitionInput().to_draft()), - ], - ResolvedExpectation( - name="from-preset", - model_path="/preset/path", - service_port=9000, - health_check_path="/ready", - ), - id="preset_fills_nested_required", - ), - pytest.param( - [ - RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ - ModelConfigDraft(name="baseline-name", model_path="/baseline/path"), - ], - ), - ), - RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ModelConfigInput(name="user-name")], - ).to_draft(), - ), - ], - ResolvedExpectation(name="user-name", model_path="/baseline/path"), - id="request_partial_overrides_baseline", - ), - ], - ) - def test_merge_resolves_to_expected_values( - self, drafts: list[RevisionDraft], expected: ResolvedExpectation - ) -> None: - merged = _merge(*drafts) - - assert merged.model_definition is not None - resolved = merged.model_definition.to_resolved() - model = resolved.models[0] - assert model.name == expected.name - assert model.model_path == expected.model_path - if expected.service_port is not None: - assert model.service is not None - assert model.service.port == expected.service_port - if expected.health_check_path is not None: - assert model.service is not None - assert model.service.health_check is not None - assert model.service.health_check.path == expected.health_check_path - - -class TestMergeRaisesWhenAllSourcesAreEmpty: - """When neither the request nor any baseline source supplies a - required field, ``to_resolved()`` must raise at the persistence - boundary — preserving the pre-BA-5983 contract. - - Each scenario layers a baseline draft (variant-style) together with - a request draft so the merge actually combines fields across - sources; the target required field remains unfilled in every layer - and the resolved-time check fires on it specifically.""" - - @pytest.mark.parametrize( - ("drafts", "error_pattern"), - [ - pytest.param( - [ - # baseline supplies model_path; request adds nothing. - RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ModelConfigDraft(model_path="/baseline/path")], - ), - ), - RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ModelConfigInput()], - ).to_draft(), - ), - ], - r"ModelConfig\.name is required", - id="name_unfilled_across_baseline_and_request", - ), - pytest.param( - [ - # baseline supplies name; request adds nothing. - RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ModelConfigDraft(name="baseline-name")], - ), - ), - RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ModelConfigInput()], - ).to_draft(), - ), - ], - r"ModelConfig\.model_path is required", - id="model_path_unfilled_across_baseline_and_request", - ), - pytest.param( - [ - # baseline supplies the outer ModelConfig fields; - # request adds an empty service (no port anywhere). - RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ModelConfigDraft(name="n", model_path="/p")], - ), - ), - RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ModelConfigInput(service=ModelServiceConfigInput())], - ).to_draft(), - ), - ], - r"ModelServiceConfig\.port is required", - id="service_port_unfilled_across_baseline_and_request", - ), - pytest.param( - [ - # baseline supplies a service with port; request adds - # an empty health_check (no path anywhere). - RevisionDraft( - model_definition=ModelDefinitionDraft( - models=[ - ModelConfigDraft( - name="n", - model_path="/p", - service=ModelServiceConfigDraft(port=8080), - ) - ], - ), - ), - RevisionDraft( - model_definition=ModelDefinitionInput( - models=[ - ModelConfigInput( - service=ModelServiceConfigInput( - health_check=ModelHealthCheckInput(), - ), - ) - ], - ).to_draft(), - ), - ], - r"ModelHealthCheck\.path is required", - id="health_check_path_unfilled_across_baseline_and_request", - ), - ], - ) - def test_required_field_unfilled_after_merge_raises( - self, drafts: list[RevisionDraft], error_pattern: str - ) -> None: - merged = _merge(*drafts) - - assert merged.model_definition is not None - with pytest.raises(ValueError, match=error_pattern): - merged.model_definition.to_resolved() - - def test_empty_request_with_no_baseline_yields_empty_resolved(self) -> None: - """A completely empty merge chain resolves to an empty ModelDefinition. - - The ``add_revision`` controller guards against this case separately - (``model_definition.models must contain at least one entry``); the - resolved type itself permits an empty models list. - """ - request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) - - merged = _merge(request) - - assert merged.model_definition is not None - resolved = merged.model_definition.to_resolved() - assert resolved.models == [] From 8366021b1a4ead22d95b997a9fd853d7496f518f Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 12:52:50 +0900 Subject: [PATCH 17/18] test(BA-5983): drop tests added in this PR MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Remove the DB-backed revision merge test and the to_draft invariant test added in this PR. Existing tests in ``tests/unit/common/dto/manager/v2/deployment/test_request.py`` remain — those were updated only to swap ``ModelDefinitionDraft()`` for ``ModelDefinitionInput()`` so they typecheck against the new field type. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../dto/manager/v2/deployment/test_request.py | 47 -- .../deployment/test_revision_merge_db.py | 418 ------------------ 2 files changed, 465 deletions(-) delete mode 100644 tests/unit/manager/repositories/deployment/test_revision_merge_db.py diff --git a/tests/unit/common/dto/manager/v2/deployment/test_request.py b/tests/unit/common/dto/manager/v2/deployment/test_request.py index 1dc0cea59ae..b5756ea086d 100644 --- a/tests/unit/common/dto/manager/v2/deployment/test_request.py +++ b/tests/unit/common/dto/manager/v2/deployment/test_request.py @@ -11,7 +11,6 @@ from pydantic import ValidationError from ai.backend.common.api_handlers import SENTINEL, Sentinel -from ai.backend.common.config import ModelDefinitionDraft from ai.backend.common.data.model_deployment.types import DeploymentStrategy from ai.backend.common.dto.manager.v2.deployment.request import ( ActivateDeploymentInput, @@ -25,14 +24,11 @@ DeploymentStrategyInput, ExtraVFolderMountInput, ImageInput, - ModelConfigInput, ModelDefinitionInput, ModelDeploymentMetadataInput, ModelDeploymentNetworkAccessInput, - ModelHealthCheckInput, ModelMountConfigInput, ModelRuntimeConfigInput, - ModelServiceConfigInput, ResourceConfigInput, ResourceGroupInput, ResourceSlotEntryInput, @@ -154,49 +150,6 @@ def test_with_extra_mounts(self) -> None: assert rev.extra_mounts[0].mount_destination == "/data" -class TestModelDefinitionInputToDraft: - """``ModelDefinitionInput.to_draft`` converts the all-optional v2 - input DTO into the ``ModelDefinitionDraft`` consumed by the - revision merge chain. The conversion must preserve unset semantics - so omitted fields stay unset on the resulting draft — otherwise - every ``None`` would clobber lower-priority sources during merge - (BA-5983). - """ - - @pytest.mark.parametrize( - "input_dto", - [ - pytest.param(ModelDefinitionInput(), id="empty"), - pytest.param( - ModelDefinitionInput(models=[ModelConfigInput(name="only-name")]), - id="partial_name_only", - ), - pytest.param( - ModelDefinitionInput( - models=[ - ModelConfigInput( - name="m", - service=ModelServiceConfigInput( - port=8080, - health_check=ModelHealthCheckInput(path="/healthz"), - ), - ) - ] - ), - id="nested_service_and_health_check", - ), - ], - ) - def test_to_draft_preserves_set_fields(self, input_dto: ModelDefinitionInput) -> None: - draft = input_dto.to_draft() - assert isinstance(draft, ModelDefinitionDraft) - # ``model_fields_set`` on the draft must match what the caller - # explicitly set on the input — that is what the merge logic - # uses to distinguish "unset → defer to baseline" from - # "explicitly None → clobber baseline". - assert draft.model_fields_set == input_dto.model_fields_set - - class TestExtraVFolderMountInput: """Tests for ExtraVFolderMountInput model.""" diff --git a/tests/unit/manager/repositories/deployment/test_revision_merge_db.py b/tests/unit/manager/repositories/deployment/test_revision_merge_db.py deleted file mode 100644 index 767f37ccd05..00000000000 --- a/tests/unit/manager/repositories/deployment/test_revision_merge_db.py +++ /dev/null @@ -1,418 +0,0 @@ -"""DB-backed verification of the BA-5983 revision merge contract. - -Each test class seeds one specific ``RuntimeVariantRow.default_model_definition`` -shape into the DB (so it round-trips through ``PydanticColumn`` -serialization) and runs the production ``RevisionDraftReader`` + -``RevisionDraft.merge`` pipeline against various request inputs. The -parametrize tables only carry the request ``ModelDefinitionInput`` and -the expected resolved values — the DB baseline is fixed per class via -its ``variant_id`` fixture. - -Scenarios are partitioned by baseline shape so each class makes the -"what's in the DB" / "what the user sends" / "what should come out" -relationship obvious at a glance. -""" - -from __future__ import annotations - -import functools -import uuid -from collections.abc import AsyncGenerator -from dataclasses import dataclass -from unittest.mock import MagicMock - -import pytest - -from ai.backend.common.config import ( - ModelConfigDraft, - ModelDefinitionDraft, - ModelHealthCheckDraft, - ModelServiceConfigDraft, -) -from ai.backend.common.dto.manager.v2.deployment.request import ( - ModelConfigInput, - ModelDefinitionInput, - ModelHealthCheckInput, - ModelServiceConfigInput, -) -from ai.backend.common.identifier.runtime_variant import RuntimeVariantID -from ai.backend.common.identifier.vfolder import VFolderUUID -from ai.backend.manager.data.deployment.types import MountMetadata, RevisionDraft -from ai.backend.manager.models.runtime_variant import RuntimeVariantRow -from ai.backend.manager.models.utils import ExtendedAsyncSAEngine -from ai.backend.manager.repositories.deployment.repository import DeploymentRepository -from ai.backend.manager.sokovan.deployment.revision_draft import RevisionDraftReader -from ai.backend.testutils.db import with_tables - - -@dataclass(frozen=True) -class ResolvedExpectation: - """Expected attributes on the resolved ``ModelConfig`` at ``models[0]``.""" - - name: str - model_path: str - service_port: int | None = None - health_check_path: str | None = None - - -@pytest.fixture -async def db_with_variant_table( - database_connection: ExtendedAsyncSAEngine, -) -> AsyncGenerator[ExtendedAsyncSAEngine, None]: - async with with_tables(database_connection, [RuntimeVariantRow]): - yield database_connection - - -@pytest.fixture -def reader( - db_with_variant_table: ExtendedAsyncSAEngine, -) -> RevisionDraftReader: - # ``load_deployment_revision_read_bundle`` only touches the - # runtime_variants table when ``preset_id`` is ``None``; the other - # repository dependencies are not exercised here and can be stubbed. - repo = DeploymentRepository( - db=db_with_variant_table, - storage_manager=MagicMock(), - valkey_stat=MagicMock(), - valkey_live=MagicMock(), - valkey_schedule=MagicMock(), - ) - return RevisionDraftReader(deployment_repository=repo) - - -@pytest.fixture -def mounts() -> MountMetadata: - return MountMetadata( - model_vfolder_id=VFolderUUID(uuid.uuid4()), - model_definition_path=None, - model_mount_destination="/models", - extra_mounts=[], - ) - - -async def _seed_variant( - db: ExtendedAsyncSAEngine, - baseline: ModelDefinitionDraft, -) -> RuntimeVariantID: - variant_id = RuntimeVariantID(uuid.uuid4()) - async with db.begin_session() as sess: - sess.add( - RuntimeVariantRow( - id=variant_id, - name=f"test-variant-{variant_id.hex[:8]}", - description="BA-5983 merge-test variant baseline", - reads_vfolder_config_files=False, - default_model_definition=baseline, - ) - ) - await sess.commit() - return variant_id - - -async def _merge_via_reader( - reader: RevisionDraftReader, - variant_id: RuntimeVariantID, - request: RevisionDraft, - mounts: MountMetadata, -) -> RevisionDraft: - drafts = await reader.read_for_deployment_revision( - runtime_variant_id=variant_id, - request_draft=request, - mounts=mounts, - preset_id=None, - ) - return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft()) - - -def _assert_resolved_matches(merged: RevisionDraft, expected: ResolvedExpectation) -> None: - assert merged.model_definition is not None - resolved = merged.model_definition.to_resolved() - model = resolved.models[0] - assert model.name == expected.name - assert model.model_path == expected.model_path - if expected.service_port is not None: - assert model.service is not None - assert model.service.port == expected.service_port - if expected.health_check_path is not None: - assert model.service is not None - assert model.service.health_check is not None - assert model.service.health_check.path == expected.health_check_path - - -class TestMergeWithCompleteBaseline: - """The variant ships a fully-populated ``default_model_definition``. - - Any request — including all-empty — resolves successfully because - the DB-side baseline already covers every required field. - """ - - @pytest.fixture - async def variant_id( - self, - db_with_variant_table: ExtendedAsyncSAEngine, - ) -> RuntimeVariantID: - return await _seed_variant( - db_with_variant_table, - ModelDefinitionDraft( - models=[ - ModelConfigDraft( - name="baseline-llama", - model_path="/models/baseline", - service=ModelServiceConfigDraft( - port=9000, - health_check=ModelHealthCheckDraft(path="/healthz"), - ), - ), - ], - ), - ) - - @pytest.mark.parametrize( - ("request_input", "expected"), - [ - pytest.param( - ModelDefinitionInput(), - ResolvedExpectation( - name="baseline-llama", - model_path="/models/baseline", - service_port=9000, - health_check_path="/healthz", - ), - id="empty_request_inherits_baseline", - ), - pytest.param( - ModelDefinitionInput(models=[ModelConfigInput(name="user-name")]), - ResolvedExpectation( - name="user-name", - model_path="/models/baseline", - service_port=9000, - health_check_path="/healthz", - ), - id="request_overrides_name_only", - ), - ], - ) - async def test_resolves_to_expected_values( - self, - reader: RevisionDraftReader, - mounts: MountMetadata, - variant_id: RuntimeVariantID, - request_input: ModelDefinitionInput, - expected: ResolvedExpectation, - ) -> None: - request = RevisionDraft(model_definition=request_input.to_draft()) - - merged = await _merge_via_reader(reader, variant_id, request, mounts) - - _assert_resolved_matches(merged, expected) - - -class TestMergeWhenBaselineLacksName: - """The variant baseline omits ``name``. The merge succeeds only - when the request supplies one — otherwise ``to_resolved()`` raises. - """ - - @pytest.fixture - async def variant_id( - self, - db_with_variant_table: ExtendedAsyncSAEngine, - ) -> RuntimeVariantID: - return await _seed_variant( - db_with_variant_table, - ModelDefinitionDraft( - models=[ModelConfigDraft(model_path="/baseline/path")], - ), - ) - - @pytest.mark.parametrize( - ("request_input", "expected"), - [ - pytest.param( - ModelDefinitionInput(models=[ModelConfigInput(name="from-request")]), - ResolvedExpectation(name="from-request", model_path="/baseline/path"), - id="request_supplies_missing_name", - ), - ], - ) - async def test_request_supplying_name_resolves( - self, - reader: RevisionDraftReader, - mounts: MountMetadata, - variant_id: RuntimeVariantID, - request_input: ModelDefinitionInput, - expected: ResolvedExpectation, - ) -> None: - request = RevisionDraft(model_definition=request_input.to_draft()) - - merged = await _merge_via_reader(reader, variant_id, request, mounts) - - _assert_resolved_matches(merged, expected) - - async def test_empty_request_raises_name_required( - self, - reader: RevisionDraftReader, - mounts: MountMetadata, - variant_id: RuntimeVariantID, - ) -> None: - request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) - - merged = await _merge_via_reader(reader, variant_id, request, mounts) - - assert merged.model_definition is not None - with pytest.raises(ValueError, match=r"ModelConfig\.name is required"): - merged.model_definition.to_resolved() - - -class TestMergeWhenBaselineLacksServicePort: - """The variant baseline supplies a ``service`` block without - ``port``. The merge succeeds only when the request supplies the - port — otherwise ``to_resolved()`` raises. - """ - - @pytest.fixture - async def variant_id( - self, - db_with_variant_table: ExtendedAsyncSAEngine, - ) -> RuntimeVariantID: - return await _seed_variant( - db_with_variant_table, - ModelDefinitionDraft( - models=[ - ModelConfigDraft( - name="baseline", - model_path="/baseline/path", - service=ModelServiceConfigDraft( - health_check=ModelHealthCheckDraft(path="/healthz"), - ), - ), - ], - ), - ) - - @pytest.mark.parametrize( - ("request_input", "expected"), - [ - pytest.param( - ModelDefinitionInput( - models=[ - ModelConfigInput( - service=ModelServiceConfigInput(port=8080), - ), - ], - ), - ResolvedExpectation( - name="baseline", - model_path="/baseline/path", - service_port=8080, - health_check_path="/healthz", - ), - id="request_supplies_service_port", - ), - ], - ) - async def test_request_supplying_port_resolves( - self, - reader: RevisionDraftReader, - mounts: MountMetadata, - variant_id: RuntimeVariantID, - request_input: ModelDefinitionInput, - expected: ResolvedExpectation, - ) -> None: - request = RevisionDraft(model_definition=request_input.to_draft()) - - merged = await _merge_via_reader(reader, variant_id, request, mounts) - - _assert_resolved_matches(merged, expected) - - async def test_empty_request_raises_port_required( - self, - reader: RevisionDraftReader, - mounts: MountMetadata, - variant_id: RuntimeVariantID, - ) -> None: - request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) - - merged = await _merge_via_reader(reader, variant_id, request, mounts) - - assert merged.model_definition is not None - with pytest.raises(ValueError, match=r"ModelServiceConfig\.port is required"): - merged.model_definition.to_resolved() - - -class TestMergeWhenBaselineLacksHealthCheckPath: - """The variant baseline supplies ``service.health_check`` without - ``path``. The merge succeeds only when the request supplies the - path — otherwise ``to_resolved()`` raises. - """ - - @pytest.fixture - async def variant_id( - self, - db_with_variant_table: ExtendedAsyncSAEngine, - ) -> RuntimeVariantID: - return await _seed_variant( - db_with_variant_table, - ModelDefinitionDraft( - models=[ - ModelConfigDraft( - name="baseline", - model_path="/baseline/path", - service=ModelServiceConfigDraft( - port=8080, - health_check=ModelHealthCheckDraft(), - ), - ), - ], - ), - ) - - @pytest.mark.parametrize( - ("request_input", "expected"), - [ - pytest.param( - ModelDefinitionInput( - models=[ - ModelConfigInput( - service=ModelServiceConfigInput( - health_check=ModelHealthCheckInput(path="/ready"), - ), - ), - ], - ), - ResolvedExpectation( - name="baseline", - model_path="/baseline/path", - service_port=8080, - health_check_path="/ready", - ), - id="request_supplies_health_check_path", - ), - ], - ) - async def test_request_supplying_path_resolves( - self, - reader: RevisionDraftReader, - mounts: MountMetadata, - variant_id: RuntimeVariantID, - request_input: ModelDefinitionInput, - expected: ResolvedExpectation, - ) -> None: - request = RevisionDraft(model_definition=request_input.to_draft()) - - merged = await _merge_via_reader(reader, variant_id, request, mounts) - - _assert_resolved_matches(merged, expected) - - async def test_empty_request_raises_health_check_path_required( - self, - reader: RevisionDraftReader, - mounts: MountMetadata, - variant_id: RuntimeVariantID, - ) -> None: - request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft()) - - merged = await _merge_via_reader(reader, variant_id, request, mounts) - - assert merged.model_definition is not None - with pytest.raises(ValueError, match=r"ModelHealthCheck\.path is required"): - merged.model_definition.to_resolved() From 139acc40c18640440cfc5e031e407332d283a0ee Mon Sep 17 00:00:00 2001 From: Gyubong Date: Mon, 11 May 2026 13:12:46 +0900 Subject: [PATCH 18/18] refactor(BA-5983): defer default-value application to strict Pydantic types ``ModelHealthCheckDraft.to_resolved`` and ``ModelServiceConfigDraft.to_resolved`` previously duplicated every default value (``10.0``, ``10``, ``15.0``, ``200``, ``60.0`` for the health-check fields; ``[]`` and ``"/bin/bash"`` for the service config) inline as ``if self.x is not None else `` branches. Those literals were already declared on the strict ``ModelHealthCheck`` / ``ModelServiceConfig`` classes via ``Field(default=...)``, so the project carried the same constants in two places at risk of drift. Switch ``to_resolved`` to drop ``None`` scalars via ``model_dump(exclude_none=True)`` and let the strict type's field defaults apply during ``model_validate``/constructor call. Required- field checks (``path``, ``port``) stay as explicit ``ValueError`` raises so the error message remains domain-specific rather than a generic ``pydantic.ValidationError``. The nested ``health_check`` draft is resolved out-of-band before the strict service config is composed, since it carries its own required-field check. Behavior-preserving: - Every ``Field(default=...)`` on the strict type matches the literal the old code wrote inline (verified field-by-field). - ``pre_start_actions``: old ``or []`` and new ``exclude_none=True`` produce identical results for the two reachable shapes (``None`` and ``list[PreStartAction]``). - Field-level constraints (``gt``, ``ge``) still fire because both the old constructor call and the new ``model_validate`` execute Pydantic validators. Existing tests in test_config / test_revision_draft_merge / test_model_definition_start_command_compat / test_revision_draft_reader still pass. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/ai/backend/common/config.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/ai/backend/common/config.py b/src/ai/backend/common/config.py index 16caaad4347..b0cfc37ffd1 100644 --- a/src/ai/backend/common/config.py +++ b/src/ai/backend/common/config.py @@ -534,16 +534,9 @@ class ModelHealthCheckDraft(BaseConfigModel): def to_resolved(self) -> ModelHealthCheck: if self.path is None: raise ValueError("ModelHealthCheck.path is required") - return ModelHealthCheck( - interval=self.interval if self.interval is not None else 10.0, - path=self.path, - max_retries=self.max_retries if self.max_retries is not None else 10, - max_wait_time=self.max_wait_time if self.max_wait_time is not None else 15.0, - expected_status_code=( - self.expected_status_code if self.expected_status_code is not None else 200 - ), - initial_delay=self.initial_delay if self.initial_delay is not None else 60.0, - ) + # Drop unset (None) fields so the strict type's ``Field(default=...)`` + # declarations remain the single source of truth for default values. + return ModelHealthCheck.model_validate(self.model_dump(exclude_none=True)) class ModelServiceConfigDraft(BaseConfigModel): @@ -561,12 +554,13 @@ def _coerce_start_command(cls, value: Any) -> Any: def to_resolved(self) -> ModelServiceConfig: if self.port is None: raise ValueError("ModelServiceConfig.port is required") + # Drop unset (None) scalars so the strict type's ``Field(default=...)`` + # declarations remain the single source of truth for default values; + # resolve the nested ``health_check`` draft explicitly so its own + # required-field check (``path``) fires with a clear error message. return ModelServiceConfig( - pre_start_actions=self.pre_start_actions or [], - start_command=self.start_command, - shell=self.shell if self.shell is not None else "/bin/bash", - port=self.port, - health_check=(self.health_check.to_resolved() if self.health_check else None), + **self.model_dump(exclude_none=True, exclude={"health_check"}), + health_check=self.health_check.to_resolved() if self.health_check else None, )