Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 28 additions & 2 deletions flashinfer/gemm/gemm_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,32 @@ def is_cudnn_override_shape_available() -> bool:
return False


def _is_fp4_cudnn_override_shape_trusted(device) -> bool:
"""Return True iff the cuDNN FP4 override-shape path is both available and
numerically trusted on ``device``.

On SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10), the FP4
override-shape fast path introduced in #2910 produces NaN/Inf output for
realistic NVFP4 shapes (observed on Nemotron-3-Nano-30B-FP4). The BF16 and
MXFP8 override-shape paths go through different helpers and are not
implicated — they still take the fast path on all archs.

Until the FP4 override-shape helpers
(``_get_real_fp4_shape_from_packed_uint8``,
``_expand_block_scale_tensor_shape``, and the ``cache_m`` bucketing in
``_get_override_graph``) are audited, force SM12x FP4 back to the
static-shape path — same cuDNN backend, numerically correct.
"""
if not is_cudnn_override_shape_available():
return False
try:
return not is_sm12x_supported(device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using is_sm12x_supported(device) as a guard here might be problematic because it includes a CUDA version check (requiring CUDA 12.8/12.9+). If a user runs on SM120/121 with an older CUDA version (but with a cuDNN version that supports override shapes), is_sm12x_supported will return False, causing this function to return True (trusted). This would re-expose the NaN/Inf issue on those systems. It is safer to check the compute capability major version directly to cover all SM12x architectures regardless of the CUDA version.

Suggested change
return not is_sm12x_supported(device)
major, _ = get_compute_capability(device)
return major != 12

except Exception:
# Fail closed: if we cannot resolve the arch, do not re-expose the
# NaN path.
return False


# One cudnn handle per each GPU
_cudnn_handles: dict[int, int] = {}

Expand Down Expand Up @@ -4274,7 +4300,7 @@ def get_valid_tactics(

# currently cudnn backend does not support alpha for dynamic-shape
# remove this restriction once cudnn suppport it
if is_cudnn_override_shape_available():
if _is_fp4_cudnn_override_shape_trusted(a.device):
graph = self._get_override_graph(
a, b, alpha, out_dtype, block_size, use_nvfp4
)
Expand Down Expand Up @@ -4334,7 +4360,7 @@ def forward(

# currently cudnn backend does not support alpha for dynamic-shape
# remove this restriction once cudnn suppport it
if is_cudnn_override_shape_available():
if _is_fp4_cudnn_override_shape_trusted(a.device):
graph = self._get_override_graph(
a, b, alpha, out_dtype, block_size, use_nvfp4
)
Expand Down
30 changes: 30 additions & 0 deletions tests/gemm/test_mm_fp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,5 +141,35 @@ def test_mm_fp4_backend_auto(
_test_mm_fp4(m, n, k, res_dtype, "auto", use_128x4_sf_layout, auto_tuning, fp4_type)


# Regression guard for the FP4 cuDNN override-shape NaN on SM120/121.
# Before `_is_fp4_cudnn_override_shape_trusted` forced these archs back to
# the static-shape path, the shapes below returned non-finite output on
# SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10). `_test_mm_fp4`
# checks cosine-sim > 0.97 against a bfloat16 matmul reference, so any
# NaN/Inf output fails the test.
@pytest.mark.parametrize(
"m,n,k",
[
(1, 6144, 4096),
(16, 6144, 4096),
(32, 4096, 6144),
],
)
def test_mm_fp4_cudnn_finite_on_sm12x(m, n, k):
compute_capability = get_compute_capability(torch.device("cuda"))
if compute_capability[0] != 12:
pytest.skip("Regression is SM120/121-specific; skip on other archs.")
_test_mm_fp4(
m,
n,
k,
res_dtype=torch.bfloat16,
backend="cudnn",
use_128x4_sf_layout=True,
auto_tuning=False,
fp4_type="nvfp4",
)


if __name__ == "__main__":
pytest.main([__file__])
Loading