Skip to content

Commit 53bf959

Browse files
iris/worker: protect freshly-submitted tasks from StartTasks→PollTasks race (#5043)
## Summary Fixes #5041. When the iris controller dispatches a task via `StartTasks` and polls the worker for state via `PollTasks` before its own view of `expected_tasks` has caught up, the worker treated the just-submitted task as "unexpected" and killed it. That kill rolled up the workers-pool job to `JOB_STATE_KILLED` and cascaded the surviving tasks with `error="Job was terminated."`, surfacing in zephyr as the misleading `"Worker job terminated permanently… Workers likely crashed"` abort. `handle_heartbeat` already guarded against this race by passing `extra_expected_keys` for the tasks it had just submitted in that RPC. `handle_poll_tasks` did not — the PollTasks path has no "tasks_to_run" field because StartTasks is a separate RPC — so freshly-submitted tasks had no protection. ### Approach - Track recent submissions on the worker: `submit_task` now records `(task_id, attempt_id) -> monotonic_time` in `self._recent_submissions`. - `_reconcile_expected_tasks` treats keys within a 30s grace window as expected. - Stale entries are pruned on each reconciliation so the dict stays bounded. - `_reset_worker_state` clears the tracking alongside `self._tasks`. This fixes both `handle_poll_tasks` and the more general case where heartbeat-submitted tasks still need race protection on the following tick. The bespoke `extra_keys` set in `handle_heartbeat` is gone since `submit_task` now populates `_recent_submissions` for all entry points.
1 parent 09234be commit 53bf959

2 files changed

Lines changed: 126 additions & 14 deletions

File tree

lib/iris/src/iris/cluster/worker/worker.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def worker_config_from_proto(
125125
class Worker:
126126
"""Unified worker managing all components and lifecycle."""
127127

128+
# Grace period during which a freshly-submitted task is treated as
129+
# "expected" by reconciliation, even if it hasn't yet appeared in the
130+
# controller's expected_tasks list. Protects against the StartTasks →
131+
# PollTasks race where the controller polls before its internal view
132+
# catches up with the task it just assigned.
133+
_RECENT_SUBMISSION_GRACE_SECONDS = 30.0
134+
128135
def __init__(
129136
self,
130137
config: WorkerConfig,
@@ -176,6 +183,11 @@ def __init__(
176183
# Task state: maps (task_id, attempt_id) -> TaskAttempt.
177184
# Preserves all attempts so logs for historical attempts remain accessible.
178185
self._tasks: dict[tuple[str, int], TaskAttempt] = {}
186+
# Freshly-submitted tasks -> monotonic submission time. Used by
187+
# reconciliation to grant a grace period before a task becomes
188+
# eligible for "unexpected, kill" if the controller hasn't yet
189+
# listed it in expected_tasks. See _RECENT_SUBMISSION_GRACE_SECONDS.
190+
self._recent_submissions: dict[tuple[str, int], float] = {}
179191
self._lock = threading.Lock()
180192

181193
self._host_metrics = HostMetricsCollector(disk_path=str(self._cache_dir))
@@ -586,6 +598,7 @@ def _reset_worker_state(self) -> None:
586598
# Clear task tracking
587599
with self._lock:
588600
self._tasks.clear()
601+
self._recent_submissions.clear()
589602

590603
# Replace the task thread container so new tasks get a fresh group.
591604
self._task_threads = self._threads.create_child("tasks")
@@ -701,6 +714,7 @@ def submit_task(self, request: job_pb2.RunTaskRequest) -> str:
701714

702715
with self._lock:
703716
self._tasks[key] = attempt
717+
self._recent_submissions[key] = time.monotonic()
704718

705719
# Start execution in a monitored non-daemon thread. When stop() is called,
706720
# the on_stop callback kills the container so attempt.run() exits promptly.
@@ -791,19 +805,33 @@ def _missing_task_status(task_id: str, expected_attempt_id: int) -> job_pb2.Work
791805
finished_at=timestamp_to_proto(Timestamp.now()),
792806
)
793807

808+
def _prune_and_get_recent_submission_keys(self) -> set[tuple[str, int]]:
809+
"""Return keys submitted within the grace window, pruning stale entries.
810+
811+
Caller must hold ``self._lock``. Stale entries (older than the grace
812+
window) are removed from ``self._recent_submissions`` so the dict
813+
doesn't grow unbounded.
814+
"""
815+
now = time.monotonic()
816+
cutoff = now - self._RECENT_SUBMISSION_GRACE_SECONDS
817+
stale = [key for key, ts in self._recent_submissions.items() if ts < cutoff]
818+
for key in stale:
819+
del self._recent_submissions[key]
820+
return set(self._recent_submissions)
821+
794822
def _reconcile_expected_tasks(
795823
self,
796824
expected_entries,
797-
extra_expected_keys: set[tuple[str, int]] | None = None,
798825
) -> tuple[list[job_pb2.WorkerTaskStatus], list[tuple[str, int]]]:
799826
"""Build status entries for expected tasks; collect non-terminal local tasks
800827
not in the expected set as targets to kill.
801828
802829
Caller must hold ``self._lock``.
803830
804-
``extra_expected_keys`` keeps freshly-submitted tasks (e.g. ``tasks_to_run``
805-
on the legacy heartbeat) from being killed when they aren't yet in the
806-
controller's expected set.
831+
Freshly-submitted tasks (``self._recent_submissions``) are protected
832+
from reconciliation kills via the grace window, which covers the
833+
StartTasks → PollTasks race where the controller polls before its
834+
internal view catches up with a task it just assigned.
807835
"""
808836
tasks: list[job_pb2.WorkerTaskStatus] = []
809837
expected_keys: set[tuple[str, int]] = set()
@@ -817,8 +845,7 @@ def _reconcile_expected_tasks(
817845
tasks.append(self._missing_task_status(task_id, expected_attempt_id))
818846
else:
819847
tasks.append(self._encode_task_status(task, task_id))
820-
if extra_expected_keys:
821-
expected_keys |= extra_expected_keys
848+
expected_keys |= self._prune_and_get_recent_submission_keys()
822849
tasks_to_kill: list[tuple[str, int]] = []
823850
for key, task in self._tasks.items():
824851
if key not in expected_keys and task.status not in self._TERMINAL_STATES:
@@ -851,8 +878,12 @@ def handle_heartbeat(self, request: job_pb2.HeartbeatRequest) -> job_pb2.Heartbe
851878
found on worker"). This happens when the worker has reset its state
852879
(_tasks.clear() in _reset_worker_state) between heartbeats — from
853880
the controller's perspective this is equivalent to a worker restart.
854-
4. Kill unexpected tasks — any task in self._tasks that is NOT in
855-
expected_tasks or tasks_to_run is killed (controller no longer wants it)
881+
4. Kill unexpected tasks — any non-terminal task in self._tasks that is
882+
NOT in expected_tasks and is not within the recent-submission grace
883+
window is killed (controller no longer wants it). The grace window
884+
keeps tasks just submitted via StartTasks or this heartbeat's
885+
tasks_to_run from being killed when the controller hasn't yet
886+
listed them in expected_tasks.
856887
857888
The ordering guarantee between steps 1 and 3 is critical: a task that
858889
appears in both tasks_to_run and expected_tasks (which is always the case
@@ -891,12 +922,11 @@ def handle_heartbeat(self, request: job_pb2.HeartbeatRequest) -> job_pb2.Heartbe
891922
logger.warning("Heartbeat: failed to kill task %s: %s", task_id, e)
892923

893924
with slow_log(logger, "heartbeat reconciliation", threshold_ms=200):
894-
# tasks_to_run was just submitted above; carry those keys so a
895-
# newly-assigned task isn't killed if the controller hasn't yet
896-
# listed it in expected_tasks.
897-
extra_keys = {(r.task_id, r.attempt_id) for r in request.tasks_to_run}
925+
# tasks_to_run was just submitted above; those keys live in
926+
# self._recent_submissions and are protected from the race by
927+
# _reconcile_expected_tasks' grace-window logic.
898928
with self._lock:
899-
tasks, tasks_to_kill = self._reconcile_expected_tasks(request.expected_tasks, extra_keys)
929+
tasks, tasks_to_kill = self._reconcile_expected_tasks(request.expected_tasks)
900930

901931
# Kill removed tasks asynchronously outside lock to avoid deadlock
902932
for task_id, attempt_id in tasks_to_kill:
@@ -959,7 +989,12 @@ def handle_stop_tasks(self, request: worker_pb2.Worker.StopTasksRequest) -> work
959989
return worker_pb2.Worker.StopTasksResponse()
960990

961991
def handle_poll_tasks(self, request: worker_pb2.Worker.PollTasksRequest) -> worker_pb2.Worker.PollTasksResponse:
962-
"""Report status of expected tasks and kill unexpected tasks."""
992+
"""Report status of expected tasks and kill unexpected tasks.
993+
994+
Freshly-submitted tasks (via StartTasks) are protected from the
995+
StartTasks → PollTasks race by the recent-submission grace window
996+
applied in _reconcile_expected_tasks.
997+
"""
963998
with self._lock:
964999
tasks, tasks_to_kill = self._reconcile_expected_tasks(request.expected_tasks)
9651000
for task_id, attempt_id in tasks_to_kill:

lib/iris/tests/cluster/worker/test_worker.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from iris.cluster.worker.service import WorkerServiceImpl
2828
from iris.cluster.worker.worker import Worker, WorkerConfig
2929
from iris.rpc import job_pb2
30+
from iris.rpc import worker_pb2
3031
from rigging.timing import Duration
3132
from iris.cluster.worker.worker_types import LogLine
3233
from tests.cluster.worker.conftest import (
@@ -479,6 +480,11 @@ def slow_stop(force=False):
479480
task = mock_worker.get_task(task_id_wire)
480481
wait_for_condition(lambda: task.status == job_pb2.TASK_STATE_RUNNING)
481482

483+
# Clear recent-submissions tracking to simulate the task having been
484+
# around long enough for the grace window to have elapsed; this test
485+
# exercises reconciliation-driven kill, not grace-window protection.
486+
mock_worker._recent_submissions.clear()
487+
482488
# Send heartbeat with empty expected_tasks -- the worker should kill
483489
# the running task because it's no longer expected
484490
heartbeat_req = job_pb2.HeartbeatRequest(expected_tasks=[])
@@ -494,6 +500,77 @@ def slow_stop(force=False):
494500
assert task.status == job_pb2.TASK_STATE_KILLED
495501

496502

503+
def test_poll_tasks_grace_window_protects_freshly_submitted_task(mock_worker, mock_runtime):
504+
"""PollTasks must not kill a task submitted moments before the controller polls.
505+
506+
Reproduces the StartTasks → PollTasks race from iris #5041: the controller
507+
dispatches a task via StartTasks but polls before its own expected_tasks view
508+
includes the new task. Without the grace window, the worker would read the
509+
task as "unexpected" and kill it, cascading the whole pool to KILLED.
510+
"""
511+
mock_handle = create_mock_container_handle(status_sequence=[ContainerStatus(phase=ContainerPhase.RUNNING)] * 1000)
512+
mock_runtime.create_container = Mock(return_value=mock_handle)
513+
514+
task_id_wire = JobName.root("test-user", "poll-race").task(0).to_wire()
515+
request = create_run_task_request(task_id=task_id_wire)
516+
mock_worker.submit_task(request)
517+
518+
task = mock_worker.get_task(task_id_wire)
519+
wait_for_condition(lambda: task.status == job_pb2.TASK_STATE_RUNNING)
520+
521+
# Controller polls with the just-submitted task missing from expected_tasks
522+
# (race: controller hasn't reconciled its own StartTasks response yet).
523+
mock_worker.handle_poll_tasks(worker_pb2.Worker.PollTasksRequest(expected_tasks=[]))
524+
525+
# The task must not have been marked for kill.
526+
assert task.should_stop is False
527+
assert task.status == job_pb2.TASK_STATE_RUNNING
528+
529+
# Clean up.
530+
mock_worker.kill_task(task_id_wire)
531+
task.thread.join(timeout=15.0)
532+
533+
534+
def test_poll_tasks_kills_task_outside_grace_window(mock_worker, mock_runtime):
535+
"""Once the grace window has elapsed, reconciliation resumes killing unexpected tasks."""
536+
mock_handle = create_mock_container_handle(status_sequence=[ContainerStatus(phase=ContainerPhase.RUNNING)] * 1000)
537+
mock_runtime.create_container = Mock(return_value=mock_handle)
538+
539+
task_id_wire = JobName.root("test-user", "poll-post-grace").task(0).to_wire()
540+
request = create_run_task_request(task_id=task_id_wire)
541+
mock_worker.submit_task(request)
542+
543+
task = mock_worker.get_task(task_id_wire)
544+
wait_for_condition(lambda: task.status == job_pb2.TASK_STATE_RUNNING)
545+
546+
# Simulate grace window elapsing by clearing recent-submissions tracking.
547+
mock_worker._recent_submissions.clear()
548+
549+
mock_worker.handle_poll_tasks(worker_pb2.Worker.PollTasksRequest(expected_tasks=[]))
550+
551+
assert task.should_stop is True
552+
task.thread.join(timeout=15.0)
553+
assert task.status == job_pb2.TASK_STATE_KILLED
554+
555+
556+
def test_recent_submissions_prune_removes_stale_entries(mock_worker):
557+
"""Stale recent-submission entries are pruned to keep the dict bounded."""
558+
key_fresh = ("task-fresh", 0)
559+
key_stale = ("task-stale", 0)
560+
grace = mock_worker._RECENT_SUBMISSION_GRACE_SECONDS
561+
now = time.monotonic()
562+
# now - (grace + 1): clearly older than the window -> should be pruned
563+
mock_worker._recent_submissions[key_stale] = now - (grace + 1)
564+
mock_worker._recent_submissions[key_fresh] = now
565+
566+
with mock_worker._lock:
567+
recent = mock_worker._prune_and_get_recent_submission_keys()
568+
569+
assert key_fresh in recent
570+
assert key_stale not in recent
571+
assert key_stale not in mock_worker._recent_submissions
572+
573+
497574
def test_kill_nonexistent_task(mock_worker):
498575
"""Test killing a nonexistent task returns False."""
499576
result = mock_worker.kill_task(JobName.root("test-user", "nonexistent-task").task(0).to_wire())

0 commit comments

Comments
 (0)