feat: RMSNorm + RoPE fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope#3148
feat: RMSNorm + RoPE fusion for WAN: flashinfer.diffusion_ops.fused_qk_rmsnorm_rope#3148kahyunnam wants to merge 35 commits intoflashinfer-ai:mainfrom
Conversation
Single consolidated header (include/flashinfer/fused_qk_norm_rope.cuh) containing the CUDA kernel and all its dependencies: - IntFastDiv (signed fast integer division, from TRT-LLM) - packed_as vector type mapping - FP8 E4M3 quantization helpers with SM89+ PTX fast paths - Blackwell FFMA2 intrinsics with SM<100 scalar fallbacks - fusedQKNormRopeKernel template (across-heads RMSNorm + 3D RoPE + V copy) - Host launcher with frequency table cache and head_dim dispatch Adapted from share-qknormrope-layernorm-fusion/fused_QKNorm_RoPE/. AI-assisted port to FlashInfer conventions (namespace, error macros). Made-with: Cursor
csrc/fused_qk_norm_rope.cu: TVM-FFI launcher that bridges TensorView to the raw-pointer kernel. Validates inputs (CUDA, contiguous, BF16), obtains stream and num_sms internally, delegates to launchFusedQKNormRope. csrc/flashinfer_fused_qk_norm_rope_binding.cu: TVM-FFI export via TVM_FFI_DLL_EXPORT_TYPED_FUNC(run, ...). AI-assisted. Made-with: Cursor
Move the launcher from separate csrc/fused_qk_norm_rope.cu into csrc/norm.cu and add the TVM-FFI export to flashinfer_norm_binding.cu. Delete the standalone files and JIT generator since the kernel now compiles as part of gen_norm_module(). This keeps all norm operations in a single JIT module, matching the existing pattern where rmsnorm, gemma_rmsnorm, layernorm etc. share one compiled .so. Future work will similarly merge rmsnorm_silu. AI-assisted. Made-with: Cursor
flashinfer/norm/fused_qk_norm_rope.py: Public API function with @flashinfer_api and @backend_requirement decorators. Uses the shared get_norm_module() to call the fused_qk_norm_rope TVM-FFI export. Comprehensive input validation: dtype (BF16), shape (3D), head_dim (64/128/256), max heads (<=32), channel sum and even-ness, spatial dims positivity, seq_len consistency. Destination-passing style for q_out/k_out/v_out. AI-assisted. Made-with: Cursor
- flashinfer/norm/__init__.py: re-export fused_qk_norm_rope, add to __all__ - flashinfer/video_gen_ops/__init__.py: re-export facade (like dsv3_ops/) so users can `from flashinfer.video_gen_ops import fused_qk_norm_rope` AI-assisted. Made-with: Cursor
tests/norm/test_fused_qk_norm_rope.py: 17 tests covering: - Interleaved RoPE correctness (4 shapes) — all pass - NeoX RoPE correctness (2 shapes) — xfail, needs kernel-side validation - V passthrough (exact BF16 copy) - Destination-passing style - FP8 E4M3 output (3 scale values) - RoPE-only mode (is_qk_norm=False) - Error cases (non-CUDA, wrong dtype, bad head_dim, channel mismatch, seq_len) Fix: Move packed_as into fused_rope_detail sub-namespace to avoid collision with tensorrt_llm::common::packed_as visible through norm.cuh. AI-assisted. Made-with: Cursor
- Expanded interleaved shapes from 4 to 8 (matching original test suite) - NeoX tests remain xfail: reference freq mapping doesn't match kernel's (dim_idx * 2) & mask convention — needs kernel author input - Added NeoX-specific embedding helper (non-interleaved freq layout) - Relaxed RoPE-only tolerance from 0.01 to 0.05 for BF16 precision - Validated on all 3 available architectures: - L40S (SM89, Ada): 19 passed, 3 xfailed - H100 NVL (SM90, Hopper): 19 passed, 3 xfailed - A100 (SM80, Ampere): 19 passed, 3 xfailed AI-assisted. Made-with: Cursor
Bug: In the NeoX (non-interleaved) RoPE path, pos_id was computed from dim_idx_x's mapped value but reused for dim_idx_y without recomputation. Since adjacent elements in a float2 pair can map to different spatial slices (e.g. x->height, y->width), the y component got the wrong position ID, causing incorrect RoPE rotations. Fix: Recompute pos_id from dim_idx_y's mapped value before computing theta_y (one-line addition). Tests: NeoX tests now pass (were xfail). 22/22 tests pass on all three architectures (L40S SM89, H100 SM90, A100 SM80). Also added proper NeoX reference implementation that matches the kernel's per-element frequency mapping via (dim_idx * 2) & ((1 << log_head_dim) - 1). AI-assisted. Made-with: Cursor
Test configs sourced from official WAN repos: - wan2.1-1.3B: dim=1536, num_heads=12, head_dim=128 — passes - wan2.2-5B: dim=3072, num_heads=24, head_dim=128 — passes - wan2.1-14B: dim=5120, num_heads=40, head_dim=128 — xfail (40 heads * 32 threads = 1280 > CUDA max 1024 threads/block) Also updated NeoX reference implementation and removed xfail now that the kernel's pos_id bug is fixed. AI-assisted. Made-with: Cursor
benchmarks/bench_fused_qk_norm_rope.py: Compares fused kernel vs eager PyTorch (nn.RMSNorm + manual RoPE) across 8 WAN model shapes. H100 NVL CUPTI results (WAN 2.2 5B config): - Production shape (1920 tokens): 5.36x speedup - Small (480 tokens): 22.3x speedup - Large (7680 tokens): 3.62x speedup - Tiny (120 tokens): 64x speedup (launch-overhead dominated) AI-assisted. Made-with: Cursor
The kernel operates on flattened [num_tokens, hidden] layout. Now the Python API accepts both: - 3D [batch, seq_len, hidden] -> outputs [batch, seq_len, heads, head_dim] - 2D [num_tokens, hidden] -> outputs [num_tokens, heads, head_dim] For 2D input, num_tokens must be divisible by ppf*pph*ppw. Test verifies 2D and 3D produce identical results. AI-assisted. Made-with: Cursor
- Remove internal provenance details from comments - List SM90, SM100, SM103 as separate primary target architectures - Generalize packed_as namespace comment AI-assisted. Made-with: Cursor
- Replace 13 per-test `from flashinfer.norm import ...` with one top-level `from flashinfer.video_gen_ops import fused_qk_norm_rope` - Tests now exercise the public user-facing API path - Also removed internal reference in module docstring AI-assisted. Made-with: Cursor
AI-assisted. Made-with: Cursor
- Remove unused imports (math, os, gen_norm_module) - Add tuple[int, ...] annotations to fix mypy variable-length tuple error - Apply clang-format to C++/CUDA files - Apply ruff format to Python files All pre-commit checks pass. AI-assisted. Made-with: Cursor
The kernel targets DIT (Diffusion Transformer) self-attention which is used in both video and image diffusion models. diffusion_ops is a better name that won't need renaming if image diffusion kernels are added later. AI-assisted. Made-with: Cursor
Observed max diffs scale as ~0.5 * output_scale (FP8 quantization boundary rounding). Previous tolerance was max(1.0*scale, 0.5) giving 2x headroom. Tightened to max(0.75*scale, 0.375) for 1.5x headroom, with a comment explaining the observed error pattern. AI-assisted. Made-with: Cursor
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a fused CUDA operator and full-stack bindings: a kernel and host launcher, C++ entry and TVM FFI export, a Python wrapper and re-exports, comprehensive tests, and a benchmark. The operator performs RMSNorm on Q/K, applies 3D RoPE (interleaved or NeoX), optional YARN scaling and FP8 output, and copies V through. Changes
Sequence DiagramsequenceDiagram
participant Py as Python API
participant TVM as TVM FFI
participant Entry as C++ Entry
participant Host as Host Launcher
participant Kernel as CUDA Kernel
participant Dev as Device Memory
Py->>Py: validate tensors, shapes, RoPE params
Py->>TVM: call fused_qk_rmsnorm_rope(...)
TVM->>Entry: fused_qk_rmsnorm_rope_run(...)
Entry->>Entry: dtype/contiguity checks\nselect device & stream
Entry->>Host: launchFusedQKNormRope(...)
Host->>Kernel: dispatch kernel instantiation (head_dim, interleave, FP8, YARN)
par Kernel Ops
Kernel->>Dev: load QKV and weights
Kernel->>Kernel: apply RMSNorm on Q/K (optional)
Kernel->>Kernel: apply 3D RoPE (interleaved or NeoX)
Kernel->>Kernel: optional YARN scaling & FP8 quant+pack
Kernel->>Dev: write q_out, k_out, v_out
end
Dev->>Py: return or fill output tensors
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces a fused QKNorm + 3D RoPE kernel optimized for video generation DIT architectures, providing the CUDA implementation, Python API, benchmarks, and tests. Feedback identifies critical issues regarding potential integer overflows in num_tokens and baseOffset calculations for large sequences. Additionally, the global static frequency cache is not multi-GPU safe, and the use of cudaFree within the launch path introduces performance-degrading device synchronization that prevents CUDA Graph capture.
The kernel performs RMSNorm specifically (not LayerNorm or generic norm). Rename to fused_qk_rmsnorm_rope for consistency with FlashInfer's existing naming convention (rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, fused_rmsnorm_silu). All files, imports, symbols, and docstrings updated. Internal kernel function names (fusedQKNormRopeKernel, launchFusedQKNormRope) kept as-is since they are not part of the public API. 25 passed, 1 xfail after rename. AI-assisted. Made-with: Cursor
Change num_tokens from int to int64_t in the kernel signature, launcher, and all offset calculations (baseOffset, v_output_offset, inputOffset, outputBase/outputOffset, quantize_store_fp8 offset parameter). This prevents integer overflow when batch_size * seq_len * num_heads * head_dim exceeds 2^31 (e.g. batch=64 with seq_len=7680 and 24 heads). AI-assisted. Made-with: Cursor
Replace the single global s_freq_cache with a per-device std::unordered_map<int, FreqCacheEntry> keyed by CUDA device ID. This prevents crashes when the kernel is called on different GPUs within the same process (the old single cache held a device pointer that was only valid on the GPU that allocated it). AI-assisted. Made-with: Cursor
AI-assisted. Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (4)
include/flashinfer/fused_qk_rmsnorm_rope.cuh (1)
33-40:abort()on validation failure is hostile to the Python host.
FLASHINFER_FUSED_CHECKcallsabort(), which tears down the entire process — including a Python interpreter — on any kernel-side precondition failure (e.g. unsupportedhead_dim,factor != 1withattention_factor != 1). These are user-input conditions that should be recoverable. Consider throwingstd::runtime_errorinstead (or returningcudaError_t) so the FFI layer can translate it into a Python exception.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh` around lines 33 - 40, The macro FLASHINFER_FUSED_CHECK currently calls abort() (and fprintf) on failures which kills the Python process; change FLASHINFER_FUSED_CHECK to report failures via exceptions or error codes instead: replace the abort() path with throwing a std::runtime_error containing the file/line/condition message (and include <stdexcept>) or return a cudaError_t/other error value so the FFI layer can convert it to a Python exception; update all call sites in fused_qk_rmsnorm_rope.cuh that rely on FLASHINFER_FUSED_CHECK to handle the thrown exception or error return appropriately.tests/norm/test_fused_qk_rmsnorm_rope.py (2)
312-433: Minor: unused unpacked variables flagged by Ruff.
v_ref(lines 312, 391) andv_fused(lines 325, 404) are never referenced. Rename them to_(or_v_ref/_v_fused) to satisfyRUF059and make the intent clear.♻️ Proposed fix
- q_ref, k_ref, v_ref = reference_qk_norm_rope( + q_ref, k_ref, _ = reference_qk_norm_rope( ... ) qkv_combined = torch.cat([query, key, value], dim=-1).contiguous() - q_fused, k_fused, v_fused = fused_qk_rmsnorm_rope( + q_fused, k_fused, _ = fused_qk_rmsnorm_rope( ... )Apply in both
test_interleaved_correctnessandtest_neox_correctness.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 312 - 433, The test unpacks unused variables v_ref and v_fused in the calls to reference_qk_norm_rope and fused_qk_rmsnorm_rope inside test_interleaved_correctness and test_neox_correctness; rename those two unused unpacked variables to _ (or _v_ref/_v_fused) in the assignments (e.g., change "q_ref, k_ref, v_ref = reference_qk_norm_rope(...)" and "q_fused, k_fused, v_fused = fused_qk_rmsnorm_rope(...)" to use "_" for the third element) so Ruff RUF059 is satisfied while leaving q_ref, k_ref, q_fused, and k_fused unchanged.
285-288: Consider gating tests to supported compute capabilities.
fused_qk_rmsnorm_ropeis decorated withsupported_compute_capability([80, 86, 89, 90, 100, 103, 110, 120, 121]). On an unsupported device the API will raise instead of skipping, causing spurious test failures. Add a top-level skip (e.g.pytest.skipguarded viafused_qk_rmsnorm_rope.is_compute_capability_supported(cc)) so the suite cleanly skips on unsupported GPUs.As per coding guidelines: "Skip test execution on unsupported GPU architectures using flashinfer.utils check functions ... or API methods like api_name.is_compute_capability_supported(cc)".
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 285 - 288, The test should early-skip on unsupported GPU architectures: inside test_interleaved_correctness (before using the device/dtype), query the current device compute capability (e.g., via torch.cuda.get_device_capability(device) or equivalent) and call fused_qk_rmsnorm_rope.is_compute_capability_supported(cc); if it returns False, call pytest.skip with a clear message. Ensure you reference the fused_qk_rmsnorm_rope.is_compute_capability_supported check and use pytest.skip to avoid raising on unsupported GPUs.benchmarks/bench_fused_qk_rmsnorm_rope.py (1)
1-211: Consider integrating with the unified benchmarking framework.This benchmark is implemented as a standalone script. Per project convention, new kernel benchmarks should plug into
benchmarks/flashinfer_benchmark.pyso they share CLI options, output format, and tracking. CUPTI timing viabench_gpu_timeis already correct; only the harness integration is missing.As per coding guidelines: "Use the unified benchmarking framework in benchmarks/flashinfer_benchmark.py for kernel benchmarking with CUPTI timing support".
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/fused_qk_rmsnorm_rope.py`:
- Line 51: Remove the stray non-ASCII character at the end of the module
docstring sentence "All SM100+/SM89+ features have scalar fallbacks, so SM80 is
the true minimum." in fused_qk_rmsnorm_rope.py (the trailing 'å'); update the
docstring text to end with a normal ASCII period and ensure the docstring
surrounding that sentence (module-level docstring or the docstring in any
function/class near that line) has no other unintended non-ASCII characters.
- Line 33: Remove the unconditional "from . import get_norm_module" import and
add a local accessor function named _get_norm_module inside
fused_qk_rmsnorm_rope.py that imports the package at runtime (e.g. import
importlib; pkg = importlib.import_module("flashinfer.norm")) and returns
pkg.get_norm_module() if that attribute exists, otherwise raise a clear error;
then replace calls to get_norm_module().fused_qk_rmsnorm_rope(...) with
_get_norm_module().fused_qk_rmsnorm_rope(...), referencing the existing
fused_qk_rmsnorm_rope call site.
In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh`:
- Around line 672-678: The CI failure is due to clang-format changes around the
cache allocation block and a LAUNCH_KERNEL macro expansion; run the formatter
and commit the formatted file: run `pre-commit run clang-format --files
include/flashinfer/fused_qk_rmsnorm_rope.cuh` (or run your project's
clang-format) to reformat the code touching the cache.d_ptr allocation block
(the cudaFree/cudaMalloc/cache.alloc_floats section) and the LAUNCH_KERNEL
usage, then add and commit the resulting changes so the pre-commit job passes.
- Around line 659-715: The CUDA calls in the cache setup (cudaGetDevice,
cudaFree, cudaMalloc, cudaMemcpy) are unchecked and the host-to-device copy is
synchronous; update the logic around s_freq_cache_map/ FreqCacheEntry to CHECK
each cuda runtime return (use FLASHINFER_FUSED_CHECK or propagate cudaError_t)
after cudaGetDevice, cudaFree, and cudaMalloc and ensure cache.d_ptr is valid
before using it, and replace the blocking cudaMemcpy with cudaMemcpyAsync on the
caller's stream (thread/stream should be accepted or obtained) so the frequency
table upload (to cache.d_ptr) can be captured in CUDA graphs and won't write to
nullptr; keep the existing cache_miss/table_size logic and still set cache.*
fields only after successful async copy completion or error-checked launch.
- Around line 631-715: Change the header-scope cache to have single-definition
linkage (replace the static unordered_map declaration with inline or move it to
a .cu with an extern here), add a std::mutex (e.g., s_freq_cache_mutex) and wrap
accesses to s_freq_cache_map[device_id] and the entire cache-miss
read-modify-write branch with std::lock_guard to prevent races, replace raw cuda
API calls (cudaGetDevice/cudaMalloc/cudaFree/cudaMemcpy) with checked calls
using TVM_FFI_ICHECK(...) on the returned status, use cudaMemcpyAsync(...,
stream) instead of synchronous cudaMemcpy to honor the caller's stream, and stop
calling cudaFree on the cached d_ptr to preserve the "allocated once and never
freed" cudagraph-compatible behavior (only allocate with cudaMalloc when d_ptr
is null or alloc_floats is too small).
---
Nitpick comments:
In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh`:
- Around line 33-40: The macro FLASHINFER_FUSED_CHECK currently calls abort()
(and fprintf) on failures which kills the Python process; change
FLASHINFER_FUSED_CHECK to report failures via exceptions or error codes instead:
replace the abort() path with throwing a std::runtime_error containing the
file/line/condition message (and include <stdexcept>) or return a
cudaError_t/other error value so the FFI layer can convert it to a Python
exception; update all call sites in fused_qk_rmsnorm_rope.cuh that rely on
FLASHINFER_FUSED_CHECK to handle the thrown exception or error return
appropriately.
In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 312-433: The test unpacks unused variables v_ref and v_fused in
the calls to reference_qk_norm_rope and fused_qk_rmsnorm_rope inside
test_interleaved_correctness and test_neox_correctness; rename those two unused
unpacked variables to _ (or _v_ref/_v_fused) in the assignments (e.g., change
"q_ref, k_ref, v_ref = reference_qk_norm_rope(...)" and "q_fused, k_fused,
v_fused = fused_qk_rmsnorm_rope(...)" to use "_" for the third element) so Ruff
RUF059 is satisfied while leaving q_ref, k_ref, q_fused, and k_fused unchanged.
- Around line 285-288: The test should early-skip on unsupported GPU
architectures: inside test_interleaved_correctness (before using the
device/dtype), query the current device compute capability (e.g., via
torch.cuda.get_device_capability(device) or equivalent) and call
fused_qk_rmsnorm_rope.is_compute_capability_supported(cc); if it returns False,
call pytest.skip with a clear message. Ensure you reference the
fused_qk_rmsnorm_rope.is_compute_capability_supported check and use pytest.skip
to avoid raising on unsupported GPUs.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 010f0d49-9914-43ac-9445-05a33dddf336
📒 Files selected for processing (8)
benchmarks/bench_fused_qk_rmsnorm_rope.pycsrc/flashinfer_norm_binding.cucsrc/norm.cuflashinfer/diffusion_ops/__init__.pyflashinfer/norm/__init__.pyflashinfer/norm/fused_qk_rmsnorm_rope.pyinclude/flashinfer/fused_qk_rmsnorm_rope.cuhtests/norm/test_fused_qk_rmsnorm_rope.py
AI-assisted. Made-with: Cursor
Move `from . import get_norm_module` from module top-level to inside the function body. get_norm_module is only defined in norm/__init__.py when _USE_CUDA_NORM is True; lazy import avoids failure if the CuTe DSL path is active at import time. AI-assisted. Made-with: Cursor
AI-assisted. Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 1
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 38f12cdd-e7fd-4061-9f0e-39707ccab638
📒 Files selected for processing (1)
include/flashinfer/fused_qk_rmsnorm_rope.cuh
Remove cudaFree call on the old d_ptr when growing the allocation.
The old pointer is intentionally leaked to remain valid for any
cudagraph that captured a kernel referencing it. This aligns with the
stated design intent ("allocated once and never freed").
Also added comment explaining why cudaMemcpy is synchronous (source
is a stack-allocated array that goes out of scope after the call).
AI-assisted.
Made-with: Cursor
- Add FLASHINFER_FUSED_CHECK on cudaGetDevice and cudaMalloc to catch errors early instead of silently dereferencing nullptr - Switch cudaMemcpy to cudaMemcpyAsync on the caller's stream (CUDA spec: async from unpinned host memory is synchronous, so stack- allocated h_freq_table is safe; using async form for cudagraph capture compatibility) AI-assisted. Made-with: Cursor
- Assert num_tokens > 0 and <= INT32_MAX before narrowing to dim3 - Add cudaGetLastError check after kernel launch to surface cudaErrorInvalidConfiguration or other launch failures AI-assisted. Made-with: Cursor
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
include/flashinfer/fused_qk_rmsnorm_rope.cuh (1)
644-678:⚠️ Potential issue | 🟠 MajorRemaining gaps from prior review: header-scope
staticcache + missing mutex.The CUDA error-checking, async memcpy on caller's stream, and grid-dim/launch-error guards are all in now — nice. Two items from the earlier critical review still stand:
Line 646:
s_freq_cache_mapisstaticat namespace scope in a header → internal linkage, one copy per TU that includes this file. It is currently safe only becausecsrc/norm.cuis the sole includer; any future inclusion will silently create a parallel cache (and, with the deliberate leak-old-pointer policy at L672-676, leak independently). Preferinline(C++17) or move the definition into a.cuwith anexterndeclaration here.Lines 667, 671-678: Access to
s_freq_cache_map[device_id]and the RMW of theFreqCacheEntry(cudaMalloc + field updates at L674-677) aren't synchronized. If the FFI runtime ever releases the GIL around the launcher (or different Python threads drive the same device), you get aunordered_mapdata race and a torn write ofcache.d_ptrobservable by a concurrent kernel launch. Astd::mutexaround the cache-miss region (and theoperator[]lookup) is sufficient.🛠️ Suggested shape
-// Per-device frequency table cache. Keyed by CUDA device ID so that -// multi-GPU usage within a single process is safe. -static std::unordered_map<int, FreqCacheEntry> s_freq_cache_map; +// Per-device frequency table cache. Keyed by CUDA device ID so that +// multi-GPU usage within a single process is safe. +inline std::unordered_map<int, FreqCacheEntry> s_freq_cache_map; +inline std::mutex s_freq_cache_mutex;and wrap the lookup + cache-miss RMW region in a
std::lock_guard<std::mutex>(plus#include <mutex>).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh` around lines 644 - 678, s_freq_cache_map is declared static at header scope causing one copy per TU and racey concurrent access to s_freq_cache_map[device_id] and FreqCacheEntry RMW in launchFusedQKNormRope; change the declaration to inline std::unordered_map<int, FreqCacheEntry> s_freq_cache_map (or move the definition into a .cu with an extern in the header) and add a global std::mutex (e.g., s_freq_cache_mutex) with `#include` <mutex>, then wrap the operator[] lookup and the cache-miss branch that does cudaMalloc and updates cache.d_ptr/cache.alloc_floats in a std::lock_guard<std::mutex> to serialize map access and the pointer/size write.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/norm/fused_qk_rmsnorm_rope.py`:
- Around line 244-254: The pre-allocated destination buffers and weight tensors
are not validated before being .view()'d and passed to the kernel; add checks to
ensure user-supplied q_out/k_out/v_out and q_weight/k_weight match expected
dtype (respecting output_fp8), device (same as qkv.device), shape
(q_out/k_out/v_out shape equals out_shape_q/out_shape_k/out_shape_v and
q_weight/k_weight length equals num_heads_* * head_dim), and are contiguous
(torch.is_contiguous) to prevent silent memory corruption; implement these
validations either by extending _check_fused_qk_rmsnorm_rope to accept
kwargs.get("q_out","k_out","v_out","q_weight","k_weight") or add an inline
validation block before creating qkv_flat/q_out_flat/k_out_flat/v_out_flat that
raises a clear ValueError if any check fails.
---
Duplicate comments:
In `@include/flashinfer/fused_qk_rmsnorm_rope.cuh`:
- Around line 644-678: s_freq_cache_map is declared static at header scope
causing one copy per TU and racey concurrent access to
s_freq_cache_map[device_id] and FreqCacheEntry RMW in launchFusedQKNormRope;
change the declaration to inline std::unordered_map<int, FreqCacheEntry>
s_freq_cache_map (or move the definition into a .cu with an extern in the
header) and add a global std::mutex (e.g., s_freq_cache_mutex) with `#include`
<mutex>, then wrap the operator[] lookup and the cache-miss branch that does
cudaMalloc and updates cache.d_ptr/cache.alloc_floats in a
std::lock_guard<std::mutex> to serialize map access and the pointer/size write.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 7ed2b704-48a6-4653-a1cd-f6770ed6608a
📒 Files selected for processing (2)
flashinfer/norm/fused_qk_rmsnorm_rope.pyinclude/flashinfer/fused_qk_rmsnorm_rope.cuh
Add inline checks before passing to the kernel: - q_weight/k_weight: size matches num_heads*head_dim, dtype is BF16, contiguous - q_out/k_out/v_out (when user-supplied): shape, dtype (BF16 or FP8 matching output_fp8 flag), contiguity, and device match Prevents silent memory corruption from mismatched buffers. AI-assisted. Made-with: Cursor
- test_error_wrong_weight_size: wrong q_weight length - test_error_wrong_output_shape: q_out with wrong num_heads - test_error_wrong_output_dtype: q_out with float16 instead of bfloat16 Exercises the validation added in the previous commit. AI-assisted. Made-with: Cursor
get_norm_module was only defined when _USE_CUDA_NORM was True, causing ImportError in CI environments where nvidia-cutlass-dsl is installed (CuTe DSL path active). The fused_qk_rmsnorm_rope kernel has no CuTe DSL alternative and always needs the CUDA JIT module. Move get_norm_module out of the conditional block so it's always available. Existing norm functions still guard calls with `if _USE_CUDA_NORM:` so behavior is unchanged for them. AI-assisted. Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/norm/test_fused_qk_rmsnorm_rope.py (2)
480-481:test_v_passthroughusesis_qk_norm=Truewith RMSNorm weights=1 — that's fine for V, but misleading for documenting "exact copy".With
is_qk_norm=Truethe kernel still runs RMSNorm on Q/K; V is passthrough regardless ofis_qk_norm. The test name and comment suggest verifying a bit-exact V copy (which it does), but it would be slightly more robust to also coveris_qk_norm=Falsehere so the passthrough guarantee is asserted on both code paths.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 480 - 481, The test test_v_passthrough currently asserts V is an exact copy only under is_qk_norm=True; update the test to also run the same assertion with is_qk_norm=False so the passthrough behavior is validated on both code paths. Locate the test_v_passthrough function and add a second sub-case (or parametrize the test) that calls the same kernel/setup with is_qk_norm=False, computes v_fused and v_expected, and asserts torch.equal(v_fused, v_expected) for that case as well; keep existing checks for is_qk_norm=True unchanged.
187-204: Vectorize NeoX reference to avoid triple-nested Python loops.The per-element Python loop over
(batch, seq_len, head_dim)is O(B·S·D) host-side scalar work and will dominate test time ifNEOX_SHAPESgrows. Current shapes keep this tractable, but the loop can be replaced with broadcasted tensor ops using the precomputedfreq_per_elem/spatial_dim_per_elemvectors, e.g. gatherpos_ids[..., spatial_dim_per_elem]and multiply byfreq_per_elem, thencos/sinin one shot.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 187 - 204, Replace the triple-nested loops that compute cos_out and sin_out by fully vectorized tensor ops: build a pos_ids tensor for all tokens (shape [batch_size, seq_len, 3]) using the same pos_t/pos_h/pos_w logic (vectorize tok→pos using arange and div/mod with pph and ppw), then gather the per-head spatial coordinate with spatial_dim_per_elem (use torch.tensor(spatial_dim_per_elem, device=...) and torch.take_along_dim or torch.gather after expanding pos_ids) to produce pos_per_head shape [batch_size, seq_len, head_dim]; multiply that by freq_per_elem (broadcasted) to get theta, compute cos/sin in one call, and assign to cos_out[...,0,:] and sin_out[...,0,:] converted to dtype; update uses of freq_per_elem, spatial_dim_per_elem, cos_out, sin_out, batch_size, seq_len, head_dim, pph, ppw, device, dtype accordingly.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 980-1002: The failing CI is caused by a too-long inline comment in
test_error_wrong_output_shape and by unused-unpack names flagged as RUF059; in
tests modify the long inline comment in fused_qk_rmsnorm_rope's
test_error_wrong_output_shape (replace or wrap the "12 heads, not 24" inline
comment so it doesn't exceed line length) and in test_interleaved_correctness
and test_neox_correctness replace the unused unpack targets v_ref and v_fused
with throwaway names (e.g., _ or _v_ref/_v_fused) or otherwise consume them so
they are not reported as unused; target the occurrences of v_ref/v_fused at the
unpack sites reported (lines referenced in the review) to silence the
unused-unpack warning.
- Around line 14-19: The test uses a hardcoded SM80 architecture guard; instead
call the decorated function's capability helper to decide skipping: use
fused_qk_rmsnorm_rope.is_compute_capability_supported(cc) (or iterate available
device compute capability) to skip tests for unsupported GPUs instead of
checking major < 8 so that SM100/103/110+ are correctly recognized; update the
module-level guard in tests/norm/test_fused_qk_rmsnorm_rope.py to invoke
fused_qk_rmsnorm_rope.is_compute_capability_supported for the current device
compute capability.
---
Nitpick comments:
In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 480-481: The test test_v_passthrough currently asserts V is an
exact copy only under is_qk_norm=True; update the test to also run the same
assertion with is_qk_norm=False so the passthrough behavior is validated on both
code paths. Locate the test_v_passthrough function and add a second sub-case (or
parametrize the test) that calls the same kernel/setup with is_qk_norm=False,
computes v_fused and v_expected, and asserts torch.equal(v_fused, v_expected)
for that case as well; keep existing checks for is_qk_norm=True unchanged.
- Around line 187-204: Replace the triple-nested loops that compute cos_out and
sin_out by fully vectorized tensor ops: build a pos_ids tensor for all tokens
(shape [batch_size, seq_len, 3]) using the same pos_t/pos_h/pos_w logic
(vectorize tok→pos using arange and div/mod with pph and ppw), then gather the
per-head spatial coordinate with spatial_dim_per_elem (use
torch.tensor(spatial_dim_per_elem, device=...) and torch.take_along_dim or
torch.gather after expanding pos_ids) to produce pos_per_head shape [batch_size,
seq_len, head_dim]; multiply that by freq_per_elem (broadcasted) to get theta,
compute cos/sin in one call, and assign to cos_out[...,0,:] and sin_out[...,0,:]
converted to dtype; update uses of freq_per_elem, spatial_dim_per_elem, cos_out,
sin_out, batch_size, seq_len, head_dim, pph, ppw, device, dtype accordingly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: e1b67ff6-f7ef-42c6-9de5-45eb731cfc2a
📒 Files selected for processing (2)
flashinfer/norm/fused_qk_rmsnorm_rope.pytests/norm/test_fused_qk_rmsnorm_rope.py
✅ Files skipped from review due to trivial changes (1)
- flashinfer/norm/fused_qk_rmsnorm_rope.py
AI-assisted. Made-with: Cursor
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/norm/test_fused_qk_rmsnorm_rope.py (1)
187-202: Vectorize the NeoX reference cos/sin construction.The triple-nested Python loop over
batch × seq_len × head_dimwith scalartorch.cos/torch.sinwrites scales poorly — for the(2, 5, 12, 32)NEOX shape that is ~983k scalar CUDA ops per test invocation, which unnecessarily stretches CI time for this module. Sincefreq_per_elemandspatial_dim_per_elemare already per-element tensors andpos_t/pos_h/pos_ware deterministic functions ofs, this can be built with a few broadcast ops.♻️ Suggested vectorization
- for b in range(batch_size): - for s in range(seq_len): - tok = s - pos_t = tok // (pph * ppw) - pos_x = tok % (pph * ppw) - pos_h = pos_x // ppw - pos_w = pos_x % ppw - pos_ids = torch.tensor( - [pos_t, pos_h, pos_w], dtype=torch.float64, device=device - ) - - for d in range(head_dim): - pos_id = pos_ids[spatial_dim_per_elem[d]] - theta = pos_id * freq_per_elem[d] - cos_out[b, s, 0, d] = torch.cos(theta) - sin_out[b, s, 0, d] = torch.sin(theta) + tok = torch.arange(seq_len, device=device, dtype=torch.float64) + pos_t = torch.div(tok, pph * ppw, rounding_mode="floor") + pos_x = tok % (pph * ppw) + pos_h = torch.div(pos_x, ppw, rounding_mode="floor") + pos_w = pos_x % ppw + pos_all = torch.stack([pos_t, pos_h, pos_w], dim=-1) # [seq_len, 3] + pos_per_elem = pos_all[:, spatial_dim_per_elem] # [seq_len, head_dim] + theta = pos_per_elem * freq_per_elem # [seq_len, head_dim] + cos_out[:] = torch.cos(theta).view(1, seq_len, 1, head_dim).expand_as(cos_out) + sin_out[:] = torch.sin(theta).view(1, seq_len, 1, head_dim).expand_as(sin_out)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/norm/test_fused_qk_rmsnorm_rope.py` around lines 187 - 202, The triple-nested loops computing cos_out and sin_out are scalar and slow; replace them by vectorized tensor ops: build tensor pos_ids for all sequence positions (compute pos_t,pos_h,pos_w for s=0..seq_len-1) on device, expand to shape (batch_size, seq_len, 3) and use spatial_dim_per_elem to index/gather the appropriate positional component per head element, multiply by freq_per_elem (broadcasted across batch and seq) to get theta, then compute torch.cos(theta) and torch.sin(theta) and assign the results into cos_out[...,0,:] and sin_out[...,0,:]; use the existing symbols cos_out, sin_out, freq_per_elem, spatial_dim_per_elem, pph, ppw, batch_size, seq_len, head_dim and device to locate and replace the loop.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/norm/test_fused_qk_rmsnorm_rope.py`:
- Around line 187-202: The triple-nested loops computing cos_out and sin_out are
scalar and slow; replace them by vectorized tensor ops: build tensor pos_ids for
all sequence positions (compute pos_t,pos_h,pos_w for s=0..seq_len-1) on device,
expand to shape (batch_size, seq_len, 3) and use spatial_dim_per_elem to
index/gather the appropriate positional component per head element, multiply by
freq_per_elem (broadcasted across batch and seq) to get theta, then compute
torch.cos(theta) and torch.sin(theta) and assign the results into
cos_out[...,0,:] and sin_out[...,0,:]; use the existing symbols cos_out,
sin_out, freq_per_elem, spatial_dim_per_elem, pph, ppw, batch_size, seq_len,
head_dim and device to locate and replace the loop.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 81de0d9c-0dc2-42e4-afd9-61d0b24e98be
📒 Files selected for processing (3)
flashinfer/norm/__init__.pyflashinfer/norm/fused_qk_rmsnorm_rope.pytests/norm/test_fused_qk_rmsnorm_rope.py
🚧 Files skipped from review as they are similar to previous changes (2)
- flashinfer/norm/init.py
- flashinfer/norm/fused_qk_rmsnorm_rope.py
Register fused_qk_rmsnorm_rope in the norm benchmark family:
- Add routine name to benchmark_apis["norm"]
- Add CC-to-backends mapping (SM80+ with cuda backend)
- Add ppf/pph/ppw CSV output columns
- Add --ppf/--pph/--ppw CLI args to parse_norm_args
- Implement testFusedQkRmsnormRope with bandwidth metrics and optional refcheck
Usage:
python benchmarks/flashinfer_benchmark.py \
--routine fused_qk_rmsnorm_rope \
--batch_size 1 --hidden_size 3072 --num_heads 24 \
--ppf 5 --pph 12 --ppw 32
AI-assisted.
Made-with: Cursor
Merge the separate fused_qk_rmsnorm_rope.py into __init__.py for consistency with all other norm functions (rmsnorm, gemma_rmsnorm, fused_add_rmsnorm, fused_rmsnorm_silu, etc.) which are all defined directly in __init__.py. Delete the separate file. No behavior change — 28 pass, 1 xfail. AI-assisted. Made-with: Cursor
Consistent with existing norm headers (ln_fwd_silu_kernel.cuh, ln_silu_headers.cuh) already in the norm/ subdirectory. AI-assisted. Made-with: Cursor
📌 Description
#2971
Add a fused CUDA kernel for across-heads QK RMSNorm + 3D Rotary Position Embeddings (RoPE) + V copy, targeting video generation DIT (Diffusion Transformer) self-attention workloads such as WAN 2.1/2.2.
This kernel fuses three operations into a single launch:
hidden_dim = num_heads * head_dim, not per-head)Optional FP8 E4M3 quantized output with SM89+ vectorized PTX conversion and SM100+ Blackwell FFMA2 intrinsics (all with scalar fallbacks for SM80+).
API
User-facing API follows the
dsv3_opspattern — implementation lives inflashinfer/norm/(alongsidermsnorm,gemma_rmsnorm, etc.) with a re-export facade atflashinfer/diffusion_ops/:Benchmark Results (B200 (sm100), CUPTI)
Bug Fix
Found and fixed a bug in the NeoX (non-interleaved) RoPE path:
pos_idwas computed fromdim_idx_x's mapped value but reused fordim_idx_ywithout recomputation. Since the(dim_idx * 2) & maskmapping can place adjacent elements in a float2 pair into different spatial slices (e.g., height vs width), the y component received an incorrect position ID. Fix: one line to recomputepos_idfromdim_idx_y. The interleaved path (used in production) was unaffected.Architecture Support
Known Limitations
num_heads ≤ 32: The kernel uses one warp per head, somax_heads × 32 = 1024threads per block (CUDA maximum). WAN 14B (40 heads) is unsupported. Supporting it would require a kernel redesign (multi-head per warp).Files Changed
🔍 Related Issues
#2971
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Test summary: 26 tests (25 pass, 1 xfail)
[num_tokens, hidden]inputis_qk_norm=False)Validated on 3 GPU architectures:
Reviewer Notes
normJIT module (gen_norm_module()), so no new JIT spec or AOT registration is needed.fused_qk_rmsnorm_rope.cuh— reviewers may want to verify thepos_idrecomputation logic fordim_idx_y.diffusion_ops/facade follows the exact same pattern asdsv3_ops/— pure re-export, no implementation.rmsnorm_silufor WAN) can be added todiffusion_ops/.Summary by CodeRabbit
New Features
Tests