Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions lib/zephyr/src/zephyr/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import logging
import os
import pickle
import re
import signal
import sys
from datetime import datetime, timezone
Expand Down Expand Up @@ -62,6 +61,9 @@
# unbounded. `_check_worker_group` backstops if workers fully exhaust Iris retries.
MAX_SHARD_FAILURES = 3

ZEPHYR_STAGE_ITEM_COUNT_KEY = "zephyr/stage/{stage_name}/item_count"
ZEPHYR_STAGE_BYTES_PROCESSED_KEY = "zephyr/stage/{stage_name}/bytes_processed"


class ShardFailureKind(enum.StrEnum):
"""TASK failures count toward MAX_SHARD_FAILURES; INFRA failures (preemption) do not."""
Expand Down Expand Up @@ -378,6 +380,9 @@ def __init__(self):
self._initialized: bool = False
self._pipeline_running: bool = False

# Set at each _start_stage so _log_status can show average throughput since stage start.
self._stage_monotonic_start: float | None = None

# Lock for accessing coordinator state from background thread
self._lock = threading.Lock()

Expand Down Expand Up @@ -492,8 +497,19 @@ def _log_status(self) -> None:
retried = {idx: att for idx, att in self._task_attempts.items() if att > 0}
alive = sum(1 for s in states if s in {WorkerState.READY, WorkerState.BUSY})
dead = sum(1 for s in states if s in {WorkerState.FAILED, WorkerState.DEAD})

totals = self.get_counters()
items = totals.get(ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=self._stage_name), 0)
bytes_processed = totals.get(ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=self._stage_name), 0)
elapsed = time.monotonic() - (
self._stage_monotonic_start if self._stage_monotonic_start is not None else float("inf")
)
item_rate = items / elapsed
byte_rate = bytes_processed / elapsed

logger.info(
"[%s] [%s] %d/%d complete, %d in-flight, %d queued, %d/%d workers alive, %d dead",
"[%s] [%s] %d/%d complete, %d in-flight, %d queued, %d/%d workers alive, %d dead; "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should log these also in/per worker subprocess?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually had this during testing, but removed it once I had it logged in the controller. Added it back in a separate commit so it's easy to pull back out if we want to.

"items=%d (%.1f/s), bytes_processed=%.1fMiB (%.1fMiB/s)",
self._execution_id,
self._stage_name,
self._completed_shards,
Expand All @@ -503,6 +519,10 @@ def _log_status(self) -> None:
alive,
len(self._worker_handles),
dead,
items,
item_rate,
bytes_processed / (1024 * 1024),
byte_rate / (1024 * 1024),
)
if retried:
attempts_histogram = dict(sorted(Counter(retried.values()).items()))
Expand Down Expand Up @@ -768,6 +788,7 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
# Only reset in-flight worker snapshots; completed snapshots
# accumulate across stages for full pipeline visibility.
self._worker_counters = {}
self._stage_monotonic_start = time.monotonic()

def _wait_for_stage(self) -> None:
"""Block until current stage completes or error occurs."""
Expand Down Expand Up @@ -865,7 +886,7 @@ def run_pipeline(
aux_per_shard = self._compute_join_aux(stage.operations, shards, stage_idx)

# Build and submit tasks
tasks = _compute_tasks_from_shards(shards, stage, aux_per_shard, stage_name=stage_label)
tasks = _compute_tasks_from_shards(shards, stage, stage_name=stage_label, aux_per_shard=aux_per_shard)
logger.info("[%s] Starting stage %s with %d tasks", self._execution_id, stage_label, len(tasks))
self._start_stage(stage_label, tasks, is_last_stage=(stage_idx == last_worker_stage_idx))

Expand Down Expand Up @@ -1760,15 +1781,12 @@ def _build_source_shards(source_items: list[SourceItem]) -> list[Shard]:
def _compute_tasks_from_shards(
shard_refs: list[Shard],
stage,
stage_name: str,
aux_per_shard: list[dict[int, Shard]] | None = None,
stage_name: str | None = None,
) -> list[ShardTask]:
"""Convert shard references into ShardTasks for the coordinator."""
total = len(shard_refs)
tasks = []
# Sanitize for use as a path component: replace non-alphanumeric runs with '-'
raw_name = stage_name or stage.stage_name(max_length=60)
output_stage_name = re.sub(r"[^a-zA-Z0-9_.-]+", "-", raw_name).strip("-")
Comment thread
wmoss marked this conversation as resolved.

for i, shard in enumerate(shard_refs):
aux_shards = None
Expand All @@ -1781,7 +1799,7 @@ def _compute_tasks_from_shards(
total_shards=total,
shard=shard,
operations=stage.operations,
stage_name=output_stage_name,
stage_name=stage_name,
aux_shards=aux_shards,
)
)
Expand Down
97 changes: 93 additions & 4 deletions lib/zephyr/src/zephyr/subprocess_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,23 @@

import logging
import os
import re
import sys
import threading
import time
import traceback
from contextlib import suppress
from typing import Any
from typing import Any, TypeVar
from collections.abc import Iterator

import cloudpickle
from rigging.filesystem import open_url

from zephyr import counters
from zephyr.execution import (
CounterSnapshot,
ZEPHYR_STAGE_BYTES_PROCESSED_KEY,
ZEPHYR_STAGE_ITEM_COUNT_KEY,
_shared_data_path,
_worker_ctx_var,
_write_stage_output,
Expand All @@ -35,7 +41,6 @@

logger = logging.getLogger(__name__)


SUBPROCESS_COUNTER_FLUSH_INTERVAL = 5.0
"""How often the subprocess flushes its counter snapshot to the counter file.

Expand All @@ -44,6 +49,25 @@
"""


T = TypeVar("T")


class StatisticsGenerator:
"""Wraps a generator and counts and sizes yielded items."""

def __init__(self, stage_name: str) -> None:
self._stage_name = stage_name

def wrap(self, gen: Iterator[T]) -> Iterator[T]:
for item in gen:
counters.increment(ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=self._stage_name), 1)
counters.increment(
ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=self._stage_name),
sys.getsizeof(item),
)
yield item


class _SubprocessWorkerContext:
"""Lightweight WorkerContext for subprocess shard execution.

Expand Down Expand Up @@ -105,6 +129,45 @@ def _periodic_counter_writer(
logger.warning("Failed to flush counter file to %s", counter_file, exc_info=True)


def _periodic_status_logger(
stop_event: threading.Event,
ctx: _SubprocessWorkerContext,
stage_name: str,
execution_id: str,
shard_idx: int,
total_shards: int,
monotonic_start: float,
interval: float,
) -> None:
"""Log ``item_count`` / ``bytes_processed`` rates on a fixed interval (cf. coordinator ``_log_status``).

Runs in a dedicated daemon thread so logs are attributed to that thread name.
Reads ``ctx._counters`` the same way as the counter flusher (shallow copy).
"""
item_key = ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=stage_name)
byte_key = ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=stage_name)
while not stop_event.wait(timeout=interval):
if sys.is_finalizing():
return
items = ctx._counters.get(item_key, 0)
bytes_processed = ctx._counters.get(byte_key, 0)
elapsed = time.monotonic() - monotonic_start
item_rate = items / elapsed
byte_rate = bytes_processed / elapsed
logger.info(
"[%s] [%s] [%s] shard %d/%d; items=%d (%.1f/s), bytes_processed=%.1fMiB (%.1fMiB/s)",
execution_id,
stage_name,
threading.current_thread().name,
shard_idx,
total_shards,
items,
item_rate,
bytes_processed / (1024 * 1024),
byte_rate / (1024 * 1024),
)


def execute_shard(task_file: str, result_file: str) -> None:
"""Entry point for subprocess shard execution.

Expand Down Expand Up @@ -135,6 +198,7 @@ def execute_shard(task_file: str, result_file: str) -> None:
counter_file = f"{result_file}.counters"
stop_event = threading.Event()
flusher: threading.Thread | None = None
status_logger: threading.Thread | None = None
result_or_error: Any
ctx: _SubprocessWorkerContext | None = None
try:
Expand All @@ -144,6 +208,8 @@ def execute_shard(task_file: str, result_file: str) -> None:
ctx = _SubprocessWorkerContext(chunk_prefix, execution_id)
_worker_ctx_var.set(ctx)

shard_monotonic_start = time.monotonic()

flusher = threading.Thread(
target=_periodic_counter_writer,
args=(stop_event, ctx, counter_file, SUBPROCESS_COUNTER_FLUSH_INTERVAL),
Expand All @@ -152,19 +218,40 @@ def execute_shard(task_file: str, result_file: str) -> None:
)
flusher.start()

status_logger = threading.Thread(
target=_periodic_status_logger,
args=(
stop_event,
ctx,
task.stage_name,
execution_id,
task.shard_idx,
task.total_shards,
shard_monotonic_start,
SUBPROCESS_COUNTER_FLUSH_INTERVAL,
),
daemon=True,
name="zephyr-subprocess-status-logger",
)
status_logger.start()

stage_ctx = StageContext(
shard=task.shard,
shard_idx=task.shard_idx,
total_shards=task.total_shards,
aux_shards=task.aux_shards,
)

stage_dir = f"{chunk_prefix}/{execution_id}/{task.stage_name}"
# Sanitize for use as a path component: replace non-alphanumeric runs with '-'
output_stage_name = re.sub(r"[^a-zA-Z0-9_.-]+", "-", task.stage_name).strip("-")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved this into the subprocess worker so I had access to the original stage name here. As far as I can tell, the stage name was previously only used here so I can just do this transformation here.

stage_dir = f"{chunk_prefix}/{execution_id}/{output_stage_name}"
external_sort_dir = f"{stage_dir}-external-sort/shard-{task.shard_idx:04d}"
scatter_op = next((op for op in task.operations if isinstance(op, Scatter)), None)

output_counter = StatisticsGenerator(task.stage_name)

result_or_error = _write_stage_output(
run_stage(stage_ctx, task.operations, external_sort_dir=external_sort_dir),
output_counter.wrap(run_stage(stage_ctx, task.operations, external_sort_dir=external_sort_dir)),
source_shard=task.shard_idx,
stage_dir=stage_dir,
shard_idx=task.shard_idx,
Expand All @@ -184,6 +271,8 @@ def execute_shard(task_file: str, result_file: str) -> None:
stop_event.set()
if flusher is not None and flusher.is_alive():
flusher.join(timeout=2.0)
if status_logger is not None and status_logger.is_alive():
status_logger.join(timeout=2.0)

with open(result_file, "wb") as f:
counters_out = dict(ctx._counters) if ctx is not None else {}
Expand Down
42 changes: 42 additions & 0 deletions lib/zephyr/tests/subprocess_worker_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright The Marin Authors
# SPDX-License-Identifier: Apache-2.0

"""Tests for zephyr.subprocess_worker."""

import cloudpickle

import zephyr.subprocess_worker as sw
from zephyr.execution import (
ZEPHYR_STAGE_BYTES_PROCESSED_KEY,
ZEPHYR_STAGE_ITEM_COUNT_KEY,
ListShard,
ShardTask,
)
from zephyr.shuffle import MemChunk


def test_execute_shard_sets_stage_scoped_output_counters(tmp_path):
"""execute_shard emits output counters under the task's stage name."""
chunk_prefix = str(tmp_path / "chunks")
execution_id = "test-exec"
stage_name = "test"
task = ShardTask(
shard_idx=0,
total_shards=1,
shard=ListShard(refs=[MemChunk(list(range(10)))]),
operations=[],
stage_name=stage_name,
)

task_file = str(tmp_path / "task.pkl")
result_file = str(tmp_path / "result.pkl")
with open(task_file, "wb") as f:
cloudpickle.dump((task, chunk_prefix, execution_id), f)

sw.execute_shard(task_file, result_file)

with open(result_file, "rb") as f:
_result_or_error, counters_out = cloudpickle.load(f)

assert counters_out[ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=stage_name)] == 10
assert counters_out[ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=stage_name)] > 0
Loading