@@ -265,7 +265,7 @@ class XeFMHAFwdKernel {
265265 cute::min (seq_len_qo, (blk_q * get<0 >(TileShapeQK{}) + q_offset_sg));
266266
267267 // calc sg level seq_len_kv
268- const int seq_len =
268+ const int sg_seq_len =
269269 CausalMask
270270 ? LocalMask
271271 ? cute::min (
@@ -275,17 +275,39 @@ class XeFMHAFwdKernel {
275275 : cute::min (
276276 seq_len_kv, full_tile_offset + seq_coord + q_sg_tile)
277277 : seq_len_kv;
278- const int k_block0 =
278+ const int sg_k_block0 =
279279 LocalMask
280280 ? cute::max (
281281 seq_coord + full_tile_offset - params.mainloop .local_left ,
282282 0 ) /
283283 get<1 >(TileShapeQK{})
284284 : 0 ;
285- const int k_blocks = cute::ceil_div (seq_len, get<1 >(TileShapeQK{}));
286- const int k_blocks_causal =
287- CausalMask ? (seq_coord + full_tile_offset) / get<1 >(TileShapeQK{})
288- : 0 ;
285+ const int sg_k_blocks =
286+ cute::ceil_div (sg_seq_len, get<1 >(TileShapeQK{}));
287+ const int sg_k_blocks_causal =
288+ CausalMask
289+ ? (seq_coord + full_tile_offset) / get<1 >(TileShapeQK{})
290+ : 0 ;
291+
292+ // The mainloop wraps each K iteration in a workgroup-scoped barrier
293+ // pair, so every subgroup in the workgroup must execute the same
294+ // K-loop trip count. Reduce the per-SG bounds across the WG:
295+ // k_block0 = min across WG (start no later than any SG)
296+ // k_blocks = max across WG (end no earlier than any SG)
297+ // k_blocks_causal = min across WG (turn on causal masking no later
298+ // than any SG needs it)
299+ // Per-element causal / local / remainder masking inside the mainloop
300+ // handles the widened range safely for SGs that didn't need it.
301+ auto wg = sycl::ext::oneapi::this_work_item::get_work_group<3 >();
302+ const int k_block0 = LocalMask
303+ ? sycl::reduce_over_group (wg, sg_k_block0, sycl::minimum<int >{})
304+ : 0 ;
305+ const int k_blocks = (CausalMask || LocalMask)
306+ ? sycl::reduce_over_group (wg, sg_k_blocks, sycl::maximum<int >{})
307+ : sg_k_blocks;
308+ const int k_blocks_causal = CausalMask
309+ ? sycl::reduce_over_group (wg, sg_k_blocks_causal, sycl::minimum<int >{})
310+ : 0 ;
289311
290312 int offset_q = 0 , offset_k = 0 , offset_v = 0 , offset_o = 0 ;
291313 if constexpr (is_var_len) {
@@ -347,7 +369,7 @@ class XeFMHAFwdKernel {
347369 k_blocks,
348370 k_blocks_causal,
349371 thr_id,
350- seq_len ,
372+ sg_seq_len ,
351373 full_tile_offset);
352374
353375 // return softmax_lse
0 commit comments