Skip to content

Commit f04996f

Browse files
committed
add controller grace period.
1 parent 28b4914 commit f04996f

2 files changed

Lines changed: 193 additions & 3 deletions

File tree

lib/iris/src/iris/cluster/controller/controller.py

Lines changed: 91 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
DIRECT_PROVIDER_PROMOTION_RATE,
100100
HeartbeatApplyRequest,
101101
ReservationClaim,
102+
RunningTaskEntry,
102103
SchedulingEvent,
103104
TaskUpdate,
104105
)
@@ -1024,6 +1025,13 @@ class Controller:
10241025
the controller will run it in a background thread.
10251026
"""
10261027

1028+
# Grace window during which a freshly-dispatched task (via StartTasks) is
1029+
# treated as expected when polling its worker. Bridges the StartTasks →
1030+
# PollTasks race: a poll's DB snapshot taken before the assignment commit
1031+
# would otherwise omit the task, and the worker would kill it as
1032+
# unexpected. 30s comfortably exceeds any normal StartTasks RPC latency.
1033+
_RECENT_DISPATCH_GRACE_SECONDS = 30.0
1034+
10271035
def __init__(
10281036
self,
10291037
config: ControllerConfig,
@@ -1130,6 +1138,14 @@ def __init__(
11301138
self._poll_thread: ManagedThread | None = None
11311139
self._task_update_queue: queue_mod.Queue[HeartbeatApplyRequest] = queue_mod.Queue()
11321140

1141+
# Track tasks dispatched via StartTasks but whose assignment commit
1142+
# may not yet be visible to a concurrent PollTasks snapshot. Merged
1143+
# into expected_tasks in _poll_all_workers so the worker does not
1144+
# treat a freshly-submitted task as "unexpected" and kill it. See
1145+
# _RECENT_DISPATCH_GRACE_SECONDS and _dispatch_assignments_direct.
1146+
self._recent_dispatches: dict[WorkerId, dict[tuple[str, int], float]] = {}
1147+
self._recent_dispatches_lock = threading.Lock()
1148+
11331149
self._autoscaler: Autoscaler | None = autoscaler
11341150

11351151
self._heartbeat_iteration = 0
@@ -2168,6 +2184,14 @@ def _dispatch_assignments_direct(
21682184
command = [Assignment(task_id=task_id, worker_id=worker_id) for task_id, worker_id in assignments]
21692185
result = self._transitions.queue_assignments(command, direct_dispatch=True)
21702186

2187+
# Register dispatches before issuing StartTasks so a concurrent poll
2188+
# whose DB snapshot predates the assignment commit still sees these
2189+
# tasks as expected. Any poll that observes the entries won't kill
2190+
# the task; any poll that misses them can't yet collide with the
2191+
# StartTasks RPC (not sent until after this point) on the same
2192+
# gRPC channel.
2193+
self._record_recent_dispatches(result.start_requests)
2194+
21712195
# Group StartTasks payloads by (worker_id, address)
21722196
by_worker: dict[tuple[WorkerId, str], list[job_pb2.RunTaskRequest]] = {}
21732197
for worker_id, address, run_request in result.start_requests:
@@ -2318,24 +2342,88 @@ def _run_poll_loop(self, stop_event: threading.Event) -> None:
23182342
except Exception:
23192343
logger.exception("Poll loop iteration failed")
23202344

2345+
def _record_recent_dispatches(
2346+
self,
2347+
start_requests: list[tuple[WorkerId, str, job_pb2.RunTaskRequest]],
2348+
) -> None:
2349+
"""Record (worker, task, attempt) dispatches with a monotonic timestamp."""
2350+
if not start_requests:
2351+
return
2352+
now = time.monotonic()
2353+
with self._recent_dispatches_lock:
2354+
for worker_id, _address, run_request in start_requests:
2355+
self._recent_dispatches.setdefault(worker_id, {})[(run_request.task_id, run_request.attempt_id)] = now
2356+
2357+
def _prune_and_snapshot_recent_dispatches(self) -> dict[WorkerId, set[tuple[str, int]]]:
2358+
"""Drop stale entries and return a per-worker snapshot of keys within the grace window."""
2359+
cutoff = time.monotonic() - self._RECENT_DISPATCH_GRACE_SECONDS
2360+
snapshot: dict[WorkerId, set[tuple[str, int]]] = {}
2361+
with self._recent_dispatches_lock:
2362+
for wid in list(self._recent_dispatches.keys()):
2363+
entries = self._recent_dispatches[wid]
2364+
for key in [k for k, ts in entries.items() if ts < cutoff]:
2365+
del entries[key]
2366+
if not entries:
2367+
del self._recent_dispatches[wid]
2368+
else:
2369+
snapshot[wid] = set(entries.keys())
2370+
return snapshot
2371+
23212372
def _poll_all_workers(self) -> None:
2322-
"""Poll all workers for task state and feed results into the updater queue."""
2373+
"""Poll all workers for task state and feed results into the updater queue.
2374+
2375+
Merges recently-dispatched (worker, task, attempt) keys into each
2376+
worker's expected_tasks so a stale DB snapshot can't cause the worker
2377+
to kill a just-submitted task. Suppresses "Task not found on worker"
2378+
updates for entries still within the grace window: the StartTasks RPC
2379+
may not yet have landed when PollTasks arrived.
2380+
"""
23232381
if self._config.dry_run:
23242382
return
23252383
running, addresses = self._transitions.get_running_tasks_for_poll()
2384+
recent = self._prune_and_snapshot_recent_dispatches()
2385+
2386+
# Merge recent dispatches into the per-worker expected list. Ensures
2387+
# a worker with only freshly-dispatched tasks (not yet visible in the
2388+
# DB snapshot) is still polled with the correct expected set.
2389+
for wid, keys in recent.items():
2390+
if wid not in addresses:
2391+
continue
2392+
entries = running.setdefault(wid, [])
2393+
existing = {(e.task_id.to_wire(), e.attempt_id) for e in entries}
2394+
for task_wire, attempt_id in keys:
2395+
if (task_wire, attempt_id) in existing:
2396+
continue
2397+
entries.append(
2398+
RunningTaskEntry(
2399+
task_id=JobName.from_wire(task_wire),
2400+
attempt_id=attempt_id,
2401+
)
2402+
)
2403+
23262404
if not running:
23272405
return
23282406
poll_results = self._provider.poll_workers(running, addresses)
23292407
for worker_id, updates, error in poll_results:
23302408
if error is not None:
23312409
logger.warning("PollTasks failed for worker %s: %s", worker_id, error)
23322410
continue
2333-
if updates:
2411+
if not updates:
2412+
continue
2413+
# Drop every update for a task still inside the dispatch grace
2414+
# window. The worker's view of a freshly-submitted task is
2415+
# unreliable during this race (either it already has it and would
2416+
# report running, or it doesn't yet and would report "not found"
2417+
# as WORKER_FAILED). Let the next poll settle this once the grace
2418+
# window closes.
2419+
grace_keys = recent.get(worker_id, set())
2420+
filtered = [u for u in updates if (u.task_id.to_wire(), u.attempt_id) not in grace_keys]
2421+
if filtered:
23342422
self._task_update_queue.put(
23352423
HeartbeatApplyRequest(
23362424
worker_id=worker_id,
23372425
worker_resource_snapshot=None,
2338-
updates=updates,
2426+
updates=filtered,
23392427
)
23402428
)
23412429

lib/iris/tests/cluster/controller/test_heartbeat.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,3 +341,105 @@ def test_heartbeat_failure_error_includes_rpc_context():
341341
assert "expected=1" in error
342342

343343
provider.close()
344+
345+
346+
def test_poll_merges_recent_dispatches_into_expected_tasks(tmp_path, worker_metadata):
347+
"""A poll whose DB snapshot predates the assignment commit still includes
348+
freshly-dispatched tasks in expected_tasks, so the worker won't kill them.
349+
350+
Reproduces the controller side of the StartTasks→PollTasks race: we record
351+
a dispatch for worker1 without committing any matching ASSIGNED row in the
352+
DB, then run _poll_all_workers and confirm the task appears in the
353+
expected_tasks passed to poll_workers.
354+
"""
355+
db = ControllerDB(db_dir=tmp_path)
356+
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)
357+
358+
captured: dict[str, list[RunningTaskEntry]] = {}
359+
360+
class RecordingProvider(FakeProvider):
361+
def poll_workers(self, running, worker_addresses):
362+
for wid, entries in running.items():
363+
captured[str(wid)] = list(entries)
364+
return []
365+
366+
controller = Controller(config=config, provider=RecordingProvider(), db=db)
367+
state = controller.state
368+
_register_worker(state, "worker1", worker_metadata, address="10.0.0.1:10001")
369+
370+
task_wire = JobName.from_wire("/user/dispatch-race/0").to_wire()
371+
run_request = job_pb2.RunTaskRequest(task_id=task_wire, attempt_id=1)
372+
controller._record_recent_dispatches([(WorkerId("worker1"), "10.0.0.1:10001", run_request)])
373+
374+
controller._poll_all_workers()
375+
376+
assert "worker1" in captured, "worker1 must be polled so expected_tasks covers the dispatch"
377+
wire_keys = {(e.task_id.to_wire(), e.attempt_id) for e in captured["worker1"]}
378+
assert (task_wire, 1) in wire_keys
379+
380+
controller.stop()
381+
db.close()
382+
383+
384+
def test_poll_suppresses_worker_failed_within_grace_window(tmp_path, worker_metadata):
385+
"""A WORKER_FAILED update for a task still in the dispatch grace window
386+
is dropped: this is the reverse race where PollTasks reaches the worker
387+
before StartTasks has been applied, producing a spurious 'not found'.
388+
"""
389+
db = ControllerDB(db_dir=tmp_path)
390+
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)
391+
392+
task_wire = JobName.from_wire("/user/dispatch-race/0").to_wire()
393+
394+
class FailingProvider(FakeProvider):
395+
def poll_workers(self, running, worker_addresses):
396+
return [
397+
(
398+
WorkerId("worker1"),
399+
[
400+
TaskUpdate(
401+
task_id=JobName.from_wire(task_wire),
402+
attempt_id=1,
403+
new_state=job_pb2.TASK_STATE_WORKER_FAILED,
404+
error="Task not found on worker",
405+
)
406+
],
407+
None,
408+
)
409+
]
410+
411+
controller = Controller(config=config, provider=FailingProvider(), db=db)
412+
state = controller.state
413+
_register_worker(state, "worker1", worker_metadata, address="10.0.0.1:10001")
414+
415+
run_request = job_pb2.RunTaskRequest(task_id=task_wire, attempt_id=1)
416+
controller._record_recent_dispatches([(WorkerId("worker1"), "10.0.0.1:10001", run_request)])
417+
418+
controller._poll_all_workers()
419+
420+
assert controller._task_update_queue.empty(), "WORKER_FAILED should be suppressed within the grace window"
421+
422+
controller.stop()
423+
db.close()
424+
425+
426+
def test_prune_recent_dispatches_drops_stale_entries(tmp_path, worker_metadata):
427+
"""Stale recent-dispatch entries past the grace window are pruned."""
428+
db = ControllerDB(db_dir=tmp_path)
429+
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)
430+
controller = Controller(config=config, provider=FakeProvider(), db=db)
431+
432+
grace = controller._RECENT_DISPATCH_GRACE_SECONDS
433+
now = time.monotonic()
434+
controller._recent_dispatches[WorkerId("worker1")] = {
435+
("/user/fresh/0", 0): now,
436+
("/user/stale/0", 0): now - (grace + 1.0),
437+
}
438+
439+
snapshot = controller._prune_and_snapshot_recent_dispatches()
440+
441+
assert snapshot == {WorkerId("worker1"): {("/user/fresh/0", 0)}}
442+
assert controller._recent_dispatches[WorkerId("worker1")] == {("/user/fresh/0", 0): now}
443+
444+
controller.stop()
445+
db.close()

0 commit comments

Comments
 (0)