Skip to content

Commit fd0f1b2

Browse files
ravwojdyla-agentravwojdylagithub-actions[bot]claude
authored
[iris] Add user-defined counters (MapReduce-style per-job stats) (#4085)
Add map<string, int64> counters to TaskStatus, JobStatus, and WorkerTaskStatus in cluster.proto. Task code calls iris.counters.increment(name, value) which writes to a JSON file in the workdir; the worker monitor loop reads it each poll cycle and forwards values through the heartbeat to the controller, which stores them per task in a new counters_json DB column. get_job_status() sums counters across all tasks. Includes migration 0013 and three unit tests. Fixes #3995 --------- Co-authored-by: Rafal Wojdyla <ravwojdyla@gmail.com> Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com> Co-authored-by: Rafal Wojdyla <ravwojdyla@users.noreply.github.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent d6679bf commit fd0f1b2

6 files changed

Lines changed: 167 additions & 6 deletions

File tree

lib/iris/src/iris/cluster/controller/transitions.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1047,7 +1047,6 @@ def _apply_single_heartbeat(
10471047
"UPDATE tasks SET resource_usage_proto = ? WHERE task_id = ?",
10481048
(usage_payload, update.task_id.to_wire()),
10491049
)
1050-
10511050
terminal_ms: int | None = None
10521051
started_ms: int | None = None
10531052
task_state = prior_state

lib/iris/src/iris/cluster/worker/task_attempt.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,6 @@ def to_proto(self) -> cluster_pb2.TaskStatus:
444444
proto.build_metrics.build_started.CopyFrom(self.build_started.to_proto())
445445
if self.build_finished is not None:
446446
proto.build_metrics.build_finished.CopyFrom(self.build_finished.to_proto())
447-
448447
return proto
449448

450449
def _check_cancelled(self) -> None:

lib/zephyr/src/zephyr/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import logging
77

8+
from zephyr import counters
89
from zephyr.dataset import Dataset, ShardInfo
910
from zephyr.execution import WorkerContext, ZephyrContext, zephyr_worker_ctx
1011
from zephyr.expr import Expr, col, lit
@@ -25,6 +26,7 @@
2526
"atomic_rename",
2627
"col",
2728
"compute_plan",
29+
"counters",
2830
"lit",
2931
"load_file",
3032
"load_jsonl",

lib/zephyr/src/zephyr/counters.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""User-defined counters for Zephyr tasks.
5+
6+
Task code can increment named counters during execution; counters are
7+
aggregated across all tasks and exposed via the coordinator's ``get_counters()``
8+
actor method.
9+
10+
Usage::
11+
12+
from zephyr import counters
13+
14+
counters.increment("documents_processed", 100)
15+
counters.increment("validation_errors")
16+
17+
Counter values are accumulated in-memory on each worker and sent to the
18+
coordinator via the heartbeat loop.
19+
20+
Outside of a Zephyr worker context, all calls are silent no-ops.
21+
"""
22+
23+
import logging
24+
25+
from zephyr.execution import _worker_ctx_var
26+
27+
logger = logging.getLogger(__name__)
28+
29+
30+
def increment(name: str, value: int = 1) -> None:
31+
"""Increment a named counter by ``value`` (default 1).
32+
33+
O(1) in-memory update. Thread-safe. No-op outside a Zephyr worker.
34+
"""
35+
worker = _worker_ctx_var.get()
36+
if worker is None:
37+
return
38+
worker.increment_counter(name, value)
39+
40+
41+
def get_counters() -> dict[str, int]:
42+
"""Return a snapshot of the current task's counters.
43+
44+
Returns an empty dict outside a Zephyr worker context.
45+
"""
46+
worker = _worker_ctx_var.get()
47+
if worker is None:
48+
return {}
49+
return worker.get_counter_snapshot()

lib/zephyr/src/zephyr/execution.py

Lines changed: 63 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,8 @@ class ZephyrWorkerError(RuntimeError):
293293

294294
class WorkerContext(Protocol):
295295
def get_shared(self, name: str) -> Any: ...
296+
def increment_counter(self, name: str, value: int = 1) -> None: ...
297+
def get_counter_snapshot(self) -> dict[str, int]: ...
296298

297299

298300
_worker_ctx_var: ContextVar[ZephyrWorker | None] = ContextVar("zephyr_worker_ctx", default=None)
@@ -351,6 +353,9 @@ def __init__(self):
351353
self._chunk_prefix: str = ""
352354
self._execution_id: str = ""
353355
self._no_workers_timeout: float = 60.0
356+
# User-defined counters: in-flight per-worker snapshots and global accumulator.
357+
self._worker_counters: dict[str, dict[str, int]] = {}
358+
self._global_counters: dict[str, int] = {}
354359

355360
# Worker management state (workers self-register via register_worker)
356361
self._worker_handles: dict[str, ActorHandle] = {}
@@ -487,6 +492,9 @@ def _maybe_requeue_worker_task(self, worker_id: str) -> None:
487492
self._task_attempts[task.shard_idx] += 1
488493
self._task_queue.append(task)
489494
self._retries += 1
495+
# Discard in-flight counter snapshot so it doesn't double-count when the
496+
# shard is retried on another worker.
497+
self._worker_counters.pop(worker_id, None)
490498

491499
def _check_worker_heartbeats(self, timeout: float = 120.0) -> None:
492500
"""Internal heartbeat check (called with lock held)."""
@@ -565,20 +573,27 @@ def report_result(self, worker_id: str, shard_idx: int, attempt: int, result: Ta
565573
self._completed_shards += 1
566574
self._in_flight.pop(worker_id, None)
567575
self._worker_states[worker_id] = WorkerState.READY
576+
# Accumulate final counters for this task into the global total
577+
for name, value in self._worker_counters.pop(worker_id, {}).items():
578+
self._global_counters[name] = self._global_counters.get(name, 0) + value
568579

569580
def report_error(self, worker_id: str, shard_idx: int, error_info: str) -> None:
570581
"""Worker reports a task failure. All errors are fatal."""
571582
with self._lock:
572583
self._last_seen[worker_id] = time.monotonic()
573584
self._assert_in_flight_consistent(worker_id, shard_idx)
574585
self._in_flight.pop(worker_id, None)
586+
self._worker_counters.pop(worker_id, None)
575587
self._fatal_error = error_info
576588
self._worker_states[worker_id] = WorkerState.DEAD
577589

578-
def heartbeat(self, worker_id: str) -> None:
579-
# No lock needed: _last_seen is only read by _check_worker_heartbeats
590+
def heartbeat(self, worker_id: str, counters: dict[str, int] | None = None) -> None:
591+
# No lock needed for _last_seen: only read by _check_worker_heartbeats
580592
# (which holds the lock), and monotonic float writes are atomic on CPython.
581593
self._last_seen[worker_id] = time.monotonic()
594+
if counters:
595+
with self._lock:
596+
self._worker_counters[worker_id] = counters
582597

583598
def get_status(self) -> JobStatus:
584599
with self._lock:
@@ -600,6 +615,15 @@ def get_status(self) -> JobStatus:
600615
},
601616
)
602617

618+
def get_counters(self) -> dict[str, int]:
619+
"""Return global counter totals: completed-task values plus current in-flight snapshots."""
620+
with self._lock:
621+
totals = dict(self._global_counters)
622+
for ctrs in self._worker_counters.values():
623+
for name, value in ctrs.items():
624+
totals[name] = totals.get(name, 0) + value
625+
return totals
626+
603627
def get_fatal_error(self) -> str | None:
604628
with self._lock:
605629
return self._fatal_error
@@ -628,6 +652,8 @@ def _start_stage(self, stage_name: str, tasks: list[ShardTask], is_last_stage: b
628652
self._task_attempts = {task.shard_idx: 0 for task in tasks}
629653
self._fatal_error = None
630654
self._is_last_stage = is_last_stage
655+
self._worker_counters = {}
656+
self._global_counters = {}
631657

632658
def _wait_for_stage(self) -> None:
633659
"""Block until current stage completes or error occurs."""
@@ -876,6 +902,9 @@ def __init__(self, coordinator_handle: ActorHandle):
876902
self._shutdown_event = threading.Event()
877903
self._chunk_prefix: str = ""
878904
self._execution_id: str = ""
905+
self._counters: dict[str, int] = {}
906+
self._counters_lock = threading.Lock()
907+
self._last_reported_counters: dict[str, int] = {}
879908

880909
# Build descriptive worker ID from actor context
881910
actor_ctx = current_actor()
@@ -911,6 +940,28 @@ def get_shared(self, name: str) -> Any:
911940
)
912941
return self._shared_data_cache[name]
913942

943+
def increment_counter(self, name: str, value: int = 1) -> None:
944+
with self._counters_lock:
945+
self._counters[name] = self._counters.get(name, 0) + value
946+
947+
def get_counter_snapshot(self) -> dict[str, int]:
948+
with self._counters_lock:
949+
return dict(self._counters)
950+
951+
def _reset_counters(self) -> None:
952+
"""Clear counters for a new task."""
953+
with self._counters_lock:
954+
self._counters.clear()
955+
956+
def _counters_changed(self) -> bool:
957+
"""Return True if counters have changed since the last heartbeat report."""
958+
with self._counters_lock:
959+
current = dict(self._counters)
960+
if current == self._last_reported_counters:
961+
return False
962+
self._last_reported_counters = current
963+
return True
964+
914965
def _run_polling(self, coordinator: ActorHandle) -> None:
915966
"""Main polling loop. Runs in a background thread started by __init__."""
916967
logger.info("[%s] Starting polling loop", self._worker_id)
@@ -941,8 +992,13 @@ def _heartbeat_loop(
941992
while not self._shutdown_event.is_set():
942993
try:
943994
# Block on result to avoid congesting the coordinator RPC pipe
944-
# with fire-and-forget heartbeats.
945-
coordinator.heartbeat.remote(self._worker_id).result()
995+
# with fire-and-forget heartbeats. Only send counter snapshot
996+
# when values have changed.
997+
snapshot = self.get_counter_snapshot() if self._counters_changed() else None
998+
coordinator.heartbeat.remote(
999+
self._worker_id,
1000+
snapshot,
1001+
).result()
9461002
heartbeat_count += 1
9471003
consecutive_failures = 0
9481004
if heartbeat_count % 10 == 1:
@@ -1049,6 +1105,9 @@ def _execute_shard(self, task: ShardTask, config: dict) -> TaskResult:
10491105
self._chunk_prefix = config["chunk_prefix"]
10501106
self._execution_id = config["execution_id"]
10511107

1108+
# Reset counters for the new task
1109+
self._reset_counters()
1110+
10521111
_worker_ctx_var.set(self)
10531112

10541113
logger.info(

lib/zephyr/tests/test_counters.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright The Marin Authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Tests for Zephyr user-defined counters: worker API and heartbeat plumbing."""
5+
6+
import threading
7+
8+
from zephyr import counters
9+
from zephyr.execution import _worker_ctx_var
10+
11+
12+
class FakeWorker:
13+
"""Minimal WorkerContext implementation for testing counters."""
14+
15+
def __init__(self):
16+
self._counters: dict[str, int] = {}
17+
self._counters_lock = threading.Lock()
18+
19+
def get_shared(self, name: str):
20+
raise NotImplementedError
21+
22+
def increment_counter(self, name: str, value: int = 1) -> None:
23+
with self._counters_lock:
24+
self._counters[name] = self._counters.get(name, 0) + value
25+
26+
def get_counter_snapshot(self) -> dict[str, int]:
27+
with self._counters_lock:
28+
return dict(self._counters)
29+
30+
31+
def test_counters_increment_and_snapshot():
32+
"""increment() accumulates in-memory; get_counter_snapshot() returns current values."""
33+
worker = FakeWorker()
34+
token = _worker_ctx_var.set(worker)
35+
try:
36+
counters.increment("docs", 10)
37+
counters.increment("docs", 5)
38+
counters.increment("errors", 1)
39+
40+
snapshot = counters.get_counters()
41+
assert snapshot == {"docs": 15, "errors": 1}
42+
finally:
43+
_worker_ctx_var.reset(token)
44+
45+
46+
def test_counters_noop_outside_worker():
47+
"""increment() is a no-op when not inside a Zephyr worker context."""
48+
token = _worker_ctx_var.set(None)
49+
try:
50+
counters.increment("anything", 999) # should not raise
51+
assert counters.get_counters() == {}
52+
finally:
53+
_worker_ctx_var.reset(token)

0 commit comments

Comments
 (0)