diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index e901e40597a..78454a6fbf4 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -2575,11 +2575,26 @@ def fused_apply_rotary_pos_emb_thd( freqs: torch.Tensor, cp_size: int = 1, cp_rank: int = 0, + interleaved: bool = False, ) -> torch.Tensor: """ Apply rotary positional embedding to input tensor T in `thd` format with CP support. """ - if is_te_min_version("1.12.0", check_equality=True): + if interleaved: + assert is_te_min_version("2.3.0"), "Only TE >= 2.3.0 supports interleaved fused RoPE." + + if is_te_min_version("2.3.0", check_equality=True): + return apply_rotary_pos_emb( + t, + freqs, + tensor_format="thd", + fused=True, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + interleaved=interleaved, + ) + elif is_te_min_version("1.12.0", check_equality=True): return apply_rotary_pos_emb( t, freqs, diff --git a/megatron/core/models/common/embeddings/rope_utils.py b/megatron/core/models/common/embeddings/rope_utils.py index e39540eb1d1..2fd19194813 100644 --- a/megatron/core/models/common/embeddings/rope_utils.py +++ b/megatron/core/models/common/embeddings/rope_utils.py @@ -288,7 +288,12 @@ def apply_rotary_pos_emb( else: assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available." return fused_apply_rotary_pos_emb_thd( - t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank() + t, + cu_seqlens, + freqs, + cp_size=cp_group.size(), + cp_rank=cp_group.rank(), + interleaved=config.rotary_interleaved, ) # use unfused implementation if cu_seqlens is None: