Skip to content

Commit b3a025a

Browse files
committed
Add xLLM partial RoPE layout support
1 parent fa5f026 commit b3a025a

3 files changed

Lines changed: 58 additions & 4 deletions

File tree

megatron/core/models/common/embeddings/rope_utils.py

Lines changed: 48 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,42 @@ def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
8989
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
9090

9191

92+
def _xllm_layout_to_hf(t: Tensor) -> Tensor:
93+
return t.reshape(*t.shape[:-1], -1, 2).transpose(-1, -2).reshape_as(t)
94+
95+
96+
def _hf_layout_to_xllm(t: Tensor) -> Tensor:
97+
return t.reshape(*t.shape[:-1], 2, -1).transpose(-1, -2).reshape_as(t)
98+
99+
100+
def _apply_xllm_partial_rotary_pos_emb_bshd(
101+
t: Tensor, freqs: Tensor, rotary_interleaved: bool = False, mscale: float = 1.0
102+
) -> Tensor:
103+
rot_dim = freqs.shape[-1]
104+
if rot_dim * 2 != t.shape[-1]:
105+
raise ValueError(
106+
"xLLM partial RoPE layout currently expects rope_head_dim * 2 == head_dim, "
107+
f"got rope_head_dim={rot_dim}, head_dim={t.shape[-1]}"
108+
)
109+
110+
x = _hf_layout_to_xllm(t)
111+
x_rope, x_pass = x[..., :rot_dim], x[..., rot_dim:]
112+
x_rope_hf = _xllm_layout_to_hf(x_rope)
113+
114+
cos_ = (torch.cos(freqs) * mscale).to(x_rope_hf.dtype)
115+
sin_ = (torch.sin(freqs) * mscale).to(x_rope_hf.dtype)
116+
y_rope_hf = (x_rope_hf * cos_) + (_rotate_half(x_rope_hf, rotary_interleaved) * sin_)
117+
y = torch.cat((_hf_layout_to_xllm(y_rope_hf), x_pass), dim=-1)
118+
return _xllm_layout_to_hf(y)
119+
120+
92121
def _apply_rotary_pos_emb_bshd(
93122
t: Tensor,
94123
freqs: Tensor,
95124
rotary_interleaved: bool = False,
96125
multi_latent_attention: bool = False,
97126
mscale: float = 1.0,
127+
xllm_partial_rope_layout: bool = False,
98128
) -> Tensor:
99129
"""Apply rotary positional embedding to input tensor T.
100130
@@ -108,6 +138,12 @@ def _apply_rotary_pos_emb_bshd(
108138
Tensor: The input tensor after applying RoPE
109139
"""
110140
rot_dim = freqs.shape[-1]
141+
if xllm_partial_rope_layout and rot_dim < t.shape[-1]:
142+
if multi_latent_attention:
143+
raise ValueError("xLLM partial RoPE layout is not compatible with MLA tensors")
144+
return _apply_xllm_partial_rotary_pos_emb_bshd(
145+
t, freqs, rotary_interleaved=rotary_interleaved, mscale=mscale
146+
)
111147

112148
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
113149
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
@@ -183,6 +219,7 @@ def _apply_rotary_pos_emb_thd(
183219
multi_latent_attention: bool = False,
184220
mscale: float = 1.0,
185221
cp_group: torch.distributed.ProcessGroup = None,
222+
xllm_partial_rope_layout: bool = False,
186223
) -> Tensor:
187224
"""A baseline implementation of applying RoPE for `thd` format.
188225
@@ -228,6 +265,7 @@ def _apply_rotary_pos_emb_thd(
228265
rotary_interleaved=rotary_interleaved,
229266
multi_latent_attention=multi_latent_attention,
230267
mscale=mscale,
268+
xllm_partial_rope_layout=xllm_partial_rope_layout,
231269
).squeeze(1)
232270
else:
233271
# CASE 2: Traditional mapping without offsets
@@ -244,6 +282,7 @@ def _apply_rotary_pos_emb_thd(
244282
rotary_interleaved=rotary_interleaved,
245283
multi_latent_attention=multi_latent_attention,
246284
mscale=mscale,
285+
xllm_partial_rope_layout=xllm_partial_rope_layout,
247286
).squeeze(1)
248287

249288

@@ -276,6 +315,8 @@ def apply_rotary_pos_emb(
276315
"Please set apply_rope_fusion to false. This will become an error in v0.16."
277316
)
278317
use_unfused = True
318+
if getattr(config, "xllm_partial_rope_layout", False):
319+
use_unfused = True
279320
if mscale != 1.0:
280321
warnings.warn(
281322
f"mscale={mscale} is not supported by TE's fused RoPE. "
@@ -286,10 +327,11 @@ def apply_rotary_pos_emb(
286327
assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available."
287328
return fused_apply_rotary_pos_emb(t, freqs, interleaved=config.rotary_interleaved)
288329
else:
289-
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
290-
return fused_apply_rotary_pos_emb_thd(
291-
t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank()
292-
)
330+
if not getattr(config, "xllm_partial_rope_layout", False):
331+
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
332+
return fused_apply_rotary_pos_emb_thd(
333+
t, cu_seqlens, freqs, cp_size=cp_group.size(), cp_rank=cp_group.rank()
334+
)
293335
# use unfused implementation
294336
if cu_seqlens is None:
295337
return _apply_rotary_pos_emb_bshd(
@@ -298,6 +340,7 @@ def apply_rotary_pos_emb(
298340
rotary_interleaved=config.rotary_interleaved,
299341
multi_latent_attention=config.multi_latent_attention,
300342
mscale=mscale,
343+
xllm_partial_rope_layout=getattr(config, "xllm_partial_rope_layout", False),
301344
)
302345
else:
303346
return _apply_rotary_pos_emb_thd(
@@ -308,6 +351,7 @@ def apply_rotary_pos_emb(
308351
multi_latent_attention=config.multi_latent_attention,
309352
mscale=mscale,
310353
cp_group=cp_group,
354+
xllm_partial_rope_layout=getattr(config, "xllm_partial_rope_layout", False),
311355
)
312356

313357

megatron/core/transformer/transformer_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,14 @@ class TransformerConfig(ModelParallelConfig):
198198
"""True is rotate pairs of even and odd dimensions (RoFormer style), False is rotate pairs of
199199
first half and second half (LLaMa style). Default to False."""
200200

201+
xllm_partial_rope_layout: bool = False
202+
"""Apply partial RoPE using xLLM's HF/SGLang head-dimension layout.
203+
204+
This is only intended for xLLM checkpoints where rotary_percent < 1.0;
205+
standard Megatron partial RoPE rotates the first contiguous rotary slice,
206+
while xLLM rotates the slice after converting to the xLLM head layout.
207+
"""
208+
201209
window_size: Optional[Tuple[int, int]] = None
202210
"""If not None, then will use sliding window attention. The size of the window is specified by
203211
the numbers inside the tuple; -1 is special value meaning "infinite window size"."""

megatron/training/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,8 @@ def _add_network_size_args(parser):
16821682
help='Base to use for rotary positional embeddings, default 10000')
16831683
group.add_argument('--rotary-percent', type=float, default=1.0,
16841684
help='Percent of rotary dimension to use, default 100%%')
1685+
group.add_argument('--xllm-partial-rope-layout', action='store_true',
1686+
help='Use xLLM HF/SGLang head-dimension layout for partial RoPE.')
16851687
group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
16861688
help='Sequence length interpolation factor for rotary embeddings.')
16871689
group.add_argument('--use-rope-scaling', action='store_true',

0 commit comments

Comments
 (0)