Skip to content

Commit 9c8616f

Browse files
baodiiYizhouZjikunshang
authored
Tune attention perf to align with IPEX attention functions (vllm-project#162)
* tune perf for decoding kernel Signed-off-by: baodii <di.bao@intel.com> * add block_dispatch for 64 and 128 Signed-off-by: baodii <di.bao@intel.com> * prefetch table before gemm Signed-off-by: Yizhou Wang <yizhou.wang@intel.com> * tem opt: set num_splits_kv = 1 for llama3-8b Signed-off-by: baodii <di.bao@intel.com> * add strategy for num_splits_kv Signed-off-by: baodii <di.bao@intel.com> * Delete tests/flash_attn/test_flash_attn_varlen_func_perf.py Signed-off-by: baodii <di.bao@intel.com> * make format happy Signed-off-by: baodii <di.bao@intel.com> * fix chunked prefill acc issue when not paged Signed-off-by: baodii <di.bao@intel.com> * update UT Signed-off-by: baodii <di.bao@intel.com> * make format happy Signed-off-by: baodii <di.bao@intel.com> * restore and update UT Signed-off-by: baodii <di.bao@intel.com> * Resolve Copilot review comments from PR vllm-project#162 - Add bounds check for page_local_idx in chunk_prefill_mainloop.hpp - Fix get_num_splits to use batch_size instead of num_tokens in flash_api.cpp - Add docstring for num_splits_kv in flash_attn_interface.py Signed-off-by: baodii <di.bao@intel.com> * use sm_count to replace hardcode 20 Signed-off-by: baodii <di.bao@intel.com> --------- Signed-off-by: baodii <di.bao@intel.com> Signed-off-by: Yizhou Wang <yizhou.wang@intel.com> Co-authored-by: Yizhou Wang <yizhou.wang@intel.com> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 368108f commit 9c8616f

14 files changed

Lines changed: 319 additions & 183 deletions

csrc/flash_attn/flash_api.cpp

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,33 @@
77

88
namespace FLASH_NAMESPACE {
99

10+
inline int get_num_splits(
11+
const sycl::queue& queue,
12+
const int& batch_size,
13+
const int& num_heads_kv,
14+
const int& max_seqlen_k,
15+
const int& block_size) {
16+
auto device = queue.get_device();
17+
int num_xe_cores =
18+
device.get_info<sycl::ext::intel::info::device::gpu_slices>() *
19+
device
20+
.get_info<sycl::ext::intel::info::device::gpu_subslices_per_slice>();
21+
int parallel_ = num_xe_cores;
22+
int parallel_2 = num_xe_cores * 2;
23+
24+
int cur_parallel_d = batch_size * num_heads_kv;
25+
26+
int num_splits = (parallel_ + cur_parallel_d - 1) / cur_parallel_d;
27+
28+
if (cur_parallel_d * num_splits > parallel_ && num_splits > 1) {
29+
num_splits = std::ceil(parallel_2 / static_cast<float>(cur_parallel_d)) - 1;
30+
}
31+
32+
int max_splits = (max_seqlen_k + block_size - 1) / block_size;
33+
max_splits = std::min(max_splits, parallel_);
34+
return std::min(num_splits, max_splits);
35+
}
36+
1037
std::vector<at::Tensor> mha_varlen_fwd(
1138
const at::Tensor& q,
1239
const at::Tensor& k,
@@ -32,7 +59,8 @@ std::vector<at::Tensor> mha_varlen_fwd(
3259
int window_size_right,
3360
const float softcap,
3461
const bool return_softmax,
35-
std::optional<at::Generator> gen_) {
62+
std::optional<at::Generator> gen_,
63+
std::optional<int> num_splits) {
3664
auto q_type = q.scalar_type();
3765
auto k_type = k.scalar_type();
3866
TORCH_CHECK(
@@ -131,18 +159,22 @@ std::vector<at::Tensor> mha_varlen_fwd(
131159
is_local,
132160
is_sink);
133161
} else {
134-
constexpr int partition_size = 512;
135-
int num_kv_splits = (max_seqlen_k + partition_size - 1) / partition_size;
136-
if (num_kv_splits > 20) num_kv_splits = 20;
137-
138162
int num_tokens = q.size(0);
163+
int batch_size = static_cast<int>(cu_seqlens_q.size(0)) - 1;
139164
int num_heads_q = q.size(1);
140165
int head_dim = q.size(2);
141166
int num_heads_kv = k.size(2);
142167
int block_size = k.size(1);
143-
at::Tensor tmp_out = at::empty(
144-
{num_tokens, num_heads_q * num_kv_splits, head_dim},
145-
q.options().device(q.device()));
168+
169+
int num_kv_splits = num_splits.value_or(get_num_splits(
170+
queue, batch_size, num_heads_kv, max_seqlen_k, block_size));
171+
172+
at::Tensor tmp_out =
173+
num_kv_splits == 1
174+
? out
175+
: at::empty(
176+
{num_tokens, num_heads_q * num_kv_splits, head_dim},
177+
q.options().device(q.device()));
146178
at::Tensor max_logits = at::empty(
147179
{num_tokens, num_heads_q, num_kv_splits},
148180
q.options().dtype(at::kFloat).device(q.device()));
@@ -200,7 +232,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
200232
"float softmax_scale, Tensor? softmax_sink, bool zero_tensors, "
201233
"bool is_causal, int window_size_left, int window_size_right, float "
202234
"softcap, bool return_softmax, "
203-
"Generator? gen) -> Tensor[]");
235+
"Generator? gen, int? num_splits) -> Tensor[]");
204236
ops.impl(
205237
"varlen_fwd",
206238
torch::kXPU,

csrc/xpu/attn/xe_2/collective/chunk_prefill_epilogue.hpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,8 @@ class DecodeFwdEpilogue {
518518
const TensorLSE2D& max_logits, // Global max logits tensor
519519
int idx_kv_split,
520520
int head_group_q,
521-
TensorSink& tSink // Sink for current head
522-
) {
521+
TensorSink& tSink, // Sink for current head
522+
int num_kv_splits) {
523523
using namespace cute;
524524
using ElementA = typename FragA::element_type;
525525

@@ -535,11 +535,9 @@ class DecodeFwdEpilogue {
535535

536536
auto [rA, rA_max, rA_sum, active] = reduce_A(tArA, tA_max, tA_sum, thr_id);
537537

538-
int thr_id_sg = thr_id % intel::sg_size;
539-
540538
// store exp sum and max logits for current KV split
541539
// assume seq_len_qo == 1
542-
if (thr_id < head_group_q) {
540+
if (thr_id < head_group_q && num_kv_splits > 1) {
543541
exp_sums(thr_id, idx_kv_split) = rA_sum(0);
544542
max_logits(thr_id, idx_kv_split) = rA_max(0);
545543
}

csrc/xpu/attn/xe_2/collective/chunk_prefill_mainloop.hpp

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ struct FMHAFwdMainloop<
194194
void* const scale_v;
195195

196196
// Paged KV Cache
197-
int* ptr_page_table;
197+
const int* ptr_page_table;
198198
int page_size;
199199
int max_pages_per_seq;
200200
int total_seqlen_kv;
@@ -236,6 +236,20 @@ struct FMHAFwdMainloop<
236236
return true;
237237
}
238238

239+
CUTLASS_DEVICE int get_paged_idx(int K, int idx_b) {
240+
int tiles_per_page = params.page_size / get<1>(TileShapeQK{});
241+
int b_offset = idx_b * params.max_pages_per_seq;
242+
int page_local_idx = K * get<1>(TileShapeQK{}) / params.page_size;
243+
244+
// Clamp page_local_idx to the valid range [0, max_pages_per_seq - 1]
245+
if (page_local_idx >= params.max_pages_per_seq) {
246+
page_local_idx = params.max_pages_per_seq - 1;
247+
}
248+
249+
return params.ptr_page_table[b_offset + page_local_idx] * tiles_per_page +
250+
K % tiles_per_page;
251+
}
252+
239253
template <typename QVCoord>
240254
CUTLASS_DEVICE void operator()(
241255
TensorQ2D const& Q_2D, // (q,d)
@@ -339,24 +353,21 @@ struct FMHAFwdMainloop<
339353
// ------
340354

341355
// PagedKV
342-
int tiles_per_page = params.page_size / get<1>(TileShapeQK{});
343-
int page_idx = blk_k0, next_page_idx;
344-
int b_offset = idx_b * params.max_pages_per_seq;
356+
int page_idx, next_page_idx;
345357
if constexpr (PagedKV) {
346-
int page_local_idx = page_idx * get<1>(TileShapeQK{}) / params.page_size;
347-
page_idx =
348-
params.ptr_page_table[b_offset + page_local_idx] * tiles_per_page +
349-
page_idx % tiles_per_page;
358+
next_page_idx = get_paged_idx(blk_k0, idx_b);
350359
}
351360

352361
/* Initialization steps for first block: Q/K prefetch, O init */
353362
/* TODO: limit D prefetch for large head size, and reorder K prefetches */
363+
CUTLASS_PRAGMA_UNROLL
354364
for (int D = 0; D < size<3>(pQgQ); D++) {
355365
prefetch(prefetch_q, pQgQ(_, _, _, D));
356366
}
357367

368+
CUTLASS_PRAGMA_UNROLL
358369
for (int D = 0; D < size<4>(pKgK); D++) {
359-
prefetch(prefetch_k, pKgK(_, _, _, page_idx, D));
370+
prefetch(prefetch_k, pKgK(_, _, _, next_page_idx, D));
360371
}
361372

362373
clear(tArA);
@@ -378,6 +389,12 @@ struct FMHAFwdMainloop<
378389
/* Split barrier to keep threads together */
379390
// barrier_arrive(ScopeSubgroup);
380391

392+
page_idx = next_page_idx;
393+
// next paged_idx
394+
if constexpr (PagedKV) {
395+
next_page_idx = get_paged_idx(K + 1, idx_b);
396+
}
397+
381398
auto tKgK_cache =
382399
PagedKV ? tKgK(_, _, _, page_idx, _) : tKgK(_, _, _, K, _);
383400
auto tVgV_cache =
@@ -473,29 +490,12 @@ struct FMHAFwdMainloop<
473490
}
474491

475492
// sycl::group_barrier(compat::get_nd_item<1>().get_group());
476-
barrier();
477-
478-
// next paged_idx
479-
next_page_idx = K + 1;
480-
if constexpr (PagedKV) {
481-
int next_page_local_idx =
482-
next_page_idx * get<1>(TileShapeQK{}) / params.page_size;
483-
bool valid_page = next_page_local_idx < params.max_pages_per_seq;
484-
if (valid_page) {
485-
next_page_idx =
486-
params.ptr_page_table[b_offset + next_page_local_idx] *
487-
tiles_per_page +
488-
next_page_idx % tiles_per_page;
489-
} else {
490-
// set to last page
491-
next_page_idx = params.max_pages_per_seq * tiles_per_page - 1;
492-
}
493-
}
494-
page_idx = next_page_idx;
493+
// barrier();
495494

496495
/* K prefetch */
496+
CUTLASS_PRAGMA_UNROLL
497497
for (int D = 0; D < size<4>(pKgK); D++) {
498-
prefetch(prefetch_k, pKgK(_, _, _, page_idx, D));
498+
prefetch(prefetch_k, pKgK(_, _, _, next_page_idx, D));
499499
}
500500

501501
// barrier_wait(ScopeSubgroup);

csrc/xpu/attn/xe_2/fmha_utils.hpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,28 @@ struct chunk_policy_head256 {
8383
using SubgroupLayoutQK = Layout<Shape<_32, _1, _1>>;
8484
};
8585

86-
// define macro for decode policy
87-
#define DECODE_NUM_SG _4
88-
#define DECODE_KV_TILE _64 // KV tile size is set to 64 for page size is 64
89-
90-
template <class q_packed, class head_dim>
86+
// define decode policy
87+
template <typename q_packed, typename head_dim, typename kv_tile>
9188
struct decode_policy_qpacked_head {
92-
using ShapeQK = Shape<q_packed, DECODE_KV_TILE, _64>;
93-
using ShapePV = Shape<q_packed, _32, DECODE_KV_TILE>;
89+
static_assert(
90+
cute::is_same_v<kv_tile, _64> || cute::is_same_v<kv_tile, _128>,
91+
"Unsupported kv_tile(page_size) for decode_policy_qpacked_head");
92+
};
93+
94+
// kv_tile == _64
95+
template <typename q_packed, typename head_dim>
96+
struct decode_policy_qpacked_head<q_packed, head_dim, _64> {
97+
using ShapeQK = Shape<q_packed, _64, _64>;
98+
using ShapePV = Shape<q_packed, _32, _64>;
99+
using ShapeOut = Shape<q_packed, head_dim>;
100+
using SubgroupLayoutQK = Layout<Shape<_1, _4, _1>>;
101+
};
102+
103+
// kv_tile == _128
104+
template <typename q_packed, typename head_dim>
105+
struct decode_policy_qpacked_head<q_packed, head_dim, _128> {
106+
using ShapeQK = Shape<q_packed, _128, _64>;
107+
using ShapePV = Shape<q_packed, _32, _128>;
94108
using ShapeOut = Shape<q_packed, head_dim>;
95-
using SubgroupLayoutQK = Layout<Shape<_1, DECODE_NUM_SG, _1>>;
109+
using SubgroupLayoutQK = Layout<Shape<_1, _8, _1>>;
96110
};

csrc/xpu/attn/xe_2/fmha_xe2.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ void cutlass_chunk_prefill_impl(
8989
// query: [batch, num_heads, seq, head_size]
9090
batch_size = query.size(0);
9191
num_heads_q = query.size(1);
92-
num_heads_kv = key_cache.size(1);
92+
num_heads_kv = is_paged ? key_cache.size(2) : key_cache.size(1);
9393
head_size = query.size(3);
9494
max_seqlen_q = query.size(2);
95-
max_seqlen_k = key_cache.size(2);
95+
max_seqlen_k = is_paged ? max_seqlen_q : key_cache.size(2);
9696
}
9797
if (is_paged) {
9898
num_blocks = key_cache.size(0);

csrc/xpu/attn/xe_2/kernel/chunk_prefill_kernel.hpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,18 @@ class XeFMHAFwdKernel {
299299
offset_o = s.num_heads_q * s.head_size_vo * qo_cumulative[idx_b];
300300
}
301301

302-
auto batch_dim = is_var_len ? 1 : s.batch;
302+
auto batch_dim_qo = is_var_len ? 1 : s.batch;
303+
auto batch_dim_kv = (PagedKV || is_var_len) ? 1 : s.batch;
303304
auto total_seqlen_kv =
304305
PagedKV ? params.mainloop.total_seqlen_kv : seq_len_kv;
305306
auto shape_Q =
306-
make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim);
307+
make_shape(seq_len_qo, s.head_size_qk, s.num_heads_q, batch_dim_qo);
307308
auto shape_K = make_shape(
308-
total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim);
309+
total_seqlen_kv, s.head_size_qk, s.num_heads_kv, batch_dim_kv);
309310
auto shape_V = make_shape(
310-
s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim);
311+
s.head_size_vo, total_seqlen_kv, s.num_heads_kv, batch_dim_kv);
311312
auto shape_O =
312-
make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim);
313+
make_shape(seq_len_qo, s.head_size_vo, s.num_heads_q, batch_dim_qo);
313314

314315
auto dcQ = const_cast<ElementQ*>(p.Q + offset_q);
315316
auto dcK = const_cast<ElementK*>(p.K + offset_k);
@@ -319,10 +320,10 @@ class XeFMHAFwdKernel {
319320
auto layout_q = is_var_len
320321
? make_ordered_layout(shape_Q, Step<_2, _0, _1, _3>{})
321322
: make_layout(shape_Q, p.dQ);
322-
auto layout_k = is_var_len
323+
auto layout_k = (PagedKV || is_var_len)
323324
? make_ordered_layout(shape_K, Step<_2, _0, _1, _3>{})
324325
: make_layout(shape_K, p.dK);
325-
auto layout_v = is_var_len
326+
auto layout_v = (PagedKV || is_var_len)
326327
? make_ordered_layout(shape_V, Step<_0, _2, _1, _3>{})
327328
: make_layout(shape_V, p.dV);
328329
auto layout_o = is_var_len
@@ -339,12 +340,13 @@ class XeFMHAFwdKernel {
339340
FragARow tA_max, tA_sum;
340341

341342
// Main loop
342-
int l_coord = is_var_len ? 0 : idx_b;
343+
int l_coord_qo = is_var_len ? 0 : idx_b;
344+
int l_coord_kv = (PagedKV || is_var_len) ? 0 : idx_b;
343345
CollectiveMainloop mainloop(params.mainloop, shared_storage.mainloop);
344346
mainloop(
345-
Q(_, _, head_q, l_coord),
346-
K(_, _, head, l_coord),
347-
V(_, _, head, l_coord),
347+
Q(_, _, head_q, l_coord_qo),
348+
K(_, _, head, l_coord_kv),
349+
V(_, _, head, l_coord_kv),
348350
tArA,
349351
tA_max,
350352
tA_sum,
@@ -368,7 +370,7 @@ class XeFMHAFwdKernel {
368370
if constexpr (Sink) {
369371
ElementSink s_head = p.ptr_S[head_q];
370372
epilogue(
371-
O(_, _, head_q, l_coord),
373+
O(_, _, head_q, l_coord_qo),
372374
tArA,
373375
tA_max,
374376
tA_sum,
@@ -377,7 +379,7 @@ class XeFMHAFwdKernel {
377379
thr_id);
378380
} else {
379381
epilogue(
380-
O(_, _, head_q, l_coord),
382+
O(_, _, head_q, l_coord_qo),
381383
tArA,
382384
tA_max,
383385
tA_sum,
@@ -389,4 +391,4 @@ class XeFMHAFwdKernel {
389391
}
390392
};
391393

392-
} // namespace cutlass::fmha::kernel
394+
} // namespace cutlass::fmha::kernel

0 commit comments

Comments
 (0)