Skip to content

Commit a64819a

Browse files
committed
Merge remote-tracking branch 'upstream/main' into chunked_prefill_fix
2 parents 8b774ed + 87eb3c2 commit a64819a

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

megatron/core/extensions/transformer_engine.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2575,11 +2575,26 @@ def fused_apply_rotary_pos_emb_thd(
25752575
freqs: torch.Tensor,
25762576
cp_size: int = 1,
25772577
cp_rank: int = 0,
2578+
interleaved: bool = False,
25782579
) -> torch.Tensor:
25792580
"""
25802581
Apply rotary positional embedding to input tensor T in `thd` format with CP support.
25812582
"""
2582-
if is_te_min_version("1.12.0", check_equality=True):
2583+
if interleaved:
2584+
assert is_te_min_version("2.3.0"), "Only TE >= 2.3.0 supports interleaved fused RoPE."
2585+
2586+
if is_te_min_version("2.3.0", check_equality=True):
2587+
return apply_rotary_pos_emb(
2588+
t,
2589+
freqs,
2590+
tensor_format="thd",
2591+
fused=True,
2592+
cu_seqlens=cu_seqlens,
2593+
cp_size=cp_size,
2594+
cp_rank=cp_rank,
2595+
interleaved=interleaved,
2596+
)
2597+
elif is_te_min_version("1.12.0", check_equality=True):
25832598
return apply_rotary_pos_emb(
25842599
t,
25852600
freqs,

megatron/core/models/common/embeddings/rope_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,12 @@ def apply_rotary_pos_emb(
288288
else:
289289
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
290290
return fused_apply_rotary_pos_emb_thd(
291-
t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank()
291+
t,
292+
cu_seqlens,
293+
freqs,
294+
cp_size=cp_group.size(),
295+
cp_rank=cp_group.rank(),
296+
interleaved=config.rotary_interleaved,
292297
)
293298
# use unfused implementation
294299
if cu_seqlens is None:

0 commit comments

Comments
 (0)