Skip to content

Commit d09484f

Browse files
committed
[iris] Add event-replay testing system for transitions equivalence
Introduces a curated replay harness for ControllerTransitions: - IrisEvent: a frozen dataclass union, one variant per public mutation method on ControllerTransitions, with a thin dispatcher that calls the matching method. - SQL trace hook: ControllerDB now exposes a class-level _trace_callback slot that _configure() registers on every connection (production never sets it). - Deterministic DB dump: per-table JSON sorted by primary key for state-equivalence assertions. - 13 curated scenarios exercising submit/assign/run/succeed, retries, preemption, worker-failure cascade, coscheduled timeout, direct-provider lifecycle, prune sweep, endpoints, and reservation claims. - pytest goldens (DB state strict, SQL trace informational) under lib/iris/tests/cluster/controller/replay/golden/, with a --update-goldens regeneration flag. - CLI: ``uv run python -m iris.cluster.controller.replay.run [--seed=DIR] [--scenario=NAMES] --out=DIR`` writes db.json and sql.txt per scenario. Companion to the rjpower/iris-sql-store branch: this PR delivers the testing infrastructure that proves the in-flight refactor preserves DB-state equivalence across all curated scenarios.
1 parent 30f6b6c commit d09484f

38 files changed

Lines changed: 11659 additions & 0 deletions

File tree

lib/iris/src/iris/cluster/controller/db.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,14 @@ class ControllerDB:
291291
AUTH_DB_FILENAME = "auth.sqlite3"
292292
PROFILES_DB_FILENAME = "profiles.sqlite3"
293293

294+
# Class-level SQL trace hook: when set, every connection (writer + read pool)
295+
# opened via _configure() registers this callback via
296+
# ``conn.set_trace_callback``. Exposed for the replay-system tests
297+
# (``iris.cluster.controller.replay``); production never sets this. Must be
298+
# set BEFORE ``ControllerDB(...)`` is constructed so the writer connection
299+
# picks it up — connections opened earlier will not retroactively trace.
300+
_trace_callback: Callable[[str], None] | None = None
301+
294302
def __init__(self, db_dir: Path):
295303
import time
296304

@@ -412,6 +420,8 @@ def _configure(conn: sqlite3.Connection) -> None:
412420
conn.execute("PRAGMA synchronous = NORMAL")
413421
conn.execute("PRAGMA busy_timeout = 5000")
414422
conn.execute("PRAGMA foreign_keys = ON")
423+
if (cb := ControllerDB._trace_callback) is not None:
424+
conn.set_trace_callback(cb)
415425

416426
def optimize(self) -> None:
417427
"""Run PRAGMA optimize to refresh statistics for tables with stale data.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Event-replay testing system for ControllerTransitions.
5+
6+
This package defines a frozen ``IrisEvent`` union — one variant per public
7+
mutation method on ``ControllerTransitions`` — together with a dispatcher,
8+
a SQLite trace hook, a deterministic DB dump, and a curated set of
9+
scenarios. Used both as a pytest golden suite and as a CLI for diffing
10+
DB state across branches/checkpoints.
11+
"""
12+
13+
from iris.cluster.controller.replay.db_dump import deterministic_dump
14+
from iris.cluster.controller.replay.dispatcher import apply_event
15+
from iris.cluster.controller.replay.events import (
16+
AddEndpoint,
17+
ApplyDirectProviderUpdates,
18+
ApplyHeartbeatsBatch,
19+
ApplyTaskUpdates,
20+
BufferDirectKill,
21+
CancelJob,
22+
CancelTasksForTimeout,
23+
DrainForDirectProvider,
24+
IrisEvent,
25+
MarkTaskUnschedulable,
26+
PreemptTask,
27+
QueueAssignments,
28+
RegisterOrRefreshWorker,
29+
RemoveEndpoint,
30+
RemoveFinishedJob,
31+
RemoveWorker,
32+
ReplaceReservationClaims,
33+
SubmitJob,
34+
UpdateWorkerPings,
35+
)
36+
from iris.cluster.controller.replay.scenarios import (
37+
SCENARIO_NAMES,
38+
SCENARIOS,
39+
run_scenario,
40+
)
41+
from iris.cluster.controller.replay.sql_trace import sql_tracing
42+
43+
__all__ = [
44+
"SCENARIOS",
45+
"SCENARIO_NAMES",
46+
"AddEndpoint",
47+
"ApplyDirectProviderUpdates",
48+
"ApplyHeartbeatsBatch",
49+
"ApplyTaskUpdates",
50+
"BufferDirectKill",
51+
"CancelJob",
52+
"CancelTasksForTimeout",
53+
"DrainForDirectProvider",
54+
"IrisEvent",
55+
"MarkTaskUnschedulable",
56+
"PreemptTask",
57+
"QueueAssignments",
58+
"RegisterOrRefreshWorker",
59+
"RemoveEndpoint",
60+
"RemoveFinishedJob",
61+
"RemoveWorker",
62+
"ReplaceReservationClaims",
63+
"SubmitJob",
64+
"UpdateWorkerPings",
65+
"apply_event",
66+
"deterministic_dump",
67+
"run_scenario",
68+
"sql_tracing",
69+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Entry point so ``python -m iris.cluster.controller.replay`` works."""
5+
6+
from iris.cluster.controller.replay.run import main
7+
8+
if __name__ == "__main__":
9+
raise SystemExit(main())
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Deterministic table-by-table JSON dump of the controller DB.
5+
6+
Produces a stable mapping ``{table_name: [row_dicts]}`` with rows sorted
7+
by primary key. This is the canonical "did the DB end up in the same
8+
state?" check used by the replay golden tests; SQL traces are reviewed
9+
informationally.
10+
11+
Excludes SQLite internal tables and ``schema_migrations`` (which is
12+
populated unconditionally on ``apply_migrations`` and adds noise without
13+
testing any controller behavior).
14+
"""
15+
16+
import base64
17+
from typing import Any
18+
19+
from iris.cluster.controller.db import ControllerDB
20+
21+
EXCLUDED_TABLES: frozenset[str] = frozenset({"schema_migrations"})
22+
"""Tables ignored by ``deterministic_dump`` — schema bookkeeping only."""
23+
24+
25+
def _encode(value: Any) -> Any:
26+
"""Render a SQLite cell value as a JSON-serializable scalar.
27+
28+
``bytes`` columns (notably ``job_workdir_files.data``) become
29+
base64 ASCII so the dump round-trips through ``json.dumps``.
30+
"""
31+
if isinstance(value, (bytes, bytearray, memoryview)):
32+
return base64.b64encode(bytes(value)).decode("ascii")
33+
return value
34+
35+
36+
def _list_user_tables(db: ControllerDB) -> list[str]:
37+
with db.read_snapshot() as snap:
38+
rows = snap.fetchall(
39+
"SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
40+
)
41+
return [str(row["name"]) for row in rows if str(row["name"]) not in EXCLUDED_TABLES]
42+
43+
44+
def _primary_key_columns(db: ControllerDB, table: str) -> list[str]:
45+
"""Return the table's primary key columns in declared order.
46+
47+
Falls back to *all* columns (also in declared order) when the table
48+
has no PK so dumps remain stable regardless of insert order.
49+
"""
50+
with db.read_snapshot() as snap:
51+
rows = snap.fetchall(f"PRAGMA table_info({table})")
52+
pk_cols = sorted(
53+
((int(row["pk"]), str(row["name"]), int(row["cid"])) for row in rows if int(row["pk"]) > 0),
54+
key=lambda triple: triple[0],
55+
)
56+
if pk_cols:
57+
return [name for _, name, _ in pk_cols]
58+
# No PK declared — sort by every column for determinism.
59+
return [str(row["name"]) for row in sorted(rows, key=lambda r: int(r["cid"]))]
60+
61+
62+
def deterministic_dump(db: ControllerDB) -> dict[str, list[dict[str, Any]]]:
63+
"""Dump every user table as ``{table: [row_dicts]}``.
64+
65+
Rows are returned as ordinary dicts in column-declaration order and
66+
sorted by primary key (or every column when no PK exists). Bytes
67+
columns are base64-encoded. Used as the canonical state-equivalence
68+
check for replay scenarios.
69+
"""
70+
out: dict[str, list[dict[str, Any]]] = {}
71+
for table in _list_user_tables(db):
72+
pk = _primary_key_columns(db, table)
73+
order = ", ".join(pk)
74+
with db.read_snapshot() as snap:
75+
rows = snap.fetchall(f"SELECT * FROM {table} ORDER BY {order}")
76+
out[table] = [{key: _encode(row[key]) for key in row.keys()} for row in rows]
77+
return out
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Dispatch an :class:`IrisEvent` to the matching ``ControllerTransitions`` method.
5+
6+
This module targets ``main`` semantics: each transition method opens its
7+
own ``with self._db.transaction()`` block and does not take a ``cur``
8+
argument. The branch ``rjpower/iris-sql-store`` rewrites this single
9+
file into the cur-passing form.
10+
"""
11+
12+
from typing import Any
13+
14+
from iris.cluster.controller.replay.events import (
15+
AddEndpoint,
16+
ApplyDirectProviderUpdates,
17+
ApplyHeartbeatsBatch,
18+
ApplyTaskUpdates,
19+
BufferDirectKill,
20+
CancelJob,
21+
CancelTasksForTimeout,
22+
DrainForDirectProvider,
23+
IrisEvent,
24+
MarkTaskUnschedulable,
25+
PreemptTask,
26+
QueueAssignments,
27+
RegisterOrRefreshWorker,
28+
RemoveEndpoint,
29+
RemoveFinishedJob,
30+
RemoveWorker,
31+
ReplaceReservationClaims,
32+
SubmitJob,
33+
UpdateWorkerPings,
34+
)
35+
from iris.cluster.controller.transitions import ControllerTransitions
36+
37+
38+
def apply_event(transitions: ControllerTransitions, event: IrisEvent) -> Any:
39+
"""Dispatch ``event`` to the matching method on ``transitions``.
40+
41+
Returns whatever the underlying method returns (``TxResult``,
42+
``SubmitJobResult``, ``DirectProviderBatch``, etc.). The dispatcher
43+
is intentionally thin — its purpose is to keep scenarios free of
44+
branch-specific calling-convention details.
45+
"""
46+
match event:
47+
case SubmitJob(job_id, request, ts):
48+
return transitions.submit_job(job_id, request, ts)
49+
case CancelJob(job_id, reason):
50+
return transitions.cancel_job(job_id, reason)
51+
case RegisterOrRefreshWorker(worker_id, address, metadata, ts, slice_id, scale_group):
52+
return transitions.register_or_refresh_worker(
53+
worker_id=worker_id,
54+
address=address,
55+
metadata=metadata,
56+
ts=ts,
57+
slice_id=slice_id,
58+
scale_group=scale_group,
59+
)
60+
case QueueAssignments(assignments, direct_dispatch):
61+
return transitions.queue_assignments(assignments, direct_dispatch=direct_dispatch)
62+
case ApplyTaskUpdates(request):
63+
return transitions.apply_task_updates(request)
64+
case ApplyHeartbeatsBatch(requests):
65+
return transitions.apply_heartbeats_batch(requests)
66+
case PreemptTask(task_id, reason):
67+
return transitions.preempt_task(task_id, reason)
68+
case CancelTasksForTimeout(task_ids, reason):
69+
return transitions.cancel_tasks_for_timeout(set(task_ids), reason)
70+
case MarkTaskUnschedulable(task_id, reason):
71+
return transitions.mark_task_unschedulable(task_id, reason)
72+
case RemoveFinishedJob(job_id):
73+
return transitions.remove_finished_job(job_id)
74+
case RemoveWorker(worker_id):
75+
return transitions.remove_worker(worker_id)
76+
case UpdateWorkerPings(snapshots):
77+
return transitions.update_worker_pings(snapshots)
78+
case DrainForDirectProvider(max_promotions):
79+
return transitions.drain_for_direct_provider(max_promotions)
80+
case ApplyDirectProviderUpdates(updates):
81+
return transitions.apply_direct_provider_updates(updates)
82+
case BufferDirectKill(task_id):
83+
return transitions.buffer_direct_kill(task_id)
84+
case AddEndpoint(endpoint):
85+
return transitions.add_endpoint(endpoint)
86+
case RemoveEndpoint(endpoint_id):
87+
return transitions.remove_endpoint(endpoint_id)
88+
case ReplaceReservationClaims(claims):
89+
return transitions.replace_reservation_claims(claims)
90+
case _:
91+
raise TypeError(f"unhandled IrisEvent variant: {type(event).__name__}")

0 commit comments

Comments
 (0)