diff --git a/flashinfer/gemm/gemm_base.py b/flashinfer/gemm/gemm_base.py index decc213e6f..d36687845b 100755 --- a/flashinfer/gemm/gemm_base.py +++ b/flashinfer/gemm/gemm_base.py @@ -1751,6 +1751,32 @@ def is_cudnn_override_shape_available() -> bool: return False +def _is_fp4_cudnn_override_shape_trusted(device) -> bool: + """Return True iff the cuDNN FP4 override-shape path is both available and + numerically trusted on ``device``. + + On SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10), the FP4 + override-shape fast path introduced in #2910 produces NaN/Inf output for + realistic NVFP4 shapes (observed on Nemotron-3-Nano-30B-FP4). The BF16 and + MXFP8 override-shape paths go through different helpers and are not + implicated — they still take the fast path on all archs. + + Until the FP4 override-shape helpers + (``_get_real_fp4_shape_from_packed_uint8``, + ``_expand_block_scale_tensor_shape``, and the ``cache_m`` bucketing in + ``_get_override_graph``) are audited, force SM12x FP4 back to the + static-shape path — same cuDNN backend, numerically correct. + """ + if not is_cudnn_override_shape_available(): + return False + try: + return not is_sm12x_supported(device) + except Exception: + # Fail closed: if we cannot resolve the arch, do not re-expose the + # NaN path. + return False + + # One cudnn handle per each GPU _cudnn_handles: dict[int, int] = {} @@ -4274,7 +4300,7 @@ def get_valid_tactics( # currently cudnn backend does not support alpha for dynamic-shape # remove this restriction once cudnn suppport it - if is_cudnn_override_shape_available(): + if _is_fp4_cudnn_override_shape_trusted(a.device): graph = self._get_override_graph( a, b, alpha, out_dtype, block_size, use_nvfp4 ) @@ -4334,7 +4360,7 @@ def forward( # currently cudnn backend does not support alpha for dynamic-shape # remove this restriction once cudnn suppport it - if is_cudnn_override_shape_available(): + if _is_fp4_cudnn_override_shape_trusted(a.device): graph = self._get_override_graph( a, b, alpha, out_dtype, block_size, use_nvfp4 ) diff --git a/tests/gemm/test_mm_fp4.py b/tests/gemm/test_mm_fp4.py index 519536ae0b..5fda86c4d4 100644 --- a/tests/gemm/test_mm_fp4.py +++ b/tests/gemm/test_mm_fp4.py @@ -141,5 +141,35 @@ def test_mm_fp4_backend_auto( _test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type) +# Regression guard for the FP4 cuDNN override-shape NaN on SM120/121. +# Before `_is_fp4_cudnn_override_shape_trusted` forced these archs back to +# the static-shape path, the shapes below returned non-finite output on +# SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10). `_test_mm_fp4` +# checks cosine-sim > 0.97 against a bfloat16 matmul reference, so any +# NaN/Inf output fails the test. +@pytest.mark.parametrize( + "m,n,k", + [ + (1, 6144, 4096), + (16, 6144, 4096), + (32, 4096, 6144), + ], +) +def test_mm_fp4_cudnn_finite_on_sm12x(m, n, k): + compute_capability = get_compute_capability(torch.device("cuda")) + if compute_capability[0] != 12: + pytest.skip("Regression is SM120/121-specific; skip on other archs.") + _test_mm_fp4( + m, + n, + k, + res_dtype=torch.bfloat16, + backend="cudnn", + use_128x4_sf_layout=True, + auto_tuning=False, + fp4_type="nvfp4", + ) + + if __name__ == "__main__": pytest.main([__file__])