Skip to content

Commit b0ddb33

Browse files
committed
[bugfix] fix mla rope (#8462)
1 parent 3401560 commit b0ddb33

3 files changed

Lines changed: 7 additions & 13 deletions

File tree

swift/megatron/init.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -876,9 +876,9 @@ def _apply_rope(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor):
876876
# x_pe [seqlen, batch, *, qk_pos_emb_head_dim]
877877
x_pe, x_nope = torch.split(
878878
x, [self.index_head_dim - self.qk_pos_emb_head_dim, self.qk_pos_emb_head_dim], dim=-1)
879-
origin_rotary_interleaved = self.config.rotary_interleaved
879+
origin_multi_latent_attention = self.config.multi_latent_attention
880880
try:
881-
self.config.rotary_interleaved = self.config.dsa_indexer_rotary_interleaved
881+
self.config.multi_latent_attention = self.config.dsa_indexer_rotary_interleaved
882882
x_pe = apply_rotary_pos_emb(
883883
x_pe,
884884
rotary_pos_emb,
@@ -887,7 +887,7 @@ def _apply_rope(self, x: torch.Tensor, rotary_pos_emb: torch.Tensor):
887887
cp_group=self.pg_collection.cp,
888888
)
889889
finally:
890-
self.config.rotary_interleaved = origin_rotary_interleaved
890+
self.config.multi_latent_attention = origin_multi_latent_attention
891891
# [seqlen, batch, *, index_head_dim]
892892
x = torch.cat([x_pe, x_nope], dim=-1)
893893
return x

swift/megatron/model/gpt_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def _apply_rotary_pos_emb_bshd(
176176

177177
# ideally t_pass is empty so rotary pos embedding is applied to all tensor t
178178
t, t_pass = t[..., :rot_dim], t[..., rot_dim:]
179+
if multi_latent_attention:
180+
x1 = t[..., 0::2]
181+
x2 = t[..., 1::2]
182+
t = torch.cat((x1, x2), dim=-1)
179183

180184
# first part is cosine component
181185
# second part is sine component, need to change signs with _rotate_half method

swift/megatron/model/model_config.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -278,11 +278,7 @@ def __post_init__(self):
278278
assert not self.swiglu
279279
self.gated_linear_unit = True
280280
self.activation_func = quick_gelu
281-
_origin_rotary_interleaved = self.rotary_interleaved
282-
if self.multi_latent_attention and self.rotary_interleaved:
283-
self.rotary_interleaved = False
284281
super().__post_init__()
285-
self.rotary_interleaved = _origin_rotary_interleaved
286282
self._check_npu()
287283
self.variable_seq_lengths = True
288284

@@ -481,8 +477,6 @@ def convert_hf_config(config) -> Dict[str, Any]:
481477
res.pop('num_query_groups', None)
482478
if llm_model_type == 'glm_moe_dsa':
483479
res['experimental_attention_variant'] = 'dsa'
484-
# https://github.com/modelscope/ms-swift/pull/8085
485-
# res['rotary_interleaved'] = False
486480
elif llm_model_type == 'qwen3_next' or hf_model_type in {'qwen3_5', 'qwen3_5_moe'}:
487481
use_mcore_gdn = get_env_args('SWIFT_USE_MCORE_GDN', bool, False)
488482
if use_mcore_gdn and llm_model_type == 'qwen3_next':
@@ -525,10 +519,6 @@ def convert_hf_config(config) -> Dict[str, Any]:
525519
mrope_interleaved = rope_scaling.get('mrope_interleaved', False) or rope_scaling.get('interleaved', False)
526520
res['mrope_interleaved'] = mrope_interleaved
527521

528-
if res.get('multi_latent_attention') and res.get('position_embedding_type') in {
529-
'rope', None
530-
} and 'rotary_interleaved' not in res:
531-
res['rotary_interleaved'] = True
532522
if first_k_dense_replace is not None:
533523
res['moe_layer_freq'] = f'[0]*{first_k_dense_replace}+[1]*{res["num_layers"] - first_k_dense_replace}'
534524
if res.get('moe_router_score_function', 'softmax') == 'sigmoid' and 'moe_router_enable_expert_bias' not in res:

0 commit comments

Comments
 (0)