fix(gemm): skip FP4 cuDNN override-shape path on SM120/SM121 (NaN regression from #2910)#3140
Conversation
The FP4 override-shape fast path in CudnnFp4GemmRunner, introduced in flashinfer-ai#2910, returns NaN/Inf output on SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10) for realistic NVFP4 shapes. Confirmed on Nemotron-3-Nano-30B-FP4 via sglang; the corrupt logits trip the torch sampler's "probability tensor contains inf/nan" assert on real requests. Add `_is_fp4_cudnn_override_shape_trusted(device)`, which returns True only when override-shape is available *and* the device is not SM12x, and use it at the two FP4 call sites (get_valid_tactics, forward). The static-shape cuDNN path it falls back to is the pre-flashinfer-ai#2910 behavior, uses the same backend, and is numerically correct. Scope: - Only CudnnFp4GemmRunner is gated. The BF16 and MXFP8 override-shape paths go through different helpers, are not implicated, and keep the flashinfer-ai#2910 fast path on all archs. - The guard uses is_sm12x_supported(), matching the convention already used elsewhere in gemm_base.py. - Helper fails closed (returns False) if compute capability cannot be resolved, so an error path cannot re-expose the NaN behavior. This is a guard, not a root-cause fix. Suspected culprits in flashinfer-ai#2910 are _get_real_fp4_shape_from_packed_uint8, _expand_block_scale_tensor_shape, and the `cache_m = last_positive_power_of_2(actual_m)` bucketing in _get_override_graph. A follow-up PR will narrow the fault and remove this guard. Add an SM12x-only regression test in tests/gemm/test_mm_fp4.py that runs mm_fp4(backend="cudnn") on shapes known to trigger the NaN and relies on the existing cosine-similarity assertion in _test_mm_fp4 to catch any non-finite output.
📝 WalkthroughWalkthroughAdded a device-specific trust gate function Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
flashinfer/gemm/gemm_base.py (1)
1770-1774: Prefer raw compute capability for the SM12x blocklist.
is_sm12x_supported()includes CUDA version checks (≥12.8 for SM120, ≥12.9 for SM121). Usingnot is_sm12x_supported(device)couples the architecture blocklist to toolkit availability: an SM120 device with insufficient CUDA would incorrectly be allowed to use override-shape. Since the NaN/Inf issue is architectural (stated in the docstring as specific to SM120/SM121), checkingmajor != 12directly better expresses the intent and avoids unnecessary coupling. This also narrows the overly broadexcept Exceptionflagged by Ruff.Proposed refactor
if not is_cudnn_override_shape_available(): return False try: - return not is_sm12x_supported(device) - except Exception: + major, _ = get_compute_capability(device) + except (RuntimeError, ValueError, TypeError): # Fail closed: if we cannot resolve the arch, do not re-expose the # NaN path. return False + return major != 12🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 1770 - 1774, The code should block SM120/SM121 by checking the device's raw compute capability major instead of calling is_sm12x_supported(device) (which mixes in CUDA-toolkit checks) and avoid the broad except; in the is_cudnn_override_shape_available() conditional replace the try/except with reading the device's compute-major (e.g., device.compute_capability_major or device.compute_capability[0] / device.cc_major depending on the device object in this codebase) and return False when major == 12 (i.e., return major != 12); if that attribute access might not exist, explicitly handle only the specific exception(s) you expect (AttributeError/TypeError) and in that narrow except fallback to calling is_sm12x_supported(device) or propagate the error — do not use a bare except Exception.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 1770-1774: The code should block SM120/SM121 by checking the
device's raw compute capability major instead of calling
is_sm12x_supported(device) (which mixes in CUDA-toolkit checks) and avoid the
broad except; in the is_cudnn_override_shape_available() conditional replace the
try/except with reading the device's compute-major (e.g.,
device.compute_capability_major or device.compute_capability[0] /
device.cc_major depending on the device object in this codebase) and return
False when major == 12 (i.e., return major != 12); if that attribute access
might not exist, explicitly handle only the specific exception(s) you expect
(AttributeError/TypeError) and in that narrow except fallback to calling
is_sm12x_supported(device) or propagate the error — do not use a bare except
Exception.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 9c310d60-54df-4bc6-a7fa-aa25510974c4
📒 Files selected for processing (2)
flashinfer/gemm/gemm_base.pytests/gemm/test_mm_fp4.py
There was a problem hiding this comment.
Code Review
This pull request implements a safeguard to disable the cuDNN FP4 override-shape path on SM12x (Blackwell) architectures, which currently produces incorrect NaN/Inf outputs. It adds a regression test specifically for these architectures. A review comment points out that the current implementation of the safeguard relies on a CUDA version check within is_sm12x_supported, which could inadvertently re-enable the buggy path on SM12x systems with older CUDA versions; a direct compute capability check is suggested instead.
| if not is_cudnn_override_shape_available(): | ||
| return False | ||
| try: | ||
| return not is_sm12x_supported(device) |
There was a problem hiding this comment.
Using is_sm12x_supported(device) as a guard here might be problematic because it includes a CUDA version check (requiring CUDA 12.8/12.9+). If a user runs on SM120/121 with an older CUDA version (but with a cuDNN version that supports override shapes), is_sm12x_supported will return False, causing this function to return True (trusted). This would re-expose the NaN/Inf issue on those systems. It is safer to check the compute capability major version directly to cover all SM12x architectures regardless of the CUDA version.
| return not is_sm12x_supported(device) | |
| major, _ = get_compute_capability(device) | |
| return major != 12 |
The FP4 override-shape fast path in
CudnnFp4GemmRunneradded in #2910 returns NaN/Inf on SM120 (RTX PRO 6000 Blackwell) and SM121 (DGX Spark GB10) for realistic NVFP4 shapes, silently corrupting logits (reproduced on Nemotron-3-Nano-30B-FP4 via sglang).This PR adds
_is_fp4_cudnn_override_shape_trusted(device)and routes SM12x FP4 back to the static-shape cuDNN path. BF16 and MXFP8 paths are untouched.Guard only — a follow-up will narrow the fault (suspected in
_get_real_fp4_shape_from_packed_uint8,_expand_block_scale_tensor_shape, or thecache_mbucketing in_get_override_graph) and remove this guard.Tests
tests/gemm/test_mm_fp4.py.L0_Nemotron-3-Nano-30B-FP4on SM121 passes with a wheel built from this PR.Summary by CodeRabbit
Bug Fixes
Tests