diff --git a/lib/iris/src/iris/cluster/worker/task_attempt.py b/lib/iris/src/iris/cluster/worker/task_attempt.py index c1c5da84e8..669d41796e 100644 --- a/lib/iris/src/iris/cluster/worker/task_attempt.py +++ b/lib/iris/src/iris/cluster/worker/task_attempt.py @@ -36,6 +36,7 @@ ) from iris.cluster.bundle import BundleStore from iris.cluster.worker.port_allocator import PortAllocator +from iris.cluster.worker.tpu_health import detect_tpu_init_failure from iris.cluster.worker.worker_types import LogLine from iris.cluster.log_store._types import task_log_key from iris.log_server.client import LogPusher @@ -51,6 +52,9 @@ logger = logging.getLogger(__name__) +# Trailing stderr lines scanned for TPU bad-node signatures on non-zero exit. +_TPU_STDERR_TAIL_LINES = 200 + # Signal numbers for interpreting exit codes > 128 _SIGNAL_NAMES = { 6: "SIGABRT", @@ -827,21 +831,39 @@ def _monitor_loop( elif status.exit_code == 0: self.transition_to(job_pb2.TASK_STATE_SUCCEEDED, exit_code=0) else: - stderr_line = None - for entry in reversed(log_reader.read_all()): - if entry.source == "stderr" and entry.data: - stderr_line = entry.data - break + stderr_tail: list[str] = [ + entry.data for entry in log_reader.read_all() if entry.source == "stderr" and entry.data + ] + stderr_line = stderr_tail[-1] if stderr_tail else None error = _format_exit_error(status.exit_code, status.oom_killed) if stderr_line: error = f"{error}. stderr: {stderr_line}" if status.oom_killed: self._append_log(source="error", data="Container was OOM killed by the kernel") - self.transition_to( - job_pb2.TASK_STATE_FAILED, - error=error, - exit_code=status.exit_code or -1, - ) + # Promote known TPU bad-node signatures to WORKER_FAILED. + tpu_pattern = detect_tpu_init_failure(stderr_tail[-_TPU_STDERR_TAIL_LINES:]) + if tpu_pattern is not None: + logger.warning( + "Task %s: TPU bad-node signature %r; promoting FAILED -> WORKER_FAILED", + self.task_id, + tpu_pattern, + ) + self._append_log( + source="error", + data=f"iris: TPU bad-node signature detected ({tpu_pattern!r}); " + "reporting as worker failure", + ) + self.transition_to( + job_pb2.TASK_STATE_WORKER_FAILED, + error=f"TPU init failure ({tpu_pattern!r}): {error}", + exit_code=status.exit_code or -1, + ) + else: + self.transition_to( + job_pb2.TASK_STATE_FAILED, + error=error, + exit_code=status.exit_code or -1, + ) break # Stream logs incrementally diff --git a/lib/iris/src/iris/cluster/worker/tpu_health.py b/lib/iris/src/iris/cluster/worker/tpu_health.py new file mode 100644 index 0000000000..ab7d39ebcb --- /dev/null +++ b/lib/iris/src/iris/cluster/worker/tpu_health.py @@ -0,0 +1,26 @@ +# Copyright The Marin Authors +# SPDX-License-Identifier: Apache-2.0 + +"""TPU bad-node stderr signatures. Hits promote FAILED -> WORKER_FAILED.""" + +from collections.abc import Iterable + +# Keep in sync with lib/iris/OPS.md bad-node triggers. +TPU_INIT_FAILURE_PATTERNS: tuple[str, ...] = ( + "Couldn't open iommu group", + "open(/dev/vfio", + "Failed to initialize TPU system", + "TPU initialization failed", + "No accelerator found", +) + + +def detect_tpu_init_failure(stderr_lines: Iterable[str]) -> str | None: + """Return the first matching bad-node pattern found in ``stderr_lines``, or None.""" + for line in stderr_lines: + if not line: + continue + for pattern in TPU_INIT_FAILURE_PATTERNS: + if pattern in line: + return pattern + return None diff --git a/lib/iris/tests/cluster/worker/test_worker.py b/lib/iris/tests/cluster/worker/test_worker.py index 650b31d7d6..6b7c61fa0c 100644 --- a/lib/iris/tests/cluster/worker/test_worker.py +++ b/lib/iris/tests/cluster/worker/test_worker.py @@ -28,7 +28,13 @@ from iris.cluster.worker.worker import Worker, WorkerConfig from iris.rpc import job_pb2 from rigging.timing import Duration -from tests.cluster.worker.conftest import create_mock_container_handle, create_run_task_request +from iris.cluster.worker.worker_types import LogLine +from tests.cluster.worker.conftest import ( + FakeContainerHandle, + FakeLogReader, + create_mock_container_handle, + create_run_task_request, +) from iris.test_util import wait_for_condition pytestmark = pytest.mark.timeout(10) @@ -151,6 +157,78 @@ def test_task_failure_on_nonzero_exit(mock_worker, mock_runtime): assert "Exit code: 1" in final_task.error +def test_tpu_bad_node_stderr_promotes_to_worker_failed(mock_worker, mock_runtime): + """Non-zero exit with TPU bad-node stderr -> WORKER_FAILED (issue #4783).""" + bad_node_stderr = [ + LogLine.now(source="stdout", data="startup: launching vLLM engine"), + LogLine.now( + source="stderr", + data=( + "jax.errors.JaxRuntimeError: UNKNOWN: TPU initialization failed: " + "open(/dev/vfio/0): Device or resource busy: Device or resource busy; " + "Couldn't open iommu group /dev/vfio/0" + ), + ), + ] + populated_reader = FakeLogReader(_logs=list(bad_node_stderr)) + + class _HandleWithStderr(FakeContainerHandle): + def log_reader(self) -> FakeLogReader: + return populated_reader + + mock_handle = _HandleWithStderr( + status_sequence=[ + ContainerStatus(phase=ContainerPhase.RUNNING), + ContainerStatus(phase=ContainerPhase.STOPPED, exit_code=1), + ] + ) + mock_runtime.create_container = Mock(return_value=mock_handle) + + request = create_run_task_request() + task_id = mock_worker.submit_task(request) + + task = mock_worker.get_task(task_id) + task.thread.join(timeout=15.0) + + final_task = mock_worker.get_task(task_id) + assert final_task.status == job_pb2.TASK_STATE_WORKER_FAILED + assert final_task.exit_code == 1 + assert final_task.error is not None + assert "TPU init failure" in final_task.error + assert "Couldn't open iommu group" in final_task.error + + +def test_non_tpu_stderr_still_maps_to_failed(mock_worker, mock_runtime): + """Non-zero exit with unrelated stderr stays FAILED (no false promotion).""" + user_stderr = [ + LogLine.now(source="stderr", data="Traceback (most recent call last):"), + LogLine.now(source="stderr", data='ValueError: bad user config: expected "foo"'), + ] + populated_reader = FakeLogReader(_logs=list(user_stderr)) + + class _HandleWithStderr(FakeContainerHandle): + def log_reader(self) -> FakeLogReader: + return populated_reader + + mock_handle = _HandleWithStderr( + status_sequence=[ + ContainerStatus(phase=ContainerPhase.RUNNING), + ContainerStatus(phase=ContainerPhase.STOPPED, exit_code=1), + ] + ) + mock_runtime.create_container = Mock(return_value=mock_handle) + + request = create_run_task_request() + task_id = mock_worker.submit_task(request) + + task = mock_worker.get_task(task_id) + task.thread.join(timeout=15.0) + + final_task = mock_worker.get_task(task_id) + assert final_task.status == job_pb2.TASK_STATE_FAILED + assert final_task.exit_code == 1 + + def test_task_failure_on_error(mock_worker, mock_runtime): """Test task fails when container returns error.""" # Update the mock handle's status to return error after first poll