[Arm] Enable SDPA Fusion with Sink input on ARM CPUs#33566
Conversation
|
@alvoron , could you please review? |
There was a problem hiding this comment.
Pull request overview
This PR enables SDPA (Scaled Dot Product Attention) fusion with Sink input support for ARM CPUs, extending functionality previously limited to x86_64 platforms. The changes remove platform-specific restrictions and implement the necessary ARM-specific kernel modifications to handle sink inputs during attention computation.
Changes:
- Removed x86_64-only restriction for sink input support in SDPA fusion
- Integrated sink input processing into ARM's ACL (Arm Compute Library) attention kernel
- Optimized SVE (Scalable Vector Extension) operations by replacing predicated-zeroing with predicated-merging variants
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| stateful_sdpa_fusion.cpp | Removes preprocessor directive that restricted sink input support to x86_64 platform |
| scaled_attn.cpp | Enables sink input parameter usage in ARM ACL kernel and passes it to softmax computation |
| softmax_kernel.hpp | Refactors SVE loop structure, optimizes predicate usage, and adds sink value processing in softmax normalization |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pg_u8 = svwhilelt_b8(0, static_cast<int>(inc)); | ||
| pg_u16 = svwhilelt_b16(0, static_cast<int>(inc)); | ||
| } | ||
| for (; i + vec_len_f16_sve() < size; i += vec_len_f16_sve()) { |
There was a problem hiding this comment.
The loop condition i + vec_len_f16_sve() < size prevents processing the last vector when exactly aligned. This should be i + vec_len_f16_sve() <= size to match the pattern used in the exp_reduce_sum_f32 function (line with i + svcnth() <= size), ensuring all complete vectors are processed.
| for (; i + vec_len_f16_sve() < size; i += vec_len_f16_sve()) { | |
| for (; i + vec_len_f16_sve() <= size; i += vec_len_f16_sve()) { |
|
|
||
| ov::float16 sum = 0.0f; | ||
| if (sink != nullptr) { | ||
| max = max > static_cast<const ov::float16>(*sink) ? max : static_cast<const ov::float16>(*sink); |
There was a problem hiding this comment.
The ternary operator can be replaced with std::max for better readability: max = std::max(max, static_cast<ov::float16>(*sink));
| if (dst_precision == ov::element::f32) { | ||
| exp_reduce_sum_f32(a, max, len, sum); | ||
| if (sink != nullptr) { | ||
| sum += std::exp(*sink - max); |
There was a problem hiding this comment.
This computation is duplicated in both branches of the if-else statement. Consider computing sink_contrib = std::exp(*sink - max) once before the if-else block and adding it to sum in both branches to reduce code duplication.
| if (sink != nullptr) { | ||
| max = max > static_cast<const ov::float16>(*sink) ? max : static_cast<const ov::float16>(*sink); | ||
| } | ||
| if (dst_precision == ov::element::f32) { | ||
| exp_reduce_sum_f32(a, max, len, sum); | ||
| if (sink != nullptr) { | ||
| sum += std::exp(*sink - max); | ||
| } | ||
| ov::float16 scalar = 1.0f / sum; | ||
| multiply_scalar(a, static_cast<float*>(a_dst), scalar, len); | ||
| // apply causual mask to final result instead of attn_score | ||
| if (total_size > len) | ||
| memset(static_cast<float*>(a_dst) + len, 0, sizeof(float) * (total_size - len)); | ||
| } else { | ||
| exp_reduce_sum_f32(a, max, len, sum); | ||
| if (sink != nullptr) { | ||
| sum += std::exp(*sink - max); | ||
| } |
There was a problem hiding this comment.
This computation is duplicated in both branches of the if-else statement. Consider computing sink_contrib = std::exp(*sink - max) once before the if-else block and adding it to sum in both branches to reduce code duplication.
| if (sink != nullptr) { | |
| max = max > static_cast<const ov::float16>(*sink) ? max : static_cast<const ov::float16>(*sink); | |
| } | |
| if (dst_precision == ov::element::f32) { | |
| exp_reduce_sum_f32(a, max, len, sum); | |
| if (sink != nullptr) { | |
| sum += std::exp(*sink - max); | |
| } | |
| ov::float16 scalar = 1.0f / sum; | |
| multiply_scalar(a, static_cast<float*>(a_dst), scalar, len); | |
| // apply causual mask to final result instead of attn_score | |
| if (total_size > len) | |
| memset(static_cast<float*>(a_dst) + len, 0, sizeof(float) * (total_size - len)); | |
| } else { | |
| exp_reduce_sum_f32(a, max, len, sum); | |
| if (sink != nullptr) { | |
| sum += std::exp(*sink - max); | |
| } | |
| ov::float16 sink_contrib = 0.0f; | |
| if (sink != nullptr) { | |
| max = max > static_cast<const ov::float16>(*sink) ? max : static_cast<const ov::float16>(*sink); | |
| sink_contrib = std::exp(*sink - max); | |
| } | |
| exp_reduce_sum_f32(a, max, len, sum); | |
| if (sink != nullptr) { | |
| sum += sink_contrib; | |
| } | |
| if (dst_precision == ov::element::f32) { | |
| ov::float16 scalar = 1.0f / sum; | |
| multiply_scalar(a, static_cast<float*>(a_dst), scalar, len); | |
| // apply causual mask to final result instead of attn_score | |
| if (total_size > len) | |
| memset(static_cast<float*>(a_dst) + len, 0, sizeof(float) * (total_size - len)); | |
| } else { |
| @@ -1411,15 +1403,20 @@ inline void attn_softmax_kernel<ov::float16>(ov::float16* a, | |||
| } | |||
|
|
|||
| ov::float16 sum = 0.0f; | |||
| if (sink != nullptr) { | |||
| max = std::max(max, static_cast<const ov::float16>(*sink)); | |||
There was a problem hiding this comment.
sink is fp32, so max could be inf after casting to fp16. Is it safe to pass inf as max into exp_reduce_sum_f32?
There was a problem hiding this comment.
Its resolved now with the latest rebasing.
8dcdae9 to
a1b7ad3
Compare
|
hi @alvoron , |
|
build_jenkins |
|
@abhijain1204fujitsu could you please rebase one more time? |
a1b7ad3 to
c9934c9
Compare
I have rebased it. Please check. Thanks ! |
f2dbee2
…#33566) This PR enables SDPA fusion on ARM, when attention contains additional input **Sink** Checked and validated the operation with GPT-OSS 20b Model. On implementing this code and enhanced performance during the LLM inference has been observed, refer the below table <img width="604" height="120" alt="image" src="https://github.com/user-attachments/assets/c77b7908-4c8a-4a94-9feb-7f887bb9697b" /> ** All values are in Tokens per second - decoding throughput. Machine : Graviton4 - single socket - 96 cores. Kindly support to review the PR and share feedback if any. Thanks! This work is contributed by @ashwins990 & @abhijain1204fujitsu Co-authored-by: Ashwin <Ashwin.Sekhar@fujitsu.com>
…#33566) This PR enables SDPA fusion on ARM, when attention contains additional input **Sink** Checked and validated the operation with GPT-OSS 20b Model. On implementing this code and enhanced performance during the LLM inference has been observed, refer the below table <img width="604" height="120" alt="image" src="https://github.com/user-attachments/assets/c77b7908-4c8a-4a94-9feb-7f887bb9697b" /> ** All values are in Tokens per second - decoding throughput. Machine : Graviton4 - single socket - 96 cores. Kindly support to review the PR and share feedback if any. Thanks! This work is contributed by @ashwins990 & @abhijain1204fujitsu Co-authored-by: Ashwin <Ashwin.Sekhar@fujitsu.com>
This PR enables SDPA fusion on ARM, when attention contains additional input Sink
Checked and validated the operation with GPT-OSS 20b Model.
On implementing this code and enhanced performance during the LLM inference has been observed, refer the below table
** All values are in Tokens per second - decoding throughput.
Machine : Graviton4 - single socket - 96 cores.
Kindly support to review the PR and share feedback if any.
Thanks!
This work is contributed by @ashwins990 & @abhijain1204fujitsu