Skip to content

Commit 5805663

Browse files
committed
[iris] Lift transactions to entrypoints; ControllerTransitions methods take cur
Transactions used to be opened ~22 times inside ControllerTransitions methods. That had three problems: (1) RPC handlers couldn't compose multiple transitions atomically — the read-validate-then-write pattern left a TOCTOU window; (2) transitions.py had a self._db backdoor next to its self._store, blurring the layering rule; (3) the per-method ``with self._db.transaction()`` boilerplate said the same thing 22 times. Move tx scope to the caller. Every "one-event" transition now takes a ``cur: TransactionCursor`` (or ``snap: Tx`` for read paths); the RPC handlers in service.py and the loop iterations in controller.py each open exactly one ``with self._store.transaction() as cur:`` around their work. Multi-tx orchestrators (``fail_workers`` chunked, ``prune_old_data`` per-row) keep their own internal tx loops because that's their reason for existing. Test helpers prefixed ``_for_test`` keep their auto-tx wrappers since tests call them ad-hoc. Post-commit work that used to live at the tail of methods (attribute cache invalidation in register/remove worker, the ``worker_registered`` audit log) moves into ``cur.on_commit(...)`` hooks so it only fires once the data is durable — and a rolled-back transaction can't leave the in-memory scheduling cache ahead of the DB row. The ``self._db`` field on ``ControllerTransitions`` becomes a thin ``@property`` that delegates to ``self._store._db``. Production code in transitions.py never touches it; the property only exists so existing test helpers that read directly via ``state._db.snapshot()`` keep working. ``ControllerStore.optimize()`` is added so transitions' post-prune ``self._store.optimize()`` call has a place to land. Tests: ~190 call sites in tests/ updated to the cur-passing form (most mechanically via a regex-based transformer). conftest helpers ``submit_job`` / ``submit_direct_job`` / ``register_worker`` / ``dispatch_task`` / ``transition_task`` / ``sync_k8s`` open a tx and pass cur. ``state._db.snapshot()`` reads in conftest switch to ``state._db.read_snapshot()`` so they don't try to BEGIN on the writer connection from inside an open writer tx. Verification: 850 controller-suite tests pass (one pre-existing failure unchanged); pyrefly clean for the controller package; full iris test suite green.
1 parent edac00f commit 5805663

18 files changed

Lines changed: 1942 additions & 1631 deletions

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 36 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,13 +1456,16 @@ def _sync_direct_provider(self) -> None:
14561456
assert isinstance(self._provider, K8sTaskProvider)
14571457
provider = self._provider
14581458
max_promotions = self._promotion_bucket.available
1459-
batch = self._transitions.drain_for_direct_provider(
1460-
max_promotions=max_promotions,
1461-
)
1459+
with self._store.transaction() as cur:
1460+
batch = self._transitions.drain_for_direct_provider(
1461+
cur,
1462+
max_promotions=max_promotions,
1463+
)
14621464
if batch.tasks_to_run:
14631465
self._promotion_bucket.try_acquire(len(batch.tasks_to_run))
14641466
result = provider.sync(batch)
1465-
tx_result = self._transitions.apply_direct_provider_updates(result.updates)
1467+
with self._store.transaction() as cur:
1468+
tx_result = self._transitions.apply_direct_provider_updates(cur, result.updates)
14661469
self._provider_scheduling_events = list(result.scheduling_events) if result.scheduling_events else []
14671470
self._provider_capacity = result.capacity
14681471
if tx_result.tasks_to_kill:
@@ -1625,7 +1628,8 @@ def _cleanup_stale_claims(self, claims: dict[WorkerId, ReservationClaim] | None
16251628
for wid in stale:
16261629
del claims[wid]
16271630
if stale and persisted:
1628-
self._transitions.replace_reservation_claims(claims)
1631+
with self._store.transaction() as cur:
1632+
self._transitions.replace_reservation_claims(cur, claims)
16291633
log_event("reservation_claims_cleaned", "controller", count=len(stale))
16301634
return bool(stale)
16311635

@@ -1673,7 +1677,8 @@ def _claim_workers_for_reservations(self, claims: dict[WorkerId, ReservationClai
16731677
changed = True
16741678
break
16751679
if changed and persisted:
1676-
self._transitions.replace_reservation_claims(claims)
1680+
with self._store.transaction() as cur:
1681+
self._transitions.replace_reservation_claims(cur, claims)
16771682
log_event("reservation_claims_updated", "controller", total_claims=len(claims))
16781683
return changed
16791684

@@ -1769,7 +1774,8 @@ def _refresh_reservation_claims(self) -> dict[WorkerId, ReservationClaim]:
17691774
if self._config.dry_run:
17701775
logger.info("[DRY-RUN] Would update %d reservation claims", len(claims))
17711776
else:
1772-
self._transitions.replace_reservation_claims(claims)
1777+
with self._store.transaction() as cur:
1778+
self._transitions.replace_reservation_claims(cur, claims)
17731779
return claims
17741780

17751781
def _read_scheduling_state(self) -> _SchedulingStateRead:
@@ -1982,7 +1988,10 @@ def _apply_preemptions(
19821988
)
19831989
preemptions = _run_preemption_pass(unscheduled, running_info, context)
19841990
for preemptor_name, victim_id in preemptions:
1985-
preempt_result = self._transitions.preempt_task(victim_id, reason=f"Preempted by {preemptor_name}")
1991+
with self._store.transaction() as cur:
1992+
preempt_result = self._transitions.preempt_task(
1993+
cur, victim_id, reason=f"Preempted by {preemptor_name}"
1994+
)
19861995
self.kill_tasks_on_workers(preempt_result.tasks_to_kill)
19871996
if preemptions:
19881997
logger.info("Preemption pass: %d tasks preempted", len(preemptions))
@@ -2052,7 +2061,8 @@ def _enforce_execution_timeouts(self) -> None:
20522061
for task in timed_out:
20532062
logger.warning("Task %s exceeded execution timeout, killing", task.task_id)
20542063
task_ids = {t.task_id for t in timed_out}
2055-
result = self._transitions.cancel_tasks_for_timeout(task_ids, reason="Execution timeout exceeded")
2064+
with self._store.transaction() as cur:
2065+
result = self._transitions.cancel_tasks_for_timeout(cur, task_ids, reason="Execution timeout exceeded")
20562066
if result.tasks_to_kill:
20572067
self.kill_tasks_on_workers(result.tasks_to_kill, result.task_kill_workers)
20582068

@@ -2067,10 +2077,12 @@ def _mark_task_unschedulable(self, task: TaskRow) -> None:
20672077
else:
20682078
timeout = None
20692079
logger.warning(f"Task {task.task_id} exceeded scheduling timeout ({timeout}), marking as UNSCHEDULABLE")
2070-
result = self._transitions.mark_task_unschedulable(
2071-
task.task_id,
2072-
reason=f"Scheduling timeout exceeded ({timeout})",
2073-
)
2080+
with self._store.transaction() as cur:
2081+
result = self._transitions.mark_task_unschedulable(
2082+
cur,
2083+
task.task_id,
2084+
reason=f"Scheduling timeout exceeded ({timeout})",
2085+
)
20742086
if result.tasks_to_kill:
20752087
self.kill_tasks_on_workers(result.tasks_to_kill, result.task_kill_workers)
20762088

@@ -2099,8 +2111,9 @@ def kill_tasks_on_workers(
20992111
self._stop_tasks_direct(task_ids, task_kill_workers)
21002112
return
21012113
# K8s: buffer direct kills for the provider sync loop.
2102-
for task_id in task_ids:
2103-
self._transitions.buffer_direct_kill(task_id.to_wire())
2114+
with self._store.transaction() as cur:
2115+
for task_id in task_ids:
2116+
self._transitions.buffer_direct_kill(cur, task_id.to_wire())
21042117

21052118
# =========================================================================
21062119
# Worker lifecycle RPC dispatch (StartTasks / StopTasks / Ping / PollTasks)
@@ -2116,7 +2129,8 @@ def _dispatch_assignments_direct(
21162129
logger.info("[DRY-RUN] Would assign task %s to worker %s", task_id, worker_id)
21172130
return
21182131
command = [Assignment(task_id=task_id, worker_id=worker_id) for task_id, worker_id in assignments]
2119-
result = self._transitions.queue_assignments(command, direct_dispatch=True)
2132+
with self._store.transaction() as cur:
2133+
result = self._transitions.queue_assignments(cur, command, direct_dispatch=True)
21202134

21212135
# Group StartTasks payloads by (worker_id, address)
21222136
by_worker: dict[tuple[WorkerId, str], list[job_pb2.RunTaskRequest]] = {}
@@ -2245,7 +2259,8 @@ def _run_ping_loop(self, stop_event: threading.Event) -> None:
22452259
self._health.ping(result.worker_id, healthy=True)
22462260
ping_snapshots[result.worker_id] = result.resource_snapshot if update_resources else None
22472261

2248-
self._transitions.update_worker_pings(ping_snapshots)
2262+
with self._store.transaction() as cur:
2263+
self._transitions.update_worker_pings(cur, ping_snapshots)
22492264

22502265
unhealthy = self._health.workers_over_threshold()
22512266
if unhealthy:
@@ -2268,7 +2283,8 @@ def _poll_all_workers(self) -> None:
22682283
"""Poll all workers for task state and feed results into the updater queue."""
22692284
if self._config.dry_run:
22702285
return
2271-
running, addresses = self._transitions.get_running_tasks_for_poll()
2286+
with self._store.read_snapshot() as snap:
2287+
running, addresses = self._transitions.get_running_tasks_for_poll(snap)
22722288
if not running:
22732289
return
22742290
poll_results = self._provider.poll_workers(running, addresses)
@@ -2296,7 +2312,8 @@ def _run_task_updater_loop(self, stop_event: threading.Event) -> None:
22962312
if not requests or stop_event.is_set():
22972313
continue
22982314
try:
2299-
results = self._transitions.apply_heartbeats_batch(requests)
2315+
with self._store.transaction() as cur:
2316+
results = self._transitions.apply_heartbeats_batch(cur, requests)
23002317
all_tasks_to_kill: set[JobName] = set()
23012318
all_task_kill_workers: dict[JobName, WorkerId] = {}
23022319
for result in results:

lib/iris/src/iris/cluster/controller/service.py

Lines changed: 35 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,19 +1161,22 @@ def launch_job(
11611161
if not is_job_finished(existing_job.state):
11621162
return controller_pb2.Controller.LaunchJobResponse(job_id=job_id.to_wire())
11631163
# Job finished, replace it (KEEP only preserves running jobs)
1164-
self._transitions.remove_finished_job(job_id)
1164+
with self._store.transaction() as cur:
1165+
self._transitions.remove_finished_job(cur, job_id)
11651166
elif policy == job_pb2.EXISTING_JOB_POLICY_RECREATE:
1166-
if not is_job_finished(existing_job.state):
1167-
self._transitions.cancel_job(job_id, "Replaced by new submission")
1168-
self._transitions.remove_finished_job(job_id)
1167+
with self._store.transaction() as cur:
1168+
if not is_job_finished(existing_job.state):
1169+
self._transitions.cancel_job(cur, job_id, "Replaced by new submission")
1170+
self._transitions.remove_finished_job(cur, job_id)
11691171
elif is_job_finished(existing_job.state):
11701172
# Default/UNSPECIFIED: replace finished jobs
11711173
logger.info(
11721174
"Replacing finished job %s (state=%s) with new submission",
11731175
job_id,
11741176
job_pb2.JobState.Name(existing_job.state),
11751177
)
1176-
self._transitions.remove_finished_job(job_id)
1178+
with self._store.transaction() as cur:
1179+
self._transitions.remove_finished_job(cur, job_id)
11771180
else:
11781181
raise ConnectError(Code.ALREADY_EXISTS, f"Job {job_id} already exists and is still running")
11791182

@@ -1228,7 +1231,8 @@ def launch_job(
12281231
f"Job {job_id} is unschedulable: {error} (constraints: {constraints})",
12291232
)
12301233

1231-
self._transitions.submit_job(job_id, request, Timestamp.now())
1234+
with self._store.transaction() as cur:
1235+
self._transitions.submit_job(cur, job_id, request, Timestamp.now())
12321236
self._controller.wake()
12331237

12341238
with self._db.read_snapshot() as q:
@@ -1380,7 +1384,8 @@ def terminate_job(
13801384
self._authorize_job_owner(job_id)
13811385
# cancel_job uses a recursive CTE to walk the full subtree in a single
13821386
# transaction, so there is no need to recurse manually.
1383-
result = self._transitions.cancel_job(job_id, reason="Terminated by user")
1387+
with self._store.transaction() as cur:
1388+
result = self._transitions.cancel_job(cur, job_id, reason="Terminated by user")
13841389
if result.tasks_to_kill:
13851390
self._controller.kill_tasks_on_workers(result.tasks_to_kill, result.task_kill_workers)
13861391
return job_pb2.Empty()
@@ -1625,14 +1630,16 @@ def register(
16251630
)
16261631
worker_id = WorkerId(request.worker_id)
16271632

1628-
self._transitions.register_or_refresh_worker(
1629-
worker_id=worker_id,
1630-
address=request.address,
1631-
metadata=request.metadata,
1632-
ts=Timestamp.now(),
1633-
slice_id=request.slice_id,
1634-
scale_group=request.scale_group,
1635-
)
1633+
with self._store.transaction() as cur:
1634+
self._transitions.register_or_refresh_worker(
1635+
cur,
1636+
worker_id=worker_id,
1637+
address=request.address,
1638+
metadata=request.metadata,
1639+
ts=Timestamp.now(),
1640+
slice_id=request.slice_id,
1641+
scale_group=request.scale_group,
1642+
)
16361643

16371644
logger.info("Worker registered: %s at %s", worker_id, request.address)
16381645
return controller_pb2.Controller.RegisterResponse(
@@ -1711,7 +1718,9 @@ def register_endpoint(
17111718
registered_at=Timestamp.now(),
17121719
)
17131720

1714-
if not self._transitions.add_endpoint(endpoint):
1721+
with self._store.transaction() as cur:
1722+
added = self._transitions.add_endpoint(cur, endpoint)
1723+
if not added:
17151724
raise ConnectError(
17161725
Code.FAILED_PRECONDITION,
17171726
f"Task {request.task_id} is already terminal; endpoint not registered",
@@ -1725,7 +1734,8 @@ def unregister_endpoint(
17251734
ctx: Any,
17261735
) -> job_pb2.Empty:
17271736
"""Unregister a service endpoint. Idempotent."""
1728-
self._transitions.remove_endpoint(request.endpoint_id)
1737+
with self._store.transaction() as cur:
1738+
self._transitions.remove_endpoint(cur, request.endpoint_id)
17291739
return job_pb2.Empty()
17301740

17311741
def list_endpoints(
@@ -2683,12 +2693,14 @@ def update_task_status(
26832693
"""
26842694
updates = task_updates_from_proto(request.updates)
26852695
if updates:
2686-
self._transitions.apply_task_updates(
2687-
HeartbeatApplyRequest(
2688-
worker_id=WorkerId(request.worker_id),
2689-
worker_resource_snapshot=None,
2690-
updates=updates,
2696+
with self._store.transaction() as cur:
2697+
self._transitions.apply_task_updates(
2698+
cur,
2699+
HeartbeatApplyRequest(
2700+
worker_id=WorkerId(request.worker_id),
2701+
worker_resource_snapshot=None,
2702+
updates=updates,
2703+
),
26912704
)
2692-
)
26932705
self._controller.wake()
26942706
return controller_pb2.Controller.UpdateTaskStatusResponse()

lib/iris/src/iris/cluster/controller/stores.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2070,3 +2070,6 @@ def transaction(self):
20702070

20712071
def read_snapshot(self):
20722072
return self._db.read_snapshot()
2073+
2074+
def optimize(self) -> None:
2075+
self._db.optimize()

0 commit comments

Comments
 (0)