Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions csrc/fmha_v2/fmha/warpspec/dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -784,13 +784,14 @@ struct DMA {
uint32_t tensor_size_v[4] = {dv, tokens_per_block, h_kv, INT_MAX};

uint64_t tensor_stride_k[3];
tensor_stride_k[0] = params.k_stride_in_bytes / tokens_per_block; // d
tensor_stride_k[1] = params.k_stride_in_bytes; // d * 64
tensor_stride_k[0] = params.k_stride_in_bytes;
tensor_stride_k[1] = params.k_stride_in_bytes_2;
tensor_stride_k[2] = params.paged_kv_cache.mBytesPerBlock;
uint64_t tensor_stride_v[3];
// we cannot use dv * Kernel_traits::ELEMENT_BYTES because V may be padded (MLA)
tensor_stride_v[0] = params.v_stride_in_bytes / tokens_per_block; // dv
tensor_stride_v[1] = params.v_stride_in_bytes; // dv * 64
// use the values given by caller
tensor_stride_v[0] = params.v_stride_in_bytes;
tensor_stride_v[1] = params.v_stride_in_bytes_2;
tensor_stride_v[2] = params.paged_kv_cache.mBytesPerBlock;

char* kv_ptr = reinterpret_cast<char*>(params.paged_kv_cache.mPoolPtr);
Expand Down
7 changes: 7 additions & 0 deletions csrc/fmha_v2/fused_multihead_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,13 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba
int64_t q_stride_in_bytes;
int64_t k_stride_in_bytes;
int64_t v_stride_in_bytes;
// Paged KV uses 4D tensor, the tensor size is:
// HND = [num_pages, H, page_size, D] or NHD = [num_pages, page_size, H, D]
// so need another pair of stride.
// x_stride_in_bytes means the stride of tensor_size[1]
// x_stride_in_bytes_2 means the stride of tensor_size[2]
int64_t k_stride_in_bytes_2;
int64_t v_stride_in_bytes_2;
Comment thread
zhou-yuxin marked this conversation as resolved.

// Paged KV load.
int blocks_per_tma_load;
Expand Down
66 changes: 46 additions & 20 deletions csrc/fmha_v2_run.cu
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ static inline void set_params(
// types
Data_type data_type, Data_type acc_type, Data_type output_dtype,
// attention input layout
Attention_input_layout input_layout,
Attention_input_layout input_layout, const bool is_paged_hnd,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Migrate the non-warp-specialized paged-KV reader before repurposing these stride fields.

This change turns k_stride_in_bytes / v_stride_in_bytes into generic 4-D layout strides, but csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h, Lines 693-700 still interprets those fields with the old per-head-stride contract and derives token stride via >> paged_kv_log2_block_size_. That leaves non-TMA Q_PAGED_KV kernels computing wrong K/V addresses for both HND and NHD. Please update that path to consume *_stride_in_bytes_2 as well, or keep the old meaning until both readers are migrated.

Also applies to: 122-136

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/fmha_v2_run.cu` at line 53, The non-warp-specialized paged-KV reader in
fmha/gmem_tile_qkv_packed.h still assumes k_stride_in_bytes / v_stride_in_bytes
are per-head strides (and computes token stride via >>
paged_kv_log2_block_size_), but the change in Attention_input_layout repurposed
these into generic 4-D layout strides; update the reader to consume the new
secondary fields (k_stride_in_bytes_2 and v_stride_in_bytes_2) when calculating
token and head offsets (or alternatively retain the old per-head semantics for
k_stride_in_bytes/v_stride_in_bytes until both readers are migrated), making the
logic in the non-warp Q_PAGED_KV path consistent with the new layout semantics
and matching how the warp-specialized reader uses *_stride_in_bytes_2.

// sizes
const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv,
const size_t d, const size_t dv, const size_t total, const size_t num_grouped_heads,
Expand Down Expand Up @@ -119,8 +119,21 @@ static inline void set_params(
get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type),
paged_kv_pool_ptr);
params.paged_kv_cache.mBlockOffsets = paged_block_offsets;
params.k_stride_in_bytes = get_size_in_bytes(tokens_per_block * d, data_type);
params.v_stride_in_bytes = get_size_in_bytes(tokens_per_block * dv, data_type);
// FMHA kernels always access the K/V tensor in 4D coordinate [num_pages, H_kv, page_size, D].
// The layout of HND or NHD is implemented by tensor strides to get the correct memory
// address. 4D tensor strides of HND: [block_size, page_size * D, D ,1] 4D tensor strides of
// NHD: [block_size, D, H_kv * D, 1]
if (is_paged_hnd) {
params.k_stride_in_bytes = get_size_in_bytes(d, data_type);
params.v_stride_in_bytes = get_size_in_bytes(dv, data_type);
params.k_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * d, data_type);
params.v_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * dv, data_type);
} else {
params.k_stride_in_bytes = get_size_in_bytes(h_kv * d, data_type);
params.v_stride_in_bytes = get_size_in_bytes(h_kv * dv, data_type);
params.k_stride_in_bytes_2 = get_size_in_bytes(d, data_type);
params.v_stride_in_bytes_2 = get_size_in_bytes(dv, data_type);
}
} else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V) {
// Layout [B, S, H_kv, D].
params.k_ptr = k_d;
Expand Down Expand Up @@ -252,7 +265,9 @@ static inline void determine_launch_params(
// threshold for adopting flash attention or warp_specialized kernels.
launch_params.flash_attention =
(data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) &&
(s >= 16 && d >= 16) && !force_non_flash_attention;
// when s < 16, non-flash attention is faster,
// but in flashinfer only flash attention kernels are generated
(/*s >= 16 &&*/ d >= 16) && !force_non_flash_attention;

// enable warp_speialized kernels when s >= 512 on hopper
// note that warp_speialized kernels need flash attention + tma
Expand Down Expand Up @@ -306,11 +321,18 @@ static inline Attention_mask_type string_to_mask_type(const std::string& s) {
return Attention_mask_type::CAUSAL; // default
}

static inline Attention_input_layout string_to_input_layout(const std::string& s) {
static inline Attention_input_layout string_to_input_layout(const std::string& s,
bool& is_paged_hnd) {
if (s == "packed_qkv") return Attention_input_layout::PACKED_QKV;
if (s == "contiguous_q_kv") return Attention_input_layout::CONTIGUOUS_Q_KV;
if (s == "q_paged_kv_nhd") return Attention_input_layout::Q_PAGED_KV;
if (s == "q_paged_kv_hnd") return Attention_input_layout::Q_PAGED_KV;
if (s == "q_paged_kv_nhd") {
is_paged_hnd = false;
return Attention_input_layout::Q_PAGED_KV;
}
if (s == "q_paged_kv_hnd") {
is_paged_hnd = true;
return Attention_input_layout::Q_PAGED_KV;
}
if (s == "separate_q_k_v") return Attention_input_layout::SEPARATE_Q_K_V;
throw std::invalid_argument("Unsupported input_layout: " + s);
Comment thread
coderabbitai[bot] marked this conversation as resolved.
}
Expand All @@ -333,7 +355,8 @@ void fmha_v2_run(
ffi::TensorView scale_bmm2_d, // Pre-populated scale_bmm2 on device [1] int32
Optional<ffi::TensorView> softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum)
Optional<ffi::TensorView> sinks) {
Attention_input_layout input_layout = string_to_input_layout(input_layout_str);
bool is_paged_hnd;
Attention_input_layout input_layout = string_to_input_layout(input_layout_str, is_paged_hnd);
Attention_mask_type attention_mask_type = string_to_mask_type(mask_mode_str);
Data_type output_dtype = dltype_to_data_type(o.dtype());
// Get device properties
Expand Down Expand Up @@ -363,9 +386,12 @@ void fmha_v2_run(
d = q.shape()[3]; // head_dim_qk
dv = q.shape()[3]; // head_dim_v (same as d for standard attention)
} else if (input_layout == Attention_input_layout::Q_PAGED_KV) {
// q is 3D: [total_tokens, H, D], k/v are 4D paged: [num_pages, H_kv, page_size, D]
// q is 3D: [total_tokens, H, D]
h = q.shape()[1];
h_kv = k.shape()[1];
// k/v are 4D paged:
// HND: [num_pages, H_kv, page_size, D]
// NHD: [num_pages, page_size, H_kv, D]
h_kv = k.shape()[is_paged_hnd ? 1 : 2];
d = q.shape()[2];
dv = v.shape()[3];
} else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) {
Expand Down Expand Up @@ -578,16 +604,16 @@ void fmha_v2_run(
}

bert::Fused_multihead_attention_params_v2 params_v2;
set_params(params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, b, s_q, s,
h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size,
// Paged kv cache.
tokens_per_block, qkv_packed_d, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr,
kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d,
static_cast<void*>(cum_seq_lens_kv.data_ptr()),
static_cast<void*>(cum_seq_lens_q.data_ptr()), o.data_ptr(), nullptr, nullptr,
softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1, scale_softmax, scale_bmm2,
softcapping_scale_bmm1, false, false, false, has_alibi,
skip_softmax_threshold_scale_factor);
set_params(
params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, is_paged_hnd, b,
s_q, s, h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size,
// Paged kv cache.
tokens_per_block, qkv_packed_d, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr,
kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d,
static_cast<void*>(cum_seq_lens_kv.data_ptr()), static_cast<void*>(cum_seq_lens_q.data_ptr()),
o.data_ptr(), nullptr, nullptr, softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1,
scale_softmax, scale_bmm2, softcapping_scale_bmm1, false, false, false, has_alibi,
skip_softmax_threshold_scale_factor);

// For Q_PAGED_KV layout, override mMaxBlocksPerSeq to match the actual block_tables stride
// that we used when expanding the block offsets from [B, M] to [B, 2, M]
Expand Down
10 changes: 5 additions & 5 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4148,9 +4148,7 @@ def trtllm_fmha_v2_prefill(
raise ValueError(
f"Q_PAGED_KV_NHD expects paged_KV shape [pages, 2, page_size, num_kv_heads, head_dim], got {tuple(paged_kv.shape)}"
)
# TODO: implement native NHD support in the kernel to avoid this transpose
kv_cache = paged_kv.transpose(-3, -2).contiguous()
k_cache, v_cache = kv_cache.unbind(dim=1)
k_cache, v_cache = paged_kv.unbind(dim=1)
elif input_layout == "Q_PAGED_KV_HND":
assert isinstance(qkv, tuple)
query, paged_kv = qkv[0], qkv[1]
Expand Down Expand Up @@ -4184,9 +4182,11 @@ def trtllm_fmha_v2_prefill(
page_size = 0 # Not applicable for packed layouts
head_dim_v = query.shape[3] # Assume same as head_dim_qk
elif input_layout in ("Q_PAGED_KV_NHD", "Q_PAGED_KV_HND"):
# Q is 3D: [tokens, H, D], Paged KV (HND after any transpose): [num_pages, H_kv, page_size, D]
# Q is 3D: [tokens, H, D]
num_qo_heads = query.shape[1]
page_size = k_cache.shape[2]
# Paged KV NHD: [num_pages, page_size, H_kv, D]
# Paged KV HND: [num_pages, H_kv, page_size, D]
page_size = k_cache.shape[1 if "NHD" in input_layout else 2]
head_dim_v = v_cache.shape[3]
elif input_layout == "CONTIGUOUS_Q_KV":
# Q is 3D: [tokens, H, D], KV is 4D: [tokens, 2, H_kv, D]
Expand Down
Loading