diff --git a/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h index 00797d0a01..6e105107fd 100644 --- a/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h +++ b/csrc/fmha_v2/fmha/gmem_tile_qkv_packed.h @@ -879,12 +879,14 @@ struct Gmem_tile_paged_kv { // Do not load/store if the thread is in the padded area col_in_bytes_ = cta_col_offset_in_bytes + col * BYTES_PER_LDG; - int64_t kv_stride_in_bytes = - qkv_offset == 1 ? params.k_stride_in_bytes : params.v_stride_in_bytes; - // The head offset. - head_stride_in_bytes_ = (int64_t)(binfo.bidh / params.h_q_per_kv) * kv_stride_in_bytes; - // When V is padded (like MLA), we cannot use VALID_BYTES_PER_ROW - token_stride_in_bytes_ = kv_stride_in_bytes >> paged_kv_log2_block_size_; + // The head stride in bytes. + int64_t head_stride_in_bytes = + qkv_offset == 1 ? params.k_stride_in_bytes_2 : params.v_stride_in_bytes_2; + // The head offset in bytes. + head_offset_in_bytes_ = (binfo.bidh / params.h_q_per_kv) * head_stride_in_bytes; + + // The token stride in bytes. + token_stride_in_bytes_ = qkv_offset == 1 ? params.k_stride_in_bytes : params.v_stride_in_bytes; // Take the CTA offset to modify the sequence length. // Actually we don't need that for flash attention. @@ -908,7 +910,7 @@ struct Gmem_tile_paged_kv { void const* ptrs[LDGS]; // Offset for the new paged kv pointer. - uint64_t const head_col_in_bytes = head_stride_in_bytes_ + col_in_bytes_; + uint64_t const head_col_in_bytes = head_offset_in_bytes_ + col_in_bytes_; // Update paged_kv ptr for each LDG (reuse is possible). #pragma unroll @@ -966,9 +968,9 @@ struct Gmem_tile_paged_kv { int row_; int64_t col_in_bytes_; // Keep track of the head offset. - int64_t head_stride_in_bytes_; + int64_t head_offset_in_bytes_; // // for DeepSeek MLA, the stride of V tokens != VALID_BYTES_PER_ROW - int32_t token_stride_in_bytes_; + int64_t token_stride_in_bytes_; // The sequence length. int actual_seqlen_; // The past sequence length (kv_seqlen - q_seqlen) considering chunked context. diff --git a/csrc/fmha_v2/fmha/warpspec/dma.h b/csrc/fmha_v2/fmha/warpspec/dma.h index 6934087270..2d0ac446a7 100644 --- a/csrc/fmha_v2/fmha/warpspec/dma.h +++ b/csrc/fmha_v2/fmha/warpspec/dma.h @@ -784,13 +784,14 @@ struct DMA { uint32_t tensor_size_v[4] = {dv, tokens_per_block, h_kv, INT_MAX}; uint64_t tensor_stride_k[3]; - tensor_stride_k[0] = params.k_stride_in_bytes / tokens_per_block; // d - tensor_stride_k[1] = params.k_stride_in_bytes; // d * 64 + tensor_stride_k[0] = params.k_stride_in_bytes; + tensor_stride_k[1] = params.k_stride_in_bytes_2; tensor_stride_k[2] = params.paged_kv_cache.mBytesPerBlock; uint64_t tensor_stride_v[3]; // we cannot use dv * Kernel_traits::ELEMENT_BYTES because V may be padded (MLA) - tensor_stride_v[0] = params.v_stride_in_bytes / tokens_per_block; // dv - tensor_stride_v[1] = params.v_stride_in_bytes; // dv * 64 + // use the values given by caller + tensor_stride_v[0] = params.v_stride_in_bytes; + tensor_stride_v[1] = params.v_stride_in_bytes_2; tensor_stride_v[2] = params.paged_kv_cache.mBytesPerBlock; char* kv_ptr = reinterpret_cast(params.paged_kv_cache.mPoolPtr); diff --git a/csrc/fmha_v2/fused_multihead_attention.h b/csrc/fmha_v2/fused_multihead_attention.h index 7049103d7f..b180889890 100644 --- a/csrc/fmha_v2/fused_multihead_attention.h +++ b/csrc/fmha_v2/fused_multihead_attention.h @@ -237,6 +237,13 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba int64_t q_stride_in_bytes; int64_t k_stride_in_bytes; int64_t v_stride_in_bytes; + // Paged KV uses 4D tensor, the tensor size is: + // HND = [num_pages, H, page_size, D] or NHD = [num_pages, page_size, H, D] + // so need another pair of stride. + // x_stride_in_bytes means the stride of tensor_size[1] + // x_stride_in_bytes_2 means the stride of tensor_size[2] + int64_t k_stride_in_bytes_2; + int64_t v_stride_in_bytes_2; // Paged KV load. int blocks_per_tma_load; diff --git a/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h b/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h index 62294e2f0a..3b2db8ccbf 100644 --- a/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h +++ b/csrc/fmha_v2/fused_multihead_attention_demo_bert_params.h @@ -177,4 +177,12 @@ struct Fused_multihead_attention_params_v2 { uint32_t* skip_softmax_total_blocks; uint32_t* skip_softmax_skipped_blocks; #endif + + // Paged KV uses 4D tensor, the tensor size is: + // HND = [num_pages, H, page_size, D] or NHD = [num_pages, page_size, H, D] + // so need another pair of stride. + // x_stride_in_bytes means the stride of tensor_size[1] + // x_stride_in_bytes_2 means the stride of tensor_size[2] + int64_t k_stride_in_bytes_2; + int64_t v_stride_in_bytes_2; }; diff --git a/csrc/fmha_v2_run.cu b/csrc/fmha_v2_run.cu index 3dfff9b967..ef0aef295e 100644 --- a/csrc/fmha_v2_run.cu +++ b/csrc/fmha_v2_run.cu @@ -50,7 +50,7 @@ static inline void set_params( // types Data_type data_type, Data_type acc_type, Data_type output_dtype, // attention input layout - Attention_input_layout input_layout, + Attention_input_layout input_layout, const bool is_paged_hnd, // sizes const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d, const size_t dv, const size_t total, const size_t num_grouped_heads, @@ -119,8 +119,21 @@ static inline void set_params( get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type), paged_kv_pool_ptr); params.paged_kv_cache.mBlockOffsets = paged_block_offsets; - params.k_stride_in_bytes = get_size_in_bytes(tokens_per_block * d, data_type); - params.v_stride_in_bytes = get_size_in_bytes(tokens_per_block * dv, data_type); + // FMHA kernels always access the K/V tensor in 4D coordinate [num_pages, H_kv, page_size, D]. + // The layout of HND or NHD is implemented by tensor strides to get the correct memory + // address. 4D tensor strides of HND: [block_size, page_size * D, D ,1] 4D tensor strides of + // NHD: [block_size, D, H_kv * D, 1] + if (is_paged_hnd) { + params.k_stride_in_bytes = get_size_in_bytes(d, data_type); + params.v_stride_in_bytes = get_size_in_bytes(dv, data_type); + params.k_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * d, data_type); + params.v_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * dv, data_type); + } else { + params.k_stride_in_bytes = get_size_in_bytes(h_kv * d, data_type); + params.v_stride_in_bytes = get_size_in_bytes(h_kv * dv, data_type); + params.k_stride_in_bytes_2 = get_size_in_bytes(d, data_type); + params.v_stride_in_bytes_2 = get_size_in_bytes(dv, data_type); + } } else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V) { // Layout [B, S, H_kv, D]. params.k_ptr = k_d; @@ -249,10 +262,15 @@ static inline void determine_launch_params( launch_params.multi_processor_count = props.multiProcessorCount; launch_params.device_l2_cache_size = props.l2CacheSize; +#if 0 // threshold for adopting flash attention or warp_specialized kernels. launch_params.flash_attention = (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && (s >= 16 && d >= 16) && !force_non_flash_attention; +#else + // Currently only flash attention kernels are generated in FlashInfer + launch_params.flash_attention = true; +#endif // enable warp_speialized kernels when s >= 512 on hopper // note that warp_speialized kernels need flash attention + tma @@ -306,11 +324,18 @@ static inline Attention_mask_type string_to_mask_type(const std::string& s) { return Attention_mask_type::CAUSAL; // default } -static inline Attention_input_layout string_to_input_layout(const std::string& s) { +static inline Attention_input_layout string_to_input_layout(const std::string& s, + bool& is_paged_hnd) { + is_paged_hnd = false; if (s == "packed_qkv") return Attention_input_layout::PACKED_QKV; if (s == "contiguous_q_kv") return Attention_input_layout::CONTIGUOUS_Q_KV; - if (s == "q_paged_kv_nhd") return Attention_input_layout::Q_PAGED_KV; - if (s == "q_paged_kv_hnd") return Attention_input_layout::Q_PAGED_KV; + if (s == "q_paged_kv_nhd") { + return Attention_input_layout::Q_PAGED_KV; + } + if (s == "q_paged_kv_hnd") { + is_paged_hnd = true; + return Attention_input_layout::Q_PAGED_KV; + } if (s == "separate_q_k_v") return Attention_input_layout::SEPARATE_Q_K_V; throw std::invalid_argument("Unsupported input_layout: " + s); } @@ -333,7 +358,8 @@ void fmha_v2_run( ffi::TensorView scale_bmm2_d, // Pre-populated scale_bmm2 on device [1] int32 Optional softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum) Optional sinks) { - Attention_input_layout input_layout = string_to_input_layout(input_layout_str); + bool is_paged_hnd; + Attention_input_layout input_layout = string_to_input_layout(input_layout_str, is_paged_hnd); Attention_mask_type attention_mask_type = string_to_mask_type(mask_mode_str); Data_type output_dtype = dltype_to_data_type(o.dtype()); // Get device properties @@ -363,9 +389,12 @@ void fmha_v2_run( d = q.shape()[3]; // head_dim_qk dv = q.shape()[3]; // head_dim_v (same as d for standard attention) } else if (input_layout == Attention_input_layout::Q_PAGED_KV) { - // q is 3D: [total_tokens, H, D], k/v are 4D paged: [num_pages, H_kv, page_size, D] + // q is 3D: [total_tokens, H, D] h = q.shape()[1]; - h_kv = k.shape()[1]; + // k/v are 4D paged: + // HND: [num_pages, H_kv, page_size, D] + // NHD: [num_pages, page_size, H_kv, D] + h_kv = k.shape()[is_paged_hnd ? 1 : 2]; d = q.shape()[2]; dv = v.shape()[3]; } else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) { @@ -578,16 +607,16 @@ void fmha_v2_run( } bert::Fused_multihead_attention_params_v2 params_v2; - set_params(params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, b, s_q, s, - h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size, - // Paged kv cache. - tokens_per_block, qkv_packed_d, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, - kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d, - static_cast(cum_seq_lens_kv.data_ptr()), - static_cast(cum_seq_lens_q.data_ptr()), o.data_ptr(), nullptr, nullptr, - softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1, scale_softmax, scale_bmm2, - softcapping_scale_bmm1, false, false, false, has_alibi, - skip_softmax_threshold_scale_factor); + set_params( + params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, is_paged_hnd, b, + s_q, s, h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size, + // Paged kv cache. + tokens_per_block, qkv_packed_d, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, + kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d, + static_cast(cum_seq_lens_kv.data_ptr()), static_cast(cum_seq_lens_q.data_ptr()), + o.data_ptr(), nullptr, nullptr, softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1, + scale_softmax, scale_bmm2, softcapping_scale_bmm1, false, false, false, has_alibi, + skip_softmax_threshold_scale_factor); // For Q_PAGED_KV layout, override mMaxBlocksPerSeq to match the actual block_tables stride // that we used when expanding the block offsets from [B, M] to [B, 2, M] diff --git a/flashinfer/prefill.py b/flashinfer/prefill.py index 4ec6a29e7d..7a647bb86c 100755 --- a/flashinfer/prefill.py +++ b/flashinfer/prefill.py @@ -4361,9 +4361,7 @@ def trtllm_fmha_v2_prefill( raise ValueError( f"Q_PAGED_KV_NHD expects paged_KV shape [pages, 2, page_size, num_kv_heads, head_dim], got {tuple(paged_kv.shape)}" ) - # TODO: implement native NHD support in the kernel to avoid this transpose - kv_cache = paged_kv.transpose(-3, -2).contiguous() - k_cache, v_cache = kv_cache.unbind(dim=1) + k_cache, v_cache = paged_kv.unbind(dim=1) elif input_layout == "Q_PAGED_KV_HND": assert isinstance(qkv, tuple) query, paged_kv = qkv[0], qkv[1] @@ -4397,9 +4395,11 @@ def trtllm_fmha_v2_prefill( page_size = 0 # Not applicable for packed layouts head_dim_v = query.shape[3] # Assume same as head_dim_qk elif input_layout in ("Q_PAGED_KV_NHD", "Q_PAGED_KV_HND"): - # Q is 3D: [tokens, H, D], Paged KV (HND after any transpose): [num_pages, H_kv, page_size, D] + # Q is 3D: [tokens, H, D] num_qo_heads = query.shape[1] - page_size = k_cache.shape[2] + # Paged KV NHD: [num_pages, page_size, H_kv, D] + # Paged KV HND: [num_pages, H_kv, page_size, D] + page_size = k_cache.shape[1 if "NHD" in input_layout else 2] head_dim_v = v_cache.shape[3] elif input_layout == "CONTIGUOUS_Q_KV": # Q is 3D: [tokens, H, D], KV is 4D: [tokens, 2, H_kv, D]