Skip to content

Commit 8f74efc

Browse files
jopemachineclaude
andcommitted
test(BA-5983): add DB-backed revision merge test
Insert a real ``RuntimeVariantRow`` with a baseline ``default_model_definition`` (round-trips through ``PydanticColumn`` serialization), then exercise the production ``RevisionDraftReader`` + ``RevisionDraft.merge`` pipeline against a request draft built from ``ModelDefinitionInput.to_draft()``. Scenarios: - Empty input + baseline supplying full required tree → resolved ``ModelConfig`` carries baseline values verbatim. - Partial request (name only) + baseline (name + model_path) → request wins on ``name``; baseline's ``model_path`` survives. - Baseline missing ``name`` → ``to_resolved()`` raises ``ModelConfig.name is required``. - Baseline supplying empty ``service`` → ``ModelServiceConfig.port is required``. - Baseline supplying empty ``health_check`` → ``ModelHealthCheck.path is required``. The synthetic merge tests still cover the pure merge functions in-process; this file pins the DB → reader → merge → resolve loop that the ``add_model_revision`` action actually runs in production. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 83b3ac4 commit 8f74efc

1 file changed

Lines changed: 260 additions & 0 deletions

File tree

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
"""DB-backed verification of the BA-5983 revision merge contract.
2+
3+
Inserts real ``RuntimeVariantRow`` records (so the variant's
4+
``default_model_definition`` round-trips through ``PydanticColumn``
5+
serialization) and runs the production ``RevisionDraftReader`` +
6+
``RevisionDraft.merge`` pipeline against a request draft built from
7+
``ModelDefinitionInput.to_draft()``. The resolved output is then
8+
inspected to confirm:
9+
10+
- An empty request inherits every required field from the variant
11+
baseline; the resolved ``ModelDefinition`` carries the baseline
12+
values verbatim.
13+
- A request that supplies a subset of fields overrides only those
14+
fields; baseline-supplied fields survive.
15+
- When no source supplies a required field, ``to_resolved()`` raises
16+
``ValueError`` with the field-specific message.
17+
18+
This exercises the full read path (DB → ``PydanticColumn`` →
19+
``RuntimeVariantData`` → ``RevisionDraft``) plus the merge and
20+
resolve phases that the ``add_model_revision`` action ultimately runs.
21+
"""
22+
23+
from __future__ import annotations
24+
25+
import functools
26+
import uuid
27+
from collections.abc import AsyncGenerator
28+
from dataclasses import dataclass
29+
from unittest.mock import MagicMock
30+
31+
import pytest
32+
33+
from ai.backend.common.config import (
34+
ModelConfigDraft,
35+
ModelDefinitionDraft,
36+
ModelHealthCheckDraft,
37+
ModelServiceConfigDraft,
38+
)
39+
from ai.backend.common.dto.manager.v2.deployment.request import (
40+
ModelConfigInput,
41+
ModelDefinitionInput,
42+
)
43+
from ai.backend.common.identifier.runtime_variant import RuntimeVariantID
44+
from ai.backend.common.identifier.vfolder import VFolderUUID
45+
from ai.backend.manager.data.deployment.types import MountMetadata, RevisionDraft
46+
from ai.backend.manager.models.runtime_variant import RuntimeVariantRow
47+
from ai.backend.manager.models.utils import ExtendedAsyncSAEngine
48+
from ai.backend.manager.repositories.deployment.repository import DeploymentRepository
49+
from ai.backend.manager.sokovan.deployment.revision_draft import RevisionDraftReader
50+
from ai.backend.testutils.db import with_tables
51+
52+
53+
@dataclass(frozen=True)
54+
class ResolvedExpectation:
55+
"""Expected attributes on the resolved ``ModelConfig`` at ``models[0]``."""
56+
57+
name: str
58+
model_path: str
59+
service_port: int | None = None
60+
health_check_path: str | None = None
61+
62+
63+
class TestRevisionMergeWithRealVariantBaseline:
64+
@pytest.fixture
65+
async def db_with_variant_table(
66+
self,
67+
database_connection: ExtendedAsyncSAEngine,
68+
) -> AsyncGenerator[ExtendedAsyncSAEngine, None]:
69+
async with with_tables(database_connection, [RuntimeVariantRow]):
70+
yield database_connection
71+
72+
@pytest.fixture
73+
def reader(
74+
self,
75+
db_with_variant_table: ExtendedAsyncSAEngine,
76+
) -> RevisionDraftReader:
77+
# ``load_deployment_revision_read_bundle`` only touches the
78+
# runtime_variants table when ``preset_id`` is ``None``; the
79+
# other repository dependencies are not exercised and can be
80+
# stubbed out.
81+
repo = DeploymentRepository(
82+
db=db_with_variant_table,
83+
storage_manager=MagicMock(),
84+
valkey_stat=MagicMock(),
85+
valkey_live=MagicMock(),
86+
valkey_schedule=MagicMock(),
87+
)
88+
return RevisionDraftReader(deployment_repository=repo)
89+
90+
@pytest.fixture
91+
def mounts(self) -> MountMetadata:
92+
return MountMetadata(
93+
model_vfolder_id=VFolderUUID(uuid.uuid4()),
94+
model_definition_path=None,
95+
model_mount_destination="/models",
96+
extra_mounts=[],
97+
)
98+
99+
@staticmethod
100+
async def _seed_variant_baseline(
101+
db: ExtendedAsyncSAEngine,
102+
baseline: ModelDefinitionDraft,
103+
) -> RuntimeVariantID:
104+
variant_id = RuntimeVariantID(uuid.uuid4())
105+
async with db.begin_session() as sess:
106+
sess.add(
107+
RuntimeVariantRow(
108+
id=variant_id,
109+
name=f"test-variant-{variant_id.hex[:8]}",
110+
description="BA-5983 merge-test variant baseline",
111+
reads_vfolder_config_files=False,
112+
default_model_definition=baseline,
113+
)
114+
)
115+
await sess.commit()
116+
return variant_id
117+
118+
@staticmethod
119+
async def _merge_via_reader(
120+
reader: RevisionDraftReader,
121+
variant_id: RuntimeVariantID,
122+
request: RevisionDraft,
123+
mounts: MountMetadata,
124+
) -> RevisionDraft:
125+
drafts = await reader.read_for_deployment_revision(
126+
runtime_variant_id=variant_id,
127+
request_draft=request,
128+
mounts=mounts,
129+
preset_id=None,
130+
)
131+
return functools.reduce(RevisionDraft.merge, drafts, RevisionDraft())
132+
133+
@pytest.mark.parametrize(
134+
("baseline", "request_input", "expected"),
135+
[
136+
pytest.param(
137+
ModelDefinitionDraft(
138+
models=[
139+
ModelConfigDraft(
140+
name="baseline-llama",
141+
model_path="/models/baseline",
142+
service=ModelServiceConfigDraft(
143+
port=9000,
144+
health_check=ModelHealthCheckDraft(path="/healthz"),
145+
),
146+
),
147+
],
148+
),
149+
ModelDefinitionInput(),
150+
ResolvedExpectation(
151+
name="baseline-llama",
152+
model_path="/models/baseline",
153+
service_port=9000,
154+
health_check_path="/healthz",
155+
),
156+
id="empty_request_inherits_full_baseline",
157+
),
158+
pytest.param(
159+
ModelDefinitionDraft(
160+
models=[
161+
ModelConfigDraft(name="baseline-name", model_path="/baseline/path"),
162+
],
163+
),
164+
ModelDefinitionInput(models=[ModelConfigInput(name="user-name")]),
165+
ResolvedExpectation(name="user-name", model_path="/baseline/path"),
166+
id="request_overrides_name_baseline_keeps_model_path",
167+
),
168+
],
169+
)
170+
async def test_merge_resolves_to_expected_values(
171+
self,
172+
db_with_variant_table: ExtendedAsyncSAEngine,
173+
reader: RevisionDraftReader,
174+
mounts: MountMetadata,
175+
baseline: ModelDefinitionDraft,
176+
request_input: ModelDefinitionInput,
177+
expected: ResolvedExpectation,
178+
) -> None:
179+
variant_id = await self._seed_variant_baseline(db_with_variant_table, baseline)
180+
request = RevisionDraft(model_definition=request_input.to_draft())
181+
182+
merged = await self._merge_via_reader(reader, variant_id, request, mounts)
183+
184+
assert merged.model_definition is not None
185+
resolved = merged.model_definition.to_resolved()
186+
model = resolved.models[0]
187+
assert model.name == expected.name
188+
assert model.model_path == expected.model_path
189+
if expected.service_port is not None:
190+
assert model.service is not None
191+
assert model.service.port == expected.service_port
192+
if expected.health_check_path is not None:
193+
assert model.service is not None
194+
assert model.service.health_check is not None
195+
assert model.service.health_check.path == expected.health_check_path
196+
197+
@pytest.mark.parametrize(
198+
("baseline", "error_pattern"),
199+
[
200+
pytest.param(
201+
# baseline supplies model_path only; reader's mount-destination
202+
# default would also fill model_path → only ``name`` remains unfilled.
203+
ModelDefinitionDraft(models=[ModelConfigDraft(model_path="/p")]),
204+
r"ModelConfig\.name is required",
205+
id="name_unfilled_across_baseline_and_request",
206+
),
207+
pytest.param(
208+
# baseline supplies name + model_path + an empty service →
209+
# service.port has no default and no override.
210+
ModelDefinitionDraft(
211+
models=[
212+
ModelConfigDraft(
213+
name="n",
214+
model_path="/p",
215+
service=ModelServiceConfigDraft(),
216+
),
217+
],
218+
),
219+
r"ModelServiceConfig\.port is required",
220+
id="service_port_unfilled_across_baseline_and_request",
221+
),
222+
pytest.param(
223+
# baseline supplies service.port but an empty health_check →
224+
# health_check.path has no default.
225+
ModelDefinitionDraft(
226+
models=[
227+
ModelConfigDraft(
228+
name="n",
229+
model_path="/p",
230+
service=ModelServiceConfigDraft(
231+
port=8080,
232+
health_check=ModelHealthCheckDraft(),
233+
),
234+
),
235+
],
236+
),
237+
r"ModelHealthCheck\.path is required",
238+
id="health_check_path_unfilled_across_baseline_and_request",
239+
),
240+
],
241+
)
242+
async def test_required_field_unfilled_after_merge_raises(
243+
self,
244+
db_with_variant_table: ExtendedAsyncSAEngine,
245+
reader: RevisionDraftReader,
246+
mounts: MountMetadata,
247+
baseline: ModelDefinitionDraft,
248+
error_pattern: str,
249+
) -> None:
250+
# Request is always an all-empty ``ModelDefinitionInput`` for these
251+
# scenarios — the merge result depends entirely on whether the
252+
# baseline (or reader-supplied defaults) cover every required field.
253+
variant_id = await self._seed_variant_baseline(db_with_variant_table, baseline)
254+
request = RevisionDraft(model_definition=ModelDefinitionInput().to_draft())
255+
256+
merged = await self._merge_via_reader(reader, variant_id, request, mounts)
257+
258+
assert merged.model_definition is not None
259+
with pytest.raises(ValueError, match=error_pattern):
260+
merged.model_definition.to_resolved()

0 commit comments

Comments
 (0)