11"""DB-backed verification of the BA-5983 revision merge contract.
22
3- A real ``RuntimeVariantRow`` is seeded via a fixture (so the variant's
4- ``default_model_definition`` round-trips through ``PydanticColumn``
5- serialization). The production ``RevisionDraftReader`` +
6- ``RevisionDraft.merge`` pipeline is then run against request drafts
7- built from various ``ModelDefinitionInput`` shapes; the parametrized
8- table only carries the request input and the expected outcome.
9-
10- Two scenario groups, each pinned to its own DB baseline fixture:
11-
12- - ``TestMergeWithFullBaseline`` — variant ships every required field;
13- the parametrized inputs probe how different requests combine with
14- it (inherit-all, partial override).
15- - ``TestMergeRaisesWithIncompleteBaseline`` — variant ships an
16- incomplete definition where ``to_resolved()`` is expected to raise
17- because no source supplies a required nested field. Each parametrize
18- entry pairs an incomplete baseline shape with the expected error
19- pattern; the request is always all-empty.
3+ Each test class seeds one specific ``RuntimeVariantRow.default_model_definition``
4+ shape into the DB (so it round-trips through ``PydanticColumn``
5+ serialization) and runs the production ``RevisionDraftReader`` +
6+ ``RevisionDraft.merge`` pipeline against various request inputs. The
7+ parametrize tables only carry the request ``ModelDefinitionInput`` and
8+ the expected resolved values — the DB baseline is fixed per class via
9+ its ``variant_id`` fixture.
10+
11+ Scenarios are partitioned by baseline shape so each class makes the
12+ "what's in the DB" / "what the user sends" / "what should come out"
13+ relationship obvious at a glance.
2014"""
2115
2216from __future__ import annotations
3832from ai .backend .common .dto .manager .v2 .deployment .request import (
3933 ModelConfigInput ,
4034 ModelDefinitionInput ,
35+ ModelHealthCheckInput ,
36+ ModelServiceConfigInput ,
4137)
4238from ai .backend .common .identifier .runtime_variant import RuntimeVariantID
4339from ai .backend .common .identifier .vfolder import VFolderUUID
@@ -128,10 +124,27 @@ async def _merge_via_reader(
128124 return functools .reduce (RevisionDraft .merge , drafts , RevisionDraft ())
129125
130126
131- class TestMergeWithFullBaseline :
132- """Baseline supplies every required field. The parametrize table
133- pairs each ``ModelDefinitionInput`` shape with the resolved values
134- we expect after merging it against this baseline."""
127+ def _assert_resolved_matches (merged : RevisionDraft , expected : ResolvedExpectation ) -> None :
128+ assert merged .model_definition is not None
129+ resolved = merged .model_definition .to_resolved ()
130+ model = resolved .models [0 ]
131+ assert model .name == expected .name
132+ assert model .model_path == expected .model_path
133+ if expected .service_port is not None :
134+ assert model .service is not None
135+ assert model .service .port == expected .service_port
136+ if expected .health_check_path is not None :
137+ assert model .service is not None
138+ assert model .service .health_check is not None
139+ assert model .service .health_check .path == expected .health_check_path
140+
141+
142+ class TestMergeWithCompleteBaseline :
143+ """The variant ships a fully-populated ``default_model_definition``.
144+
145+ Any request — including all-empty — resolves successfully because
146+ the DB-side baseline already covers every required field.
147+ """
135148
136149 @pytest .fixture
137150 async def variant_id (
@@ -165,7 +178,7 @@ async def variant_id(
165178 service_port = 9000 ,
166179 health_check_path = "/healthz" ,
167180 ),
168- id = "empty_request_inherits_full_baseline " ,
181+ id = "empty_request_inherits_baseline " ,
169182 ),
170183 pytest .param (
171184 ModelDefinitionInput (models = [ModelConfigInput (name = "user-name" )]),
@@ -179,7 +192,7 @@ async def variant_id(
179192 ),
180193 ],
181194 )
182- async def test_merge_resolves_to_expected_values (
195+ async def test_resolves_to_expected_values (
183196 self ,
184197 reader : RevisionDraftReader ,
185198 mounts : MountMetadata ,
@@ -191,81 +204,215 @@ async def test_merge_resolves_to_expected_values(
191204
192205 merged = await _merge_via_reader (reader , variant_id , request , mounts )
193206
194- assert merged . model_definition is not None
195- resolved = merged . model_definition . to_resolved ()
196- model = resolved . models [ 0 ]
197- assert model . name == expected . name
198- assert model . model_path == expected . model_path
199- if expected . service_port is not None :
200- assert model . service is not None
201- assert model . service . port == expected . service_port
202- if expected . health_check_path is not None :
203- assert model . service is not None
204- assert model . service . health_check is not None
205- assert model . service . health_check . path == expected . health_check_path
206-
207-
208- class TestMergeRaisesWithIncompleteBaseline :
209- """Each parametrize entry seeds its own incomplete baseline (via the
210- ``baseline_factory``) and expects ``to_resolved()`` to raise because
211- no source supplies a required field. The request is always
212- all-empty so the failure mode comes entirely from the baseline."""
207+ _assert_resolved_matches ( merged , expected )
208+
209+
210+ class TestMergeWhenBaselineLacksName :
211+ """The variant baseline omits ``name``. The merge succeeds only
212+ when the request supplies one — otherwise ``to_resolved()`` raises.
213+ """
214+
215+ @ pytest . fixture
216+ async def variant_id (
217+ self ,
218+ db_with_variant_table : ExtendedAsyncSAEngine ,
219+ ) -> RuntimeVariantID :
220+ return await _seed_variant (
221+ db_with_variant_table ,
222+ ModelDefinitionDraft (
223+ models = [ ModelConfigDraft ( model_path = "/baseline/path" )],
224+ ),
225+ )
213226
214227 @pytest .mark .parametrize (
215- ("incomplete_baseline " , "error_pattern " ),
228+ ("request_input " , "expected " ),
216229 [
217230 pytest .param (
218- # Reader's mount-destination default also fills model_path,
219- # so the only required ``ModelConfig`` field that ends up
220- # unfilled is ``name``.
221- ModelDefinitionDraft (models = [ModelConfigDraft (model_path = "/p" )]),
222- r"ModelConfig\.name is required" ,
223- id = "name_unfilled" ,
231+ ModelDefinitionInput (models = [ModelConfigInput (name = "from-request" )]),
232+ ResolvedExpectation (name = "from-request" , model_path = "/baseline/path" ),
233+ id = "request_supplies_missing_name" ,
234+ ),
235+ ],
236+ )
237+ async def test_request_supplying_name_resolves (
238+ self ,
239+ reader : RevisionDraftReader ,
240+ mounts : MountMetadata ,
241+ variant_id : RuntimeVariantID ,
242+ request_input : ModelDefinitionInput ,
243+ expected : ResolvedExpectation ,
244+ ) -> None :
245+ request = RevisionDraft (model_definition = request_input .to_draft ())
246+
247+ merged = await _merge_via_reader (reader , variant_id , request , mounts )
248+
249+ _assert_resolved_matches (merged , expected )
250+
251+ async def test_empty_request_raises_name_required (
252+ self ,
253+ reader : RevisionDraftReader ,
254+ mounts : MountMetadata ,
255+ variant_id : RuntimeVariantID ,
256+ ) -> None :
257+ request = RevisionDraft (model_definition = ModelDefinitionInput ().to_draft ())
258+
259+ merged = await _merge_via_reader (reader , variant_id , request , mounts )
260+
261+ assert merged .model_definition is not None
262+ with pytest .raises (ValueError , match = r"ModelConfig\.name is required" ):
263+ merged .model_definition .to_resolved ()
264+
265+
266+ class TestMergeWhenBaselineLacksServicePort :
267+ """The variant baseline supplies a ``service`` block without
268+ ``port``. The merge succeeds only when the request supplies the
269+ port — otherwise ``to_resolved()`` raises.
270+ """
271+
272+ @pytest .fixture
273+ async def variant_id (
274+ self ,
275+ db_with_variant_table : ExtendedAsyncSAEngine ,
276+ ) -> RuntimeVariantID :
277+ return await _seed_variant (
278+ db_with_variant_table ,
279+ ModelDefinitionDraft (
280+ models = [
281+ ModelConfigDraft (
282+ name = "baseline" ,
283+ model_path = "/baseline/path" ,
284+ service = ModelServiceConfigDraft (
285+ health_check = ModelHealthCheckDraft (path = "/healthz" ),
286+ ),
287+ ),
288+ ],
224289 ),
290+ )
291+
292+ @pytest .mark .parametrize (
293+ ("request_input" , "expected" ),
294+ [
225295 pytest .param (
226- ModelDefinitionDraft (
296+ ModelDefinitionInput (
227297 models = [
228- ModelConfigDraft (
229- name = "n" ,
230- model_path = "/p" ,
231- service = ModelServiceConfigDraft (),
298+ ModelConfigInput (
299+ service = ModelServiceConfigInput (port = 8080 ),
232300 ),
233301 ],
234302 ),
235- r"ModelServiceConfig\.port is required" ,
236- id = "service_port_unfilled" ,
303+ ResolvedExpectation (
304+ name = "baseline" ,
305+ model_path = "/baseline/path" ,
306+ service_port = 8080 ,
307+ health_check_path = "/healthz" ,
308+ ),
309+ id = "request_supplies_service_port" ,
237310 ),
311+ ],
312+ )
313+ async def test_request_supplying_port_resolves (
314+ self ,
315+ reader : RevisionDraftReader ,
316+ mounts : MountMetadata ,
317+ variant_id : RuntimeVariantID ,
318+ request_input : ModelDefinitionInput ,
319+ expected : ResolvedExpectation ,
320+ ) -> None :
321+ request = RevisionDraft (model_definition = request_input .to_draft ())
322+
323+ merged = await _merge_via_reader (reader , variant_id , request , mounts )
324+
325+ _assert_resolved_matches (merged , expected )
326+
327+ async def test_empty_request_raises_port_required (
328+ self ,
329+ reader : RevisionDraftReader ,
330+ mounts : MountMetadata ,
331+ variant_id : RuntimeVariantID ,
332+ ) -> None :
333+ request = RevisionDraft (model_definition = ModelDefinitionInput ().to_draft ())
334+
335+ merged = await _merge_via_reader (reader , variant_id , request , mounts )
336+
337+ assert merged .model_definition is not None
338+ with pytest .raises (ValueError , match = r"ModelServiceConfig\.port is required" ):
339+ merged .model_definition .to_resolved ()
340+
341+
342+ class TestMergeWhenBaselineLacksHealthCheckPath :
343+ """The variant baseline supplies ``service.health_check`` without
344+ ``path``. The merge succeeds only when the request supplies the
345+ path — otherwise ``to_resolved()`` raises.
346+ """
347+
348+ @pytest .fixture
349+ async def variant_id (
350+ self ,
351+ db_with_variant_table : ExtendedAsyncSAEngine ,
352+ ) -> RuntimeVariantID :
353+ return await _seed_variant (
354+ db_with_variant_table ,
355+ ModelDefinitionDraft (
356+ models = [
357+ ModelConfigDraft (
358+ name = "baseline" ,
359+ model_path = "/baseline/path" ,
360+ service = ModelServiceConfigDraft (
361+ port = 8080 ,
362+ health_check = ModelHealthCheckDraft (),
363+ ),
364+ ),
365+ ],
366+ ),
367+ )
368+
369+ @pytest .mark .parametrize (
370+ ("request_input" , "expected" ),
371+ [
238372 pytest .param (
239- ModelDefinitionDraft (
373+ ModelDefinitionInput (
240374 models = [
241- ModelConfigDraft (
242- name = "n" ,
243- model_path = "/p" ,
244- service = ModelServiceConfigDraft (
245- port = 8080 ,
246- health_check = ModelHealthCheckDraft (),
375+ ModelConfigInput (
376+ service = ModelServiceConfigInput (
377+ health_check = ModelHealthCheckInput (path = "/ready" ),
247378 ),
248379 ),
249380 ],
250381 ),
251- r"ModelHealthCheck\.path is required" ,
252- id = "health_check_path_unfilled" ,
382+ ResolvedExpectation (
383+ name = "baseline" ,
384+ model_path = "/baseline/path" ,
385+ service_port = 8080 ,
386+ health_check_path = "/ready" ,
387+ ),
388+ id = "request_supplies_health_check_path" ,
253389 ),
254390 ],
255391 )
256- async def test_required_field_unfilled_after_merge_raises (
392+ async def test_request_supplying_path_resolves (
393+ self ,
394+ reader : RevisionDraftReader ,
395+ mounts : MountMetadata ,
396+ variant_id : RuntimeVariantID ,
397+ request_input : ModelDefinitionInput ,
398+ expected : ResolvedExpectation ,
399+ ) -> None :
400+ request = RevisionDraft (model_definition = request_input .to_draft ())
401+
402+ merged = await _merge_via_reader (reader , variant_id , request , mounts )
403+
404+ _assert_resolved_matches (merged , expected )
405+
406+ async def test_empty_request_raises_health_check_path_required (
257407 self ,
258- db_with_variant_table : ExtendedAsyncSAEngine ,
259408 reader : RevisionDraftReader ,
260409 mounts : MountMetadata ,
261- incomplete_baseline : ModelDefinitionDraft ,
262- error_pattern : str ,
410+ variant_id : RuntimeVariantID ,
263411 ) -> None :
264- variant_id = await _seed_variant (db_with_variant_table , incomplete_baseline )
265412 request = RevisionDraft (model_definition = ModelDefinitionInput ().to_draft ())
266413
267414 merged = await _merge_via_reader (reader , variant_id , request , mounts )
268415
269416 assert merged .model_definition is not None
270- with pytest .raises (ValueError , match = error_pattern ):
417+ with pytest .raises (ValueError , match = r"ModelHealthCheck\.path is required" ):
271418 merged .model_definition .to_resolved ()
0 commit comments