Skip to content

Commit 43282cc

Browse files
authored
[Zephyr] Remove legacy start_stage method (#5145)
`ZephyrCoordinator::start_stage` is only used in tests (and commented as legacy). Remove `start_stage` and replace its usage with the internal `_start_stage` method throughout the test suite.
1 parent 0497c64 commit 43282cc

3 files changed

Lines changed: 16 additions & 20 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -993,10 +993,6 @@ def set_chunk_config(self, prefix: str, execution_id: str) -> None:
993993
self._chunk_prefix = prefix
994994
self._execution_id = execution_id
995995

996-
def start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: bool = False) -> None:
997-
"""Load a new stage's tasks into the queue (legacy compat)."""
998-
self._start_stage(stage_name, tasks, is_last_stage=is_last_stage)
999-
1000996
def check_heartbeats(self, timeout: float = 120.0) -> None:
1001997
"""Marks stale workers as FAILED, re-queues their in-flight tasks."""
1002998
with self._lock:

lib/zephyr/tests/test_execution.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def test_status_reports_alive_workers_not_total(actor_context, tmp_path):
307307
operations=[],
308308
stage_name="test",
309309
)
310-
coord.start_stage("test", [task])
310+
coord._start_stage("test", [task])
311311

312312
# Register 3 workers
313313
for i in range(3):
@@ -371,7 +371,7 @@ def test_no_duplicate_results_on_heartbeat_timeout(actor_context, tmp_path):
371371
operations=[],
372372
stage_name="test",
373373
)
374-
coord.start_stage("test", [task])
374+
coord._start_stage("test", [task])
375375

376376
# Worker A pulls task (attempt 0)
377377
pulled = coord.pull_task("worker-A")
@@ -435,7 +435,7 @@ def test_coordinator_accepts_winner_ignores_stale(actor_context, tmp_path):
435435
operations=[],
436436
stage_name="test",
437437
)
438-
coord.start_stage("test", [task])
438+
coord._start_stage("test", [task])
439439

440440
# Worker A pulls task (attempt 0)
441441
pulled_a = coord.pull_task("worker-A")
@@ -520,7 +520,7 @@ def test_report_error_requeues_until_max_shard_failures(actor_context, tmp_path)
520520
operations=[],
521521
stage_name="test",
522522
)
523-
coord.start_stage("test", [task])
523+
coord._start_stage("test", [task])
524524
coord.register_worker("worker-0", MagicMock())
525525

526526
# Each failure should re-queue until the limit
@@ -553,7 +553,7 @@ def test_heartbeat_timeouts_do_not_count_toward_shard_failures(actor_context, tm
553553
operations=[],
554554
stage_name="test",
555555
)
556-
coord.start_stage("test", [task])
556+
coord._start_stage("test", [task])
557557
coord.register_worker("worker-0", MagicMock())
558558

559559
# Far more heartbeat timeouts than MAX_SHARD_FAILURES — must not abort.
@@ -586,7 +586,7 @@ def test_worker_reregistration_does_not_count_toward_shard_failures(actor_contex
586586
operations=[],
587587
stage_name="test",
588588
)
589-
coord.start_stage("test", [task])
589+
coord._start_stage("test", [task])
590590
coord.register_worker("worker-0", MagicMock())
591591

592592
for _ in range(MAX_SHARD_FAILURES * 5):
@@ -613,7 +613,7 @@ def test_report_error_still_aborts_at_max_shard_failures_after_preemptions(actor
613613
operations=[],
614614
stage_name="test",
615615
)
616-
coord.start_stage("test", [task])
616+
coord._start_stage("test", [task])
617617
coord.register_worker("worker-0", MagicMock())
618618

619619
# Several preemption cycles first — these must not count.
@@ -649,7 +649,7 @@ def test_wait_for_stage_fails_when_all_workers_die(actor_context, tmp_path):
649649
operations=[],
650650
stage_name="test",
651651
)
652-
coord.start_stage("test", [task])
652+
coord._start_stage("test", [task])
653653

654654
# Register 2 workers
655655
coord.register_worker("worker-0", MagicMock())
@@ -682,7 +682,7 @@ def test_wait_for_stage_resets_dead_timer_on_recovery(actor_context, tmp_path):
682682
operations=[],
683683
stage_name="test",
684684
)
685-
coord.start_stage("test", [task])
685+
coord._start_stage("test", [task])
686686

687687
# Register and kill a worker
688688
coord.register_worker("worker-0", MagicMock())
@@ -838,7 +838,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp
838838
)
839839

840840
# Non-last stage: empty queue returns None
841-
coord.start_stage("stage-0", [task], is_last_stage=False)
841+
coord._start_stage("stage-0", [task], is_last_stage=False)
842842
pulled = coord.pull_task("worker-A")
843843
assert pulled is not None and pulled != "SHUTDOWN"
844844
_task, attempt, _config = pulled
@@ -856,7 +856,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp
856856
operations=[],
857857
stage_name="test-last",
858858
)
859-
coord.start_stage("stage-1", [task2], is_last_stage=True)
859+
coord._start_stage("stage-1", [task2], is_last_stage=True)
860860
pulled = coord.pull_task("worker-A")
861861
assert pulled is not None and pulled != "SHUTDOWN"
862862
_task, attempt, _config = pulled
@@ -871,7 +871,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp
871871
ShardTask(shard_idx=i, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test-last2")
872872
for i in range(2)
873873
]
874-
coord.start_stage("stage-2", tasks_2, is_last_stage=True)
874+
coord._start_stage("stage-2", tasks_2, is_last_stage=True)
875875
coord.pull_task("worker-A") # task 0 in-flight
876876
# Queue has one task left; worker-B takes it
877877
coord.pull_task("worker-B") # task 1 in-flight
@@ -889,7 +889,7 @@ def test_last_stage_deadlock_detected_when_worker_job_dies(actor_context, tmp_pa
889889
ShardTask(shard_idx=i, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test")
890890
for i in range(2)
891891
]
892-
coord.start_stage("last-stage", tasks, is_last_stage=True)
892+
coord._start_stage("last-stage", tasks, is_last_stage=True)
893893

894894
# Set up a mock worker group so _check_worker_group can query it.
895895
mock_group = MagicMock()

lib/zephyr/tests/test_worker_group_race.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_check_worker_group_skips_after_completed_stage(coordinator):
3939
coordinator.set_worker_group(mock_group)
4040

4141
task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test")
42-
coordinator.start_stage("last-stage", [task], is_last_stage=True)
42+
coordinator._start_stage("last-stage", [task], is_last_stage=True)
4343
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty())
4444

4545
assert coordinator._completed_shards >= coordinator._total_shards
@@ -56,7 +56,7 @@ def test_check_worker_group_still_aborts_mid_stage(coordinator):
5656
coordinator.set_worker_group(mock_group)
5757

5858
task = ShardTask(shard_idx=0, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test")
59-
coordinator.start_stage("mid-stage", [task, task], is_last_stage=False)
59+
coordinator._start_stage("mid-stage", [task, task], is_last_stage=False)
6060
# Only 1 of 2 shards completed
6161
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty())
6262

@@ -80,7 +80,7 @@ def is_done_with_delay():
8080
coordinator.set_worker_group(mock_group)
8181

8282
task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test")
83-
coordinator.start_stage("last-stage", [task], is_last_stage=True)
83+
coordinator._start_stage("last-stage", [task], is_last_stage=True)
8484
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty())
8585

8686
t = threading.Thread(target=coordinator._coordinator_loop, daemon=True)

0 commit comments

Comments
 (0)