@@ -293,6 +293,8 @@ class ZephyrWorkerError(RuntimeError):
293293
294294class WorkerContext (Protocol ):
295295 def get_shared (self , name : str ) -> Any : ...
296+ def increment_counter (self , name : str , value : int = 1 ) -> None : ...
297+ def get_counter_snapshot (self ) -> dict [str , int ]: ...
296298
297299
298300_worker_ctx_var : ContextVar [ZephyrWorker | None ] = ContextVar ("zephyr_worker_ctx" , default = None )
@@ -351,6 +353,9 @@ def __init__(self):
351353 self ._chunk_prefix : str = ""
352354 self ._execution_id : str = ""
353355 self ._no_workers_timeout : float = 60.0
356+ # User-defined counters: in-flight per-worker snapshots and global accumulator.
357+ self ._worker_counters : dict [str , dict [str , int ]] = {}
358+ self ._global_counters : dict [str , int ] = {}
354359
355360 # Worker management state (workers self-register via register_worker)
356361 self ._worker_handles : dict [str , ActorHandle ] = {}
@@ -487,6 +492,9 @@ def _maybe_requeue_worker_task(self, worker_id: str) -> None:
487492 self ._task_attempts [task .shard_idx ] += 1
488493 self ._task_queue .append (task )
489494 self ._retries += 1
495+ # Discard in-flight counter snapshot so it doesn't double-count when the
496+ # shard is retried on another worker.
497+ self ._worker_counters .pop (worker_id , None )
490498
491499 def _check_worker_heartbeats (self , timeout : float = 120.0 ) -> None :
492500 """Internal heartbeat check (called with lock held)."""
@@ -565,20 +573,27 @@ def report_result(self, worker_id: str, shard_idx: int, attempt: int, result: Ta
565573 self ._completed_shards += 1
566574 self ._in_flight .pop (worker_id , None )
567575 self ._worker_states [worker_id ] = WorkerState .READY
576+ # Accumulate final counters for this task into the global total
577+ for name , value in self ._worker_counters .pop (worker_id , {}).items ():
578+ self ._global_counters [name ] = self ._global_counters .get (name , 0 ) + value
568579
569580 def report_error (self , worker_id : str , shard_idx : int , error_info : str ) -> None :
570581 """Worker reports a task failure. All errors are fatal."""
571582 with self ._lock :
572583 self ._last_seen [worker_id ] = time .monotonic ()
573584 self ._assert_in_flight_consistent (worker_id , shard_idx )
574585 self ._in_flight .pop (worker_id , None )
586+ self ._worker_counters .pop (worker_id , None )
575587 self ._fatal_error = error_info
576588 self ._worker_states [worker_id ] = WorkerState .DEAD
577589
578- def heartbeat (self , worker_id : str ) -> None :
579- # No lock needed: _last_seen is only read by _check_worker_heartbeats
590+ def heartbeat (self , worker_id : str , counters : dict [ str , int ] | None = None ) -> None :
591+ # No lock needed for _last_seen: only read by _check_worker_heartbeats
580592 # (which holds the lock), and monotonic float writes are atomic on CPython.
581593 self ._last_seen [worker_id ] = time .monotonic ()
594+ if counters :
595+ with self ._lock :
596+ self ._worker_counters [worker_id ] = counters
582597
583598 def get_status (self ) -> JobStatus :
584599 with self ._lock :
@@ -600,6 +615,15 @@ def get_status(self) -> JobStatus:
600615 },
601616 )
602617
618+ def get_counters (self ) -> dict [str , int ]:
619+ """Return global counter totals: completed-task values plus current in-flight snapshots."""
620+ with self ._lock :
621+ totals = dict (self ._global_counters )
622+ for ctrs in self ._worker_counters .values ():
623+ for name , value in ctrs .items ():
624+ totals [name ] = totals .get (name , 0 ) + value
625+ return totals
626+
603627 def get_fatal_error (self ) -> str | None :
604628 with self ._lock :
605629 return self ._fatal_error
@@ -628,6 +652,8 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
628652 self ._task_attempts = {task .shard_idx : 0 for task in tasks }
629653 self ._fatal_error = None
630654 self ._is_last_stage = is_last_stage
655+ self ._worker_counters = {}
656+ self ._global_counters = {}
631657
632658 def _wait_for_stage (self ) -> None :
633659 """Block until current stage completes or error occurs."""
@@ -876,6 +902,9 @@ def __init__(self, coordinator_handle: ActorHandle):
876902 self ._shutdown_event = threading .Event ()
877903 self ._chunk_prefix : str = ""
878904 self ._execution_id : str = ""
905+ self ._counters : dict [str , int ] = {}
906+ self ._counters_lock = threading .Lock ()
907+ self ._last_reported_counters : dict [str , int ] = {}
879908
880909 # Build descriptive worker ID from actor context
881910 actor_ctx = current_actor ()
@@ -911,6 +940,28 @@ def get_shared(self, name: str) -> Any:
911940 )
912941 return self ._shared_data_cache [name ]
913942
943+ def increment_counter (self , name : str , value : int = 1 ) -> None :
944+ with self ._counters_lock :
945+ self ._counters [name ] = self ._counters .get (name , 0 ) + value
946+
947+ def get_counter_snapshot (self ) -> dict [str , int ]:
948+ with self ._counters_lock :
949+ return dict (self ._counters )
950+
951+ def _reset_counters (self ) -> None :
952+ """Clear counters for a new task."""
953+ with self ._counters_lock :
954+ self ._counters .clear ()
955+
956+ def _counters_changed (self ) -> bool :
957+ """Return True if counters have changed since the last heartbeat report."""
958+ with self ._counters_lock :
959+ current = dict (self ._counters )
960+ if current == self ._last_reported_counters :
961+ return False
962+ self ._last_reported_counters = current
963+ return True
964+
914965 def _run_polling (self , coordinator : ActorHandle ) -> None :
915966 """Main polling loop. Runs in a background thread started by __init__."""
916967 logger .info ("[%s] Starting polling loop" , self ._worker_id )
@@ -941,8 +992,13 @@ def _heartbeat_loop(
941992 while not self ._shutdown_event .is_set ():
942993 try :
943994 # Block on result to avoid congesting the coordinator RPC pipe
944- # with fire-and-forget heartbeats.
945- coordinator .heartbeat .remote (self ._worker_id ).result ()
995+ # with fire-and-forget heartbeats. Only send counter snapshot
996+ # when values have changed.
997+ snapshot = self .get_counter_snapshot () if self ._counters_changed () else None
998+ coordinator .heartbeat .remote (
999+ self ._worker_id ,
1000+ snapshot ,
1001+ ).result ()
9461002 heartbeat_count += 1
9471003 consecutive_failures = 0
9481004 if heartbeat_count % 10 == 1 :
@@ -1049,6 +1105,9 @@ def _execute_shard(self, task: ShardTask, config: dict) -> TaskResult:
10491105 self ._chunk_prefix = config ["chunk_prefix" ]
10501106 self ._execution_id = config ["execution_id" ]
10511107
1108+ # Reset counters for the new task
1109+ self ._reset_counters ()
1110+
10521111 _worker_ctx_var .set (self )
10531112
10541113 logger .info (
0 commit comments