Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions lib/iris/src/iris/cluster/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
DIRECT_PROVIDER_PROMOTION_RATE,
HeartbeatApplyRequest,
ReservationClaim,
RunningTaskEntry,
SchedulingEvent,
TaskUpdate,
)
Expand Down Expand Up @@ -1024,6 +1025,13 @@ class Controller:
the controller will run it in a background thread.
"""

# Grace window during which a freshly-dispatched task (via StartTasks) is
# treated as expected when polling its worker. Bridges the StartTasks →
# PollTasks race: a poll's DB snapshot taken before the assignment commit
# would otherwise omit the task, and the worker would kill it as
# unexpected. 30s comfortably exceeds any normal StartTasks RPC latency.
_RECENT_DISPATCH_GRACE_SECONDS = 30.0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm scared of these magic numbers - but it makes sense!

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I want to get rid of this soon, we'll switch to a more sensible model without the race


def __init__(
self,
config: ControllerConfig,
Expand Down Expand Up @@ -1130,6 +1138,14 @@ def __init__(
self._poll_thread: ManagedThread | None = None
self._task_update_queue: queue_mod.Queue[HeartbeatApplyRequest] = queue_mod.Queue()

# Track tasks dispatched via StartTasks but whose assignment commit
# may not yet be visible to a concurrent PollTasks snapshot. Merged
# into expected_tasks in _poll_all_workers so the worker does not
# treat a freshly-submitted task as "unexpected" and kill it. See
# _RECENT_DISPATCH_GRACE_SECONDS and _dispatch_assignments_direct.
self._recent_dispatches: dict[WorkerId, dict[tuple[str, int], float]] = {}
self._recent_dispatches_lock = threading.Lock()

self._autoscaler: Autoscaler | None = autoscaler

self._heartbeat_iteration = 0
Expand Down Expand Up @@ -2168,6 +2184,14 @@ def _dispatch_assignments_direct(
command = [Assignment(task_id=task_id, worker_id=worker_id) for task_id, worker_id in assignments]
result = self._transitions.queue_assignments(command, direct_dispatch=True)

# Register dispatches before issuing StartTasks so a concurrent poll
# whose DB snapshot predates the assignment commit still sees these
# tasks as expected. Any poll that observes the entries won't kill
# the task; any poll that misses them can't yet collide with the
# StartTasks RPC (not sent until after this point) on the same
# gRPC channel.
self._record_recent_dispatches(result.start_requests)

# Group StartTasks payloads by (worker_id, address)
by_worker: dict[tuple[WorkerId, str], list[job_pb2.RunTaskRequest]] = {}
for worker_id, address, run_request in result.start_requests:
Expand Down Expand Up @@ -2318,24 +2342,88 @@ def _run_poll_loop(self, stop_event: threading.Event) -> None:
except Exception:
logger.exception("Poll loop iteration failed")

def _record_recent_dispatches(
self,
start_requests: list[tuple[WorkerId, str, job_pb2.RunTaskRequest]],
) -> None:
"""Record (worker, task, attempt) dispatches with a monotonic timestamp."""
if not start_requests:
return
now = time.monotonic()
with self._recent_dispatches_lock:
for worker_id, _address, run_request in start_requests:
self._recent_dispatches.setdefault(worker_id, {})[(run_request.task_id, run_request.attempt_id)] = now

def _prune_and_snapshot_recent_dispatches(self) -> dict[WorkerId, set[tuple[str, int]]]:
"""Drop stale entries and return a per-worker snapshot of keys within the grace window."""
cutoff = time.monotonic() - self._RECENT_DISPATCH_GRACE_SECONDS
snapshot: dict[WorkerId, set[tuple[str, int]]] = {}
with self._recent_dispatches_lock:
for wid in list(self._recent_dispatches.keys()):
entries = self._recent_dispatches[wid]
for key in [k for k, ts in entries.items() if ts < cutoff]:
del entries[key]
if not entries:
del self._recent_dispatches[wid]
else:
snapshot[wid] = set(entries.keys())
return snapshot

def _poll_all_workers(self) -> None:
"""Poll all workers for task state and feed results into the updater queue."""
"""Poll all workers for task state and feed results into the updater queue.

Merges recently-dispatched (worker, task, attempt) keys into each
worker's expected_tasks so a stale DB snapshot can't cause the worker
to kill a just-submitted task. Suppresses "Task not found on worker"
updates for entries still within the grace window: the StartTasks RPC
may not yet have landed when PollTasks arrived.
"""
if self._config.dry_run:
return
running, addresses = self._transitions.get_running_tasks_for_poll()
recent = self._prune_and_snapshot_recent_dispatches()

# Merge recent dispatches into the per-worker expected list. Ensures
# a worker with only freshly-dispatched tasks (not yet visible in the
# DB snapshot) is still polled with the correct expected set.
for wid, keys in recent.items():
if wid not in addresses:
continue
entries = running.setdefault(wid, [])
existing = {(e.task_id.to_wire(), e.attempt_id) for e in entries}
for task_wire, attempt_id in keys:
if (task_wire, attempt_id) in existing:
continue
entries.append(
RunningTaskEntry(
task_id=JobName.from_wire(task_wire),
attempt_id=attempt_id,
)
)

if not running:
return
poll_results = self._provider.poll_workers(running, addresses)
for worker_id, updates, error in poll_results:
if error is not None:
logger.warning("PollTasks failed for worker %s: %s", worker_id, error)
continue
if updates:
if not updates:
continue
# Drop every update for a task still inside the dispatch grace
# window. The worker's view of a freshly-submitted task is
# unreliable during this race (either it already has it and would
# report running, or it doesn't yet and would report "not found"
# as WORKER_FAILED). Let the next poll settle this once the grace
# window closes.
grace_keys = recent.get(worker_id, set())
filtered = [u for u in updates if (u.task_id.to_wire(), u.attempt_id) not in grace_keys]
if filtered:
self._task_update_queue.put(
HeartbeatApplyRequest(
worker_id=worker_id,
worker_resource_snapshot=None,
updates=updates,
updates=filtered,
)
)

Expand Down
102 changes: 102 additions & 0 deletions lib/iris/tests/cluster/controller/test_heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,3 +341,105 @@ def test_heartbeat_failure_error_includes_rpc_context():
assert "expected=1" in error

provider.close()


def test_poll_merges_recent_dispatches_into_expected_tasks(tmp_path, worker_metadata):
"""A poll whose DB snapshot predates the assignment commit still includes
freshly-dispatched tasks in expected_tasks, so the worker won't kill them.

Reproduces the controller side of the StartTasks→PollTasks race: we record
a dispatch for worker1 without committing any matching ASSIGNED row in the
DB, then run _poll_all_workers and confirm the task appears in the
expected_tasks passed to poll_workers.
"""
db = ControllerDB(db_dir=tmp_path)
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)

captured: dict[str, list[RunningTaskEntry]] = {}

class RecordingProvider(FakeProvider):
def poll_workers(self, running, worker_addresses):
for wid, entries in running.items():
captured[str(wid)] = list(entries)
return []

controller = Controller(config=config, provider=RecordingProvider(), db=db)
state = controller.state
_register_worker(state, "worker1", worker_metadata, address="10.0.0.1:10001")

task_wire = JobName.from_wire("/user/dispatch-race/0").to_wire()
run_request = job_pb2.RunTaskRequest(task_id=task_wire, attempt_id=1)
controller._record_recent_dispatches([(WorkerId("worker1"), "10.0.0.1:10001", run_request)])

controller._poll_all_workers()

assert "worker1" in captured, "worker1 must be polled so expected_tasks covers the dispatch"
wire_keys = {(e.task_id.to_wire(), e.attempt_id) for e in captured["worker1"]}
assert (task_wire, 1) in wire_keys

controller.stop()
db.close()


def test_poll_suppresses_worker_failed_within_grace_window(tmp_path, worker_metadata):
"""A WORKER_FAILED update for a task still in the dispatch grace window
is dropped: this is the reverse race where PollTasks reaches the worker
before StartTasks has been applied, producing a spurious 'not found'.
"""
db = ControllerDB(db_dir=tmp_path)
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)

task_wire = JobName.from_wire("/user/dispatch-race/0").to_wire()

class FailingProvider(FakeProvider):
def poll_workers(self, running, worker_addresses):
return [
(
WorkerId("worker1"),
[
TaskUpdate(
task_id=JobName.from_wire(task_wire),
attempt_id=1,
new_state=job_pb2.TASK_STATE_WORKER_FAILED,
error="Task not found on worker",
)
],
None,
)
]

controller = Controller(config=config, provider=FailingProvider(), db=db)
state = controller.state
_register_worker(state, "worker1", worker_metadata, address="10.0.0.1:10001")

run_request = job_pb2.RunTaskRequest(task_id=task_wire, attempt_id=1)
controller._record_recent_dispatches([(WorkerId("worker1"), "10.0.0.1:10001", run_request)])

controller._poll_all_workers()

assert controller._task_update_queue.empty(), "WORKER_FAILED should be suppressed within the grace window"

controller.stop()
db.close()


def test_prune_recent_dispatches_drops_stale_entries(tmp_path, worker_metadata):
"""Stale recent-dispatch entries past the grace window are pruned."""
db = ControllerDB(db_dir=tmp_path)
config = ControllerConfig(remote_state_dir="file:///tmp/iris-test-state", local_state_dir=tmp_path)
controller = Controller(config=config, provider=FakeProvider(), db=db)

grace = controller._RECENT_DISPATCH_GRACE_SECONDS
now = time.monotonic()
controller._recent_dispatches[WorkerId("worker1")] = {
("/user/fresh/0", 0): now,
("/user/stale/0", 0): now - (grace + 1.0),
}

snapshot = controller._prune_and_snapshot_recent_dispatches()

assert snapshot == {WorkerId("worker1"): {("/user/fresh/0", 0)}}
assert controller._recent_dispatches[WorkerId("worker1")] == {("/user/fresh/0", 0): now}

controller.stop()
db.close()
Loading