diff --git a/docker/patch/latest/megatron.patch b/docker/patch/latest/megatron.patch index ac18b9408..7483233aa 100644 --- a/docker/patch/latest/megatron.patch +++ b/docker/patch/latest/megatron.patch @@ -93,19 +93,6 @@ index 860ee64a9..80944b702 100755 sharded_state_dict_keys_map={ "mlp.0.weight": "mlp.linear_fc1.layer_norm_weight", "mlp.0.bias": "mlp.linear_fc1.layer_norm_bias", -diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py -index 6aec66e6d..3ac631935 100644 ---- a/megatron/core/models/gpt/gpt_model.py -+++ b/megatron/core/models/gpt/gpt_model.py -@@ -446,7 +446,7 @@ class GPTModel(LanguageModule): - if self.share_embeddings_and_output_weights: - output_weight = self.shared_embedding_or_output_weight() - -- if mtp_in_postprocess: -+ if mtp_in_postprocess and labels is not None: - hidden_states = self.mtp( - input_ids=input_ids, - position_ids=position_ids, diff --git a/megatron/core/parallel_state.py b/megatron/core/parallel_state.py index a40c85a88..86688c331 100644 --- a/megatron/core/parallel_state.py @@ -149,6 +136,148 @@ index 63ee9d1f5..b90b744c1 100644 ) ops.append(recv_next_op) if len(ops) > 0: +diff --git a/megatron/core/transformer/attention.py b/megatron/core/transformer/attention.py +index c749bac43..dde8d50e7 100644 +--- a/megatron/core/transformer/attention.py ++++ b/megatron/core/transformer/attention.py +@@ -670,7 +670,10 @@ class Attention(MegatronModule, ABC): + # Get the query, key and value tensors based on the type of attention - + # self or cross attn. + nvtx_range_push(suffix="qkv") +- query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) ++ if self.config.use_gated_attention: ++ query, gate, key, value = self.get_query_gate_key_value_tensors(hidden_states, key_value_states) ++ else: ++ query, key, value = self.get_query_key_value_tensors(hidden_states, key_value_states) + nvtx_range_pop(suffix="qkv") + + # =================================================== +@@ -842,6 +845,11 @@ class Attention(MegatronModule, ABC): + # Output. [sq, b, h] + # ================= + ++ if self.config.use_gated_attention: ++ nvtx_range_push(suffix="sigmoid_gate") ++ core_attn_out = core_attn_out * torch.sigmoid(gate) ++ nvtx_range_pop(suffix="sigmoid_gate") ++ + nvtx_range_push(suffix="linear_proj") + output, bias = self.linear_proj(core_attn_out) + nvtx_range_pop(suffix="linear_proj") +@@ -879,19 +887,34 @@ class SelfAttention(Attention): + model_comm_pgs=model_comm_pgs, + ) + +- self.linear_qkv = build_module( +- submodules.linear_qkv, +- self.config.hidden_size, +- self.query_projection_size + 2 * self.kv_projection_size, +- config=self.config, +- init_method=self.config.init_method, +- gather_output=False, +- bias=self.config.add_bias_linear or self.config.add_qkv_bias, +- skip_bias_add=False, +- is_expert=False, +- tp_comm_buffer_name='qkv', +- tp_group=self.model_comm_pgs.tp, +- ) ++ if self.config.use_gated_attention: ++ self.linear_qgkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ 2 * (self.query_projection_size + self.kv_projection_size), ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) ++ else: ++ self.linear_qkv = build_module( ++ submodules.linear_qkv, ++ self.config.hidden_size, ++ self.query_projection_size + 2 * self.kv_projection_size, ++ config=self.config, ++ init_method=self.config.init_method, ++ gather_output=False, ++ bias=self.config.add_bias_linear or self.config.add_qkv_bias, ++ skip_bias_add=False, ++ is_expert=False, ++ tp_comm_buffer_name='qkv', ++ tp_group=self.model_comm_pgs.tp, ++ ) + + if submodules.q_layernorm is not None: + self.q_layernorm = build_module( +@@ -1036,6 +1059,65 @@ class SelfAttention(Attention): + + return query, key, value + ++ # adapt from https://github.com/alibaba/Pai-Megatron-Patch/blob/8e6cbb0556ba09933ab4a4edb23c0af1d19d9960/megatron_patch/model/qwen3_next/gated_attention.py#L192 ++ def get_query_gate_key_value_tensors(self, hidden_states, key_value_states=None): ++ """ ++ Derives `query`, `key` and `value` tensors from `hidden_states`. ++ """ ++ # Attention heads [sq, b, h] --> [sq, b, ng * 2 * (np/ng + 1) * hn)] ++ mixed_qgkv, _ = self.linear_qgkv(hidden_states) ++ ++ # [sq, b, hp] --> [sq, b, ng, 2 * (np/ng + 1) * hn] ++ new_tensor_shape = mixed_qgkv.size()[:-1] + ( ++ self.num_query_groups_per_partition, ++ ( ++ 2 * (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 1) ++ * self.hidden_size_per_attention_head ++ ), ++ ) ++ mixed_qgkv = mixed_qgkv.view(*new_tensor_shape) ++ ++ split_arg_list = [ ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ ( ++ self.num_attention_heads_per_partition ++ // self.num_query_groups_per_partition ++ * self.hidden_size_per_attention_head ++ ), ++ self.hidden_size_per_attention_head, ++ self.hidden_size_per_attention_head, ++ ] ++ ++ if SplitAlongDim is not None: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = SplitAlongDim(mixed_qgkv, 3, split_arg_list) ++ else: ++ ++ # [sq, b, ng, (np/ng + 2) * hn] ++ # --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] ++ (query, gate, key, value) = torch.split(mixed_qgkv, split_arg_list, dim=3) ++ ++ # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] ++ query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head) ++ gate = gate.reshape(query.size(0), query.size(1), -1) ++ ++ if self.q_layernorm is not None: ++ query = self.q_layernorm(query) ++ ++ if self.k_layernorm is not None: ++ key = self.k_layernorm(key) ++ ++ if self.config.test_mode: ++ self.run_realtime_tests() ++ ++ return query, gate, key, value ++ + def backward_dw(self) -> NoReturn: + """Execute weight update operations""" + self._backward_qkv_proj() diff --git a/megatron/core/transformer/moe/moe_utils.py b/megatron/core/transformer/moe/moe_utils.py index 235b6f6af..fbcffe278 100644 --- a/megatron/core/transformer/moe/moe_utils.py @@ -177,16 +306,42 @@ index 6b20b8622..459e65921 100644 def _maintain_float32_expert_bias(self): """ Maintain the expert bias in float32. +diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py +index b7884e18e..7ea47da8a 100755 +--- a/megatron/core/transformer/multi_token_prediction.py ++++ b/megatron/core/transformer/multi_token_prediction.py +@@ -681,9 +681,6 @@ class MultiTokenPredictionLayer(MegatronModule): + [s, b, h], and optionally the updated context tensor if cross-attention is used. + """ + assert context is None, f"multi token prediction + cross attention is not yet supported." +- assert ( +- packed_seq_params is None +- ), f"multi token prediction + sequence packing is not yet supported." + + input_ids, position_ids, decoder_input, hidden_states = self._get_embeddings( + input_ids=input_ids, +@@ -910,9 +907,7 @@ class MultiTokenPredictionBlock(MegatronModule): + # to the hidden_states_list + hidden_states_list.append(hidden_states) + +- # concat the hidden states of all mtp layers +- hidden_states = torch.cat(hidden_states_list, dim=0) +- return hidden_states ++ return hidden_states_list + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py -index d55bebe7e..1e1d9c781 100644 +index d55bebe7e..1eecbbd38 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py -@@ -173,6 +173,9 @@ class TransformerConfig(ModelParallelConfig): +@@ -173,6 +173,10 @@ class TransformerConfig(ModelParallelConfig): qk_layernorm: bool = False """Whether to apply `normalization` type of normalization to the query and key embeddings.""" + post_self_attn_layernorm: bool = False + post_mlp_layernorm: bool = False ++ use_gated_attention: bool = False + test_mode: bool = False """Whether to run real-time tests.""" @@ -262,7 +417,7 @@ index 84f22bdea..f0f3f8e86 100644 # discard the output of the pre-mlp layernorm and register the recompute # as a gradient hook of mlp_output_with_bias[0] diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py -index e3459c5ee..2a2fefac3 100644 +index e3459c5ee..7346bf35b 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -937,8 +937,6 @@ def validate_args(args, defaults={}): @@ -274,17 +429,18 @@ index e3459c5ee..2a2fefac3 100644 if args.num_experts is not None and args.moe_ffn_hidden_size is None: args.moe_ffn_hidden_size = args.ffn_hidden_size print("Warning: moe_ffn_hidden_size is not set, using ffn_hidden_size for MoE instead.") -@@ -1198,6 +1196,9 @@ def core_transformer_config_from_args(args, config_class=None): +@@ -1198,6 +1196,10 @@ def core_transformer_config_from_args(args, config_class=None): if args.is_hybrid_model: kw_args['is_hybrid_model'] = args.is_hybrid_model + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm ++ kw_args['use_gated_attention'] = args.use_gated_attention + # handle quantization config # NOTE: Kitchen arguments are only added to the namespace when # Kitchen library is available. -@@ -1488,6 +1489,10 @@ def _add_network_size_args(parser): +@@ -1488,6 +1490,12 @@ def _add_network_size_args(parser): action='store_true', help='If set, use original BERT residula connection ' 'ordering.') @@ -292,6 +448,21 @@ index e3459c5ee..2a2fefac3 100644 + help='If set, use post self attention layernorm.') + group.add_argument('--post-mlp-layernorm', action='store_true', + help='If set, use post MLP layernorm.') ++ group.add_argument('--use-gated-attention', action='store_true', ++ help='If set, use gated attention as in Qwen3Next') group.add_argument('--openai-gelu', action='store_true', help='Use OpenAIs GeLU implementation. This option' 'should not be used unless for backward compatibility' +diff --git a/megatron/training/tokenizer/tokenizer.py b/megatron/training/tokenizer/tokenizer.py +index 5cf222ccc..d1554ca4c 100644 +--- a/megatron/training/tokenizer/tokenizer.py ++++ b/megatron/training/tokenizer/tokenizer.py +@@ -138,6 +138,8 @@ class _HuggingFaceTokenizer(MegatronTokenizer): + f"The transformers library must be installed to use huggingface_tokenizer_provider" + ) + ++ if "trust_remote_code" not in kwargs: ++ kwargs["trust_remote_code"] = True + # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there + self._tokenizer = transformers.AutoTokenizer.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs diff --git a/scripts/models/qwen3-next-80B-A3B.sh b/scripts/models/qwen3-next-80B-A3B.sh index 38bed96cb..21eda3023 100644 --- a/scripts/models/qwen3-next-80B-A3B.sh +++ b/scripts/models/qwen3-next-80B-A3B.sh @@ -25,6 +25,7 @@ MODEL_ARGS=( --num-layers 48 --hidden-size 2048 --ffn-hidden-size 5120 + --use-gated-attention --normalization RMSNorm --apply-layernorm-1p diff --git a/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py b/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py index 05f8c60c1..6b6dfc327 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py +++ b/slime/backends/megatron_utils/megatron_to_hf/qwen3_next.py @@ -75,11 +75,17 @@ def convert_qwen3_next_to_hf(args, name, param): if rest == "self_attention.linear_proj.weight": return [(f"model.layers.{layer_idx}.self_attn.o_proj.weight", param)] - elif rest == "self_attention.linear_qkv.weight": + elif rest == "self_attention.linear_qgkv.weight": param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size) - q_param, k_param, v_param = torch.split(param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1) - q_param = q_param.reshape(-1, args.hidden_size) + q_param, k_param, v_param = torch.split( + param, split_size_or_sections=[2 * value_num_per_group, 1, 1], dim=1 + ) + q_param = ( + q_param.reshape(args.num_query_groups, 2, value_num_per_group, head_dim, args.hidden_size) + .transpose(1, 2) + .reshape(-1, args.hidden_size) + ) k_param = k_param.reshape(-1, args.hidden_size) v_param = v_param.reshape(-1, args.hidden_size) return [ @@ -87,7 +93,7 @@ def convert_qwen3_next_to_hf(args, name, param): (f"model.layers.{layer_idx}.self_attn.k_proj.weight", k_param), (f"model.layers.{layer_idx}.self_attn.v_proj.weight", v_param), ] - elif rest == "self_attention.linear_qkv.bias": + elif rest == "self_attention.linear_qgkv.bias": param = param.view(args.num_query_groups, -1) q_bias, k_bias, v_bias = torch.split( param, @@ -110,7 +116,7 @@ def convert_qwen3_next_to_hf(args, name, param): ] elif rest == "mlp.linear_fc2.weight": return [(f"model.layers.{layer_idx}.mlp.down_proj.weight", param)] - elif rest == "self_attention.linear_qkv.layer_norm_weight": + elif rest == "self_attention.linear_qgkv.layer_norm_weight": return [(f"model.layers.{layer_idx}.input_layernorm.weight", param)] elif rest == "mlp.linear_fc1.layer_norm_weight": return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)] diff --git a/slime_plugins/mbridge/qwen3_next.py b/slime_plugins/mbridge/qwen3_next.py index 4b4182a2e..377ba18f6 100644 --- a/slime_plugins/mbridge/qwen3_next.py +++ b/slime_plugins/mbridge/qwen3_next.py @@ -1,30 +1,79 @@ +import torch from mbridge.core import register_model from mbridge.models import Qwen2MoEBridge @register_model("qwen3_next") class Qwen3NextBridge(Qwen2MoEBridge): - _ATTENTION_MAPPING = Qwen2MoEBridge._ATTENTION_MAPPING | { - f"self_attention.{weight_name}": ["model.layers.{layer_number}." + weight_name] - for weight_name in [ - "input_layernorm.weight", - # linear attn - "linear_attn.A_log", - "linear_attn.conv1d.weight", - "linear_attn.dt_bias", - "linear_attn.in_proj_ba.weight", - "linear_attn.in_proj_qkvz.weight", - "linear_attn.norm.weight", - "linear_attn.out_proj.weight", - # gated attn - "self_attn.k_norm.weight", - "self_attn.k_proj.weight", - "self_attn.o_proj.weight", - "self_attn.q_norm.weight", - "self_attn.q_proj.weight", - "self_attn.v_proj.weight", - ] - } + _ATTENTION_MAPPING = ( + Qwen2MoEBridge._ATTENTION_MAPPING + | { + f"self_attention.{weight_name}": ["model.layers.{layer_number}." + weight_name] + for weight_name in [ + "input_layernorm.weight", + # linear attn + "linear_attn.A_log", + "linear_attn.conv1d.weight", + "linear_attn.dt_bias", + "linear_attn.in_proj_ba.weight", + "linear_attn.in_proj_qkvz.weight", + "linear_attn.norm.weight", + "linear_attn.out_proj.weight", + # gated attn + "self_attn.k_norm.weight", + "self_attn.k_proj.weight", + "self_attn.o_proj.weight", + "self_attn.q_norm.weight", + "self_attn.q_proj.weight", + "self_attn.v_proj.weight", + ] + } + | { + "self_attention.linear_qgkv.layer_norm_weight": ["model.layers.{layer_number}.input_layernorm.weight"], + "self_attention.linear_qgkv.weight": [ + "model.layers.{layer_number}.self_attn.q_proj.weight", + "model.layers.{layer_number}.self_attn.k_proj.weight", + "model.layers.{layer_number}.self_attn.v_proj.weight", + ], + } + ) + + def _weight_to_mcore_format( + self, mcore_weights_name: str, hf_weights: list[torch.Tensor] + ) -> tuple[list[str], list[torch.Tensor]]: + if "self_attention.linear_qgkv." in mcore_weights_name and "layer_norm" not in mcore_weights_name: + # merge qkv + assert len(hf_weights) == 3 + num_key_value_heads = self.hf_config.num_key_value_heads + hidden_dim = self.hf_config.hidden_size + num_attention_heads = self.hf_config.num_attention_heads + num_querys_per_group = num_attention_heads // self.hf_config.num_key_value_heads + head_dim = getattr(self.hf_config, "head_dim", hidden_dim // num_attention_heads) + group_dim = head_dim * num_attention_heads // num_key_value_heads + q, k, v = hf_weights + # q k v might be tp split + real_num_key_value_heads = q.shape[0] // (2 * group_dim) + q = ( + q.view( + [ + real_num_key_value_heads, + num_querys_per_group, + 2, + head_dim, + -1, + ] + ) + .transpose(1, 2) + .flatten(1, 3) + ) + k = k.view([real_num_key_value_heads, head_dim, -1]) + v = v.view([real_num_key_value_heads, head_dim, -1]) + out_shape = [-1, hidden_dim] if ".bias" not in mcore_weights_name else [-1] + + qgkv = torch.cat([q, k, v], dim=1).view(*out_shape).contiguous() + return qgkv + + return super()._weight_to_mcore_format(mcore_weights_name, hf_weights) def _build_config(self): return self._build_base_config( @@ -46,4 +95,6 @@ def _build_config(self): # Qwen specific moe_router_pre_softmax=False, qk_layernorm=True, + # Qwen3 Next specific + use_gated_attention=True, ) diff --git a/slime_plugins/models/hf_attention.py b/slime_plugins/models/hf_attention.py index 0a3c33d80..4c2db88c4 100644 --- a/slime_plugins/models/hf_attention.py +++ b/slime_plugins/models/hf_attention.py @@ -73,21 +73,19 @@ def forward( # TODO: preprocess this for each batch to prevent tolist in the training step whole_hidden_states_list = [] + local_cu_seqlens = cu_seqlens // cp_size for i in range(len(cu_seqlens) - 1): seqlen = cu_seqlens[i + 1] - cu_seqlens[i] chunk_size = seqlen // 2 - whole_hidden_states_list.append( - torch.cat( - [ - hidden_states_list[cp_rank][cu_seqlens[i] : cu_seqlens[i] + chunk_size] - for cp_rank in range(cp_size) - ] - + [ - hidden_states_list[cp_rank][cu_seqlens[i] + chunk_size : cu_seqlens[i + 1]] - for cp_rank in range(cp_size) - ][::-1], - dim=0, - ) + whole_hidden_states_list.extend( + [ + hidden_states_list[cp_rank][local_cu_seqlens[i] : local_cu_seqlens[i] + chunk_size] + for cp_rank in range(cp_size) + ] + + [ + hidden_states_list[cp_rank][local_cu_seqlens[i] + chunk_size : local_cu_seqlens[i + 1]] + for cp_rank in range(cp_size) + ][::-1], ) hidden_states = torch.cat(whole_hidden_states_list, dim=0) @@ -102,12 +100,19 @@ def forward( output = self.hf_forward(hidden_states, position_ids, packed_seq_params) bias = None + if mpu.get_context_parallel_world_size() > 1: + output_list = [] + for i in range(len(cu_seqlens) - 1): + seqlen = cu_seqlens[i + 1] - cu_seqlens[i] + chunk_size = seqlen // 2 + + output = output.permute(1, 0, 2) # [seq_len, bsz, hidden_dim] + if self.args.sequence_parallel: output = tensor_parallel.scatter_to_sequence_parallel_region( - hidden_states, group=mpu.get_tensor_model_parallel_group() + output, group=mpu.get_tensor_model_parallel_group() ) - output = output.permute(1, 0, 2) # [seq_len, bsz, hidden_dim] return output, bias @abstractmethod diff --git a/slime_plugins/models/qwen3_next.py b/slime_plugins/models/qwen3_next.py index e89ffd0c0..b013db13c 100644 --- a/slime_plugins/models/qwen3_next.py +++ b/slime_plugins/models/qwen3_next.py @@ -1,19 +1,19 @@ +import copy + import torch import torch.nn as nn import torch.nn.functional as F from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_block import get_num_layers_to_build +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from transformers import AutoConfig from transformers.activations import ACT2FN try: from fla.modules import FusedRMSNormGated, ShortConvolution from fla.ops.gated_delta_rule import chunk_gated_delta_rule - from transformers.models.qwen3_next.modeling_qwen3_next import ( - Qwen3NextAttention, - Qwen3NextRMSNorm, - Qwen3NextRotaryEmbedding, - ) + from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextAttention, Qwen3NextRMSNorm except ImportError: pass @@ -181,32 +181,15 @@ def __init__( if Qwen3NextAttention is None: raise ImportError("Please install transformers>=4.35.0 to use Qwen3NextAttention.") - self.layer_type = self.hf_config.layer_types[self.hf_layer_idx] - if self.layer_type == "linear_attention": - self.linear_attn = Qwen3NextGatedDeltaNet(self.hf_config, self.hf_layer_idx) - elif self.layer_type == "full_attention": - self.rotary_emb = Qwen3NextRotaryEmbedding(config=self.hf_config) - self.self_attn = Qwen3NextAttention(self.hf_config, self.hf_layer_idx) - + self.linear_attn = Qwen3NextGatedDeltaNet(self.hf_config, self.hf_layer_idx) self.input_layernorm = Qwen3NextRMSNorm(self.hf_config.hidden_size, eps=self.hf_config.rms_norm_eps) def hf_forward(self, hidden_states, position_ids, packed_seq_params): hidden_states = self.input_layernorm(hidden_states) - - if self.layer_type == "linear_attention": - hidden_states = self.linear_attn( - hidden_states=hidden_states, - cu_seqlens=packed_seq_params.cu_seqlens_q, - ) - elif self.layer_type == "full_attention": - # Self Attention - position_embeddings = self.rotary_emb(hidden_states, position_ids) - hidden_states, _ = self.self_attn( - hidden_states=hidden_states, - attention_mask=None, - position_ids=position_ids, - position_embeddings=position_embeddings, - ) + hidden_states = self.linear_attn( + hidden_states=hidden_states, + cu_seqlens=packed_seq_params.cu_seqlens_q, + ) return hidden_states @@ -228,12 +211,17 @@ def get_qwen3_next_spec(args, config, vp_stage): # Slice the layer specs to only include the layers that are built in this pipeline stage. # Note: MCore layer_number starts at 1 num_layers_to_build = get_num_layers_to_build(config, vp_stage=vp_stage) + offset = get_transformer_layer_offset(config, vp_stage=vp_stage) + + hf_config = AutoConfig.from_pretrained(args.hf_checkpoint, trust_remote_code=True) for layer_id in range(num_layers_to_build): - transformer_layer_spec.layer_specs[layer_id].submodules.self_attention = ModuleSpec( - module=Attention, - params={"args": args}, - ) + if hf_config.layer_types[layer_id + offset] == "linear_attention": + layer_specs = copy.deepcopy(transformer_layer_spec.layer_specs[layer_id]) + layer_specs.submodules.self_attention = ModuleSpec( + module=Attention, + params={"args": args}, + ) + transformer_layer_spec.layer_specs[layer_id] = layer_specs transformer_layer_spec.layer_specs[layer_id].submodules.mlp.submodules.shared_experts.params = {"gate": True} - return transformer_layer_spec