Skip to content

Commit 493c9bb

Browse files
authored
[iris] Reap TPU hosts that keep failing launches; fire on_stop on natural return (#5038)
Three linked fixes for a TPU co-schedule loop. ASSIGNED->WORKER_FAILED now bumps the worker health tracker so a host that repeatedly fails to bring up a task (e.g. iommu/vfio group already held) gets reaped instead of looping forever. Excludes reservation-holder tasks from PollTasksRequest.expected_tasks, since holders are virtual and polling them produced bogus WORKER_FAILEDs that drained preemption budget. ManagedThread now fires on_stop on the natural-return path so docker kill+rm actually runs, releasing the TPU vfio group for the next task. Adds regression tests for each.
1 parent 9fd3b66 commit 493c9bb

5 files changed

Lines changed: 199 additions & 2 deletions

File tree

lib/iris/src/iris/cluster/controller/transitions.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1878,6 +1878,13 @@ def _apply_task_transitions(
18781878
if update.new_state == job_pb2.TASK_STATE_WORKER_FAILED and prior_state == job_pb2.TASK_STATE_ASSIGNED:
18791879
task_state = job_pb2.TASK_STATE_PENDING
18801880
terminal_ms = None
1881+
# ASSIGNED -> WORKER_FAILED means the worker accepted the task but
1882+
# couldn't bring it up (e.g. TPU iommu/vfio already held by another
1883+
# process on the VM). Attribute the failure to the worker so a host
1884+
# that keeps failing launches gets reaped; otherwise the task loops
1885+
# forever without draining preemption budget.
1886+
if worker_id is not None:
1887+
self._health.build_failed(WorkerId(str(worker_id)))
18811888
if update.new_state == job_pb2.TASK_STATE_FAILED and failure_count <= int(
18821889
task_row["max_retries_failure"]
18831890
):
@@ -3128,10 +3135,17 @@ def get_running_tasks_for_poll(
31283135
return {}, {}
31293136

31303137
placeholders = ",".join("?" for _ in worker_ids)
3138+
# Reservation holders are virtual — they live on ``current_worker_id``
3139+
# only as a scheduling anchor and never get a RunTaskRequest. Sending
3140+
# them in PollTasksRequest.expected_tasks makes the worker reconcile
3141+
# against its _tasks dict, miss, and return WORKER_FAILED every cycle,
3142+
# which drains the holder's preemption budget and (post the build-
3143+
# failure health hook) reaps the claimed worker for a harmless miss.
31313144
task_rows = snap.fetchall(
31323145
f"SELECT t.task_id, t.current_attempt_id, t.current_worker_id "
3133-
f"FROM tasks t "
3146+
f"FROM tasks t JOIN jobs j ON j.job_id = t.job_id "
31343147
f"WHERE t.current_worker_id IN ({placeholders}) AND t.state IN (?, ?, ?) "
3148+
f"AND j.is_reservation_holder = 0 "
31353149
f"ORDER BY t.task_id ASC",
31363150
(*worker_ids, *ACTIVE_TASK_STATES),
31373151
)

lib/iris/src/iris/managed_thread.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,14 @@ def _watch_stop() -> None:
110110
raise
111111
finally:
112112
if watcher:
113-
watcher.join(timeout=1.0)
113+
# Wake the watcher regardless of how the target exited so
114+
# on_stop runs on the natural-completion path too. Otherwise
115+
# cleanup (e.g. docker kill+rm for task containers) is
116+
# silently skipped whenever the target returns without an
117+
# explicit stop() — leaving wedged containers that keep
118+
# holding TPU vfio/iommu groups and break subsequent tasks.
119+
self._stop_event.set()
120+
watcher.join(timeout=5.0)
114121
if watcher.is_alive():
115122
logger.warning("on_stop callback for %s did not complete", name)
116123

lib/iris/tests/cluster/controller/test_reservation.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,15 @@
5353
from tests.cluster.controller.conftest import (
5454
FakeProvider,
5555
hydrate_worker_attributes as _with_attrs,
56+
make_job_request,
5657
query_job as _query_job,
5758
query_job_row as _query_job_row,
5859
query_task as _query_task,
5960
query_task_with_attempts as _query_task_with_attempts,
6061
query_tasks_for_job as _query_tasks_for_job,
6162
query_worker as _query_worker,
6263
schedulable_tasks as _schedulable_tasks,
64+
submit_job as _submit_job_tasks,
6365
worker_running_tasks as _worker_running_tasks,
6466
)
6567

@@ -1436,6 +1438,49 @@ def test_holder_task_worker_death_no_failure_record(state):
14361438
assert task_row_can_be_scheduled(holder_task), "holder task must be schedulable again"
14371439

14381440

1441+
def test_get_running_tasks_for_poll_excludes_reservation_holders(state):
1442+
"""get_running_tasks_for_poll must filter reservation-holder tasks.
1443+
1444+
Regression: the ping/poll loop feeds its output directly into
1445+
PollTasksRequest.expected_tasks. Holders are virtual — they never reach
1446+
the worker's _tasks dict — so including them makes the worker reconcile,
1447+
miss, and return WORKER_FAILED("Task not found on worker") every cycle.
1448+
That drains the holder's preemption budget and (with the ASSIGNED→
1449+
WORKER_FAILED health hook) reaps the claimed worker every few minutes.
1450+
1451+
Produced observed ~51 attempts/hour per holder in production.
1452+
"""
1453+
request = _make_job_request_with_reservation(
1454+
reservation_entries=[_make_reservation_entry(_cpu_device())],
1455+
)
1456+
parent_job_id = _submit_job(state, "res-job", request)
1457+
holder_job_id = parent_job_id.child(RESERVATION_HOLDER_JOB_NAME)
1458+
1459+
holder_tasks = _query_tasks_for_job(state, holder_job_id)
1460+
assert len(holder_tasks) == 1
1461+
holder_task = holder_tasks[0]
1462+
1463+
real_request = make_job_request("real-job")
1464+
(real_task,) = _submit_job_tasks(state, "real-job", real_request)
1465+
1466+
worker_id = _register_worker(state, "w1")
1467+
state.queue_assignments(
1468+
[
1469+
Assignment(task_id=holder_task.task_id, worker_id=worker_id),
1470+
Assignment(task_id=real_task.task_id, worker_id=worker_id),
1471+
]
1472+
)
1473+
1474+
running, _addresses = state.get_running_tasks_for_poll()
1475+
1476+
task_ids = {entry.task_id for entry in running.get(worker_id, [])}
1477+
assert real_task.task_id in task_ids, "real task must still appear for polling"
1478+
assert holder_task.task_id not in task_ids, (
1479+
"reservation holder must be excluded — worker has no in-memory state "
1480+
"for virtual holders, so polling them produces bogus WORKER_FAILEDs"
1481+
)
1482+
1483+
14391484
def test_holder_task_removed_from_worker_when_parent_succeeds(state):
14401485
"""Holder task is cleaned from worker.running_tasks when the parent job succeeds.
14411486

lib/iris/tests/cluster/controller/test_transitions.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,6 +2225,40 @@ def test_worker_failed_from_building_counts_as_preemption(state):
22252225
assert _query_task(state, task.task_id).failure_count == 0
22262226

22272227

2228+
def test_worker_failed_from_assigned_bumps_health_tracker(state):
2229+
"""ASSIGNED -> WORKER_FAILED attributes the failure to the worker.
2230+
2231+
Regression for the TPU-iommu co-schedule loop: the task retries to PENDING
2232+
(no preemption-budget cost) but the health tracker must still bump so that
2233+
a host that repeatedly fails launches eventually crosses the threshold and
2234+
gets reaped.
2235+
"""
2236+
worker_id = register_worker(state, "w1", "host:8080", make_worker_metadata())
2237+
req = make_job_request("job1")
2238+
req.max_retries_preemption = 5
2239+
tasks = submit_job(state, "j1", req)
2240+
task = tasks[0]
2241+
2242+
state.queue_assignments([Assignment(task_id=task.task_id, worker_id=worker_id)])
2243+
assert _query_task(state, task.task_id).state == job_pb2.TASK_STATE_ASSIGNED
2244+
assert state._health.snapshot().get(worker_id) is None
2245+
2246+
transition_task(
2247+
state,
2248+
task.task_id,
2249+
job_pb2.TASK_STATE_WORKER_FAILED,
2250+
error='TPU init failure ("Couldn\'t open iommu group")',
2251+
)
2252+
2253+
# Task retries without consuming preemption budget...
2254+
t = _query_task(state, task.task_id)
2255+
assert t.state == job_pb2.TASK_STATE_PENDING
2256+
assert t.preemption_count == 0
2257+
# ...but the worker is charged a build failure.
2258+
_, build_failures = state._health.snapshot()[worker_id]
2259+
assert build_failures == 1
2260+
2261+
22282262
def test_failed_from_building_bumps_health_tracker(state):
22292263
"""FAILED originating from BUILDING increments the build failure counter.
22302264
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Regression tests for ManagedThread lifecycle.
5+
6+
Focus: on_stop callback must run both when stop() is called externally AND
7+
when the thread target returns on its own. A missed on_stop on the natural-
8+
completion path left task containers un-reaped in production — the container
9+
process stayed wedged on the TPU vfio/iommu group, poisoning the VM for
10+
subsequent tasks.
11+
"""
12+
13+
import threading
14+
import time
15+
16+
from iris.managed_thread import ManagedThread
17+
18+
19+
def test_on_stop_runs_when_stop_is_called():
20+
stopped = threading.Event()
21+
released = threading.Event()
22+
23+
def target(stop_event: threading.Event) -> None:
24+
stop_event.wait(timeout=5.0)
25+
26+
def on_stop() -> None:
27+
stopped.set()
28+
released.set()
29+
30+
t = ManagedThread(target=target, name="stop-called", on_stop=on_stop)
31+
t.start()
32+
t.stop()
33+
t.join()
34+
assert stopped.is_set()
35+
36+
37+
def test_on_stop_runs_when_target_returns_naturally():
38+
"""Regression: target returning on its own must still fire on_stop.
39+
40+
Before the fix, on_stop was only invoked when an explicit stop() set the
41+
stop event. When the target returned naturally (e.g. a task container
42+
exited and the monitoring loop finished), the watcher stayed parked on
43+
stop_event.wait() and the finally block timed out silently, skipping
44+
on_stop. For task threads this meant docker kill + docker rm never ran,
45+
leaving wedged containers holding TPU vfio groups.
46+
"""
47+
on_stop_ran = threading.Event()
48+
49+
def target(_stop_event: threading.Event) -> None:
50+
# Return immediately without touching the stop event.
51+
return
52+
53+
def on_stop() -> None:
54+
on_stop_ran.set()
55+
56+
t = ManagedThread(target=target, name="natural-return", on_stop=on_stop)
57+
t.start()
58+
t.join()
59+
assert on_stop_ran.is_set(), "on_stop must run when target completes naturally"
60+
61+
62+
def test_on_stop_runs_when_target_raises():
63+
"""on_stop must also fire when the target raises — exception path."""
64+
on_stop_ran = threading.Event()
65+
66+
class _Boom(Exception):
67+
pass
68+
69+
def target(_stop_event: threading.Event) -> None:
70+
raise _Boom("task blew up")
71+
72+
def on_stop() -> None:
73+
on_stop_ran.set()
74+
75+
t = ManagedThread(target=target, name="raising-target", on_stop=on_stop)
76+
t.start()
77+
t.join()
78+
assert on_stop_ran.is_set(), "on_stop must run even when target raises"
79+
80+
81+
def test_on_stop_runs_only_once():
82+
"""on_stop must not double-fire when both stop() and natural return occur."""
83+
calls = []
84+
lock = threading.Lock()
85+
86+
def target(stop_event: threading.Event) -> None:
87+
stop_event.wait(timeout=0.2)
88+
89+
def on_stop() -> None:
90+
with lock:
91+
calls.append(time.monotonic())
92+
93+
t = ManagedThread(target=target, name="no-double-fire", on_stop=on_stop)
94+
t.start()
95+
t.stop()
96+
t.join()
97+
assert len(calls) == 1, f"on_stop fired {len(calls)} times, expected 1"

0 commit comments

Comments
 (0)