[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+#2693
Open
KshitijLakhani wants to merge 15 commits intoNVIDIA:mainfrom
Open
[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+#2693KshitijLakhani wants to merge 15 commits intoNVIDIA:mainfrom
KshitijLakhani wants to merge 15 commits intoNVIDIA:mainfrom
Conversation
674394b to
998b3b8
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
…pe instead of TH1 for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
dc282ea to
b2f5864
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…incorrect max logit calculation (includes padded tokens in max calculation) Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…pa arbitrary kernel call Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…clude a check for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Contributor
Greptile SummaryThis PR enables SM120 (Blackwell) support for fused attention with THD (ragged sequence) layouts when cuDNN ≥ 9.18.1, by adapting the graph-building logic for a cuDNN stride validation difference on that architecture. Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend - utils.py] --> B{SM120 + cuDNN < 9.18.1?}
B -- Yes --> C[Disable FusedAttention for THD]
B -- No --> D{SM120 + t3hd/th3d layout?}
D -- Yes --> C
D -- No --> E[FusedAttention enabled for SM120 + THD]
E --> F[fused_attn_fwd - Python]
F --> G{T3HD or TH3D layout?}
G -- Yes --> H[NVTE_ERROR - assert in nvte_fused_attn_fwd]
G -- No --> I[fused_attn_arbitrary_seqlen_fwd_impl]
I --> J{SM120 and ragged Q/KV?}
J -- Yes --> K[Keep b=batch, s_q=max_seqlen_q\nBHSD-like layout - passes cuDNN stride check]
J -- No --> L[b=max_b, s_q=max_t_q\nPacked layout - quantization bucket]
K --> M{use_ragged_stats?\nis_ragged_q AND cudnn>=9.6 AND sm<120}
L --> M
M -- False - SM120 --> N[Stats: BHS1 stride\nbatch x heads x max_seqlen_q x 1]
M -- True - non-SM120 --> O[Stats: TH1 stride\nnum_tokens x heads x 1\nwith ragged offset]
N --> P[return_max_logit path in Python\nmax_tensor.ndim==4: mask padded positions\namax over dims 0,2,3 to get shape h]
O --> Q[return_max_logit path in Python\nmax_tensor.ndim==3: amax over dims 0,2 to get shape h]
R[context_parallel.py\nAttnFuncWithCPAndKVP2P] --> S{SM120?}
S -- Yes --> T[softmax_lse_in_packed_format=False\nUse BHS1 format]
S -- No --> U[softmax_lse_in_packed_format = cudnn>=9.6\nUse TH1 packed format]
Last reviewed commit: bcfef90 |
Comment on lines
+639
to
+641
| NVTE_ERROR( | ||
| "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 " | ||
| "Use thd_thd_thd or other THD layouts instead."); |
Contributor
There was a problem hiding this comment.
Missing period in forward error message
The forward error message is missing a period that is present in the corresponding backward error message (line 748). Minor inconsistency but worth fixing for uniformity.
Suggested change
| NVTE_ERROR( | |
| "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 " | |
| "Use thd_thd_thd or other THD layouts instead."); | |
| "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. " | |
| "Use thd_thd_thd or other THD layouts instead."); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Collaborator
Author
|
/te-ci L0 L1 |
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Collaborator
Author
|
/te-ci L0 L1 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Enable sm120 support for THD for fused attn for cuDNN 9.18.1+
Type of change
Changes
get_attention_backends()if T3HD or TH3D shapes are used as cuDNN does not support then. Also, assert in common for the same before calling f16 arbitrary seqlens fwd/bwdget_attention_backends()(until fully supported)Test results:
Ran PyT attention tests on sm120 and no failures:
Checklist: