Skip to content

Commit 496e4f5

Browse files
committed
Implement XPos (Sun et al.)
1 parent c2407de commit 496e4f5

3 files changed

Lines changed: 47 additions & 16 deletions

File tree

flash_attn/layers/rotary.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,11 @@ def backward(ctx, do):
7878
class ApplyRotaryEmbQKV_(torch.autograd.Function):
7979

8080
@staticmethod
81-
def forward(ctx, qkv, cos, sin):
81+
def forward(ctx, qkv, cos, sin, cos_k=None, sin_k=None):
8282
"""
8383
qkv: (batch_size, seqlen, 3, nheads, headdim)
8484
cos, sin: (seqlen, rotary_dim / 2)
85+
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
8586
rotary_dim must be <= headdim
8687
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
8788
"""
@@ -91,29 +92,31 @@ def forward(ctx, qkv, cos, sin):
9192
rotary_dim *= 2
9293
assert rotary_dim <= headdim
9394
assert seqlen <= rotary_seqlen
94-
assert sin.shape == (rotary_seqlen, rotary_dim // 2)
95+
cos_k = cos if cos_k is None else cos_k
96+
sin_k = sin if sin_k is None else sin_k
97+
assert sin.shape == cos_k.shape == sin_k.shape == (rotary_seqlen, rotary_dim // 2)
9598
q1, q2 = qkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
9699
rotary_emb.apply_rotary(q1, q2, rearrange(cos[:seqlen], 's d -> s 1 d'),
97100
rearrange(sin[:seqlen], 's d -> s 1 d'), q1, q2, False)
98101
k1, k2 = qkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
99-
rotary_emb.apply_rotary(k1, k2, rearrange(cos[:seqlen], 's d -> s 1 d'),
100-
rearrange(sin[:seqlen], 's d -> s 1 d'), k1, k2, False)
101-
ctx.save_for_backward(cos, sin)
102+
rotary_emb.apply_rotary(k1, k2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
103+
rearrange(sin_k[:seqlen], 's d -> s 1 d'), k1, k2, False)
104+
ctx.save_for_backward(cos, sin, cos_k, sin_k)
102105
return qkv
103106

104107
@staticmethod
105108
def backward(ctx, dqkv):
106-
cos, sin = ctx.saved_tensors
109+
cos, sin, cos_k, sin_k = ctx.saved_tensors
107110
_, seqlen, _, _, headdim = dqkv.shape
108111
rotary_dim = cos.shape[-1]
109112
rotary_dim *= 2
110113
dq1, dq2 = dqkv[:, :, 0, :, :rotary_dim].chunk(2, dim=-1)
111114
rotary_emb.apply_rotary(dq1, dq2, rearrange(cos[:seqlen], 's d -> s 1 d'),
112115
rearrange(sin[:seqlen], 's d -> s 1 d'), dq1, dq2, True)
113116
dk1, dk2 = dqkv[:, :, 1, :, :rotary_dim].chunk(2, dim=-1)
114-
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos[:seqlen], 's d -> s 1 d'),
115-
rearrange(sin[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
116-
return dqkv, None, None
117+
rotary_emb.apply_rotary(dk1, dk2, rearrange(cos_k[:seqlen], 's d -> s 1 d'),
118+
rearrange(sin_k[:seqlen], 's d -> s 1 d'), dk1, dk2, True)
119+
return dqkv, None, None, None, None
117120

118121

119122
apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_.apply
@@ -134,15 +137,24 @@ class RotaryEmbedding(torch.nn.Module):
134137
135138
"""
136139

137-
def __init__(self, dim: int, base=10000, *_, **__):
140+
def __init__(self, dim: int, base=10000, scale_base=0, *_, **__):
141+
"""
142+
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
143+
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
144+
"""
138145
super().__init__()
139146
# Generate and save the inverse frequency buffer (non trainable)
140147
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
141148
self.register_buffer("inv_freq", inv_freq)
149+
self.scale_base = scale_base
150+
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None
151+
self.register_buffer("scale", scale)
142152

143153
self._seq_len_cached = 0
144154
self._cos_cached = None
145155
self._sin_cached = None
156+
self._cos_k_cached = None
157+
self._sin_k_cached = None
146158

147159
def _update_cos_sin_cache(self, x, seqlen_offset=0):
148160
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
@@ -157,14 +169,31 @@ def _update_cos_sin_cache(self, x, seqlen_offset=0):
157169
# Don't do einsum, it converts fp32 to fp16
158170
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
159171
freqs = torch.outer(t, self.inv_freq)
160-
self._cos_cached = torch.cos(freqs).to(x.dtype)
161-
self._sin_cached = torch.sin(freqs).to(x.dtype)
172+
if self.scale is None:
173+
self._cos_cached = torch.cos(freqs).to(x.dtype)
174+
self._sin_cached = torch.sin(freqs).to(x.dtype)
175+
else:
176+
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
177+
- seqlen // 2) / self.scale_base)
178+
scale = self.scale ** rearrange(power, 's -> s 1')
179+
# We want the multiplication by scale to happen in fp32
180+
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
181+
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
182+
self._cos_k_cached = (torch.cos(freqs) / scale).to(x.dtype)
183+
self._sin_k_cached = (torch.sin(freqs) / scale).to(x.dtype)
162184

163185
def forward(self, qkv: torch.Tensor, seqlen_offset: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
164186
"""
165187
seqlen_offset: can be used in generation where the qkv being passed in is only the last
166188
token in the batch.
167189
"""
168190
self._update_cos_sin_cache(qkv, seqlen_offset)
169-
return apply_rotary_emb_qkv_(qkv, self._cos_cached[seqlen_offset:],
170-
self._sin_cached[seqlen_offset:])
191+
if self.scale is None:
192+
return apply_rotary_emb_qkv_(
193+
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:]
194+
)
195+
else:
196+
return apply_rotary_emb_qkv_(
197+
qkv, self._cos_cached[seqlen_offset:], self._sin_cached[seqlen_offset:],
198+
self._cos_k_cached[seqlen_offset:], self._sin_k_cached[seqlen_offset:]
199+
)

flash_attn/models/gpt.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ def create_mixer_cls(config, layer_idx=None):
3636
softmax_scale /= float(layer_idx + 1)
3737
dwconv = getattr(config, 'attn_dwconv', False)
3838
rotary_emb_dim = int(getattr(config, 'rotary_emb_fraction', 0.0) * head_dim)
39+
rotary_emb_scale_base = getattr(config, 'rotary_emb_scale_base', 0)
3940
use_flash_attn = getattr(config, 'use_flash_attn', False)
4041
fused_bias_fc = getattr(config, 'fused_bias_fc', False)
4142
mixer_cls = partial(MHA, num_heads=config.num_attention_heads, dropout=config.attn_pdrop,
4243
softmax_scale=softmax_scale, causal=True, dwconv=dwconv,
43-
rotary_emb_dim=rotary_emb_dim,
44+
rotary_emb_dim=rotary_emb_dim, rotary_emb_scale_base=rotary_emb_scale_base,
4445
fused_bias_fc=fused_bias_fc, use_flash_attn=use_flash_attn)
4546
return mixer_cls
4647

flash_attn/modules/mha.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,7 @@ class MHA(nn.Module):
283283

284284
def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.0,
285285
softmax_scale=None, causal=False, dwconv=False, rotary_emb_dim=0,
286+
rotary_emb_scale_base=0,
286287
fused_bias_fc=False, use_flash_attn=False, return_residual=False,
287288
checkpointing=False, device=None, dtype=None) -> None:
288289
"""
@@ -308,7 +309,7 @@ def __init__(self, embed_dim, num_heads, cross_attn=False, bias=True, dropout=0.
308309
if self.rotary_emb_dim > 0:
309310
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
310311
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
311-
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim)
312+
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base)
312313

313314
if fused_bias_fc and FusedDenseTD is None:
314315
raise ImportError('fused_dense is not installed')

0 commit comments

Comments
 (0)