Skip to content

perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels#2904

Merged
bkryu merged 14 commits intoflashinfer-ai:mainfrom
bkryu:optimize_quant
Apr 1, 2026
Merged

perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels#2904
bkryu merged 14 commits intoflashinfer-ai:mainfrom
bkryu:optimize_quant

Conversation

@bkryu
Copy link
Copy Markdown
Collaborator

@bkryu bkryu commented Mar 27, 2026

📌 Description

Summary

  • Adopt dual-path kernel architecture (linear flat + swizzled row-based) for MXFP4 and NVFP4 CuTe-DSL quantization kernels.
  • Architecture chanes to MXFP8 quantization for better perf.
  • Expand benchmark scripts and test coverage across all three quantization kernels. Compares exact output match between CUDA & CuTe-DSL backends
  • All mxfp4, mxfp8, and nvfp4 quantization have exact bitwise match in for CUDA vs. CuTe DSL backends in both the output and scaling factors.

Kernel changes
mxfp8_quantize.py

  • Adaptive 2T/SF dispatch: 2 threads per SF block for large problems (total_sf >= 65536), 4 threads for small problems for better memory bandwidth utilization.
  • Integer UE8M0 conversion (float_to_ue8m0_fast, ue8m0_to_inv_scale_fast): replaces SFU-based lg2.approx/ex2.approx with integer bit manipulation, freeing the SFU pipeline
  • reduce_max_2threads: 1-shuffle XOR reduction for the 2T path
  • Remove unused self.dtype and self.K attributes

mxfp4_quantize.py

  • Add swizzled kernel. Previously only supported linear layout.
  • Swizzled kernel: small-K multi-row path and large-K column-loop path, compile-time selected via const_expr(needs_col_loop)
  • Inline padding for swizzled layout (row and column) — eliminates the expensive separate flat-iteration padding passes that caused 5x+ regression at small M
  • Dynamic thread count via _compute_optimal_threads(K) for 100% thread utilization

nvfp4_quantize.py

  • Same dual-path split: NVFP4QuantizeLinearKernel + NVFP4QuantizeSwizzledKernel
  • Supports all three SF layouts (128x4, 8x4, linear) with compile-time dispatch
  • Remove unused self.row_tile_size and self.ROW_ITERATIONS from TMA kernel

quantization_cute_dsl_utils.py

  • ue8m0_to_inv_scale_fast: integer bit construction replacing ex2.approx
  • reduce_max_2threads: 1-shuffle reduction for 2T/SF MXFP8 path
  • 2T/SF constants: ELTS_PER_THREAD, THREADS_PER_SF, SF_BLOCKS_PER_WARP + legacy 4T variants
  • MXFP8_2T_SF_THRESHOLD = 65536

Test changes

  • test_fp4_quantize.py and test_fp8_quantize.py: Add more problem sizes.
  • test_fp4_quantize_padding.py: Add both-backend parametrization and CUDA-vs-CuTe-DSL parity test for linear layout padding.

Perf comparison between backends on B200

Click to see mxfp8 performance comparison

Linear (gmean 1.42x)

mxfp8_backend_comparison_linear_bfloat16

Swizzled (gmean 1.37x)

mxfp8_backend_comparison_swizzled_bfloat16
Click to see mxfp4 performance comparison

Linear (gmean 1.41x)

mxfp4_quantize_backend_comparison_linear_bfloat16

Swizzled (gmean 1.39x)

mxfp4_quantize_backend_comparison_swizzled_bfloat16
Click to see nvfp4 performance comparison

Linear (gmean 1.34x)

nvfp4_quantize_backend_comparison_linear_bfloat16

Swizzled (gmean 1.32x)

nvfp4_quantize_backend_comparison_swizzled_bfloat16

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • New Features

    • Benchmarks now run and report both linear and swizzled scale-factor layouts separately, generating distinct heatmaps/tables and layout-aware labels.
    • Added a new NVFP4 quantization benchmark with bandwidth and comparison modes.
  • Refactor

    • Quantize kernels split into layout-specific implementations and introduced dual-mode threading optimizations for better performance across sizes.
  • Tests

    • Expanded parameter sweeps, added backend parameterization, and capability-aware skips.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 27, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds layout-specific MXFP4/MXFP8/NVFP4 quantize kernels (linear and swizzled), threads layout flags through benchmarks and tests, introduces a new NVFP4 benchmark script, updates CuTe-DSL utilities/intrinsics, and expands test shape grids and gating.

Changes

Cohort / File(s) Summary
Benchmarks — MXFP4
benchmarks/bench_mxfp4_quantize_backend_comparison.py
Threaded is_sf_swizzled_layout through correctness and timing flows; switch to flashinfer.quantization.fp4_quantization.fp4_quantize; compute per-run global_sf; run separate linear/swizzled sweeps; configurable layout_name for reporting/heatmaps; function signatures updated.
Benchmarks — MXFP8
benchmarks/bench_mxfp8_quantize_backend_comparison.py
Added verify_mxfp8_correctness(...) with quant/scale agreement and cosine-similarity checks; run per-(m,k) verification and skip failed cases; expanded small-M sweep values and improved console status output.
Benchmarks — NVFP4 (new)
benchmarks/bench_nvfp4_quantize_backend_comparison.py
New script implementing NVFP4 linear/swizzled benchmarks and correctness checks, bandwidth-mode measurement, SM/CuTe-DSL gating, time/bandwidth sweeps, and heatmap/table generation.
Kernel exports
flashinfer/quantization/kernels/__init__.py
Replaced single MXFP4QuantizeKernel export with MXFP4QuantizeLinearKernel and MXFP4QuantizeSwizzledKernel.
MXFP4 kernels
flashinfer/quantization/kernels/mxfp4_quantize.py
Split unified MXFP4 kernel into MXFP4QuantizeLinearKernel and MXFP4QuantizeSwizzledKernel; added layout-specific thread/block strategies, use_4t_per_sf modes, cache-key inclusion, and adjusted launch metadata and buffer sizing.
MXFP8 kernels
flashinfer/quantization/kernels/mxfp8_quantize.py
Added use_2t_per_sf mode (2T/SF vs 4T/SF), recomputed thread/SF tiling, added 2-thread reduction/load path, updated cache helpers and exports.
NVFP4 kernels
flashinfer/quantization/kernels/nvfp4_quantize.py
Added NVFP4QuantizeLinearKernel; refactored swizzled planning via _compute_optimal_threads, added column-loop vs multi-row paths, and return layout-specific compiled metadata.
Quantization utils
flashinfer/quantization/quantization_cute_dsl_utils.py
Changed MXFP8 defaults to 2T/SF (16 elts/thread) while preserving legacy small-problem constants; replaced PTX ex2-based scale inverse with integer float construction; added reduce_max_2threads and new constants.
CuTe-DSL intrinsics
flashinfer/cute_dsl/fp4_common.py
Added st_global_u32 CuTe user-op emitting st.global.u32 PTX inline asm.
Tests — FP4/FP8
tests/utils/test_fp4_quantize.py, tests/utils/test_fp4_quantize_padding.py, tests/utils/test_fp8_quantize.py
Expanded MXFP4/NVFP4 shape lists to include tiny/odd/large-K cases; added backend param and CuTe-DSL availability gating in padding tests; updated MXFP8 tests to new compiled-kernel helper names and enlarged m/k parameter grids.

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Host as Host (runner)
participant CUDA as CUDA backend
participant CuTe as CuTe-DSL backend
participant Disk as Disk (heatmaps/logs)

Host->>CUDA: generate input, compute global_sf, call quantize(backend="cuda")
Host->>CuTe: compile/select kernel(layout, use_?_per_sf), call quantize(backend="cute-dsl")
CUDA-->>Host: return quant, scale, dequant, timing
CuTe-->>Host: return quant, scale, dequant, timing
Host->>Host: compare quant/scale, compute cosine similarity, record pass/fail
Host->>Disk: save heatmap/table per layout

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

cute-dsl, run-ci

Suggested reviewers

  • yzh119
  • cyx-6
  • aleozlx
  • jimmyzho
  • yongwww
  • nv-yunzheq

Poem

🐰 I hopped through threads and scales with care,

Linear rows and swizzled hops now pair.
Benchmarks hum and kernels split with glee,
Heatmaps glow — a carrot patch of QC.

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is incomplete; it is missing the 'Related Issues' section and the 'Tests' checkbox is unchecked despite test changes being present. Complete the 'Related Issues' section and check the 'Tests' checkbox to confirm all tests are passing.
Docstring Coverage ⚠️ Warning Docstring coverage is 70.42% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main optimization objective: CuTe-DSL fp4 and fp8 quantization kernel performance improvements.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@bkryu bkryu changed the title perf: Optimize CuTe-DSL quantization kernels perf: Optimize CuTe-DSL fp4 and fp8 quantization kernels Mar 27, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant optimizations and structural improvements to the MXFP4, MXFP8, and NVFP4 quantization kernels in the CuTe-DSL backend. Key changes include the implementation of a dual-path optimization strategy (linear vs. swizzled layouts) to improve thread utilization, the addition of correctness verification in benchmark suites, and the introduction of a new NVFP4 benchmark. I have identified an inconsistency in the run_benchmark_sweep function docstring where a no_verify parameter is documented but missing from the function signature.

Comment thread benchmarks/bench_mxfp4_quantize_backend_comparison.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
flashinfer/quantization/kernels/__init__.py (1)

44-52: ⚠️ Potential issue | 🟡 Minor

Re-export NVFP4QuantizeLinearKernel here.

flashinfer.quantization.kernels.nvfp4_quantize now publishes NVFP4QuantizeLinearKernel, but this package surface still omits it. That leaves from flashinfer.quantization.kernels import NVFP4QuantizeLinearKernel broken even though the layout split is now public.

🔧 Proposed fix
 from .nvfp4_quantize import (
+    NVFP4QuantizeLinearKernel,
     NVFP4QuantizeSwizzledKernel,
     nvfp4_quantize_cute_dsl,
 )
@@
+    "NVFP4QuantizeLinearKernel",
     "NVFP4QuantizeSwizzledKernel",
     "nvfp4_quantize_cute_dsl",
 ]
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/__init__.py` around lines 44 - 52, The
package __all__ list is missing NVFP4QuantizeLinearKernel which prevents
re-exporting it; update the __all__ in the kernels package to include
"NVFP4QuantizeLinearKernel" alongside the other symbols (e.g., add
"NVFP4QuantizeLinearKernel" to the __all__ list that currently contains
"NVFP4QuantizeSwizzledKernel", "nvfp4_quantize_cute_dsl", etc.) so that from
flashinfer.quantization.kernels import NVFP4QuantizeLinearKernel works as
expected.
benchmarks/bench_mxfp4_quantize_backend_comparison.py (2)

223-257: ⚠️ Potential issue | 🟠 Major

Count swizzled padding in the bandwidth numerator.

This helper always treats scale-factor traffic as m * k / 32, but the swizzled MXFP4 path writes padded_m * padded_sf_cols bytes in flashinfer/quantization/kernels/mxfp4_quantize.py Lines 668-670. The reported TB/s is therefore inflated, especially for small M.

📏 Proposed fix
-def compute_bandwidth_tb_per_sec(
-    m: int, k: int, dtype: torch.dtype, time_ms: float
+def compute_bandwidth_tb_per_sec(
+    m: int,
+    k: int,
+    dtype: torch.dtype,
+    time_ms: float,
+    is_sf_swizzled_layout: bool,
 ) -> float:
@@
-    num_scale_factors = num_elements // SF_VEC_SIZE
+    if is_sf_swizzled_layout:
+        padded_m = ((m + 128 - 1) // 128) * 128
+        padded_sf_cols = (((k // SF_VEC_SIZE) + 3) // 4) * 4
+        num_scale_factors = padded_m * padded_sf_cols
+    else:
+        num_scale_factors = num_elements // SF_VEC_SIZE

You'll also need to thread is_sf_swizzled_layout through the run_bandwidth_sweep call site.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py` around lines 223 -
257, compute_bandwidth_tb_per_sec currently computes scale-factor bytes as
num_elements // SF_VEC_SIZE which ignores swizzled padding and thus overstates
TB/s for swizzled layout; modify compute_bandwidth_tb_per_sec to accept an
is_sf_swizzled_layout flag (or padded dims) and when true compute scale-factor
traffic using padded_m and padded_sf_cols (matching the swizzled write size used
in mxfp4_quantize.py for the MXFP4 path) instead of m * k / SF_VEC_SIZE, i.e.,
calculate num_scale_factors = padded_m * padded_sf_cols and include that in
problem_bytes; also thread the new is_sf_swizzled_layout argument through
run_bandwidth_sweep call sites so the bandwidth helper knows when to use padded
counts.

138-158: ⚠️ Potential issue | 🟠 Major

Exclude non-bitwise-equal cases from the MXFP4 timing sweep.

This has the same hole as the NVFP4 benchmark: quant_match_pct and scale_match_pct are recorded, but the case still counts as verified if cosine stays above 0.9. That makes the benchmark tables look valid even when the backends diverge.

✅ Proposed fix
         # Check backend agreement
         quant_match_pct = (quant_cuda == quant_cute).float().mean().item() * 100
         scale_match_pct = (scale_cuda == scale_cute).float().mean().item() * 100
+        if not torch.equal(quant_cuda, quant_cute) or not torch.equal(
+            scale_cuda, scale_cute
+        ):
+            return (
+                False,
+                f"Backend mismatch: quant={quant_match_pct:.1f}%, scale={scale_match_pct:.1f}%",
+                quant_match_pct,
+                scale_match_pct,
+            )

         # FP4 quantization should have cosine similarity > 0.9
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py` around lines 138 -
158, The current verification returns success whenever cosine similarity
(cos_sim_cuda or cos_sim_cute) >= 0.9 even if quantized outputs differ; change
the logic to require bitwise-equal quantization and scales before marking a case
as verified: compute quant_match_pct and scale_match_pct and if either is <
100.0, return False (or exclude from timing sweep) with a clear message
including quant_match_pct and scale_match_pct, otherwise continue to the cosine
checks; update the block that currently checks cos_sim_cuda/cos_sim_cute so that
the bitwise-equality check (quant_match_pct==100 and scale_match_pct==100) is
performed first.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@benchmarks/bench_mxfp8_quantize_backend_comparison.py`:
- Around line 98-129: The comparison is wrong: scale_match_pct compares scale
tensors with float8 semantics and the cosine similarity uses raw FP8 payloads
instead of dequantized floats. Fix by comparing raw uint8 bytes for both
quantized payloads and scale carriers (use scale_cuda.view(torch.uint8) and
scale_cute.view(torch.uint8) to compute scale_match_pct), and compute cosine
similarity on dequantized outputs (dequantize quant_cuda and quant_cute using
their corresponding scale_cuda/scale_cute and the FP8 format—don’t just
.to(torch.float32) on the raw bytes; produce dq_cuda and dq_cute as true float32
reconstructions before calling torch.nn.functional.cosine_similarity). Ensure
you keep the existing variable names (quant_cuda, quant_cute, scale_cuda,
scale_cute, dq_cuda, dq_cute) so the rest of the function uses the corrected
values.

In `@benchmarks/bench_nvfp4_quantize_backend_comparison.py`:
- Around line 141-161: After computing quant_match_pct and scale_match_pct, add
a strict backend-agreement check that fails the verification if either
percentage is less than 100; specifically, if quant_match_pct < 100 or
scale_match_pct < 100 return (False, f"Backend mismatch:
quant_match_pct={quant_match_pct:.4f}%, scale_match_pct={scale_match_pct:.4f}%",
quant_match_pct, scale_match_pct). Keep this check alongside the existing
cosine-threshold checks (using cos_sim_cuda and cos_sim_cute) so the function
only returns success when both roundtrip quality and bitwise agreement between
quant_cuda and quant_cute (and scales) are satisfied.

In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 668-685: The reshape of scale_output after kernel execution uses
num_sf_blocks_per_row but the swizzled path and the allocation use
padded_sf_cols (scale_output_size = padded_m * padded_sf_cols), causing a
runtime size mismatch for 4-way padded SF columns; update the reshape to use
padded_sf_cols instead of num_sf_blocks_per_row (i.e., scale_output =
scale_output.reshape(-1, padded_sf_cols)) and ensure this change is applied
alongside references to padded_m/padded_sf_cols around kernel_fn and
scale_output allocation so the swizzled scale layout is consistent.

In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 114-116: The warp-count computation in
_compute_optimal_warps_for_k() always uses the constant SF_BLOCKS_PER_WARP (16)
when computing gcd/divisibility, but the legacy path should use
SF_BLOCKS_PER_WARP_SMALL (8) when use_2t_per_sf is False; change the logic so
the function selects the active sf_blocks_per_warp (e.g., set sf_blocks_per_warp
= SF_BLOCKS_PER_WARP if use_2t_per_sf else SF_BLOCKS_PER_WARP_SMALL) and use
that variable in the gcd calculation and any subsequent divisibility/warp-count
math (replace uses of SF_BLOCKS_PER_WARP in the gcd_val and warp derivation with
sf_blocks_per_warp) so rows_per_block and warp/thread counts are computed
correctly for K like 3072.

---

Outside diff comments:
In `@benchmarks/bench_mxfp4_quantize_backend_comparison.py`:
- Around line 223-257: compute_bandwidth_tb_per_sec currently computes
scale-factor bytes as num_elements // SF_VEC_SIZE which ignores swizzled padding
and thus overstates TB/s for swizzled layout; modify
compute_bandwidth_tb_per_sec to accept an is_sf_swizzled_layout flag (or padded
dims) and when true compute scale-factor traffic using padded_m and
padded_sf_cols (matching the swizzled write size used in mxfp4_quantize.py for
the MXFP4 path) instead of m * k / SF_VEC_SIZE, i.e., calculate
num_scale_factors = padded_m * padded_sf_cols and include that in problem_bytes;
also thread the new is_sf_swizzled_layout argument through run_bandwidth_sweep
call sites so the bandwidth helper knows when to use padded counts.
- Around line 138-158: The current verification returns success whenever cosine
similarity (cos_sim_cuda or cos_sim_cute) >= 0.9 even if quantized outputs
differ; change the logic to require bitwise-equal quantization and scales before
marking a case as verified: compute quant_match_pct and scale_match_pct and if
either is < 100.0, return False (or exclude from timing sweep) with a clear
message including quant_match_pct and scale_match_pct, otherwise continue to the
cosine checks; update the block that currently checks cos_sim_cuda/cos_sim_cute
so that the bitwise-equality check (quant_match_pct==100 and
scale_match_pct==100) is performed first.

In `@flashinfer/quantization/kernels/__init__.py`:
- Around line 44-52: The package __all__ list is missing
NVFP4QuantizeLinearKernel which prevents re-exporting it; update the __all__ in
the kernels package to include "NVFP4QuantizeLinearKernel" alongside the other
symbols (e.g., add "NVFP4QuantizeLinearKernel" to the __all__ list that
currently contains "NVFP4QuantizeSwizzledKernel", "nvfp4_quantize_cute_dsl",
etc.) so that from flashinfer.quantization.kernels import
NVFP4QuantizeLinearKernel works as expected.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 9aa2c7e3-005c-4fea-94b0-406037e2f3b6

📥 Commits

Reviewing files that changed from the base of the PR and between 31b63bc and 8a9545c.

📒 Files selected for processing (11)
  • benchmarks/bench_mxfp4_quantize_backend_comparison.py
  • benchmarks/bench_mxfp8_quantize_backend_comparison.py
  • benchmarks/bench_nvfp4_quantize_backend_comparison.py
  • flashinfer/quantization/kernels/__init__.py
  • flashinfer/quantization/kernels/mxfp4_quantize.py
  • flashinfer/quantization/kernels/mxfp8_quantize.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py
  • flashinfer/quantization/quantization_cute_dsl_utils.py
  • tests/utils/test_fp4_quantize.py
  • tests/utils/test_fp4_quantize_padding.py
  • tests/utils/test_fp8_quantize.py

Comment thread benchmarks/bench_mxfp8_quantize_backend_comparison.py
Comment thread benchmarks/bench_nvfp4_quantize_backend_comparison.py
Comment thread flashinfer/quantization/kernels/mxfp4_quantize.py
Comment thread flashinfer/quantization/kernels/mxfp8_quantize.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 384-385: The swizzled-path calculation of threads_per_row uses
num_sf_blocks_per_row so sf_col_idx never iterates into the padded SF columns
(indices num_sf_blocks_per_row .. padded_sf_cols-1), leaving those padding
columns uninitialized (scale_output from torch.empty). Update the swizzled
path(s) where threads_per_row is set (referenced by threads_per_row,
num_sf_blocks_per_row, _threads_per_sf) to size rows by padded_sf_cols (i.e.,
use padded_sf_cols * _threads_per_sf) so sf_col_idx loops cover padded_sf_cols,
and ensure the padding-column clear block and any zeroing logic that uses
sf_col_idx and padded_sf_cols runs for those extra columns; also ensure
scale_output is allocated/initialized accordingly rather than relying on
torch.empty leaving bytes undefined.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 53a4f10b-27c1-47c7-b4e6-a08b0bf9046c

📥 Commits

Reviewing files that changed from the base of the PR and between 8a9545c and 0eda88e.

📒 Files selected for processing (4)
  • benchmarks/bench_mxfp4_quantize_backend_comparison.py
  • flashinfer/quantization/kernels/mxfp8_quantize.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py
  • tests/utils/test_fp8_quantize.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/utils/test_fp8_quantize.py
  • benchmarks/bench_mxfp4_quantize_backend_comparison.py

Comment thread flashinfer/quantization/kernels/mxfp8_quantize.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/quantization/kernels/mxfp8_quantize.py (1)

377-390: ⚠️ Potential issue | 🔴 Critical

Require exact divisibility before using the multi-row swizzled mapping.

If _compute_optimal_warps() falls back to 16 warps, Lines 389-390 can still select the small-K path even when col_units_per_block is not divisible by num_sf_blocks_per_row. For K=1056 (33 SF blocks), Line 574 then creates a partial extra row (512 / 66 threads-per-row in 2T/SF, 512 / 132 in 4T/SF), so those tail threads alias the next row batch’s first row. Gate the multi-row path on exact divisibility and otherwise fall back to needs_col_loop=True.

🔧 Suggested fix
         # rows_per_block = col_units_per_block // num_sf_blocks_per_row
-        # With optimal warps, this should divide evenly for small K
-        if self.num_sf_blocks_per_row <= col_units_per_block:
+        # Multi-row processing requires exact row tiling; otherwise the tail
+        # threads spill into a partial extra row and overlap the next row batch.
+        if (
+            self.num_sf_blocks_per_row <= col_units_per_block
+            and col_units_per_block % self.num_sf_blocks_per_row == 0
+        ):
             self.rows_per_block = col_units_per_block // self.num_sf_blocks_per_row
             self.needs_col_loop = False
         else:
             self.rows_per_block = 1
             self.needs_col_loop = True

Also applies to: 574-577

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/quantization/kernels/mxfp8_quantize.py` around lines 377 - 390,
The multi-row swizzle path currently assumes col_units_per_block divides
num_sf_blocks_per_row; update the gating logic in the block that computes
self.warps_per_block (call site uses _compute_optimal_warps) to require exact
divisibility before enabling rows_per_block: check (col_units_per_block %
self.num_sf_blocks_per_row == 0) and only then set self.rows_per_block =
col_units_per_block // self.num_sf_blocks_per_row; otherwise do not set
rows_per_block and force the fallback by setting self.needs_col_loop = True (and
ensure any later code that relies on rows_per_block uses the fallback when
needs_col_loop is True). This change addresses aliasing when
_compute_optimal_warps returns a fallback warp count (e.g., 16) for K values
like 1056.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@flashinfer/quantization/kernels/mxfp8_quantize.py`:
- Around line 377-390: The multi-row swizzle path currently assumes
col_units_per_block divides num_sf_blocks_per_row; update the gating logic in
the block that computes self.warps_per_block (call site uses
_compute_optimal_warps) to require exact divisibility before enabling
rows_per_block: check (col_units_per_block % self.num_sf_blocks_per_row == 0)
and only then set self.rows_per_block = col_units_per_block //
self.num_sf_blocks_per_row; otherwise do not set rows_per_block and force the
fallback by setting self.needs_col_loop = True (and ensure any later code that
relies on rows_per_block uses the fallback when needs_col_loop is True). This
change addresses aliasing when _compute_optimal_warps returns a fallback warp
count (e.g., 16) for K values like 1056.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e7ea1b97-a0cb-4167-aa30-cb8db8e97129

📥 Commits

Reviewing files that changed from the base of the PR and between 0eda88e and 07ba503.

📒 Files selected for processing (1)
  • flashinfer/quantization/kernels/mxfp8_quantize.py

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 27, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !468 has been created, and the CI pipeline #47118274 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47118274: 13/20 passed

@bkryu bkryu self-assigned this Mar 31, 2026
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 31, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !468 has been updated with latest changes, and the CI pipeline #47306420 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47306420: 7/20 passed

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 31, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !468 has been updated with latest changes, and the CI pipeline #47380772 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/quantization/kernels/mxfp4_quantize.py`:
- Around line 932-968: The swizzled path lacks a fail-fast check for the K
multiple required by swizzled MXFP4; add an explicit guard when sf_layout ==
SF_LAYOUT_128x4 that asserts or raises if k % 128 != 0 (use the existing k
variable) before allocating outputs or calling kernel_fn so errors surface
immediately (rather than later at scale_output.reshape); update the branch that
handles the non-SF_LAYOUT_LINEAR case (the block using padded_m, padded_sf_cols,
kernel_fn, and scale_output.reshape) to perform this check and emit a clear
error mentioning SF_LAYOUT_128x4 and the required k % 128 == 0 constraint.
- Around line 252-322: Replace the Unicode multiplication character '×' with
ASCII 'x' in the inline comments within the mxfp4_quantize kernel to satisfy
Ruff RUF003; specifically update comment occurrences like "Each thread loads 8
elements (1×128-bit)" and "Each thread loads 32 elements (4×128-bit)" (near the
code that uses MXFP4_SF_VEC_SIZE, ld_global_v4_u32, st_global_u32 and
compute_sf_index_linear_gpu) so comments read "1x128-bit" and "4x128-bit" (and
any other similar occurrences) without changing code logic.

In `@flashinfer/quantization/kernels/nvfp4_quantize.py`:
- Line 1255: The current assignment to enable_pdl overwrites an explicit
True/False from the caller; change it so auto-detection only happens when the
caller passed None. Specifically, in nvfp4_quantize.py replace the expression
that sets enable_pdl using device_support_pdl(input.device) unless enable_pdl is
None (e.g., if enable_pdl is None: enable_pdl =
device_support_pdl(input.device)) so that explicit enable_pdl=True or False
passed into the function remains respected while None triggers device capability
detection.
- Around line 84-91: The helper _compute_swizzled_layout_sf_size is dead code in
this module—either delete the _compute_swizzled_layout_sf_size function to
remove the unused symbol, or wire it into the buffer allocation/launch path that
prepares swizzled scale-factor buffers (call _compute_swizzled_layout_sf_size
from the launch/allocation routine that computes padded_row/padded_column before
allocating the SF buffer and remove any duplicate copies in fp4_quantization or
fp8_quantization); ensure you update any related imports/exports (remove from
__all__ if present) and eliminate duplicate implementations to centralize the
computation under the function name _compute_swizzled_layout_sf_size if you
choose to keep it.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c4247646-c0f7-437f-8b4b-add14daf5fce

📥 Commits

Reviewing files that changed from the base of the PR and between 7150697 and 9269e93.

📒 Files selected for processing (3)
  • flashinfer/quantization/kernels/mxfp4_quantize.py
  • flashinfer/quantization/kernels/mxfp8_quantize.py
  • flashinfer/quantization/kernels/nvfp4_quantize.py

Comment thread flashinfer/quantization/kernels/mxfp4_quantize.py Outdated
Comment thread flashinfer/quantization/kernels/mxfp4_quantize.py
Comment thread flashinfer/quantization/kernels/nvfp4_quantize.py Outdated
Comment thread flashinfer/quantization/kernels/nvfp4_quantize.py
@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 31, 2026

/bot stop

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

The GitLab CI pipeline #47380772 has been cancelled.

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 31, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !468 has been updated with latest changes, and the CI pipeline #47383493 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[CANCELING] Pipeline #47383493: canceled

@bkryu
Copy link
Copy Markdown
Collaborator Author

bkryu commented Mar 31, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !468 has been updated with latest changes, and the CI pipeline #47391372 is currently running. I'll report back once the pipeline job completes.

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

[FAILED] Pipeline #47391372: 10/20 passed

@bkryu bkryu merged commit d476c61 into flashinfer-ai:main Apr 1, 2026
30 of 34 checks passed
@bkryu bkryu deleted the optimize_quant branch April 1, 2026 19:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants