forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrope_utils.py
More file actions
351 lines (297 loc) · 12.4 KB
/
rope_utils.py
File metadata and controls
351 lines (297 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from megatron.core.transformer.transformer_config import TransformerConfig
import logging
import torch
from torch import Tensor
from megatron.core import parallel_state
logger = logging.getLogger(__name__)
try:
from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb
except ImportError:
fused_apply_rotary_pos_emb = None
try:
from megatron.core.extensions.transformer_engine import fused_apply_rotary_pos_emb_thd
except ImportError:
fused_apply_rotary_pos_emb_thd = None
try:
from flash_attn.layers.rotary import apply_rotary_emb as apply_rotary_emb_flash
except ImportError:
apply_rotary_emb_flash = None
__all__ = [
'apply_rotary_pos_emb',
'apply_rotary_emb_flash',
'apply_rotary_pos_emb_with_cos_sin',
'fused_apply_rotary_pos_emb',
'fused_apply_rotary_pos_emb_thd',
'get_pos_emb_on_this_cp_rank',
]
def get_pos_emb_on_this_cp_rank(
pos_emb: Tensor, seq_dim: int, cp_group: torch.distributed.ProcessGroup
) -> Tensor:
"""Get the position embedding on the current context parallel rank.
Args:
pos_emb (Tensor): Positional embedding tensor
seq_dim (int): Sequence dimension
cp_group (torch.distributed.ProcessGroup): The context parallel group
"""
if cp_group is None:
raise ValueError("cp_group must be provided to get positional embedding per CP rank")
cp_size = cp_group.size()
cp_rank = cp_group.rank()
cp_idx = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
pos_emb = pos_emb.view(
*pos_emb.shape[:seq_dim], 2 * cp_size, -1, *pos_emb.shape[(seq_dim + 1) :]
)
pos_emb = pos_emb.index_select(seq_dim, cp_idx)
pos_emb = pos_emb.view(*pos_emb.shape[:seq_dim], -1, *pos_emb.shape[(seq_dim + 2) :])
return pos_emb
def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
"""Change sign so the last dimension becomes [-odd, +even]
Args:
x (Tensor): Input tensor
Returns:
Tensor: Tensor rotated half
"""
if not rotary_interleaved:
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
else:
x1 = x[:, :, :, ::2]
x2 = x[:, :, :, 1::2]
x_new = torch.stack((-x2, x1), dim=-1)
return x_new.view(x_new.shape[0], x_new.shape[1], x_new.shape[2], -1)
def _apply_rotary_pos_emb_bshd(
t: Tensor,
freqs: Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
mscale: float = 1.0,
) -> Tensor:
"""Apply rotary positional embedding to input tensor T.
check https://kexue.fm/archives/8265 for detailed formulas
Args:
t (Tensor): Input tensor T is of shape [seq_length, ... , dim]
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim]
Returns:
Tensor: The input tensor after applying RoPE
"""
rot_dim = freqs.shape[-1]
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
if multi_latent_attention:
x1 = t[..., 0::2]
x2 = t[..., 1::2]
t = torch.cat((x1, x2), dim=-1)
# first part is cosine component
# second part is sine component, need to change signs with _rotate_half method
cos_ = (torch.cos(freqs) * mscale).to(t.dtype)
sin_ = (torch.sin(freqs) * mscale).to(t.dtype)
t = (t * cos_) + (_rotate_half(t, rotary_interleaved) * sin_)
return torch.cat((t, t_pass), dim=-1)
def _get_thd_freqs_on_this_cp_rank(
cp_rank: int, cp_size: int, x: Tensor, freqs: Tensor, offset: int = 0
) -> Tensor:
"""Get the correct frequency slice for this context parallel rank with optional sequence offset.
Args:
cp_rank: Current context parallel rank
cp_size: Total context parallel size
x: Input tensor for current sequence
freqs: Frequency tensor - either full batch positions or max sequence length
offset: Starting position offset for this sequence in the original batch (default: 0)
Returns:
Tensor: Frequency slice corresponding to this CP rank's portion of the sequence
Note:
This function supports two modes based on the offset parameter:
1. offset > 0: Exact mapping mode - freqs contains all positions across all sequences.
The offset ensures each sequence gets frequencies from its actual position within
the overall batch. Critical for non-1D RoPE in VLMs where spatial positions matter.
2. offset = 0: Traditional mode - freqs contains only max sequence length positions.
All sequences use frequencies starting from position 0, preserving backward
compatibility.
"""
if cp_size > 1:
cp_seg = x.size(0) // 2
full_seqlen = cp_size * x.size(0)
# Apply offset to both forward and backward segments for context parallelism
# offset=0: traditional behavior, freqs[0:cp_seg] and freqs[...]
# offset>0: exact mapping, freqs[offset+0:offset+cp_seg] and freqs[offset+...]
return torch.cat(
[
freqs[offset + cp_rank * cp_seg : offset + (cp_rank + 1) * cp_seg],
freqs[
offset
+ full_seqlen
- (cp_rank + 1) * cp_seg : offset
+ full_seqlen
- cp_rank * cp_seg
],
]
)
else:
# For single context parallel rank:
# offset=0: use freqs[0:x.size(0)] (traditional)
# offset>0: use freqs[offset:offset+x.size(0)] (exact mapping)
return freqs[offset : offset + x.size(0)]
def _apply_rotary_pos_emb_thd(
t: Tensor,
cu_seqlens: Tensor,
freqs: Tensor,
rotary_interleaved: bool = False,
multi_latent_attention: bool = False,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
) -> Tensor:
"""A baseline implementation of applying RoPE for `thd` format.
Args:
t (Tensor): Input tensor T is of shape [t, h, d]
cu_seqlens(Tensor): Cumulative sum of sequence lengths in a batch for `t`,
with shape [b + 1] and dtype torch.int32.
freqs (Tensor): Rotary Positional embedding tensor freq is of shape [max_s, 1, 1, d]
cp_group (torch.distributed.ProcessGroup): The context parallel group
Returns:
Tensor: Shape [t, h, d]. The input tensor after applying RoPE.
"""
if cp_group is None:
raise ValueError("cp_group must be provided for THD format RoPE")
cp_size = cp_group.size()
cp_rank = cp_group.rank()
seqlens = ((cu_seqlens[1:] - cu_seqlens[:-1]) // cp_size).tolist()
# Handle two different frequency tensor formats:
# 1. If freqs.size(0) == cu_seqlens[-1]: freqs contains all positions across all sequences
# -> Use offset-based mapping for exact positional correspondence
# 2. Otherwise: freqs contains only max sequence length positions
# -> Use traditional mapping without offsets (map first :seqlen part)
if freqs.dim() >= 1 and freqs.size(0) == cu_seqlens[-1]:
# CASE 1: Exact mapping with offsets
# Build packed freqs in one pass, then apply once to the whole packed tensor
sequence_splits = torch.split(t, seqlens)
freq_slices = []
for i, x in enumerate(sequence_splits):
# cu_seqlens[i] is the starting offset of this sequence in the original batch
seq_start_offset = cu_seqlens[i].item()
freq_slices.append(
_get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs, seq_start_offset)
)
freqs_packed = torch.cat(freq_slices, dim=0)
return _apply_rotary_pos_emb_bshd(
t.unsqueeze(1),
freqs_packed,
rotary_interleaved=rotary_interleaved,
multi_latent_attention=multi_latent_attention,
mscale=mscale,
).squeeze(1)
else:
# CASE 2: Traditional mapping without offsets
# Build packed freqs for all sequences using the standard mapping, then apply once
sequence_splits = torch.split(t, seqlens)
freqs_packed = torch.cat(
[_get_thd_freqs_on_this_cp_rank(cp_rank, cp_size, x, freqs) for x in sequence_splits],
dim=0,
)
return _apply_rotary_pos_emb_bshd(
t.unsqueeze(1),
freqs_packed,
rotary_interleaved=rotary_interleaved,
multi_latent_attention=multi_latent_attention,
mscale=mscale,
).squeeze(1)
def apply_rotary_pos_emb(
t: Tensor,
freqs: Tensor,
config: TransformerConfig,
cu_seqlens: Optional[Tensor] = None,
mscale: float = 1.0,
cp_group: torch.distributed.ProcessGroup = None,
):
"""
Reroute to the appropriate apply_rotary_pos_emb function depending on
fused/unfused kernels, or bshd (conventional) / thd (packed seq) format
"""
global fused_apply_rotary_pos_emb, fused_apply_rotary_pos_emb_thd
# Keep for backward compatibility. Will deprecate in the future.
if cp_group is None:
cp_group = parallel_state.get_context_parallel_group()
if config.apply_rope_fusion:
if cu_seqlens is None:
# NOTE: TE backends do not support mRoPE in bshd format when bs > 1.
use_unfused = False
if config.mrope_section is not None and freqs.shape[1] > 1:
# TODO: Add a check in TransformerConfig and remove this unfused implementation.
warnings.warn(
"apply_rope_fusion does not support mRoPE in bshd format when bs > 1. "
"Please set apply_rope_fusion to false. This will become an error in v0.16."
)
use_unfused = True
if mscale != 1.0:
warnings.warn(
f"mscale={mscale} is not supported by TE's fused RoPE. "
"Using unfused implementation."
)
use_unfused = True
if not use_unfused:
assert fused_apply_rotary_pos_emb is not None, "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb(t, freqs, interleaved=config.rotary_interleaved)
else:
assert fused_apply_rotary_pos_emb_thd is not None, "apply_rope_fusion is not available."
return fused_apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
cp_size=cp_group.size(),
cp_rank=cp_group.rank(),
interleaved=config.rotary_interleaved,
)
# use unfused implementation
if cu_seqlens is None:
return _apply_rotary_pos_emb_bshd(
t,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
)
else:
return _apply_rotary_pos_emb_thd(
t,
cu_seqlens,
freqs,
rotary_interleaved=config.rotary_interleaved,
multi_latent_attention=config.multi_latent_attention,
mscale=mscale,
cp_group=cp_group,
)
def apply_rotary_pos_emb_with_cos_sin(
t: Tensor, cos: Tensor, sin: Tensor, rotary_interleaved: bool = False
) -> Tensor:
"""
This function applies rotary positional embedding to the target tensor t
using precomputed cos and sin of size (seq_len, d_rot / 2)
"""
cos = cos.to(t.dtype)
sin = sin.to(t.dtype)
if apply_rotary_emb_flash is None:
# Combine cos and sin into freqs
freqs = torch.stack([cos, sin], dim=-1).flatten(start_dim=-2)
# Expand freqs to match t's shape
while freqs.dim() < t.dim():
freqs = freqs.unsqueeze(1)
freqs = freqs.expand(t.shape[:-1] + (-1,))
y = _apply_rotary_pos_emb_bshd(
t,
freqs,
rotary_interleaved=rotary_interleaved,
multi_latent_attention=False,
mscale=1.0,
)
else:
# Use Flash Attention's optimized kernel for rotary embedding
t = t.permute(1, 0, 2, 3)
y = apply_rotary_emb_flash(t, cos, sin, rotary_interleaved)
y = y.permute(1, 0, 2, 3)
return y