Skip to content

Commit 0fd80fd

Browse files
committed
[zephyr] Fix _check_worker_group false abort after completed stage
_check_worker_group now skips when all shards are completed. Previously it unconditionally treated worker_group.is_done()==True as a crash, even when workers exited cleanly after receiving SHUTDOWN on the last stage. This caused flaky failures on the Iris backend where is_done() checks real job status (Local/Ray hardcode it to False).
1 parent 2fff66f commit 0fd80fd

2 files changed

Lines changed: 32 additions & 36 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,11 @@ def _check_worker_group(self) -> None:
453453
"""Abort the pipeline if the worker job has permanently terminated."""
454454
if self._worker_group is None or self._fatal_error is not None:
455455
return
456+
# After the last stage completes, workers exit cleanly via SHUTDOWN.
457+
# The worker job finishing at that point is expected, not a crash.
458+
with self._lock:
459+
if self._total_shards > 0 and self._completed_shards >= self._total_shards:
460+
return
456461
try:
457462
if self._worker_group.is_done():
458463
self.abort(
Lines changed: 27 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
# Copyright The Marin Authors
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""Reproduction for issue #4117: _check_worker_group aborts after last stage completes.
4+
"""Regression tests for issue #4117: _check_worker_group false abort.
55
6-
The race:
6+
The race (before fix):
77
1. Last stage completes — all shards done, workers get SHUTDOWN, exit
88
2. Main thread is in _collect_results / _regroup_result_refs (between
99
_wait_for_stage returning and self.shutdown())
1010
3. Background coordinator loop calls _check_worker_group
1111
4. worker_group.is_done() returns True (workers exited cleanly!)
1212
5. Coordinator calls abort("Worker job terminated permanently...")
13-
6. Next _wait_for_stage (or the caller reading results) sees fatal_error
1413
15-
_check_worker_group does NOT check whether the stage actually completed.
16-
It unconditionally treats is_done()==True as a crash.
14+
Fix: _check_worker_group skips when all shards are completed.
1715
"""
1816

1917
from __future__ import annotations
@@ -34,50 +32,48 @@ def coordinator(actor_context, tmp_path):
3432
coord.shutdown()
3533

3634

37-
def test_check_worker_group_aborts_after_completed_stage(coordinator):
38-
"""Reproduce: worker group finishing after last stage triggers false abort.
39-
40-
This is the exact race from issue #4117. The coordinator background loop
41-
sees worker_group.is_done() == True and aborts, even though the stage
42-
completed successfully.
43-
"""
44-
# Set up a worker group that reports is_done=True (workers exited after SHUTDOWN)
35+
def test_check_worker_group_skips_after_completed_stage(coordinator):
36+
"""Worker group finishing after completed stage must not abort. #4117."""
4537
mock_group = MagicMock()
4638
mock_group.is_done.return_value = True
4739
coordinator.set_worker_group(mock_group)
4840

49-
# Simulate a completed stage: 1 task, 1 completed
5041
task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test")
5142
coordinator.start_stage("last-stage", [task], is_last_stage=True)
52-
53-
# Simulate task completion (worker finished before exiting)
5443
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])))
5544

56-
# Verify stage is complete
5745
assert coordinator._completed_shards >= coordinator._total_shards
5846

59-
# Now _check_worker_group fires — this is the bug
6047
coordinator._check_worker_group()
6148

62-
fatal = coordinator.get_fatal_error()
63-
# BUG: fatal_error is set even though the stage completed successfully
64-
assert fatal is not None, "Bug not triggered — _check_worker_group should have aborted"
65-
assert "Worker job terminated permanently" in fatal
66-
print(f"\nBUG REPRODUCED: {fatal!r}")
67-
print("Workers exited cleanly after SHUTDOWN, but coordinator treated it as a crash.")
49+
assert coordinator.get_fatal_error() is None
50+
51+
52+
def test_check_worker_group_still_aborts_mid_stage(coordinator):
53+
"""Worker group dying while shards are still in-flight must abort."""
54+
mock_group = MagicMock()
55+
mock_group.is_done.return_value = True
56+
coordinator.set_worker_group(mock_group)
57+
58+
task = ShardTask(shard_idx=0, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test")
59+
coordinator.start_stage("mid-stage", [task, task], is_last_stage=False)
60+
# Only 1 of 2 shards completed
61+
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])))
62+
63+
coordinator._check_worker_group()
64+
65+
assert coordinator.get_fatal_error() is not None
66+
assert "Worker job terminated permanently" in coordinator.get_fatal_error()
6867

6968

70-
def test_check_worker_group_aborts_during_result_collection(coordinator):
71-
"""Reproduce the full race: background loop fires during post-stage processing."""
69+
def test_coordinator_loop_no_abort_during_result_collection(coordinator):
70+
"""Background loop must not abort during post-stage result collection. #4117."""
7271
mock_group = MagicMock()
73-
# Workers alive during stage, then exit after SHUTDOWN
7472
call_count = 0
7573

7674
def is_done_with_delay():
7775
nonlocal call_count
7876
call_count += 1
79-
# First few calls: workers still running
80-
# After that: workers have exited
8177
return call_count > 2
8278

8379
mock_group.is_done.side_effect = is_done_with_delay
@@ -87,19 +83,14 @@ def is_done_with_delay():
8783
coordinator.start_stage("last-stage", [task], is_last_stage=True)
8884
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])))
8985

90-
# Start coordinator background loop
9186
t = threading.Thread(target=coordinator._coordinator_loop, daemon=True)
9287
t.start()
9388

94-
# Simulate the main thread doing post-stage work (collect_results, regroup, etc.)
95-
# During this time, the background loop keeps calling _check_worker_group
89+
# Simulate the post-stage window where main thread collects/regroups results
9690
time.sleep(2)
9791

9892
fatal = coordinator.get_fatal_error()
9993
coordinator.shutdown()
10094
t.join(timeout=2.0)
10195

102-
assert fatal is not None, "Bug not triggered — expected abort during result collection window"
103-
assert "Worker job terminated permanently" in fatal
104-
print(f"\nBUG REPRODUCED: {fatal!r}")
105-
print("Background loop aborted during post-stage result collection.")
96+
assert fatal is None

0 commit comments

Comments
 (0)