Add TGV NVFP4 GEMM tactic to mm_fp4 cute-dsl backend (SM100/SM103)#3141
Add TGV NVFP4 GEMM tactic to mm_fp4 cute-dsl backend (SM100/SM103)#3141Sinestro38 wants to merge 1 commit intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds a TGV (CuTe-DSL) FP4 GEMM backend hook into mm_fp4, including kernel compilation/caching and runtime runner, and re-exports TGV kernel/compiler symbols from Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant mm_fp4 as mm_fp4
participant Router as Backend Router
participant Req as TGV Requirements
participant Runner as TGV Runner
participant Cache as Kernel Cache
participant Compiler as Compiler
participant GPU as GPU Device
Caller->>mm_fp4: mm_fp4(..., backend="tgv" or auto)
mm_fp4->>Router: select "tgv" path
Router->>Req: validate CuTe-DSL, SM, and shape constraints (M%128,N%8,K%256)
alt requirements met
Req-->>Router: OK
Router->>Runner: run TGV NVFP4 tactic
Runner->>Cache: lookup kernel key (out_dtype, enable_pdl)
alt cache hit
Cache-->>Runner: cached kernel
else cache miss
Runner->>Compiler: compile_tgv_gemm_nvfp4(...)
Compiler-->>Runner: compiled kernel
Runner->>Cache: store compiled kernel
end
Runner->>GPU: bind pointers & launch kernel
GPU-->>Runner: execution complete
Runner->>Runner: transpose/copy output if layout mismatch
alt alpha_tensor provided
Runner->>Runner: apply post-hoc scaling
end
Runner-->>mm_fp4: result tensor
else requirements failed
Req-->>mm_fp4: fallthrough/error
end
mm_fp4-->>Caller: return tensor
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request adds support for the "tgv" backend, a low-latency NVFP4 GEMM optimized for small-N shapes on SM100/SM103 architectures. The implementation includes requirement validation, a specialized runner with kernel caching, and documentation updates. Feedback highlights the need to use symbolic shapes during kernel compilation to properly support dynamic problem sizes and recommends ensuring the scaling tensor is on the correct device before multiplication to prevent potential CPU-GPU synchronization overhead.
| compiled_gemm = compile_tgv_gemm_nvfp4( | ||
| a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, | ||
| acc_dtype=cutlass.Float32, | ||
| problem_mnkl=problem_mnkl, | ||
| sf_vec_size=16, | ||
| use_pdl=enable_pdl, | ||
| ) |
There was a problem hiding this comment.
The compile_tgv_gemm_nvfp4 call passes concrete problem_mnkl values (M, N, K, L) during compilation. This will cause the kernel to be specialized for the first shape encountered. Since the kernel is cached by (out_dtype, enable_pdl) but used for dynamic shapes (as stated in the PR description), this specialization will lead to incorrect results or crashes when the runner is reused for different problem sizes. To support dynamic shapes, problem_mnkl should be passed as None during compilation, allowing the CuTe-DSL to use symbolic integers.
| compiled_gemm = compile_tgv_gemm_nvfp4( | |
| a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, | |
| acc_dtype=cutlass.Float32, | |
| problem_mnkl=problem_mnkl, | |
| sf_vec_size=16, | |
| use_pdl=enable_pdl, | |
| ) | |
| compiled_gemm = _TGV_MM_FP4_KERNEL_CACHE.get(cache_key) | |
| if compiled_gemm is None: | |
| compiled_gemm = compile_tgv_gemm_nvfp4( | |
| a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr, | |
| acc_dtype=cutlass.Float32, | |
| problem_mnkl=None, | |
| sf_vec_size=16, | |
| use_pdl=enable_pdl, | |
| ) | |
| _TGV_MM_FP4_KERNEL_CACHE[cache_key] = compiled_gemm |
|
|
||
| # alpha is applied post-hoc; the TGV kernel does not fuse it. | ||
| if alpha_tensor is not None: | ||
| out.mul_(alpha_tensor.to(out.dtype)) |
There was a problem hiding this comment.
Applying alpha via out.mul_ might trigger an expensive CPU-GPU synchronization if alpha_tensor is a CPU tensor (which is common for global scaling factors in PyTorch). It is safer to ensure the tensor is on the correct device and has the correct dtype before the in-place multiplication to maintain low latency.
| out.mul_(alpha_tensor.to(out.dtype)) | |
| if alpha_tensor is not None: | |
| out.mul_(alpha_tensor.to(device=out.device, dtype=out.dtype)) |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)
5433-5437:⚠️ Potential issue | 🟡 MinorUpdate the
enable_pdldocs for TGV.The new runner passes
enable_pdlintocompile_tgv_gemm_nvfp4, so this parameter is no longer only used bycute_dsl.Proposed fix
enable_pdl: bool - Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl`` + Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl`` and ``tgv`` backend, defaults to ``True``. PDL allows overlapping the tail of one kernel with the start of the next for reduced launch latency. This parameter is - only used by the ``cute_dsl`` backend and is ignored by other backends. + ignored by other backends.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 5433 - 5437, Update the docstring for the enable_pdl parameter to reflect that it is also used by the TGV path: change its description where defined (in gemm_base.py near the enable_pdl parameter) to state that enable_pdl enables Programmatic Dependent Launch for the cute_dsl backend and is also passed into compile_tgv_gemm_nvfp4 for the TGV runner (i.e., not ignored by other backends). Keep the existing explanation of what PDL does and that it defaults to True, but remove the claim that it is only used by cute_dsl.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Line 5130: Rename the ambiguous variable `l` (triggers Ruff E741) to a
descriptive dimension name (e.g., `length_dim`, `layers`, or another name
matching the surrounding semantics) in the assignment `l = 1` and update every
subsequent use in the same scope (function/class) to the new identifier; ensure
you update any related docstrings/comments and type hints or default arguments
that referenced `l` so the code and names remain consistent and tests/linters
pass.
---
Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5433-5437: Update the docstring for the enable_pdl parameter to
reflect that it is also used by the TGV path: change its description where
defined (in gemm_base.py near the enable_pdl parameter) to state that enable_pdl
enables Programmatic Dependent Launch for the cute_dsl backend and is also
passed into compile_tgv_gemm_nvfp4 for the TGV runner (i.e., not ignored by
other backends). Keep the existing explanation of what PDL does and that it
defaults to True, but remove the claim that it is only used by cute_dsl.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: ee26291d-5ab3-44b7-9f63-e8d285edd92e
📒 Files selected for processing (3)
flashinfer/cute_dsl/__init__.pyflashinfer/cute_dsl/tgv_gemm_nvfp4.pyflashinfer/gemm/gemm_base.py
| k_packed = a.shape[1] | ||
| n = b.shape[1] | ||
| real_k = k_packed * 2 | ||
| l = 1 |
There was a problem hiding this comment.
Rename l to satisfy Ruff E741.
The static analysis failure is valid; l is visually ambiguous. Use a descriptive dimension name instead.
Proposed fix
- l = 1
+ batch_l = 1
...
- cutlass.Int32(real_k), cutlass.Int32(l),
+ cutlass.Int32(real_k), cutlass.Int32(batch_l),🧰 Tools
🪛 Ruff (0.15.10)
[error] 5130-5130: Ambiguous variable name: l
(E741)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` at line 5130, Rename the ambiguous variable `l`
(triggers Ruff E741) to a descriptive dimension name (e.g., `length_dim`,
`layers`, or another name matching the surrounding semantics) in the assignment
`l = 1` and update every subsequent use in the same scope (function/class) to
the new identifier; ensure you update any related docstrings/comments and type
hints or default arguments that referenced `l` so the code and names remain
consistent and tests/linters pass.
| # TGV writes the output M-contiguous (column-major in 2D). mm_fp4's | ||
| # default `out` is row-major, so if needed allocate a transposed | ||
| # scratch and copy back. | ||
| if out.stride(0) == 1 and out.stride(1) == m: | ||
| kernel_out = out | ||
| copy_back = False | ||
| else: | ||
| kernel_out = torch.empty_strided( | ||
| (m, n), (1, m), dtype=out.dtype, device=out.device | ||
| ) | ||
| copy_back = True | ||
|
|
||
| # B comes in as (K_packed, N) col-major. The kernel expects (N, K_packed, L) | ||
| # with K contiguous, which is exactly b.T's row-major view. | ||
| kernel_b = b.T | ||
| kernel_b_sf = b_descale.T | ||
|
|
||
| a_ptr = make_ptr( | ||
| cutlass.Float4E2M1FN, a.data_ptr(), | ||
| cute.AddressSpace.gmem, assumed_align=16, | ||
| ) | ||
| b_ptr = make_ptr( | ||
| cutlass.Float4E2M1FN, kernel_b.data_ptr(), | ||
| cute.AddressSpace.gmem, assumed_align=16, | ||
| ) | ||
| sfa_ptr = make_ptr( | ||
| cutlass.Float8E4M3FN, a_descale.data_ptr(), | ||
| cute.AddressSpace.gmem, assumed_align=32, | ||
| ) | ||
| sfb_ptr = make_ptr( | ||
| cutlass.Float8E4M3FN, kernel_b_sf.data_ptr(), | ||
| cute.AddressSpace.gmem, assumed_align=16, | ||
| ) | ||
| c_ptr = make_ptr( | ||
| c_cutlass_dtype, kernel_out.data_ptr(), | ||
| cute.AddressSpace.gmem, assumed_align=16, | ||
| ) |
There was a problem hiding this comment.
Validate out before using it as the kernel destination.
mm_fp4 does not currently validate a caller-provided out; this runner then derives the C pointer type from out_dtype while writing into out.dtype storage. A mismatched shape/device/dtype can produce wrong results or unsafe writes before copy_back has a chance to fail cleanly.
Proposed fix
real_k = k_packed * 2
- l = 1
+ l = 1
+
+ expected_out_shape = (m, n)
+ if out.shape != expected_out_shape:
+ raise ValueError(
+ f"Output shape mismatch. Expected {expected_out_shape}, got {out.shape}."
+ )
+ if out.device != a.device:
+ raise ValueError(
+ f"Output device mismatch. Expected {a.device}, got {out.device}."
+ )
+ if out.dtype != out_dtype:
+ raise ValueError(
+ f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}."
+ )
# TGV writes the output M-contiguous (column-major in 2D). mm_fp4'sPorts the TGV GEMM NVFP4 CuTe-DSL kernel (low-latency Blackwell GEMM tuned for small-N / decode shapes) and exposes it as an additional tactic inside the existing "cute-dsl" backend of mm_fp4, rather than a new top-level backend. The autotuner picks between the standard SM100/SM103 tactics and TGV. - flashinfer/cute_dsl/tgv_gemm_nvfp4.py: the kernel module, exporting TgvGemmNvfp4Kernel and compile_tgv_gemm_nvfp4. - flashinfer/cute_dsl/__init__.py: re-export the public symbols. - flashinfer/gemm/gemm_base.py: in _cute_dsl_gemm_fp4_runner, append a TGV tactic to get_valid_tactics when the problem is NVFP4 with M%128==0, N%8==0, K%256==0, and dispatch to _run_tgv_nvfp4_tactic in forward. Uses the tcgen05.cp (non-STTM) SFB path so the layout produced by nvfp4_quantize(sfLayout=layout_128x4, do_shuffle=False) passes through zero-copy. alpha is applied post-hoc since the kernel does not fuse it. Compiled kernels cached by (out_dtype, enable_pdl); problem M/N/K are runtime-dynamic.
a508f7a to
f4a4276
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)
4643-4647:⚠️ Potential issue | 🟡 MinorRename
lto avoid Ruff E741.
lis visually ambiguous and the current static-analysis failure is valid.Proposed rename
- l = 1 + batch_l = 1 problem_mnkl = ( cutlass.Int32(m), cutlass.Int32(n), - cutlass.Int32(real_k), cutlass.Int32(l), + cutlass.Int32(real_k), cutlass.Int32(batch_l), )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 4643 - 4647, The variable named `l` is ambiguous and triggers Ruff E741; rename it to a clearer identifier (e.g., `l_dim` or `depth_l`) and update its declaration and all usages, including the tuple `problem_mnkl = (cutlass.Int32(m), cutlass.Int32(n), cutlass.Int32(real_k), cutlass.Int32(l))` so the tuple uses the new name (e.g., `cutlass.Int32(l_dim)`). Make sure you rename the variable consistently throughout the surrounding scope and any functions or expressions that reference `l` to avoid unresolved names.
4608-4641:⚠️ Potential issue | 🟠 MajorValidate
outbefore deriving the kernel destination pointer.This TGV path still trusts
out.shape,out.device, andout.dtypebefore constructingkernel_out/c_ptr. A mismatched caller-providedoutcan route a CUDA kernel to the wrong storage or fail only after the kernel has already run.Proposed guard
+ expected_out_shape = (m, n) + if out.shape != expected_out_shape: + raise ValueError( + f"Output shape mismatch. Expected {expected_out_shape}, got {out.shape}." + ) + if out.device != a.device: + raise ValueError( + f"Output device mismatch. Expected {a.device}, got {out.device}." + ) + if out.dtype != out_dtype: + raise ValueError( + f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}." + ) + if out.stride(0) == 1 and out.stride(1) == m: kernel_out = out copy_back = False🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/gemm/gemm_base.py` around lines 4608 - 4641, Validate the caller-provided out tensor before deriving kernel_out and c_ptr: check that out.shape == (m, n), out.device == a.device (or match expected device), and out.dtype is compatible with c_cutlass_dtype; if any check fails, raise a clear exception (or fall back to allocating a correctly-shaped/device/dtype tensor) instead of proceeding to create kernel_out/c_ptr; update the logic around kernel_out, copy_back and c_ptr (functions/variables: kernel_out, c_ptr, out, m, n, c_cutlass_dtype) so the kernel never gets a pointer into mismatched storage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4817-4824: The current logic appends a TGV tactic into
valid_tactics guarded only by use_nvfp4, which lets callers request
backend="tgv" yet still get included when backend="cute-dsl"; update the
selection so TGV is a distinct opt-in backend path: either (A) add a backend
check (e.g., backend == "tgv") around the block that appends the TGV tactic (the
use_nvfp4 / valid_tactics.append(...) site) and similarly at the other
occurrence (around the similar block at the other location), or (B) if
backend=="tgv" is requested force-only TGV tactics by filtering valid_tactics to
include only entries tagged "tgv" after collection; apply the same change to the
corresponding block at the other lines referenced so TGV is only chosen when the
caller explicitly selects the TGV backend.
---
Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4643-4647: The variable named `l` is ambiguous and triggers Ruff
E741; rename it to a clearer identifier (e.g., `l_dim` or `depth_l`) and update
its declaration and all usages, including the tuple `problem_mnkl =
(cutlass.Int32(m), cutlass.Int32(n), cutlass.Int32(real_k), cutlass.Int32(l))`
so the tuple uses the new name (e.g., `cutlass.Int32(l_dim)`). Make sure you
rename the variable consistently throughout the surrounding scope and any
functions or expressions that reference `l` to avoid unresolved names.
- Around line 4608-4641: Validate the caller-provided out tensor before deriving
kernel_out and c_ptr: check that out.shape == (m, n), out.device == a.device (or
match expected device), and out.dtype is compatible with c_cutlass_dtype; if any
check fails, raise a clear exception (or fall back to allocating a
correctly-shaped/device/dtype tensor) instead of proceeding to create
kernel_out/c_ptr; update the logic around kernel_out, copy_back and c_ptr
(functions/variables: kernel_out, c_ptr, out, m, n, c_cutlass_dtype) so the
kernel never gets a pointer into mismatched storage.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 5c9d27f7-39d3-4fd6-a86b-61d63dafc84c
📒 Files selected for processing (3)
flashinfer/cute_dsl/__init__.pyflashinfer/cute_dsl/tgv_gemm_nvfp4.pyflashinfer/gemm/gemm_base.py
| # --- TGV NVFP4 tactic (SM100/SM103, low-latency small-N) --- | ||
| # Fixed CTA tile (128, 8, 256). Only eligible when the problem | ||
| # satisfies the TGV alignment constraints; we still let the | ||
| # autotuner pick between this and the standard SM100/SM103 tactics. | ||
| if use_nvfp4 and m % 128 == 0 and n % 8 == 0 and real_k % 256 == 0: | ||
| valid_tactics.append( | ||
| ((128, 8), (1, 1), False, False, "tgv", None) | ||
| ) |
There was a problem hiding this comment.
Wire TGV as a selectable backend if it is meant to be opt-in.
This adds TGV as one candidate inside backend="cute-dsl", so callers cannot request backend="tgv" and backend="cute-dsl" may still pick the standard CuTe-DSL tactics. If the intended API contract is an opt-in TGV backend, add a tgv backend requirement/runner path or make the TGV path force only TGV tactics.
Also applies to: 5365-5369
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/gemm/gemm_base.py` around lines 4817 - 4824, The current logic
appends a TGV tactic into valid_tactics guarded only by use_nvfp4, which lets
callers request backend="tgv" yet still get included when backend="cute-dsl";
update the selection so TGV is a distinct opt-in backend path: either (A) add a
backend check (e.g., backend == "tgv") around the block that appends the TGV
tactic (the use_nvfp4 / valid_tactics.append(...) site) and similarly at the
other occurrence (around the similar block at the other location), or (B) if
backend=="tgv" is requested force-only TGV tactics by filtering valid_tactics to
include only entries tagged "tgv" after collection; apply the same change to the
corresponding block at the other lines referenced so TGV is only chosen when the
caller explicitly selects the TGV backend.
Summary
flashinfer/cute_dsl/tgv_gemm_nvfp4.py."cute-dsl"backend ofmm_fp4(not a new top-level backend). The autotuner picks between the standard SM100/SM103 tactics and TGV.M%128==0,N%8==0,K%256==0.Notes
tcgen05.cp(non-tcgen05.st) SFB path, which matches the layout produced bynvfp4_quantize(sfLayout=layout_128x4, do_shuffle=False)so scale factors pass through zero-copy.alphais applied post-hoc (out.mul_) since the kernel doesn't fuse it.(out_dtype, enable_pdl); problem M/N/K are runtime-dynamic.outis row-major (TGV writes M-contiguous). Allocatingoutcolumn-major upstream avoids this extra copy.Test plan
mm_fp4(..., backend="cute-dsl")with autotuning enabled on decode-shape problems (e.g.,M=128, N=8, K=4096) and confirm the TGV tactic wins.Summary by CodeRabbit