@@ -78,10 +78,11 @@ def backward(ctx, do):
7878class ApplyRotaryEmbQKV_ (torch .autograd .Function ):
7979
8080 @staticmethod
81- def forward (ctx , qkv , cos , sin ):
81+ def forward (ctx , qkv , cos , sin , cos_k = None , sin_k = None ):
8282 """
8383 qkv: (batch_size, seqlen, 3, nheads, headdim)
8484 cos, sin: (seqlen, rotary_dim / 2)
85+ cos_k, sin_k: (seqlen, rotary_dim / 2), optional
8586 rotary_dim must be <= headdim
8687 Apply rotary embedding *inplace* to the first rotary_dim of q and k.
8788 """
@@ -91,29 +92,31 @@ def forward(ctx, qkv, cos, sin):
9192 rotary_dim *= 2
9293 assert rotary_dim <= headdim
9394 assert seqlen <= rotary_seqlen
94- assert sin .shape == (rotary_seqlen , rotary_dim // 2 )
95+ cos_k = cos if cos_k is None else cos_k
96+ sin_k = sin if sin_k is None else sin_k
97+ assert sin .shape == cos_k .shape == sin_k .shape == (rotary_seqlen , rotary_dim // 2 )
9598 q1 , q2 = qkv [:, :, 0 , :, :rotary_dim ].chunk (2 , dim = - 1 )
9699 rotary_emb .apply_rotary (q1 , q2 , rearrange (cos [:seqlen ], 's d -> s 1 d' ),
97100 rearrange (sin [:seqlen ], 's d -> s 1 d' ), q1 , q2 , False )
98101 k1 , k2 = qkv [:, :, 1 , :, :rotary_dim ].chunk (2 , dim = - 1 )
99- rotary_emb .apply_rotary (k1 , k2 , rearrange (cos [:seqlen ], 's d -> s 1 d' ),
100- rearrange (sin [:seqlen ], 's d -> s 1 d' ), k1 , k2 , False )
101- ctx .save_for_backward (cos , sin )
102+ rotary_emb .apply_rotary (k1 , k2 , rearrange (cos_k [:seqlen ], 's d -> s 1 d' ),
103+ rearrange (sin_k [:seqlen ], 's d -> s 1 d' ), k1 , k2 , False )
104+ ctx .save_for_backward (cos , sin , cos_k , sin_k )
102105 return qkv
103106
104107 @staticmethod
105108 def backward (ctx , dqkv ):
106- cos , sin = ctx .saved_tensors
109+ cos , sin , cos_k , sin_k = ctx .saved_tensors
107110 _ , seqlen , _ , _ , headdim = dqkv .shape
108111 rotary_dim = cos .shape [- 1 ]
109112 rotary_dim *= 2
110113 dq1 , dq2 = dqkv [:, :, 0 , :, :rotary_dim ].chunk (2 , dim = - 1 )
111114 rotary_emb .apply_rotary (dq1 , dq2 , rearrange (cos [:seqlen ], 's d -> s 1 d' ),
112115 rearrange (sin [:seqlen ], 's d -> s 1 d' ), dq1 , dq2 , True )
113116 dk1 , dk2 = dqkv [:, :, 1 , :, :rotary_dim ].chunk (2 , dim = - 1 )
114- rotary_emb .apply_rotary (dk1 , dk2 , rearrange (cos [:seqlen ], 's d -> s 1 d' ),
115- rearrange (sin [:seqlen ], 's d -> s 1 d' ), dk1 , dk2 , True )
116- return dqkv , None , None
117+ rotary_emb .apply_rotary (dk1 , dk2 , rearrange (cos_k [:seqlen ], 's d -> s 1 d' ),
118+ rearrange (sin_k [:seqlen ], 's d -> s 1 d' ), dk1 , dk2 , True )
119+ return dqkv , None , None , None , None
117120
118121
119122apply_rotary_emb_qkv_ = ApplyRotaryEmbQKV_ .apply
@@ -134,15 +137,24 @@ class RotaryEmbedding(torch.nn.Module):
134137
135138 """
136139
137- def __init__ (self , dim : int , base = 10000 , * _ , ** __ ):
140+ def __init__ (self , dim : int , base = 10000 , scale_base = 0 , * _ , ** __ ):
141+ """
142+ If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
143+ Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
144+ """
138145 super ().__init__ ()
139146 # Generate and save the inverse frequency buffer (non trainable)
140147 inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 ).float () / dim ))
141148 self .register_buffer ("inv_freq" , inv_freq )
149+ self .scale_base = scale_base
150+ scale = (torch .arange (0 , dim , 2 ) + 0.4 * dim ) / (1.4 * dim ) if scale_base > 0 else None
151+ self .register_buffer ("scale" , scale )
142152
143153 self ._seq_len_cached = 0
144154 self ._cos_cached = None
145155 self ._sin_cached = None
156+ self ._cos_k_cached = None
157+ self ._sin_k_cached = None
146158
147159 def _update_cos_sin_cache (self , x , seqlen_offset = 0 ):
148160 """x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
@@ -157,14 +169,31 @@ def _update_cos_sin_cache(self, x, seqlen_offset=0):
157169 # Don't do einsum, it converts fp32 to fp16
158170 # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
159171 freqs = torch .outer (t , self .inv_freq )
160- self ._cos_cached = torch .cos (freqs ).to (x .dtype )
161- self ._sin_cached = torch .sin (freqs ).to (x .dtype )
172+ if self .scale is None :
173+ self ._cos_cached = torch .cos (freqs ).to (x .dtype )
174+ self ._sin_cached = torch .sin (freqs ).to (x .dtype )
175+ else :
176+ power = ((torch .arange (seqlen , dtype = self .scale .dtype , device = self .scale .device )
177+ - seqlen // 2 ) / self .scale_base )
178+ scale = self .scale ** rearrange (power , 's -> s 1' )
179+ # We want the multiplication by scale to happen in fp32
180+ self ._cos_cached = (torch .cos (freqs ) * scale ).to (x .dtype )
181+ self ._sin_cached = (torch .sin (freqs ) * scale ).to (x .dtype )
182+ self ._cos_k_cached = (torch .cos (freqs ) / scale ).to (x .dtype )
183+ self ._sin_k_cached = (torch .sin (freqs ) / scale ).to (x .dtype )
162184
163185 def forward (self , qkv : torch .Tensor , seqlen_offset : int = 0 ) -> Tuple [torch .Tensor , torch .Tensor ]:
164186 """
165187 seqlen_offset: can be used in generation where the qkv being passed in is only the last
166188 token in the batch.
167189 """
168190 self ._update_cos_sin_cache (qkv , seqlen_offset )
169- return apply_rotary_emb_qkv_ (qkv , self ._cos_cached [seqlen_offset :],
170- self ._sin_cached [seqlen_offset :])
191+ if self .scale is None :
192+ return apply_rotary_emb_qkv_ (
193+ qkv , self ._cos_cached [seqlen_offset :], self ._sin_cached [seqlen_offset :]
194+ )
195+ else :
196+ return apply_rotary_emb_qkv_ (
197+ qkv , self ._cos_cached [seqlen_offset :], self ._sin_cached [seqlen_offset :],
198+ self ._cos_k_cached [seqlen_offset :], self ._sin_k_cached [seqlen_offset :]
199+ )
0 commit comments