Skip to content

feat: RMSNorm + RoPE fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope#3148

Open
kahyunnam wants to merge 35 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/fused_norm_rope_for_video_gen
Open

feat: RMSNorm + RoPE fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope#3148
kahyunnam wants to merge 35 commits intoflashinfer-ai:mainfrom
kahyunnam:knam/fused_norm_rope_for_video_gen

Conversation

@kahyunnam
Copy link
Copy Markdown
Member

@kahyunnam kahyunnam commented Apr 22, 2026

📌 Description

#2971

Add a fused CUDA kernel for across-heads QK RMSNorm + 3D Rotary Position Embeddings (RoPE) + V copy, targeting video generation DIT (Diffusion Transformer) self-attention workloads such as WAN 2.1/2.2.

This kernel fuses three operations into a single launch:

  1. Across-heads RMSNorm on Q and K (normalizes over the full hidden_dim = num_heads * head_dim, not per-head)
  2. 3D RoPE with frame/height/width spatial decomposition (each head dimension is split into temporal, height, and width frequency channels)
  3. V passthrough copy to a contiguous output buffer

Optional FP8 E4M3 quantized output with SM89+ vectorized PTX conversion and SM100+ Blackwell FFMA2 intrinsics (all with scalar fallbacks for SM80+).

API

User-facing API follows the dsv3_ops pattern — implementation lives in flashinfer/norm/ (alongside rmsnorm, gemma_rmsnorm, etc.) with a re-export facade at flashinfer/diffusion_ops/:

from flashinfer.diffusion_ops import fused_qk_rmsnorm_rope

q, k, v = fused_qk_rmsnorm_rope(
    qkv,               # [batch, seq_len, (nq+nk+nv)*head_dim] or [num_tokens, ...]
    q_weight, k_weight, # RMSNorm weights [num_heads_x * head_dim]
    ppf=5, pph=12, ppw=32,
    num_frame_channels=44, num_height_channels=42, num_width_channels=42,
    num_heads_q=24, num_heads_k=24, num_heads_v=24,
    head_dim=128,
)

Benchmark Results (B200 (sm100), CUPTI)

python /workspace/flashinfer/benchmarks/bench_fused_qk_rmsnorm_rope.py
GPU: NVIDIA B200
Config: WAN 2.2 5B (num_heads=24, head_dim=128)

Shape                                                Eager (ms)   Fused (ms)    Speedup
------------------------------------------------------------------------------------------
B=1 5x12x32= 1920 (480p production (1920 tokens))        0.2420       0.0389      6.22x
B=1 5x12x8=  480 (480p small (480 tokens))               0.2363       0.0114     20.68x
B=1 5x48x32= 7680 (720p large (7680 tokens))             0.4948       0.1462      3.38x
B=2 5x12x32= 1920 (batch=2 (3840 tokens))                0.2864       0.0750      3.82x
B=1 5x6x4=  120 (tiny (120 tokens))                      0.2467       0.0045     54.67x
B=4 5x12x32= 1920 (batch=4 (7680 tokens))                0.4770       0.1462      3.26x
B=1 5x12x16=  960 (half seq (960 tokens))                0.2376       0.0207     11.49x
B=1 10x12x32= 3840 (double frames (3840 tokens))         0.2823       0.0749      3.77x
------------------------------------------------------------------------------------------

Bug Fix

Found and fixed a bug in the NeoX (non-interleaved) RoPE path: pos_id was computed from dim_idx_x's mapped value but reused for dim_idx_y without recomputation. Since the (dim_idx * 2) & mask mapping can place adjacent elements in a float2 pair into different spatial slices (e.g., height vs width), the y component received an incorrect position ID. Fix: one line to recompute pos_id from dim_idx_y. The interleaved path (used in production) was unaffected.

Architecture Support

SM Architecture Support Level
SM80 Ampere (A100) Full — BF16 path, FP8 via software emulation
SM86 Ampere (RTX 3090) Same as SM80
SM89 Ada (L40, RTX 4090) + native FP8 conversion
SM90 Hopper (H100) Primary target
SM100 Blackwell (B200, GB200, RTX 5090) Primary target + FFMA2
SM103 Blackwell (B300, GB300) Primary target + FFMA2
SM110–SM121 Blackwell variants Expected to work (FFMA2 + FP8)

Known Limitations

  • num_heads ≤ 32: The kernel uses one warp per head, so max_heads × 32 = 1024 threads per block (CUDA maximum). WAN 14B (40 heads) is unsupported. Supporting it would require a kernel redesign (multi-head per warp).
  • BF16 input only: FP16/FP32 input would need new template instantiations.
  • 3D RoPE is specialized: The frame/height/width decomposition targets video-gen DIT models. This is not a general-purpose RoPE.

Files Changed

include/flashinfer/fused_qk_rmsnorm_rope.cuh     # CUDA kernel + utilities (754 lines)
csrc/norm.cu                                    # TVM-FFI launcher (added to norm module)
csrc/flashinfer_norm_binding.cu                 # TVM-FFI export (added to norm module)
flashinfer/norm/fused_qk_rmsnorm_rope.py           # Python API with validation
flashinfer/norm/__init__.py                     # Re-export
flashinfer/diffusion_ops/__init__.py            # User-facing facade (like dsv3_ops/)
tests/norm/test_fused_qk_rmsnorm_rope.py           # 26 tests
benchmarks/bench_fused_qk_rmsnorm_rope.py          # Benchmark script

🔍 Related Issues

#2971

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Test summary: 26 tests (25 pass, 1 xfail)

  • 8 interleaved correctness shapes (WAN 2.2 5B config)
  • 3 NeoX (non-interleaved) correctness shapes
  • V passthrough (exact BF16 copy)
  • Destination-passing style
  • 2D [num_tokens, hidden] input
  • 3 FP8 output scales (1.0, 0.5, 2.0)
  • RoPE-only mode (is_qk_norm=False)
  • 3 multi-config: WAN 1.3B (12 heads), WAN 5B (24 heads), WAN 14B (40 heads — xfail, exceeds 32-head limit)
  • 5 error-case validation tests

Validated on 3 GPU architectures:

  • NVIDIA A100 (SM80, Ampere)
  • NVIDIA L40S (SM89, Ada)
  • NVIDIA H100 NVL (SM90, Hopper)
  • NVIDIA B200 (sm100, Blackwell)

Reviewer Notes

  • The kernel compiles as part of the existing norm JIT module (gen_norm_module()), so no new JIT spec or AOT registration is needed.
  • The NeoX RoPE bugfix is a one-line change in fused_qk_rmsnorm_rope.cuh — reviewers may want to verify the pos_id recomputation logic for dim_idx_y.
  • The diffusion_ops/ facade follows the exact same pattern as dsv3_ops/ — pure re-export, no implementation.
  • Future video-gen kernels (e.g., fused cross-attention, rmsnorm_silu for WAN) can be added to diffusion_ops/.

Summary by CodeRabbit

  • New Features

    • Fused Q/K RMSNorm + 3D rotary embeddings for video self-attention with optional scaling and FP8 E4M3 output modes.
    • Public Python API and package re-exports for easy invocation; outputs can be preallocated or auto-allocated and V is passed through unchanged.
    • GPU benchmark script to compare fused implementation vs. reference timings across representative shapes.
  • Tests

    • Comprehensive CUDA test suite validating BF16/FP8 modes, RoPE variants, correctness, output semantics, and many error cases.

Single consolidated header (include/flashinfer/fused_qk_norm_rope.cuh)
containing the CUDA kernel and all its dependencies:
- IntFastDiv (signed fast integer division, from TRT-LLM)
- packed_as vector type mapping
- FP8 E4M3 quantization helpers with SM89+ PTX fast paths
- Blackwell FFMA2 intrinsics with SM<100 scalar fallbacks
- fusedQKNormRopeKernel template (across-heads RMSNorm + 3D RoPE + V copy)
- Host launcher with frequency table cache and head_dim dispatch

Adapted from share-qknormrope-layernorm-fusion/fused_QKNorm_RoPE/.
AI-assisted port to FlashInfer conventions (namespace, error macros).

Made-with: Cursor
csrc/fused_qk_norm_rope.cu: TVM-FFI launcher that bridges TensorView
to the raw-pointer kernel. Validates inputs (CUDA, contiguous, BF16),
obtains stream and num_sms internally, delegates to launchFusedQKNormRope.

csrc/flashinfer_fused_qk_norm_rope_binding.cu: TVM-FFI export via
TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, ...).

AI-assisted.

Made-with: Cursor
Move the launcher from separate csrc/fused_qk_norm_rope.cu into
csrc/norm.cu and add the TVM-FFI export to flashinfer_norm_binding.cu.
Delete the standalone files and JIT generator since the kernel now
compiles as part of gen_norm_module().

This keeps all norm operations in a single JIT module, matching the
existing pattern where rmsnorm, gemma_rmsnorm, layernorm etc. share
one compiled .so. Future work will similarly merge rmsnorm_silu.

AI-assisted.

Made-with: Cursor
flashinfer/norm/fused_qk_norm_rope.py: Public API function with
@flashinfer_api and @backend_requirement decorators. Uses the shared
get_norm_module() to call the fused_qk_norm_rope TVM-FFI export.

Comprehensive input validation: dtype (BF16), shape (3D), head_dim
(64/128/256), max heads (<=32), channel sum and even-ness, spatial
dims positivity, seq_len consistency.

Destination-passing style for q_out/k_out/v_out.

AI-assisted.

Made-with: Cursor
- flashinfer/norm/__init__.py: re-export fused_qk_norm_rope, add to __all__
- flashinfer/video_gen_ops/__init__.py: re-export facade (like dsv3_ops/)
  so users can `from flashinfer.video_gen_ops import fused_qk_norm_rope`

AI-assisted.

Made-with: Cursor
tests/norm/test_fused_qk_norm_rope.py: 17 tests covering:
- Interleaved RoPE correctness (4 shapes) — all pass
- NeoX RoPE correctness (2 shapes) — xfail, needs kernel-side validation
- V passthrough (exact BF16 copy)
- Destination-passing style
- FP8 E4M3 output (3 scale values)
- RoPE-only mode (is_qk_norm=False)
- Error cases (non-CUDA, wrong dtype, bad head_dim, channel mismatch, seq_len)

Fix: Move packed_as into fused_rope_detail sub-namespace to avoid
collision with tensorrt_llm::common::packed_as visible through norm.cuh.

AI-assisted.

Made-with: Cursor
- Expanded interleaved shapes from 4 to 8 (matching original test suite)
- NeoX tests remain xfail: reference freq mapping doesn't match kernel's
  (dim_idx * 2) & mask convention — needs kernel author input
- Added NeoX-specific embedding helper (non-interleaved freq layout)
- Relaxed RoPE-only tolerance from 0.01 to 0.05 for BF16 precision
- Validated on all 3 available architectures:
  - L40S (SM89, Ada): 19 passed, 3 xfailed
  - H100 NVL (SM90, Hopper): 19 passed, 3 xfailed
  - A100 (SM80, Ampere): 19 passed, 3 xfailed

AI-assisted.

Made-with: Cursor
Bug: In the NeoX (non-interleaved) RoPE path, pos_id was computed from
dim_idx_x's mapped value but reused for dim_idx_y without recomputation.
Since adjacent elements in a float2 pair can map to different spatial
slices (e.g. x->height, y->width), the y component got the wrong
position ID, causing incorrect RoPE rotations.

Fix: Recompute pos_id from dim_idx_y's mapped value before computing
theta_y (one-line addition).

Tests: NeoX tests now pass (were xfail). 22/22 tests pass on all three
architectures (L40S SM89, H100 SM90, A100 SM80). Also added proper
NeoX reference implementation that matches the kernel's per-element
frequency mapping via (dim_idx * 2) & ((1 << log_head_dim) - 1).

AI-assisted.

Made-with: Cursor
Test configs sourced from official WAN repos:
- wan2.1-1.3B: dim=1536, num_heads=12, head_dim=128 — passes
- wan2.2-5B:   dim=3072, num_heads=24, head_dim=128 — passes
- wan2.1-14B:  dim=5120, num_heads=40, head_dim=128 — xfail
  (40 heads * 32 threads = 1280 > CUDA max 1024 threads/block)

Also updated NeoX reference implementation and removed xfail now
that the kernel's pos_id bug is fixed.

AI-assisted.

Made-with: Cursor
benchmarks/bench_fused_qk_norm_rope.py: Compares fused kernel vs eager
PyTorch (nn.RMSNorm + manual RoPE) across 8 WAN model shapes.

H100 NVL CUPTI results (WAN 2.2 5B config):
- Production shape (1920 tokens): 5.36x speedup
- Small (480 tokens): 22.3x speedup
- Large (7680 tokens): 3.62x speedup
- Tiny (120 tokens): 64x speedup (launch-overhead dominated)

AI-assisted.

Made-with: Cursor
The kernel operates on flattened [num_tokens, hidden] layout. Now the
Python API accepts both:
- 3D [batch, seq_len, hidden] -> outputs [batch, seq_len, heads, head_dim]
- 2D [num_tokens, hidden] -> outputs [num_tokens, heads, head_dim]

For 2D input, num_tokens must be divisible by ppf*pph*ppw.
Test verifies 2D and 3D produce identical results.

AI-assisted.

Made-with: Cursor
- Remove internal provenance details from comments
- List SM90, SM100, SM103 as separate primary target architectures
- Generalize packed_as namespace comment

AI-assisted.

Made-with: Cursor
- Replace 13 per-test `from flashinfer.norm import ...` with one
  top-level `from flashinfer.video_gen_ops import fused_qk_norm_rope`
- Tests now exercise the public user-facing API path
- Also removed internal reference in module docstring

AI-assisted.

Made-with: Cursor
- Remove unused imports (math, os, gen_norm_module)
- Add tuple[int, ...] annotations to fix mypy variable-length tuple error
- Apply clang-format to C++/CUDA files
- Apply ruff format to Python files

All pre-commit checks pass.

AI-assisted.

Made-with: Cursor
The kernel targets DIT (Diffusion Transformer) self-attention which is
used in both video and image diffusion models. diffusion_ops is a
better name that won't need renaming if image diffusion kernels are
added later.

AI-assisted.

Made-with: Cursor
Observed max diffs scale as ~0.5 * output_scale (FP8 quantization
boundary rounding). Previous tolerance was max(1.0*scale, 0.5) giving
2x headroom. Tightened to max(0.75*scale, 0.375) for 1.5x headroom,
with a comment explaining the observed error pattern.

AI-assisted.

Made-with: Cursor
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 22, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds a fused CUDA operator and full-stack bindings: a kernel and host launcher, C++ entry and TVM FFI export, a Python wrapper and re-exports, comprehensive tests, and a benchmark. The operator performs RMSNorm on Q/K, applies 3D RoPE (interleaved or NeoX), optional YARN scaling and FP8 output, and copies V through.

Changes

Cohort / File(s) Summary
CUDA Kernel & Host Launcher
include/flashinfer/fused_qk_rmsnorm_rope.cuh
New fused kernel template fusedQKNormRopeKernel and host API launchFusedQKNormRope; adds IntFastDiv, FP8 E4M3 helpers, frequency-table caching, dispatch by head_dim/interleave/output_fp8/YARN, and input/launch checks.
C++ Binding & Entry
csrc/flashinfer_norm_binding.cu, csrc/norm.cu
Exports TVM FFI symbol fused_qk_rmsnorm_rope; adds fused_qk_rmsnorm_rope_run entry performing dtype/contiguity/device/stream checks and calling the host launcher with RoPE/norm/quant params.
Python API & Re-exports
flashinfer/norm/fused_qk_rmsnorm_rope.py, flashinfer/norm/__init__.py, flashinfer/diffusion_ops/__init__.py
New Python wrapper fused_qk_rmsnorm_rope(...) with rigorous input/shape/partition validation, optional preallocated outputs, BF16 expectation for inputs/weights, FP8 output controls, flattening to token-major layout, and public re-exports.
Tests
tests/norm/test_fused_qk_rmsnorm_rope.py
Extensive CUDA tests: PyTorch reference implementations (interleaved and NeoX), correctness checks for Q/K (with/without RMSNorm), V passthrough, preallocated-output identity, FP8 output validation, multi-shape coverage, 2D vs 3D path parity, and many negative/error-case tests.
Benchmark
benchmarks/bench_fused_qk_rmsnorm_rope.py
New benchmark comparing eager PyTorch Q/K RMSNorm + interleaved 3D RoPE vs fused operator across WAN-like shapes; reports median GPU times and speedups; GPU selectable via CLI.

Sequence Diagram

sequenceDiagram
    participant Py as Python API
    participant TVM as TVM FFI
    participant Entry as C++ Entry
    participant Host as Host Launcher
    participant Kernel as CUDA Kernel
    participant Dev as Device Memory

    Py->>Py: validate tensors, shapes, RoPE params
    Py->>TVM: call fused_qk_rmsnorm_rope(...)
    TVM->>Entry: fused_qk_rmsnorm_rope_run(...)
    Entry->>Entry: dtype/contiguity checks\nselect device & stream
    Entry->>Host: launchFusedQKNormRope(...)
    Host->>Kernel: dispatch kernel instantiation (head_dim, interleave, FP8, YARN)
    
    par Kernel Ops
        Kernel->>Dev: load QKV and weights
        Kernel->>Kernel: apply RMSNorm on Q/K (optional)
        Kernel->>Kernel: apply 3D RoPE (interleaved or NeoX)
        Kernel->>Kernel: optional YARN scaling & FP8 quant+pack
        Kernel->>Dev: write q_out, k_out, v_out
    end

    Dev->>Py: return or fill output tensors
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Suggested reviewers

  • yzh119
  • aleozlx
  • nvmbreughe
  • cyx-6
  • jimmyzho
  • bkryu

Poem

🐰 I normed Q and K beneath the moon,
RoPE spun frames across each head and tune,
FP8 winked, YARN stretched a tiny boon,
Kernels hum and tests will check by noon,
A rabbit hops — fused ops land soon!

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.95% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Description check ✅ Passed The PR description is comprehensive, following the template structure with clear Description, Related Issues, Pre-commit Checks (marked complete), and Tests sections. Includes detailed implementation context, benchmarks, architecture support, and known limitations.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Title check ✅ Passed The title clearly and specifically describes the main feature: a fused RMSNorm + RoPE operation for WAN (diffusion models) being added to the flashinfer.diffusion_ops module.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@kahyunnam kahyunnam changed the title feat: flashinfer.diffusion_ops.fused_qk_norm_rope feat (norm fusion op, diffusion support): flashinfer.diffusion_ops.fused_qk_norm_rope Apr 22, 2026
@kahyunnam kahyunnam changed the title feat (norm fusion op, diffusion support): flashinfer.diffusion_ops.fused_qk_norm_rope feat: flashinfer.diffusion_ops.fused_qk_norm_rope Apr 22, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused QKNorm + 3D RoPE kernel optimized for video generation DIT architectures, providing the CUDA implementation, Python API, benchmarks, and tests. Feedback identifies critical issues regarding potential integer overflows in num_tokens and baseOffset calculations for large sequences. Additionally, the global static frequency cache is not multi-GPU safe, and the use of cudaFree within the launch path introduces performance-degrading device synchronization that prevents CUDA Graph capture.

Comment thread include/flashinfer/fused_qk_rmsnorm_rope.cuh Outdated
Comment thread include/flashinfer/fused_qk_rmsnorm_rope.cuh Outdated
Comment thread include/flashinfer/fused_qk_rmsnorm_rope.cuh Outdated
Comment thread include/flashinfer/fused_qk_rmsnorm_rope.cuh Outdated
The kernel performs RMSNorm specifically (not LayerNorm or generic norm).
Rename to fused_qk_rmsnorm_rope for consistency with FlashInfer's
existing naming convention (rmsnorm, fused_add_rmsnorm, gemma_rmsnorm,
fused_rmsnorm_silu).

All files, imports, symbols, and docstrings updated. Internal kernel
function names (fusedQKNormRopeKernel, launchFusedQKNormRope) kept
as-is since they are not part of the public API.

25 passed, 1 xfail after rename.

AI-assisted.

Made-with: Cursor
@kahyunnam kahyunnam changed the title feat: flashinfer.diffusion_ops.fused_qk_norm_rope feat: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope Apr 22, 2026
Change num_tokens from int to int64_t in the kernel signature, launcher,
and all offset calculations (baseOffset, v_output_offset, inputOffset,
outputBase/outputOffset, quantize_store_fp8 offset parameter).

This prevents integer overflow when batch_size * seq_len * num_heads *
head_dim exceeds 2^31 (e.g. batch=64 with seq_len=7680 and 24 heads).

AI-assisted.

Made-with: Cursor
Replace the single global s_freq_cache with a per-device
std::unordered_map<int, FreqCacheEntry> keyed by CUDA device ID.
This prevents crashes when the kernel is called on different GPUs
within the same process (the old single cache held a device pointer
that was only valid on the GPU that allocated it).

AI-assisted.

Made-with: Cursor
AI-assisted.

Made-with: Cursor
@kahyunnam kahyunnam marked this pull request as ready for review April 22, 2026 21:22
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (4)
include/flashinfer/fused_qk_rmsnorm_rope.cuh (1)

33-40: abort() on validation failure is hostile to the Python host.

FLASHINFER_FUSED_CHECK calls abort(), which tears down the entire process — including a Python interpreter — on any kernel-side precondition failure (e.g. unsupported head_dim, factor != 1 with attention_factor != 1). These are user-input conditions that should be recoverable. Consider throwing std::runtime_error instead (or returning cudaError_t) so the FFI layer can translate it into a Python exception.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh` around lines 33 - 40, The macro
FLASHINFER_FUSED_CHECK currently calls abort() (and fprintf) on failures which
kills the Python process; change FLASHINFER_FUSED_CHECK to report failures via
exceptions or error codes instead: replace the abort() path with throwing a
std::runtime_error containing the file/line/condition message (and include
<stdexcept>) or return a cudaError_t/other error value so the FFI layer can
convert it to a Python exception; update all call sites in
fused_qk_rmsnorm_rope.cuh that rely on FLASHINFER_FUSED_CHECK to handle the
thrown exception or error return appropriately.
tests/norm/test_fused_qk_rmsnorm_rope.py (2)

312-433: Minor: unused unpacked variables flagged by Ruff.

v_ref (lines 312, 391) and v_fused (lines 325, 404) are never referenced. Rename them to _ (or _v_ref / _v_fused) to satisfy RUF059 and make the intent clear.

♻️ Proposed fix
-    q_ref, k_ref, v_ref = reference_qk_norm_rope(
+    q_ref, k_ref, _ = reference_qk_norm_rope(
         ...
     )
 
     qkv_combined = torch.cat([query, key, value], dim=-1).contiguous()
-    q_fused, k_fused, v_fused = fused_qk_rmsnorm_rope(
+    q_fused, k_fused, _ = fused_qk_rmsnorm_rope(
         ...
     )

Apply in both test_interleaved_correctness and test_neox_correctness.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 312 - 433, The test
unpacks unused variables v_ref and v_fused in the calls to
reference_qk_norm_rope and fused_qk_rmsnorm_rope inside
test_interleaved_correctness and test_neox_correctness; rename those two unused
unpacked variables to _ (or _v_ref/_v_fused) in the assignments (e.g., change
"q_ref, k_ref, v_ref = reference_qk_norm_rope(...)" and "q_fused, k_fused,
v_fused = fused_qk_rmsnorm_rope(...)" to use "_" for the third element) so Ruff
RUF059 is satisfied while leaving q_ref, k_ref, q_fused, and k_fused unchanged.

285-288: Consider gating tests to supported compute capabilities.

fused_qk_rmsnorm_rope is decorated with supported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120, 121]). On an unsupported device the API will raise instead of skipping, causing spurious test failures. Add a top-level skip (e.g. pytest.skip guarded via fused_qk_rmsnorm_rope.is_compute_capability_supported(cc)) so the suite cleanly skips on unsupported GPUs.

As per coding guidelines: "Skip test execution on unsupported GPU architectures using flashinfer.utils check functions ... or API methods like api_name.is_compute_capability_supported(cc)".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 285 - 288, The test
should early-skip on unsupported GPU architectures: inside
test_interleaved_correctness (before using the device/dtype), query the current
device compute capability (e.g., via torch.cuda.get_device_capability(device) or
equivalent) and call fused_qk_rmsnorm_rope.is_compute_capability_supported(cc);
if it returns False, call pytest.skip with a clear message. Ensure you reference
the fused_qk_rmsnorm_rope.is_compute_capability_supported check and use
pytest.skip to avoid raising on unsupported GPUs.
benchmarks/bench_fused_qk_rmsnorm_rope.py (1)

1-211: Consider integrating with the unified benchmarking framework.

This benchmark is implemented as a standalone script. Per project convention, new kernel benchmarks should plug into benchmarks/flashinfer_benchmark.py so they share CLI options, output format, and tracking. CUPTI timing via bench_gpu_time is already correct; only the harness integration is missing.

As per coding guidelines: "Use the unified benchmarking framework in benchmarks/flashinfer_benchmark.py for kernel benchmarking with CUPTI timing support".

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/fused_qk_rmsnorm_rope.py`:
- Line 51: Remove the stray non-ASCII character at the end of the module
docstring sentence "All SM100+/SM89+ features have scalar fallbacks, so SM80 is
the true minimum." in fused_qk_rmsnorm_rope.py (the trailing 'å'); update the
docstring text to end with a normal ASCII period and ensure the docstring
surrounding that sentence (module-level docstring or the docstring in any
function/class near that line) has no other unintended non-ASCII characters.
- Line 33: Remove the unconditional "from . import get_norm_module" import and
add a local accessor function named _get_norm_module inside
fused_qk_rmsnorm_rope.py that imports the package at runtime (e.g. import
importlib; pkg = importlib.import_module("flashinfer.norm")) and returns
pkg.get_norm_module() if that attribute exists, otherwise raise a clear error;
then replace calls to get_norm_module().fused_qk_rmsnorm_rope(...) with
_get_norm_module().fused_qk_rmsnorm_rope(...), referencing the existing
fused_qk_rmsnorm_rope call site.

In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh`:
- Around line 672-678: The CI failure is due to clang-format changes around the
cache allocation block and a LAUNCH_KERNEL macro expansion; run the formatter
and commit the formatted file: run `pre-commit run clang-format --files
include/flashinfer/fused_qk_rmsnorm_rope.cuh` (or run your project's
clang-format) to reformat the code touching the cache.d_ptr allocation block
(the cudaFree/cudaMalloc/cache.alloc_floats section) and the LAUNCH_KERNEL
usage, then add and commit the resulting changes so the pre-commit job passes.
- Around line 659-715: The CUDA calls in the cache setup (cudaGetDevice,
cudaFree, cudaMalloc, cudaMemcpy) are unchecked and the host-to-device copy is
synchronous; update the logic around s_freq_cache_map/ FreqCacheEntry to CHECK
each cuda runtime return (use FLASHINFER_FUSED_CHECK or propagate cudaError_t)
after cudaGetDevice, cudaFree, and cudaMalloc and ensure cache.d_ptr is valid
before using it, and replace the blocking cudaMemcpy with cudaMemcpyAsync on the
caller's stream (thread/stream should be accepted or obtained) so the frequency
table upload (to cache.d_ptr) can be captured in CUDA graphs and won't write to
nullptr; keep the existing cache_miss/table_size logic and still set cache.*
fields only after successful async copy completion or error-checked launch.
- Around line 631-715: Change the header-scope cache to have single-definition
linkage (replace the static unordered_map declaration with inline or move it to
a .cu with an extern here), add a std::mutex (e.g., s_freq_cache_mutex) and wrap
accesses to s_freq_cache_map[device_id] and the entire cache-miss
read-modify-write branch with std::lock_guard to prevent races, replace raw cuda
API calls (cudaGetDevice/cudaMalloc/cudaFree/cudaMemcpy) with checked calls
using TVM_FFI_ICHECK(...) on the returned status, use cudaMemcpyAsync(...,
stream) instead of synchronous cudaMemcpy to honor the caller's stream, and stop
calling cudaFree on the cached d_ptr to preserve the "allocated once and never
freed" cudagraph-compatible behavior (only allocate with cudaMalloc when d_ptr
is null or alloc_floats is too small).

---

Nitpick comments:
In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh`:
- Around line 33-40: The macro FLASHINFER_FUSED_CHECK currently calls abort()
(and fprintf) on failures which kills the Python process; change
FLASHINFER_FUSED_CHECK to report failures via exceptions or error codes instead:
replace the abort() path with throwing a std::runtime_error containing the
file/line/condition message (and include <stdexcept>) or return a
cudaError_t/other error value so the FFI layer can convert it to a Python
exception; update all call sites in fused_qk_rmsnorm_rope.cuh that rely on
FLASHINFER_FUSED_CHECK to handle the thrown exception or error return
appropriately.

In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 312-433: The test unpacks unused variables v_ref and v_fused in
the calls to reference_qk_norm_rope and fused_qk_rmsnorm_rope inside
test_interleaved_correctness and test_neox_correctness; rename those two unused
unpacked variables to _ (or _v_ref/_v_fused) in the assignments (e.g., change
"q_ref, k_ref, v_ref = reference_qk_norm_rope(...)" and "q_fused, k_fused,
v_fused = fused_qk_rmsnorm_rope(...)" to use "_" for the third element) so Ruff
RUF059 is satisfied while leaving q_ref, k_ref, q_fused, and k_fused unchanged.
- Around line 285-288: The test should early-skip on unsupported GPU
architectures: inside test_interleaved_correctness (before using the
device/dtype), query the current device compute capability (e.g., via
torch.cuda.get_device_capability(device) or equivalent) and call
fused_qk_rmsnorm_rope.is_compute_capability_supported(cc); if it returns False,
call pytest.skip with a clear message. Ensure you reference the
fused_qk_rmsnorm_rope.is_compute_capability_supported check and use pytest.skip
to avoid raising on unsupported GPUs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 010f0d49-9914-43ac-9445-05a33dddf336

📥 Commits

Reviewing files that changed from the base of the PR and between 6ddbdb0 and 2de4c2b.

📒 Files selected for processing (8)
  • benchmarks/bench_fused_qk_rmsnorm_rope.py
  • csrc/flashinfer_norm_binding.cu
  • csrc/norm.cu
  • flashinfer/diffusion_ops/__init__.py
  • flashinfer/norm/__init__.py
  • flashinfer/norm/fused_qk_rmsnorm_rope.py
  • include/flashinfer/fused_qk_rmsnorm_rope.cuh
  • tests/norm/test_fused_qk_rmsnorm_rope.py

Comment thread flashinfer/norm/fused_qk_rmsnorm_rope.py Outdated
Comment thread flashinfer/norm/fused_qk_rmsnorm_rope.py Outdated
Comment thread include/flashinfer/norm/fused_qk_rmsnorm_rope.cuh
Comment thread include/flashinfer/norm/fused_qk_rmsnorm_rope.cuh
Comment thread include/flashinfer/fused_qk_rmsnorm_rope.cuh Outdated
AI-assisted.

Made-with: Cursor
Move `from . import get_norm_module` from module top-level to inside
the function body. get_norm_module is only defined in norm/__init__.py
when _USE_CUDA_NORM is True; lazy import avoids failure if the CuTe
DSL path is active at import time.

AI-assisted.

Made-with: Cursor
AI-assisted.

Made-with: Cursor
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1


ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 38f12cdd-e7fd-4061-9f0e-39707ccab638

📥 Commits

Reviewing files that changed from the base of the PR and between 2de4c2b and 5a6c91b.

📒 Files selected for processing (1)
  • include/flashinfer/fused_qk_rmsnorm_rope.cuh

Comment thread include/flashinfer/norm/fused_qk_rmsnorm_rope.cuh
Remove cudaFree call on the old d_ptr when growing the allocation.
The old pointer is intentionally leaked to remain valid for any
cudagraph that captured a kernel referencing it. This aligns with the
stated design intent ("allocated once and never freed").

Also added comment explaining why cudaMemcpy is synchronous (source
is a stack-allocated array that goes out of scope after the call).

AI-assisted.

Made-with: Cursor
- Add FLASHINFER_FUSED_CHECK on cudaGetDevice and cudaMalloc to catch
  errors early instead of silently dereferencing nullptr
- Switch cudaMemcpy to cudaMemcpyAsync on the caller's stream (CUDA
  spec: async from unpinned host memory is synchronous, so stack-
  allocated h_freq_table is safe; using async form for cudagraph
  capture compatibility)

AI-assisted.

Made-with: Cursor
- Assert num_tokens > 0 and <= INT32_MAX before narrowing to dim3
- Add cudaGetLastError check after kernel launch to surface
  cudaErrorInvalidConfiguration or other launch failures

AI-assisted.

Made-with: Cursor
@kahyunnam
Copy link
Copy Markdown
Member Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !583 has been created, and the CI pipeline #49234040 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
include/flashinfer/fused_qk_rmsnorm_rope.cuh (1)

644-678: ⚠️ Potential issue | 🟠 Major

Remaining gaps from prior review: header-scope static cache + missing mutex.

The CUDA error-checking, async memcpy on caller's stream, and grid-dim/launch-error guards are all in now — nice. Two items from the earlier critical review still stand:

  1. Line 646: s_freq_cache_map is static at namespace scope in a header → internal linkage, one copy per TU that includes this file. It is currently safe only because csrc/norm.cu is the sole includer; any future inclusion will silently create a parallel cache (and, with the deliberate leak-old-pointer policy at L672-676, leak independently). Prefer inline (C++17) or move the definition into a .cu with an extern declaration here.

  2. Lines 667, 671-678: Access to s_freq_cache_map[device_id] and the RMW of the FreqCacheEntry (cudaMalloc + field updates at L674-677) aren't synchronized. If the FFI runtime ever releases the GIL around the launcher (or different Python threads drive the same device), you get a unordered_map data race and a torn write of cache.d_ptr observable by a concurrent kernel launch. A std::mutex around the cache-miss region (and the operator[] lookup) is sufficient.

🛠️ Suggested shape
-// Per-device frequency table cache. Keyed by CUDA device ID so that
-// multi-GPU usage within a single process is safe.
-static std::unordered_map<int, FreqCacheEntry> s_freq_cache_map;
+// Per-device frequency table cache. Keyed by CUDA device ID so that
+// multi-GPU usage within a single process is safe.
+inline std::unordered_map<int, FreqCacheEntry> s_freq_cache_map;
+inline std::mutex s_freq_cache_mutex;

and wrap the lookup + cache-miss RMW region in a std::lock_guard<std::mutex> (plus #include <mutex>).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh` around lines 644 - 678,
s_freq_cache_map is declared static at header scope causing one copy per TU and
racey concurrent access to s_freq_cache_map[device_id] and FreqCacheEntry RMW in
launchFusedQKNormRope; change the declaration to inline std::unordered_map<int,
FreqCacheEntry> s_freq_cache_map (or move the definition into a .cu with an
extern in the header) and add a global std::mutex (e.g., s_freq_cache_mutex)
with `#include` <mutex>, then wrap the operator[] lookup and the cache-miss branch
that does cudaMalloc and updates cache.d_ptr/cache.alloc_floats in a
std::lock_guard<std::mutex> to serialize map access and the pointer/size write.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@flashinfer/norm/fused_qk_rmsnorm_rope.py`:
- Around line 244-254: The pre-allocated destination buffers and weight tensors
are not validated before being .view()'d and passed to the kernel; add checks to
ensure user-supplied q_out/k_out/v_out and q_weight/k_weight match expected
dtype (respecting output_fp8), device (same as qkv.device), shape
(q_out/k_out/v_out shape equals out_shape_q/out_shape_k/out_shape_v and
q_weight/k_weight length equals num_heads_* * head_dim), and are contiguous
(torch.is_contiguous) to prevent silent memory corruption; implement these
validations either by extending _check_fused_qk_rmsnorm_rope to accept
kwargs.get("q_out","k_out","v_out","q_weight","k_weight") or add an inline
validation block before creating qkv_flat/q_out_flat/k_out_flat/v_out_flat that
raises a clear ValueError if any check fails.

---

Duplicate comments:
In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh`:
- Around line 644-678: s_freq_cache_map is declared static at header scope
causing one copy per TU and racey concurrent access to
s_freq_cache_map[device_id] and FreqCacheEntry RMW in launchFusedQKNormRope;
change the declaration to inline std::unordered_map<int, FreqCacheEntry>
s_freq_cache_map (or move the definition into a .cu with an extern in the
header) and add a global std::mutex (e.g., s_freq_cache_mutex) with `#include`
<mutex>, then wrap the operator[] lookup and the cache-miss branch that does
cudaMalloc and updates cache.d_ptr/cache.alloc_floats in a
std::lock_guard<std::mutex> to serialize map access and the pointer/size write.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 7ed2b704-48a6-4653-a1cd-f6770ed6608a

📥 Commits

Reviewing files that changed from the base of the PR and between 5a6c91b and 2c92645.

📒 Files selected for processing (2)
  • flashinfer/norm/fused_qk_rmsnorm_rope.py
  • include/flashinfer/fused_qk_rmsnorm_rope.cuh

Comment thread flashinfer/norm/fused_qk_rmsnorm_rope.py Outdated
Add inline checks before passing to the kernel:
- q_weight/k_weight: size matches num_heads*head_dim, dtype is BF16,
  contiguous
- q_out/k_out/v_out (when user-supplied): shape, dtype (BF16 or FP8
  matching output_fp8 flag), contiguity, and device match

Prevents silent memory corruption from mismatched buffers.

AI-assisted.

Made-with: Cursor
- test_error_wrong_weight_size: wrong q_weight length
- test_error_wrong_output_shape: q_out with wrong num_heads
- test_error_wrong_output_dtype: q_out with float16 instead of bfloat16

Exercises the validation added in the previous commit.

AI-assisted.

Made-with: Cursor
get_norm_module was only defined when _USE_CUDA_NORM was True, causing
ImportError in CI environments where nvidia-cutlass-dsl is installed
(CuTe DSL path active). The fused_qk_rmsnorm_rope kernel has no CuTe
DSL alternative and always needs the CUDA JIT module.

Move get_norm_module out of the conditional block so it's always
available. Existing norm functions still guard calls with
`if _USE_CUDA_NORM:` so behavior is unchanged for them.

AI-assisted.

Made-with: Cursor
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/norm/test_fused_qk_rmsnorm_rope.py (2)

480-481: test_v_passthrough uses is_qk_norm=True with RMSNorm weights=1 — that's fine for V, but misleading for documenting "exact copy".

With is_qk_norm=True the kernel still runs RMSNorm on Q/K; V is passthrough regardless of is_qk_norm. The test name and comment suggest verifying a bit-exact V copy (which it does), but it would be slightly more robust to also cover is_qk_norm=False here so the passthrough guarantee is asserted on both code paths.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 480 - 481, The test
test_v_passthrough currently asserts V is an exact copy only under
is_qk_norm=True; update the test to also run the same assertion with
is_qk_norm=False so the passthrough behavior is validated on both code paths.
Locate the test_v_passthrough function and add a second sub-case (or parametrize
the test) that calls the same kernel/setup with is_qk_norm=False, computes
v_fused and v_expected, and asserts torch.equal(v_fused, v_expected) for that
case as well; keep existing checks for is_qk_norm=True unchanged.

187-204: Vectorize NeoX reference to avoid triple-nested Python loops.

The per-element Python loop over (batch, seq_len, head_dim) is O(B·S·D) host-side scalar work and will dominate test time if NEOX_SHAPES grows. Current shapes keep this tractable, but the loop can be replaced with broadcasted tensor ops using the precomputed freq_per_elem/spatial_dim_per_elem vectors, e.g. gather pos_ids[..., spatial_dim_per_elem] and multiply by freq_per_elem, then cos/sin in one shot.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 187 - 204, Replace the
triple-nested loops that compute cos_out and sin_out by fully vectorized tensor
ops: build a pos_ids tensor for all tokens (shape [batch_size, seq_len, 3])
using the same pos_t/pos_h/pos_w logic (vectorize tok→pos using arange and
div/mod with pph and ppw), then gather the per-head spatial coordinate with
spatial_dim_per_elem (use torch.tensor(spatial_dim_per_elem, device=...) and
torch.take_along_dim or torch.gather after expanding pos_ids) to produce
pos_per_head shape [batch_size, seq_len, head_dim]; multiply that by
freq_per_elem (broadcasted) to get theta, compute cos/sin in one call, and
assign to cos_out[...,0,:] and sin_out[...,0,:] converted to dtype; update uses
of freq_per_elem, spatial_dim_per_elem, cos_out, sin_out, batch_size, seq_len,
head_dim, pph, ppw, device, dtype accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 980-1002: The failing CI is caused by a too-long inline comment in
test_error_wrong_output_shape and by unused-unpack names flagged as RUF059; in
tests modify the long inline comment in fused_qk_rmsnorm_rope's
test_error_wrong_output_shape (replace or wrap the "12 heads, not 24" inline
comment so it doesn't exceed line length) and in test_interleaved_correctness
and test_neox_correctness replace the unused unpack targets v_ref and v_fused
with throwaway names (e.g., _ or _v_ref/_v_fused) or otherwise consume them so
they are not reported as unused; target the occurrences of v_ref/v_fused at the
unpack sites reported (lines referenced in the review) to silence the
unused-unpack warning.
- Around line 14-19: The test uses a hardcoded SM80 architecture guard; instead
call the decorated function's capability helper to decide skipping: use
fused_qk_rmsnorm_rope.is_compute_capability_supported(cc) (or iterate available
device compute capability) to skip tests for unsupported GPUs instead of
checking major < 8 so that SM100/103/110+ are correctly recognized; update the
module-level guard in tests/norm/test_fused_qk_rmsnorm_rope.py to invoke
fused_qk_rmsnorm_rope.is_compute_capability_supported for the current device
compute capability.

---

Nitpick comments:
In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 480-481: The test test_v_passthrough currently asserts V is an
exact copy only under is_qk_norm=True; update the test to also run the same
assertion with is_qk_norm=False so the passthrough behavior is validated on both
code paths. Locate the test_v_passthrough function and add a second sub-case (or
parametrize the test) that calls the same kernel/setup with is_qk_norm=False,
computes v_fused and v_expected, and asserts torch.equal(v_fused, v_expected)
for that case as well; keep existing checks for is_qk_norm=True unchanged.
- Around line 187-204: Replace the triple-nested loops that compute cos_out and
sin_out by fully vectorized tensor ops: build a pos_ids tensor for all tokens
(shape [batch_size, seq_len, 3]) using the same pos_t/pos_h/pos_w logic
(vectorize tok→pos using arange and div/mod with pph and ppw), then gather the
per-head spatial coordinate with spatial_dim_per_elem (use
torch.tensor(spatial_dim_per_elem, device=...) and torch.take_along_dim or
torch.gather after expanding pos_ids) to produce pos_per_head shape [batch_size,
seq_len, head_dim]; multiply that by freq_per_elem (broadcasted) to get theta,
compute cos/sin in one call, and assign to cos_out[...,0,:] and sin_out[...,0,:]
converted to dtype; update uses of freq_per_elem, spatial_dim_per_elem, cos_out,
sin_out, batch_size, seq_len, head_dim, pph, ppw, device, dtype accordingly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: e1b67ff6-f7ef-42c6-9de5-45eb731cfc2a

📥 Commits

Reviewing files that changed from the base of the PR and between 2c92645 and 9ac8944.

📒 Files selected for processing (2)
  • flashinfer/norm/fused_qk_rmsnorm_rope.py
  • tests/norm/test_fused_qk_rmsnorm_rope.py
✅ Files skipped from review due to trivial changes (1)
  • flashinfer/norm/fused_qk_rmsnorm_rope.py

Comment thread tests/norm/test_fused_qk_rmsnorm_rope.py
Comment thread tests/norm/test_fused_qk_rmsnorm_rope.py
AI-assisted.

Made-with: Cursor
@kahyunnam
Copy link
Copy Markdown
Member Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !583 has been updated with latest changes, and the CI pipeline #49299698 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/norm/test_fused_qk_rmsnorm_rope.py (1)

187-202: Vectorize the NeoX reference cos/sin construction.

The triple-nested Python loop over batch × seq_len × head_dim with scalar torch.cos/torch.sin writes scales poorly — for the (2, 5, 12, 32) NEOX shape that is ~983k scalar CUDA ops per test invocation, which unnecessarily stretches CI time for this module. Since freq_per_elem and spatial_dim_per_elem are already per-element tensors and pos_t/pos_h/pos_w are deterministic functions of s, this can be built with a few broadcast ops.

♻️ Suggested vectorization
-    for b in range(batch_size):
-        for s in range(seq_len):
-            tok = s
-            pos_t = tok // (pph * ppw)
-            pos_x = tok % (pph * ppw)
-            pos_h = pos_x // ppw
-            pos_w = pos_x % ppw
-            pos_ids = torch.tensor(
-                [pos_t, pos_h, pos_w], dtype=torch.float64, device=device
-            )
-
-            for d in range(head_dim):
-                pos_id = pos_ids[spatial_dim_per_elem[d]]
-                theta = pos_id * freq_per_elem[d]
-                cos_out[b, s, 0, d] = torch.cos(theta)
-                sin_out[b, s, 0, d] = torch.sin(theta)
+    tok = torch.arange(seq_len, device=device, dtype=torch.float64)
+    pos_t = torch.div(tok, pph * ppw, rounding_mode="floor")
+    pos_x = tok % (pph * ppw)
+    pos_h = torch.div(pos_x, ppw, rounding_mode="floor")
+    pos_w = pos_x % ppw
+    pos_all = torch.stack([pos_t, pos_h, pos_w], dim=-1)  # [seq_len, 3]
+    pos_per_elem = pos_all[:, spatial_dim_per_elem]       # [seq_len, head_dim]
+    theta = pos_per_elem * freq_per_elem                  # [seq_len, head_dim]
+    cos_out[:] = torch.cos(theta).view(1, seq_len, 1, head_dim).expand_as(cos_out)
+    sin_out[:] = torch.sin(theta).view(1, seq_len, 1, head_dim).expand_as(sin_out)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 187 - 202, The
triple-nested loops computing cos_out and sin_out are scalar and slow; replace
them by vectorized tensor ops: build tensor pos_ids for all sequence positions
(compute pos_t,pos_h,pos_w for s=0..seq_len-1) on device, expand to shape
(batch_size, seq_len, 3) and use spatial_dim_per_elem to index/gather the
appropriate positional component per head element, multiply by freq_per_elem
(broadcasted across batch and seq) to get theta, then compute torch.cos(theta)
and torch.sin(theta) and assign the results into cos_out[...,0,:] and
sin_out[...,0,:]; use the existing symbols cos_out, sin_out, freq_per_elem,
spatial_dim_per_elem, pph, ppw, batch_size, seq_len, head_dim and device to
locate and replace the loop.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 187-202: The triple-nested loops computing cos_out and sin_out are
scalar and slow; replace them by vectorized tensor ops: build tensor pos_ids for
all sequence positions (compute pos_t,pos_h,pos_w for s=0..seq_len-1) on device,
expand to shape (batch_size, seq_len, 3) and use spatial_dim_per_elem to
index/gather the appropriate positional component per head element, multiply by
freq_per_elem (broadcasted across batch and seq) to get theta, then compute
torch.cos(theta) and torch.sin(theta) and assign the results into
cos_out[...,0,:] and sin_out[...,0,:]; use the existing symbols cos_out,
sin_out, freq_per_elem, spatial_dim_per_elem, pph, ppw, batch_size, seq_len,
head_dim and device to locate and replace the loop.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 81de0d9c-0dc2-42e4-afd9-61d0b24e98be

📥 Commits

Reviewing files that changed from the base of the PR and between 9ac8944 and 7de01bb.

📒 Files selected for processing (3)
  • flashinfer/norm/__init__.py
  • flashinfer/norm/fused_qk_rmsnorm_rope.py
  • tests/norm/test_fused_qk_rmsnorm_rope.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • flashinfer/norm/init.py
  • flashinfer/norm/fused_qk_rmsnorm_rope.py

Register fused_qk_rmsnorm_rope in the norm benchmark family:
- Add routine name to benchmark_apis["norm"]
- Add CC-to-backends mapping (SM80+ with cuda backend)
- Add ppf/pph/ppw CSV output columns
- Add --ppf/--pph/--ppw CLI args to parse_norm_args
- Implement testFusedQkRmsnormRope with bandwidth metrics and optional refcheck

Usage:
  python benchmarks/flashinfer_benchmark.py \
    --routine fused_qk_rmsnorm_rope \
    --batch_size 1 --hidden_size 3072 --num_heads 24 \
    --ppf 5 --pph 12 --ppw 32

AI-assisted.

Made-with: Cursor
@kahyunnam kahyunnam changed the title feat: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope feat: RMSNorm + RoPe fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope Apr 23, 2026
Merge the separate fused_qk_rmsnorm_rope.py into __init__.py for
consistency with all other norm functions (rmsnorm, gemma_rmsnorm,
fused_add_rmsnorm, fused_rmsnorm_silu, etc.) which are all defined
directly in __init__.py.

Delete the separate file. No behavior change — 28 pass, 1 xfail.

AI-assisted.

Made-with: Cursor
Consistent with existing norm headers (ln_fwd_silu_kernel.cuh,
ln_silu_headers.cuh) already in the norm/ subdirectory.

AI-assisted.

Made-with: Cursor
@kahyunnam kahyunnam changed the title feat: RMSNorm + RoPe fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope feat: RMSNorm + RoPE fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope Apr 23, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants