Skip to content

Commit 1a112a2

Browse files
Add blackwell check in TRTRTX EP unit tests for FP4/FP8 Custom ops (#26832)
### Description - Follow-up to [PR-26555](#26555) - add Blackwell check in TRTRTX EP unit tests for FP4/FP8 Custom ops since Blackwell ### Motivation and Context - NVFP4 recipe (combination of FP4 and FP8) is primarily intended for Blackwell+ GPUs as they have Tensor Cores for FP4 data type.
1 parent aab5661 commit 1a112a2

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

onnxruntime/test/common/cuda_op_test_utils.cc

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4-
#ifdef USE_CUDA
4+
#include <iostream>
5+
6+
#if defined(USE_CUDA) || defined(USE_NV)
57
#include "cuda_runtime_api.h"
68
#endif
79

@@ -13,7 +15,7 @@ int GetCudaArchitecture() {
1315
// Usually, we test on a single GPU or multiple GPUs of same architecture, so it's fine to cache the result.
1416
static int cuda_arch = -1;
1517

16-
#ifdef USE_CUDA
18+
#if defined(USE_CUDA) || defined(USE_NV)
1719
if (cuda_arch == -1) {
1820
int current_device_id = 0;
1921
cudaGetDevice(&current_device_id);
@@ -26,6 +28,15 @@ int GetCudaArchitecture() {
2628
if (cudaSuccess == cudaGetDeviceProperties(&prop, current_device_id)) {
2729
cuda_arch = prop.major * 100 + prop.minor * 10;
2830
}
31+
32+
// Log GPU compute capability
33+
if (cuda_arch == -1) {
34+
std::cout << "WARNING: CUDA is not available or failed to initialize" << std::endl;
35+
} else {
36+
std::cout << "GPU Compute Capability: SM "
37+
<< cuda_arch / 100 << "." << (cuda_arch % 100) / 10
38+
<< " (value: " << cuda_arch << ")" << std::endl;
39+
}
2940
}
3041
#endif
3142

onnxruntime/test/providers/nv_tensorrt_rtx/nv_basic_test.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "test/util/include/scoped_env_vars.h"
1010
#include "test/common/trt_op_test_utils.h"
1111
#include "test/common/random_generator.h"
12+
#include "test/common/cuda_op_test_utils.h"
1213
#include "test/providers/nv_tensorrt_rtx/test_nv_trt_rtx_ep_util.h"
1314

1415
#include <thread>
@@ -22,6 +23,21 @@ namespace onnxruntime {
2223

2324
namespace test {
2425

26+
// Helper function to check if GPU is Blackwell (SM 12.0+) or above
27+
// Returns true if requirement is met
28+
// Returns false if CUDA is unavailable or GPU is below SM 12.0
29+
static bool IsBlackwellOrAbove() {
30+
constexpr int kBlackwellMinCapability = 1200; // SM 12.0 = 12 * 100 + 0 * 10
31+
int cuda_arch = GetCudaArchitecture();
32+
33+
// Check if CUDA is available
34+
if (cuda_arch == -1) {
35+
return false;
36+
}
37+
38+
return cuda_arch >= kBlackwellMinCapability;
39+
}
40+
2541
TEST(NvExecutionProviderTest, ContextEmbedAndReload) {
2642
PathString model_name = ORT_TSTR("nv_execution_provider_test.onnx");
2743
PathString model_name_ctx = ORT_TSTR("nv_execution_provider_test_ctx.onnx");
@@ -442,6 +458,10 @@ TEST(NvExecutionProviderTest, DataTransfer) {
442458
}
443459

444460
TEST(NvExecutionProviderTest, FP8CustomOpModel) {
461+
if (!IsBlackwellOrAbove()) {
462+
GTEST_SKIP() << "Test requires SM 12.0+ GPU (Blackwell+)";
463+
}
464+
445465
PathString model_name = ORT_TSTR("nv_execution_provider_fp8_quantize_dequantize_test.onnx");
446466
clearFileIfExists(model_name);
447467
std::string graph_name = "nv_execution_provider_fp8_quantize_dequantize_graph";
@@ -509,6 +529,10 @@ TEST(NvExecutionProviderTest, FP8CustomOpModel) {
509529
}
510530

511531
TEST(NvExecutionProviderTest, FP4CustomOpModel) {
532+
if (!IsBlackwellOrAbove()) {
533+
GTEST_SKIP() << "Test requires SM 12.0+ GPU (Blackwell+)";
534+
}
535+
512536
PathString model_name = ORT_TSTR("nv_execution_provider_fp4_dynamic_quantize_test.onnx");
513537
clearFileIfExists(model_name);
514538
std::string graph_name = "nv_execution_provider_fp4_dynamic_quantize_graph";

0 commit comments

Comments
 (0)