Skip to content

Commit 5c6f621

Browse files
rjpowerHelw150
authored andcommitted
[zephyr] Fix _check_worker_group false abort after completed stage (#4140)
_check_worker_group unconditionally treated worker_group.is_done()==True as a crash. After the last stage, workers exit cleanly via SHUTDOWN, Iris marks the job finished, and the coordinator background loop aborts with "Worker job terminated permanently" even though all shards completed. Only affects Iris (Local/Ray hardcode is_done to False). Adds a completed-shards guard to _check_worker_group and three regression tests. Fixes #4117
1 parent 8217640 commit 5c6f621

2 files changed

Lines changed: 101 additions & 0 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ def _check_worker_group(self) -> None:
453453
"""Abort the pipeline if the worker job has permanently terminated."""
454454
if self._worker_group is None or self._fatal_error is not None:
455455
return
456+
# After the last stage completes, workers exit cleanly via SHUTDOWN.
457+
# The worker job finishing at that point is expected, not a crash.
458+
with self._lock:
459+
if self._total_shards > 0 and self._completed_shards >= self._total_shards:
460+
return
456461
try:
457462
if self._worker_group.is_done():
458463
self.abort(
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Regression tests for issue #4117: _check_worker_group false abort.
5+
6+
The race (before fix):
7+
1. Last stage completes — all shards done, workers get SHUTDOWN, exit
8+
2. Main thread is in _collect_results / _regroup_result_refs (between
9+
_wait_for_stage returning and self.shutdown())
10+
3. Background coordinator loop calls _check_worker_group
11+
4. worker_group.is_done() returns True (workers exited cleanly!)
12+
5. Coordinator calls abort("Worker job terminated permanently...")
13+
14+
Fix: _check_worker_group skips when all shards are completed.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import threading
20+
import time
21+
from unittest.mock import MagicMock
22+
23+
import pytest
24+
from zephyr.execution import ListShard, ShardTask, TaskResult, ZephyrCoordinator
25+
26+
27+
@pytest.fixture
28+
def coordinator(actor_context, tmp_path):
29+
coord = ZephyrCoordinator()
30+
coord.set_chunk_config(str(tmp_path / "chunks"), "test-exec")
31+
yield coord
32+
coord.shutdown()
33+
34+
35+
def test_check_worker_group_skips_after_completed_stage(coordinator):
36+
"""Worker group finishing after completed stage must not abort. #4117."""
37+
mock_group = MagicMock()
38+
mock_group.is_done.return_value = True
39+
coordinator.set_worker_group(mock_group)
40+
41+
task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test")
42+
coordinator.start_stage("last-stage", [task], is_last_stage=True)
43+
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])))
44+
45+
assert coordinator._completed_shards >= coordinator._total_shards
46+
47+
coordinator._check_worker_group()
48+
49+
assert coordinator.get_fatal_error() is None
50+
51+
52+
def test_check_worker_group_still_aborts_mid_stage(coordinator):
53+
"""Worker group dying while shards are still in-flight must abort."""
54+
mock_group = MagicMock()
55+
mock_group.is_done.return_value = True
56+
coordinator.set_worker_group(mock_group)
57+
58+
task = ShardTask(shard_idx=0, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test")
59+
coordinator.start_stage("mid-stage", [task, task], is_last_stage=False)
60+
# Only 1 of 2 shards completed
61+
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])))
62+
63+
coordinator._check_worker_group()
64+
65+
assert coordinator.get_fatal_error() is not None
66+
assert "Worker job terminated permanently" in coordinator.get_fatal_error()
67+
68+
69+
def test_coordinator_loop_no_abort_during_result_collection(coordinator):
70+
"""Background loop must not abort during post-stage result collection. #4117."""
71+
mock_group = MagicMock()
72+
call_count = 0
73+
74+
def is_done_with_delay():
75+
nonlocal call_count
76+
call_count += 1
77+
return call_count > 2
78+
79+
mock_group.is_done.side_effect = is_done_with_delay
80+
coordinator.set_worker_group(mock_group)
81+
82+
task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test")
83+
coordinator.start_stage("last-stage", [task], is_last_stage=True)
84+
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])))
85+
86+
t = threading.Thread(target=coordinator._coordinator_loop, daemon=True)
87+
t.start()
88+
89+
# Simulate the post-stage window where main thread collects/regroups results
90+
time.sleep(2)
91+
92+
fatal = coordinator.get_fatal_error()
93+
coordinator.shutdown()
94+
t.join(timeout=2.0)
95+
96+
assert fatal is None

0 commit comments

Comments
 (0)