Skip to content

Fix workgroup barrier deadlock#312

Draft
frost-intel wants to merge 3 commits intovllm-project:mainfrom
frost-intel:flash_attn_xe2_barrier_deadlock
Draft

Fix workgroup barrier deadlock#312
frost-intel wants to merge 3 commits intovllm-project:mainfrom
frost-intel:flash_attn_xe2_barrier_deadlock

Conversation

@frost-intel
Copy link
Copy Markdown

Purpose

Every subgroup in the workgroup should execute the same number of K-loops. The Xe2 FMHA kernel computes k_block0, k_blocks, and k_blocks_causal from seq_coord, which depends on q_offset_sg — a per-subgroup broadcast of the thread's row coordinate. Under causal masking, different subgroups within the same WG compute different loop bounds and execute different iteration counts, leaving some subgroups stuck at barrier_wait forever.

As a concrete example:
(seq_q=129, seq_k=463) causal with TileQ=128, tile_k=64:

SG0 (row 0) -> k_blocks = 6
SG7 (row 112) -> k_blocks = 8

I'm fairly confident that this is the root cause of the hang in PVC flash_attn kernels. I'm not sure why this hasn't been an issue on BMG. However I'm not an expert so I'd welcome any feedback on this solution.

This fix computes tight per-SG bounds as before, then reduces across the workgroup to obtain uniform bounds for the loop. This change resolves the hang in the xe_2 FMHA chunk-prefill kernel that occurred under causal masking with short variable-length q, and under sliding-window (local) masking. These are the cases currently guarded by SKIP_HANG_KERNEL in tests/flash_attn/test_flash_attn_varlen_func.py.

Test Plan

pytest -v -s tests/flash_attn/
Note this no longer requires SKIP_HANG_KERNEL=1

Test Result

All tests pass, no hang.

Signed-off-by: frost-intel <frost.mitchell@intel.com>
Signed-off-by: frost-intel <frost.mitchell@intel.com>
@frost-intel frost-intel force-pushed the flash_attn_xe2_barrier_deadlock branch from e08162f to a05baa8 Compare April 24, 2026 17:42
@frost-intel
Copy link
Copy Markdown
Author

@YizhouZ Can you review this?

@jikunshang
Copy link
Copy Markdown
Collaborator

cc @xuechendi

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants