@@ -291,7 +291,6 @@ static inline Data_type dltype_to_data_type(DLDataType dtype) {
291291 return DATA_TYPE_FP16;
292292}
293293
294-
295294static inline Attention_mask_type string_to_mask_type (const std::string& s) {
296295 if (s == " padding" ) return Attention_mask_type::PADDING;
297296 if (s == " causal" ) return Attention_mask_type::CAUSAL;
@@ -309,7 +308,6 @@ static inline Attention_input_layout string_to_input_layout(const std::string& s
309308 return Attention_input_layout::Q_PAGED_KV; // default
310309}
311310
312-
313311void fmha_v2_run (
314312 ffi::TensorView q, // [batch, s_q, num_heads, head_dim]
315313 ffi::TensorView k, // [batch, s_kv, num_kv_heads, head_dim]
@@ -321,12 +319,11 @@ void fmha_v2_run(
321319 ffi::TensorView seq_lens, // [batch]
322320 ffi::TensorView cum_seq_lens_q, // [batch + 1]
323321 ffi::TensorView cum_seq_lens_kv, // [batch + 1]
324- const std::string& input_layout_str,
325- int max_q_len, int max_kv_len, int batch_size, int total_q_tokens,
326- int total_kv_tokens, // Totals from cum_seq_lens (computed in Python)
327- const std::string& mask_mode_str,
328- float scale_softmax, float scale_bmm1, float scale_bmm2, int window_left,
329- int chunked_attention_size, bool has_alibi, float softcapping_scale,
322+ const std::string& input_layout_str, int max_q_len, int max_kv_len, int batch_size,
323+ int total_q_tokens,
324+ int total_kv_tokens, // Totals from cum_seq_lens (computed in Python)
325+ const std::string& mask_mode_str, float scale_softmax, float scale_bmm1, float scale_bmm2,
326+ int window_left, int chunked_attention_size, bool has_alibi, float softcapping_scale,
330327 ffi::TensorView scale_bmm2_d, // Pre-populated scale_bmm2 on device [1] int32
331328 Optional<ffi::TensorView> softmax_stats, // Optional [batch, s_q, num_heads, 2] for (max, sum)
332329 Optional<ffi::TensorView> sinks) {
@@ -473,7 +470,8 @@ void fmha_v2_run(
473470 std::tie (warps_m, warps_n, warps_k) = get_warps (launch_params, sm, data_type, s, b, d, 2 );
474471
475472 // Debug output for warps
476- printf (" DEBUG: get_warps returned warps_m=%zu, warps_n=%zu, warps_k=%zu\n " , warps_m, warps_n, warps_k);
473+ printf (" DEBUG: get_warps returned warps_m=%zu, warps_n=%zu, warps_k=%zu\n " , warps_m, warps_n,
474+ warps_k);
477475 printf (" DEBUG: launch_params: flash_attention=%d, warp_specialization=%d, use_tma=%d\n " ,
478476 launch_params.flash_attention , launch_params.warp_specialization , launch_params.use_tma );
479477 printf (" DEBUG: data_type=%d, sm=%d, s=%zu, d=%zu\n " , int (data_type), sm, s, d);
0 commit comments