-
Notifications
You must be signed in to change notification settings - Fork 67
Persistent SDPA kernel #608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
maybe add the limitation of this algorithm in the code as well, especially for one with atomic. |
| if (args.kernel.shape.seq_len_qo > 1) { | ||
| return false; | ||
| } | ||
| // current kernel only support num batch heads less than total XeCore count | ||
| if (args.kernel.shape.batch * args.kernel.shape.num_heads_q > args.hw_info.sm_count) { | ||
| return false; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pengzhao-intel Added checks here in can_implement().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a persistent SDPA (Scaled Dot Product Attention) kernel for decode scenarios that implements dynamic load balancing across XeCores. The key innovation is fixing the number of work groups to match total XeCores and dynamically splitting KV sequence length across all work groups for balanced workload distribution.
Key changes:
- New persistent tile scheduler (
XeFHMAIndividualPersistentTileScheduler) that distributes work evenly across fixed XeCore count - New kernel implementation (
XeFMHAFwdDynamicSplitKernel) with split-K reduction for partial results - Support infrastructure including atomic operations (
atomicSub,atomicLoad) for synchronization
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
include/cutlass/gpu_generics.h |
Adds atomic operations (atomicSub, atomicLoad) for synchronization primitives |
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp |
Integrates persistent kernel selection and queries hardware XeCore count |
examples/06_bmg_flash_attention/CMakeLists.txt |
Adds build target for persistent kernel testing |
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp |
Configures persistent kernel with appropriate tile sizes and subgroup layouts |
applications/flash_attention_v2/kernel/xe_tile_scheduler.hpp |
Implements persistent tile scheduler with dynamic work distribution |
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp |
Implements dynamic split-K kernel with partial result reduction |
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp |
Updates mainloop to use total block count for remainder masking |
|
|
||
| // Important: make sure multiple of 16 element for each copy | ||
| // this is for storing partial results from different KV partitions | ||
| static constexpr int num_elem_per_thead = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16; |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'thead' to 'thread'.
| static constexpr int num_elem_per_thead = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16; | |
| static constexpr int num_elem_per_thread = (size(FragA{}.shape()) + 2 * size(FragARow{}.shape()) + 15) / 16 * 16; |
| int offset = batch_head_id * max_num_partitions * num_elem_per_thead * SGPerWG::value * intel::sg_size | ||
| + partition_id * num_elem_per_thead * SGPerWG::value * intel::sg_size | ||
| + sg_id * intel::sg_size * num_elem_per_thead | ||
| + tid_in_sg * num_elem_per_thead; | ||
| Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{})); | ||
| Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{}); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable name 'num_elem_per_thead' uses misspelled 'thead' instead of 'thread'. This should be renamed for consistency.
| int offset = batch_head_id * max_num_partitions * num_elem_per_thead * SGPerWG::value * intel::sg_size | |
| + partition_id * num_elem_per_thead * SGPerWG::value * intel::sg_size | |
| + sg_id * intel::sg_size * num_elem_per_thead | |
| + tid_in_sg * num_elem_per_thead; | |
| Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{})); | |
| Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{}); | |
| int offset = batch_head_id * max_num_partitions * num_elem_per_thread * SGPerWG::value * intel::sg_size | |
| + partition_id * num_elem_per_thread * SGPerWG::value * intel::sg_size | |
| + sg_id * intel::sg_size * num_elem_per_thread | |
| + tid_in_sg * num_elem_per_thread; | |
| Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thread>{})); | |
| Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread>{}); |
| int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thead | ||
| + i * SGPerWG::value * intel::sg_size * num_elem_per_thead | ||
| + sg_id * intel::sg_size * num_elem_per_thead | ||
| + tid_in_sg * num_elem_per_thead; | ||
| Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{})); | ||
| Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{}); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable name 'num_elem_per_thead' uses misspelled 'thead' instead of 'thread'. This should be renamed for consistency.
| int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thead | |
| + i * SGPerWG::value * intel::sg_size * num_elem_per_thead | |
| + sg_id * intel::sg_size * num_elem_per_thead | |
| + tid_in_sg * num_elem_per_thead; | |
| Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thead>{})); | |
| Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thead>{}); | |
| int offset = wg_id * max_num_partitions * SGPerWG::value * intel::sg_size * num_elem_per_thread | |
| + i * SGPerWG::value * intel::sg_size * num_elem_per_thread | |
| + sg_id * intel::sg_size * num_elem_per_thread | |
| + tid_in_sg * num_elem_per_thread; | |
| Tensor tPartial = make_tensor(params.partial_results_ptr + offset, make_shape(Int<num_elem_per_thread>{})); | |
| Tensor merged_res = make_tensor<ElementA>(Int<num_elem_per_thread>{}); |
| CUTLASS_DEVICE int atomicLoad(int *address) { | ||
| int result = 0; | ||
| #if defined(__SYCL_DEVICE_ONLY__) | ||
| auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The atomic_ref is constructed with address[0] which dereferences the pointer. This should be *address for clarity and consistency with standard atomic operations patterns.
| auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(address[0]); | |
| auto atm = sycl::atomic_ref<int, sycl::memory_order::relaxed, sycl::memory_scope::device, sycl::access::address_space::generic_space>(*address); |
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < size(FragARow{}.shape()); ++i) { | ||
| merged_res(i + size(FragA{}.shape())) = tA_max(i); | ||
| merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing logic appears incorrect. For tA_sum, the offset should be size(FragA{}.shape()) + size(FragARow{}.shape()), not size(FragA{}.shape()) + 1. This will cause tA_max and tA_sum values to overlap/overwrite.
| merged_res(i + 1 + size(FragA{}.shape())) = tA_sum(i); | |
| merged_res(i + size(FragA{}.shape()) + size(FragARow{}.shape())) = tA_sum(i); |
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < size(FragARow{}.shape()); ++i) { | ||
| tA_max(i) = merged_res(i + size(FragA{}.shape())); | ||
| tA_sum(i) = merged_res(i + 1 + size(FragA{}.shape())); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing logic appears incorrect. This should use offset size(FragA{}.shape()) + size(FragARow{}.shape()) to correctly retrieve tA_sum values, matching the storage layout.
| CUTLASS_PRAGMA_UNROLL | ||
| for (int i = 0; i < size(FragARow{}.shape()); ++i) { | ||
| tA_max_2(i) = merged_res(i + size(FragA{}.shape())); | ||
| tA_sum_2(i) = merged_res(i + 1 + size(FragA{}.shape())); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indexing logic appears incorrect. This should use offset size(FragA{}.shape()) + size(FragARow{}.shape()) to correctly retrieve tA_sum values, matching the storage layout.
The new kernel implements below method, key points are:
As of now there are two limitations:
seq_len_qo==1)batch_size * num_heads_q <= num of total XeCores