diff --git a/lib/zephyr/src/zephyr/execution.py b/lib/zephyr/src/zephyr/execution.py index 215e7a115c..8a099849a0 100644 --- a/lib/zephyr/src/zephyr/execution.py +++ b/lib/zephyr/src/zephyr/execution.py @@ -993,10 +993,6 @@ def set_chunk_config(self, prefix: str, execution_id: str) -> None: self._chunk_prefix = prefix self._execution_id = execution_id - def start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: bool = False) -> None: - """Load a new stage's tasks into the queue (legacy compat).""" - self._start_stage(stage_name, tasks, is_last_stage=is_last_stage) - def check_heartbeats(self, timeout: float = 120.0) -> None: """Marks stale workers as FAILED, re-queues their in-flight tasks.""" with self._lock: diff --git a/lib/zephyr/tests/test_execution.py b/lib/zephyr/tests/test_execution.py index a2a48f87be..68d92386b8 100644 --- a/lib/zephyr/tests/test_execution.py +++ b/lib/zephyr/tests/test_execution.py @@ -307,7 +307,7 @@ def test_status_reports_alive_workers_not_total(actor_context, tmp_path): operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) # Register 3 workers for i in range(3): @@ -371,7 +371,7 @@ def test_no_duplicate_results_on_heartbeat_timeout(actor_context, tmp_path): operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) # Worker A pulls task (attempt 0) pulled = coord.pull_task("worker-A") @@ -435,7 +435,7 @@ def test_coordinator_accepts_winner_ignores_stale(actor_context, tmp_path): operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) # Worker A pulls task (attempt 0) pulled_a = coord.pull_task("worker-A") @@ -520,7 +520,7 @@ def test_report_error_requeues_until_max_shard_failures(actor_context, tmp_path) operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) coord.register_worker("worker-0", MagicMock()) # Each failure should re-queue until the limit @@ -553,7 +553,7 @@ def test_heartbeat_timeouts_do_not_count_toward_shard_failures(actor_context, tm operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) coord.register_worker("worker-0", MagicMock()) # Far more heartbeat timeouts than MAX_SHARD_FAILURES — must not abort. @@ -586,7 +586,7 @@ def test_worker_reregistration_does_not_count_toward_shard_failures(actor_contex operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) coord.register_worker("worker-0", MagicMock()) for _ in range(MAX_SHARD_FAILURES * 5): @@ -613,7 +613,7 @@ def test_report_error_still_aborts_at_max_shard_failures_after_preemptions(actor operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) coord.register_worker("worker-0", MagicMock()) # Several preemption cycles first — these must not count. @@ -649,7 +649,7 @@ def test_wait_for_stage_fails_when_all_workers_die(actor_context, tmp_path): operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) # Register 2 workers coord.register_worker("worker-0", MagicMock()) @@ -682,7 +682,7 @@ def test_wait_for_stage_resets_dead_timer_on_recovery(actor_context, tmp_path): operations=[], stage_name="test", ) - coord.start_stage("test", [task]) + coord._start_stage("test", [task]) # Register and kill a worker coord.register_worker("worker-0", MagicMock()) @@ -838,7 +838,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp ) # Non-last stage: empty queue returns None - coord.start_stage("stage-0", [task], is_last_stage=False) + coord._start_stage("stage-0", [task], is_last_stage=False) pulled = coord.pull_task("worker-A") assert pulled is not None and pulled != "SHUTDOWN" _task, attempt, _config = pulled @@ -856,7 +856,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp operations=[], stage_name="test-last", ) - coord.start_stage("stage-1", [task2], is_last_stage=True) + coord._start_stage("stage-1", [task2], is_last_stage=True) pulled = coord.pull_task("worker-A") assert pulled is not None and pulled != "SHUTDOWN" _task, attempt, _config = pulled @@ -871,7 +871,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp ShardTask(shard_idx=i, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test-last2") for i in range(2) ] - coord.start_stage("stage-2", tasks_2, is_last_stage=True) + coord._start_stage("stage-2", tasks_2, is_last_stage=True) coord.pull_task("worker-A") # task 0 in-flight # Queue has one task left; worker-B takes it coord.pull_task("worker-B") # task 1 in-flight @@ -889,7 +889,7 @@ def test_last_stage_deadlock_detected_when_worker_job_dies(actor_context, tmp_pa ShardTask(shard_idx=i, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test") for i in range(2) ] - coord.start_stage("last-stage", tasks, is_last_stage=True) + coord._start_stage("last-stage", tasks, is_last_stage=True) # Set up a mock worker group so _check_worker_group can query it. mock_group = MagicMock() diff --git a/lib/zephyr/tests/test_worker_group_race.py b/lib/zephyr/tests/test_worker_group_race.py index 3419a949f2..2392bfae3a 100644 --- a/lib/zephyr/tests/test_worker_group_race.py +++ b/lib/zephyr/tests/test_worker_group_race.py @@ -39,7 +39,7 @@ def test_check_worker_group_skips_after_completed_stage(coordinator): coordinator.set_worker_group(mock_group) task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test") - coordinator.start_stage("last-stage", [task], is_last_stage=True) + coordinator._start_stage("last-stage", [task], is_last_stage=True) coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty()) assert coordinator._completed_shards >= coordinator._total_shards @@ -56,7 +56,7 @@ def test_check_worker_group_still_aborts_mid_stage(coordinator): coordinator.set_worker_group(mock_group) task = ShardTask(shard_idx=0, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test") - coordinator.start_stage("mid-stage", [task, task], is_last_stage=False) + coordinator._start_stage("mid-stage", [task, task], is_last_stage=False) # Only 1 of 2 shards completed coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty()) @@ -80,7 +80,7 @@ def is_done_with_delay(): coordinator.set_worker_group(mock_group) task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test") - coordinator.start_stage("last-stage", [task], is_last_stage=True) + coordinator._start_stage("last-stage", [task], is_last_stage=True) coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty()) t = threading.Thread(target=coordinator._coordinator_loop, daemon=True)