|
13 | 13 | from pydantic import Field, field_validator |
14 | 14 |
|
15 | 15 | from ai.backend.common.api_handlers import SENTINEL, BaseRequestModel, Sentinel |
16 | | -from ai.backend.common.config import ModelDefinitionDraft |
| 16 | +from ai.backend.common.config import ( |
| 17 | + ModelDefinitionDraft, |
| 18 | + PreStartAction, |
| 19 | +) |
17 | 20 | from ai.backend.common.data.model_deployment.types import ( |
18 | 21 | DeploymentStrategy, |
19 | 22 | RouteHealthStatus, |
|
80 | 83 | "EnvironmentVariablesInput", |
81 | 84 | "ExtraVFolderMountInput", |
82 | 85 | "ImageInput", |
| 86 | + "ModelConfigInput", |
| 87 | + "ModelDefinitionInput", |
83 | 88 | "ModelDeploymentMetadataInput", |
84 | 89 | "ModelDeploymentNetworkAccessInput", |
| 90 | + "ModelHealthCheckInput", |
| 91 | + "ModelMetadataInput", |
85 | 92 | "ModelMountConfigInput", |
86 | 93 | "ModelRuntimeConfigInput", |
| 94 | + "ModelServiceConfigInput", |
87 | 95 | "ReplicaFilter", |
88 | 96 | "ReplicaOrder", |
89 | 97 | "ReplicaStatusFilter", |
|
116 | 124 | ) |
117 | 125 |
|
118 | 126 |
|
| 127 | +class ModelHealthCheckInput(BaseRequestModel): |
| 128 | + interval: float | None = None |
| 129 | + path: str | None = None |
| 130 | + max_retries: int | None = None |
| 131 | + max_wait_time: float | None = None |
| 132 | + expected_status_code: int | None = None |
| 133 | + initial_delay: float | None = None |
| 134 | + |
| 135 | + |
| 136 | +class ModelMetadataInput(BaseRequestModel): |
| 137 | + author: str | None = None |
| 138 | + title: str | None = None |
| 139 | + version: str | None = None |
| 140 | + created: str | None = None |
| 141 | + last_modified: str | None = None |
| 142 | + description: str | None = None |
| 143 | + task: str | None = None |
| 144 | + category: str | None = None |
| 145 | + architecture: str | None = None |
| 146 | + framework: list[str] | None = None |
| 147 | + label: list[str] | None = None |
| 148 | + license: str | None = None |
| 149 | + min_resource: dict[str, Any] | None = None |
| 150 | + |
| 151 | + |
| 152 | +class ModelServiceConfigInput(BaseRequestModel): |
| 153 | + pre_start_actions: list[PreStartAction] | None = None |
| 154 | + start_command: list[str] | None = None |
| 155 | + shell: str | None = None |
| 156 | + port: int | None = None |
| 157 | + health_check: ModelHealthCheckInput | None = None |
| 158 | + |
| 159 | + |
| 160 | +class ModelConfigInput(BaseRequestModel): |
| 161 | + name: str | None = None |
| 162 | + model_path: str | None = None |
| 163 | + service: ModelServiceConfigInput | None = None |
| 164 | + metadata: ModelMetadataInput | None = None |
| 165 | + |
| 166 | + |
| 167 | +class ModelDefinitionInput(BaseRequestModel): |
| 168 | + """All-optional v2 input mirror of :class:`ModelDefinitionDraft`. |
| 169 | +
|
| 170 | + Fields a request omits are filled by lower-priority sources in the |
| 171 | + revision merge chain (runtime variant baseline, revision preset, |
| 172 | + vfolder ``model-definition.yaml``, ``model_mount_destination`` |
| 173 | + default). Required-field enforcement happens later in |
| 174 | + ``ModelDefinitionDraft.to_resolved`` after the merge. |
| 175 | + """ |
| 176 | + |
| 177 | + models: list[ModelConfigInput] | None = None |
| 178 | + |
| 179 | + def to_draft(self) -> ModelDefinitionDraft: |
| 180 | + # ``exclude_unset=True`` keeps the resulting draft's |
| 181 | + # ``model_fields_set`` aligned with what the caller actually |
| 182 | + # provided. Without it, every field would appear "explicitly |
| 183 | + # set" (to ``None``) and clobber lower-priority sources during |
| 184 | + # the revision merge. |
| 185 | + return ModelDefinitionDraft.model_validate(self.model_dump(exclude_unset=True)) |
| 186 | + |
| 187 | + |
119 | 188 | class ClusterConfigInput(BaseRequestModel): |
120 | 189 | """Cluster configuration input for a revision.""" |
121 | 190 |
|
@@ -240,7 +309,7 @@ class CreateRevisionInputDTO(BaseRequestModel): |
240 | 309 | image: ImageInput = Field(description="Container image") |
241 | 310 | model_runtime_config: ModelRuntimeConfigInput = Field(description="Runtime configuration") |
242 | 311 | model_mount_config: ModelMountConfigInput = Field(description="Model mount configuration") |
243 | | - model_definition: ModelDefinitionDraft | None = Field( |
| 312 | + model_definition: ModelDefinitionInput | None = Field( |
244 | 313 | default=None, |
245 | 314 | description="Model definition to override the default values generated by the server", |
246 | 315 | ) |
@@ -276,7 +345,7 @@ class AddRevisionGQLInputDTO(BaseRequestModel): |
276 | 345 | image: ImageInput = Field(description="Container image") |
277 | 346 | model_runtime_config: ModelRuntimeConfigInput = Field(description="Runtime configuration") |
278 | 347 | model_mount_config: ModelMountConfigInput = Field(description="Model mount configuration") |
279 | | - model_definition: ModelDefinitionDraft | None = Field( |
| 348 | + model_definition: ModelDefinitionInput | None = Field( |
280 | 349 | default=None, |
281 | 350 | description="Model definition to override the default values generated by the server", |
282 | 351 | ) |
@@ -403,7 +472,7 @@ class RevisionInput(BaseRequestModel): |
403 | 472 | default="/models", description="Mount destination for model vfolder" |
404 | 473 | ) |
405 | 474 | model_definition_path: str = Field(description="Path to model definition file") |
406 | | - model_definition: ModelDefinitionDraft | None = Field( |
| 475 | + model_definition: ModelDefinitionInput | None = Field( |
407 | 476 | default=None, |
408 | 477 | description="Model definition to override the default values generated by the server", |
409 | 478 | ) |
|
0 commit comments