Skip to content

[fmha-v2] Support HND and NHD paged KV cache layouts with conditional stride handling#2799

Open
zhou-yuxin wants to merge 5 commits intoflashinfer-ai:mainfrom
zhou-yuxin:HND
Open

[fmha-v2] Support HND and NHD paged KV cache layouts with conditional stride handling#2799
zhou-yuxin wants to merge 5 commits intoflashinfer-ai:mainfrom
zhou-yuxin:HND

Conversation

@zhou-yuxin
Copy link
Copy Markdown
Contributor

@zhou-yuxin zhou-yuxin commented Mar 17, 2026

Previously, the fmha_v2 paged kv kernels only supported the HND layout. When users passed in an NHD layout, transpose(-3, -2).contiguous() is called, which introduced significant overhead. This pull request adds native support for the NHD layout to the fmha_v2 kernels by allowing users to pass in custom TMA strides.

Summary by CodeRabbit

  • New Features

    • Added support for both HND and NHD paged key/value tensor layouts.
    • Added stride metadata to support 4D paged KV layouts.
  • Bug Fixes

    • Corrected paged key/value stride handling so strides map to layout dimensions.
    • Broadened flash-attention kernel use by removing prior gating.
    • Simplified prefill for paged KV (no extra transpose) and fixed page-size detection.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Mar 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Added explicit second-per-dimension K/V stride fields for 4D paged-KV layouts, propagated an is_paged_hnd flag from Python through the C++ launcher into host DMA params, changed TMA stride assignment for Q_PAGED_KV to use caller-provided byte strides directly, updated paged-KV byte-offset/head addressing, and adjusted Python prefill extraction for both paged-KV shapes.

Changes

Cohort / File(s) Summary
API / Params
csrc/fmha_v2/fused_multihead_attention.h
Added k_stride_in_bytes_2 and v_stride_in_bytes_2 to Fused_multihead_attention_params_v2 to carry a second stride for 4D paged-KV layouts.
Launcher / runtime
csrc/fmha_v2_run.cu
Propagated is_paged_hnd from string_to_input_layout into fmha_v2_run and set_params; select h_kv and K/V stride fields using is_paged_hnd; unconditional flash-attention flag for built kernels.
Host DMA / TMA descriptor
csrc/fmha_v2/fmha/warpspec/dma.h
For Q_PAGED_KV, TMA tensor strides for K and V now use caller-provided byte strides directly (removed division by tokens_per_block) and accept the new _2 stride fields for the alternate tensor dimension.
GMEM tiled loads
csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h
Switched to separate byte-stride/offset handling for per-head addressing: introduced head_offset_in_bytes_ from _2 stride fields, adjusted pointer arithmetic, and changed token_stride_in_bytes_ to int64_t set directly from params.
Python prefill
flashinfer/prefill.py
Removed transpose/contiguous for NHD path; now unbinds directly from original paged_kv; page_size selection adjusted to use the correct axis depending on NHD vs HND layout.

Sequence Diagram(s)

sequenceDiagram
    participant Py as Python prefill
    participant Launch as C++ launcher (fmha_v2_run)
    participant Host as Host DMA (dma.h)
    participant Dev as Device kernel / TMA

    Py->>Launch: fmha_v2_run(input_layout, paged_kv, ...)
    Launch->>Launch: string_to_input_layout(..., is_paged_hnd)
    Launch->>Host: set_params(..., is_paged_hnd, k_stride_in_bytes, k_stride_in_bytes_2, v_stride_in_bytes, v_stride_in_bytes_2)
    Host->>Dev: emit TMA descriptors (select stride [1] vs [2] based on is_paged_hnd)
    Py->>Py: extract k_cache, v_cache (unbind dim per layout)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci, op: attention

Suggested reviewers

  • sricketts
  • aleozlx
  • yongwww
  • cyx-6
  • djmmoss
  • yzh119
  • bkryu
  • jiahanc
  • IwakuraRein

Poem

🐰 I hopped from Python down to C++,

carrying strides in twos and plus.
HND or NHD, I pick the stride,
TMA marches, offsets glide.
Tiny hops — paged KV feels just right!

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description explains the motivation (eliminating transpose overhead) and the solution (native NHD support via custom TMA strides), but lacks required template sections like Related Issues, Pre-commit Checks, and Tests completion status. Add missing sections from the template: Related Issues, Pre-commit Checks checklist, Tests section, and optional Reviewer Notes to meet the repository's PR documentation standards.
Docstring Coverage ⚠️ Warning Docstring coverage is 58.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately reflects the main change: adding support for both HND and NHD paged KV cache layouts with conditional stride handling in the fmha-v2 implementation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the FlashInfer FMHAv2 implementation by introducing performance optimizations and new attention features. It addresses potential integer overflows in memory addressing, improves flexibility for variable sequence lengths, and integrates a 'skip-softmax' mechanism to reduce computational overhead. The changes also refine paged KV cache handling and expand the kernel generation framework to support these new capabilities, all validated through comprehensive new test cases.

Highlights

  • Integer Overflow Prevention: Added explicit (int64_t) casts to bidx and block_info.bidx multiplications in memory offset calculations across gmem_tile_o_packed.h and gmem_tile_ps.h to prevent potential integer overflow issues with large offsets.
  • Flash Attention Kernel Traits Update: Modified kernel_traits.h to correctly handle S=0 for variable sequence lengths in flash attention, adjusting static assertions and the calculation of TOTAL_BMM2_MMAS_K.
  • Skip-Softmax Attention Feature: Introduced a 'skip-softmax' attention mechanism with new ENABLE_SKIP_SOFTMAX parameters, shared memory voting for warps, and statistics tracking in warpspec/compute.h and warpspec/epilogue.h to optimize attention computation.
  • Paged KV Cache Stride Refinement: Corrected kv_idx_end calculation for causal mask early stopping and updated paged KV cache stride logic in warpspec/dma.h to utilize new k_stride_in_bytes_2 and v_stride_in_bytes_2 fields.
  • FMHAv2 Parameter Expansion: Extended Fused_multihead_attention_params_v2 and Fused_multihead_attention_launch_params in fused_multihead_attention.h to include new paged KV cache strides and skip-softmax related parameters and statistics.
  • New Kernel Generation Templates: Added new Jinja templates (fa_kernel.jinja, kernel.jinja, kernel_hopper.jinja, kernel_hopper_ws.jinja) to facilitate the generation of various FMHAv2 kernel variants.
  • JIT Binding and Runtime Execution: Introduced new C++ source files (fmha_v2_jit_binding.cu, fmha_v2_run.cu) to provide JIT binding and runtime execution capabilities for the FMHAv2 kernels.
  • Python API and Kernel Generation Integration: Updated Python modules (flashinfer/__init__.py, flashinfer/jit/__init__.py, flashinfer/jit/attention/__init__.py, flashinfer/jit/attention/fmha_v2/fmha_library.py, flashinfer/jit/attention/fmha_v2/generator_utils.py, flashinfer/jit/attention/fmha_v2/utils.py, flashinfer/prefill.py) to integrate the new FMHAv2 kernels, including skip-softmax parameters and refined kernel generation logic for TRT-LLM.
  • Comprehensive FMHAv2 Prefill Testing: Added a new test file tests/attention/test_fmha_v2_prefill.py to thoroughly validate the new FMHAv2 prefill functionality, covering deepseek, skip-softmax, and chunked attention scenarios.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • csrc/fmha_v2/fmha/gmem_tile_o_packed.h
    • Added (int64_t) casts to bidx multiplications to prevent integer overflow.
  • csrc/fmha_v2/fmha/gmem_tile_ps.h
    • Added (int64_t) casts to bidx multiplications to prevent integer overflow.
  • csrc/fmha_v2/fmha/hopper/gmem_tile_o_packed.h
    • Added (int64_t) casts to block_info.bidx multiplications to prevent integer overflow.
  • csrc/fmha_v2/fmha/kernel_traits.h
    • Modified static assertions and TOTAL_BMM2_MMAS_K calculation to support S=0 for variable sequence length.
  • csrc/fmha_v2/fmha/warpspec/compute.h
    • Modified compute_single_tile function call to pass skip_softmax_vote pointer.
    • Added logic to update softmax.skip_softmax_threshold based on actual_kv_seqlen when ENABLE_SKIP_SOFTMAX is true.
    • Added conditional atomicAdd operations for skip_softmax_total_blocks and skip_softmax_skipped_blocks under SKIP_SOFTMAX_STAT.
    • Modified compute_single_tile function signature to include uint32_t* skip_softmax_vote parameter.
    • Added initialization of *skip_softmax_vote = 1 at the beginning of compute_single_tile.
    • Added named_barrier_wait before BMM1 if skip-softmax is enabled.
    • Modified softmax.compute_and_update_scale call to pass skip_softmax_vote and added conditional skip logic.
  • csrc/fmha_v2/fmha/warpspec/dma.h
    • Corrected kv_idx_end calculation for causal mask early stopping.
    • Updated tensor_stride_k and tensor_stride_v calculations for paged KV cache to use k_stride_in_bytes_2 and v_stride_in_bytes_2.
  • csrc/fmha_v2/fmha/warpspec/epilogue.h
    • Included fmha/hopper/arrive_wait.h.
    • Added SKIP_SOFTMAX_BARRIER enum to Softmax_base.
    • Added total_blocks, skipped_blocks (under SKIP_SOFTMAX_STAT), and skip_softmax_threshold members to Softmax_base.
    • Modified compute_and_update_scale function signature to include uint32_t* skip_softmax_vote parameter.
    • Implemented skip-softmax logic within compute_and_update_scale, including warp-level voting and conditional return.
    • Changed compute_and_update_scale return type to bool.
  • csrc/fmha_v2/fmha/warpspec/kernel_traits.h
    • Added ENABLE_SKIP_SOFTMAX_ template parameter.
    • Added ENABLE_SKIP_SOFTMAX enum.
    • Added SKIP_SOFTMAX_BARRIER_ID constant.
    • Added skip_softmax_votes array to Shared struct for inter-warp communication.
    • Updated Kernel_traits_Hopper_qgmma_e4m3_fp32 to pass ENABLE_SKIP_SOFTMAX_ to its base class.
    • Added skip_softmax_votes array to Shared struct in Kernel_traits_Hopper_qgmma_e4m3_fp32.
  • csrc/fmha_v2/fused_multihead_attention.h
    • Added k_stride_in_bytes_2 and v_stride_in_bytes_2 fields to Fused_multihead_attention_params_v2 for paged KV cache.
    • Added skip_softmax_threshold_scale_factor, skip_softmax_total_blocks, and skip_softmax_skipped_blocks fields to Fused_multihead_attention_params_v2.
    • Added enable_skip_softmax field to Fused_multihead_attention_launch_params.
  • csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h
    • Added skip_softmax_threshold_scale_factor, skip_softmax_total_blocks, and skip_softmax_skipped_blocks fields to Fused_multihead_attention_params_v2.
  • csrc/fmha_v2/templates/fa_kernel.jinja
    • Added new file for Flash Attention kernel template.
  • csrc/fmha_v2/templates/kernel.jinja
    • Added new file for general kernel template.
  • csrc/fmha_v2/templates/kernel_hopper.jinja
    • Added new file for Hopper-specific kernel template.
  • csrc/fmha_v2/templates/kernel_hopper_ws.jinja
    • Added new file for Hopper warp-specialized kernel template.
  • csrc/fmha_v2_jit_binding.cu
    • Added new file for JIT binding of FMHAv2 functions.
  • csrc/fmha_v2_run.cu
    • Added new file for runtime execution logic of FMHAv2.
  • flashinfer/init.py
    • Imported trtllm_fmha_v2_prefill.
  • flashinfer/jit/init.py
    • Renamed get_trtllm_fmha_v2_module to gen_trtllm_fmha_v2_sm120_module and added gen_fmha_v2_module.
  • flashinfer/jit/attention/init.py
    • Renamed get_trtllm_fmha_v2_module to gen_trtllm_fmha_v2_sm120_module and added gen_fmha_v2_module.
  • flashinfer/jit/attention/fmha_v2/fmha_library.py
    • Updated FMHAv2KernelSpec dataclass to include enable_skip_softmax.
    • Modified select_ldgsts logic for SM120.
    • Updated generate_kernel_spec to configure kv_loop_step and kv_tile_buffers based on head_size and dtype for warp-specialized kernels.
    • Added enable_skip_softmax to is_kernel_spec_valid checks.
    • Updated get_kernel_code to include enable_skip_softmax_flag in template rendering.
    • Updated get_api_code to include enable_skip_softmax in kernel dispatch logic.
    • Updated generate_jit_sources to include enable_skip_softmax in kernel generation.
  • flashinfer/jit/attention/fmha_v2/generator_utils.py
    • Added enable_skip_softmax to kernel_spec namedtuple.
    • Updated get_makefile_code to include enable_skip_softmax_flag.
    • Updated encode_name to include _skipSoftmax tag.
    • Updated get_cubin_header to include enable_skip_softmax in use_cubin_header check and metadata generation.
    • Updated modify_cubin_header to include false for enable_skip_softmax_flag in a specific kernel line.
    • Updated enumerate_hgmma_flash_warpspec_kernels and enumerate_qgmma_flash_warpspec_kernels to include enable_skip_softmax parameter.
    • Updated enumerate_kernels to iterate over enable_skip_softmax values.
  • flashinfer/jit/attention/fmha_v2/utils.py
    • Added enable_skip_softmax to kernel_spec namedtuple and its defaults.
    • Updated encode_name to include _skipSoftmax tag.
    • Updated get_api_code to include enable_skip_softmax in kernel dispatch logic.
  • flashinfer/prefill.py
    • Added _create_scale_bmm2_d_tensor function to handle scale_bmm2_d tensor creation.
    • Renamed get_trtllm_fmha_v2_module to get_trtllm_fmha_v2_sm120_module and added get_trtllm_fmha_v2_module (new JIT-compiled version).
    • Updated fmha_v2_prefill_deepseek to use get_trtllm_fmha_v2_sm120_module.
    • Added trtllm_fmha_v2_prefill function, a new API for TRT-LLM FMHAv2 prefill attention, supporting various input layouts, mask modes, and features like skip-softmax.
  • tests/attention/test_fmha_v2_prefill.py
    • Added new test file for fmha_v2_prefill_deepseek and trtllm_fmha_v2_prefill, including tests for skip-softmax and chunked attention.
Activity
  • Introduced a new 'skip-softmax' attention feature, designed to enhance performance by conditionally skipping softmax and BMM2 computations.
  • Refactored the kernel generation and dispatch logic, specifically for TRT-LLM FMHAv2, to accommodate new features and support advanced architectures like SM12x.
  • Implemented new test cases to thoroughly validate the correctness and functionality of the newly added features, including skip-softmax and chunked attention.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant "Skip-Softmax" optimization for attention computation, which can improve performance by skipping softmax and BMM2 for negligible attention scores. The implementation is well-structured across the C++ kernel, traits, and epilogue, with careful handling of synchronization. Additionally, the PR includes several important correctness fixes, such as preventing integer overflows in pointer arithmetic and correcting the causal mask boundary calculation. The Python-level changes correctly expose the new features and refactor the JIT module generation for better organization. I've identified one potential issue related to a division-by-zero that should be addressed.

Comment thread csrc/fmha_v2/fmha/warpspec/compute.h
@zhou-yuxin zhou-yuxin marked this pull request as ready for review March 17, 2026 03:13
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@csrc/fmha_v2_run.cu`:
- Line 53: The non-warp-specialized paged-KV reader in
fmha/gmem_tile_qkv_packed.h still assumes k_stride_in_bytes / v_stride_in_bytes
are per-head strides (and computes token stride via >>
paged_kv_log2_block_size_), but the change in Attention_input_layout repurposed
these into generic 4-D layout strides; update the reader to consume the new
secondary fields (k_stride_in_bytes_2 and v_stride_in_bytes_2) when calculating
token and head offsets (or alternatively retain the old per-head semantics for
k_stride_in_bytes/v_stride_in_bytes until both readers are migrated), making the
logic in the non-warp Q_PAGED_KV path consistent with the new layout semantics
and matching how the warp-specialized reader uses *_stride_in_bytes_2.
- Around line 322-335: In string_to_input_layout initialize the out-parameter
is_paged_hnd at the top (e.g. set is_paged_hnd = false) before any early returns
so it is always defined for callers like set_params; keep the existing per-case
assignments for "q_paged_kv_nhd" and "q_paged_kv_hnd" but ensure a default
initialization at the start of the function to avoid undefined behavior.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 427315a0-6d34-4698-876d-232fa9aabaad

📥 Commits

Reviewing files that changed from the base of the PR and between a5e5cae and 11259d7.

📒 Files selected for processing (4)
  • csrc/fmha_v2/fmha/warpspec/dma.h
  • csrc/fmha_v2/fused_multihead_attention.h
  • csrc/fmha_v2_run.cu
  • flashinfer/prefill.py

Comment thread csrc/fmha_v2_run.cu
Data_type data_type, Data_type acc_type, Data_type output_dtype,
// attention input layout
Attention_input_layout input_layout,
Attention_input_layout input_layout, const bool is_paged_hnd,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Migrate the non-warp-specialized paged-KV reader before repurposing these stride fields.

This change turns k_stride_in_bytes / v_stride_in_bytes into generic 4-D layout strides, but csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h, Lines 693-700 still interprets those fields with the old per-head-stride contract and derives token stride via >> paged_kv_log2_block_size_. That leaves non-TMA Q_PAGED_KV kernels computing wrong K/V addresses for both HND and NHD. Please update that path to consume *_stride_in_bytes_2 as well, or keep the old meaning until both readers are migrated.

Also applies to: 122-136

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fmha_v2_run.cu` at line 53, The non-warp-specialized paged-KV reader in
fmha/gmem_tile_qkv_packed.h still assumes k_stride_in_bytes / v_stride_in_bytes
are per-head strides (and computes token stride via >>
paged_kv_log2_block_size_), but the change in Attention_input_layout repurposed
these into generic 4-D layout strides; update the reader to consume the new
secondary fields (k_stride_in_bytes_2 and v_stride_in_bytes_2) when calculating
token and head offsets (or alternatively retain the old per-head semantics for
k_stride_in_bytes/v_stride_in_bytes until both readers are migrated), making the
logic in the non-warp Q_PAGED_KV path consistent with the new layout semantics
and matching how the warp-specialized reader uses *_stride_in_bytes_2.

Comment thread csrc/fmha_v2_run.cu
@yzh119 yzh119 changed the title Hnd [fmha-v2] Support HND and NHD paged KV cache layouts with conditional stride handling Mar 17, 2026
@zhou-yuxin zhou-yuxin force-pushed the HND branch 3 times, most recently from f28b2d8 to 11259d7 Compare March 19, 2026 06:16
@zhou-yuxin zhou-yuxin closed this Mar 19, 2026
@zhou-yuxin zhou-yuxin reopened this Mar 19, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

♻️ Duplicate comments (1)
csrc/fmha_v2_run.cu (1)

324-337: ⚠️ Potential issue | 🔴 Critical

Initialize is_paged_hnd before any early return.

string_to_input_layout() only writes the out-parameter in the paged-KV branches, but fmha_v2_run() still forwards it to set_params() on Line 608 for every layout. Passing that indeterminate bool by value is undefined behavior even if non-paged layouts ignore it later.

Minimal fix
 static inline Attention_input_layout string_to_input_layout(const std::string& s,
                                                             bool& is_paged_hnd) {
+  is_paged_hnd = false;
   if (s == "packed_qkv") return Attention_input_layout::PACKED_QKV;
   if (s == "contiguous_q_kv") return Attention_input_layout::CONTIGUOUS_Q_KV;
   if (s == "q_paged_kv_nhd") {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fmha_v2_run.cu` around lines 324 - 337, The out-parameter is_paged_hnd
is left uninitialized for non-paged branches in string_to_input_layout;
initialize it deterministically (e.g., set is_paged_hnd = false at the top of
string_to_input_layout) or explicitly set it in every return branch so callers
like fmha_v2_run that pass it into set_params always receive a defined value;
update string_to_input_layout (and verify Attention_input_layout returns) to
ensure is_paged_hnd is always assigned before any return.
🧹 Nitpick comments (1)
csrc/fmha_v2_run.cu (1)

389-395: Validate page_size against the selected paged-KV axis here.

This block now derives h_kv from k.shape() based on HND vs NHD, but it still trusts the separate page_size argument. tokens_per_block later drives both Kv_block_array and the new paged-KV stride math, so a mismatched value will quietly misaddress the cache. An assert against k.shape()[2] for HND / k.shape()[1] for NHD would harden this path.

Possible hardening
   } else if (input_layout == Attention_input_layout::Q_PAGED_KV) {
     // q is 3D: [total_tokens, H, D]
     h = q.shape()[1];
+    const size_t inferred_page_size = k.shape()[is_paged_hnd ? 2 : 1];
     // k/v are 4D paged:
     //   HND: [num_pages, H_kv, page_size, D]
     //   NHD: [num_pages, page_size, H_kv, D]
     h_kv = k.shape()[is_paged_hnd ? 1 : 2];
     d = q.shape()[2];
     dv = v.shape()[3];
+    assert(inferred_page_size == static_cast<size_t>(page_size) &&
+           "page_size must match the paged KV cache tensor shape");
   } else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) {
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fmha_v2_run.cu` around lines 389 - 395, Validate that the provided
page_size matches the paged-KV tensor dimensions right after deriving h_kv;
specifically, when computing h_kv from k.shape() using is_paged_hnd, assert that
page_size == k.shape()[2] for HND layout and page_size == k.shape()[1] for NHD
layout (so tokens_per_block / Kv_block_array and the paged-KV stride math cannot
be driven by a mismatched page_size). Locate the check near the block that sets
h = q.shape()[1], h_kv = k.shape()[is_paged_hnd ? 1 : 2], d = q.shape()[2] and
add the corresponding assertion(s) referencing page_size, is_paged_hnd, and
k.shape().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Duplicate comments:
In `@csrc/fmha_v2_run.cu`:
- Around line 324-337: The out-parameter is_paged_hnd is left uninitialized for
non-paged branches in string_to_input_layout; initialize it deterministically
(e.g., set is_paged_hnd = false at the top of string_to_input_layout) or
explicitly set it in every return branch so callers like fmha_v2_run that pass
it into set_params always receive a defined value; update string_to_input_layout
(and verify Attention_input_layout returns) to ensure is_paged_hnd is always
assigned before any return.

---

Nitpick comments:
In `@csrc/fmha_v2_run.cu`:
- Around line 389-395: Validate that the provided page_size matches the paged-KV
tensor dimensions right after deriving h_kv; specifically, when computing h_kv
from k.shape() using is_paged_hnd, assert that page_size == k.shape()[2] for HND
layout and page_size == k.shape()[1] for NHD layout (so tokens_per_block /
Kv_block_array and the paged-KV stride math cannot be driven by a mismatched
page_size). Locate the check near the block that sets h = q.shape()[1], h_kv =
k.shape()[is_paged_hnd ? 1 : 2], d = q.shape()[2] and add the corresponding
assertion(s) referencing page_size, is_paged_hnd, and k.shape().

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 16953942-084d-4ffd-ad8f-e1c12d5abccf

📥 Commits

Reviewing files that changed from the base of the PR and between 67d6fbd and 5846873.

📒 Files selected for processing (1)
  • csrc/fmha_v2_run.cu

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/fmha_v2_run.cu (1)

265-273: Remove dead #if 0 gating and make ignored override explicit.

force_non_flash_attention is effectively ignored here; consider replacing the preprocessor block with an explicit comment/cast to avoid ambiguity and future regressions.

♻️ Suggested cleanup
-#if 0
-  // threshold for adopting flash attention or warp_specialized kernels.
-  launch_params.flash_attention =
-      (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) &&
-      (s >= 16 && d >= 16) && !force_non_flash_attention;
-#else
-  // Currently only flash attention kernels are generated in FlashInfer
-  launch_params.flash_attention = true;
-#endif
+  // Currently only flash-attention kernels are generated in FlashInfer.
+  // Keep this explicit to avoid confusion about `force_non_flash_attention`.
+  (void)force_non_flash_attention;
+  launch_params.flash_attention = true;
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fmha_v2_run.cu` around lines 265 - 273, The preprocessor dead-code block
around launch_params.flash_attention makes force_non_flash_attention and the
original gating ambiguous; remove the `#if` 0/#else/#endif and set
launch_params.flash_attention = true explicitly, add a concise comment that
FlashInfer currently always uses flash attention and that
force_non_flash_attention is intentionally ignored, and if
force_non_flash_attention (or other unused symbols like data_type,
DATA_TYPE_FP16/BF16/E4M3) remain in scope, explicitly mark them as intentionally
unused (e.g., cast to void) to avoid warnings and future confusion; locate the
code around launch_params.flash_attention in fmha_v2_run.cu to apply this
change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@csrc/fmha_v2_run.cu`:
- Around line 265-273: The preprocessor dead-code block around
launch_params.flash_attention makes force_non_flash_attention and the original
gating ambiguous; remove the `#if` 0/#else/#endif and set
launch_params.flash_attention = true explicitly, add a concise comment that
FlashInfer currently always uses flash attention and that
force_non_flash_attention is intentionally ignored, and if
force_non_flash_attention (or other unused symbols like data_type,
DATA_TYPE_FP16/BF16/E4M3) remain in scope, explicitly mark them as intentionally
unused (e.g., cast to void) to avoid warnings and future confusion; locate the
code around launch_params.flash_attention in fmha_v2_run.cu to apply this
change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 523bf6c8-90b0-4dcd-8041-cc86c104e3b9

📥 Commits

Reviewing files that changed from the base of the PR and between 5846873 and 1299a6d.

📒 Files selected for processing (1)
  • csrc/fmha_v2_run.cu

Signed-off-by: Yuxin <yuxinz@nvidia.com>
@qsang-nv qsang-nv self-requested a review as a code owner April 21, 2026 08:43
@qsang-nv
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !579 has been created, and the CI pipeline #49076701 is currently running. I'll report back once the pipeline job completes.

// x_stride_in_bytes means the stride of tensor_size[1]
// x_stride_in_bytes_2 means the stride of tensor_size[2]
int64_t k_stride_in_bytes_2;
int64_t v_stride_in_bytes_2;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unsure if this is in scope, but should these be added to fused_multihead_attention_demo_bert_params.h as well?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants