Skip to content

Commit 4584c5d

Browse files
committed
Add a per-worker logging thread
1 parent c5e7ded commit 4584c5d

3 files changed

Lines changed: 82 additions & 7 deletions

File tree

lib/zephyr/src/zephyr/execution.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@
6161
# unbounded. `_check_worker_group` backstops if workers fully exhaust Iris retries.
6262
MAX_SHARD_FAILURES = 3
6363

64+
ZEPHYR_STAGE_ITEM_COUNT_KEY = "zephyr/stage/{stage_name}/item_count"
65+
ZEPHYR_STAGE_BYTES_PROCESSED_KEY = "zephyr/stage/{stage_name}/bytes_processed"
66+
6467

6568
class ShardFailureKind(enum.StrEnum):
6669
"""TASK failures count toward MAX_SHARD_FAILURES; INFRA failures (preemption) do not."""
@@ -496,8 +499,8 @@ def _log_status(self) -> None:
496499
dead = sum(1 for s in states if s in {WorkerState.FAILED, WorkerState.DEAD})
497500

498501
totals = self.get_counters()
499-
items = totals.get(f"zephyr/stage/{self._stage_name}/item_count", 0)
500-
bytes_processed = totals.get(f"zephyr/stage/{self._stage_name}/bytes_processed", 0)
502+
items = totals.get(ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=self._stage_name), 0)
503+
bytes_processed = totals.get(ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=self._stage_name), 0)
501504
elapsed = time.monotonic() - (
502505
self._stage_monotonic_start if self._stage_monotonic_start is not None else float("inf")
503506
)

lib/zephyr/src/zephyr/subprocess_worker.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import re
2020
import sys
2121
import threading
22+
import time
2223
import traceback
2324
from contextlib import suppress
2425
from typing import Any, TypeVar
@@ -30,6 +31,8 @@
3031
from zephyr import counters
3132
from zephyr.execution import (
3233
CounterSnapshot,
34+
ZEPHYR_STAGE_BYTES_PROCESSED_KEY,
35+
ZEPHYR_STAGE_ITEM_COUNT_KEY,
3336
_shared_data_path,
3437
_worker_ctx_var,
3538
_write_stage_output,
@@ -57,8 +60,11 @@ def __init__(self, stage_name: str) -> None:
5760

5861
def wrap(self, gen: Iterator[T]) -> Iterator[T]:
5962
for item in gen:
60-
counters.increment(f"zephyr/stage/{self._stage_name}/item_count", 1)
61-
counters.increment(f"zephyr/stage/{self._stage_name}/bytes_processed", sys.getsizeof(item))
63+
counters.increment(ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=self._stage_name), 1)
64+
counters.increment(
65+
ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=self._stage_name),
66+
sys.getsizeof(item),
67+
)
6268
yield item
6369

6470

@@ -123,6 +129,45 @@ def _periodic_counter_writer(
123129
logger.warning("Failed to flush counter file to %s", counter_file, exc_info=True)
124130

125131

132+
def _periodic_status_logger(
133+
stop_event: threading.Event,
134+
ctx: _SubprocessWorkerContext,
135+
stage_name: str,
136+
execution_id: str,
137+
shard_idx: int,
138+
total_shards: int,
139+
monotonic_start: float,
140+
interval: float,
141+
) -> None:
142+
"""Log ``item_count`` / ``bytes_processed`` rates on a fixed interval (cf. coordinator ``_log_status``).
143+
144+
Runs in a dedicated daemon thread so logs are attributed to that thread name.
145+
Reads ``ctx._counters`` the same way as the counter flusher (shallow copy).
146+
"""
147+
item_key = ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=stage_name)
148+
byte_key = ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=stage_name)
149+
while not stop_event.wait(timeout=interval):
150+
if sys.is_finalizing():
151+
return
152+
items = ctx._counters.get(item_key, 0)
153+
bytes_processed = ctx._counters.get(byte_key, 0)
154+
elapsed = time.monotonic() - monotonic_start
155+
item_rate = items / elapsed
156+
byte_rate = bytes_processed / elapsed
157+
logger.info(
158+
"[%s] [%s] [%s] shard %d/%d; items=%d (%.1f/s), bytes_processed=%.1fMiB (%.1fMiB/s)",
159+
execution_id,
160+
stage_name,
161+
threading.current_thread().name,
162+
shard_idx,
163+
total_shards,
164+
items,
165+
item_rate,
166+
bytes_processed / (1024 * 1024),
167+
byte_rate / (1024 * 1024),
168+
)
169+
170+
126171
def execute_shard(task_file: str, result_file: str) -> None:
127172
"""Entry point for subprocess shard execution.
128173
@@ -153,6 +198,7 @@ def execute_shard(task_file: str, result_file: str) -> None:
153198
counter_file = f"{result_file}.counters"
154199
stop_event = threading.Event()
155200
flusher: threading.Thread | None = None
201+
status_logger: threading.Thread | None = None
156202
result_or_error: Any
157203
ctx: _SubprocessWorkerContext | None = None
158204
try:
@@ -162,6 +208,8 @@ def execute_shard(task_file: str, result_file: str) -> None:
162208
ctx = _SubprocessWorkerContext(chunk_prefix, execution_id)
163209
_worker_ctx_var.set(ctx)
164210

211+
shard_monotonic_start = time.monotonic()
212+
165213
flusher = threading.Thread(
166214
target=_periodic_counter_writer,
167215
args=(stop_event, ctx, counter_file, SUBPROCESS_COUNTER_FLUSH_INTERVAL),
@@ -170,6 +218,23 @@ def execute_shard(task_file: str, result_file: str) -> None:
170218
)
171219
flusher.start()
172220

221+
status_logger = threading.Thread(
222+
target=_periodic_status_logger,
223+
args=(
224+
stop_event,
225+
ctx,
226+
task.stage_name,
227+
execution_id,
228+
task.shard_idx,
229+
task.total_shards,
230+
shard_monotonic_start,
231+
SUBPROCESS_COUNTER_FLUSH_INTERVAL,
232+
),
233+
daemon=True,
234+
name="zephyr-subprocess-status-logger",
235+
)
236+
status_logger.start()
237+
173238
stage_ctx = StageContext(
174239
shard=task.shard,
175240
shard_idx=task.shard_idx,
@@ -206,6 +271,8 @@ def execute_shard(task_file: str, result_file: str) -> None:
206271
stop_event.set()
207272
if flusher is not None and flusher.is_alive():
208273
flusher.join(timeout=2.0)
274+
if status_logger is not None and status_logger.is_alive():
275+
status_logger.join(timeout=2.0)
209276

210277
with open(result_file, "wb") as f:
211278
counters_out = dict(ctx._counters) if ctx is not None else {}

lib/zephyr/tests/subprocess_worker_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
import cloudpickle
77

88
import zephyr.subprocess_worker as sw
9-
from zephyr.execution import ListShard, ShardTask
9+
from zephyr.execution import (
10+
ZEPHYR_STAGE_BYTES_PROCESSED_KEY,
11+
ZEPHYR_STAGE_ITEM_COUNT_KEY,
12+
ListShard,
13+
ShardTask,
14+
)
1015
from zephyr.shuffle import MemChunk
1116

1217

@@ -33,5 +38,5 @@ def test_execute_shard_sets_stage_scoped_output_counters(tmp_path):
3338
with open(result_file, "rb") as f:
3439
_result_or_error, counters_out = cloudpickle.load(f)
3540

36-
assert counters_out[f"zephyr/stage/{stage_name}/item_count"] == 10
37-
assert counters_out[f"zephyr/stage/{stage_name}/bytes_processed"] > 0
41+
assert counters_out[ZEPHYR_STAGE_ITEM_COUNT_KEY.format(stage_name=stage_name)] == 10
42+
assert counters_out[ZEPHYR_STAGE_BYTES_PROCESSED_KEY.format(stage_name=stage_name)] > 0

0 commit comments

Comments
 (0)