Skip to content

Commit e82cd8f

Browse files
jopemachineclaude
andcommitted
refactor(BA-5983): make to_draft a method of ModelDefinitionInput
Address review feedback by moving the standalone ``to_model_definition_draft`` helper into ``ModelDefinitionInput.to_draft`` and update the test fixtures to use ``ModelDefinitionInput`` (matching the field type) instead of ``ModelDefinitionDraft``. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 9c1e003 commit e82cd8f

3 files changed

Lines changed: 12 additions & 14 deletions

File tree

src/ai/backend/common/dto/manager/v2/deployment/request.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -176,13 +176,8 @@ class ModelDefinitionInput(BaseRequestModel):
176176

177177
models: list[ModelConfigInput] | None = None
178178

179-
180-
def to_model_definition_draft(
181-
input: ModelDefinitionInput | None,
182-
) -> ModelDefinitionDraft | None:
183-
if input is None:
184-
return None
185-
return ModelDefinitionDraft.model_validate(input.model_dump())
179+
def to_draft(self) -> ModelDefinitionDraft:
180+
return ModelDefinitionDraft.model_validate(self.model_dump())
186181

187182

188183
class ClusterConfigInput(BaseRequestModel):

src/ai/backend/manager/api/adapters/deployment/adapter.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,6 @@
6262
SyncReplicaInput,
6363
UpdateDeploymentInput,
6464
UpsertDeploymentPolicyInput,
65-
to_model_definition_draft,
6665
)
6766
from ai.backend.common.dto.manager.v2.deployment.response import (
6867
AccessTokenNode,
@@ -506,7 +505,9 @@ async def create(
506505
else None,
507506
),
508507
mounts=mounts_creator,
509-
model_definition=to_model_definition_draft(initial_revision.model_definition),
508+
model_definition=initial_revision.model_definition.to_draft()
509+
if initial_revision.model_definition is not None
510+
else None,
510511
revision_preset_id=initial_revision.revision_preset_id,
511512
execution=ExecutionSpec(
512513
runtime_variant_id=initial_revision.model_runtime_config.runtime_variant_id,
@@ -1111,7 +1112,9 @@ async def add_revision(
11111112
else None,
11121113
inference_runtime_config=input.model_runtime_config.inference_runtime_config,
11131114
),
1114-
model_definition=to_model_definition_draft(input.model_definition),
1115+
model_definition=input.model_definition.to_draft()
1116+
if input.model_definition is not None
1117+
else None,
11151118
revision_preset_id=input.revision_preset_id,
11161119
)
11171120
action_result = await self._processors.deployment.add_model_revision.wait_for_complete(

tests/unit/common/dto/manager/v2/deployment/test_request.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from pydantic import ValidationError
1212

1313
from ai.backend.common.api_handlers import SENTINEL, Sentinel
14-
from ai.backend.common.config import ModelDefinitionDraft
1514
from ai.backend.common.data.model_deployment.types import DeploymentStrategy
1615
from ai.backend.common.dto.manager.v2.deployment.request import (
1716
ActivateDeploymentInput,
@@ -25,6 +24,7 @@
2524
DeploymentStrategyInput,
2625
ExtraVFolderMountInput,
2726
ImageInput,
27+
ModelDefinitionInput,
2828
ModelDeploymentMetadataInput,
2929
ModelDeploymentNetworkAccessInput,
3030
ModelMountConfigInput,
@@ -55,7 +55,7 @@ def _make_revision_input(**kwargs: object) -> RevisionInput:
5555
"runtime_variant_id": RuntimeVariantID(uuid.uuid4()),
5656
"model_vfolder_id": VFolderUUID(uuid.uuid4()),
5757
"model_definition_path": "/models/model.yaml",
58-
"model_definition": ModelDefinitionDraft(),
58+
"model_definition": ModelDefinitionInput(),
5959
}
6060
defaults.update(kwargs)
6161
return RevisionInput(**defaults)
@@ -82,7 +82,7 @@ def _make_create_revision_input_dto(**kwargs: object) -> CreateRevisionInputDTO:
8282
mount_destination="/models",
8383
definition_path="/models/model.yaml",
8484
),
85-
"model_definition": ModelDefinitionDraft(),
85+
"model_definition": ModelDefinitionInput(),
8686
}
8787
defaults.update(kwargs)
8888
return CreateRevisionInputDTO(**defaults)
@@ -103,7 +103,7 @@ def test_valid_creation_with_required_fields(self) -> None:
103103
runtime_variant_id=runtime_variant_id,
104104
model_vfolder_id=model_id,
105105
model_definition_path="/models/def.yaml",
106-
model_definition=ModelDefinitionDraft(),
106+
model_definition=ModelDefinitionInput(),
107107
)
108108
assert rev.image_id == image_id
109109
assert rev.cluster_mode == ClusterMode.SINGLE_NODE

0 commit comments

Comments
 (0)