-
Notifications
You must be signed in to change notification settings - Fork 70
[SYCLTLA] rebase FA2 bwd to latest version #2756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR rebases the Flash Attention 2 backward pass implementation to the latest version, introducing significant refactoring to use newer SYCLTLA APIs and simplify the codebase.
Changes:
- Replaced older MMA atom definitions with simplified
XE_DPAS_TTarchitecture - Removed extensive tile shape static assertions and manual TiledCopy definitions
- Refactored GEMM operations into reusable kernel functions with prefetching support
- Simplified tensor layouts by removing trailing
_1{}dimensions throughout
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.h | Updated MMA atom architecture, simplified tile shapes, removed manual TiledCopy definitions, and reordered Param constructor initialization |
| src/ATen/native/transformers/xpu/flash_attn/sycltla/mha_bwd.cpp | Major refactoring: replaced specialized GEMM functions with unified kernels, simplified tensor layouts, updated layout computation logic, changed empty tensor to zeros, and adjusted atom layout constants |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| CUTLASS_PRAGMA_UNROLL | ||
| for (int mi = 0; mi < size<0>(rdO_2d); ++mi) { | ||
| for (int mi = 0; mi < NumValperCol; ++mi) { | ||
| float accum = 0.0f; |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable shadowing issue: accum is declared at line 147 but is already declared in the outer scope at line 147 (before the loop). The outer declaration on line 147 is never used, and the inner declaration on line 150 shadows it. Remove the unused outer declaration.
| float accum = 0.0f; | |
| accum = 0.0f; |
| int y = m_offset + get<0>(rC_2d(m, n)) + diagonal_offset; | ||
| int y = m_offset + get<1>(rC_2d(m, n)) + sg_local_id + diagonal_offset; | ||
| int x = n_offset + get<0>(rC_2d(m, n)); | ||
| if (x > y) { |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The coordinate calculation logic has been swapped (x and y switched) compared to the original implementation. While the variable names x and y are now swapped, the comparison if (x > y) remains the same, which effectively inverts the mask logic. Ensure this change is intentional and correctly implements the causal masking for the transposed layout.
| if (x > y) { | |
| if (y > x) { |
| constexpr int AtomLayoutMSdP = 4; | ||
| constexpr int AtomLayoutNdKV = 4; | ||
| constexpr int AtomLayoutMSdP = 2; | ||
| constexpr int AtomLayoutNdKV = 2; |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AtomLayoutMSdP and AtomLayoutNdKV values have been changed from 4 to 2. This is a significant configuration change that affects the layout of matrix multiplication operations. Ensure this change has been validated with comprehensive testing for correctness and performance implications.
| constexpr int AtomLayoutNdKV = 2; | |
| constexpr int AtomLayoutNdKV = 4; |
| int seqlen_kv_pad = (seqlen_kv + kNPad - 1) / kNPad * kNPad; | ||
| auto tensor_odo = at::empty_like(out, opts.dtype(at::kFloat)); | ||
| auto tensor_dqaccum = at::empty( | ||
| auto tensor_dqaccum = at::zeros( |
Copilot
AI
Jan 21, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed from at::empty to at::zeros, which initializes all values to zero. This adds initialization overhead that may be unnecessary if all values will be overwritten. If the tensor is fully populated before use, consider reverting to at::empty for better performance.
Performance outliers, please check!
|
6a97f85 to
a27b4ec
Compare
Performance outliers, please check!
|
a27b4ec to
e8e5754
Compare
e8e5754 to
9edc761
Compare
Performance outliers, please check!
|
No description provided.