1919import re
2020import sys
2121import threading
22+ import time
2223import traceback
2324from contextlib import suppress
2425from typing import Any , TypeVar
3031from zephyr import counters
3132from zephyr .execution import (
3233 CounterSnapshot ,
34+ ZEPHYR_STAGE_BYTES_PROCESSED_KEY ,
35+ ZEPHYR_STAGE_ITEM_COUNT_KEY ,
3336 _shared_data_path ,
3437 _worker_ctx_var ,
3538 _write_stage_output ,
@@ -57,8 +60,11 @@ def __init__(self, stage_name: str) -> None:
5760
5861 def wrap (self , gen : Iterator [T ]) -> Iterator [T ]:
5962 for item in gen :
60- counters .increment (f"zephyr/stage/{ self ._stage_name } /item_count" , 1 )
61- counters .increment (f"zephyr/stage/{ self ._stage_name } /bytes_processed" , sys .getsizeof (item ))
63+ counters .increment (ZEPHYR_STAGE_ITEM_COUNT_KEY .format (stage_name = self ._stage_name ), 1 )
64+ counters .increment (
65+ ZEPHYR_STAGE_BYTES_PROCESSED_KEY .format (stage_name = self ._stage_name ),
66+ sys .getsizeof (item ),
67+ )
6268 yield item
6369
6470
@@ -123,6 +129,45 @@ def _periodic_counter_writer(
123129 logger .warning ("Failed to flush counter file to %s" , counter_file , exc_info = True )
124130
125131
132+ def _periodic_status_logger (
133+ stop_event : threading .Event ,
134+ ctx : _SubprocessWorkerContext ,
135+ stage_name : str ,
136+ execution_id : str ,
137+ shard_idx : int ,
138+ total_shards : int ,
139+ monotonic_start : float ,
140+ interval : float ,
141+ ) -> None :
142+ """Log ``item_count`` / ``bytes_processed`` rates on a fixed interval (cf. coordinator ``_log_status``).
143+
144+ Runs in a dedicated daemon thread so logs are attributed to that thread name.
145+ Reads ``ctx._counters`` the same way as the counter flusher (shallow copy).
146+ """
147+ item_key = ZEPHYR_STAGE_ITEM_COUNT_KEY .format (stage_name = stage_name )
148+ byte_key = ZEPHYR_STAGE_BYTES_PROCESSED_KEY .format (stage_name = stage_name )
149+ while not stop_event .wait (timeout = interval ):
150+ if sys .is_finalizing ():
151+ return
152+ items = ctx ._counters .get (item_key , 0 )
153+ bytes_processed = ctx ._counters .get (byte_key , 0 )
154+ elapsed = time .monotonic () - monotonic_start
155+ item_rate = items / elapsed
156+ byte_rate = bytes_processed / elapsed
157+ logger .info (
158+ "[%s] [%s] [%s] shard %d/%d; items=%d (%.1f/s), bytes_processed=%.1fMiB (%.1fMiB/s)" ,
159+ execution_id ,
160+ stage_name ,
161+ threading .current_thread ().name ,
162+ shard_idx ,
163+ total_shards ,
164+ items ,
165+ item_rate ,
166+ bytes_processed / (1024 * 1024 ),
167+ byte_rate / (1024 * 1024 ),
168+ )
169+
170+
126171def execute_shard (task_file : str , result_file : str ) -> None :
127172 """Entry point for subprocess shard execution.
128173
@@ -153,6 +198,7 @@ def execute_shard(task_file: str, result_file: str) -> None:
153198 counter_file = f"{ result_file } .counters"
154199 stop_event = threading .Event ()
155200 flusher : threading .Thread | None = None
201+ status_logger : threading .Thread | None = None
156202 result_or_error : Any
157203 ctx : _SubprocessWorkerContext | None = None
158204 try :
@@ -162,6 +208,8 @@ def execute_shard(task_file: str, result_file: str) -> None:
162208 ctx = _SubprocessWorkerContext (chunk_prefix , execution_id )
163209 _worker_ctx_var .set (ctx )
164210
211+ shard_monotonic_start = time .monotonic ()
212+
165213 flusher = threading .Thread (
166214 target = _periodic_counter_writer ,
167215 args = (stop_event , ctx , counter_file , SUBPROCESS_COUNTER_FLUSH_INTERVAL ),
@@ -170,6 +218,23 @@ def execute_shard(task_file: str, result_file: str) -> None:
170218 )
171219 flusher .start ()
172220
221+ status_logger = threading .Thread (
222+ target = _periodic_status_logger ,
223+ args = (
224+ stop_event ,
225+ ctx ,
226+ task .stage_name ,
227+ execution_id ,
228+ task .shard_idx ,
229+ task .total_shards ,
230+ shard_monotonic_start ,
231+ SUBPROCESS_COUNTER_FLUSH_INTERVAL ,
232+ ),
233+ daemon = True ,
234+ name = "zephyr-subprocess-status-logger" ,
235+ )
236+ status_logger .start ()
237+
173238 stage_ctx = StageContext (
174239 shard = task .shard ,
175240 shard_idx = task .shard_idx ,
@@ -206,6 +271,8 @@ def execute_shard(task_file: str, result_file: str) -> None:
206271 stop_event .set ()
207272 if flusher is not None and flusher .is_alive ():
208273 flusher .join (timeout = 2.0 )
274+ if status_logger is not None and status_logger .is_alive ():
275+ status_logger .join (timeout = 2.0 )
209276
210277 with open (result_file , "wb" ) as f :
211278 counters_out = dict (ctx ._counters ) if ctx is not None else {}
0 commit comments