Skip to content

Conversation

@kocchop
Copy link
Collaborator

@kocchop kocchop commented Nov 5, 2025

Description

Added THD packed format support and configurable context parallel strategies (all_gather/ring) for TransformerEngine's DotProductAttention. Included comprehensive GPU tests for packed attention (sm90+) and ring attention modes.

  • Added max_segments_per_seq config for packed sequence control
  • Supported context_parallel_strategy selection (all_gather vs ring)
  • Add GPU tests for both packed and ring attention scenarios

Note: Context parallelism with packing temporarily disabled pending full support soon.

Tests

Appropriate tests for packed sequence with fused attention and ring attention have been added.

Checklist

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

generate_padding_batch_eval: False
# Maximum number of segments that can be packed into a single sequence
# This needs to be passed to TransformerEngine's DotProductAttention layer for packing
max_segments_per_seq: 32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a valueError in pyconfig if this is set on TPU it errors out - this is only supported on GPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make it a warning for TPUs. But it does not get used in the TPU path. So, not sure if erroring out would make sense

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a warning @gobbleturk please check

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize it doesn't get used in the TPU path but I'm worried some of our TPU users might just read this variable name "max_segments_per_seq" and try it out on TPU. They will not be happy to see it silently doesn't do anything - I would prefer an error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @gobbleturk I removed the warning and squashed in a single commit. Could you please check?

@kocchop kocchop requested a review from gobbleturk November 6, 2025 02:04
    Added THD packed format support and configurable context parallel strategies
    (all_gather/ring) for TransformerEngine's DotProductAttention. Included
    comprehensive GPU tests for packed attention (sm90+) and ring attention modes.

    - Added max_segments_per_seq config for packed sequence control
    - Supported context_parallel_strategy selection (all_gather vs ring)
    - Handled dummy attn mask separately for packed sequences
    - Added config validations and softened the synth data + packing check
    - Add GPU tests for both packed and ring attention scenarios

    Note: Context parallelism with packing temporarily disabled pending full support soon.
@copybara-service copybara-service bot merged commit 9d0c860 into AI-Hypercomputer:main Nov 7, 2025
35 of 36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants