-
Notifications
You must be signed in to change notification settings - Fork 67
Add CausalMask support with new flash attention api #604
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Chen, Xi2 <[email protected]>
| class TiledCopyK_ = void, // Optional TiledCopy for loading K | ||
| class TiledCopyV_ = void> // Optional TiledCopy for loading V | ||
| class TiledCopyV_ = void, // Optional TiledCopy for loading V | ||
| class SubgroupLayoutQK_ = void> // Optional SubgroupLayout for QK |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ClarkChin08 -- You can derive SubgroupLayoutQK_ from TiledMMAQK_, so the user doesn't have to pass it in.
| int item_id = get_sub_group().get_local_id()[0]; | ||
| int base_col = item_id + K * get<1>(TileShapeQK{}); | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int n = 0; n < shape<2>(tSrS.shape()); ++n) { | ||
| int col_idx = base_col + n * get<1>(MmaAtomShapeQK()); | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int m = 0; m < shape<0>(tSrS.shape()); ++m) { | ||
| int row_idx = seq_coord + m; | ||
| if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { | ||
| tSrS(m, 0, n) = ElementS(-INFINITY); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid making assumptions about the layout of MMA blocks within tSrS as they can easily broken by different choices for TiledMMAQK. Also, this code is assuming size<1>(tSrS) == 1.
Instead, I'd suggest using coordinate tensors to identify the coordinates of thread-owned data -- something like this (code below untested):
auto cS_thread = thr_mma_qk.partition_C(cP); /* local S coordinates, within WG tile */
for (int i = 0; i < tSrS.size(); i++)
if (get<1>(cS_thread(i)) >= seq_len - base_col)
tSrS(i) = ElementS(-INFINITY);
Also, by using this method, you can avoid referring to SublayoutLayoutQK at all, because cS_thread is already aware of the subgroup tile's position within the workgroup tile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @petercad , Let me have a try.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for causal masking in the flash attention implementation by introducing a new SubgroupLayoutQK template parameter and implementing the causal mask logic in the mainloop.
Key Changes:
- Added
SubgroupLayoutQKtemplate parameter to the collective mainloop and kernel interfaces - Implemented causal masking logic that applies
-INFINITYto attention scores beyond the causal boundary - Updated the example runner to conditionally instantiate causal or non-causal configurations based on user options
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
applications/flash_attention_v2/collective/xe_fmha_fwd_mainloop.hpp |
Implements causal mask logic and removes the static assertion that previously blocked causal mask usage |
applications/flash_attention_v2/kernel/xe_fhma_fwd_kernel.hpp |
Adds subgroup layout type alias and computes sequence coordinates for causal masking |
examples/06_bmg_flash_attention/xe_fmha_fwd_runner.hpp |
Adds SubgroupLayoutQK template parameter to mainloop type |
examples/06_bmg_flash_attention/06_xe_fmha_fwd.cpp |
Conditionally selects causal or non-causal kernel based on is_causal option |
| auto discard_seq_coord = s.seq_len_qo - offset; | ||
| auto full_tile_offset = s.seq_len_kv - offset; | ||
|
|
||
| int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ)); |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This calculation is overly complex and difficult to understand. Consider breaking it into intermediate variables with descriptive names to clarify the computation of tile offset within the subgroup layout.
| int seq_coord = cute::min(s.seq_len_qo, (blk_q * get<0>(TileShapeQK{}) + (sub_group_id / get<1>(shape(SubgroupLayoutQK{}))) * SGTileQ)); | |
| // Break down the seq_coord calculation for clarity | |
| int tile_shape_qk_0 = get<0>(TileShapeQK{}); | |
| int subgroup_layout_qk_1 = get<1>(shape(SubgroupLayoutQK{})); | |
| int blk_q_offset = blk_q * tile_shape_qk_0; | |
| int subgroup_tile_offset = (sub_group_id / subgroup_layout_qk_1) * SGTileQ; | |
| int raw_seq_coord = blk_q_offset + subgroup_tile_offset; | |
| int seq_coord = cute::min(s.seq_len_qo, raw_seq_coord); |
| int col_idx = base_col + n * get<1>(MmaAtomShapeQK()); | ||
| CUTLASS_PRAGMA_UNROLL | ||
| for (int m = 0; m < shape<0>(tSrS.shape()); ++m) { | ||
| int row_idx = seq_coord + m; |
Copilot
AI
Nov 5, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The causal mask condition logic lacks explanation. Add a comment describing what this inequality represents in terms of the attention matrix and why these specific offsets are used.
| int row_idx = seq_coord + m; | |
| int row_idx = seq_coord + m; | |
| // Causal mask: For each (row_idx, col_idx) in the attention matrix, | |
| // set positions where the query (row) would attend to future keys (col) to -INFINITY. | |
| // The offsets (full_tile_offset, discard_seq_coord) adjust for tiling and indexing, | |
| // ensuring that only positions where col_idx > row_idx (i.e., future positions) | |
| // are masked out, preserving causality in the attention mechanism. |
No description provided.