diff --git a/docs/cli_options.md b/docs/cli_options.md index 8a947b34e..e8d93a48d 100644 --- a/docs/cli_options.md +++ b/docs/cli_options.md @@ -735,7 +735,7 @@ Duration in seconds to ramp warmup request rate from a proportional minimum to t #### `--gpu-telemetry` `` -Enable GPU telemetry console display and optionally specify: (1) 'dashboard' for realtime dashboard mode, (2) custom DCGM exporter URLs (e.g., http://node1:9401/metrics), (3) custom metrics CSV file (e.g., custom_gpu_metrics.csv). Default endpoints localhost:9400 and localhost:9401 are always attempted. Example: --gpu-telemetry dashboard node1:9400 custom.csv. +Enable GPU telemetry console display and optionally specify: (1) 'pynvml' to use local pynvml library instead of DCGM HTTP endpoints, (2) 'dashboard' for realtime dashboard mode, (3) custom DCGM exporter URLs (e.g., http://node1:9401/metrics), (4) custom metrics CSV file (e.g., custom_gpu_metrics.csv). Default: DCGM mode with localhost:9400 and localhost:9401 endpoints. Examples: --gpu-telemetry pynvml | --gpu-telemetry dashboard node1:9400. #### `--no-gpu-telemetry` diff --git a/docs/tutorials/gpu-telemetry.md b/docs/tutorials/gpu-telemetry.md index 7ae6348e2..7a76bf856 100644 --- a/docs/tutorials/gpu-telemetry.md +++ b/docs/tutorials/gpu-telemetry.md @@ -1,5 +1,5 @@ @@ -9,7 +9,7 @@ This guide shows you how to collect GPU metrics (power, utilization, memory, tem ## Overview -This guide covers two setup paths depending on your inference backend: +This guide covers three setup paths depending on your inference backend and requirements: ### Path 1: Dynamo (Built-in DCGM) If you're using **Dynamo**, it comes with DCGM pre-configured on port 9401. No additional setup needed! Just use the `--gpu-telemetry` flag to enable console display and optionally add additional DCGM url endpoints. URLs can be specified with or without the `http://` prefix (e.g., `localhost:9400` or `http://localhost:9400`). @@ -17,6 +17,9 @@ If you're using **Dynamo**, it comes with DCGM pre-configured on port 9401. No a ### Path 2: Other Inference Servers (Custom DCGM) If you're using **any other inference backend**, you'll need to set up DCGM separately. +### Path 3: Local GPU Monitoring (pynvml) +If you want **simple local GPU monitoring without DCGM**, use `--gpu-telemetry pynvml`. This uses NVIDIA's nvidia-ml-py Python library (commonly known as pynvml) to collect metrics directly from the GPU driver. No HTTP endpoints or additional containers required. + ## Prerequisites - NVIDIA GPU with CUDA support @@ -36,14 +39,23 @@ AIPerf provides GPU telemetry collection with the `--gpu-telemetry` flag. Here's | **Custom URLs** | `aiperf profile --model MODEL ... --gpu-telemetry node1:9400 http://node2:9400/metrics` | `http://localhost:9400/metrics` + `http://localhost:9401/metrics` + [custom URLs](#multi-node-gpu-telemetry-example) | ✅ Yes | ❌ No | ✅ Yes | | **Dashboard + URLs** | `aiperf profile --model MODEL ... --gpu-telemetry dashboard localhost:9400` | `http://localhost:9400/metrics` + `http://localhost:9401/metrics` + [custom URLs](#multi-node-gpu-telemetry-example) | ✅ Yes | ✅ Yes ([see dashboard](#real-time-dashboard-view)) | ✅ Yes | | **Custom metrics** | `aiperf profile --model MODEL ... --gpu-telemetry custom_gpu_metrics.csv` | `http://localhost:9400/metrics` + `http://localhost:9401/metrics` + [custom metrics from CSV](#customizing-displayed-metrics) | ✅ Yes | ❌ No | ✅ Yes | +| **pynvml mode** | `aiperf profile --model MODEL ... --gpu-telemetry pynvml` | Local GPUs via pynvml library ([see pynvml section](#3-using-pynvml-local-gpu-monitoring)) | ✅ Yes | ❌ No | ✅ Yes | +| **pynvml + dashboard** | `aiperf profile --model MODEL ... --gpu-telemetry pynvml dashboard` | Local GPUs via pynvml library | ✅ Yes | ✅ Yes ([see dashboard](#real-time-dashboard-view)) | ✅ Yes | | **Disabled** | `aiperf profile --model MODEL ... --no-gpu-telemetry` | None | ❌ No | ❌ No | ❌ No | > [!IMPORTANT] -> The default endpoints `http://localhost:9400/metrics` and `http://localhost:9401/metrics` are ALWAYS attempted for telemetry collection, regardless of whether the `--gpu-telemetry` flag is used. The flag primarily controls whether metrics are displayed on the console and allows you to specify additional custom DCGM exporter endpoints. To completely disable GPU telemetry collection, use `--no-gpu-telemetry`. +> **DCGM mode (default):** The default endpoints `http://localhost:9400/metrics` and `http://localhost:9401/metrics` are always attempted for telemetry collection, regardless of whether the `--gpu-telemetry` flag is used. The flag primarily controls whether metrics are displayed on the console and allows you to specify additional custom DCGM exporter endpoints. +> +> **pynvml mode:** When using `--gpu-telemetry pynvml`, DCGM endpoints are NOT used. Metrics are collected directly from local GPUs via the nvidia-ml-py library. +> +> To completely disable GPU telemetry collection, use `--no-gpu-telemetry`. > [!NOTE] > When specifying custom DCGM exporter URLs, the `http://` prefix is optional. URLs like `localhost:9400` will automatically be treated as `http://localhost:9400`. Both formats work identically. +> [!TIP] +> For simple local GPU monitoring without DCGM setup, use `--gpu-telemetry pynvml`. This collects metrics directly from the NVIDIA driver using the nvidia-ml-py library. See [Path 3: pynvml](#3-using-pynvml-local-gpu-monitoring) for details. + ### Real-Time Dashboard View Adding `dashboard` to the `--gpu-telemetry` flag enables a live terminal UI (TUI) that displays GPU metrics in real-time during your benchmark runs: @@ -300,8 +312,85 @@ aiperf profile \ > [!TIP] > The `dashboard` keyword enables a live terminal UI for real-time GPU telemetry visualization. Press `5` to maximize the GPU Telemetry panel during the benchmark run. +--- + +# 3: Using pynvml (Local GPU Monitoring) + +For simple local GPU monitoring without DCGM infrastructure, AIPerf supports direct GPU metrics collection using NVIDIA's nvidia-ml-py Python library (commonly known as pynvml). This approach requires no additional containers, HTTP endpoints, or DCGM setup. + +## Prerequisites + +- NVIDIA GPU with driver installed +- nvidia-ml-py package: `pip install nvidia-ml-py` + +## When to Use pynvml + +| Scenario | Recommended Approach | +|----------|---------------------| +| Local development/testing | pynvml | +| Single-node inference server | pynvml or DCGM | +| Multi-node distributed setup | DCGM (HTTP endpoints required) | +| Production with existing DCGM | DCGM | +| Quick GPU monitoring without setup | pynvml | + +## Run AIPerf with pynvml + +```bash +aiperf profile \ + --model Qwen/Qwen3-0.6B \ + --endpoint-type chat \ + --endpoint /v1/chat/completions \ + --streaming \ + --url localhost:8000 \ + --synthetic-input-tokens-mean 100 \ + --synthetic-input-tokens-stddev 0 \ + --output-tokens-mean 200 \ + --output-tokens-stddev 0 \ + --extra-inputs min_tokens:200 \ + --extra-inputs ignore_eos:true \ + --concurrency 4 \ + --request-count 64 \ + --warmup-request-count 1 \ + --num-dataset-entries 8 \ + --random-seed 100 \ + --gpu-telemetry pynvml +``` + > [!TIP] -> The `dashboard` keyword enables a live terminal UI for real-time GPU telemetry visualization. Press `5` to maximize the GPU Telemetry panel during the benchmark run. +> Add `dashboard` after `pynvml` for the real-time terminal UI: `--gpu-telemetry pynvml dashboard` + +## Metrics Collected via pynvml + +The nvidia-ml-py library (pynvml) collects the following metrics directly from the NVIDIA driver: + +| Metric | Description | Unit | +|--------|-------------|------| +| GPU Power Usage | Current power draw | W | +| Energy Consumption | Total energy since boot | MJ | +| GPU Utilization | GPU compute utilization | % | +| Memory Utilization | Memory controller utilization | % | +| GPU Memory Used | Framebuffer memory in use | GB | +| GPU Temperature | GPU die temperature | °C | +| SM Utilization | Streaming multiprocessor utilization | % | +| Decoder Utilization | Video decoder utilization | % | +| Encoder Utilization | Video encoder utilization | % | +| JPEG Utilization | JPEG decoder utilization | % | +| Power Violation | Throttling duration due to power limits | µs | + +> [!NOTE] +> Not all metrics are available on all GPU models. AIPerf gracefully handles missing metrics and reports only what the hardware supports. + +## Comparing DCGM vs pynvml + +| Feature | DCGM | pynvml | +|---------|------|--------| +| Setup complexity | Requires container/service | Just install nvidia-ml-py Python package | +| Multi-node support | Yes (via HTTP endpoints) | No (local only) | +| Metrics granularity | High (profiling-level metrics) | Standard (driver-level metrics) | +| Kubernetes integration | Native with dcgm-exporter | Not applicable | +| XID error reporting | Yes | No | + +--- ## Multi-Node GPU Telemetry Example diff --git a/pyproject.toml b/pyproject.toml index 23e258391..38320f4e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "matplotlib>=3.10.0", "msgspec>=0.19.0,<1.0.0", "numpy~=1.26.4", + "nvidia-ml-py", # Note: No version specified to be most compatible with CUDA version "orjson~=3.10.18", "pandas~=2.3.3", "pillow~=11.1.0", diff --git a/src/aiperf/common/config/user_config.py b/src/aiperf/common/config/user_config.py index 2076a8261..297075d50 100644 --- a/src/aiperf/common/config/user_config.py +++ b/src/aiperf/common/config/user_config.py @@ -23,7 +23,12 @@ from aiperf.common.config.loadgen_config import LoadGeneratorConfig from aiperf.common.config.output_config import OutputConfig from aiperf.common.config.tokenizer_config import TokenizerConfig -from aiperf.common.enums import CustomDatasetType, GPUTelemetryMode, ServerMetricsFormat +from aiperf.common.enums import ( + CustomDatasetType, + GPUTelemetryCollectorType, + GPUTelemetryMode, + ServerMetricsFormat, +) from aiperf.common.enums.plugin_enums import EndpointType from aiperf.common.enums.timing_enums import ArrivalPattern, TimingMode from aiperf.common.utils import load_json_str @@ -414,11 +419,12 @@ def _count_dataset_entries(self) -> int: Field( description=( "Enable GPU telemetry console display and optionally specify: " - "(1) 'dashboard' for realtime dashboard mode, " - "(2) custom DCGM exporter URLs (e.g., http://node1:9401/metrics), " - "(3) custom metrics CSV file (e.g., custom_gpu_metrics.csv). " - "Default endpoints localhost:9400 and localhost:9401 are always attempted. " - "Example: --gpu-telemetry dashboard node1:9400 custom.csv" + "(1) 'pynvml' to use local pynvml library instead of DCGM HTTP endpoints, " + "(2) 'dashboard' for realtime dashboard mode, " + "(3) custom DCGM exporter URLs (e.g., http://node1:9401/metrics), " + "(4) custom metrics CSV file (e.g., custom_gpu_metrics.csv). " + "Default: DCGM mode with localhost:9400 and localhost:9401 endpoints. " + "Examples: --gpu-telemetry pynvml | --gpu-telemetry dashboard node1:9400" ), ), BeforeValidator(parse_str_or_list), @@ -441,12 +447,15 @@ def _count_dataset_entries(self) -> int: ] = False _gpu_telemetry_mode: GPUTelemetryMode = GPUTelemetryMode.SUMMARY + _gpu_telemetry_collector_type: GPUTelemetryCollectorType = ( + GPUTelemetryCollectorType.DCGM + ) _gpu_telemetry_urls: list[str] = [] _gpu_telemetry_metrics_file: Path | None = None @model_validator(mode="after") def _parse_gpu_telemetry_config(self) -> Self: - """Parse gpu_telemetry list into mode, URLs, and metrics file.""" + """Parse gpu_telemetry list into mode, collector type, URLs, and metrics file.""" if ( "no_gpu_telemetry" in self.model_fields_set and "gpu_telemetry" in self.model_fields_set @@ -460,6 +469,7 @@ def _parse_gpu_telemetry_config(self) -> Self: return self mode = GPUTelemetryMode.SUMMARY + collector_type = GPUTelemetryCollectorType.DCGM urls = [] metrics_file = None @@ -469,17 +479,35 @@ def _parse_gpu_telemetry_config(self) -> Self: metrics_file = Path(item) if not metrics_file.exists(): raise ValueError(f"GPU metrics file not found: {item}") - continue - + # Check for pynvml collector type + elif item.lower() == "pynvml": + collector_type = GPUTelemetryCollectorType.PYNVML + try: + import pynvml # noqa: F401 + except ImportError as e: + raise ValueError( + "pynvml package not installed. Install with: pip install nvidia-ml-py" + ) from e # Check for dashboard mode - if item in ["dashboard"]: + elif item in ["dashboard"]: mode = GPUTelemetryMode.REALTIME_DASHBOARD - # Check for URLs + # Check for URLs (only applicable for DCGM collector) elif item.startswith("http") or ":" in item: normalized_url = item if item.startswith("http") else f"http://{item}" urls.append(normalized_url) + else: + raise ValueError( + f"Invalid GPU telemetry item: {item}. Valid options are: 'pynvml', 'dashboard', '.csv' file, and URLs." + ) + + if collector_type == GPUTelemetryCollectorType.PYNVML and urls: + raise ValueError( + "Cannot use pynvml with DCGM URLs. Use either 'pynvml' for local " + "GPU monitoring or URLs for DCGM endpoints, not both." + ) self._gpu_telemetry_mode = mode + self._gpu_telemetry_collector_type = collector_type self._gpu_telemetry_urls = urls self._gpu_telemetry_metrics_file = metrics_file return self @@ -494,6 +522,11 @@ def gpu_telemetry_mode(self, value: GPUTelemetryMode) -> None: """Set the GPU telemetry display mode.""" self._gpu_telemetry_mode = value + @property + def gpu_telemetry_collector_type(self) -> GPUTelemetryCollectorType: + """Get the GPU telemetry collector type (DCGM or PYNVML).""" + return self._gpu_telemetry_collector_type + @property def gpu_telemetry_urls(self) -> list[str]: """Get the parsed GPU telemetry DCGM endpoint URLs.""" diff --git a/src/aiperf/common/enums/__init__.py b/src/aiperf/common/enums/__init__.py index e77e47d8d..6b7a1b0d3 100644 --- a/src/aiperf/common/enums/__init__.py +++ b/src/aiperf/common/enums/__init__.py @@ -109,6 +109,7 @@ SystemState, ) from aiperf.common.enums.telemetry_enums import ( + GPUTelemetryCollectorType, GPUTelemetryMode, ) from aiperf.common.enums.timing_enums import ( @@ -150,6 +151,7 @@ "ExportLevel", "FrequencyMetricUnit", "FrequencyMetricUnitInfo", + "GPUTelemetryCollectorType", "GPUTelemetryMode", "GenericMetricUnit", "ImageFormat", diff --git a/src/aiperf/common/enums/telemetry_enums.py b/src/aiperf/common/enums/telemetry_enums.py index 1efd870de..75ba9e5df 100644 --- a/src/aiperf/common/enums/telemetry_enums.py +++ b/src/aiperf/common/enums/telemetry_enums.py @@ -1,9 +1,19 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from aiperf.common.enums.base_enums import CaseInsensitiveStrEnum +class GPUTelemetryCollectorType(CaseInsensitiveStrEnum): + """GPU telemetry collector implementation type.""" + + DCGM = "dcgm" + """Collects GPU telemetry metrics from DCGM Prometheus exporter.""" + + PYNVML = "pynvml" + """Collects GPU telemetry metrics using the pynvml Python library.""" + + class GPUTelemetryMode(CaseInsensitiveStrEnum): """GPU telemetry display mode.""" diff --git a/src/aiperf/common/mixins/base_metrics_collector_mixin.py b/src/aiperf/common/mixins/base_metrics_collector_mixin.py index 0844e3b24..13d914c33 100644 --- a/src/aiperf/common/mixins/base_metrics_collector_mixin.py +++ b/src/aiperf/common/mixins/base_metrics_collector_mixin.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """Base mixin for async HTTP metrics data collectors. @@ -160,7 +160,7 @@ class BaseMetricsCollectorMixin(AIPerfLifecycleMixin, ABC, Generic[TRecord]): - Precise HTTP timing capture for correlation analysis Used by: - - GPUTelemetryDataCollector (DCGM metrics from GPU monitoring) + - DCGMTelemetryCollector (DCGM metrics from GPU monitoring) - ServerMetricsDataCollector (Prometheus metrics from inference servers) Example: diff --git a/src/aiperf/common/models/telemetry_models.py b/src/aiperf/common/models/telemetry_models.py index cdeac36be..3c3064296 100644 --- a/src/aiperf/common/models/telemetry_models.py +++ b/src/aiperf/common/models/telemetry_models.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -28,7 +28,9 @@ class TelemetryMetrics(AIPerfBaseModel): default=None, description="Cumulative energy consumption in MJ" ) gpu_utilization: float | None = Field( - default=None, description="GPU utilization percentage (0-100)" + default=None, + description="GPU utilization percentage (0-100). " + "Percent of time over the past sample period during which one or more kernels was executing on the GPU.", ) gpu_memory_used: float | None = Field( default=None, description="GPU memory used in GB" @@ -36,6 +38,24 @@ class TelemetryMetrics(AIPerfBaseModel): gpu_temperature: float | None = Field( default=None, description="GPU temperature in °C" ) + mem_utilization: float | None = Field( + default=None, + description="Memory bandwidth utilization percentage (0-100). " + "Percent of time over the past sample period during which global (device) memory was being read or written.", + ) + sm_utilization: float | None = Field( + default=None, + description="Streaming multiprocessor utilization percentage (0-100)", + ) + decoder_utilization: float | None = Field( + default=None, description="Video decoder (NVDEC) utilization percentage (0-100)" + ) + encoder_utilization: float | None = Field( + default=None, description="Video encoder (NVENC) utilization percentage (0-100)" + ) + jpg_utilization: float | None = Field( + default=None, description="JPEG decoder utilization percentage (0-100)" + ) xid_errors: float | None = Field( default=None, description="Value of the last XID error encountered" ) @@ -92,7 +112,7 @@ class TelemetryRecord(GpuMetadata): description="Nanosecond wall-clock timestamp when telemetry was collected (time_ns)" ) dcgm_url: str = Field( - description="Source DCGM endpoint URL (e.g., 'http://node1:9401/metrics')" + description="Source identifier (DCGM URL e.g., 'http://node1:9401/metrics' or 'pynvml://localhost')" ) telemetry_data: TelemetryMetrics = Field( description="GPU metrics snapshot collected at this timestamp" diff --git a/src/aiperf/controller/system_controller.py b/src/aiperf/controller/system_controller.py index b85ab7087..db72616e9 100644 --- a/src/aiperf/controller/system_controller.py +++ b/src/aiperf/controller/system_controller.py @@ -191,7 +191,6 @@ async def _start_services(self) -> None: # Start optional services before waiting for registration so they can participate in configuration if not self.user_config.gpu_telemetry_disabled: - self.debug("Starting optional TelemetryManager service") await self.service_manager.run_service(ServiceType.GPU_TELEMETRY_MANAGER) else: self.info("GPU telemetry disabled via --no-gpu-telemetry") diff --git a/src/aiperf/gpu_telemetry/__init__.py b/src/aiperf/gpu_telemetry/__init__.py index 80ebbfd9b..ff73916f8 100644 --- a/src/aiperf/gpu_telemetry/__init__.py +++ b/src/aiperf/gpu_telemetry/__init__.py @@ -19,11 +19,17 @@ from aiperf.gpu_telemetry.constants import ( DCGM_TO_FIELD_MAPPING, GPU_TELEMETRY_METRICS_CONFIG, - SCALING_FACTORS, + PYNVML_SOURCE_IDENTIFIER, get_gpu_telemetry_metrics_config, ) -from aiperf.gpu_telemetry.data_collector import ( - GPUTelemetryDataCollector, +from aiperf.gpu_telemetry.dcgm_collector import ( + DCGMTelemetryCollector, +) +from aiperf.gpu_telemetry.factories import ( + GPUTelemetryCollectorFactory, + GPUTelemetryCollectorProtocol, + TErrorCallback, + TRecordCallback, ) from aiperf.gpu_telemetry.jsonl_writer import ( GPUTelemetryJSONLWriter, @@ -34,15 +40,23 @@ from aiperf.gpu_telemetry.metrics_config import ( MetricsConfigLoader, ) +from aiperf.gpu_telemetry.pynvml_collector import ( + PyNVMLTelemetryCollector, +) __all__ = [ + "DCGMTelemetryCollector", "DCGM_TO_FIELD_MAPPING", "GPUTelemetryAccumulator", - "GPUTelemetryDataCollector", + "GPUTelemetryCollectorFactory", + "GPUTelemetryCollectorProtocol", "GPUTelemetryJSONLWriter", "GPUTelemetryManager", "GPU_TELEMETRY_METRICS_CONFIG", "MetricsConfigLoader", - "SCALING_FACTORS", + "PYNVML_SOURCE_IDENTIFIER", + "PyNVMLTelemetryCollector", + "TErrorCallback", + "TRecordCallback", "get_gpu_telemetry_metrics_config", ] diff --git a/src/aiperf/gpu_telemetry/constants.py b/src/aiperf/gpu_telemetry/constants.py index 49a486feb..f7438ac8b 100644 --- a/src/aiperf/gpu_telemetry/constants.py +++ b/src/aiperf/gpu_telemetry/constants.py @@ -1,7 +1,7 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Constants specific to GPU telemetry collection.""" +"""Constants for GPU telemetry collection (DCGM and pynvml).""" from aiperf.common.enums.metric_enums import ( EnergyMetricUnit, @@ -13,19 +13,20 @@ TemperatureMetricUnit, ) -# Unit conversion scaling factors -SCALING_FACTORS = { - "energy_consumption": 1e-9, # mJ to MJ - "gpu_memory_used": 1.048576 * 1e-3, # MiB to GB -} +# Source identifier for pynvml collector (used in TelemetryRecord.dcgm_url field) +PYNVML_SOURCE_IDENTIFIER = "pynvml://localhost" # DCGM field mapping to telemetry record fields DCGM_TO_FIELD_MAPPING = { "DCGM_FI_DEV_POWER_USAGE": "gpu_power_usage", "DCGM_FI_DEV_TOTAL_ENERGY_CONSUMPTION": "energy_consumption", "DCGM_FI_DEV_GPU_UTIL": "gpu_utilization", + "DCGM_FI_DEV_MEM_COPY_UTIL": "mem_utilization", "DCGM_FI_DEV_FB_USED": "gpu_memory_used", "DCGM_FI_DEV_GPU_TEMP": "gpu_temperature", + "DCGM_FI_DEV_ENC_UTIL": "encoder_utilization", + "DCGM_FI_DEV_DEC_UTIL": "decoder_utilization", + "DCGM_FI_PROF_SM_ACTIVE": "sm_utilization", "DCGM_FI_DEV_XID_ERRORS": "xid_errors", "DCGM_FI_DEV_POWER_VIOLATION": "power_violation", } @@ -41,6 +42,11 @@ ("GPU Utilization", "gpu_utilization", GenericMetricUnit.PERCENT), ("GPU Memory Used", "gpu_memory_used", MetricSizeUnit.GIGABYTES), ("GPU Temperature", "gpu_temperature", TemperatureMetricUnit.CELSIUS), + ("Memory Utilization", "mem_utilization", GenericMetricUnit.PERCENT), + ("SM Utilization", "sm_utilization", GenericMetricUnit.PERCENT), + ("Decoder Utilization", "decoder_utilization", GenericMetricUnit.PERCENT), + ("Encoder Utilization", "encoder_utilization", GenericMetricUnit.PERCENT), + ("JPEG Utilization", "jpg_utilization", GenericMetricUnit.PERCENT), ("XID Errors", "xid_errors", GenericMetricUnit.COUNT), ("Power Violation", "power_violation", MetricTimeUnit.MICROSECONDS), ] diff --git a/src/aiperf/gpu_telemetry/data_collector.py b/src/aiperf/gpu_telemetry/dcgm_collector.py similarity index 85% rename from src/aiperf/gpu_telemetry/data_collector.py rename to src/aiperf/gpu_telemetry/dcgm_collector.py index 591f64e30..0492399e2 100644 --- a/src/aiperf/gpu_telemetry/data_collector.py +++ b/src/aiperf/gpu_telemetry/dcgm_collector.py @@ -1,10 +1,12 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 import time from prometheus_client.parser import text_string_to_metric_families +from aiperf.common.decorators import implements_protocol +from aiperf.common.enums import GPUTelemetryCollectorType from aiperf.common.environment import Environment from aiperf.common.mixins import ( BaseMetricsCollectorMixin, @@ -12,21 +14,35 @@ TRecordCallback, ) from aiperf.common.models import GpuMetadata, TelemetryMetrics, TelemetryRecord -from aiperf.gpu_telemetry.constants import ( - DCGM_TO_FIELD_MAPPING, - SCALING_FACTORS, +from aiperf.gpu_telemetry.constants import DCGM_TO_FIELD_MAPPING +from aiperf.gpu_telemetry.factories import ( + GPUTelemetryCollectorFactory, + GPUTelemetryCollectorProtocol, ) -__all__ = ["GPUTelemetryDataCollector"] +__all__ = ["DCGMTelemetryCollector"] +# Unit conversion scaling factors for DCGM metrics +SCALING_FACTORS = { + "energy_consumption": 1e-9, # mJ -> MJ + "gpu_memory_used": 1.048576e-3, # MiB -> GB + "sm_utilization": 100, # ratio (0-1) -> percentage (0-100) + "power_violation": 1e-3, # ns -> µs +} -class GPUTelemetryDataCollector(BaseMetricsCollectorMixin[TelemetryRecord]): - """Collects GPU telemetry metrics from DCGM exporter endpoints. + +@implements_protocol(GPUTelemetryCollectorProtocol) +@GPUTelemetryCollectorFactory.register(GPUTelemetryCollectorType.DCGM) +class DCGMTelemetryCollector(BaseMetricsCollectorMixin[TelemetryRecord]): + """Collects GPU telemetry metrics from DCGM exporter HTTP endpoints. Async collector that fetches GPU metrics from DCGM Prometheus exporter and converts them to TelemetryRecord objects. Extends BaseMetricsCollectorMixin for HTTP collection patterns and uses prometheus_client for robust metric parsing. + This is the default collector type for GPU telemetry when DCGM is available. + For local GPU monitoring without DCGM, see PyNVMLTelemetryCollector. + Features: - Async HTTP collection with aiohttp - DCGM Prometheus format parsing @@ -173,9 +189,9 @@ def _apply_scaling_factors(self, metrics: dict) -> dict: """Apply scaling factors to convert raw DCGM units to display units. Converts metrics from DCGM's native units to human-readable units: - - Power: milliwatts -> watts (multiply by 0.001) - - Memory: bytes -> megabytes (multiply by 1e-6) - - Frequency: MHz values (no scaling needed) + - Energy: millijoules -> megajoules + - Memory: MiB -> GB + - SM utilization: ratio (0-1) -> percentage (0-100) Only applies scaling to metrics present in the input dict. None values are preserved. diff --git a/src/aiperf/gpu_telemetry/factories.py b/src/aiperf/gpu_telemetry/factories.py new file mode 100644 index 000000000..a8d70062d --- /dev/null +++ b/src/aiperf/gpu_telemetry/factories.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Factory for GPU telemetry collectors.""" + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +from aiperf.common.enums import GPUTelemetryCollectorType +from aiperf.common.factories import AIPerfFactory + +if TYPE_CHECKING: + from aiperf.common.models import ErrorDetails, TelemetryRecord + + +@runtime_checkable +class GPUTelemetryCollectorProtocol(Protocol): + """Protocol for GPU telemetry collectors. + + Defines the interface for collectors that gather GPU metrics from various sources + (DCGM HTTP endpoints, pynvml library, etc.) and deliver them via callbacks. + """ + + @property + def id(self) -> str: + """Get the collector's unique identifier.""" + ... + + @property + def endpoint_url(self) -> str: + """Get the source identifier (URL for DCGM, 'pynvml://localhost' for pynvml).""" + ... + + async def initialize(self) -> None: + """Initialize the collector resources.""" + ... + + async def start(self) -> None: + """Start the background collection task.""" + ... + + async def stop(self) -> None: + """Stop the collector and clean up resources.""" + ... + + async def is_url_reachable(self) -> bool: + """Check if the collector source is available. + + For DCGM: Tests HTTP endpoint reachability. + For pynvml: Tests NVML library initialization. + + Returns: + True if the source is available and ready for collection. + """ + ... + + +# Type aliases for callbacks +TRecordCallback = Callable[[list["TelemetryRecord"], str], Awaitable[None]] +TErrorCallback = Callable[["ErrorDetails", str], Awaitable[None]] + + +class GPUTelemetryCollectorFactory( + AIPerfFactory[GPUTelemetryCollectorType, GPUTelemetryCollectorProtocol] +): + """Factory for creating GPU telemetry collector instances. + + Supports multiple collector implementations: + - DCGM: HTTP-based collection from DCGM Prometheus exporter + - PYNVML: Direct collection using pynvml Python library + + Example: + collector = GPUTelemetryCollectorFactory.create_instance( + GPUTelemetryCollectorType.DCGM, + dcgm_url="http://localhost:9400/metrics", + collection_interval=0.333, + record_callback=my_callback, + ) + """ + + pass diff --git a/src/aiperf/gpu_telemetry/manager.py b/src/aiperf/gpu_telemetry/manager.py index c19b61e9a..2b5772093 100644 --- a/src/aiperf/gpu_telemetry/manager.py +++ b/src/aiperf/gpu_telemetry/manager.py @@ -9,6 +9,7 @@ from aiperf.common.enums import ( CommAddress, CommandType, + GPUTelemetryCollectorType, ServiceType, ) from aiperf.common.environment import Environment @@ -25,7 +26,9 @@ PushClientProtocol, ServiceProtocol, ) -from aiperf.gpu_telemetry.data_collector import GPUTelemetryDataCollector +from aiperf.gpu_telemetry.constants import PYNVML_SOURCE_IDENTIFIER +from aiperf.gpu_telemetry.dcgm_collector import DCGMTelemetryCollector +from aiperf.gpu_telemetry.factories import GPUTelemetryCollectorProtocol __all__ = ["GPUTelemetryManager"] @@ -68,7 +71,7 @@ def __init__( CommAddress.RECORDS, ) - self._collectors: dict[str, GPUTelemetryDataCollector] = {} + self._collectors: dict[str, GPUTelemetryCollectorProtocol] = {} self._collector_id_to_url: dict[str, str] = {} self._telemetry_disabled = user_config.gpu_telemetry_disabled @@ -76,6 +79,10 @@ def __init__( user_config.gpu_telemetry is not None and not self._telemetry_disabled ) + # Store the collector type (DCGM or PYNVML) + self._collector_type = user_config.gpu_telemetry_collector_type + + # DCGM-specific endpoint configuration user_endpoints = user_config.gpu_telemetry_urls or [] if isinstance(user_endpoints, str): user_endpoints = [user_endpoints] @@ -157,9 +164,9 @@ async def _profile_configure_command( ) -> None: """Configure the telemetry collectors but don't start them yet. - Creates TelemetryDataCollector instances for each configured DCGM endpoint, + Creates collector instances based on configured type (DCGM or PYNVML), tests reachability, and sends status message to RecordsManager. - If no endpoints are reachable, disables telemetry and stops the service. + If no collectors can be created, disables telemetry and stops the service. Args: message: Profile configuration command from SystemController @@ -175,11 +182,72 @@ async def _profile_configure_command( self._collectors.clear() self._collector_id_to_url.clear() + + if self._collector_type == GPUTelemetryCollectorType.PYNVML: + await self._configure_pynvml_collector() + else: + await self._configure_dcgm_collectors() + + async def _configure_pynvml_collector(self) -> None: + """Configure a single PyNVML collector for local GPU monitoring.""" + self.debug("GPU Telemetry: Configuring pynvml collector") + + try: + # Import here to defer pynvml check until actually needed + from aiperf.gpu_telemetry.pynvml_collector import PyNVMLTelemetryCollector + + collector_id = "pynvml_collector" + collector = PyNVMLTelemetryCollector( + collection_interval=self._collection_interval, + record_callback=self._on_telemetry_records, + error_callback=self._on_telemetry_error, + collector_id=collector_id, + ) + + is_available = await collector.is_url_reachable() + if is_available: + self._collectors[PYNVML_SOURCE_IDENTIFIER] = collector + self._collector_id_to_url[collector_id] = PYNVML_SOURCE_IDENTIFIER + self.debug("GPU Telemetry: pynvml collector configured successfully") + await self._send_telemetry_status( + enabled=True, + reason=None, + endpoints_configured=[PYNVML_SOURCE_IDENTIFIER], + endpoints_reachable=[PYNVML_SOURCE_IDENTIFIER], + ) + else: + self.warning("GPU Telemetry: pynvml not available or no GPUs found") + await self._send_telemetry_status( + enabled=False, + reason="pynvml not available or no GPUs found", + endpoints_configured=[PYNVML_SOURCE_IDENTIFIER], + endpoints_reachable=[], + ) + except RuntimeError as e: + # pynvml package not installed + self.error(f"GPU Telemetry: {e}") + await self._send_telemetry_status( + enabled=False, + reason=str(e), + endpoints_configured=[], + endpoints_reachable=[], + ) + except Exception as e: # noqa: BLE001 - fault-tolerant telemetry + self.error(f"GPU Telemetry: Failed to configure pynvml collector: {e}") + await self._send_telemetry_status( + enabled=False, + reason=f"pynvml configuration failed: {e}", + endpoints_configured=[], + endpoints_reachable=[], + ) + + async def _configure_dcgm_collectors(self) -> None: + """Configure DCGM collectors for HTTP-based GPU telemetry.""" for dcgm_url in self._dcgm_endpoints: self.debug(f"GPU Telemetry: Testing reachability of {dcgm_url}") collector_id = f"collector_{dcgm_url.replace(':', '_').replace('/', '_')}" self._collector_id_to_url[collector_id] = dcgm_url - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url=dcgm_url, collection_interval=self._collection_interval, record_callback=self._on_telemetry_records, @@ -241,13 +309,13 @@ async def _on_start_profiling(self, message) -> None: return started_count = 0 - for dcgm_url, collector in self._collectors.items(): + for source_url, collector in self._collectors.items(): try: await collector.initialize() await collector.start() started_count += 1 - except Exception as e: - self.error(f"Failed to start collector for {dcgm_url}: {e}") + except Exception as e: # noqa: BLE001 - fault-tolerant telemetry + self.error(f"Failed to start collector for {source_url}: {e}") if started_count == 0: self.warning("No GPU telemetry collectors successfully started") @@ -307,11 +375,11 @@ async def _stop_all_collectors(self) -> None: if not self._collectors: return - for dcgm_url, collector in self._collectors.items(): + for source_url, collector in self._collectors.items(): try: await collector.stop() - except Exception as e: - self.error(f"Failed to stop collector for {dcgm_url}: {e}") + except Exception as e: # noqa: BLE001 - fault-tolerant telemetry + self.error(f"Failed to stop collector for {source_url}: {e}") async def _on_telemetry_records( self, records: list[TelemetryRecord], collector_id: str diff --git a/src/aiperf/gpu_telemetry/pynvml_collector.py b/src/aiperf/gpu_telemetry/pynvml_collector.py new file mode 100644 index 000000000..70ecea662 --- /dev/null +++ b/src/aiperf/gpu_telemetry/pynvml_collector.py @@ -0,0 +1,459 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""PyNVML-based GPU telemetry collector. + +Collects GPU metrics directly using the pynvml Python library, providing an +alternative to DCGM HTTP endpoints for local GPU monitoring. +""" + +import asyncio +import contextlib +import threading +import time +from dataclasses import dataclass + +import pynvml + +from aiperf.common.decorators import implements_protocol +from aiperf.common.enums import GPUTelemetryCollectorType +from aiperf.common.environment import Environment +from aiperf.common.hooks import background_task, on_init, on_stop +from aiperf.common.mixins import AIPerfLifecycleMixin +from aiperf.common.models import ( + ErrorDetails, + GpuMetadata, + TelemetryMetrics, + TelemetryRecord, +) +from aiperf.gpu_telemetry.constants import PYNVML_SOURCE_IDENTIFIER +from aiperf.gpu_telemetry.factories import ( + GPUTelemetryCollectorFactory, + GPUTelemetryCollectorProtocol, + TErrorCallback, + TRecordCallback, +) + +__all__ = ["PyNVMLTelemetryCollector"] + + +@dataclass(frozen=True) +class ScalingFactors: + """Unit conversion scaling factors for NVML metrics.""" + + gpu_power_usage = 1e-3 # mW -> W + energy_consumption = 1e-9 # mJ -> MJ + gpu_memory_used = 1e-9 # bytes -> GB + power_violation = 1e-3 # ns -> µs + + +@dataclass(slots=True) +class GpuDeviceState: + """Per-GPU state for NVML telemetry collection. + + Args: + handle: NVML device handle + metadata: GPU metadata + gpm_samples: GPM samples (prev, curr) if GPM supported, else None + """ + + handle: object + metadata: GpuMetadata + gpm_samples: tuple[object, object] | None = None + + +@implements_protocol(GPUTelemetryCollectorProtocol) +@GPUTelemetryCollectorFactory.register(GPUTelemetryCollectorType.PYNVML) +class PyNVMLTelemetryCollector(AIPerfLifecycleMixin): + """Collects GPU telemetry metrics using the pynvml Python library. + + Direct collector that uses NVIDIA's pynvml library to gather GPU metrics + locally without requiring a DCGM HTTP endpoint. Useful for environments + where DCGM is not deployed or for simple local GPU monitoring. + + Features: + - Direct NVML API access via pynvml + - Automatic GPU discovery and enumeration + - Same TelemetryRecord output format as DCGM collector + - Callback-based record delivery + + Requirements: + - pynvml package installed: `pip install nvidia-ml-py` + - NVIDIA driver installed with NVML support + + Args: + collection_interval: Interval in seconds between metric collections (default: from Environment) + record_callback: Optional async callback to receive collected records. + Signature: async (records: list[TelemetryRecord], collector_id: str) -> None + error_callback: Optional async callback to receive collection errors. + Signature: async (error: ErrorDetails, collector_id: str) -> None + collector_id: Unique identifier for this collector instance + + Raises: + RuntimeError: If pynvml package is not installed + """ + + def __init__( + self, + collection_interval: float = Environment.GPU.COLLECTION_INTERVAL, + record_callback: TRecordCallback | None = None, + error_callback: TErrorCallback | None = None, + collector_id: str = "pynvml_collector", + ) -> None: + super().__init__(id=collector_id) + self._collection_interval = collection_interval + self._record_callback = record_callback + self._error_callback = error_callback + + # Per-GPU state (populated on init) + self._gpus: list[GpuDeviceState] = [] + + # NVML initialization state and thread safety + self._nvml_initialized = False + self._nvml_lock = threading.Lock() + + @property + def endpoint_url(self) -> str: + """Get the source identifier for this collector. + + Returns: + 'pynvml://localhost' to identify records from pynvml collection. + """ + return PYNVML_SOURCE_IDENTIFIER + + @property + def collection_interval(self) -> float: + """Get the collection interval in seconds.""" + return self._collection_interval + + async def is_url_reachable(self) -> bool: + """Check if NVML is available and can be initialized. + + Tests NVML availability by attempting initialization if not already done. + This allows pre-flight checks before starting collection. + + Returns: + True if NVML is available and can access at least one GPU. + """ + # If already initialized, just check if we have GPUs + if self._nvml_initialized: + return len(self._gpus) > 0 + + try: + return await asyncio.to_thread(self._probe_nvml_devices) + except Exception: + return False + + def _probe_nvml_devices(self) -> bool: + """Probe NVML to check if GPUs are available. + + Synchronous helper that performs blocking NVML calls to check availability. + Called via asyncio.to_thread to avoid blocking the event loop. + + Returns: + True if NVML can be initialized and at least one GPU is available. + """ + pynvml.nvmlInit() + try: + count = pynvml.nvmlDeviceGetCount() + return count > 0 + finally: + pynvml.nvmlShutdown() + + @on_init + async def _initialize_nvml(self) -> None: + """Initialize NVML and discover GPUs. + + Called automatically during initialization phase. + Initializes the NVML library and enumerates available GPUs. + + Raises: + RuntimeError: If NVML initialization or GPU discovery fails. + """ + try: + pynvml.nvmlInit() + except pynvml.NVMLError as e: + raise RuntimeError(f"Failed to initialize NVML: {e}") from e + + self._nvml_initialized = True + + try: + device_count = pynvml.nvmlDeviceGetCount() + except pynvml.NVMLError as e: + # Cleanup NVML if device enumeration fails + self._shutdown_nvml_sync() + raise RuntimeError(f"Failed to get GPU device count: {e}") from e + + self._gpus = [] + + for i in range(device_count): + gpu = self._create_gpu_for_device_index(i) + if gpu: + self._gpus.append(gpu) + + gpm_count = sum(1 for gpu in self._gpus if gpu.gpm_samples) + self.info( + f"PyNVML initialized with {len(self._gpus)} GPU(s) " + f"({gpm_count} with GPM support)" + ) + + def _create_gpu_for_device_index(self, index: int) -> GpuDeviceState | None: + """Initialize a GPU for telemetry collection.""" + try: + handle = pynvml.nvmlDeviceGetHandleByIndex(index) + except pynvml.NVMLError as e: + self.warning(f"Failed to get handle for GPU {index}: {e}") + return None + + # Gather static metadata for this GPU + try: + uuid = pynvml.nvmlDeviceGetUUID(handle) + except pynvml.NVMLError: + uuid = f"GPU-unknown-{index}" + + try: + name = pynvml.nvmlDeviceGetName(handle) + # pynvml may return bytes in some versions + if isinstance(name, bytes): + name = name.decode("utf-8") + except pynvml.NVMLError: + name = "Unknown GPU" + + try: + pci_info = pynvml.nvmlDeviceGetPciInfo(handle) + pci_bus_id = pci_info.busId + if isinstance(pci_bus_id, bytes): + pci_bus_id = pci_bus_id.decode("utf-8") + except pynvml.NVMLError: + pci_bus_id = None + + # Create GPU state with metadata + # gpu_index in metadata reflects original NVML index for display + gpu = GpuDeviceState( + handle=handle, + metadata=GpuMetadata( + gpu_index=index, + gpu_uuid=uuid, + gpu_model_name=name, + pci_bus_id=pci_bus_id, + device=f"nvidia{index}", + hostname="localhost", + ), + ) + + # Check GPM support and allocate samples for efficient SM utilization + self._init_gpm_for_device(gpu) + return gpu + + def _init_gpm_for_device(self, gpu: GpuDeviceState) -> None: + """Initialize GPM (GPU Performance Metrics) for efficient SM utilization.""" + try: + if not pynvml.nvmlGpmQueryDeviceSupport(gpu.handle).isSupportedDevice: + return + sample1 = pynvml.nvmlGpmSampleAlloc() + sample2 = pynvml.nvmlGpmSampleAlloc() + # Take initial sample so delta computation works on first collection + pynvml.nvmlGpmSampleGet(gpu.handle, sample1) + gpu.gpm_samples = (sample1, sample2) + self.debug(lambda: f"GPM enabled for GPU {gpu.metadata.gpu_index}") + except pynvml.NVMLError: + # GPM unavailable, will use process API fallback + self.debug(lambda: f"GPM not supported for GPU {gpu.metadata.gpu_index}") + + def _free_gpm_samples(self) -> None: + """Free all allocated GPM sample buffers.""" + for gpu in self._gpus: + if gpu.gpm_samples: + for sample in gpu.gpm_samples: + with contextlib.suppress(pynvml.NVMLError): + pynvml.nvmlGpmSampleFree(sample) + gpu.gpm_samples = None + + def _get_sm_utilization_gpm(self, gpu: GpuDeviceState) -> float | None: + """Get SM utilization using GPM API (device-level, more efficient).""" + prev_sample, curr_sample = gpu.gpm_samples # type: ignore[misc] + try: + pynvml.nvmlGpmSampleGet(gpu.handle, curr_sample) + metrics_get = pynvml.c_nvmlGpmMetricsGet_t() + metrics_get.version = pynvml.NVML_GPM_METRICS_GET_VERSION + metrics_get.sample1 = prev_sample + metrics_get.sample2 = curr_sample + metrics_get.numMetrics = 1 + metrics_get.metrics[0].metricId = pynvml.NVML_GPM_METRIC_SM_UTIL + pynvml.nvmlGpmMetricsGet(metrics_get) + sm_util = metrics_get.metrics[0].value + except pynvml.NVMLError: + sm_util = None + gpu.gpm_samples = (curr_sample, prev_sample) # Swap for next iteration + return sm_util + + def _shutdown_nvml_sync(self) -> None: + """Synchronous NVML shutdown helper. + + Thread-safe shutdown that clears all state. Can be called from + any context (init cleanup or stop phase). + """ + with self._nvml_lock: + if not self._nvml_initialized: + return + + # Free GPM samples before NVML shutdown + self._free_gpm_samples() + + try: + pynvml.nvmlShutdown() + except Exception as e: + self.warning(f"Error during NVML shutdown: {e!r}") + finally: + # Always clear state regardless of shutdown success + self._nvml_initialized = False + self._gpus = [] + + @on_stop + async def _shutdown_nvml(self) -> None: + """Shutdown NVML library. + + Called automatically during shutdown phase. + Thread-safe - waits for any in-progress collection to complete. + """ + await asyncio.to_thread(self._shutdown_nvml_sync) + self.debug("PyNVML shutdown complete") + + @background_task(immediate=True, interval=lambda self: self.collection_interval) + async def _collect_metrics_loop(self) -> None: + """Background task for collecting metrics at regular intervals. + + Runs continuously during collector's RUNNING state, triggering a metrics + collection every collection_interval seconds. + """ + await self._collect_and_process_metrics() + + async def _collect_and_process_metrics(self) -> None: + """Collect metrics from all GPUs and send via callback. + + Gathers current metrics from all discovered GPUs using NVML APIs, + converts them to TelemetryRecord objects, and delivers via callback. + Uses asyncio.to_thread() to avoid blocking the event loop with NVML calls. + """ + try: + records = await asyncio.to_thread(self._collect_gpu_metrics) + if records and self._record_callback: + await self._record_callback(records, self.id) + except Exception as e: + if self._error_callback: + try: + await self._error_callback(ErrorDetails.from_exception(e), self.id) + except Exception as callback_error: + self.error(f"Failed to send error via callback: {callback_error}") + else: + self.error(f"Metrics collection error: {e}") + + def _collect_gpu_metrics(self) -> list[TelemetryRecord]: + """Collect metrics from all GPUs using NVML APIs. + + Thread-safe - acquires lock to prevent collection during shutdown. + + Returns: + List of TelemetryRecord objects, one per GPU. + """ + with self._nvml_lock: + if not self._nvml_initialized or not self._gpus: + return [] + + current_timestamp = time.time_ns() + records = [] + NVMLError = pynvml.NVMLError + + for gpu in self._gpus: + handle = gpu.handle + telemetry_data = TelemetryMetrics() + + # Power usage (milliwatts -> watts) + with contextlib.suppress(NVMLError): + power_mw = pynvml.nvmlDeviceGetPowerUsage(handle) + telemetry_data.gpu_power_usage = ( + power_mw * ScalingFactors.gpu_power_usage + ) + + # Total energy consumption (millijoules -> megajoules) + with contextlib.suppress(NVMLError): + energy_mj = pynvml.nvmlDeviceGetTotalEnergyConsumption(handle) + telemetry_data.energy_consumption = ( + energy_mj * ScalingFactors.energy_consumption + ) + + # GPU and memory utilization (percent) + with contextlib.suppress(NVMLError): + util = pynvml.nvmlDeviceGetUtilizationRates(handle) + telemetry_data.gpu_utilization = float(util.gpu) + telemetry_data.mem_utilization = float(util.memory) + + # Memory used (bytes -> gigabytes) + with contextlib.suppress(NVMLError): + mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle) + telemetry_data.gpu_memory_used = ( + mem_info.used * ScalingFactors.gpu_memory_used + ) + + # Temperature (Celsius) + with contextlib.suppress(NVMLError): + temp = pynvml.nvmlDeviceGetTemperature( + handle, pynvml.NVML_TEMPERATURE_GPU + ) + telemetry_data.gpu_temperature = float(temp) + + # Video decoder utilization (percent) + with contextlib.suppress(NVMLError): + dec_util, _ = pynvml.nvmlDeviceGetDecoderUtilization(handle) + telemetry_data.decoder_utilization = float(dec_util) + + # Video encoder utilization (percent) + with contextlib.suppress(NVMLError): + enc_util, _ = pynvml.nvmlDeviceGetEncoderUtilization(handle) + telemetry_data.encoder_utilization = float(enc_util) + + # JPEG decoder utilization (percent) + with contextlib.suppress(NVMLError): + jpg_util, _ = pynvml.nvmlDeviceGetJpgUtilization(handle) + telemetry_data.jpg_utilization = float(jpg_util) + + # SM utilization: prefer GPM (device-level) over process enumeration + sm_util: float | None = None + if gpu.gpm_samples: + sm_util = self._get_sm_utilization_gpm(gpu) + + # Fallback to process-level API if GPM unavailable or returned None + if sm_util is None: + with contextlib.suppress(NVMLError): + process_utils = pynvml.nvmlDeviceGetProcessesUtilizationInfo( + handle, 0 + ) + sm_util = ( + sum(p.smUtil for p in process_utils) + if process_utils + else 0.0 + ) + + if sm_util is not None: + telemetry_data.sm_utilization = min(float(sm_util), 100.0) + + # Power violation / throttling duration (nanoseconds -> microseconds) + with contextlib.suppress(NVMLError): + violation = pynvml.nvmlDeviceGetViolationStatus( + handle, pynvml.NVML_PERF_POLICY_POWER + ) + telemetry_data.power_violation = ( + violation.violationTime * ScalingFactors.power_violation + ) + + # Create record if any metrics were collected + if telemetry_data.model_fields_set: + record = TelemetryRecord( + timestamp_ns=current_timestamp, + dcgm_url=PYNVML_SOURCE_IDENTIFIER, + **gpu.metadata.model_dump(), + telemetry_data=telemetry_data, + ) + records.append(record) + + return records diff --git a/tests/integration/test_custom_gpu_metrics.py b/tests/integration/test_custom_gpu_metrics.py index b8dc2aaf8..1dbf886a4 100644 --- a/tests/integration/test_custom_gpu_metrics.py +++ b/tests/integration/test_custom_gpu_metrics.py @@ -12,11 +12,12 @@ GenericMetricUnit, TemperatureMetricUnit, ) -from aiperf.gpu_telemetry.constants import ( - GPU_TELEMETRY_METRICS_CONFIG, -) from tests.harness.utils import AIPerfCLI, AIPerfMockServer +# DCGMFaker provides 8 of the 12 default metrics defined in GPU_TELEMETRY_METRICS_CONFIG. +# Missing from DCGMFaker: encoder_utilization, decoder_utilization, sm_utilization, jpg_utilization +DCGM_FAKER_DEFAULT_METRIC_COUNT = 8 + @pytest.mark.skipif( platform.system() == "Darwin", @@ -44,7 +45,7 @@ def custom_gpu_metrics_csv(self, tmp_path: Path) -> Path: # Custom temperature metrics (DCGMFaker returns this) DCGM_FI_DEV_MEMORY_TEMP, gauge, Memory temperature (in °C) -# Custom utilization metric (DCGMFaker returns this) +# This is already a default metric (maps to mem_utilization), included to test deduplication DCGM_FI_DEV_MEM_COPY_UTIL, gauge, Memory copy utilization (in %) """ csv_path.write_text(csv_content) @@ -118,18 +119,20 @@ async def test_custom_metrics_csv_loading_basic( for gpu_data in endpoint_data.gpus.values(): assert gpu_data.metrics is not None - default_metric_count = len(GPU_TELEMETRY_METRICS_CONFIG) + # 8 defaults from DCGMFaker + 3 custom (sm_clock, mem_clock, memory_temp) + # Note: DCGM_FI_DEV_MEM_COPY_UTIL maps to default "mem_utilization", not added as custom + expected_min_metrics = DCGM_FAKER_DEFAULT_METRIC_COUNT + 3 - assert len(gpu_data.metrics) >= default_metric_count, ( - f"Expected at least {default_metric_count} default metrics, " + assert len(gpu_data.metrics) >= expected_min_metrics, ( + f"Expected at least {expected_min_metrics} metrics, " f"got {len(gpu_data.metrics)}" ) + # These are the actual custom metrics added (mem_copy_util is a default as mem_utilization) custom_metric_names = [ "sm_clock", "mem_clock", "memory_temp", - "mem_copy_util", ] for metric_name in custom_metric_names: assert metric_name in gpu_data.metrics, ( @@ -162,11 +165,12 @@ async def test_custom_metrics_csv_loading_basic( ), ( f"memory_temp unit is {gpu_data.metrics['memory_temp'].unit}, expected {TemperatureMetricUnit.CELSIUS.value}" ) + # DCGM_FI_DEV_MEM_COPY_UTIL maps to default "mem_utilization" (not "mem_copy_util") assert ( - gpu_data.metrics["mem_copy_util"].unit + gpu_data.metrics["mem_utilization"].unit == GenericMetricUnit.PERCENT.value ), ( - f"mem_copy_util unit is {gpu_data.metrics['mem_copy_util'].unit}, expected {GenericMetricUnit.PERCENT.value}" + f"mem_utilization unit is {gpu_data.metrics['mem_utilization'].unit}, expected {GenericMetricUnit.PERCENT.value}" ) async def test_custom_metrics_deduplication( @@ -209,12 +213,11 @@ async def test_custom_metrics_deduplication( assert "sm_clock" in gpu_data.metrics assert "mem_clock" in gpu_data.metrics - default_metric_count = len(GPU_TELEMETRY_METRICS_CONFIG) - custom_metrics_added = 2 + # 8 defaults from DCGMFaker + 2 custom (sm_clock, mem_clock) + # GPU_UTIL and POWER_USAGE from CSV are already defaults, so not added as custom + expected_min_metrics = DCGM_FAKER_DEFAULT_METRIC_COUNT + 2 - assert ( - len(gpu_data.metrics) >= default_metric_count + custom_metrics_added - ) + assert len(gpu_data.metrics) >= expected_min_metrics async def test_invalid_csv_fallback_to_defaults( self, @@ -245,8 +248,9 @@ async def test_invalid_csv_fallback_to_defaults( for gpu_data in endpoint_data.gpus.values(): assert "sm_clock" in gpu_data.metrics - default_metric_count = len(GPU_TELEMETRY_METRICS_CONFIG) - assert len(gpu_data.metrics) >= default_metric_count + # 8 defaults from DCGMFaker + 1 valid custom (sm_clock) + expected_min_metrics = DCGM_FAKER_DEFAULT_METRIC_COUNT + 1 + assert len(gpu_data.metrics) >= expected_min_metrics async def test_nonexistent_csv_file_error( self, cli: AIPerfCLI, aiperf_mock_server: AIPerfMockServer, tmp_path: Path diff --git a/tests/integration/test_dcgm_faker.py b/tests/integration/test_dcgm_faker.py index 0745cf90b..0e36aa5ff 100644 --- a/tests/integration/test_dcgm_faker.py +++ b/tests/integration/test_dcgm_faker.py @@ -6,7 +6,7 @@ from aiperf_mock_server.dcgm_faker import GPU_CONFIGS, DCGMFaker from pytest import approx -from aiperf.gpu_telemetry.data_collector import GPUTelemetryDataCollector +from aiperf.gpu_telemetry.dcgm_collector import DCGMTelemetryCollector class TestDCGMFaker: @@ -20,7 +20,7 @@ def test_faker_output_parsed_by_real_telemetry_collector(self, gpu_name): print(metrics_text) # Use real TelemetryDataCollector to parse the output - collector = GPUTelemetryDataCollector(dcgm_url="http://fake") + collector = DCGMTelemetryCollector(dcgm_url="http://fake") records = collector._parse_metrics_to_records(metrics_text) # Should get 2 TelemetryRecord objects (one per GPU) @@ -63,7 +63,7 @@ def test_faker_output_parsed_by_real_telemetry_collector(self, gpu_name): def test_load_affects_telemetry_records(self): """Test that load changes affect TelemetryRecords when parsed by real collector.""" faker = DCGMFaker(gpu_name="b200", num_gpus=1, seed=42) - collector = GPUTelemetryDataCollector(dcgm_url="http://fake") + collector = DCGMTelemetryCollector(dcgm_url="http://fake") # Low load faker.set_load(0.1) diff --git a/tests/unit/common/config/test_user_config.py b/tests/unit/common/config/test_user_config.py index 059dc51b1..f77d732c7 100644 --- a/tests/unit/common/config/test_user_config.py +++ b/tests/unit/common/config/test_user_config.py @@ -23,7 +23,11 @@ UserConfig, ) from aiperf.common.config.prompt_config import InputTokensConfig -from aiperf.common.enums import EndpointType, GPUTelemetryMode +from aiperf.common.enums import ( + EndpointType, + GPUTelemetryCollectorType, + GPUTelemetryMode, +) from aiperf.common.enums.dataset_enums import DatasetSamplingStrategy from aiperf.common.enums.timing_enums import ArrivalPattern, TimingMode @@ -299,7 +303,6 @@ def test_urls_extraction(self): "dashboard", "http://node1:9401/metrics", "https://node2:9401/metrics", - "summary", ], ) @@ -307,7 +310,6 @@ def test_urls_extraction(self): assert "http://node1:9401/metrics" in config.gpu_telemetry_urls assert "https://node2:9401/metrics" in config.gpu_telemetry_urls assert "dashboard" not in config.gpu_telemetry_urls - assert "summary" not in config.gpu_telemetry_urls @pytest.mark.parametrize( "gpu_telemetry,expected_urls", @@ -345,6 +347,59 @@ def test_csv_file_not_found(self): with pytest.raises(ValueError, match="GPU metrics file not found"): make_config(gpu_telemetry=["dashboard", "/nonexistent/path/metrics.csv"]) + def test_pynvml_with_urls_raises_error(self): + """Test that using pynvml with DCGM URLs raises an error.""" + with pytest.raises(ValueError, match="Cannot use pynvml with DCGM URLs"): + make_config(gpu_telemetry=["pynvml", "http://localhost:9401/metrics"]) + + def test_pynvml_with_multiple_urls_raises_error(self): + """Test that using pynvml with multiple DCGM URLs raises an error.""" + with pytest.raises(ValueError, match="Cannot use pynvml with DCGM URLs"): + make_config( + gpu_telemetry=[ + "pynvml", + "http://node1:9401/metrics", + "http://node2:9401/metrics", + ] + ) + + def test_pynvml_with_dashboard_allowed(self): + """Test that pynvml can be used with dashboard mode.""" + config = make_config(gpu_telemetry=["pynvml", "dashboard"]) + + assert config.gpu_telemetry_collector_type == GPUTelemetryCollectorType.PYNVML + assert config.gpu_telemetry_mode == GPUTelemetryMode.REALTIME_DASHBOARD + assert config.gpu_telemetry_urls == [] + + def test_pynvml_only(self): + """Test that pynvml can be used alone.""" + config = make_config(gpu_telemetry=["pynvml"]) + + assert config.gpu_telemetry_collector_type == GPUTelemetryCollectorType.PYNVML + assert config.gpu_telemetry_mode == GPUTelemetryMode.SUMMARY + assert config.gpu_telemetry_urls == [] + + @pytest.mark.parametrize( + "invalid_item", + [ + "unknown", + "invalid_option", + "dcgm", + "gpu", + "telemetry", + "metrics", + ], + ) + def test_unknown_item_raises_error(self, invalid_item): + """Test that unknown items in gpu_telemetry raise an error.""" + with pytest.raises(ValueError, match="Invalid GPU telemetry item"): + make_config(gpu_telemetry=[invalid_item]) + + def test_unknown_item_with_valid_items_raises_error(self): + """Test that unknown items mixed with valid items still raise an error.""" + with pytest.raises(ValueError, match="Invalid GPU telemetry item"): + make_config(gpu_telemetry=["dashboard", "unknown_option"]) + # ============================================================================= # Load Generator Validation Tests diff --git a/tests/unit/gpu_telemetry/test_pynvml_collector.py b/tests/unit/gpu_telemetry/test_pynvml_collector.py new file mode 100644 index 000000000..4fb160d46 --- /dev/null +++ b/tests/unit/gpu_telemetry/test_pynvml_collector.py @@ -0,0 +1,847 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for PyNVMLTelemetryCollector. + +Tests use mocked pynvml module to verify collector behavior without requiring +actual GPU hardware. +""" + +import asyncio +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pytest import param + +from aiperf.common.models import TelemetryRecord +from aiperf.gpu_telemetry.constants import PYNVML_SOURCE_IDENTIFIER +from aiperf.gpu_telemetry.pynvml_collector import ScalingFactors + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_pynvml(): + """Create a mock pynvml module with typical NVML responses.""" + mock_module = MagicMock() + + # NVML error class and constants + mock_module.NVMLError = Exception + mock_module.NVML_TEMPERATURE_GPU = 0 + mock_module.NVML_PERF_POLICY_POWER = 0 + + # Mock device count + mock_module.nvmlDeviceGetCount.return_value = 2 + + # Create mock device handles + mock_handles = [MagicMock(), MagicMock()] + + def get_handle_by_index(idx): + return mock_handles[idx] + + mock_module.nvmlDeviceGetHandleByIndex.side_effect = get_handle_by_index + + # GPU data with device-dependent return values + mock_module.nvmlDeviceGetUUID.side_effect = lambda h: ( + "GPU-abc123" if h == mock_handles[0] else "GPU-def456" + ) + mock_module.nvmlDeviceGetName.side_effect = lambda h: ( + "NVIDIA GeForce RTX 4090" if h == mock_handles[0] else "NVIDIA GeForce RTX 4080" + ) + + # PCI info + pci_info_0 = SimpleNamespace(busId="00000000:01:00.0") + pci_info_1 = SimpleNamespace(busId="00000000:02:00.0") + mock_module.nvmlDeviceGetPciInfo.side_effect = lambda h: ( + pci_info_0 if h == mock_handles[0] else pci_info_1 + ) + + # Power usage (milliwatts): 350W, 280W + mock_module.nvmlDeviceGetPowerUsage.side_effect = lambda h: ( + 350000 if h == mock_handles[0] else 280000 + ) + + # Energy consumption (millijoules): 1000J, 800J + mock_module.nvmlDeviceGetTotalEnergyConsumption.side_effect = lambda h: ( + 1000000000 if h == mock_handles[0] else 800000000 + ) + + # Utilization rates (GPU and memory bandwidth) + util_0 = SimpleNamespace(gpu=95, memory=45) + util_1 = SimpleNamespace(gpu=75, memory=35) + mock_module.nvmlDeviceGetUtilizationRates.side_effect = lambda h: ( + util_0 if h == mock_handles[0] else util_1 + ) + + # Memory info (bytes): 20 GB, 16 GB + mem_0 = SimpleNamespace(used=20 * 1024 * 1024 * 1024) + mem_1 = SimpleNamespace(used=16 * 1024 * 1024 * 1024) + mock_module.nvmlDeviceGetMemoryInfo.side_effect = lambda h: ( + mem_0 if h == mock_handles[0] else mem_1 + ) + + # Temperature (Celsius) + mock_module.nvmlDeviceGetTemperature.side_effect = lambda h, t: ( + 72 if h == mock_handles[0] else 68 + ) + + # Video decoder/encoder/JPEG utilization (percent, sampling_period) + mock_module.nvmlDeviceGetDecoderUtilization.side_effect = lambda h: ( + (25, 1000) if h == mock_handles[0] else (15, 1000) + ) + mock_module.nvmlDeviceGetEncoderUtilization.side_effect = lambda h: ( + (30, 1000) if h == mock_handles[0] else (20, 1000) + ) + mock_module.nvmlDeviceGetJpgUtilization.side_effect = lambda h: ( + (10, 1000) if h == mock_handles[0] else (5, 1000) + ) + + # Process utilization info + proc_util_0 = SimpleNamespace(smUtil=85, encUtil=28, decUtil=22, jpgUtil=8) + proc_util_1 = SimpleNamespace(smUtil=65, encUtil=18, decUtil=12, jpgUtil=3) + mock_module.nvmlDeviceGetProcessesUtilizationInfo.side_effect = lambda h, t: ( + [proc_util_0] if h == mock_handles[0] else [proc_util_1] + ) + + # Power violation status (nanoseconds): 5ms, 2ms + violation_0 = SimpleNamespace(violationTime=5000000) + violation_1 = SimpleNamespace(violationTime=2000000) + mock_module.nvmlDeviceGetViolationStatus.side_effect = lambda h, p: ( + violation_0 if h == mock_handles[0] else violation_1 + ) + + # GPM (GPU Performance Metrics) support - disabled by default + gpm_support = SimpleNamespace(isSupportedDevice=False) + mock_module.nvmlGpmQueryDeviceSupport.return_value = gpm_support + mock_module.nvmlGpmSampleAlloc.return_value = MagicMock() + mock_module.nvmlGpmSampleFree.return_value = None + mock_module.nvmlGpmSampleGet.return_value = MagicMock() + + # GPM metrics get - for computing SM utilization + mock_module.NVML_GPM_METRICS_GET_VERSION = 1 + mock_module.NVML_GPM_METRIC_SM_UTIL = 2 + mock_module.c_nvmlGpmMetricsGet_t = MagicMock + + return mock_module + + +@pytest.fixture +def patch_pynvml(mock_pynvml): + """Patch pynvml module reference in the collector module for testing.""" + from aiperf.gpu_telemetry import pynvml_collector + from aiperf.gpu_telemetry.pynvml_collector import PyNVMLTelemetryCollector + + # Patch the pynvml reference in the collector module's namespace + with patch.object(pynvml_collector, "pynvml", mock_pynvml): + yield mock_pynvml, PyNVMLTelemetryCollector + + +@pytest.fixture +def collector(patch_pynvml): + """Create an uninitialized collector.""" + _, PyNVMLTelemetryCollector = patch_pynvml + return PyNVMLTelemetryCollector() + + +@pytest.fixture +async def initialized_collector(patch_pynvml): + """Create and initialize a collector, yielding it for tests, then stopping.""" + _, PyNVMLTelemetryCollector = patch_pynvml + collector = PyNVMLTelemetryCollector() + await collector.initialize() + yield collector + await collector.stop() + + +# --------------------------------------------------------------------------- +# Test Initialization +# --------------------------------------------------------------------------- + + +class TestPyNVMLTelemetryCollectorInitialization: + """Test PyNVMLTelemetryCollector initialization.""" + + def test_initialization_with_custom_values(self, patch_pynvml): + """Test collector initializes with custom values.""" + _, PyNVMLTelemetryCollector = patch_pynvml + + collector = PyNVMLTelemetryCollector( + collection_interval=0.5, + collector_id="test_collector", + ) + + assert collector.id == "test_collector" + assert collector.collection_interval == 0.5 + assert collector.endpoint_url == PYNVML_SOURCE_IDENTIFIER + assert not collector.was_initialized + assert not collector.was_started + + def test_initialization_default_values(self, patch_pynvml): + """Test collector uses default values when not specified.""" + _, PyNVMLTelemetryCollector = patch_pynvml + + collector = PyNVMLTelemetryCollector() + + assert collector.id == "pynvml_collector" + assert collector.collection_interval == 0.333 + assert collector._record_callback is None + assert collector._error_callback is None + + +# --------------------------------------------------------------------------- +# Test Reachability +# --------------------------------------------------------------------------- + + +class TestPyNVMLReachability: + """Test NVML reachability checks.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "device_count,init_error,expected", + [ + param(2, None, True, id="success"), + param(0, None, False, id="no_gpus"), + param(2, Exception("fail"), False, id="nvml_error"), + ], + ) + async def test_is_url_reachable( + self, patch_pynvml, device_count, init_error, expected + ): + """Test reachability under various conditions.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + mock_pynvml.nvmlDeviceGetCount.return_value = device_count + if init_error: + mock_pynvml.nvmlInit.side_effect = init_error + + collector = PyNVMLTelemetryCollector() + result = await collector.is_url_reachable() + + assert result is expected + + +# --------------------------------------------------------------------------- +# Test Lifecycle +# --------------------------------------------------------------------------- + + +class TestPyNVMLLifecycle: + """Test collector lifecycle management.""" + + @pytest.mark.asyncio + async def test_initialize_discovers_gpus(self, initialized_collector): + """Test initialization discovers and catalogs GPUs.""" + assert initialized_collector._nvml_initialized + assert len(initialized_collector._gpus) == 2 + + # Verify GPU metadata + assert initialized_collector._gpus[0].metadata.gpu_uuid == "GPU-abc123" + assert ( + initialized_collector._gpus[0].metadata.gpu_model_name + == "NVIDIA GeForce RTX 4090" + ) + assert initialized_collector._gpus[1].metadata.gpu_uuid == "GPU-def456" + assert ( + initialized_collector._gpus[1].metadata.gpu_model_name + == "NVIDIA GeForce RTX 4080" + ) + + @pytest.mark.asyncio + async def test_stop_shuts_down_nvml(self, patch_pynvml): + """Test stop properly shuts down NVML.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + assert collector._nvml_initialized + + await collector.stop() + + assert not collector._nvml_initialized + mock_pynvml.nvmlShutdown.assert_called() + + @pytest.mark.asyncio + async def test_stop_before_init_safe(self, collector): + """Test stopping before initialization doesn't cause issues.""" + await collector.stop() # Should not raise + + @pytest.mark.asyncio + async def test_stop_clears_device_handles(self, patch_pynvml): + """Test stop clears device handles and metadata.""" + _, PyNVMLTelemetryCollector = patch_pynvml + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + assert len(collector._gpus) == 2 + + await collector.stop() + + assert collector._gpus == [] + + @pytest.mark.asyncio + async def test_init_failure_nvml_init_raises(self, patch_pynvml): + """Test initialization fails gracefully when nvmlInit raises.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + mock_pynvml.nvmlInit.side_effect = mock_pynvml.NVMLError("Driver not loaded") + + collector = PyNVMLTelemetryCollector() + + with pytest.raises(asyncio.CancelledError, match="Failed to initialize NVML"): + await collector.initialize() + + assert not collector._nvml_initialized + + @pytest.mark.asyncio + async def test_init_failure_device_count_raises(self, patch_pynvml): + """Test initialization cleans up when nvmlDeviceGetCount fails.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + mock_pynvml.nvmlDeviceGetCount.side_effect = mock_pynvml.NVMLError( + "Device enumeration failed" + ) + + collector = PyNVMLTelemetryCollector() + + with pytest.raises( + asyncio.CancelledError, match="Failed to get GPU device count" + ): + await collector.initialize() + + # Should have cleaned up NVML + assert not collector._nvml_initialized + mock_pynvml.nvmlShutdown.assert_called() + + @pytest.mark.asyncio + async def test_init_skips_failed_device_handles(self, patch_pynvml): + """Test initialization continues when individual GPU handle fails.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # First GPU fails, second succeeds + mock_pynvml.nvmlDeviceGetHandleByIndex.side_effect = [ + mock_pynvml.NVMLError("GPU 0 failed"), + MagicMock(), + ] + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # Should have initialized with only the second GPU + assert collector._nvml_initialized + assert len(collector._gpus) == 1 + + await collector.stop() + + +# --------------------------------------------------------------------------- +# Test Metrics Collection +# --------------------------------------------------------------------------- + + +class TestPyNVMLMetricsCollection: + """Test GPU metrics collection.""" + + @pytest.mark.asyncio + async def test_collect_gpu_metrics(self, initialized_collector): + """Test metrics collection returns correct TelemetryRecord objects.""" + records = initialized_collector._collect_gpu_metrics() + + assert len(records) == 2 + assert all(isinstance(r, TelemetryRecord) for r in records) + + # Verify both GPUs have expected metadata and metrics + gpu0 = next(r for r in records if r.gpu_index == 0) + gpu1 = next(r for r in records if r.gpu_index == 1) + + # GPU 0 verification + assert gpu0.dcgm_url == PYNVML_SOURCE_IDENTIFIER + assert gpu0.gpu_uuid == "GPU-abc123" + assert gpu0.gpu_model_name == "NVIDIA GeForce RTX 4090" + assert gpu0.telemetry_data.gpu_power_usage == pytest.approx(350.0, rel=0.01) + assert gpu0.telemetry_data.gpu_utilization == 95.0 + assert gpu0.telemetry_data.mem_utilization == 45.0 + assert gpu0.telemetry_data.gpu_temperature == 72.0 + assert gpu0.telemetry_data.gpu_memory_used == pytest.approx(20.0, rel=0.1) + assert gpu0.telemetry_data.encoder_utilization == 30.0 + assert gpu0.telemetry_data.decoder_utilization == 25.0 + assert gpu0.telemetry_data.jpg_utilization == 10.0 + assert gpu0.telemetry_data.sm_utilization == 85.0 + assert gpu0.telemetry_data.power_violation == 5000.0 + + # GPU 1 verification + assert gpu1.gpu_uuid == "GPU-def456" + assert gpu1.telemetry_data.gpu_power_usage == pytest.approx(280.0, rel=0.01) + assert gpu1.telemetry_data.gpu_utilization == 75.0 + assert gpu1.telemetry_data.mem_utilization == 35.0 + assert gpu1.telemetry_data.encoder_utilization == 20.0 + assert gpu1.telemetry_data.decoder_utilization == 15.0 + assert gpu1.telemetry_data.jpg_utilization == 5.0 + assert gpu1.telemetry_data.sm_utilization == 65.0 + assert gpu1.telemetry_data.power_violation == 2000.0 + + @pytest.mark.asyncio + async def test_collect_handles_nvml_errors(self, patch_pynvml): + """Test collection continues when individual metrics fail.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + mock_pynvml.nvmlDeviceGetPowerUsage.side_effect = mock_pynvml.NVMLError( + "Power not supported" + ) + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + records = collector._collect_gpu_metrics() + + # Should still get records with other metrics + assert len(records) == 2 + for r in records: + assert r.telemetry_data.gpu_power_usage is None + assert r.telemetry_data.gpu_utilization is not None + assert r.telemetry_data.gpu_temperature is not None + + await collector.stop() + + @pytest.mark.asyncio + async def test_collect_returns_empty_when_not_initialized(self, collector): + """Test collection returns empty list when NVML not initialized.""" + records = collector._collect_gpu_metrics() + assert records == [] + + +# --------------------------------------------------------------------------- +# Test Callbacks +# --------------------------------------------------------------------------- + + +class TestPyNVMLCallbacks: + """Test callback functionality.""" + + @pytest.mark.asyncio + async def test_record_callback_called(self, patch_pynvml): + """Test record callback is called with collected records.""" + _, PyNVMLTelemetryCollector = patch_pynvml + + mock_callback = AsyncMock() + collector = PyNVMLTelemetryCollector( + record_callback=mock_callback, + collector_id="test_collector", + ) + + await collector.initialize() + await collector._collect_and_process_metrics() + await collector.stop() + + mock_callback.assert_called_once() + records, collector_id = mock_callback.call_args[0] + + assert len(records) == 2 + assert collector_id == "test_collector" + + @pytest.mark.asyncio + async def test_error_callback_on_exception(self, patch_pynvml): + """Test error callback is called when collection fails.""" + _, PyNVMLTelemetryCollector = patch_pynvml + + mock_error_callback = AsyncMock() + collector = PyNVMLTelemetryCollector( + error_callback=mock_error_callback, + collector_id="test_collector", + ) + + await collector.initialize() + + # Force an error by making the collect method raise + collector._collect_gpu_metrics = MagicMock( + side_effect=Exception("Collection failed") + ) + + await collector._collect_and_process_metrics() + await collector.stop() + + mock_error_callback.assert_called_once() + error = mock_error_callback.call_args[0][0] + assert hasattr(error, "message") + + +# --------------------------------------------------------------------------- +# Test Scaling Factors +# --------------------------------------------------------------------------- + + +class TestPyNVMLScalingFactors: + """Test unit scaling factors.""" + + @pytest.mark.parametrize( + "field,factor,raw_value,expected", + [ + param("gpu_power_usage", 1e-3, 350000, 350.0, id="power_mW_to_W"), + param("energy_consumption", 1e-9, 1e9, 1.0, id="energy_mJ_to_MJ"), + param("gpu_memory_used", 1e-9, 20e9, 20.0, id="memory_bytes_to_GB"), + ], + ) + def test_scaling_factor(self, field, factor, raw_value, expected): + """Test scaling factors convert units correctly.""" + assert getattr(ScalingFactors, field) == factor + assert raw_value * getattr(ScalingFactors, field) == expected + + +# --------------------------------------------------------------------------- +# Test Edge Cases +# --------------------------------------------------------------------------- + + +class TestPyNVMLEdgeCases: + """Test edge cases and error handling.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "api_method,return_value,metadata_field,expected", + [ + param( + "nvmlDeviceGetName", + b"NVIDIA RTX 4090", + "gpu_model_name", + "NVIDIA RTX 4090", + id="name", + ), + param( + "nvmlDeviceGetPciInfo", + SimpleNamespace(busId=b"00000000:01:00.0"), + "pci_bus_id", + "00000000:01:00.0", + id="pci_bus_id", + ), + ], + ) + async def test_handles_bytes_values( + self, patch_pynvml, api_method, return_value, metadata_field, expected + ): + """Test handles API values returned as bytes.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + if api_method == "nvmlDeviceGetName": + mock_pynvml.nvmlDeviceGetName.side_effect = lambda h: return_value + else: + mock_pynvml.nvmlDeviceGetPciInfo.return_value = return_value + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + assert getattr(collector._gpus[0].metadata, metadata_field) == expected + + await collector.stop() + + @pytest.mark.asyncio + async def test_energy_consumption_collected(self, initialized_collector): + """Test energy consumption metric is collected and scaled correctly.""" + records = initialized_collector._collect_gpu_metrics() + + gpu0 = next(r for r in records if r.gpu_index == 0) + # 1000000000 mJ * 1e-9 = 1.0 MJ + assert gpu0.telemetry_data.energy_consumption == pytest.approx(1.0, rel=0.01) + + @pytest.mark.asyncio + async def test_sm_utilization_sums_multiple_processes(self, patch_pynvml): + """Test SM utilization sums across multiple processes on same GPU.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # Multiple processes on GPU 0 + proc1 = SimpleNamespace(smUtil=40, encUtil=10, decUtil=5, jpgUtil=2) + proc2 = SimpleNamespace(smUtil=35, encUtil=8, decUtil=3, jpgUtil=1) + mock_pynvml.nvmlDeviceGetProcessesUtilizationInfo.side_effect = lambda h, t: ( + [proc1, proc2] if h == mock_pynvml.nvmlDeviceGetHandleByIndex(0) else [] + ) + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + records = collector._collect_gpu_metrics() + gpu0 = next(r for r in records if r.gpu_index == 0) + + # Should sum: 40 + 35 = 75 + assert gpu0.telemetry_data.sm_utilization == 75.0 + + await collector.stop() + + @pytest.mark.asyncio + async def test_empty_process_list_zero_sm_utilization(self, patch_pynvml): + """Test SM utilization is 0.0 when no processes are running.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # Clear side_effect and set return_value (side_effect takes precedence) + mock_pynvml.nvmlDeviceGetProcessesUtilizationInfo.side_effect = None + mock_pynvml.nvmlDeviceGetProcessesUtilizationInfo.return_value = [] + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + records = collector._collect_gpu_metrics() + + for r in records: + assert r.telemetry_data.sm_utilization == 0.0 + + await collector.stop() + + @pytest.mark.asyncio + async def test_sm_utilization_capped_at_100(self, patch_pynvml): + """Test SM utilization is capped at 100% when sum exceeds it.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # Multiple processes with high utilization that sum > 100% + proc1 = SimpleNamespace(smUtil=60, encUtil=10, decUtil=5, jpgUtil=2) + proc2 = SimpleNamespace(smUtil=55, encUtil=8, decUtil=3, jpgUtil=1) + mock_pynvml.nvmlDeviceGetProcessesUtilizationInfo.side_effect = lambda h, t: ( + [proc1, proc2] if h == mock_pynvml.nvmlDeviceGetHandleByIndex(0) else [] + ) + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + records = collector._collect_gpu_metrics() + gpu0 = next(r for r in records if r.gpu_index == 0) + + # Sum would be 60 + 55 = 115, but should be capped at 100.0 + assert gpu0.telemetry_data.sm_utilization == 100.0 + + await collector.stop() + + @pytest.mark.asyncio + async def test_is_url_reachable_when_already_initialized( + self, initialized_collector + ): + """Test reachability returns True when already initialized with GPUs.""" + result = await initialized_collector.is_url_reachable() + assert result is True + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "api_method,fallback_field,fallback_pattern", + [ + param("nvmlDeviceGetUUID", "gpu_uuid", "GPU-unknown-0", id="uuid"), + param("nvmlDeviceGetName", "gpu_model_name", "Unknown GPU", id="name"), + param("nvmlDeviceGetPciInfo", "pci_bus_id", None, id="pci"), + ], + ) + async def test_metadata_fallback_on_error( + self, patch_pynvml, api_method, fallback_field, fallback_pattern + ): + """Test fallback values when metadata APIs fail during initialization.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + getattr(mock_pynvml, api_method).side_effect = mock_pynvml.NVMLError( + "Not supported" + ) + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + assert getattr(collector._gpus[0].metadata, fallback_field) == fallback_pattern + + await collector.stop() + + +# --------------------------------------------------------------------------- +# Test GPM (GPU Performance Metrics) +# --------------------------------------------------------------------------- + + +class TestPyNVMLGPM: + """Test GPM (GPU Performance Metrics) functionality for efficient SM utilization.""" + + @pytest.mark.asyncio + async def test_gpm_not_supported_uses_process_api(self, patch_pynvml): + """Test fallback to process API when GPM is not supported.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # GPM not supported (default in fixture) + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # Should not have GPM enabled + assert all(gpu.gpm_samples is None for gpu in collector._gpus) + + # Should still collect SM utilization via process API + records = collector._collect_gpu_metrics() + assert all(r.telemetry_data.sm_utilization is not None for r in records) + + await collector.stop() + + @pytest.mark.asyncio + async def test_gpm_supported_allocates_samples(self, patch_pynvml): + """Test GPM sample allocation when device supports GPM.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # Enable GPM support + gpm_support = SimpleNamespace(isSupportedDevice=True) + mock_pynvml.nvmlGpmQueryDeviceSupport.return_value = gpm_support + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # Should have GPM enabled for both GPUs (gpm_samples not None means supported) + assert all(gpu.gpm_samples is not None for gpu in collector._gpus) + + # Each GPU should have two sample buffers allocated + initial sample taken + assert mock_pynvml.nvmlGpmSampleAlloc.call_count == 4 # 2 GPUs * 2 samples each + assert mock_pynvml.nvmlGpmSampleGet.call_count == 2 # Initial sample per GPU + + await collector.stop() + + @pytest.mark.asyncio + async def test_gpm_samples_freed_on_shutdown(self, patch_pynvml): + """Test GPM samples are freed during shutdown.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # Enable GPM support + gpm_support = SimpleNamespace(isSupportedDevice=True) + mock_pynvml.nvmlGpmQueryDeviceSupport.return_value = gpm_support + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + assert all(gpu.gpm_samples is not None for gpu in collector._gpus) + + await collector.stop() + + # GPU list should be cleared on shutdown + assert collector._gpus == [] + # 4 samples freed (2 GPUs * 2 samples each) + assert mock_pynvml.nvmlGpmSampleFree.call_count == 4 + + @pytest.mark.asyncio + async def test_gpm_first_collection_uses_gpm(self, patch_pynvml): + """Test first collection uses GPM (initial sample taken during init).""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # Enable GPM support + gpm_support = SimpleNamespace(isSupportedDevice=True) + mock_pynvml.nvmlGpmQueryDeviceSupport.return_value = gpm_support + + # Mock GPM metrics result + def mock_gpm_metrics_get(metrics_get): + metrics_get.metrics[0].value = 88.5 # SM utilization from GPM + return metrics_get + + mock_pynvml.nvmlGpmMetricsGet.side_effect = mock_gpm_metrics_get + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # Initial sample taken during init (one per GPU) + assert mock_pynvml.nvmlGpmSampleGet.call_count == 2 + + # First collection should use GPM directly (no fallback needed) + records = collector._collect_gpu_metrics() + + # GPM metrics should have been queried + assert mock_pynvml.nvmlGpmMetricsGet.called + + # SM utilization should come from GPM + gpu0 = next(r for r in records if r.gpu_index == 0) + assert gpu0.telemetry_data.sm_utilization == 88.5 + + await collector.stop() + + @pytest.mark.asyncio + async def test_gpm_query_support_failure_disables_gpm(self, patch_pynvml): + """Test GPM is disabled when nvmlGpmQueryDeviceSupport fails.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # GPM query fails + mock_pynvml.nvmlGpmQueryDeviceSupport.side_effect = mock_pynvml.NVMLError( + "GPM not available" + ) + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # GPM should be disabled (no samples allocated) + assert all(gpu.gpm_samples is None for gpu in collector._gpus) + + # Should still work via process API + records = collector._collect_gpu_metrics() + assert all(r.telemetry_data.sm_utilization is not None for r in records) + + await collector.stop() + + @pytest.mark.asyncio + async def test_gpm_sample_alloc_failure_disables_gpm(self, patch_pynvml): + """Test GPM is disabled when sample allocation fails.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # GPM supported but allocation fails + gpm_support = SimpleNamespace(isSupportedDevice=True) + mock_pynvml.nvmlGpmQueryDeviceSupport.return_value = gpm_support + mock_pynvml.nvmlGpmSampleAlloc.side_effect = mock_pynvml.NVMLError( + "Allocation failed" + ) + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # GPM should be disabled due to allocation failure (no samples) + assert all(gpu.gpm_samples is None for gpu in collector._gpus) + + await collector.stop() + + @pytest.mark.asyncio + async def test_gpm_metrics_get_failure_falls_back_to_process_api( + self, patch_pynvml + ): + """Test fallback to process API when nvmlGpmMetricsGet fails.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + # Enable GPM support + gpm_support = SimpleNamespace(isSupportedDevice=True) + mock_pynvml.nvmlGpmQueryDeviceSupport.return_value = gpm_support + + # GPM metrics get fails + mock_pynvml.nvmlGpmMetricsGet.side_effect = mock_pynvml.NVMLError( + "Metrics query failed" + ) + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # First collection - primes the sample buffer + collector._collect_gpu_metrics() + + # Second collection - GPM fails, should fall back to process API + records = collector._collect_gpu_metrics() + + # Should still get SM utilization from process API fallback + assert all(r.telemetry_data.sm_utilization is not None for r in records) + + # Process API should have been called + mock_pynvml.nvmlDeviceGetProcessesUtilizationInfo.assert_called() + + await collector.stop() + + @pytest.mark.asyncio + async def test_gpm_mixed_support(self, patch_pynvml): + """Test handling when only some GPUs support GPM.""" + mock_pynvml, PyNVMLTelemetryCollector = patch_pynvml + + mock_handles = [MagicMock(), MagicMock()] + mock_pynvml.nvmlDeviceGetHandleByIndex.side_effect = lambda i: mock_handles[i] + + # GPU 0 supports GPM, GPU 1 does not + def gpm_support_check(handle): + if handle == mock_handles[0]: + return SimpleNamespace(isSupportedDevice=True) + raise mock_pynvml.NVMLError("Not supported") + + mock_pynvml.nvmlGpmQueryDeviceSupport.side_effect = gpm_support_check + + collector = PyNVMLTelemetryCollector() + await collector.initialize() + + # GPU 0 should have GPM (samples allocated), GPU 1 should not + assert collector._gpus[0].gpm_samples is not None + assert collector._gpus[1].gpm_samples is None + + await collector.stop() diff --git a/tests/unit/gpu_telemetry/test_telemetry_data_collector.py b/tests/unit/gpu_telemetry/test_telemetry_data_collector.py index ea3b7d7ef..f1f40c10d 100644 --- a/tests/unit/gpu_telemetry/test_telemetry_data_collector.py +++ b/tests/unit/gpu_telemetry/test_telemetry_data_collector.py @@ -8,14 +8,14 @@ import pytest from aiperf.common.models.telemetry_models import TelemetryRecord -from aiperf.gpu_telemetry.data_collector import GPUTelemetryDataCollector +from aiperf.gpu_telemetry.dcgm_collector import DCGMTelemetryCollector -class TestGPUTelemetryDataCollectorCore: - """Test core GPUTelemetryDataCollector functionality. +class TestDCGMTelemetryCollectorCore: + """Test core DCGMTelemetryCollector functionality. This test class focuses exclusively on the data collection, parsing, - and lifecycle management of the GPUTelemetryDataCollector using the new async architecture. + and lifecycle management of the DCGMTelemetryCollector using the new async architecture. Key areas tested: - Initialization and configuration @@ -26,13 +26,13 @@ class TestGPUTelemetryDataCollectorCore: """ def test_collector_initialization_complete(self): - """Test GPUTelemetryDataCollector initialization with custom parameters. + """Test DCGMTelemetryCollector initialization with custom parameters. Verifies that the collector properly stores configuration parameters including DCGM URL, collection interval, and collector ID. Also checks that the initial lifecycle state is correct. """ - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://localhost:9401/metrics", collection_interval=0.1, collector_id="test_collector", @@ -46,13 +46,13 @@ def test_collector_initialization_complete(self): assert not collector.was_started def test_collector_initialization_minimal(self): - """Test GPUTelemetryDataCollector initialization with minimal parameters. + """Test DCGMTelemetryCollector initialization with minimal parameters. Verifies that the collector applies correct default values when only the required DCGM URL is provided. Tests default collection interval and default collector ID generation. """ - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") assert collector.endpoint_url == "http://localhost:9401/metrics" assert collector.collection_interval == 0.333 # Default collection interval @@ -76,7 +76,7 @@ def test_complete_parsing_single_gpu(self, sample_dcgm_data): Tests proper unit scaling (MiB→GB for memory, mJ→MJ for energy) and that all metadata and metric values are correctly assigned. """ - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") records = collector._parse_metrics_to_records(sample_dcgm_data) assert len(records) == 1 @@ -101,7 +101,7 @@ def test_complete_parsing_multi_gpu(self, multi_gpu_dcgm_data): metrics for multiple GPUs and create separate TelemetryRecord objects for each. Tests that GPU-specific metadata is correctly associated with the right GPU. """ - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") records = collector._parse_metrics_to_records(multi_gpu_dcgm_data) assert len(records) == 3 @@ -137,7 +137,7 @@ def test_empty_response_handling(self): Note: For full pipeline testing with empty responses, see test_telemetry_integration.py::test_empty_dcgm_response_handling() """ - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") empty_cases = [ "", # Empty @@ -156,7 +156,7 @@ class TestHttpCommunication: @pytest.mark.asyncio async def test_endpoint_reachability_success(self): """Test DCGM endpoint reachability check with successful HTTP response.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") with patch("aiohttp.ClientSession.head") as mock_head: # Mock successful HEAD response with Prometheus content-type @@ -180,7 +180,7 @@ async def test_endpoint_reachability_success(self): @pytest.mark.asyncio async def test_endpoint_reachability_failures(self, time_traveler): """Test DCGM endpoint reachability check with various failure scenarios.""" - collector = GPUTelemetryDataCollector("http://nonexistent:9401/metrics") + collector = DCGMTelemetryCollector("http://nonexistent:9401/metrics") with ( patch("aiohttp.ClientSession.head") as mock_head, @@ -205,7 +205,7 @@ async def test_endpoint_reachability_failures(self, time_traveler): @pytest.mark.asyncio async def test_endpoint_reachability_head_fallback(self): """Test that HEAD request falls back to GET when HEAD returns non-200.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") with ( patch("aiohttp.ClientSession.head") as mock_head, @@ -234,7 +234,7 @@ async def test_endpoint_reachability_head_fallback(self): @pytest.mark.asyncio async def test_endpoint_reachability_without_session(self): """Test reachability check creates temporary session when collector not initialized.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") # Don't initialize - should create temporary session with patch("aiohttp.ClientSession") as mock_session_class: @@ -266,7 +266,7 @@ async def test_endpoint_reachability_without_session(self): @pytest.mark.asyncio async def test_metrics_fetching(self, sample_dcgm_data): """Test successful HTTP fetching of DCGM metrics.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") with patch("aiohttp.ClientSession.get") as mock_get: # Mock successful response with sample data @@ -287,7 +287,7 @@ async def test_metrics_fetching(self, sample_dcgm_data): @pytest.mark.asyncio async def test_fetch_metrics_session_closed(self): """Test fetch_metrics raises error when session is closed.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") await collector.initialize() @@ -301,7 +301,7 @@ async def test_fetch_metrics_session_closed(self): @pytest.mark.asyncio async def test_fetch_metrics_when_stop_requested(self): """Test fetch_metrics raises CancelledError when stop is requested.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") await collector.initialize() @@ -319,7 +319,7 @@ async def test_fetch_metrics_when_stop_requested(self): @pytest.mark.asyncio async def test_fetch_metrics_no_session(self): """Test fetch_metrics raises error when session not initialized.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") # Don't initialize - session is None with pytest.raises(RuntimeError, match="HTTP session not initialized"): @@ -334,7 +334,7 @@ async def test_successful_collection_loop(self, faker): """Test successful telemetry collection with proper lifecycle management.""" mock_callback = AsyncMock() - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://localhost:9401/metrics", collection_interval=0.1, record_callback=mock_callback, @@ -368,7 +368,7 @@ async def test_error_handling_in_collection_loop(self): """ mock_error_callback = AsyncMock() - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://localhost:9401/metrics", collection_interval=0.05, error_callback=mock_error_callback, @@ -398,7 +398,7 @@ async def test_callback_exception_resilience(self, faker): """ mock_callback = AsyncMock(side_effect=ValueError("Callback failed")) - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://localhost:9401/metrics", collection_interval=0.1, record_callback=mock_callback, @@ -431,7 +431,7 @@ async def test_multiple_start_calls_safety(self): Note: For testing multiple start/stop cycles with separate instances (real-world usage), see integration test test_telemetry_collector_multiple_start_stop() """ - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") await collector.initialize() @@ -446,7 +446,7 @@ async def test_multiple_start_calls_safety(self): @pytest.mark.asyncio async def test_stop_before_start_safety(self): """Test that stopping before starting doesn't cause issues.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") # Should handle stop before start gracefully await collector.stop() # Should not raise exceptions @@ -457,7 +457,7 @@ class TestDataProcessingEdgeCases: def test_unit_scaling_accuracy(self): """Test accuracy of unit scaling factors for different metrics.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") test_metrics = { "gpu_power_usage": 100.0, # Should remain unchanged (W) @@ -475,7 +475,7 @@ def test_unit_scaling_accuracy(self): def test_temporal_consistency_in_batches(self, sample_dcgm_data): """Test that all records in a batch have consistent timestamps.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") records = collector._parse_metrics_to_records(sample_dcgm_data) @@ -486,7 +486,7 @@ def test_temporal_consistency_in_batches(self, sample_dcgm_data): def test_mixed_quality_response_resilience(self): """Test resilience when DCGM response contains mix of valid/invalid data.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") mixed_response = """ # Valid metric @@ -506,14 +506,14 @@ def test_mixed_quality_response_resilience(self): @pytest.mark.asyncio async def test_empty_url_reachability(self): """Test URL reachability check with empty URL.""" - collector = GPUTelemetryDataCollector("") + collector = DCGMTelemetryCollector("") result = await collector.is_url_reachable() assert result is False def test_invalid_prometheus_format_handling(self): """Test handling of completely invalid Prometheus format.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") # Invalid format that cannot be parsed invalid_data = "invalid prometheus {{{{{ data" @@ -524,7 +524,7 @@ def test_invalid_prometheus_format_handling(self): def test_nan_inf_values_filtering(self): """Test that NaN and inf values are filtered out during parsing.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") metrics_with_invalid_values = """ # NaN value @@ -550,7 +550,7 @@ def test_nan_inf_values_filtering(self): def test_invalid_gpu_index_handling(self): """Test handling of non-numeric GPU index values.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") invalid_gpu_index_data = """ # Invalid GPU index (not a number) @@ -571,7 +571,7 @@ async def test_error_callback_exception_handling(self): side_effect=RuntimeError("Error callback failed") ) - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://localhost:9401/metrics", collection_interval=0.05, error_callback=mock_error_callback, @@ -590,7 +590,7 @@ async def test_error_callback_exception_handling(self): @pytest.mark.asyncio async def test_collection_without_callbacks(self, faker): - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://localhost:9401/metrics", collection_interval=0.1, ) @@ -610,7 +610,7 @@ async def test_collection_without_callbacks(self, faker): async def test_collection_with_empty_records(self): mock_callback = AsyncMock() - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://localhost:9401/metrics", collection_interval=0.1, record_callback=mock_callback, @@ -631,7 +631,7 @@ async def test_collection_with_empty_records(self): def test_scaling_factors_with_none_values(self): """Test that scaling factors handle None values correctly.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") metrics_with_none = { "gpu_power_usage": None, @@ -649,7 +649,7 @@ def test_scaling_factors_with_none_values(self): def test_scaling_factors_preserves_unscaled_metrics(self): """Test that metrics without scaling factors are preserved as-is.""" - collector = GPUTelemetryDataCollector("http://localhost:9401/metrics") + collector = DCGMTelemetryCollector("http://localhost:9401/metrics") metrics = { "gpu_power_usage": 100.0, diff --git a/tests/unit/gpu_telemetry/test_telemetry_integration.py b/tests/unit/gpu_telemetry/test_telemetry_integration.py index 5385fc999..98de85e36 100644 --- a/tests/unit/gpu_telemetry/test_telemetry_integration.py +++ b/tests/unit/gpu_telemetry/test_telemetry_integration.py @@ -1,10 +1,10 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 """ Integration tests for GPU telemetry collection pipeline. -Tests the end-to-end flow from GPUTelemetryDataCollector through GPUTelemetryAccumulator +Tests the end-to-end flow from DCGMTelemetryCollector through GPUTelemetryAccumulator with realistic mock data and callback mechanisms. """ @@ -16,14 +16,14 @@ from aiperf.common.config import ServiceConfig, UserConfig from aiperf.common.config.endpoint_config import EndpointConfig from aiperf.gpu_telemetry.accumulator import GPUTelemetryAccumulator -from aiperf.gpu_telemetry.data_collector import GPUTelemetryDataCollector +from aiperf.gpu_telemetry.dcgm_collector import DCGMTelemetryCollector class TestGPUTelemetryIntegration: """Integration tests for complete telemetry collection and processing pipeline. This test class verifies the end-to-end integration between: - - GPUTelemetryDataCollector (DCGM data collection) + - DCGMTelemetryCollector (DCGM data collection) - GPUTelemetryAccumulator (hierarchical data organization) - Multi-node telemetry aggregation - Error handling across the pipeline @@ -130,7 +130,7 @@ async def test_multi_node_telemetry_collection_and_processing( Integration test for multi-node telemetry collection through processing pipeline. Tests the complete flow: - 1. GPUTelemetryDataCollector fetches from multiple DCGM endpoints + 1. DCGMTelemetryCollector fetches from multiple DCGM endpoints 2. Records are processed through callbacks 3. GPUTelemetryAccumulator stores in hierarchical structure 4. Statistical aggregation produces MetricResult objects @@ -157,7 +157,7 @@ def mock_aiohttp_get(url, **kwargs): return mock_context_manager with patch("aiohttp.ClientSession.get", side_effect=mock_aiohttp_get): - collector1 = GPUTelemetryDataCollector( + collector1 = DCGMTelemetryCollector( dcgm_url="http://node1:9401/metrics", collection_interval=0.05, record_callback=self.record_callback, @@ -165,7 +165,7 @@ def mock_aiohttp_get(url, **kwargs): collector_id="node1_collector", ) - collector2 = GPUTelemetryDataCollector( + collector2 = DCGMTelemetryCollector( dcgm_url="http://node2:9401/metrics", collection_interval=0.05, record_callback=self.record_callback, @@ -305,7 +305,7 @@ def mock_aiohttp_get_error(url, **kwargs): return mock_context_manager with patch("aiohttp.ClientSession.get", side_effect=mock_aiohttp_get_error): - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://node1:9401/metrics", collection_interval=0.05, record_callback=failing_record_callback, @@ -361,7 +361,7 @@ def mock_aiohttp_get_empty(url, **kwargs): return mock_context_manager with patch("aiohttp.ClientSession.get", side_effect=mock_aiohttp_get_empty): - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://empty-node:9401/metrics", collection_interval=0.1, record_callback=self.record_callback, @@ -410,7 +410,7 @@ def mock_aiohttp_get_scaling(url, **kwargs): return mock_context_manager with patch("aiohttp.ClientSession.get", side_effect=mock_aiohttp_get_scaling): - collector = GPUTelemetryDataCollector( + collector = DCGMTelemetryCollector( dcgm_url="http://testnode:9401/metrics", collection_interval=0.1, record_callback=self.record_callback, diff --git a/tests/unit/gpu_telemetry/test_telemetry_manager.py b/tests/unit/gpu_telemetry/test_telemetry_manager.py index df22fb0a0..b94a22e6a 100644 --- a/tests/unit/gpu_telemetry/test_telemetry_manager.py +++ b/tests/unit/gpu_telemetry/test_telemetry_manager.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 from unittest.mock import AsyncMock, MagicMock, patch @@ -7,6 +7,7 @@ from aiperf.common.config import UserConfig from aiperf.common.config.endpoint_config import EndpointConfig +from aiperf.common.enums import GPUTelemetryCollectorType from aiperf.common.environment import Environment from aiperf.common.messages import ( ProfileConfigureCommand, @@ -15,7 +16,8 @@ TelemetryStatusMessage, ) from aiperf.common.models import ErrorDetails -from aiperf.gpu_telemetry.data_collector import GPUTelemetryDataCollector +from aiperf.gpu_telemetry.constants import PYNVML_SOURCE_IDENTIFIER +from aiperf.gpu_telemetry.dcgm_collector import DCGMTelemetryCollector from aiperf.gpu_telemetry.manager import GPUTelemetryManager @@ -395,7 +397,7 @@ def close_coroutine(coro): manager.error = MagicMock() # Mock error logging # Add a mock collector that will fail to start - mock_collector = AsyncMock(spec=GPUTelemetryDataCollector) + mock_collector = AsyncMock(spec=DCGMTelemetryCollector) mock_collector.initialize.side_effect = Exception("Failed to initialize") manager._collectors["http://localhost:9400/metrics"] = mock_collector @@ -440,8 +442,8 @@ async def test_stop_all_collectors_success(self): manager = self._create_test_manager() # Create mock collectors - mock_collector1 = AsyncMock(spec=GPUTelemetryDataCollector) - mock_collector2 = AsyncMock(spec=GPUTelemetryDataCollector) + mock_collector1 = AsyncMock(spec=DCGMTelemetryCollector) + mock_collector2 = AsyncMock(spec=DCGMTelemetryCollector) manager._collectors = { "http://node1:9401/metrics": mock_collector1, @@ -469,9 +471,9 @@ async def test_stop_all_collectors_handles_failures(self): manager = self._create_test_manager() # Create mock collectors - one fails, one succeeds - mock_collector1 = AsyncMock(spec=GPUTelemetryDataCollector) + mock_collector1 = AsyncMock(spec=DCGMTelemetryCollector) mock_collector1.stop.side_effect = Exception("Stop failed") - mock_collector2 = AsyncMock(spec=GPUTelemetryDataCollector) + mock_collector2 = AsyncMock(spec=DCGMTelemetryCollector) manager._collectors = { "http://node1:9401/metrics": mock_collector1, @@ -630,6 +632,7 @@ def _create_test_manager(self): manager._user_explicitly_configured_telemetry = False manager._telemetry_disabled = False manager._collection_interval = 0.333 + manager._collector_type = GPUTelemetryCollectorType.DCGM manager.error = MagicMock() manager.debug = MagicMock() return manager @@ -640,9 +643,9 @@ async def test_configure_no_shutdown_when_no_endpoints_reachable(self): manager = self._create_test_manager() manager.publish = AsyncMock() - # Mock GPUTelemetryDataCollector to return unreachable + # Mock DCGMTelemetryCollector to return unreachable with patch.object( - GPUTelemetryDataCollector, "is_url_reachable", return_value=False + DCGMTelemetryCollector, "is_url_reachable", return_value=False ): configure_msg = ProfileConfigureCommand( command_id="test", service_id="system_controller", config={} @@ -668,9 +671,9 @@ async def test_configure_sends_enabled_status_when_endpoints_reachable(self): manager = self._create_test_manager() manager.publish = AsyncMock() - # Mock GPUTelemetryDataCollector to return reachable + # Mock DCGMTelemetryCollector to return reachable with patch.object( - GPUTelemetryDataCollector, "is_url_reachable", return_value=True + DCGMTelemetryCollector, "is_url_reachable", return_value=True ): configure_msg = ProfileConfigureCommand( command_id="test", service_id="system_controller", config={} @@ -742,7 +745,7 @@ async def test_start_no_redundant_reachability_check(self): manager.publish = AsyncMock() # Add mock collector - mock_collector = AsyncMock(spec=GPUTelemetryDataCollector) + mock_collector = AsyncMock(spec=DCGMTelemetryCollector) manager._collectors["http://localhost:9400/metrics"] = mock_collector start_msg = ProfileStartCommand( @@ -772,6 +775,7 @@ def _create_test_manager(self, user_requested, user_endpoints): manager._user_explicitly_configured_telemetry = user_requested manager._telemetry_disabled = False manager._collection_interval = 0.333 + manager._collector_type = GPUTelemetryCollectorType.DCGM manager.error = MagicMock() manager.debug = MagicMock() return manager @@ -817,7 +821,7 @@ async def test_show_custom_urls_when_defaults_unreachable(self): # Mock all endpoints as unreachable with patch.object( - GPUTelemetryDataCollector, "is_url_reachable", return_value=False + DCGMTelemetryCollector, "is_url_reachable", return_value=False ): configure_msg = ProfileConfigureCommand( command_id="test", service_id="system_controller", config={} @@ -871,7 +875,7 @@ async def test_hide_defaults_when_not_requested_and_all_unreachable(self): # Mock all endpoints as unreachable with patch.object( - GPUTelemetryDataCollector, "is_url_reachable", return_value=False + DCGMTelemetryCollector, "is_url_reachable", return_value=False ): configure_msg = ProfileConfigureCommand( command_id="test", service_id="system_controller", config={} @@ -884,3 +888,148 @@ async def test_hide_defaults_when_not_requested_and_all_unreachable(self): assert ( len(call_args.endpoints_configured) == 0 ) # No user endpoints, defaults hidden + + +class TestPynvmlCollectorIntegration: + """Test PYNVML collector integration in manager's configure phase.""" + + def _create_test_manager(self): + """Helper to create a TelemetryManager instance configured for PYNVML.""" + manager = GPUTelemetryManager.__new__(GPUTelemetryManager) + manager.service_id = "test_manager" + manager._collectors = {} + manager._collector_id_to_url = {} + manager._dcgm_endpoints = list(Environment.GPU.DEFAULT_DCGM_ENDPOINTS) + manager._user_provided_endpoints = [] + manager._user_explicitly_configured_telemetry = False + manager._telemetry_disabled = False + manager._collection_interval = 0.333 + manager._collector_type = GPUTelemetryCollectorType.PYNVML + manager.error = MagicMock() + manager.warning = MagicMock() + manager.debug = MagicMock() + return manager + + @pytest.mark.asyncio + async def test_configure_pynvml_collector_success(self): + """Test successful PYNVML collector configuration when GPUs are available.""" + manager = self._create_test_manager() + manager.publish = AsyncMock() + + mock_collector = AsyncMock() + mock_collector.is_url_reachable = AsyncMock(return_value=True) + + with patch( + "aiperf.gpu_telemetry.pynvml_collector.PyNVMLTelemetryCollector", + return_value=mock_collector, + ): + configure_msg = ProfileConfigureCommand( + command_id="test", service_id="system_controller", config={} + ) + await manager._profile_configure_command(configure_msg) + + # Should have sent enabled status + manager.publish.assert_called_once() + call_args = manager.publish.call_args[0][0] + assert isinstance(call_args, TelemetryStatusMessage) + assert call_args.enabled is True + assert call_args.reason is None + assert PYNVML_SOURCE_IDENTIFIER in call_args.endpoints_configured + assert PYNVML_SOURCE_IDENTIFIER in call_args.endpoints_reachable + + # Should have collector registered + assert PYNVML_SOURCE_IDENTIFIER in manager._collectors + assert ( + manager._collector_id_to_url["pynvml_collector"] == PYNVML_SOURCE_IDENTIFIER + ) + + @pytest.mark.asyncio + async def test_configure_pynvml_collector_no_gpus_found(self): + """Test PYNVML collector configuration when no GPUs are available.""" + manager = self._create_test_manager() + manager.publish = AsyncMock() + + mock_collector = AsyncMock() + mock_collector.is_url_reachable = AsyncMock(return_value=False) + + with patch( + "aiperf.gpu_telemetry.pynvml_collector.PyNVMLTelemetryCollector", + return_value=mock_collector, + ): + configure_msg = ProfileConfigureCommand( + command_id="test", service_id="system_controller", config={} + ) + await manager._profile_configure_command(configure_msg) + + # Should have sent disabled status + manager.publish.assert_called_once() + call_args = manager.publish.call_args[0][0] + assert isinstance(call_args, TelemetryStatusMessage) + assert call_args.enabled is False + assert call_args.reason == "pynvml not available or no GPUs found" + assert PYNVML_SOURCE_IDENTIFIER in call_args.endpoints_configured + assert call_args.endpoints_reachable == [] + + # Should have no collectors registered + assert len(manager._collectors) == 0 + + # Should have logged warning + manager.warning.assert_called_once() + + @pytest.mark.asyncio + async def test_configure_pynvml_collector_package_not_installed(self): + """Test PYNVML collector configuration when pynvml package is not installed.""" + manager = self._create_test_manager() + manager.publish = AsyncMock() + + with patch( + "aiperf.gpu_telemetry.pynvml_collector.PyNVMLTelemetryCollector", + side_effect=RuntimeError( + "pynvml package not installed. Install with: pip install nvidia-ml-py" + ), + ): + configure_msg = ProfileConfigureCommand( + command_id="test", service_id="system_controller", config={} + ) + await manager._profile_configure_command(configure_msg) + + # Should have sent disabled status with RuntimeError message + manager.publish.assert_called_once() + call_args = manager.publish.call_args[0][0] + assert isinstance(call_args, TelemetryStatusMessage) + assert call_args.enabled is False + assert "pynvml package not installed" in call_args.reason + assert call_args.endpoints_configured == [] + assert call_args.endpoints_reachable == [] + + # Should have logged error + manager.error.assert_called_once() + assert "pynvml package not installed" in str(manager.error.call_args) + + @pytest.mark.asyncio + async def test_configure_pynvml_collector_general_exception(self): + """Test PYNVML collector configuration handles unexpected exceptions.""" + manager = self._create_test_manager() + manager.publish = AsyncMock() + + with patch( + "aiperf.gpu_telemetry.pynvml_collector.PyNVMLTelemetryCollector", + side_effect=ValueError("Unexpected initialization error"), + ): + configure_msg = ProfileConfigureCommand( + command_id="test", service_id="system_controller", config={} + ) + await manager._profile_configure_command(configure_msg) + + # Should have sent disabled status with general error message + manager.publish.assert_called_once() + call_args = manager.publish.call_args[0][0] + assert isinstance(call_args, TelemetryStatusMessage) + assert call_args.enabled is False + assert "pynvml configuration failed" in call_args.reason + assert call_args.endpoints_configured == [] + assert call_args.endpoints_reachable == [] + + # Should have logged error about failed configuration + manager.error.assert_called_once() + assert "Failed to configure pynvml collector" in str(manager.error.call_args) diff --git a/tests/unit/server/test_dcgm_faker.py b/tests/unit/server/test_dcgm_faker.py index e5f8f92cb..120e0103f 100644 --- a/tests/unit/server/test_dcgm_faker.py +++ b/tests/unit/server/test_dcgm_faker.py @@ -12,7 +12,7 @@ ) from pytest import approx -from aiperf.gpu_telemetry.data_collector import GPUTelemetryDataCollector +from aiperf.gpu_telemetry.dcgm_collector import DCGMTelemetryCollector class TestGPUConfig: @@ -196,7 +196,7 @@ def test_faker_output_parsed_by_real_telemetry_collector(self, gpu_name): print(metrics_text) # Use real TelemetryDataCollector to parse the output - collector = GPUTelemetryDataCollector(dcgm_url="http://fake") + collector = DCGMTelemetryCollector(dcgm_url="http://fake") records = collector._parse_metrics_to_records(metrics_text) # Should get 2 TelemetryRecord objects (one per GPU) @@ -239,7 +239,7 @@ def test_faker_output_parsed_by_real_telemetry_collector(self, gpu_name): def test_load_affects_telemetry_records(self): """Test that load changes affect TelemetryRecords when parsed by real collector.""" faker = DCGMFaker(gpu_name="b200", num_gpus=1, seed=42) - collector = GPUTelemetryDataCollector(dcgm_url="http://fake") + collector = DCGMTelemetryCollector(dcgm_url="http://fake") # Low load faker.set_load(0.1) @@ -262,7 +262,7 @@ def test_load_affects_telemetry_records(self): def test_metrics_clamped_to_bounds(self): """Test that all metrics are clamped to [0, max] bounds.""" faker = DCGMFaker(gpu_name="h100", num_gpus=2, seed=42) - collector = GPUTelemetryDataCollector(dcgm_url="http://fake") + collector = DCGMTelemetryCollector(dcgm_url="http://fake") # Test extreme high load faker.set_load(1.0)