Skip to content

Commit 83b3ac4

Browse files
jopemachineclaude
andcommitted
fix(BA-5983): preserve unset semantics in ModelDefinitionInput.to_draft
``ModelDefinitionInput.to_draft()`` used ``model_dump()`` with default arguments, which dumps every field — including the unset ones at their ``None`` default. Round-tripping that through ``model_validate`` left the resulting draft with ``model_fields_set`` containing every field, so every ``None`` looked "explicitly set" and clobbered lower-priority baselines during the revision merge chain. Switch to ``model_dump(exclude_unset=True)`` so the resulting draft's ``model_fields_set`` reflects only what the caller actually provided. This is what makes the BA-5983 scenario actually work end-to-end: a request that omits ``name`` / ``model_path`` / ``service.port`` / ``health_check.path`` lets the variant baseline (or preset) fill them in instead of nulling them out. Extend the merge test to cover this directly — every "missing required field" scenario now layers a baseline draft together with a request draft so the merge actually combines fields across sources. Without the fix, the model_path / service_port / health_check_path cases would raise the wrong error first (e.g. ``ModelConfig.name is required`` fires before ``model_path``) because every request-side ``None`` would clobber the baseline's preserved value. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent d3cf04c commit 83b3ac4

2 files changed

Lines changed: 86 additions & 35 deletions

File tree

src/ai/backend/common/dto/manager/v2/deployment/request.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,12 @@ class ModelDefinitionInput(BaseRequestModel):
177177
models: list[ModelConfigInput] | None = None
178178

179179
def to_draft(self) -> ModelDefinitionDraft:
180-
return ModelDefinitionDraft.model_validate(self.model_dump())
180+
# ``exclude_unset=True`` keeps the resulting draft's
181+
# ``model_fields_set`` aligned with what the caller actually
182+
# provided. Without it, every field would appear "explicitly
183+
# set" (to ``None``) and clobber lower-priority sources during
184+
# the revision merge.
185+
return ModelDefinitionDraft.model_validate(self.model_dump(exclude_unset=True))
181186

182187

183188
class ClusterConfigInput(BaseRequestModel):

tests/unit/manager/sokovan/deployment/test_model_definition_merge.py

Lines changed: 80 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -173,58 +173,104 @@ def test_merge_resolves_to_expected_values(
173173
class TestMergeRaisesWhenAllSourcesAreEmpty:
174174
"""When neither the request nor any baseline source supplies a
175175
required field, ``to_resolved()`` must raise at the persistence
176-
boundary — preserving the pre-BA-5983 contract."""
176+
boundary — preserving the pre-BA-5983 contract.
177+
178+
Each scenario layers a baseline draft (variant-style) together with
179+
a request draft so the merge actually combines fields across
180+
sources; the target required field remains unfilled in every layer
181+
and the resolved-time check fires on it specifically."""
177182

178183
@pytest.mark.parametrize(
179-
("request_input", "error_pattern"),
184+
("drafts", "error_pattern"),
180185
[
181186
pytest.param(
182-
ModelDefinitionInput(models=[ModelConfigInput(model_path="/p")]),
187+
[
188+
# baseline supplies model_path; request adds nothing.
189+
RevisionDraft(
190+
model_definition=ModelDefinitionDraft(
191+
models=[ModelConfigDraft(model_path="/baseline/path")],
192+
),
193+
),
194+
RevisionDraft(
195+
model_definition=ModelDefinitionInput(
196+
models=[ModelConfigInput()],
197+
).to_draft(),
198+
),
199+
],
183200
r"ModelConfig\.name is required",
184-
id="missing_name",
201+
id="name_unfilled_across_baseline_and_request",
185202
),
186203
pytest.param(
187-
ModelDefinitionInput(models=[ModelConfigInput(name="n")]),
204+
[
205+
# baseline supplies name; request adds nothing.
206+
RevisionDraft(
207+
model_definition=ModelDefinitionDraft(
208+
models=[ModelConfigDraft(name="baseline-name")],
209+
),
210+
),
211+
RevisionDraft(
212+
model_definition=ModelDefinitionInput(
213+
models=[ModelConfigInput()],
214+
).to_draft(),
215+
),
216+
],
188217
r"ModelConfig\.model_path is required",
189-
id="missing_model_path",
218+
id="model_path_unfilled_across_baseline_and_request",
190219
),
191220
pytest.param(
192-
ModelDefinitionInput(
193-
models=[
194-
ModelConfigInput(
195-
name="n",
196-
model_path="/p",
197-
service=ModelServiceConfigInput(),
198-
)
199-
],
200-
),
221+
[
222+
# baseline supplies the outer ModelConfig fields;
223+
# request adds an empty service (no port anywhere).
224+
RevisionDraft(
225+
model_definition=ModelDefinitionDraft(
226+
models=[ModelConfigDraft(name="n", model_path="/p")],
227+
),
228+
),
229+
RevisionDraft(
230+
model_definition=ModelDefinitionInput(
231+
models=[ModelConfigInput(service=ModelServiceConfigInput())],
232+
).to_draft(),
233+
),
234+
],
201235
r"ModelServiceConfig\.port is required",
202-
id="missing_service_port",
236+
id="service_port_unfilled_across_baseline_and_request",
203237
),
204238
pytest.param(
205-
ModelDefinitionInput(
206-
models=[
207-
ModelConfigInput(
208-
name="n",
209-
model_path="/p",
210-
service=ModelServiceConfigInput(
211-
port=8080,
212-
health_check=ModelHealthCheckInput(),
213-
),
214-
)
215-
],
216-
),
239+
[
240+
# baseline supplies a service with port; request adds
241+
# an empty health_check (no path anywhere).
242+
RevisionDraft(
243+
model_definition=ModelDefinitionDraft(
244+
models=[
245+
ModelConfigDraft(
246+
name="n",
247+
model_path="/p",
248+
service=ModelServiceConfigDraft(port=8080),
249+
)
250+
],
251+
),
252+
),
253+
RevisionDraft(
254+
model_definition=ModelDefinitionInput(
255+
models=[
256+
ModelConfigInput(
257+
service=ModelServiceConfigInput(
258+
health_check=ModelHealthCheckInput(),
259+
),
260+
)
261+
],
262+
).to_draft(),
263+
),
264+
],
217265
r"ModelHealthCheck\.path is required",
218-
id="missing_health_check_path",
266+
id="health_check_path_unfilled_across_baseline_and_request",
219267
),
220268
],
221269
)
222-
def test_missing_required_field_raises(
223-
self, request_input: ModelDefinitionInput, error_pattern: str
270+
def test_required_field_unfilled_after_merge_raises(
271+
self, drafts: list[RevisionDraft], error_pattern: str
224272
) -> None:
225-
request = RevisionDraft(model_definition=request_input.to_draft())
226-
227-
merged = _merge(request)
273+
merged = _merge(*drafts)
228274

229275
assert merged.model_definition is not None
230276
with pytest.raises(ValueError, match=error_pattern):

0 commit comments

Comments
 (0)