Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 163 additions & 4 deletions src/aiperf/common/models/telemetry_models.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
from numpy.typing import NDArray
from pydantic import ConfigDict, Field

from aiperf.common.exceptions import NoMetricValue
from aiperf.common.models.base_models import AIPerfBaseModel
from aiperf.common.models.export_models import TelemetryExportData
from aiperf.common.models.record_models import MetricResult
from aiperf.common.models.server_metrics_models import TimeRangeFilter


class TelemetryMetrics(AIPerfBaseModel):
Expand Down Expand Up @@ -236,14 +238,17 @@ def to_metric_result(
arr, [1, 5, 10, 25, 50, 75, 90, 95, 99]
)

# Use sample std (ddof=1) for unbiased estimate; 0 for single sample
std_dev = float(np.std(arr, ddof=1)) if len(arr) > 1 else 0.0

return MetricResult(
tag=tag,
header=header,
unit=unit,
min=float(np.min(arr)),
max=float(np.max(arr)),
avg=float(np.mean(arr)),
std=float(np.std(arr)),
std=std_dev,
count=len(arr),
current=float(arr[-1]),
p1=p1,
Expand All @@ -257,6 +262,148 @@ def to_metric_result(
p99=p99,
)

def get_time_mask(self, time_filter: TimeRangeFilter | None) -> NDArray[np.bool_]:
"""Get boolean mask for points within time range.

Uses np.searchsorted for O(log n) binary search on sorted timestamps,
then slice assignment for mask creation (10-100x faster than element-wise
boolean comparisons for large arrays).

Args:
time_filter: Time range filter specifying start_ns and/or end_ns bounds.
None returns all-True mask.

Returns:
Boolean mask array where True indicates timestamp within range
"""
if time_filter is None:
return np.ones(self._size, dtype=bool)

timestamps = self.timestamps
first_idx = 0
last_idx = self._size

# O(log n) binary search for range boundaries
if time_filter.start_ns is not None:
first_idx = int(
np.searchsorted(timestamps, time_filter.start_ns, side="left")
)
if time_filter.end_ns is not None:
last_idx = int(
np.searchsorted(timestamps, time_filter.end_ns, side="right")
)

# Single allocation + slice assignment
mask = np.zeros(self._size, dtype=bool)
mask[first_idx:last_idx] = True
return mask

def get_reference_idx(self, time_filter: TimeRangeFilter | None) -> int | None:
"""Get index of last point BEFORE time filter start (for delta calculation).

Uses np.searchsorted for O(log n) lookup. Returns None if no baseline exists
(i.e., time_filter is None, start_ns is None, or no data before start_ns).

Args:
time_filter: Time range filter. Reference point is found before start_ns.

Returns:
Index of last timestamp before start_ns, or None if no baseline exists
"""
if time_filter is None or time_filter.start_ns is None:
return None
insert_pos = int(
np.searchsorted(self.timestamps, time_filter.start_ns, side="left")
)
return insert_pos - 1 if insert_pos > 0 else None

def to_metric_result_filtered(
self,
metric_name: str,
tag: str,
header: str,
unit: str,
time_filter: TimeRangeFilter | None = None,
is_counter: bool = False,
) -> MetricResult:
"""Compute stats with time filtering and optional delta for counters.

For gauges: Uses vectorized NumPy on filtered array (np.mean, np.std, np.percentile)
For counters: Computes delta from reference point before profiling start

Args:
metric_name: Name of the metric to analyze
tag: Unique identifier for this metric
header: Human-readable name for display
unit: Unit of measurement
time_filter: Optional time range filter to exclude warmup/cooldown periods
is_counter: If True, compute delta from baseline instead of statistics

Returns:
MetricResult with min/max/avg/percentiles for gauges, or delta for counters

Raises:
NoMetricValue: If no data for this metric or no data in filtered range
"""
arr = self.get_metric_array(metric_name)
if arr is None or len(arr) == 0:
raise NoMetricValue(
f"No telemetry data available for metric '{metric_name}'"
)

# Common: apply time filter
time_mask = self.get_time_mask(time_filter)
filtered = arr[time_mask]
if len(filtered) == 0:
raise NoMetricValue(f"No data in time range for metric '{metric_name}'")

if is_counter:
# Counter: compute delta from baseline
reference_idx = self.get_reference_idx(time_filter)
reference_value = (
arr[reference_idx] if reference_idx is not None else filtered[0]
)
raw_delta = float(filtered[-1] - reference_value)

# Handle counter resets (e.g., DCGM restart) by clamping to 0
delta = max(raw_delta, 0.0)

# Counters report a single delta value, not a distribution
return MetricResult(
tag=tag,
header=header,
unit=unit,
avg=delta,
)

# Gauge: vectorized stats on filtered data
p1, p5, p10, p25, p50, p75, p90, p95, p99 = np.percentile(
filtered, [1, 5, 10, 25, 50, 75, 90, 95, 99]
)

# Use sample std (ddof=1) for unbiased estimate; 0 for single sample
std_dev = float(np.std(filtered, ddof=1)) if len(filtered) > 1 else 0.0

return MetricResult(
tag=tag,
header=header,
unit=unit,
min=float(np.min(filtered)),
max=float(np.max(filtered)),
avg=float(np.mean(filtered)),
std=std_dev,
count=len(filtered),
p1=p1,
p5=p5,
p10=p10,
p25=p25,
p50=p50,
p75=p75,
p90=p90,
p95=p95,
p99=p99,
)

def __len__(self) -> int:
"""Return the number of snapshots in the time series."""
return self._size
Expand Down Expand Up @@ -292,19 +439,31 @@ def add_record(self, record: TelemetryRecord) -> None:
self.time_series.append_snapshot(valid_metrics, record.timestamp_ns)

def get_metric_result(
self, metric_name: str, tag: str, header: str, unit: str
self,
metric_name: str,
tag: str,
header: str,
unit: str,
time_filter: TimeRangeFilter | None = None,
is_counter: bool = False,
) -> MetricResult:
"""Get MetricResult for a specific metric.
"""Get MetricResult for a specific metric with optional time filtering.

Args:
metric_name: Name of the metric to analyze
tag: Unique identifier for this metric
header: Human-readable name for display
unit: Unit of measurement
time_filter: Optional time range filter to exclude warmup/cooldown periods
is_counter: If True, compute delta from baseline instead of statistics

Returns:
MetricResult with statistical summary for the specified metric
"""
if time_filter is not None or is_counter:
return self.time_series.to_metric_result_filtered(
metric_name, tag, header, unit, time_filter, is_counter
)
return self.time_series.to_metric_result(metric_name, tag, header, unit)


Expand Down
2 changes: 2 additions & 0 deletions src/aiperf/gpu_telemetry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from aiperf.gpu_telemetry.constants import (
DCGM_TO_FIELD_MAPPING,
GPU_TELEMETRY_COUNTER_METRICS,
GPU_TELEMETRY_METRICS_CONFIG,
SCALING_FACTORS,
get_gpu_telemetry_metrics_config,
Expand All @@ -41,6 +42,7 @@
"GPUTelemetryDataCollector",
"GPUTelemetryJSONLWriter",
"GPUTelemetryManager",
"GPU_TELEMETRY_COUNTER_METRICS",
"GPU_TELEMETRY_METRICS_CONFIG",
"MetricsConfigLoader",
"SCALING_FACTORS",
Expand Down
64 changes: 49 additions & 15 deletions src/aiperf/gpu_telemetry/accumulator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import asyncio
Expand All @@ -22,10 +22,14 @@
TelemetryExportData,
TelemetrySummary,
)
from aiperf.common.models.server_metrics_models import TimeRangeFilter
from aiperf.common.models.telemetry_models import TelemetryHierarchy, TelemetryRecord
from aiperf.common.protocols import GPUTelemetryAccumulatorProtocol, PubClientProtocol
from aiperf.exporters.display_units_utils import normalize_endpoint_display
from aiperf.gpu_telemetry.constants import get_gpu_telemetry_metrics_config
from aiperf.gpu_telemetry.constants import (
GPU_TELEMETRY_COUNTER_METRICS,
get_gpu_telemetry_metrics_config,
)
from aiperf.post_processors.base_metrics_processor import BaseMetricsProcessor


Expand Down Expand Up @@ -176,10 +180,8 @@ async def summarize(self) -> list[MetricResult]:

header = f"{metric_display} | {endpoint_display} | GPU {gpu_index} | {model_name}"

unit = unit_enum.value

result = telemetry_data.get_metric_result(
metric_name, tag, header, unit
metric_name, tag, header, unit_enum
)
results.append(result)
except NoMetricValue:
Expand All @@ -197,30 +199,51 @@ async def summarize(self) -> list[MetricResult]:

def export_results(
self,
start_ns: int,
end_ns: int,
start_ns: int | None = None,
end_ns: int | None = None,
error_summary: list[ErrorDetailsCount] | None = None,
) -> "TelemetryExportData | None":
"""Export accumulated telemetry data as a TelemetryExportData object.

Transforms the internal numpy-backed telemetry hierarchy into a serializable
format with pre-computed metric statistics for each GPU.

Time filtering is applied to exclude warmup periods from statistics:
- Gauge metrics (power, utilization, etc.): Stats computed on filtered data only
- Counter metrics (energy, errors): Delta computed from baseline before start_ns

Args:
start_ns: Start time of collection in nanoseconds
end_ns: End time of collection in nanoseconds
start_ns: Start time of profiling phase in nanoseconds (excludes warmup).
If None, includes all data from beginning.
end_ns: End time of profiling phase in nanoseconds. If None, includes all
data after start_ns (including final scrape after profiling completes).
error_summary: Optional list of error counts

Returns:
TelemetryExportData object with pre-computed metrics for each GPU
"""
# Create time filter for warmup exclusion
# Note: end_ns is typically None to include the final telemetry scrape
# that occurs after PROFILE_COMPLETE but before export
time_filter = TimeRangeFilter(start_ns=start_ns, end_ns=end_ns)

# Build summary
# When start_ns/end_ns is None, use current time as the timestamp
start_time = (
datetime.fromtimestamp(start_ns / NANOS_PER_SECOND)
if start_ns is not None
else datetime.now()
)
end_time = (
datetime.fromtimestamp(end_ns / NANOS_PER_SECOND)
if end_ns is not None
else datetime.now()
)
summary = TelemetrySummary(
endpoints_configured=list(self._hierarchy.dcgm_endpoints.keys()),
endpoints_successful=list(self._hierarchy.dcgm_endpoints.keys()),
start_time=datetime.fromtimestamp(start_ns / NANOS_PER_SECOND),
end_time=datetime.fromtimestamp(end_ns / NANOS_PER_SECOND),
start_time=start_time,
end_time=end_time,
)

# Build endpoints dict with pre-computed metrics
Expand All @@ -238,16 +261,27 @@ def export_results(
metrics_dict = {}

for (
_metric_display,
metric_display,
metric_key,
unit,
unit_enum,
) in get_gpu_telemetry_metrics_config():
try:
is_counter = metric_key in GPU_TELEMETRY_COUNTER_METRICS
metric_result = gpu_data.get_metric_result(
metric_key, metric_key, metric_key, unit
metric_key,
metric_key,
metric_display,
unit_enum,
time_filter=time_filter,
is_counter=is_counter,
)
metrics_dict[metric_key] = metric_result.to_json_result()
except Exception:
except NoMetricValue:
continue
except Exception as e:
self.warning(
f"Failed to compute metric '{metric_key}' for GPU {gpu_uuid[:12]}: {e}"
)
Comment on lines +264 to +284
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Find telemetry_models file and inspect get_metric_result
find . -type f -name '*telemetry*models*.py' | head -20

Repository: ai-dynamo/aiperf

Length of output: 158


🏁 Script executed:

#!/bin/bash
# Search for get_metric_result definition
rg -n 'def get_metric_result' --type py

Repository: ai-dynamo/aiperf

Length of output: 135


🏁 Script executed:

#!/bin/bash
# Look for exception definitions and raises in the telemetry module
rg -n 'class.*Error|class.*Exception' src/aiperf/gpu_telemetry/ --type py
rg -n 'raise ' src/aiperf/gpu_telemetry/ --type py | head -30

Repository: ai-dynamo/aiperf

Length of output: 228


🏁 Script executed:

#!/bin/bash
# Check the accumulator.py file to see imports and context
cat -n src/aiperf/gpu_telemetry/accumulator.py | head -50
cat -n src/aiperf/gpu_telemetry/accumulator.py | sed -n '240,280p'

Repository: ai-dynamo/aiperf

Length of output: 4837


🏁 Script executed:

#!/bin/bash
# Read get_metric_result implementation
cat -n src/aiperf/common/models/telemetry_models.py | sed -n '441,500p'

Repository: ai-dynamo/aiperf

Length of output: 2679


🏁 Script executed:

#!/bin/bash
# Check for more context, including exception handling in get_metric_result
cat -n src/aiperf/common/models/telemetry_models.py | sed -n '430,550p'

Repository: ai-dynamo/aiperf

Length of output: 4811


🏁 Script executed:

#!/bin/bash
# Look for what exceptions are defined/raised in the telemetry_models
rg -n 'raise|except' src/aiperf/common/models/telemetry_models.py | head -40

Repository: ai-dynamo/aiperf

Length of output: 275


🏁 Script executed:

#!/bin/bash
# Find time_series class and its to_metric_result methods
rg -n 'def to_metric_result' src/aiperf/common/models/telemetry_models.py -A 15

Repository: ai-dynamo/aiperf

Length of output: 1321


🏁 Script executed:

#!/bin/bash
# Check the entire to_metric_result_filtered to see what it does
rg -n 'def to_metric_result_filtered' src/aiperf/common/models/ -A 20

Repository: ai-dynamo/aiperf

Length of output: 1964


🏁 Script executed:

#!/bin/bash
# Get full implementation of to_metric_result
cat -n src/aiperf/common/models/telemetry_models.py | sed -n '213,280p'

Repository: ai-dynamo/aiperf

Length of output: 2829


🏁 Script executed:

#!/bin/bash
# Get full implementation of to_metric_result_filtered
cat -n src/aiperf/common/models/telemetry_models.py | sed -n '320,410p'

Repository: ai-dynamo/aiperf

Length of output: 3857


Use self.exception() with explicit # noqa: BLE001 to justify blanket exception handling.

Ruff flags BLE001 here; while the blanket catch serves resilience (preventing one metric failure from stopping others), using self.warning() drops the traceback, hindering debugging. Change to self.exception() to log the full stack trace and add # noqa: BLE001 to explicitly document this as intentional exception masking.

Suggested change
                        except Exception as e:  # noqa: BLE001
                            self.exception(
                                f"Failed to compute metric '{metric_key}' for GPU {gpu_uuid[:12]}"
                            )
🧰 Tools
🪛 Ruff (0.14.13)

266-266: Do not catch blind exception: Exception

(BLE001)

🤖 Prompt for AI Agents
In `@src/aiperf/gpu_telemetry/accumulator.py` around lines 249 - 269, The blanket
except Exception in the metric loop is intentional but loses traceback; replace
the self.warning(...) call with self.exception(...) so the full stack trace is
logged when gpu_data.get_metric_result or other code fails (inside the loop that
iterates get_gpu_telemetry_metrics_config()), and add an inline "# noqa: BLE001"
comment on that except Exception line to document intentional broad exception
masking; keep the existing except NoMetricValue: continue behavior and reference
GPU_TELEMETRY_COUNTER_METRICS, get_gpu_telemetry_metrics_config,
gpu_data.get_metric_result and NoMetricValue when making the change.

continue

gpu_summary = GpuSummary(
Expand Down
11 changes: 10 additions & 1 deletion src/aiperf/gpu_telemetry/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Constants specific to GPU telemetry collection."""
Expand Down Expand Up @@ -45,6 +45,15 @@
("Power Violation", "power_violation", MetricTimeUnit.MICROSECONDS),
]

# Metrics that are cumulative counters (need delta calculation).
# These metrics accumulate over time (e.g., total energy consumed since boot),
# so we compute the delta between baseline and final values rather than statistics.
GPU_TELEMETRY_COUNTER_METRICS: set[str] = {
"energy_consumption",
"xid_errors",
"power_violation",
}


def get_gpu_telemetry_metrics_config() -> list[tuple[str, str, MetricUnitT]]:
"""Get the current GPU telemetry metrics configuration."""
Expand Down
Loading