From 06d5ee5bd8978e0165635e4b2dbfbe49987a13e6 Mon Sep 17 00:00:00 2001 From: Serge Panev Date: Tue, 21 Apr 2026 15:40:17 -0700 Subject: [PATCH] fix(gemm): skip FP4 cuDNN override-shape path on SM120/SM121 The FP4 override-shape fast path in CudnnFp4GemmRunner, introduced in #2910, returns NaN/Inf output on SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10) for realistic NVFP4 shapes. Confirmed on Nemotron-3-Nano-30B-FP4 via sglang; the corrupt logits trip the torch sampler's "probability tensor contains inf/nan" assert on real requests. Add `_is_fp4_cudnn_override_shape_trusted(device)`, which returns True only when override-shape is available *and* the device is not SM12x, and use it at the two FP4 call sites (get_valid_tactics, forward). The static-shape cuDNN path it falls back to is the pre-#2910 behavior, uses the same backend, and is numerically correct. Scope: - Only CudnnFp4GemmRunner is gated. The BF16 and MXFP8 override-shape paths go through different helpers, are not implicated, and keep the #2910 fast path on all archs. - The guard uses is_sm12x_supported(), matching the convention already used elsewhere in gemm_base.py. - Helper fails closed (returns False) if compute capability cannot be resolved, so an error path cannot re-expose the NaN behavior. This is a guard, not a root-cause fix. Suspected culprits in #2910 are _get_real_fp4_shape_from_packed_uint8, _expand_block_scale_tensor_shape, and the `cache_m = last_positive_power_of_2(actual_m)` bucketing in _get_override_graph. A follow-up PR will narrow the fault and remove this guard. Add an SM12x-only regression test in tests/gemm/test_mm_fp4.py that runs mm_fp4(backend="cudnn") on shapes known to trigger the NaN and relies on the existing cosine-similarity assertion in _test_mm_fp4 to catch any non-finite output. --- flashinfer/gemm/gemm_base.py | 30 ++++++++++++++++++++++++++++-- tests/gemm/test_mm_fp4.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+), 2 deletions(-) diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index decc213e6f..d36687845b 100755 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1751,6 +1751,32 @@ def is_cudnn_override_shape_available() -> bool: return False +def _is_fp4_cudnn_override_shape_trusted(device) -> bool: + """Return True iff the cuDNN FP4 override-shape path is both available and + numerically trusted on ``device``. + + On SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10), the FP4 + override-shape fast path introduced in #2910 produces NaN/Inf output for + realistic NVFP4 shapes (observed on Nemotron-3-Nano-30B-FP4). The BF16 and + MXFP8 override-shape paths go through different helpers and are not + implicated — they still take the fast path on all archs. + + Until the FP4 override-shape helpers + (``_get_real_fp4_shape_from_packed_uint8``, + ``_expand_block_scale_tensor_shape``, and the ``cache_m`` bucketing in + ``_get_override_graph``) are audited, force SM12x FP4 back to the + static-shape path — same cuDNN backend, numerically correct. + """ + if not is_cudnn_override_shape_available(): + return False + try: + return not is_sm12x_supported(device) + except Exception: + # Fail closed: if we cannot resolve the arch, do not re-expose the + # NaN path. + return False + + # One cudnn handle per each GPU _cudnn_handles: dict[int, int] = {} @@ -4274,7 +4300,7 @@ def get_valid_tactics( # currently cudnn backend does not support alpha for dynamic-shape # remove this restriction once cudnn suppport it - if is_cudnn_override_shape_available(): + if _is_fp4_cudnn_override_shape_trusted(a.device): graph = self._get_override_graph( a, b, alpha, out_dtype, block_size, use_nvfp4 ) @@ -4334,7 +4360,7 @@ def forward( # currently cudnn backend does not support alpha for dynamic-shape # remove this restriction once cudnn suppport it - if is_cudnn_override_shape_available(): + if _is_fp4_cudnn_override_shape_trusted(a.device): graph = self._get_override_graph( a, b, alpha, out_dtype, block_size, use_nvfp4 ) diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index 519536ae0b..5fda86c4d4 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -141,5 +141,35 @@ def test_mm_fp4_backend_auto( _test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type) +# Regression guard for the FP4 cuDNN override-shape NaN on SM120/121. +# Before `_is_fp4_cudnn_override_shape_trusted` forced these archs back to +# the static-shape path, the shapes below returned non-finite output on +# SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10). `_test_mm_fp4` +# checks cosine-sim > 0.97 against a bfloat16 matmul reference, so any +# NaN/Inf output fails the test. +@pytest.mark.parametrize( + "m,n,k", + [ + (1, 6144, 4096), + (16, 6144, 4096), + (32, 4096, 6144), + ], +) +def test_mm_fp4_cudnn_finite_on_sm12x(m, n, k): + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] != 12: + pytest.skip("Regression is SM120/121-specific; skip on other archs.") + _test_mm_fp4( + m, + n, + k, + res_dtype=torch.bfloat16, + backend="cudnn", + use_128x4_sf_layout=True, + auto_tuning=False, + fp4_type="nvfp4", + ) + + if __name__ == "__main__": pytest.main([__file__])