Fix workgroup barrier deadlock#312
Draft
frost-intel wants to merge 3 commits intovllm-project:mainfrom
Draft
Conversation
Signed-off-by: frost-intel <frost.mitchell@intel.com>
Signed-off-by: frost-intel <frost.mitchell@intel.com>
e08162f to
a05baa8
Compare
Author
|
@YizhouZ Can you review this? |
Collaborator
|
cc @xuechendi |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Purpose
Every subgroup in the workgroup should execute the same number of K-loops. The Xe2 FMHA kernel computes
k_block0,k_blocks, andk_blocks_causalfromseq_coord, which depends onq_offset_sg— a per-subgroup broadcast of the thread's row coordinate. Under causal masking, different subgroups within the same WG compute different loop bounds and execute different iteration counts, leaving some subgroups stuck atbarrier_waitforever.As a concrete example:
(seq_q=129, seq_k=463)causal withTileQ=128,tile_k=64:SG0 (row 0) ->
k_blocks = 6SG7 (row 112) ->
k_blocks = 8I'm fairly confident that this is the root cause of the hang in PVC flash_attn kernels. I'm not sure why this hasn't been an issue on BMG. However I'm not an expert so I'd welcome any feedback on this solution.
This fix computes tight per-SG bounds as before, then reduces across the workgroup to obtain uniform bounds for the loop. This change resolves the hang in the xe_2 FMHA chunk-prefill kernel that occurred under causal masking with short variable-length q, and under sliding-window (local) masking. These are the cases currently guarded by SKIP_HANG_KERNEL in tests/flash_attn/test_flash_attn_varlen_func.py.
Test Plan
pytest -v -s tests/flash_attn/Note this no longer requires SKIP_HANG_KERNEL=1
Test Result
All tests pass, no hang.