Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2575,6 +2575,7 @@ def fused_apply_rotary_pos_emb_thd(
freqs: torch.Tensor,
cp_size: int = 1,
cp_rank: int = 0,
interleaved: bool = False,
) -> torch.Tensor:
"""
Apply rotary positional embedding to input tensor T in `thd` format with CP support.
Expand All @@ -2588,6 +2589,7 @@ def fused_apply_rotary_pos_emb_thd(
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
interleaved=interleaved,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: The interleaved parameter is correctly passed in this if branch (TE >= 1.12.0), but the else branch at line 2596-2598 does not pass interleaved=interleaved to apply_rotary_pos_emb. This means on older TE versions, the interleaved setting will be silently ignored. Since this PR is specifically adding interleaved support, the fallback path should be fixed too:

# line 2596-2598, else branch:
return apply_rotary_pos_emb(
    t, freqs, tensor_format="thd", fused=True, cu_seqlens=cu_seqlens,
    interleaved=interleaved,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If older TE versions don't have the interleaved argument we should assert False in the else branch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jaredcasper I looked into the TE code and found out that interleaved only supports for TE>2.3.0.
Hence, I added a condition check for min_TE when using interleaved=True, if it's <2.3.0 will assert error. (https://github.com/huvunvidia/Megatron-LM/blob/ce883edd14063fa0f298b8b8bbaaea5f0ba893c9/megatron/core/extensions/transformer_engine.py#L2583)
With that we don't need to modify subsequent code.

)
else:
assert cp_size == 1, "Only TE >= 1.12 supports RoPE fusion for THD format with CP."
Expand Down
7 changes: 6 additions & 1 deletion megatron/core/models/common/embeddings/rope_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,12 @@ def apply_rotary_pos_emb(
else:
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb_thd(
t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank()
t,
cu_seqlens,
freqs,
cp_size=cp_group.size(),
cp_rank=cp_group.rank(),
interleaved=config.rotary_interleaved,
)
# use unfused implementation
if cu_seqlens is None:
Expand Down
Loading