Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,10 +993,6 @@ def set_chunk_config(self, prefix: str, execution_id: str) -> None:
self._chunk_prefix = prefix
self._execution_id = execution_id

def start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: bool = False) -> None:
"""Load a new stage's tasks into the queue (legacy compat)."""
self._start_stage(stage_name, tasks, is_last_stage=is_last_stage)

def check_heartbeats(self, timeout: float = 120.0) -> None:
"""Marks stale workers as FAILED, re-queues their in-flight tasks."""
with self._lock:
Expand Down
26 changes: 13 additions & 13 deletions lib/zephyr/tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def test_status_reports_alive_workers_not_total(actor_context, tmp_path):
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])

# Register 3 workers
for i in range(3):
Expand Down Expand Up @@ -371,7 +371,7 @@ def test_no_duplicate_results_on_heartbeat_timeout(actor_context, tmp_path):
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])

# Worker A pulls task (attempt 0)
pulled = coord.pull_task("worker-A")
Expand Down Expand Up @@ -435,7 +435,7 @@ def test_coordinator_accepts_winner_ignores_stale(actor_context, tmp_path):
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])

# Worker A pulls task (attempt 0)
pulled_a = coord.pull_task("worker-A")
Expand Down Expand Up @@ -520,7 +520,7 @@ def test_report_error_requeues_until_max_shard_failures(actor_context, tmp_path)
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])
coord.register_worker("worker-0", MagicMock())

# Each failure should re-queue until the limit
Expand Down Expand Up @@ -553,7 +553,7 @@ def test_heartbeat_timeouts_do_not_count_toward_shard_failures(actor_context, tm
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])
coord.register_worker("worker-0", MagicMock())

# Far more heartbeat timeouts than MAX_SHARD_FAILURES — must not abort.
Expand Down Expand Up @@ -586,7 +586,7 @@ def test_worker_reregistration_does_not_count_toward_shard_failures(actor_contex
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])
coord.register_worker("worker-0", MagicMock())

for _ in range(MAX_SHARD_FAILURES * 5):
Expand All @@ -613,7 +613,7 @@ def test_report_error_still_aborts_at_max_shard_failures_after_preemptions(actor
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])
coord.register_worker("worker-0", MagicMock())

# Several preemption cycles first — these must not count.
Expand Down Expand Up @@ -649,7 +649,7 @@ def test_wait_for_stage_fails_when_all_workers_die(actor_context, tmp_path):
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])

# Register 2 workers
coord.register_worker("worker-0", MagicMock())
Expand Down Expand Up @@ -682,7 +682,7 @@ def test_wait_for_stage_resets_dead_timer_on_recovery(actor_context, tmp_path):
operations=[],
stage_name="test",
)
coord.start_stage("test", [task])
coord._start_stage("test", [task])

# Register and kill a worker
coord.register_worker("worker-0", MagicMock())
Expand Down Expand Up @@ -838,7 +838,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp
)

# Non-last stage: empty queue returns None
coord.start_stage("stage-0", [task], is_last_stage=False)
coord._start_stage("stage-0", [task], is_last_stage=False)
pulled = coord.pull_task("worker-A")
assert pulled is not None and pulled != "SHUTDOWN"
_task, attempt, _config = pulled
Expand All @@ -856,7 +856,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp
operations=[],
stage_name="test-last",
)
coord.start_stage("stage-1", [task2], is_last_stage=True)
coord._start_stage("stage-1", [task2], is_last_stage=True)
pulled = coord.pull_task("worker-A")
assert pulled is not None and pulled != "SHUTDOWN"
_task, attempt, _config = pulled
Expand All @@ -871,7 +871,7 @@ def test_pull_task_returns_shutdown_on_last_stage_empty_queue(actor_context, tmp
ShardTask(shard_idx=i, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test-last2")
for i in range(2)
]
coord.start_stage("stage-2", tasks_2, is_last_stage=True)
coord._start_stage("stage-2", tasks_2, is_last_stage=True)
coord.pull_task("worker-A") # task 0 in-flight
# Queue has one task left; worker-B takes it
coord.pull_task("worker-B") # task 1 in-flight
Expand All @@ -889,7 +889,7 @@ def test_last_stage_deadlock_detected_when_worker_job_dies(actor_context, tmp_pa
ShardTask(shard_idx=i, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test")
for i in range(2)
]
coord.start_stage("last-stage", tasks, is_last_stage=True)
coord._start_stage("last-stage", tasks, is_last_stage=True)

# Set up a mock worker group so _check_worker_group can query it.
mock_group = MagicMock()
Expand Down
6 changes: 3 additions & 3 deletions lib/zephyr/tests/test_worker_group_race.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_check_worker_group_skips_after_completed_stage(coordinator):
coordinator.set_worker_group(mock_group)

task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test")
coordinator.start_stage("last-stage", [task], is_last_stage=True)
coordinator._start_stage("last-stage", [task], is_last_stage=True)
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty())

assert coordinator._completed_shards >= coordinator._total_shards
Expand All @@ -56,7 +56,7 @@ def test_check_worker_group_still_aborts_mid_stage(coordinator):
coordinator.set_worker_group(mock_group)

task = ShardTask(shard_idx=0, total_shards=2, shard=ListShard(refs=[]), operations=[], stage_name="test")
coordinator.start_stage("mid-stage", [task, task], is_last_stage=False)
coordinator._start_stage("mid-stage", [task, task], is_last_stage=False)
# Only 1 of 2 shards completed
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty())

Expand All @@ -80,7 +80,7 @@ def is_done_with_delay():
coordinator.set_worker_group(mock_group)

task = ShardTask(shard_idx=0, total_shards=1, shard=ListShard(refs=[]), operations=[], stage_name="test")
coordinator.start_stage("last-stage", [task], is_last_stage=True)
coordinator._start_stage("last-stage", [task], is_last_stage=True)
coordinator.report_result("worker-0", 0, 0, TaskResult(shard=ListShard(refs=[])), CounterSnapshot.empty())

t = threading.Thread(target=coordinator._coordinator_loop, daemon=True)
Expand Down
Loading