@@ -89,12 +89,42 @@ def _rotate_half(x: Tensor, rotary_interleaved: bool) -> Tensor:
8989 return x_new .view (x_new .shape [0 ], x_new .shape [1 ], x_new .shape [2 ], - 1 )
9090
9191
92+ def _xllm_layout_to_hf (t : Tensor ) -> Tensor :
93+ return t .reshape (* t .shape [:- 1 ], - 1 , 2 ).transpose (- 1 , - 2 ).reshape_as (t )
94+
95+
96+ def _hf_layout_to_xllm (t : Tensor ) -> Tensor :
97+ return t .reshape (* t .shape [:- 1 ], 2 , - 1 ).transpose (- 1 , - 2 ).reshape_as (t )
98+
99+
100+ def _apply_xllm_partial_rotary_pos_emb_bshd (
101+ t : Tensor , freqs : Tensor , rotary_interleaved : bool = False , mscale : float = 1.0
102+ ) -> Tensor :
103+ rot_dim = freqs .shape [- 1 ]
104+ if rot_dim * 2 != t .shape [- 1 ]:
105+ raise ValueError (
106+ "xLLM partial RoPE layout currently expects rope_head_dim * 2 == head_dim, "
107+ f"got rope_head_dim={ rot_dim } , head_dim={ t .shape [- 1 ]} "
108+ )
109+
110+ x = _hf_layout_to_xllm (t )
111+ x_rope , x_pass = x [..., :rot_dim ], x [..., rot_dim :]
112+ x_rope_hf = _xllm_layout_to_hf (x_rope )
113+
114+ cos_ = (torch .cos (freqs ) * mscale ).to (x_rope_hf .dtype )
115+ sin_ = (torch .sin (freqs ) * mscale ).to (x_rope_hf .dtype )
116+ y_rope_hf = (x_rope_hf * cos_ ) + (_rotate_half (x_rope_hf , rotary_interleaved ) * sin_ )
117+ y = torch .cat ((_hf_layout_to_xllm (y_rope_hf ), x_pass ), dim = - 1 )
118+ return _xllm_layout_to_hf (y )
119+
120+
92121def _apply_rotary_pos_emb_bshd (
93122 t : Tensor ,
94123 freqs : Tensor ,
95124 rotary_interleaved : bool = False ,
96125 multi_latent_attention : bool = False ,
97126 mscale : float = 1.0 ,
127+ xllm_partial_rope_layout : bool = False ,
98128) -> Tensor :
99129 """Apply rotary positional embedding to input tensor T.
100130
@@ -108,6 +138,12 @@ def _apply_rotary_pos_emb_bshd(
108138 Tensor: The input tensor after applying RoPE
109139 """
110140 rot_dim = freqs .shape [- 1 ]
141+ if xllm_partial_rope_layout and rot_dim < t .shape [- 1 ]:
142+ if multi_latent_attention :
143+ raise ValueError ("xLLM partial RoPE layout is not compatible with MLA tensors" )
144+ return _apply_xllm_partial_rotary_pos_emb_bshd (
145+ t , freqs , rotary_interleaved = rotary_interleaved , mscale = mscale
146+ )
111147
112148 # ideally t_pass is empty so rotary pos embedding is applied to all tensor t
113149 t , t_pass = t [..., :rot_dim ], t [..., rot_dim :]
@@ -183,6 +219,7 @@ def _apply_rotary_pos_emb_thd(
183219 multi_latent_attention : bool = False ,
184220 mscale : float = 1.0 ,
185221 cp_group : torch .distributed .ProcessGroup = None ,
222+ xllm_partial_rope_layout : bool = False ,
186223) -> Tensor :
187224 """A baseline implementation of applying RoPE for `thd` format.
188225
@@ -228,6 +265,7 @@ def _apply_rotary_pos_emb_thd(
228265 rotary_interleaved = rotary_interleaved ,
229266 multi_latent_attention = multi_latent_attention ,
230267 mscale = mscale ,
268+ xllm_partial_rope_layout = xllm_partial_rope_layout ,
231269 ).squeeze (1 )
232270 else :
233271 # CASE 2: Traditional mapping without offsets
@@ -244,6 +282,7 @@ def _apply_rotary_pos_emb_thd(
244282 rotary_interleaved = rotary_interleaved ,
245283 multi_latent_attention = multi_latent_attention ,
246284 mscale = mscale ,
285+ xllm_partial_rope_layout = xllm_partial_rope_layout ,
247286 ).squeeze (1 )
248287
249288
@@ -276,6 +315,8 @@ def apply_rotary_pos_emb(
276315 "Please set apply_rope_fusion to false. This will become an error in v0.16."
277316 )
278317 use_unfused = True
318+ if getattr (config , "xllm_partial_rope_layout" , False ):
319+ use_unfused = True
279320 if mscale != 1.0 :
280321 warnings .warn (
281322 f"mscale={ mscale } is not supported by TE's fused RoPE. "
@@ -286,10 +327,11 @@ def apply_rotary_pos_emb(
286327 assert fused_apply_rotary_pos_emb is not None , "apply_rope_fusion is not available."
287328 return fused_apply_rotary_pos_emb (t , freqs , interleaved = config .rotary_interleaved )
288329 else :
289- assert fused_apply_rotary_pos_emb_thd is not None , "apply_rope_fusion is not available."
290- return fused_apply_rotary_pos_emb_thd (
291- t , cu_seqlens , freqs , cp_size = cp_group .size (), cp_rank = cp_group .rank ()
292- )
330+ if not getattr (config , "xllm_partial_rope_layout" , False ):
331+ assert fused_apply_rotary_pos_emb_thd is not None , "apply_rope_fusion is not available."
332+ return fused_apply_rotary_pos_emb_thd (
333+ t , cu_seqlens , freqs , cp_size = cp_group .size (), cp_rank = cp_group .rank ()
334+ )
293335 # use unfused implementation
294336 if cu_seqlens is None :
295337 return _apply_rotary_pos_emb_bshd (
@@ -298,6 +340,7 @@ def apply_rotary_pos_emb(
298340 rotary_interleaved = config .rotary_interleaved ,
299341 multi_latent_attention = config .multi_latent_attention ,
300342 mscale = mscale ,
343+ xllm_partial_rope_layout = getattr (config , "xllm_partial_rope_layout" , False ),
301344 )
302345 else :
303346 return _apply_rotary_pos_emb_thd (
@@ -308,6 +351,7 @@ def apply_rotary_pos_emb(
308351 multi_latent_attention = config .multi_latent_attention ,
309352 mscale = mscale ,
310353 cp_group = cp_group ,
354+ xllm_partial_rope_layout = getattr (config , "xllm_partial_rope_layout" , False ),
311355 )
312356
313357
0 commit comments