-
Notifications
You must be signed in to change notification settings - Fork 920
[fmha-v2] Support HND and NHD paged KV cache layouts with conditional stride handling #2799
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
67d6fbd
5846873
1299a6d
115a01e
02ae0b7
abd10a6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -50,7 +50,7 @@ static inline void set_params( | |
| // types | ||
| Data_type data_type, Data_type acc_type, Data_type output_dtype, | ||
| // attention input layout | ||
| Attention_input_layout input_layout, | ||
| Attention_input_layout input_layout, const bool is_paged_hnd, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Migrate the non-warp-specialized paged-KV reader before repurposing these stride fields. This change turns Also applies to: 122-136 🤖 Prompt for AI Agents |
||
| // sizes | ||
| const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, | ||
| const size_t d, const size_t dv, const size_t total, const size_t num_grouped_heads, | ||
|
|
@@ -119,8 +119,21 @@ static inline void set_params( | |
| get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type), | ||
| paged_kv_pool_ptr); | ||
| params.paged_kv_cache.mBlockOffsets = paged_block_offsets; | ||
| params.k_stride_in_bytes = get_size_in_bytes(tokens_per_block * d, data_type); | ||
| params.v_stride_in_bytes = get_size_in_bytes(tokens_per_block * dv, data_type); | ||
| // FMHA kernels always access the K/V tensor in 4D coordinate [num_pages, H_kv, page_size, D]. | ||
| // The layout of HND or NHD is implemented by tensor strides to get the correct memory | ||
| // address. 4D tensor strides of HND: [block_size, page_size * D, D ,1] 4D tensor strides of | ||
| // NHD: [block_size, D, H_kv * D, 1] | ||
| if (is_paged_hnd) { | ||
| params.k_stride_in_bytes = get_size_in_bytes(d, data_type); | ||
| params.v_stride_in_bytes = get_size_in_bytes(dv, data_type); | ||
| params.k_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * d, data_type); | ||
| params.v_stride_in_bytes_2 = get_size_in_bytes(tokens_per_block * dv, data_type); | ||
| } else { | ||
| params.k_stride_in_bytes = get_size_in_bytes(h_kv * d, data_type); | ||
| params.v_stride_in_bytes = get_size_in_bytes(h_kv * dv, data_type); | ||
| params.k_stride_in_bytes_2 = get_size_in_bytes(d, data_type); | ||
| params.v_stride_in_bytes_2 = get_size_in_bytes(dv, data_type); | ||
| } | ||
| } else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V) { | ||
| // Layout [B, S, H_kv, D]. | ||
| params.k_ptr = k_d; | ||
|
|
@@ -252,7 +265,9 @@ static inline void determine_launch_params( | |
| // threshold for adopting flash attention or warp_specialized kernels. | ||
| launch_params.flash_attention = | ||
| (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) && | ||
| (s >= 16 && d >= 16) && !force_non_flash_attention; | ||
| // when s < 16, non-flash attention is faster, | ||
| // but in flashinfer only flash attention kernels are generated | ||
| (/*s >= 16 &&*/ d >= 16) && !force_non_flash_attention; | ||
|
|
||
| // enable warp_speialized kernels when s >= 512 on hopper | ||
| // note that warp_speialized kernels need flash attention + tma | ||
|
|
@@ -306,11 +321,18 @@ static inline Attention_mask_type string_to_mask_type(const std::string& s) { | |
| return Attention_mask_type::CAUSAL; // default | ||
| } | ||
|
|
||
| static inline Attention_input_layout string_to_input_layout(const std::string& s) { | ||
| static inline Attention_input_layout string_to_input_layout(const std::string& s, | ||
| bool& is_paged_hnd) { | ||
| if (s == "packed_qkv") return Attention_input_layout::PACKED_QKV; | ||
| if (s == "contiguous_q_kv") return Attention_input_layout::CONTIGUOUS_Q_KV; | ||
| if (s == "q_paged_kv_nhd") return Attention_input_layout::Q_PAGED_KV; | ||
| if (s == "q_paged_kv_hnd") return Attention_input_layout::Q_PAGED_KV; | ||
| if (s == "q_paged_kv_nhd") { | ||
| is_paged_hnd = false; | ||
| return Attention_input_layout::Q_PAGED_KV; | ||
| } | ||
| if (s == "q_paged_kv_hnd") { | ||
| is_paged_hnd = true; | ||
| return Attention_input_layout::Q_PAGED_KV; | ||
| } | ||
| if (s == "separate_q_k_v") return Attention_input_layout::SEPARATE_Q_K_V; | ||
| throw std::invalid_argument("Unsupported input_layout: " + s); | ||
|
coderabbitai[bot] marked this conversation as resolved.
|
||
| } | ||
|
|
@@ -333,7 +355,8 @@ void fmha_v2_run( | |
| ffi::TensorView scale_bmm2_d, // Pre-populated scale_bmm2 on device [1] int32 | ||
| Optional<ffi::TensorView> softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum) | ||
| Optional<ffi::TensorView> sinks) { | ||
| Attention_input_layout input_layout = string_to_input_layout(input_layout_str); | ||
| bool is_paged_hnd; | ||
| Attention_input_layout input_layout = string_to_input_layout(input_layout_str, is_paged_hnd); | ||
| Attention_mask_type attention_mask_type = string_to_mask_type(mask_mode_str); | ||
| Data_type output_dtype = dltype_to_data_type(o.dtype()); | ||
| // Get device properties | ||
|
|
@@ -363,9 +386,12 @@ void fmha_v2_run( | |
| d = q.shape()[3]; // head_dim_qk | ||
| dv = q.shape()[3]; // head_dim_v (same as d for standard attention) | ||
| } else if (input_layout == Attention_input_layout::Q_PAGED_KV) { | ||
| // q is 3D: [total_tokens, H, D], k/v are 4D paged: [num_pages, H_kv, page_size, D] | ||
| // q is 3D: [total_tokens, H, D] | ||
| h = q.shape()[1]; | ||
| h_kv = k.shape()[1]; | ||
| // k/v are 4D paged: | ||
| // HND: [num_pages, H_kv, page_size, D] | ||
| // NHD: [num_pages, page_size, H_kv, D] | ||
| h_kv = k.shape()[is_paged_hnd ? 1 : 2]; | ||
| d = q.shape()[2]; | ||
| dv = v.shape()[3]; | ||
| } else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) { | ||
|
|
@@ -578,16 +604,16 @@ void fmha_v2_run( | |
| } | ||
|
|
||
| bert::Fused_multihead_attention_params_v2 params_v2; | ||
| set_params(params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, b, s_q, s, | ||
| h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size, | ||
| // Paged kv cache. | ||
| tokens_per_block, qkv_packed_d, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, | ||
| kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d, | ||
| static_cast<void*>(cum_seq_lens_kv.data_ptr()), | ||
| static_cast<void*>(cum_seq_lens_q.data_ptr()), o.data_ptr(), nullptr, nullptr, | ||
| softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1, scale_softmax, scale_bmm2, | ||
| softcapping_scale_bmm1, false, false, false, has_alibi, | ||
| skip_softmax_threshold_scale_factor); | ||
| set_params( | ||
| params_v2, launch_params, data_type, acc_type, output_dtype, input_layout, is_paged_hnd, b, | ||
| s_q, s, h, h_kv, d, dv, total, 1, sliding_window_size, chunked_attention_size, | ||
| // Paged kv cache. | ||
| tokens_per_block, qkv_packed_d, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, | ||
| kv_cache_block_offsets_d, packed_mask_d, nullptr, attention_sinks_d, | ||
| static_cast<void*>(cum_seq_lens_kv.data_ptr()), static_cast<void*>(cum_seq_lens_q.data_ptr()), | ||
| o.data_ptr(), nullptr, nullptr, softmax_stats_ptr, scale_bmm2_d.data_ptr(), scale_bmm1, | ||
| scale_softmax, scale_bmm2, softcapping_scale_bmm1, false, false, false, has_alibi, | ||
| skip_softmax_threshold_scale_factor); | ||
|
|
||
| // For Q_PAGED_KV layout, override mMaxBlocksPerSeq to match the actual block_tables stride | ||
| // that we used when expanding the block offsets from [B, M] to [B, 2, M] | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.