Skip to content

Add TGV NVFP4 GEMM tactic to mm_fp4 cute-dsl backend (SM100/SM103)#3141

Open
Sinestro38 wants to merge 1 commit intoflashinfer-ai:mainfrom
Sinestro38:tgv-nvfp4-mm-fp4-backend
Open

Add TGV NVFP4 GEMM tactic to mm_fp4 cute-dsl backend (SM100/SM103)#3141
Sinestro38 wants to merge 1 commit intoflashinfer-ai:mainfrom
Sinestro38:tgv-nvfp4-mm-fp4-backend

Conversation

@Sinestro38
Copy link
Copy Markdown

@Sinestro38 Sinestro38 commented Apr 21, 2026

Summary

  • Ports the TGV NVFP4 GEMM CuTe-DSL kernel (low-latency Blackwell GEMM tuned for small-N / decode shapes) to flashinfer/cute_dsl/tgv_gemm_nvfp4.py.
  • Exposes it as an additional tactic inside the existing "cute-dsl" backend of mm_fp4 (not a new top-level backend). The autotuner picks between the standard SM100/SM103 tactics and TGV.
  • Tactic is only added when the problem satisfies the TGV alignment constraints: NVFP4, M%128==0, N%8==0, K%256==0.

Notes

  • Uses the tcgen05.cp (non-tcgen05.st) SFB path, which matches the layout produced by nvfp4_quantize(sfLayout=layout_128x4, do_shuffle=False) so scale factors pass through zero-copy.
  • alpha is applied post-hoc (out.mul_) since the kernel doesn't fuse it.
  • Compiled kernels cached by (out_dtype, enable_pdl); problem M/N/K are runtime-dynamic.
  • The tactic falls back to an M-contiguous scratch + copy-back when out is row-major (TGV writes M-contiguous). Allocating out column-major upstream avoids this extra copy.

Test plan

  • Run on SM100 hardware: mm_fp4(..., backend="cute-dsl") with autotuning enabled on decode-shape problems (e.g., M=128, N=8, K=4096) and confirm the TGV tactic wins.
  • Verify numerical match against a cutlass/cudnn reference.
  • Benchmark latency vs the existing SM100/SM103 cute-dsl tactics on small-N shapes.

Summary by CodeRabbit

  • New Features
    • Added a TGV NVFP4 GEMM backend for FP4 matrix multiplication on supported GPUs; available when selected and shape constraints are met (M % 128 == 0, N % 8 == 0, K % 256 == 0).
    • Exposes TGV NVFP4 kernel and compile helper via the DSL interface when that DSL is available.
  • Documentation
    • Updated mm_fp4 docs to describe the new tactic, selection rules, and execution behavior.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Adds a TGV (CuTe-DSL) FP4 GEMM backend hook into mm_fp4, including kernel compilation/caching and runtime runner, and re-exports TGV kernel/compiler symbols from flashinfer.cute_dsl when CuTe-DSL is available.

Changes

Cohort / File(s) Summary
CuTe DSL Public API
flashinfer/cute_dsl/__init__.py
Conditionally re-export TgvGemmKernel as TgvGemmNvfp4Kernel and compile_tgv_gemm_nvfp4 from .tgv_gemm_nvfp4 when CuTe-DSL is available; extend __all__ accordingly.
FP4 GEMM Backend Integration
flashinfer/gemm/gemm_base.py
Add "tgv" tactic to FP4 autotuner when use_nvfp4 enabled and shapes meet M%128==0, N%8==0, real_k%256==0; implement _run_tgv_nvfp4_tactic with a module-level kernel cache, compile-on-miss, CUTLASS/CuTe pointer binding, conditional scratch-output transpose/copyback, and optional post-multiply by alpha_tensor.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant mm_fp4 as mm_fp4
    participant Router as Backend Router
    participant Req as TGV Requirements
    participant Runner as TGV Runner
    participant Cache as Kernel Cache
    participant Compiler as Compiler
    participant GPU as GPU Device

    Caller->>mm_fp4: mm_fp4(..., backend="tgv" or auto)
    mm_fp4->>Router: select "tgv" path
    Router->>Req: validate CuTe-DSL, SM, and shape constraints (M%128,N%8,K%256)
    alt requirements met
        Req-->>Router: OK
        Router->>Runner: run TGV NVFP4 tactic
        Runner->>Cache: lookup kernel key (out_dtype, enable_pdl)
        alt cache hit
            Cache-->>Runner: cached kernel
        else cache miss
            Runner->>Compiler: compile_tgv_gemm_nvfp4(...)
            Compiler-->>Runner: compiled kernel
            Runner->>Cache: store compiled kernel
        end
        Runner->>GPU: bind pointers & launch kernel
        GPU-->>Runner: execution complete
        Runner->>Runner: transpose/copy output if layout mismatch
        alt alpha_tensor provided
            Runner->>Runner: apply post-hoc scaling
        end
        Runner-->>mm_fp4: result tensor
    else requirements failed
        Req-->>mm_fp4: fallthrough/error
    end
    mm_fp4-->>Caller: return tensor
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

cute-dsl, run-ci

Suggested reviewers

  • nv-yunzheq
  • yzh119
  • bkryu
  • cyx-6
  • jiahanc
  • jimmyzho
  • nvmbreughe

Poem

🐇 I hop through kernels, swift and spry,
I compile TGV under moonlit sky,
Cache a hop, launch a bound,
FP4 vectors hum around —
A rabbit's cheer for speedy math!

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive PR description is comprehensive, well-structured, and clearly explains changes; however, it does not follow the required template sections (Related Issues, Pre-commit Checks, Tests). Complete the PR description by adding the missing template sections: explicitly link related issues, confirm pre-commit checks are completed, and check off test items or explain what testing has been done.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding a TGV NVFP4 GEMM tactic to the mm_fp4 cute-dsl backend for SM100/SM103.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the "tgv" backend, a low-latency NVFP4 GEMM optimized for small-N shapes on SM100/SM103 architectures. The implementation includes requirement validation, a specialized runner with kernel caching, and documentation updates. Feedback highlights the need to use symbolic shapes during kernel compilation to properly support dynamic problem sizes and recommends ensuring the scaling tensor is on the correct device before multiplication to prevent potential CPU-GPU synchronization overhead.

Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment on lines +5181 to +5187
compiled_gemm = compile_tgv_gemm_nvfp4(
a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr,
acc_dtype=cutlass.Float32,
problem_mnkl=problem_mnkl,
sf_vec_size=16,
use_pdl=enable_pdl,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The compile_tgv_gemm_nvfp4 call passes concrete problem_mnkl values (M, N, K, L) during compilation. This will cause the kernel to be specialized for the first shape encountered. Since the kernel is cached by (out_dtype, enable_pdl) but used for dynamic shapes (as stated in the PR description), this specialization will lead to incorrect results or crashes when the runner is reused for different problem sizes. To support dynamic shapes, problem_mnkl should be passed as None during compilation, allowing the CuTe-DSL to use symbolic integers.

Suggested change
compiled_gemm = compile_tgv_gemm_nvfp4(
a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr,
acc_dtype=cutlass.Float32,
problem_mnkl=problem_mnkl,
sf_vec_size=16,
use_pdl=enable_pdl,
)
compiled_gemm = _TGV_MM_FP4_KERNEL_CACHE.get(cache_key)
if compiled_gemm is None:
compiled_gemm = compile_tgv_gemm_nvfp4(
a_ptr, b_ptr, sfa_ptr, sfb_ptr, c_ptr,
acc_dtype=cutlass.Float32,
problem_mnkl=None,
sf_vec_size=16,
use_pdl=enable_pdl,
)
_TGV_MM_FP4_KERNEL_CACHE[cache_key] = compiled_gemm

Comment thread flashinfer/gemm/gemm_base.py Outdated

# alpha is applied post-hoc; the TGV kernel does not fuse it.
if alpha_tensor is not None:
out.mul_(alpha_tensor.to(out.dtype))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Applying alpha via out.mul_ might trigger an expensive CPU-GPU synchronization if alpha_tensor is a CPU tensor (which is common for global scaling factors in PyTorch). It is safer to ensure the tensor is on the correct device and has the correct dtype before the in-place multiplication to maintain low latency.

Suggested change
out.mul_(alpha_tensor.to(out.dtype))
if alpha_tensor is not None:
out.mul_(alpha_tensor.to(device=out.device, dtype=out.dtype))

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/gemm/gemm_base.py (1)

5433-5437: ⚠️ Potential issue | 🟡 Minor

Update the enable_pdl docs for TGV.

The new runner passes enable_pdl into compile_tgv_gemm_nvfp4, so this parameter is no longer only used by cute_dsl.

Proposed fix
     enable_pdl: bool
-        Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl``
+        Whether to enable Programmatic Dependent Launch (PDL) for the ``cute_dsl`` and ``tgv``
         backend, defaults to ``True``. PDL allows overlapping the tail of one kernel
         with the start of the next for reduced launch latency. This parameter is
-        only used by the ``cute_dsl`` backend and is ignored by other backends.
+        ignored by other backends.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 5433 - 5437, Update the docstring
for the enable_pdl parameter to reflect that it is also used by the TGV path:
change its description where defined (in gemm_base.py near the enable_pdl
parameter) to state that enable_pdl enables Programmatic Dependent Launch for
the cute_dsl backend and is also passed into compile_tgv_gemm_nvfp4 for the TGV
runner (i.e., not ignored by other backends). Keep the existing explanation of
what PDL does and that it defaults to True, but remove the claim that it is only
used by cute_dsl.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Line 5130: Rename the ambiguous variable `l` (triggers Ruff E741) to a
descriptive dimension name (e.g., `length_dim`, `layers`, or another name
matching the surrounding semantics) in the assignment `l = 1` and update every
subsequent use in the same scope (function/class) to the new identifier; ensure
you update any related docstrings/comments and type hints or default arguments
that referenced `l` so the code and names remain consistent and tests/linters
pass.

---

Outside diff comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 5433-5437: Update the docstring for the enable_pdl parameter to
reflect that it is also used by the TGV path: change its description where
defined (in gemm_base.py near the enable_pdl parameter) to state that enable_pdl
enables Programmatic Dependent Launch for the cute_dsl backend and is also
passed into compile_tgv_gemm_nvfp4 for the TGV runner (i.e., not ignored by
other backends). Keep the existing explanation of what PDL does and that it
defaults to True, but remove the claim that it is only used by cute_dsl.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: ee26291d-5ab3-44b7-9f63-e8d285edd92e

📥 Commits

Reviewing files that changed from the base of the PR and between 9e3d8b9 and a228a63.

📒 Files selected for processing (3)
  • flashinfer/cute_dsl/__init__.py
  • flashinfer/cute_dsl/tgv_gemm_nvfp4.py
  • flashinfer/gemm/gemm_base.py

Comment thread flashinfer/gemm/gemm_base.py Outdated
k_packed = a.shape[1]
n = b.shape[1]
real_k = k_packed * 2
l = 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Rename l to satisfy Ruff E741.

The static analysis failure is valid; l is visually ambiguous. Use a descriptive dimension name instead.

Proposed fix
-            l = 1
+            batch_l = 1
...
-                cutlass.Int32(real_k), cutlass.Int32(l),
+                cutlass.Int32(real_k), cutlass.Int32(batch_l),
🧰 Tools
🪛 Ruff (0.15.10)

[error] 5130-5130: Ambiguous variable name: l

(E741)

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` at line 5130, Rename the ambiguous variable `l`
(triggers Ruff E741) to a descriptive dimension name (e.g., `length_dim`,
`layers`, or another name matching the surrounding semantics) in the assignment
`l = 1` and update every subsequent use in the same scope (function/class) to
the new identifier; ensure you update any related docstrings/comments and type
hints or default arguments that referenced `l` so the code and names remain
consistent and tests/linters pass.

Comment thread flashinfer/gemm/gemm_base.py Outdated
Comment on lines +5132 to +5168
# TGV writes the output M-contiguous (column-major in 2D). mm_fp4's
# default `out` is row-major, so if needed allocate a transposed
# scratch and copy back.
if out.stride(0) == 1 and out.stride(1) == m:
kernel_out = out
copy_back = False
else:
kernel_out = torch.empty_strided(
(m, n), (1, m), dtype=out.dtype, device=out.device
)
copy_back = True

# B comes in as (K_packed, N) col-major. The kernel expects (N, K_packed, L)
# with K contiguous, which is exactly b.T's row-major view.
kernel_b = b.T
kernel_b_sf = b_descale.T

a_ptr = make_ptr(
cutlass.Float4E2M1FN, a.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float4E2M1FN, kernel_b.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16,
)
sfa_ptr = make_ptr(
cutlass.Float8E4M3FN, a_descale.data_ptr(),
cute.AddressSpace.gmem, assumed_align=32,
)
sfb_ptr = make_ptr(
cutlass.Float8E4M3FN, kernel_b_sf.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16,
)
c_ptr = make_ptr(
c_cutlass_dtype, kernel_out.data_ptr(),
cute.AddressSpace.gmem, assumed_align=16,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate out before using it as the kernel destination.

mm_fp4 does not currently validate a caller-provided out; this runner then derives the C pointer type from out_dtype while writing into out.dtype storage. A mismatched shape/device/dtype can produce wrong results or unsafe writes before copy_back has a chance to fail cleanly.

Proposed fix
             real_k = k_packed * 2
-            l = 1
+            l = 1
+
+            expected_out_shape = (m, n)
+            if out.shape != expected_out_shape:
+                raise ValueError(
+                    f"Output shape mismatch. Expected {expected_out_shape}, got {out.shape}."
+                )
+            if out.device != a.device:
+                raise ValueError(
+                    f"Output device mismatch. Expected {a.device}, got {out.device}."
+                )
+            if out.dtype != out_dtype:
+                raise ValueError(
+                    f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}."
+                )
 
             # TGV writes the output M-contiguous (column-major in 2D). mm_fp4's

Ports the TGV GEMM NVFP4 CuTe-DSL kernel (low-latency Blackwell GEMM
tuned for small-N / decode shapes) and exposes it as an additional
tactic inside the existing "cute-dsl" backend of mm_fp4, rather than a
new top-level backend. The autotuner picks between the standard
SM100/SM103 tactics and TGV.

- flashinfer/cute_dsl/tgv_gemm_nvfp4.py: the kernel module, exporting
  TgvGemmNvfp4Kernel and compile_tgv_gemm_nvfp4.
- flashinfer/cute_dsl/__init__.py: re-export the public symbols.
- flashinfer/gemm/gemm_base.py: in _cute_dsl_gemm_fp4_runner, append a
  TGV tactic to get_valid_tactics when the problem is NVFP4 with
  M%128==0, N%8==0, K%256==0, and dispatch to _run_tgv_nvfp4_tactic
  in forward. Uses the tcgen05.cp (non-STTM) SFB path so the layout
  produced by nvfp4_quantize(sfLayout=layout_128x4, do_shuffle=False)
  passes through zero-copy. alpha is applied post-hoc since the kernel
  does not fuse it. Compiled kernels cached by (out_dtype, enable_pdl);
  problem M/N/K are runtime-dynamic.
@Sinestro38 Sinestro38 force-pushed the tgv-nvfp4-mm-fp4-backend branch from a508f7a to f4a4276 Compare April 21, 2026 23:09
@Sinestro38 Sinestro38 changed the title Add TGV NVFP4 GEMM backend to mm_fp4 (SM100/SM103) Add TGV NVFP4 GEMM tactic to mm_fp4 cute-dsl backend (SM100/SM103) Apr 21, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
flashinfer/gemm/gemm_base.py (2)

4643-4647: ⚠️ Potential issue | 🟡 Minor

Rename l to avoid Ruff E741.

l is visually ambiguous and the current static-analysis failure is valid.

Proposed rename
-    l = 1
+    batch_l = 1
     problem_mnkl = (
         cutlass.Int32(m), cutlass.Int32(n),
-        cutlass.Int32(real_k), cutlass.Int32(l),
+        cutlass.Int32(real_k), cutlass.Int32(batch_l),
     )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 4643 - 4647, The variable named
`l` is ambiguous and triggers Ruff E741; rename it to a clearer identifier
(e.g., `l_dim` or `depth_l`) and update its declaration and all usages,
including the tuple `problem_mnkl = (cutlass.Int32(m), cutlass.Int32(n),
cutlass.Int32(real_k), cutlass.Int32(l))` so the tuple uses the new name (e.g.,
`cutlass.Int32(l_dim)`). Make sure you rename the variable consistently
throughout the surrounding scope and any functions or expressions that reference
`l` to avoid unresolved names.

4608-4641: ⚠️ Potential issue | 🟠 Major

Validate out before deriving the kernel destination pointer.

This TGV path still trusts out.shape, out.device, and out.dtype before constructing kernel_out/c_ptr. A mismatched caller-provided out can route a CUDA kernel to the wrong storage or fail only after the kernel has already run.

Proposed guard
+    expected_out_shape = (m, n)
+    if out.shape != expected_out_shape:
+        raise ValueError(
+            f"Output shape mismatch. Expected {expected_out_shape}, got {out.shape}."
+        )
+    if out.device != a.device:
+        raise ValueError(
+            f"Output device mismatch. Expected {a.device}, got {out.device}."
+        )
+    if out.dtype != out_dtype:
+        raise ValueError(
+            f"Output dtype mismatch. Expected {out_dtype}, got {out.dtype}."
+        )
+
     if out.stride(0) == 1 and out.stride(1) == m:
         kernel_out = out
         copy_back = False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 4608 - 4641, Validate the
caller-provided out tensor before deriving kernel_out and c_ptr: check that
out.shape == (m, n), out.device == a.device (or match expected device), and
out.dtype is compatible with c_cutlass_dtype; if any check fails, raise a clear
exception (or fall back to allocating a correctly-shaped/device/dtype tensor)
instead of proceeding to create kernel_out/c_ptr; update the logic around
kernel_out, copy_back and c_ptr (functions/variables: kernel_out, c_ptr, out, m,
n, c_cutlass_dtype) so the kernel never gets a pointer into mismatched storage.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4817-4824: The current logic appends a TGV tactic into
valid_tactics guarded only by use_nvfp4, which lets callers request
backend="tgv" yet still get included when backend="cute-dsl"; update the
selection so TGV is a distinct opt-in backend path: either (A) add a backend
check (e.g., backend == "tgv") around the block that appends the TGV tactic (the
use_nvfp4 / valid_tactics.append(...) site) and similarly at the other
occurrence (around the similar block at the other location), or (B) if
backend=="tgv" is requested force-only TGV tactics by filtering valid_tactics to
include only entries tagged "tgv" after collection; apply the same change to the
corresponding block at the other lines referenced so TGV is only chosen when the
caller explicitly selects the TGV backend.

---

Duplicate comments:
In `@flashinfer/gemm/gemm_base.py`:
- Around line 4643-4647: The variable named `l` is ambiguous and triggers Ruff
E741; rename it to a clearer identifier (e.g., `l_dim` or `depth_l`) and update
its declaration and all usages, including the tuple `problem_mnkl =
(cutlass.Int32(m), cutlass.Int32(n), cutlass.Int32(real_k), cutlass.Int32(l))`
so the tuple uses the new name (e.g., `cutlass.Int32(l_dim)`). Make sure you
rename the variable consistently throughout the surrounding scope and any
functions or expressions that reference `l` to avoid unresolved names.
- Around line 4608-4641: Validate the caller-provided out tensor before deriving
kernel_out and c_ptr: check that out.shape == (m, n), out.device == a.device (or
match expected device), and out.dtype is compatible with c_cutlass_dtype; if any
check fails, raise a clear exception (or fall back to allocating a
correctly-shaped/device/dtype tensor) instead of proceeding to create
kernel_out/c_ptr; update the logic around kernel_out, copy_back and c_ptr
(functions/variables: kernel_out, c_ptr, out, m, n, c_cutlass_dtype) so the
kernel never gets a pointer into mismatched storage.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 5c9d27f7-39d3-4fd6-a86b-61d63dafc84c

📥 Commits

Reviewing files that changed from the base of the PR and between a228a63 and f4a4276.

📒 Files selected for processing (3)
  • flashinfer/cute_dsl/__init__.py
  • flashinfer/cute_dsl/tgv_gemm_nvfp4.py
  • flashinfer/gemm/gemm_base.py

Comment on lines +4817 to +4824
# --- TGV NVFP4 tactic (SM100/SM103, low-latency small-N) ---
# Fixed CTA tile (128, 8, 256). Only eligible when the problem
# satisfies the TGV alignment constraints; we still let the
# autotuner pick between this and the standard SM100/SM103 tactics.
if use_nvfp4 and m % 128 == 0 and n % 8 == 0 and real_k % 256 == 0:
valid_tactics.append(
((128, 8), (1, 1), False, False, "tgv", None)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Wire TGV as a selectable backend if it is meant to be opt-in.

This adds TGV as one candidate inside backend="cute-dsl", so callers cannot request backend="tgv" and backend="cute-dsl" may still pick the standard CuTe-DSL tactics. If the intended API contract is an opt-in TGV backend, add a tgv backend requirement/runner path or make the TGV path force only TGV tactics.

Also applies to: 5365-5369

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/gemm/gemm_base.py` around lines 4817 - 4824, The current logic
appends a TGV tactic into valid_tactics guarded only by use_nvfp4, which lets
callers request backend="tgv" yet still get included when backend="cute-dsl";
update the selection so TGV is a distinct opt-in backend path: either (A) add a
backend check (e.g., backend == "tgv") around the block that appends the TGV
tactic (the use_nvfp4 / valid_tactics.append(...) site) and similarly at the
other occurrence (around the similar block at the other location), or (B) if
backend=="tgv" is requested force-only TGV tactics by filtering valid_tactics to
include only entries tagged "tgv" after collection; apply the same change to the
corresponding block at the other lines referenced so TGV is only chosen when the
caller explicitly selects the TGV backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants