Skip to content

Conversation

@wuxun-zhang
Copy link

The new kernel implements below method, key points are:

  • num of work groups are fixed to num of total XeCores
  • dynamically split KV seq length from all seqs into all work groups
  • each XeCore gets balanced work units
image

As of now there are two limitations:

  • only decode support (seq_len_qo==1)
  • batch_size * num_heads_q <= num of total XeCores

@pengzhao-intel
Copy link

maybe add the limitation of this algorithm in the code as well, especially for one with atomic.

Comment on lines +353 to +359
if (args.kernel.shape.seq_len_qo > 1) {
return false;
}
// current kernel only support num batch heads less than total XeCore count
if (args.kernel.shape.batch * args.kernel.shape.num_heads_q > args.hw_info.sm_count) {
return false;
}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pengzhao-intel Added checks here in can_implement().

@Antonyvance Antonyvance requested a review from Copilot November 5, 2025 07:29
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces a persistent SDPA (Scaled Dot Product Attention) kernel for decode scenarios that implements dynamic load balancing across XeCores. The key innovation is fixing the number of work groups to match total XeCores and dynamically splitting KV sequence length across all work groups for balanced workload distribution.

Key changes:

  • New persistent tile scheduler (XeFHMAIndividualPersistentTileScheduler) that distributes work evenly across fixed XeCore count
  • New kernel implementation (XeFMHAFwdDynamicSplitKernel) with split-K reduction for partial results
  • Support infrastructure including atomic operations (atomicSub, atomicLoad) for synchronization

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
include/cutlass/gpu_generics.h Adds atomic operations (atomicSub, atomicLoad) for synchronization primitives
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp Integrates persistent kernel selection and queries hardware XeCore count
examples/06_bmg_flash_attention/CMakeLists.txt Adds build target for persistent kernel testing
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp Configures persistent kernel with appropriate tile sizes and subgroup layouts
applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp Implements persistent tile scheduler with dynamic work distribution
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp Implements dynamic split-K kernel with partial result reduction
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp Updates mainloop to use total block count for remainder masking


// Important: make sure multiple of 16 element for each copy
// this is for storing partial results from different KV partitions
static constexpr int num_elem_per_thead = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16;
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'thead' to 'thread'.

Suggested change
static constexpr int num_elem_per_thead = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16;
static constexpr int num_elem_per_thread = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16;

Copilot uses AI. Check for mistakes.
Comment on lines +552 to +557
int offset = batch_head_id * max_num_partitions * num_elem_per_thead * SGPerWG::value * intel::sg_size
+ partition_id * num_elem_per_thead * SGPerWG::value * intel::sg_size
+ sg_id * intel::sg_size * num_elem_per_thead
+ tid_in_sg * num_elem_per_thead;
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{}));
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{});
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable name 'num_elem_per_thead' uses misspelled 'thead' instead of 'thread'. This should be renamed for consistency.

Suggested change
int offset = batch_head_id * max_num_partitions * num_elem_per_thead * SGPerWG::value * intel::sg_size
+ partition_id * num_elem_per_thead * SGPerWG::value * intel::sg_size
+ sg_id * intel::sg_size * num_elem_per_thead
+ tid_in_sg * num_elem_per_thead;
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{}));
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{});
int offset = batch_head_id * max_num_partitions * num_elem_per_thread * SGPerWG::value * intel::sg_size
+ partition_id * num_elem_per_thread * SGPerWG::value * intel::sg_size
+ sg_id * intel::sg_size * num_elem_per_thread
+ tid_in_sg * num_elem_per_thread;
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thread>{}));
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread>{});

Copilot uses AI. Check for mistakes.
Comment on lines +595 to +600
int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thead
+ i * SGPerWG::value * intel::sg_size * num_elem_per_thead
+ sg_id * intel::sg_size * num_elem_per_thead
+ tid_in_sg * num_elem_per_thead;
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{}));
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{});
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable name 'num_elem_per_thead' uses misspelled 'thead' instead of 'thread'. This should be renamed for consistency.

Suggested change
int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thead
+ i * SGPerWG::value * intel::sg_size * num_elem_per_thead
+ sg_id * intel::sg_size * num_elem_per_thead
+ tid_in_sg * num_elem_per_thead;
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{}));
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{});
int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thread
+ i * SGPerWG::value * intel::sg_size * num_elem_per_thread
+ sg_id * intel::sg_size * num_elem_per_thread
+ tid_in_sg * num_elem_per_thread;
Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thread>{}));
Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread>{});

Copilot uses AI. Check for mistakes.
CUTLASS_DEVICE int atomicLoad(int *address) {
int result = 0;
#if defined(__SYCL_DEVICE_ONLY__)
auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]);
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The atomic_ref is constructed with address[0] which dereferences the pointer. This should be *address for clarity and consistency with standard atomic operations patterns.

Suggested change
auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]);
auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(*address);

Copilot uses AI. Check for mistakes.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(FragARow{}.shape()); ++i) {
merged_res(i + size(FragA{}.shape())) = tA_max(i);
merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i);
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing logic appears incorrect. For tA_sum, the offset should be size(FragA{}.shape()) + size(FragARow{}.shape()), not size(FragA{}.shape()) + 1. This will cause tA_max and tA_sum values to overlap/overwrite.

Suggested change
merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i);
merged_res(i + size(FragA{}.shape()) + size(FragARow{}.shape())) = tA_sum(i);

Copilot uses AI. Check for mistakes.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(FragARow{}.shape()); ++i) {
tA_max(i) = merged_res(i + size(FragA{}.shape()));
tA_sum(i) = merged_res(i + 1 + size(FragA{}.shape()));
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing logic appears incorrect. This should use offset size(FragA{}.shape()) + size(FragARow{}.shape()) to correctly retrieve tA_sum values, matching the storage layout.

Copilot uses AI. Check for mistakes.
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(FragARow{}.shape()); ++i) {
tA_max_2(i) = merged_res(i + size(FragA{}.shape()));
tA_sum_2(i) = merged_res(i + 1 + size(FragA{}.shape()));
Copy link

Copilot AI Nov 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indexing logic appears incorrect. This should use offset size(FragA{}.shape()) + size(FragARow{}.shape()) to correctly retrieve tA_sum values, matching the storage layout.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants