@@ -50,7 +50,7 @@ static inline void set_params(
5050 // types
5151 Data_type data_type, Data_type acc_type, Data_type output_dtype,
5252 // attention input layout
53- Attention_input_layout input_layout,
53+ Attention_input_layout input_layout, const bool is_paged_hnd,
5454 // sizes
5555 const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv,
5656 const size_t d, const size_t dv, const size_t total, const size_t num_grouped_heads,
@@ -119,8 +119,21 @@ static inline void set_params(
119119 get_size_in_bytes (tokens_per_block * h_kv * std::gcd (d, dv), data_type),
120120 paged_kv_pool_ptr);
121121 params.paged_kv_cache .mBlockOffsets = paged_block_offsets;
122- params.k_stride_in_bytes = get_size_in_bytes (tokens_per_block * d, data_type);
123- params.v_stride_in_bytes = get_size_in_bytes (tokens_per_block * dv, data_type);
122+ // FMHA kernels always access the K/V tensor in 4D coordinate [num_pages, H_kv, page_size, D].
123+ // The layout of HND or NHD is implemented by tensor strides to get the correct memory
124+ // address. 4D tensor strides of HND: [block_size, page_size * D, D ,1] 4D tensor strides of
125+ // NHD: [block_size, D, H_kv * D, 1]
126+ if (is_paged_hnd) {
127+ params.k_stride_in_bytes = get_size_in_bytes (d, data_type);
128+ params.v_stride_in_bytes = get_size_in_bytes (dv, data_type);
129+ params.k_stride_in_bytes_2 = get_size_in_bytes (tokens_per_block * d, data_type);
130+ params.v_stride_in_bytes_2 = get_size_in_bytes (tokens_per_block * dv, data_type);
131+ } else {
132+ params.k_stride_in_bytes = get_size_in_bytes (h_kv * d, data_type);
133+ params.v_stride_in_bytes = get_size_in_bytes (h_kv * dv, data_type);
134+ params.k_stride_in_bytes_2 = get_size_in_bytes (d, data_type);
135+ params.v_stride_in_bytes_2 = get_size_in_bytes (dv, data_type);
136+ }
124137 } else if (input_layout == Attention_input_layout::SEPARATE_Q_K_V) {
125138 // Layout [B, S, H_kv, D].
126139 params.k_ptr = k_d;
@@ -247,10 +260,15 @@ static inline void determine_launch_params(
247260 launch_params.multi_processor_count = props.multiProcessorCount ;
248261 launch_params.device_l2_cache_size = props.l2CacheSize ;
249262
263+ #if 0
250264 // threshold for adopting flash attention or warp_specialized kernels.
251265 launch_params.flash_attention =
252266 (data_type == DATA_TYPE_FP16 || data_type == DATA_TYPE_BF16 || data_type == DATA_TYPE_E4M3) &&
253267 (s >= 16 && d >= 16) && !force_non_flash_attention;
268+ #else
269+ // Currently only flash attention kernels are generated in FlashInfer
270+ launch_params.flash_attention = true ;
271+ #endif
254272
255273 // enable warp_speialized kernels when s >= 512 on hopper
256274 // note that warp_speialized kernels need flash attention + tma
@@ -304,11 +322,18 @@ static inline Attention_mask_type string_to_mask_type(const std::string& s) {
304322 return Attention_mask_type::CAUSAL; // default
305323}
306324
307- static inline Attention_input_layout string_to_input_layout (const std::string& s) {
325+ static inline Attention_input_layout string_to_input_layout (const std::string& s,
326+ bool & is_paged_hnd) {
327+ is_paged_hnd = false ;
308328 if (s == " packed_qkv" ) return Attention_input_layout::PACKED_QKV;
309329 if (s == " contiguous_q_kv" ) return Attention_input_layout::CONTIGUOUS_Q_KV;
310- if (s == " q_paged_kv_nhd" ) return Attention_input_layout::Q_PAGED_KV;
311- if (s == " q_paged_kv_hnd" ) return Attention_input_layout::Q_PAGED_KV;
330+ if (s == " q_paged_kv_nhd" ) {
331+ return Attention_input_layout::Q_PAGED_KV;
332+ }
333+ if (s == " q_paged_kv_hnd" ) {
334+ is_paged_hnd = true ;
335+ return Attention_input_layout::Q_PAGED_KV;
336+ }
312337 if (s == " separate_q_k_v" ) return Attention_input_layout::SEPARATE_Q_K_V;
313338 throw std::invalid_argument (" Unsupported input_layout: " + s);
314339}
@@ -330,7 +355,8 @@ void fmha_v2_run(
330355 float skip_softmax_threshold_scale_factor,
331356 Optional<ffi::TensorView> softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum)
332357 Optional<ffi::TensorView> sinks) {
333- Attention_input_layout input_layout = string_to_input_layout (input_layout_str);
358+ bool is_paged_hnd;
359+ Attention_input_layout input_layout = string_to_input_layout (input_layout_str, is_paged_hnd);
334360 Attention_mask_type attention_mask_type = string_to_mask_type (mask_mode_str);
335361 Data_type output_dtype = dltype_to_data_type (o.dtype ());
336362 // Get device properties
@@ -360,9 +386,12 @@ void fmha_v2_run(
360386 d = q.shape ()[3 ]; // head_dim_qk
361387 dv = q.shape ()[3 ]; // head_dim_v (same as d for standard attention)
362388 } else if (input_layout == Attention_input_layout::Q_PAGED_KV) {
363- // q is 3D: [total_tokens, H, D], k/v are 4D paged: [num_pages, H_kv, page_size, D]
389+ // q is 3D: [total_tokens, H, D]
364390 h = q.shape ()[1 ];
365- h_kv = k.shape ()[1 ];
391+ // k/v are 4D paged:
392+ // HND: [num_pages, H_kv, page_size, D]
393+ // NHD: [num_pages, page_size, H_kv, D]
394+ h_kv = k.shape ()[is_paged_hnd ? 1 : 2 ];
366395 d = q.shape ()[2 ];
367396 dv = v.shape ()[3 ];
368397 } else if (input_layout == Attention_input_layout::CONTIGUOUS_Q_KV) {
0 commit comments