Skip to content

fix(gemm): skip FP4 cuDNN override-shape path on SM120/SM121 (NaN regression from #2910)#3140

Open
Kh4L wants to merge 1 commit intoflashinfer-ai:mainfrom
Kh4L:fix/sm120-121-fp4-cudnn-override-shape-guard
Open

fix(gemm): skip FP4 cuDNN override-shape path on SM120/SM121 (NaN regression from #2910)#3140
Kh4L wants to merge 1 commit intoflashinfer-ai:mainfrom
Kh4L:fix/sm120-121-fp4-cudnn-override-shape-guard

Conversation

@Kh4L
Copy link
Copy Markdown

@Kh4L Kh4L commented Apr 21, 2026

The FP4 override-shape fast path in CudnnFp4GemmRunner added in #2910 returns NaN/Inf on SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10) for realistic NVFP4 shapes, silently corrupting logits (reproduced on Nemotron-3-Nano-30B-FP4 via sglang).

This PR adds _is_fp4_cudnn_override_shape_trusted(device) and routes SM12x FP4 back to the static-shape cuDNN path. BF16 and MXFP8 paths are untouched.

Guard only — a follow-up will narrow the fault (suspected in _get_real_fp4_shape_from_packed_uint8, _expand_block_scale_tensor_shape, or the cache_m bucketing in _get_override_graph) and remove this guard.

Tests

  • SM12x-only regression test added in tests/gemm/test_mm_fp4.py.
  • End-to-end: sglang L0_Nemotron-3-Nano-30B-FP4 on SM121 passes with a wheel built from this PR.

Summary by CodeRabbit

  • Bug Fixes

    • Enhanced FP4 matrix multiplication reliability by refining cuDNN execution path selection based on GPU architecture capabilities.
  • Tests

    • Added regression test for FP4 operations to ensure numerical correctness on SM12x-class GPU architectures.

The FP4 override-shape fast path in CudnnFp4GemmRunner, introduced in
flashinfer-ai#2910, returns NaN/Inf output on SM120 (RTX PRO 6000 Blackwell) and
SM121 (DGX Spark GB10) for realistic NVFP4 shapes. Confirmed on
Nemotron-3-Nano-30B-FP4 via sglang; the corrupt logits trip the torch
sampler's "probability tensor contains inf/nan" assert on real
requests.

Add `_is_fp4_cudnn_override_shape_trusted(device)`, which returns True
only when override-shape is available *and* the device is not SM12x,
and use it at the two FP4 call sites (get_valid_tactics, forward). The
static-shape cuDNN path it falls back to is the pre-flashinfer-ai#2910 behavior,
uses the same backend, and is numerically correct.

Scope:
- Only CudnnFp4GemmRunner is gated. The BF16 and MXFP8 override-shape
  paths go through different helpers, are not implicated, and keep
  the flashinfer-ai#2910 fast path on all archs.
- The guard uses is_sm12x_supported(), matching the convention
  already used elsewhere in gemm_base.py.
- Helper fails closed (returns False) if compute capability cannot be
  resolved, so an error path cannot re-expose the NaN behavior.

This is a guard, not a root-cause fix. Suspected culprits in flashinfer-ai#2910 are
_get_real_fp4_shape_from_packed_uint8, _expand_block_scale_tensor_shape,
and the `cache_m = last_positive_power_of_2(actual_m)` bucketing in
_get_override_graph. A follow-up PR will narrow the fault and remove
this guard.

Add an SM12x-only regression test in tests/gemm/test_mm_fp4.py that
runs mm_fp4(backend="cudnn") on shapes known to trigger the NaN and
relies on the existing cosine-similarity assertion in _test_mm_fp4 to
catch any non-finite output.
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Added a device-specific trust gate function _is_fp4_cudnn_override_shape_trusted() to control FP4 cuDNN override-shape GEMM usage by excluding SM12x architectures and checking override-shape availability. Updated FP4 GEMM control flow to use this gate instead of the previous availability check. Added regression test for SM12x finite result validation.

Changes

Cohort / File(s) Summary
FP4 cuDNN Override-Shape Gate
flashinfer/gemm/gemm_base.py
Added _is_fp4_cudnn_override_shape_trusted(device) function to gate cuDNN FP4 override-shape usage by requiring override-shape availability while excluding SM12x architectures (fails closed on architecture resolution errors). Updated get_valid_tactics() and forward() to use this gate instead of is_cudnn_override_shape_available() for FP4 GEMM cuDNN path selection.
FP4 SM12x Regression Test
tests/gemm/test_mm_fp4.py
Added test_mm_fp4_cudnn_finite_on_sm12x() parametrized test covering SM12x architectures (major version 12). Tests FP4 GEMM with cuDNN backend, bfloat16 results, and specific configuration to validate finite output and catch non-finite results from cuDNN override-shape behavior.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested labels

op: gemm

Suggested reviewers

  • dhiraj113
  • aleozlx
  • yzh119
  • bkryu

Poem

🐰 A rabbit hops through FP4 terrain,
SM12x brought some NaN pain—
A trust gate stands, "fail closed" so true,
Let cuDNN prove what it can do,
Regression tests now guard the way! 🛡️✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the fix: skipping FP4 cuDNN override-shape on SM120/SM121 due to a NaN regression, and references issue #2910 for context.
Description check ✅ Passed The PR description adequately explains the issue, the solution, testing approach, and future work, though it lacks explicit pre-commit and test completion checkboxes from the template.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)

1770-1774: Prefer raw compute capability for the SM12x blocklist.

is_sm12x_supported() includes CUDA version checks (≥12.8 for SM120, ≥12.9 for SM121). Using not is_sm12x_supported(device) couples the architecture blocklist to toolkit availability: an SM120 device with insufficient CUDA would incorrectly be allowed to use override-shape. Since the NaN/Inf issue is architectural (stated in the docstring as specific to SM120/SM121), checking major != 12 directly better expresses the intent and avoids unnecessary coupling. This also narrows the overly broad except Exception flagged by Ruff.

Proposed refactor
     if not is_cudnn_override_shape_available():
         return False
     try:
-        return not is_sm12x_supported(device)
-    except Exception:
+        major, _ = get_compute_capability(device)
+    except (RuntimeError, ValueError, TypeError):
         # Fail closed: if we cannot resolve the arch, do not re-expose the
         # NaN path.
         return False
+    return major != 12
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 1770 - 1774, The code should block
SM120/SM121 by checking the device's raw compute capability major instead of
calling is_sm12x_supported(device) (which mixes in CUDA-toolkit checks) and
avoid the broad except; in the is_cudnn_override_shape_available() conditional
replace the try/except with reading the device's compute-major (e.g.,
device.compute_capability_major or device.compute_capability[0] /
device.cc_major depending on the device object in this codebase) and return
False when major == 12 (i.e., return major != 12); if that attribute access
might not exist, explicitly handle only the specific exception(s) you expect
(AttributeError/TypeError) and in that narrow except fallback to calling
is_sm12x_supported(device) or propagate the error — do not use a bare except
Exception.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1770-1774: The code should block SM120/SM121 by checking the
device's raw compute capability major instead of calling
is_sm12x_supported(device) (which mixes in CUDA-toolkit checks) and avoid the
broad except; in the is_cudnn_override_shape_available() conditional replace the
try/except with reading the device's compute-major (e.g.,
device.compute_capability_major or device.compute_capability[0] /
device.cc_major depending on the device object in this codebase) and return
False when major == 12 (i.e., return major != 12); if that attribute access
might not exist, explicitly handle only the specific exception(s) you expect
(AttributeError/TypeError) and in that narrow except fallback to calling
is_sm12x_supported(device) or propagate the error — do not use a bare except
Exception.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9c310d60-54df-4bc6-a7fa-aa25510974c4

📥 Commits

Reviewing files that changed from the base of the PR and between 9e3d8b9 and 06d5ee5.

📒 Files selected for processing (2)
  • flashinfer/gemm/gemm_base.py
  • tests/gemm/test_mm_fp4.py

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a safeguard to disable the cuDNN FP4 override-shape path on SM12x (Blackwell) architectures, which currently produces incorrect NaN/Inf outputs. It adds a regression test specifically for these architectures. A review comment points out that the current implementation of the safeguard relies on a CUDA version check within is_sm12x_supported, which could inadvertently re-enable the buggy path on SM12x systems with older CUDA versions; a direct compute capability check is suggested instead.

if not is_cudnn_override_shape_available():
return False
try:
return not is_sm12x_supported(device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using is_sm12x_supported(device) as a guard here might be problematic because it includes a CUDA version check (requiring CUDA 12.8/12.9+). If a user runs on SM120/121 with an older CUDA version (but with a cuDNN version that supports override shapes), is_sm12x_supported will return False, causing this function to return True (trusted). This would re-expose the NaN/Inf issue on those systems. It is safer to check the compute capability major version directly to cover all SM12x architectures regardless of the CUDA version.

Suggested change
return not is_sm12x_supported(device)
major, _ = get_compute_capability(device)
return major != 12

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants