Skip to content

Commit 8fe86bf

Browse files
nv-alichengclaude
andcommitted
fix(metrics): drain pending count, SIGTERM task GC, NaN display
Addresses the three high-priority findings from the review council: H1: drain_tasks now owns the timeout + cancel-and-await sequence, so the pending count is captured before per-task done callbacks empty the in-flight set. Previously read 0 unconditionally — the documented state==COMPLETE and n_pending_tasks>0 drain-timeout contract was unenforceable. H2: Extract _make_sigterm_handler returning a strong-ref set[Task] that holds the spawned _signal_finalize task; the loop tracks tasks via weakref only, so a discarded create_task() return value can be GC'd mid-flight (Python asyncio docs) — exactly the failure the INTERRUPTED delivery path exists to prevent. H3: _scrub_nonfinite maps producer-side NaN/Inf to None for strict JSON. _display_metric did val * scale_factor with no guard → TypeError on display(), which finalize_benchmark calls outside the report-build try/except. Render N/A for None across named scalars, histogram bucket edges, and percentiles. Tests added (all verified failing pre-fix): - test_drain_timeout_reports_pending_count: forever-blocking pool + drain_timeout_s=0.05, asserts publish_final receives n_pending>0 - test_sigterm_handler_holds_strong_reference_to_finalize_task: drives the handler, asserts task is in the strong-ref set, survives gc.collect(), and self-removes via done-callback on completion - test_sigterm_handler_refreshes_tracked_duration: handler mirrors the ENDED path's tracked_duration_ns refresh before publish_final - test_display_handles_scrubbed_nan_percentiles: dict with scrubbed None percentile values does not crash display(); renders N/A - test_scrub_nonfinite_round_trip_yields_none: registry-side NaN/Inf surfaces as None in snapshot_to_dict and round-trips through json.dumps(allow_nan=False) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent a95ec22 commit 8fe86bf

7 files changed

Lines changed: 434 additions & 69 deletions

File tree

src/inference_endpoint/async_utils/services/metrics_aggregator/__main__.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import asyncio
2020
import logging
2121
import signal
22+
from collections.abc import Callable
2223
from contextlib import AbstractContextManager, nullcontext
2324
from pathlib import Path
2425

@@ -28,6 +29,7 @@
2829
from inference_endpoint.utils.logging import setup_logging
2930

3031
from .aggregator import MetricCounterKey, MetricsAggregatorService
32+
from .metrics_table import MetricsTable
3133
from .publisher import MetricsPublisher
3234
from .registry import MetricsRegistry
3335
from .snapshot import MetricsSnapshotCodec
@@ -36,6 +38,63 @@
3638
logger = logging.getLogger(__name__)
3739

3840

41+
def _make_sigterm_handler(
42+
*,
43+
loop: asyncio.AbstractEventLoop,
44+
registry: MetricsRegistry,
45+
publisher: MetricsPublisher,
46+
table: MetricsTable,
47+
shutdown_event: asyncio.Event,
48+
) -> tuple[Callable[[], None], set[asyncio.Task]]:
49+
"""Build the SIGTERM handler that writes the INTERRUPTED final snapshot.
50+
51+
Returns ``(handler, pending_tasks)``. ``pending_tasks`` is the
52+
strong-reference container that keeps spawned finalize tasks alive
53+
while they run: asyncio tracks tasks only by weakref, so a task
54+
whose only reference is the local variable inside the handler can
55+
be garbage-collected mid-execution (per Python's asyncio docs).
56+
Each spawned task self-removes from the set via
57+
``add_done_callback`` once it completes.
58+
59+
Exposed at module level (rather than nested in ``main()``) so the
60+
GC-safety contract is unit-testable without driving the whole
61+
subprocess lifecycle.
62+
"""
63+
pending_tasks: set[asyncio.Task] = set()
64+
65+
async def _signal_finalize() -> None:
66+
try:
67+
# Mirror the ENDED-driven path: refresh tracked_duration_ns
68+
# from the table BEFORE publish_final, otherwise an
69+
# interrupted run whose STOP_PERFORMANCE_TRACKING never
70+
# fired would report duration_ns=0 and QPS=N/A in the final
71+
# report even after processing many tracked samples.
72+
registry.set_counter(
73+
MetricCounterKey.TRACKED_DURATION_NS.value,
74+
table.total_tracked_duration_ns,
75+
)
76+
await publisher.publish_final(
77+
registry,
78+
n_pending_tasks=table.in_flight_tasks_count,
79+
interrupted=True,
80+
)
81+
except Exception: # noqa: BLE001 — best-effort.
82+
logger.exception(
83+
"metrics aggregator: SIGTERM-triggered publish_final failed"
84+
)
85+
shutdown_event.set()
86+
87+
def _on_sigterm() -> None:
88+
logger.warning(
89+
"metrics aggregator received SIGTERM; " "writing INTERRUPTED final snapshot"
90+
)
91+
task = loop.create_task(_signal_finalize())
92+
pending_tasks.add(task)
93+
task.add_done_callback(pending_tasks.discard)
94+
95+
return _on_sigterm, pending_tasks
96+
97+
3998
async def main() -> None:
4099
parser = argparse.ArgumentParser(
41100
description="Metrics aggregator service - subscribes to EventRecords and computes real-time metrics"
@@ -206,37 +265,14 @@ async def main() -> None:
206265
# aggregator's finalize — so we install a no-op handler for
207266
# SIGINT here, which prevents Python's default
208267
# KeyboardInterrupt and lets the parent control the lifecycle.
209-
def _on_sigterm() -> None:
210-
logger.warning(
211-
"metrics aggregator received SIGTERM; "
212-
"writing INTERRUPTED final snapshot"
213-
)
214-
loop.create_task(_signal_finalize())
215-
216-
async def _signal_finalize() -> None:
217-
try:
218-
# Mirror the ENDED-driven path: refresh
219-
# tracked_duration_ns from the table BEFORE
220-
# publish_final, otherwise an interrupted run whose
221-
# STOP_PERFORMANCE_TRACKING never fired would
222-
# report duration_ns=0 and QPS=N/A in the final
223-
# report even after processing many tracked samples.
224-
registry.set_counter(
225-
MetricCounterKey.TRACKED_DURATION_NS.value,
226-
aggregator._table.total_tracked_duration_ns,
227-
)
228-
await publisher.publish_final(
229-
registry,
230-
n_pending_tasks=aggregator._table.in_flight_tasks_count,
231-
interrupted=True,
232-
)
233-
except Exception: # noqa: BLE001 — best-effort.
234-
logger.exception(
235-
"metrics aggregator: SIGTERM-triggered publish_final failed"
236-
)
237-
shutdown_event.set()
238-
239-
loop.add_signal_handler(signal.SIGTERM, _on_sigterm)
268+
on_sigterm, _sigterm_tasks = _make_sigterm_handler(
269+
loop=loop,
270+
registry=registry,
271+
publisher=publisher,
272+
table=aggregator._table,
273+
shutdown_event=shutdown_event,
274+
)
275+
loop.add_signal_handler(signal.SIGTERM, on_sigterm)
240276
# No-op SIGINT handler: silence the default KeyboardInterrupt
241277
# and let the parent's ENDED-driven path drive shutdown.
242278
loop.add_signal_handler(

src/inference_endpoint/async_utils/services/metrics_aggregator/aggregator.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -368,26 +368,18 @@ async def process(self, records: list[EventRecord]) -> None:
368368
# that fires before publish_final reflects the new state.
369369
self._session_state = SessionState.DRAINING
370370
logger.info("Draining %d async tasks...", table.in_flight_tasks_count)
371-
try:
372-
await asyncio.wait_for(
373-
table.drain_tasks(), timeout=self._drain_timeout_s
374-
)
375-
except TimeoutError:
371+
# drain_tasks owns the timeout + cancel-and-await sequence so
372+
# the pending count is captured BEFORE done-callbacks empty
373+
# the in-flight set. Reading in_flight_tasks_count out here
374+
# would always be 0 (see drain_tasks docstring).
375+
n_pending = await table.drain_tasks(timeout=self._drain_timeout_s)
376+
if n_pending > 0:
376377
logger.warning(
377-
"drain_tasks timed out after %.1fs; some async metrics "
378-
"may be incomplete",
378+
"drain_tasks timed out after %.1fs; %d async tasks "
379+
"did not complete and were cancelled",
379380
self._drain_timeout_s,
381+
n_pending,
380382
)
381-
# cancel() only *schedules* cancellation at the next await
382-
# point. Await the cancelled tasks so they actually exit
383-
# before publish_final reads n_pending — otherwise the
384-
# snapshot reports stale-high pending counts and the
385-
# event-loop tear-down emits "Task was destroyed but it
386-
# is pending!" warnings on the cancelled set.
387-
cancelled = table.cancel_in_flight_tasks()
388-
if cancelled:
389-
await asyncio.gather(*cancelled, return_exceptions=True)
390-
n_pending = table.in_flight_tasks_count
391383
logger.info(
392384
"Async tasks drained (n_pending_tasks=%d at finalize)", n_pending
393385
)

src/inference_endpoint/async_utils/services/metrics_aggregator/metrics_table.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -494,26 +494,37 @@ def in_flight_tasks_count(self) -> int:
494494
"""Number of async trigger tasks currently in flight."""
495495
return len(self._in_flight_tasks)
496496

497-
async def drain_tasks(self) -> None:
498-
"""Await all in-flight async trigger tasks."""
499-
if self._in_flight_tasks:
497+
async def drain_tasks(self, *, timeout: float | None = None) -> int:
498+
"""Await in-flight async trigger tasks.
499+
500+
With ``timeout``, the pending set at the timeout boundary is
501+
cancelled and awaited; the count of those pending tasks is
502+
returned (>0 indicates the drain timed out). Without
503+
``timeout``, blocks indefinitely and returns 0 on clean drain.
504+
505+
The pending count must be captured BEFORE the cancel-and-await
506+
step: each task's ``add_done_callback(_in_flight_tasks.discard)``
507+
empties ``_in_flight_tasks`` as cancellation propagates, so
508+
reading ``in_flight_tasks_count`` after this method returns
509+
would always be 0 — making a drain timeout indistinguishable
510+
from a clean run.
511+
"""
512+
if not self._in_flight_tasks:
513+
return 0
514+
if timeout is None:
500515
await asyncio.gather(*self._in_flight_tasks, return_exceptions=True)
501516
self._in_flight_tasks.clear()
502-
503-
def cancel_in_flight_tasks(self) -> list[asyncio.Task]:
504-
"""Cancel every in-flight async trigger task that hasn't finished.
505-
506-
Returns the tasks that were cancelled so callers can await them
507-
(cancellation is only scheduled by ``Task.cancel()`` — the tasks
508-
must still be awaited at a later point for the cancellation to
509-
actually take effect).
510-
"""
511-
cancelled: list[asyncio.Task] = []
512-
for t in list(self._in_flight_tasks):
513-
if not t.done():
517+
return 0
518+
_, still_pending = await asyncio.wait(
519+
list(self._in_flight_tasks), timeout=timeout
520+
)
521+
n_pending = len(still_pending)
522+
if still_pending:
523+
for t in still_pending:
514524
t.cancel()
515-
cancelled.append(t)
516-
return cancelled
525+
await asyncio.gather(*still_pending, return_exceptions=True)
526+
self._in_flight_tasks.clear()
527+
return n_pending
517528

518529
# --- Internal ---
519530

src/inference_endpoint/metrics/report.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,32 @@ def _display_metric(
296296
scale_factor: float = 1.0,
297297
newline: str = "",
298298
) -> None:
299+
# ``_scrub_nonfinite`` (snapshot.py) maps producer-side NaN/±Inf to
300+
# ``None`` so the persisted JSON stays strict. Any of the named
301+
# scalars / percentile values below can therefore be ``None`` —
302+
# render an ``N/A`` indicator instead of crashing on
303+
# ``None * scale_factor``.
304+
def _scaled(v: Any) -> str:
305+
if v is None:
306+
return "N/A"
307+
return f"{v * scale_factor:.2f}"
308+
299309
for name, key in [
300310
("Min", "min"),
301311
("Max", "max"),
302312
("Median", "median"),
303313
("Avg.", "avg"),
304314
("Std Dev.", "std_dev"),
305315
]:
306-
fn(f" {name}: {metric_dict[key] * scale_factor:.2f} {unit}{newline}")
316+
fn(f" {name}: {_scaled(metric_dict[key])} {unit}{newline}")
307317

308318
fn(f"\n Histogram:{newline}")
309319
buckets = metric_dict["histogram"]["buckets"]
310320
counts = metric_dict["histogram"]["counts"]
311321

312322
if buckets:
313323
bucket_strs = [
314-
f" [{lo * scale_factor:.2f}, {hi * scale_factor:.2f}"
315-
+ ("]" if i == len(buckets) - 1 else ")")
324+
f" [{_scaled(lo)}, {_scaled(hi)}" + ("]" if i == len(buckets) - 1 else ")")
316325
for i, (lo, hi) in enumerate(buckets)
317326
]
318327
max_count = max(counts)
@@ -325,4 +334,4 @@ def _display_metric(
325334

326335
fn(f"\n Percentiles:{newline}")
327336
for p, val in metric_dict.get("percentiles", {}).items():
328-
fn(f" {p:>6}: {val * scale_factor:.2f} {unit}{newline}")
337+
fn(f" {p:>6}: {_scaled(val)} {unit}{newline}")

tests/unit/async_utils/services/metrics_aggregator/test_aggregator.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,3 +1055,68 @@ async def test_shutdown_drains_async_tasks(self, tmp_path):
10551055
# exercised here. Adding a MockTokenizePool that raises on
10561056
# token_count_async would let us assert no metric is emitted, the
10571057
# aggregator does not crash, and the task set is cleaned up.
1058+
1059+
@pytest.mark.asyncio
1060+
async def test_drain_timeout_reports_pending_count(self, tmp_path):
1061+
"""On drain timeout, publish_final must receive n_pending_tasks > 0.
1062+
1063+
AGENTS.md and the ``MetricsSnapshot.n_pending_tasks`` docstring
1064+
document the consumer contract: a drain-timeout run is detected
1065+
downstream as ``state == COMPLETE and n_pending_tasks > 0``. If
1066+
the producer always reports 0 here, the timeout is silently
1067+
rebadged as a clean run and the Report shows no warning.
1068+
"""
1069+
loop = asyncio.get_event_loop()
1070+
1071+
class BlockingTokenizePool:
1072+
async def token_count_async(self, text, _loop):
1073+
await asyncio.sleep(10.0) # exceeds drain timeout
1074+
return 0
1075+
1076+
def token_count(self, text):
1077+
return 0
1078+
1079+
def close(self):
1080+
pass
1081+
1082+
def __enter__(self):
1083+
return self
1084+
1085+
def __exit__(self, *args):
1086+
self.close()
1087+
1088+
with ManagedZMQContext.scoped(socket_dir=str(tmp_path)) as ctx:
1089+
agg, _, publisher = make_aggregator(
1090+
ctx,
1091+
loop,
1092+
"agg_drain_timeout",
1093+
tokenize_pool=BlockingTokenizePool(),
1094+
)
1095+
agg._drain_timeout_s = 0.05
1096+
try:
1097+
await agg.process(
1098+
[
1099+
session_event(
1100+
SessionEventType.START_PERFORMANCE_TRACKING, ts=0
1101+
),
1102+
sample_event(
1103+
SampleEventType.ISSUED,
1104+
"s1",
1105+
ts=1000,
1106+
data=PromptData(text="some text to tokenize"),
1107+
),
1108+
]
1109+
)
1110+
assert (
1111+
agg._table.in_flight_tasks_count > 0
1112+
), "precondition: ISL task must be in-flight before ENDED"
1113+
await agg.process([session_event(SessionEventType.ENDED, ts=2000)])
1114+
1115+
publisher.publish_final.assert_awaited_once()
1116+
kwargs = publisher.publish_final.await_args.kwargs
1117+
assert kwargs["n_pending_tasks"] > 0, (
1118+
f"drain timeout must report stuck tasks; got "
1119+
f"n_pending_tasks={kwargs['n_pending_tasks']}"
1120+
)
1121+
finally:
1122+
agg.close()

0 commit comments

Comments
 (0)