-
Notifications
You must be signed in to change notification settings - Fork 419
Enable packed sequences and ring attention for CUDNN Flash Attention #2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable packed sequences and ring attention for CUDNN Flash Attention #2600
Conversation
| generate_padding_batch_eval: False | ||
| # Maximum number of segments that can be packed into a single sequence | ||
| # This needs to be passed to TransformerEngine's DotProductAttention layer for packing | ||
| max_segments_per_seq: 32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a valueError in pyconfig if this is set on TPU it errors out - this is only supported on GPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make it a warning for TPUs. But it does not get used in the TPU path. So, not sure if erroring out would make sense
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a warning @gobbleturk please check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realize it doesn't get used in the TPU path but I'm worried some of our TPU users might just read this variable name "max_segments_per_seq" and try it out on TPU. They will not be happy to see it silently doesn't do anything - I would prefer an error
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @gobbleturk I removed the warning and squashed in a single commit. Could you please check?
Added THD packed format support and configurable context parallel strategies
(all_gather/ring) for TransformerEngine's DotProductAttention. Included
comprehensive GPU tests for packed attention (sm90+) and ring attention modes.
- Added max_segments_per_seq config for packed sequence control
- Supported context_parallel_strategy selection (all_gather vs ring)
- Handled dummy attn mask separately for packed sequences
- Added config validations and softened the synth data + packing check
- Add GPU tests for both packed and ring attention scenarios
Note: Context parallelism with packing temporarily disabled pending full support soon.
73af807 to
2454858
Compare
9d0c860
into
AI-Hypercomputer:main
Description
Added THD packed format support and configurable context parallel strategies (all_gather/ring) for TransformerEngine's DotProductAttention. Included comprehensive GPU tests for packed attention (sm90+) and ring attention modes.
Note: Context parallelism with packing temporarily disabled pending full support soon.
Tests
Appropriate tests for packed sequence with fused attention and ring attention have been added.
Checklist
gemini-reviewlabel.