Skip to content

Commit 0a6581e

Browse files
authored
[iris] Add event-replay testing system for transitions equivalence (#5165)
## Summary Adds a pytest-driven replay framework under `lib/iris/tests/cluster/controller/replay/` that drives a fresh `ControllerDB` through curated mutation sequences and asserts the final DB state against committed goldens. **Pure test infrastructure — no production-code changes.** ### Workflow Generate goldens once on `main` (this PR's commit): ``` uv run pytest lib/iris/tests/cluster/controller/replay/ --update-goldens ``` The committed golden tree then acts as a behavioral fingerprint of `main`. Any branch that preserves observable DB state for all 13 scenarios keeps the tests green without regeneration.
1 parent fdc9349 commit 0a6581e

19 files changed

Lines changed: 3231 additions & 0 deletions
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Pytest hooks for the replay golden tests."""
5+
6+
7+
def pytest_addoption(parser) -> None:
8+
"""Register ``--update-goldens`` to regenerate the committed golden files."""
9+
parser.addoption(
10+
"--update-goldens",
11+
action="store_true",
12+
default=False,
13+
help="Regenerate replay golden files instead of asserting against them.",
14+
)
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Deterministic table-by-table JSON dump of the controller DB.
5+
6+
Produces a stable mapping ``{table_name: [row_dicts]}`` with rows sorted
7+
by primary key. This is the canonical "did the DB end up in the same
8+
state?" check used by the replay golden tests; SQL traces are reviewed
9+
informationally.
10+
11+
Excludes SQLite internal tables and ``schema_migrations`` (which is
12+
populated unconditionally on ``apply_migrations`` and adds noise without
13+
testing any controller behavior).
14+
"""
15+
16+
import base64
17+
from typing import Any
18+
19+
from iris.cluster.controller.db import ControllerDB
20+
21+
EXCLUDED_TABLES: frozenset[str] = frozenset({"schema_migrations"})
22+
"""Tables ignored by ``deterministic_dump`` — schema bookkeeping only."""
23+
24+
25+
def _encode(value: Any) -> Any:
26+
"""Render a SQLite cell value as a JSON-serializable scalar.
27+
28+
``bytes`` columns (notably ``job_workdir_files.data``) become
29+
base64 ASCII so the dump round-trips through ``json.dumps``.
30+
"""
31+
if isinstance(value, (bytes, bytearray, memoryview)):
32+
return base64.b64encode(bytes(value)).decode("ascii")
33+
return value
34+
35+
36+
def _list_user_tables(db: ControllerDB) -> list[str]:
37+
with db.read_snapshot() as snap:
38+
rows = snap.fetchall(
39+
"SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
40+
)
41+
return [str(row["name"]) for row in rows if str(row["name"]) not in EXCLUDED_TABLES]
42+
43+
44+
def _primary_key_columns(db: ControllerDB, table: str) -> list[str]:
45+
"""Return the table's primary key columns in declared order.
46+
47+
Falls back to *all* columns (also in declared order) when the table
48+
has no PK so dumps remain stable regardless of insert order.
49+
"""
50+
with db.read_snapshot() as snap:
51+
rows = snap.fetchall(f"PRAGMA table_info({table})")
52+
pk_cols = sorted(
53+
((int(row["pk"]), str(row["name"]), int(row["cid"])) for row in rows if int(row["pk"]) > 0),
54+
key=lambda triple: triple[0],
55+
)
56+
if pk_cols:
57+
return [name for _, name, _ in pk_cols]
58+
# No PK declared — sort by every column for determinism.
59+
return [str(row["name"]) for row in sorted(rows, key=lambda r: int(r["cid"]))]
60+
61+
62+
def deterministic_dump(db: ControllerDB) -> dict[str, list[dict[str, Any]]]:
63+
"""Dump every user table as ``{table: [row_dicts]}``.
64+
65+
Rows are returned as ordinary dicts in column-declaration order and
66+
sorted by primary key (or every column when no PK exists). Bytes
67+
columns are base64-encoded. Used as the canonical state-equivalence
68+
check for replay scenarios.
69+
"""
70+
out: dict[str, list[dict[str, Any]]] = {}
71+
for table in _list_user_tables(db):
72+
pk = _primary_key_columns(db, table)
73+
order = ", ".join(pk)
74+
with db.read_snapshot() as snap:
75+
rows = snap.fetchall(f"SELECT * FROM {table} ORDER BY {order}")
76+
out[table] = [{key: _encode(row[key]) for key in row.keys()} for row in rows]
77+
return out
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""``IrisEvent`` dataclass union + ``apply_event`` dispatcher.
5+
6+
Each variant captures the arguments of one public mutation method on
7+
``ControllerTransitions``. ``apply_event`` opens a write transaction
8+
and invokes the matching method.
9+
10+
Multi-transaction orchestrators (``fail_workers``, ``prune_old_data``)
11+
and ``*_for_test`` helpers are intentionally excluded — scenarios call
12+
those methods directly when needed.
13+
"""
14+
15+
from collections.abc import Mapping
16+
from dataclasses import dataclass
17+
from typing import Any
18+
19+
from iris.cluster.controller.schema import EndpointRow
20+
from iris.cluster.controller.transitions import (
21+
Assignment,
22+
ControllerTransitions,
23+
HeartbeatApplyRequest,
24+
ReservationClaim,
25+
TaskUpdate,
26+
)
27+
from iris.cluster.types import JobName, WorkerId
28+
from iris.rpc import controller_pb2, job_pb2
29+
from rigging.timing import Timestamp
30+
31+
32+
@dataclass(frozen=True, slots=True)
33+
class SubmitJob:
34+
job_id: JobName
35+
request: controller_pb2.Controller.LaunchJobRequest
36+
ts: Timestamp
37+
38+
39+
@dataclass(frozen=True, slots=True)
40+
class CancelJob:
41+
job_id: JobName
42+
reason: str
43+
44+
45+
@dataclass(frozen=True, slots=True)
46+
class RegisterOrRefreshWorker:
47+
worker_id: WorkerId
48+
address: str
49+
metadata: job_pb2.WorkerMetadata
50+
ts: Timestamp
51+
slice_id: str = ""
52+
scale_group: str = ""
53+
54+
55+
@dataclass(frozen=True, slots=True)
56+
class QueueAssignments:
57+
assignments: list[Assignment]
58+
direct_dispatch: bool = False
59+
60+
61+
@dataclass(frozen=True, slots=True)
62+
class ApplyTaskUpdates:
63+
request: HeartbeatApplyRequest
64+
65+
66+
@dataclass(frozen=True, slots=True)
67+
class ApplyHeartbeatsBatch:
68+
requests: list[HeartbeatApplyRequest]
69+
70+
71+
@dataclass(frozen=True, slots=True)
72+
class PreemptTask:
73+
task_id: JobName
74+
reason: str
75+
76+
77+
@dataclass(frozen=True, slots=True)
78+
class CancelTasksForTimeout:
79+
task_ids: frozenset[JobName]
80+
reason: str
81+
82+
83+
@dataclass(frozen=True, slots=True)
84+
class MarkTaskUnschedulable:
85+
task_id: JobName
86+
reason: str
87+
88+
89+
@dataclass(frozen=True, slots=True)
90+
class RemoveFinishedJob:
91+
job_id: JobName
92+
93+
94+
@dataclass(frozen=True, slots=True)
95+
class RemoveWorker:
96+
worker_id: WorkerId
97+
98+
99+
@dataclass(frozen=True, slots=True)
100+
class UpdateWorkerPings:
101+
snapshots: Mapping[WorkerId, job_pb2.WorkerResourceSnapshot | None]
102+
103+
104+
@dataclass(frozen=True, slots=True)
105+
class DrainForDirectProvider:
106+
max_promotions: int = 16
107+
108+
109+
@dataclass(frozen=True, slots=True)
110+
class ApplyDirectProviderUpdates:
111+
updates: list[TaskUpdate]
112+
113+
114+
@dataclass(frozen=True, slots=True)
115+
class BufferDirectKill:
116+
task_id: str
117+
118+
119+
@dataclass(frozen=True, slots=True)
120+
class AddEndpoint:
121+
endpoint: EndpointRow
122+
123+
124+
@dataclass(frozen=True, slots=True)
125+
class RemoveEndpoint:
126+
endpoint_id: str
127+
128+
129+
@dataclass(frozen=True, slots=True)
130+
class ReplaceReservationClaims:
131+
claims: dict[WorkerId, ReservationClaim]
132+
133+
134+
IrisEvent = (
135+
SubmitJob
136+
| CancelJob
137+
| RegisterOrRefreshWorker
138+
| QueueAssignments
139+
| ApplyTaskUpdates
140+
| ApplyHeartbeatsBatch
141+
| PreemptTask
142+
| CancelTasksForTimeout
143+
| MarkTaskUnschedulable
144+
| RemoveFinishedJob
145+
| RemoveWorker
146+
| UpdateWorkerPings
147+
| DrainForDirectProvider
148+
| ApplyDirectProviderUpdates
149+
| BufferDirectKill
150+
| AddEndpoint
151+
| RemoveEndpoint
152+
| ReplaceReservationClaims
153+
)
154+
155+
156+
def apply_event(transitions: ControllerTransitions, event: IrisEvent) -> Any:
157+
"""Dispatch ``event`` to the matching method on ``ControllerTransitions``.
158+
159+
Each transition method opens its own transaction — the dispatcher is
160+
a thin match on the event variant. Returns whatever the underlying
161+
method returns (``TxResult``, ``SubmitJobResult``,
162+
``DirectProviderBatch``, etc.).
163+
"""
164+
match event:
165+
case SubmitJob(job_id, request, ts):
166+
return transitions.submit_job(job_id, request, ts)
167+
case CancelJob(job_id, reason):
168+
return transitions.cancel_job(job_id, reason)
169+
case RegisterOrRefreshWorker(worker_id, address, metadata, ts, slice_id, scale_group):
170+
return transitions.register_or_refresh_worker(
171+
worker_id=worker_id,
172+
address=address,
173+
metadata=metadata,
174+
ts=ts,
175+
slice_id=slice_id,
176+
scale_group=scale_group,
177+
)
178+
case QueueAssignments(assignments, direct_dispatch):
179+
return transitions.queue_assignments(assignments, direct_dispatch=direct_dispatch)
180+
case ApplyTaskUpdates(request):
181+
return transitions.apply_task_updates(request)
182+
case ApplyHeartbeatsBatch(requests):
183+
return transitions.apply_heartbeats_batch(requests)
184+
case PreemptTask(task_id, reason):
185+
return transitions.preempt_task(task_id, reason)
186+
case CancelTasksForTimeout(task_ids, reason):
187+
return transitions.cancel_tasks_for_timeout(set(task_ids), reason)
188+
case MarkTaskUnschedulable(task_id, reason):
189+
return transitions.mark_task_unschedulable(task_id, reason)
190+
case RemoveFinishedJob(job_id):
191+
return transitions.remove_finished_job(job_id)
192+
case RemoveWorker(worker_id):
193+
return transitions.remove_worker(worker_id)
194+
case UpdateWorkerPings(snapshots):
195+
return transitions.update_worker_pings(snapshots)
196+
case DrainForDirectProvider(max_promotions):
197+
return transitions.drain_for_direct_provider(max_promotions)
198+
case ApplyDirectProviderUpdates(updates):
199+
return transitions.apply_direct_provider_updates(updates)
200+
case BufferDirectKill(task_id):
201+
return transitions.buffer_direct_kill(task_id)
202+
case AddEndpoint(endpoint):
203+
return transitions.add_endpoint(endpoint)
204+
case RemoveEndpoint(endpoint_id):
205+
return transitions.remove_endpoint(endpoint_id)
206+
case ReplaceReservationClaims(claims):
207+
return transitions.replace_reservation_claims(claims)
208+
case _:
209+
raise TypeError(f"unhandled IrisEvent variant: {type(event).__name__}")

0 commit comments

Comments
 (0)