Skip to content

Commit 1b4770e

Browse files
baodiiCopilot
andauthored
[Decode Attn] Change strategy of num_splits to avoid acc issue (vllm-project#204)
* change strategy of num_splits Signed-off-by: baodii <di.bao@intel.com> * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> Signed-off-by: baodi <di.bao@intel.com> --------- Signed-off-by: baodii <di.bao@intel.com> Signed-off-by: baodi <di.bao@intel.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
1 parent 69da26d commit 1b4770e

3 files changed

Lines changed: 95 additions & 39 deletions

File tree

csrc/flash_attn/flash_api.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,41 @@ inline int get_num_splits(
1818
device.get_info<sycl::ext::intel::info::device::gpu_slices>() *
1919
device
2020
.get_info<sycl::ext::intel::info::device::gpu_subslices_per_slice>();
21-
int parallel_ = num_xe_cores;
22-
int parallel_2 = num_xe_cores * 2;
2321

24-
int cur_parallel_d = batch_size * num_heads_kv;
22+
int cur_parallel = batch_size * num_heads_kv;
23+
int kv_blocks = (max_seqlen_k + block_size - 1) / block_size;
2524

26-
int num_splits = (parallel_ + cur_parallel_d - 1) / cur_parallel_d;
25+
// Below 128 KV blocks the per-split FMHA compute is too small relative
26+
// to the ReduceSplitK overhead, regardless of block size.
27+
if (kv_blocks < 128) return 1;
2728

28-
if (cur_parallel_d * num_splits > parallel_ && num_splits > 1) {
29-
num_splits = std::ceil(parallel_2 / static_cast<float>(cur_parallel_d)) - 1;
29+
int target_splits;
30+
if (cur_parallel < num_xe_cores) {
31+
// Under-utilized: fill GPU cores.
32+
// Scale by block_size since larger blocks mean more compute per WG.
33+
int eff_parallel = cur_parallel * block_size / 64;
34+
eff_parallel = std::max(1, eff_parallel);
35+
target_splits = (num_xe_cores + eff_parallel - 1) / eff_parallel;
36+
} else if (cur_parallel <= num_xe_cores * 2) {
37+
// Well-utilized zone (1x-2x oversubscription):
38+
// GPU is busy, splitting adds overhead without benefit.
39+
return 1;
40+
} else {
41+
// Heavily oversubscribed (>2x): shorter WGs help.
42+
// But gate out when compute is already saturated.
43+
int eff_parallel = cur_parallel * block_size / 64;
44+
if (eff_parallel >= num_xe_cores * 8) return 1;
45+
target_splits = std::max(1, kv_blocks / 64);
46+
int par_cap = std::max(1, num_xe_cores * 8 / cur_parallel);
47+
target_splits = std::min(target_splits, par_cap);
3048
}
3149

32-
int max_splits = (max_seqlen_k + block_size - 1) / block_size;
33-
max_splits = std::min(max_splits, parallel_);
34-
return std::min(num_splits, max_splits);
50+
// Each split must process at least 32 KV blocks.
51+
int max_splits_blocks = std::max(1, kv_blocks / 32);
52+
// Hard cap: more splits give diminishing returns and increase
53+
// ReduceSplitK overhead and temporary buffer memory.
54+
int num_splits = std::min({target_splits, max_splits_blocks, 8});
55+
return std::max(1, num_splits);
3556
}
3657

3758
std::vector<at::Tensor> mha_varlen_fwd(
@@ -181,10 +202,11 @@ std::vector<at::Tensor> mha_varlen_fwd(
181202
: at::empty(
182203
{num_tokens, num_heads_q * num_kv_splits, head_dim},
183204
q.options().device(q.device()));
184-
at::Tensor max_logits = at::empty(
205+
at::Tensor max_logits = at::full(
185206
{num_tokens, num_heads_q, num_kv_splits},
207+
-std::numeric_limits<float>::infinity(),
186208
q.options().dtype(at::kFloat).device(q.device()));
187-
at::Tensor exp_sums = at::empty(
209+
at::Tensor exp_sums = at::zeros(
188210
{num_tokens, num_heads_q, num_kv_splits},
189211
q.options().dtype(at::kFloat).device(q.device()));
190212

csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,8 @@ class DecodeFwdEpilogue {
519519
int idx_kv_split,
520520
int head_group_q,
521521
TensorSink& tSink, // Sink for current head
522-
int num_kv_splits) {
522+
int num_kv_splits,
523+
bool is_single_split) {
523524
using namespace cute;
524525
using ElementA = typename FragA::element_type;
525526

@@ -535,25 +536,36 @@ class DecodeFwdEpilogue {
535536

536537
auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id);
537538

538-
// store exp sum and max logits for current KV split
539+
// Always store exp sum and max logits for current KV split.
539540
// assume seq_len_qo == 1
540-
if (thr_id < head_group_q && num_kv_splits > 1) {
541-
exp_sums(thr_id, idx_kv_split) = rA_sum(0);
542-
max_logits(thr_id, idx_kv_split) = rA_max(0);
541+
if (thr_id < head_group_q) {
542+
if (is_single_split) {
543+
// Sentinel values: make ReduceSplitK a pass-through copy.
544+
exp_sums(thr_id, idx_kv_split) = ElementA(1);
545+
max_logits(thr_id, idx_kv_split) = ElementA(0);
546+
} else if (num_kv_splits > 1) {
547+
exp_sums(thr_id, idx_kv_split) = rA_sum(0);
548+
max_logits(thr_id, idx_kv_split) = rA_max(0);
549+
}
543550
}
544551

545552
/* Some subgroups may not have any work to do; if so, quit early. */
546553
if (!active) return;
547554

548-
/* Complete softmax, dividing out sums. */
549-
CUTLASS_PRAGMA_UNROLL
550-
for (int i = 0; i < rA_sum.size(); i++) {
551-
rA_sum(i) = ElementA(1) / rA_sum(i);
552-
}
555+
/* Complete softmax: normalize output for single-split sequences
556+
(so ReduceSplitK pass-through gives correct result).
557+
For multi-split, store unnormalized to avoid divide-multiply
558+
precision loss in the reduce roundtrip. */
559+
if (is_single_split || num_kv_splits <= 1) {
560+
CUTLASS_PRAGMA_UNROLL
561+
for (int i = 0; i < rA_sum.size(); i++) {
562+
rA_sum(i) = ElementA(1) / rA_sum(i);
563+
}
553564

554-
CUTLASS_PRAGMA_UNROLL
555-
for (int i = 0; i < rA.size(); i++) {
556-
rA(i) *= broadcast<0>(rA_sum, rA, i);
565+
CUTLASS_PRAGMA_UNROLL
566+
for (int i = 0; i < rA.size(); i++) {
567+
rA(i) *= broadcast<0>(rA_sum, rA, i);
568+
}
557569
}
558570

559571
/* Tile output */
@@ -585,8 +597,7 @@ class DecodeFwdEpilogue {
585597
using namespace sycl::ext::oneapi::this_work_item;
586598

587599
if constexpr (ReduceK{} == _1{}) {
588-
ReduceFragARow rA_max;
589-
return std::make_tuple(tArA, rA_max, tA_sum, true);
600+
return std::make_tuple(tArA, tA_max, tA_sum, true);
590601
} else {
591602
/* Identify A tile ID and k block for this subgroup. */
592603
auto thr_vak = group<1, 3>(TiledMMAPV{}.get_thr_layout_vmnk())

csrc/xpu/attn/xe_2/kernel/paged_decode_kernel.hpp

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,29 @@ class XeFMHAFwdSplitKVKernel {
319319

320320
int num_blocks_per_split =
321321
cute::ceil_div(windowed_k_blocks, num_kv_splits);
322-
int kv_split_offset = k_block0 + idx_kv_split * num_blocks_per_split;
323-
int num_effective_kv_blocks = cute::min(
324-
windowed_k_blocks - idx_kv_split * num_blocks_per_split,
325-
num_blocks_per_split);
322+
323+
// Per-sequence split decision: short sequences are treated as
324+
// single-split even when num_kv_splits > 1, avoiding precision
325+
// loss from the split-reduce roundtrip.
326+
constexpr int kMinBlocksForSplit = 128;
327+
bool is_single_split =
328+
(num_kv_splits > 1) && (windowed_k_blocks < kMinBlocksForSplit);
329+
330+
int kv_split_offset;
331+
int num_effective_kv_blocks;
332+
if (is_single_split) {
333+
// Split 0 processes all blocks; splits 1+ skip entirely.
334+
if (idx_kv_split > 0) {
335+
continue;
336+
}
337+
kv_split_offset = k_block0;
338+
num_effective_kv_blocks = windowed_k_blocks;
339+
} else {
340+
kv_split_offset = k_block0 + idx_kv_split * num_blocks_per_split;
341+
num_effective_kv_blocks = cute::min(
342+
windowed_k_blocks - idx_kv_split * num_blocks_per_split,
343+
num_blocks_per_split);
344+
}
326345

327346
if (num_effective_kv_blocks <= 0) {
328347
// no need computation
@@ -409,7 +428,8 @@ class XeFMHAFwdSplitKVKernel {
409428
idx_kv_split,
410429
head_group_q,
411430
sinks_per_kv,
412-
num_kv_splits);
431+
num_kv_splits,
432+
is_single_split);
413433
} else {
414434
epilogue(
415435
O(_, _, head, idx_kv_split, l_coord),
@@ -423,7 +443,8 @@ class XeFMHAFwdSplitKVKernel {
423443
idx_kv_split,
424444
head_group_q,
425445
sinks,
426-
num_kv_splits);
446+
num_kv_splits,
447+
is_single_split);
427448
}
428449
}
429450
}
@@ -702,16 +723,18 @@ class ReduceSplitK {
702723
ElementLSE local_max_logit = shared_storage.max_logits_slm_array[i];
703724
ElementLSE local_exp_sum = shared_storage.exp_sums_slm_array[i];
704725

726+
// Skip splits with no valid data (short sequences treated as
727+
// single-split have exp_sums=0 / max_logits=-inf for unused splits).
728+
if (local_exp_sum <= ElementLSE(0)) continue;
729+
705730
ElementLSE rescale =
706731
sycl::native::exp2(local_max_logit - global_max_logits);
707732

708-
// in FMHA epilogue, it's divided by local_exp_sum, here we multiply
709-
// back
710-
ElementLSE adjusted_o_accum =
711-
static_cast<ElementLSE>(
712-
Oaccum(seq_idx, idx, i * num_heads_q + head_q, l_coord)) *
713-
local_exp_sum;
714-
acc += adjusted_o_accum * rescale;
733+
// Partial outputs are unnormalized (not divided by exp_sum in the
734+
// epilogue), so combine them directly with the rescale factor.
735+
ElementLSE o_accum_val = static_cast<ElementLSE>(
736+
Oaccum(seq_idx, idx, i * num_heads_q + head_q, l_coord));
737+
acc += o_accum_val * rescale;
715738

716739
// update global exp sum
717740
global_exp_sums += local_exp_sum * rescale;

0 commit comments

Comments
 (0)