[fmha-v2] Support HND and NHD paged KV cache layouts with conditional stride handling#2799
[fmha-v2] Support HND and NHD paged KV cache layouts with conditional stride handling#2799zhou-yuxin wants to merge 5 commits intoflashinfer-ai:mainfrom
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdded explicit second-per-dimension K/V stride fields for 4D paged-KV layouts, propagated an Changes
Sequence Diagram(s)sequenceDiagram
participant Py as Python prefill
participant Launch as C++ launcher (fmha_v2_run)
participant Host as Host DMA (dma.h)
participant Dev as Device kernel / TMA
Py->>Launch: fmha_v2_run(input_layout, paged_kv, ...)
Launch->>Launch: string_to_input_layout(..., is_paged_hnd)
Launch->>Host: set_params(..., is_paged_hnd, k_stride_in_bytes, k_stride_in_bytes_2, v_stride_in_bytes, v_stride_in_bytes_2)
Host->>Dev: emit TMA descriptors (select stride [1] vs [2] based on is_paged_hnd)
Py->>Py: extract k_cache, v_cache (unbind dim per layout)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the FlashInfer FMHAv2 implementation by introducing performance optimizations and new attention features. It addresses potential integer overflows in memory addressing, improves flexibility for variable sequence lengths, and integrates a 'skip-softmax' mechanism to reduce computational overhead. The changes also refine paged KV cache handling and expand the kernel generation framework to support these new capabilities, all validated through comprehensive new test cases. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a significant "Skip-Softmax" optimization for attention computation, which can improve performance by skipping softmax and BMM2 for negligible attention scores. The implementation is well-structured across the C++ kernel, traits, and epilogue, with careful handling of synchronization. Additionally, the PR includes several important correctness fixes, such as preventing integer overflows in pointer arithmetic and correcting the causal mask boundary calculation. The Python-level changes correctly expose the new features and refactor the JIT module generation for better organization. I've identified one potential issue related to a division-by-zero that should be addressed.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@csrc/fmha_v2_run.cu`:
- Line 53: The non-warp-specialized paged-KV reader in
fmha/gmem_tile_qkv_packed.h still assumes k_stride_in_bytes / v_stride_in_bytes
are per-head strides (and computes token stride via >>
paged_kv_log2_block_size_), but the change in Attention_input_layout repurposed
these into generic 4-D layout strides; update the reader to consume the new
secondary fields (k_stride_in_bytes_2 and v_stride_in_bytes_2) when calculating
token and head offsets (or alternatively retain the old per-head semantics for
k_stride_in_bytes/v_stride_in_bytes until both readers are migrated), making the
logic in the non-warp Q_PAGED_KV path consistent with the new layout semantics
and matching how the warp-specialized reader uses *_stride_in_bytes_2.
- Around line 322-335: In string_to_input_layout initialize the out-parameter
is_paged_hnd at the top (e.g. set is_paged_hnd = false) before any early returns
so it is always defined for callers like set_params; keep the existing per-case
assignments for "q_paged_kv_nhd" and "q_paged_kv_hnd" but ensure a default
initialization at the start of the function to avoid undefined behavior.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 427315a0-6d34-4698-876d-232fa9aabaad
📒 Files selected for processing (4)
csrc/fmha_v2/fmha/warpspec/dma.hcsrc/fmha_v2/fused_multihead_attention.hcsrc/fmha_v2_run.cuflashinfer/prefill.py
| Data_type data_type, Data_type acc_type, Data_type output_dtype, | ||
| // attention input layout | ||
| Attention_input_layout input_layout, | ||
| Attention_input_layout input_layout, const bool is_paged_hnd, |
There was a problem hiding this comment.
Migrate the non-warp-specialized paged-KV reader before repurposing these stride fields.
This change turns k_stride_in_bytes / v_stride_in_bytes into generic 4-D layout strides, but csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h, Lines 693-700 still interprets those fields with the old per-head-stride contract and derives token stride via >> paged_kv_log2_block_size_. That leaves non-TMA Q_PAGED_KV kernels computing wrong K/V addresses for both HND and NHD. Please update that path to consume *_stride_in_bytes_2 as well, or keep the old meaning until both readers are migrated.
Also applies to: 122-136
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@csrc/fmha_v2_run.cu` at line 53, The non-warp-specialized paged-KV reader in
fmha/gmem_tile_qkv_packed.h still assumes k_stride_in_bytes / v_stride_in_bytes
are per-head strides (and computes token stride via >>
paged_kv_log2_block_size_), but the change in Attention_input_layout repurposed
these into generic 4-D layout strides; update the reader to consume the new
secondary fields (k_stride_in_bytes_2 and v_stride_in_bytes_2) when calculating
token and head offsets (or alternatively retain the old per-head semantics for
k_stride_in_bytes/v_stride_in_bytes until both readers are migrated), making the
logic in the non-warp Q_PAGED_KV path consistent with the new layout semantics
and matching how the warp-specialized reader uses *_stride_in_bytes_2.
f28b2d8 to
11259d7
Compare
There was a problem hiding this comment.
♻️ Duplicate comments (1)
csrc/fmha_v2_run.cu (1)
324-337:⚠️ Potential issue | 🔴 CriticalInitialize
is_paged_hndbefore any early return.
string_to_input_layout()only writes the out-parameter in the paged-KV branches, butfmha_v2_run()still forwards it toset_params()on Line 608 for every layout. Passing that indeterminateboolby value is undefined behavior even if non-paged layouts ignore it later.Minimal fix
static inline Attention_input_layout string_to_input_layout(const std::string& s, bool& is_paged_hnd) { + is_paged_hnd = false; if (s == "packed_qkv") return Attention_input_layout::PACKED_QKV; if (s == "contiguous_q_kv") return Attention_input_layout::CONTIGUOUS_Q_KV; if (s == "q_paged_kv_nhd") {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/fmha_v2_run.cu` around lines 324 - 337, The out-parameter is_paged_hnd is left uninitialized for non-paged branches in string_to_input_layout; initialize it deterministically (e.g., set is_paged_hnd = false at the top of string_to_input_layout) or explicitly set it in every return branch so callers like fmha_v2_run that pass it into set_params always receive a defined value; update string_to_input_layout (and verify Attention_input_layout returns) to ensure is_paged_hnd is always assigned before any return.
🧹 Nitpick comments (1)
csrc/fmha_v2_run.cu (1)
389-395: Validatepage_sizeagainst the selected paged-KV axis here.This block now derives
h_kvfromk.shape()based on HND vs NHD, but it still trusts the separatepage_sizeargument.tokens_per_blocklater drives bothKv_block_arrayand the new paged-KV stride math, so a mismatched value will quietly misaddress the cache. An assert againstk.shape()[2]for HND /k.shape()[1]for NHD would harden this path.Possible hardening
} else if (input_layout == Attention_input_layout::Q_PAGED_KV) { // q is 3D: [total_tokens, H, D] h = q.shape()[1]; + const size_t inferred_page_size = k.shape()[is_paged_hnd ? 2 : 1]; // k/v are 4D paged: // HND: [num_pages, H_kv, page_size, D] // NHD: [num_pages, page_size, H_kv, D] h_kv = k.shape()[is_paged_hnd ? 1 : 2]; d = q.shape()[2]; dv = v.shape()[3]; + assert(inferred_page_size == static_cast<size_t>(page_size) && + "page_size must match the paged KV cache tensor shape"); } else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) {🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/fmha_v2_run.cu` around lines 389 - 395, Validate that the provided page_size matches the paged-KV tensor dimensions right after deriving h_kv; specifically, when computing h_kv from k.shape() using is_paged_hnd, assert that page_size == k.shape()[2] for HND layout and page_size == k.shape()[1] for NHD layout (so tokens_per_block / Kv_block_array and the paged-KV stride math cannot be driven by a mismatched page_size). Locate the check near the block that sets h = q.shape()[1], h_kv = k.shape()[is_paged_hnd ? 1 : 2], d = q.shape()[2] and add the corresponding assertion(s) referencing page_size, is_paged_hnd, and k.shape().
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@csrc/fmha_v2_run.cu`:
- Around line 324-337: The out-parameter is_paged_hnd is left uninitialized for
non-paged branches in string_to_input_layout; initialize it deterministically
(e.g., set is_paged_hnd = false at the top of string_to_input_layout) or
explicitly set it in every return branch so callers like fmha_v2_run that pass
it into set_params always receive a defined value; update string_to_input_layout
(and verify Attention_input_layout returns) to ensure is_paged_hnd is always
assigned before any return.
---
Nitpick comments:
In `@csrc/fmha_v2_run.cu`:
- Around line 389-395: Validate that the provided page_size matches the paged-KV
tensor dimensions right after deriving h_kv; specifically, when computing h_kv
from k.shape() using is_paged_hnd, assert that page_size == k.shape()[2] for HND
layout and page_size == k.shape()[1] for NHD layout (so tokens_per_block /
Kv_block_array and the paged-KV stride math cannot be driven by a mismatched
page_size). Locate the check near the block that sets h = q.shape()[1], h_kv =
k.shape()[is_paged_hnd ? 1 : 2], d = q.shape()[2] and add the corresponding
assertion(s) referencing page_size, is_paged_hnd, and k.shape().
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/fmha_v2_run.cu (1)
265-273: Remove dead#if 0gating and make ignored override explicit.
force_non_flash_attentionis effectively ignored here; consider replacing the preprocessor block with an explicit comment/cast to avoid ambiguity and future regressions.♻️ Suggested cleanup
-#if 0 - // threshold for adopting flash attention or warp_specialized kernels. - launch_params.flash_attention = - (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && - (s >= 16 && d >= 16) && !force_non_flash_attention; -#else - // Currently only flash attention kernels are generated in FlashInfer - launch_params.flash_attention = true; -#endif + // Currently only flash-attention kernels are generated in FlashInfer. + // Keep this explicit to avoid confusion about `force_non_flash_attention`. + (void)force_non_flash_attention; + launch_params.flash_attention = true;🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/fmha_v2_run.cu` around lines 265 - 273, The preprocessor dead-code block around launch_params.flash_attention makes force_non_flash_attention and the original gating ambiguous; remove the `#if` 0/#else/#endif and set launch_params.flash_attention = true explicitly, add a concise comment that FlashInfer currently always uses flash attention and that force_non_flash_attention is intentionally ignored, and if force_non_flash_attention (or other unused symbols like data_type, DATA_TYPE_FP16/BF16/E4M3) remain in scope, explicitly mark them as intentionally unused (e.g., cast to void) to avoid warnings and future confusion; locate the code around launch_params.flash_attention in fmha_v2_run.cu to apply this change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@csrc/fmha_v2_run.cu`:
- Around line 265-273: The preprocessor dead-code block around
launch_params.flash_attention makes force_non_flash_attention and the original
gating ambiguous; remove the `#if` 0/#else/#endif and set
launch_params.flash_attention = true explicitly, add a concise comment that
FlashInfer currently always uses flash attention and that
force_non_flash_attention is intentionally ignored, and if
force_non_flash_attention (or other unused symbols like data_type,
DATA_TYPE_FP16/BF16/E4M3) remain in scope, explicitly mark them as intentionally
unused (e.g., cast to void) to avoid warnings and future confusion; locate the
code around launch_params.flash_attention in fmha_v2_run.cu to apply this
change.
Signed-off-by: Yuxin <yuxinz@nvidia.com>
|
/bot run |
| // x_stride_in_bytes means the stride of tensor_size[1] | ||
| // x_stride_in_bytes_2 means the stride of tensor_size[2] | ||
| int64_t k_stride_in_bytes_2; | ||
| int64_t v_stride_in_bytes_2; |
There was a problem hiding this comment.
Unsure if this is in scope, but should these be added to fused_multihead_attention_demo_bert_params.h as well?
Previously, the fmha_v2 paged kv kernels only supported the HND layout. When users passed in an NHD layout, transpose(-3, -2).contiguous() is called, which introduced significant overhead. This pull request adds native support for the NHD layout to the fmha_v2 kernels by allowing users to pass in custom TMA strides.
Summary by CodeRabbit
New Features
Bug Fixes