Skip to content

Commit e08162f

Browse files
committed
Fix workgroup barrier deadlock
1 parent 1ffbac0 commit e08162f

3 files changed

Lines changed: 30 additions & 29 deletions

File tree

.github/workflows/ut.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ jobs:
119119
- name: test
120120
run: |
121121
echo "Running tests with XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }}"
122-
XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }} ZE_AFFINITY_MASK=0,1 SKIP_HANG_KERNEL=1 SKIP_ACC_ERROR_KERNEL=1 pytest -v -s tests/
122+
XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }} ZE_AFFINITY_MASK=0,1 SKIP_ACC_ERROR_KERNEL=1 pytest -v -s tests/
123123
VLLM_XPU_FORCE_XE_DEFAULT_KERNEL=1 XPU_KERNEL_TEST_SCOPE=${{ env.XPU_KERNEL_TEST_SCOPE }} ZE_AFFINITY_MASK=0,1 pytest -v -s tests/fused_moe/test_grouped_gemm.py::test_grouped_gemm
124124
125125
clean-repo-pvc:

csrc/xpu/attn/xe_2/kernel/chunk_prefill_kernel.hpp

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ class XeFMHAFwdKernel {
265265
cute::min(seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + q_offset_sg));
266266

267267
// calc sg level seq_len_kv
268-
const int seq_len =
268+
const int sg_seq_len =
269269
CausalMask
270270
? LocalMask
271271
? cute::min(
@@ -275,17 +275,39 @@ class XeFMHAFwdKernel {
275275
: cute::min(
276276
seq_len_kv, full_tile_offset + seq_coord + q_sg_tile)
277277
: seq_len_kv;
278-
const int k_block0 =
278+
const int sg_k_block0 =
279279
LocalMask
280280
? cute::max(
281281
seq_coord + full_tile_offset - params.mainloop.local_left,
282282
0) /
283283
get<1>(TileShapeQK{})
284284
: 0;
285-
const int k_blocks = cute::ceil_div(seq_len, get<1>(TileShapeQK{}));
286-
const int k_blocks_causal =
287-
CausalMask ? (seq_coord + full_tile_offset) / get<1>(TileShapeQK{})
288-
: 0;
285+
const int sg_k_blocks =
286+
cute::ceil_div(sg_seq_len, get<1>(TileShapeQK{}));
287+
const int sg_k_blocks_causal =
288+
CausalMask
289+
? (seq_coord + full_tile_offset) / get<1>(TileShapeQK{})
290+
: 0;
291+
292+
// The mainloop wraps each K iteration in a workgroup-scoped barrier
293+
// pair, so every subgroup in the workgroup must execute the same
294+
// K-loop trip count. Reduce the per-SG bounds across the WG:
295+
// k_block0 = min across WG (start no later than any SG)
296+
// k_blocks = max across WG (end no earlier than any SG)
297+
// k_blocks_causal = min across WG (turn on causal masking no later
298+
// than any SG needs it)
299+
// Per-element causal / local / remainder masking inside the mainloop
300+
// handles the widened range safely for SGs that didn't need it.
301+
auto wg = sycl::ext::oneapi::this_work_item::get_work_group<3>();
302+
const int k_block0 = LocalMask
303+
? sycl::reduce_over_group(wg, sg_k_block0, sycl::minimum<int>{})
304+
: 0;
305+
const int k_blocks = (CausalMask || LocalMask)
306+
? sycl::reduce_over_group(wg, sg_k_blocks, sycl::maximum<int>{})
307+
: sg_k_blocks;
308+
const int k_blocks_causal = CausalMask
309+
? sycl::reduce_over_group(wg, sg_k_blocks_causal, sycl::minimum<int>{})
310+
: 0;
289311

290312
int offset_q = 0, offset_k = 0, offset_v = 0, offset_o = 0;
291313
if constexpr (is_var_len) {
@@ -347,7 +369,7 @@ class XeFMHAFwdKernel {
347369
k_blocks,
348370
k_blocks_causal,
349371
thr_id,
350-
seq_len,
372+
sg_seq_len,
351373
full_tile_offset);
352374

353375
// return softmax_lse

tests/flash_attn/test_flash_attn_varlen_func.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,6 @@ def test_varlen_with_paged_kv(
193193
) -> None:
194194
torch.set_default_device("xpu")
195195
torch.xpu.set_device("xpu:0")
196-
# # FIXME: remove skip
197-
if (is_casual and seq_lens[1][0]
198-
== 5) and (os.getenv("SKIP_HANG_KERNEL") is not None
199-
and os.getenv("SKIP_HANG_KERNEL") == "1"):
200-
pytest.skip("skip casual for seqlen0 to avoid runtime hang on CI.")
201-
if (window_size[0] != -1 or window_size[1]
202-
!= -1) and (os.getenv("SKIP_HANG_KERNEL") is not None
203-
and os.getenv("SKIP_HANG_KERNEL") == "1"):
204-
pytest.skip("skip local attn to avoid runtime hang on CI.")
205196
if block_size == 128 and num_blocks == 32768 and head_size >= 192:
206197
pytest.skip("skip test cases that may run out of Memory.")
207198
if stride_pad > 0 and fp8_dtype is not None:
@@ -393,15 +384,6 @@ def test_varlen_with_interleaved_paged_kv(
393384
) -> None:
394385
torch.set_default_device("xpu")
395386
torch.xpu.set_device("xpu:0")
396-
# # FIXME: remove skip
397-
if (is_casual and seq_lens[1][0]
398-
== 5) and (os.getenv("SKIP_HANG_KERNEL") is not None
399-
and os.getenv("SKIP_HANG_KERNEL") == "1"):
400-
pytest.skip("skip casual for seqlen0 to avoid runtime hang on CI.")
401-
if (window_size[0] != -1 or window_size[1]
402-
!= -1) and (os.getenv("SKIP_HANG_KERNEL") is not None
403-
and os.getenv("SKIP_HANG_KERNEL") == "1"):
404-
pytest.skip("skip local attn to avoid runtime hang on CI.")
405387
if block_size == 128 and num_blocks == 32768 and head_size >= 192:
406388
pytest.skip("skip test cases that may run out of Memory.")
407389

@@ -808,9 +790,6 @@ def test_varlen_with_softmax_lse(
808790
) -> None:
809791
torch.set_default_device("xpu")
810792
torch.xpu.set_device("xpu:0")
811-
if (is_casual and seq_lens[1][0]
812-
== 5) and (os.getenv("SKIP_HANG_KERNEL") == "1"):
813-
pytest.skip("skip casual for seqlen0 to avoid runtime hang on CI.")
814793
torch.manual_seed(4242)
815794

816795
query_lens = [x[0] for x in seq_lens]

0 commit comments

Comments
 (0)