11# Copyright The Marin Authors
22# SPDX-License-Identifier: Apache-2.0
33
4- """Reproduction for issue #4117: _check_worker_group aborts after last stage completes .
4+ """Regression tests for issue #4117: _check_worker_group false abort .
55
6- The race:
6+ The race (before fix) :
771. Last stage completes — all shards done, workers get SHUTDOWN, exit
882. Main thread is in _collect_results / _regroup_result_refs (between
99 _wait_for_stage returning and self.shutdown())
10103. Background coordinator loop calls _check_worker_group
11114. worker_group.is_done() returns True (workers exited cleanly!)
12125. Coordinator calls abort("Worker job terminated permanently...")
13- 6. Next _wait_for_stage (or the caller reading results) sees fatal_error
1413
15- _check_worker_group does NOT check whether the stage actually completed.
16- It unconditionally treats is_done()==True as a crash.
14+ Fix: _check_worker_group skips when all shards are completed.
1715"""
1816
1917from __future__ import annotations
@@ -34,50 +32,48 @@ def coordinator(actor_context, tmp_path):
3432 coord .shutdown ()
3533
3634
37- def test_check_worker_group_aborts_after_completed_stage (coordinator ):
38- """Reproduce: worker group finishing after last stage triggers false abort.
39-
40- This is the exact race from issue #4117. The coordinator background loop
41- sees worker_group.is_done() == True and aborts, even though the stage
42- completed successfully.
43- """
44- # Set up a worker group that reports is_done=True (workers exited after SHUTDOWN)
35+ def test_check_worker_group_skips_after_completed_stage (coordinator ):
36+ """Worker group finishing after completed stage must not abort. #4117."""
4537 mock_group = MagicMock ()
4638 mock_group .is_done .return_value = True
4739 coordinator .set_worker_group (mock_group )
4840
49- # Simulate a completed stage: 1 task, 1 completed
5041 task = ShardTask (shard_idx = 0 , total_shards = 1 , shard = ListShard (refs = []), operations = [], stage_name = "test" )
5142 coordinator .start_stage ("last-stage" , [task ], is_last_stage = True )
52-
53- # Simulate task completion (worker finished before exiting)
5443 coordinator .report_result ("worker-0" , 0 , 0 , TaskResult (shard = ListShard (refs = [])))
5544
56- # Verify stage is complete
5745 assert coordinator ._completed_shards >= coordinator ._total_shards
5846
59- # Now _check_worker_group fires — this is the bug
6047 coordinator ._check_worker_group ()
6148
62- fatal = coordinator .get_fatal_error ()
63- # BUG: fatal_error is set even though the stage completed successfully
64- assert fatal is not None , "Bug not triggered — _check_worker_group should have aborted"
65- assert "Worker job terminated permanently" in fatal
66- print (f"\n BUG REPRODUCED: { fatal !r} " )
67- print ("Workers exited cleanly after SHUTDOWN, but coordinator treated it as a crash." )
49+ assert coordinator .get_fatal_error () is None
50+
51+
52+ def test_check_worker_group_still_aborts_mid_stage (coordinator ):
53+ """Worker group dying while shards are still in-flight must abort."""
54+ mock_group = MagicMock ()
55+ mock_group .is_done .return_value = True
56+ coordinator .set_worker_group (mock_group )
57+
58+ task = ShardTask (shard_idx = 0 , total_shards = 2 , shard = ListShard (refs = []), operations = [], stage_name = "test" )
59+ coordinator .start_stage ("mid-stage" , [task , task ], is_last_stage = False )
60+ # Only 1 of 2 shards completed
61+ coordinator .report_result ("worker-0" , 0 , 0 , TaskResult (shard = ListShard (refs = [])))
62+
63+ coordinator ._check_worker_group ()
64+
65+ assert coordinator .get_fatal_error () is not None
66+ assert "Worker job terminated permanently" in coordinator .get_fatal_error ()
6867
6968
70- def test_check_worker_group_aborts_during_result_collection (coordinator ):
71- """Reproduce the full race: background loop fires during post-stage processing ."""
69+ def test_coordinator_loop_no_abort_during_result_collection (coordinator ):
70+ """Background loop must not abort during post-stage result collection. #4117 ."""
7271 mock_group = MagicMock ()
73- # Workers alive during stage, then exit after SHUTDOWN
7472 call_count = 0
7573
7674 def is_done_with_delay ():
7775 nonlocal call_count
7876 call_count += 1
79- # First few calls: workers still running
80- # After that: workers have exited
8177 return call_count > 2
8278
8379 mock_group .is_done .side_effect = is_done_with_delay
@@ -87,19 +83,14 @@ def is_done_with_delay():
8783 coordinator .start_stage ("last-stage" , [task ], is_last_stage = True )
8884 coordinator .report_result ("worker-0" , 0 , 0 , TaskResult (shard = ListShard (refs = [])))
8985
90- # Start coordinator background loop
9186 t = threading .Thread (target = coordinator ._coordinator_loop , daemon = True )
9287 t .start ()
9388
94- # Simulate the main thread doing post-stage work (collect_results, regroup, etc.)
95- # During this time, the background loop keeps calling _check_worker_group
89+ # Simulate the post-stage window where main thread collects/regroups results
9690 time .sleep (2 )
9791
9892 fatal = coordinator .get_fatal_error ()
9993 coordinator .shutdown ()
10094 t .join (timeout = 2.0 )
10195
102- assert fatal is not None , "Bug not triggered — expected abort during result collection window"
103- assert "Worker job terminated permanently" in fatal
104- print (f"\n BUG REPRODUCED: { fatal !r} " )
105- print ("Background loop aborted during post-stage result collection." )
96+ assert fatal is None
0 commit comments