Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions lib/iris/src/iris/cluster/worker/task_attempt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from iris.cluster.bundle import BundleStore
from iris.cluster.worker.port_allocator import PortAllocator
from iris.cluster.worker.tpu_health import detect_tpu_init_failure
from iris.cluster.worker.worker_types import LogLine
from iris.cluster.log_store._types import task_log_key
from iris.log_server.client import LogPusher
Expand All @@ -51,6 +52,10 @@

logger = logging.getLogger(__name__)

# How many trailing stderr lines to scan for TPU bad-node signatures.
# Bounded so we don't rescan the full container log on every failure.
_TPU_STDERR_TAIL_LINES = 200

# Signal numbers for interpreting exit codes > 128
_SIGNAL_NAMES = {
6: "SIGABRT",
Expand Down Expand Up @@ -827,21 +832,43 @@ def _monitor_loop(
elif status.exit_code == 0:
self.transition_to(job_pb2.TASK_STATE_SUCCEEDED, exit_code=0)
else:
stderr_line = None
for entry in reversed(log_reader.read_all()):
if entry.source == "stderr" and entry.data:
stderr_line = entry.data
break
stderr_tail: list[str] = [
entry.data for entry in log_reader.read_all() if entry.source == "stderr" and entry.data
]
stderr_line = stderr_tail[-1] if stderr_tail else None
error = _format_exit_error(status.exit_code, status.oom_killed)
if stderr_line:
error = f"{error}. stderr: {stderr_line}"
if status.oom_killed:
self._append_log(source="error", data="Container was OOM killed by the kernel")
self.transition_to(
job_pb2.TASK_STATE_FAILED,
error=error,
exit_code=status.exit_code or -1,
)
# Promote known TPU bad-node signatures to WORKER_FAILED so
# the attempt consumes the preemption budget instead of the
# user-code failure budget. Scan a bounded tail — these
# signatures are emitted close to process exit.
tpu_pattern = detect_tpu_init_failure(stderr_tail[-_TPU_STDERR_TAIL_LINES:])
if tpu_pattern is not None:
Comment thread
rjpower marked this conversation as resolved.
tpu_error = f"TPU init failure ({tpu_pattern!r}): {error}"
logger.warning(
"Task %s: detected TPU bad-node signature %r; " "promoting FAILED -> WORKER_FAILED",
self.task_id,
tpu_pattern,
)
self._append_log(
source="error",
data=f"iris: TPU bad-node signature detected ({tpu_pattern!r}); "
"reporting as worker failure",
)
self.transition_to(
job_pb2.TASK_STATE_WORKER_FAILED,
error=tpu_error,
exit_code=status.exit_code or -1,
)
else:
self.transition_to(
job_pb2.TASK_STATE_FAILED,
error=error,
exit_code=status.exit_code or -1,
)
break

# Stream logs incrementally
Expand Down
51 changes: 51 additions & 0 deletions lib/iris/src/iris/cluster/worker/tpu_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""TPU-level bad-node failure detection.

When a task container exits with a non-zero status, the worker normally marks
the task as ``TASK_STATE_FAILED`` (user-code failure). Some failure signatures
are actually signs that the underlying TPU VM is dirty — typically after a
preemption / teardown where ``/dev/vfio`` is still claimed by a previous
process. We need to promote those to ``TASK_STATE_WORKER_FAILED`` so the
controller treats the attempt as an infra preemption and retries it elsewhere.

Patterns are hard-coded on purpose: these signatures are stable strings
emitted by JAX / libtpu during TPU init, and OPS.md already documents them as
the manual trigger list for bad-node triage.
"""

from collections.abc import Iterable

# Substrings matched against container stderr tail. A single hit promotes the
# attempt from FAILED to WORKER_FAILED.
#
# Keep this list in sync with ``lib/iris/OPS.md`` bad-node triggers.
TPU_INIT_FAILURE_PATTERNS: tuple[str, ...] = (
# /dev/vfio/<n> busy after a dirty preemption — the canonical case from #4783.
"Couldn't open iommu group",
"open(/dev/vfio",
# libtpu / JAX surface when the device is held by another process.
"Failed to initialize TPU system",
"TPU initialization failed",
# Host has no visible accelerator at all (VM came up without TPU attached).
"No accelerator found",
)


def detect_tpu_init_failure(stderr_lines: Iterable[str]) -> str | None:
"""Return the first matching bad-node pattern found in ``stderr_lines``.

``stderr_lines`` is any iterable of stderr strings (typically the tail of
the container log). Returns ``None`` if no pattern matches.

Callers should pass a bounded tail (not the full log) — these signatures
are emitted close to process exit, and scanning the full log wastes work.
"""
for line in stderr_lines:
if not line:
continue
for pattern in TPU_INIT_FAILURE_PATTERNS:
if pattern in line:
return pattern
return None
67 changes: 67 additions & 0 deletions lib/iris/tests/cluster/worker/test_tpu_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Tests for TPU bad-node stderr pattern detection."""

import pytest

from iris.cluster.worker.tpu_health import (
TPU_INIT_FAILURE_PATTERNS,
detect_tpu_init_failure,
)


@pytest.mark.parametrize(
"line",
[
# Exact failure from #4783:
"jax.errors.JaxRuntimeError: UNKNOWN: TPU initialization failed: "
"open(/dev/vfio/0): Device or resource busy: Device or resource busy; "
"Couldn't open iommu group /dev/vfio/0",
# libtpu's older init-failure wording
"Failed to initialize TPU system: some backend error",
# JAX surface when the VM booted without a TPU attached
"RuntimeError: No accelerator found on this host",
# vfio path-only hit
"libtpu: open(/dev/vfio/0) returned -1",
],
)
def test_detects_known_bad_node_signatures(line: str) -> None:
assert detect_tpu_init_failure([line]) is not None


def test_detects_from_mixed_log_tail() -> None:
tail = [
"normal startup log line",
"another info line",
"Couldn't open iommu group /dev/vfio/0",
"subsequent error traceback frame",
]
pattern = detect_tpu_init_failure(tail)
assert pattern == "Couldn't open iommu group"


def test_returns_none_on_unrelated_stderr() -> None:
tail = [
"Traceback (most recent call last):",
'ValueError: bad user config: expected "foo"',
"",
]
assert detect_tpu_init_failure(tail) is None


def test_empty_input() -> None:
assert detect_tpu_init_failure([]) is None


def test_ignores_empty_lines() -> None:
# Empty strings should not be mistaken for matches and should not crash.
assert detect_tpu_init_failure(["", None or ""]) is None


def test_all_patterns_are_discoverable() -> None:
# Sanity: every declared pattern must be detected when it appears verbatim
# in a line. Guards against accidental pattern-list / detector drift.
for pattern in TPU_INIT_FAILURE_PATTERNS:
line = f"prefix noise {pattern} trailing noise"
assert detect_tpu_init_failure([line]) == pattern
87 changes: 86 additions & 1 deletion lib/iris/tests/cluster/worker/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from iris.cluster.worker.worker import Worker, WorkerConfig
from iris.rpc import job_pb2
from rigging.timing import Duration
from tests.cluster.worker.conftest import create_mock_container_handle, create_run_task_request
from iris.cluster.worker.worker_types import LogLine
from tests.cluster.worker.conftest import (
FakeContainerHandle,
FakeLogReader,
create_mock_container_handle,
create_run_task_request,
)
from iris.test_util import wait_for_condition

pytestmark = pytest.mark.timeout(10)
Expand Down Expand Up @@ -151,6 +157,85 @@ def test_task_failure_on_nonzero_exit(mock_worker, mock_runtime):
assert "Exit code: 1" in final_task.error


def test_tpu_bad_node_stderr_promotes_to_worker_failed(mock_worker, mock_runtime):
"""A non-zero container exit whose stderr tail matches a TPU bad-node
signature must be promoted from TASK_STATE_FAILED to TASK_STATE_WORKER_FAILED
so the attempt consumes the preemption budget and can retry elsewhere.

See: https://github.com/marin-community/marin/issues/4783
"""
bad_node_stderr = [
LogLine.now(source="stdout", data="startup: launching vLLM engine"),
LogLine.now(
source="stderr",
data=(
"jax.errors.JaxRuntimeError: UNKNOWN: TPU initialization failed: "
"open(/dev/vfio/0): Device or resource busy: Device or resource busy; "
"Couldn't open iommu group /dev/vfio/0"
),
),
]
populated_reader = FakeLogReader(_logs=list(bad_node_stderr))

class _HandleWithStderr(FakeContainerHandle):
def log_reader(self) -> FakeLogReader:
return populated_reader

mock_handle = _HandleWithStderr(
status_sequence=[
ContainerStatus(phase=ContainerPhase.RUNNING),
ContainerStatus(phase=ContainerPhase.STOPPED, exit_code=1),
]
)
mock_runtime.create_container = Mock(return_value=mock_handle)

request = create_run_task_request()
task_id = mock_worker.submit_task(request)

task = mock_worker.get_task(task_id)
task.thread.join(timeout=15.0)

final_task = mock_worker.get_task(task_id)
assert final_task.status == job_pb2.TASK_STATE_WORKER_FAILED
assert final_task.exit_code == 1
assert final_task.error is not None
assert "TPU init failure" in final_task.error
assert "Couldn't open iommu group" in final_task.error


def test_non_tpu_stderr_still_maps_to_failed(mock_worker, mock_runtime):
"""A non-zero container exit with unrelated stderr must remain TASK_STATE_FAILED
(no false positive promotion to WORKER_FAILED).
"""
user_stderr = [
LogLine.now(source="stderr", data="Traceback (most recent call last):"),
LogLine.now(source="stderr", data='ValueError: bad user config: expected "foo"'),
]
populated_reader = FakeLogReader(_logs=list(user_stderr))

class _HandleWithStderr(FakeContainerHandle):
def log_reader(self) -> FakeLogReader:
return populated_reader

mock_handle = _HandleWithStderr(
status_sequence=[
ContainerStatus(phase=ContainerPhase.RUNNING),
ContainerStatus(phase=ContainerPhase.STOPPED, exit_code=1),
]
)
mock_runtime.create_container = Mock(return_value=mock_handle)

request = create_run_task_request()
task_id = mock_worker.submit_task(request)

task = mock_worker.get_task(task_id)
task.thread.join(timeout=15.0)

final_task = mock_worker.get_task(task_id)
assert final_task.status == job_pb2.TASK_STATE_FAILED
assert final_task.exit_code == 1


def test_task_failure_on_error(mock_worker, mock_runtime):
"""Test task fails when container returns error."""
# Update the mock handle's status to return error after first poll
Expand Down
Loading