|
14 | 14 | """Paddle Qwen3-Next model.""" |
15 | 15 |
|
16 | 16 | from functools import partial |
17 | | -from typing import Any, Callable, List, Optional |
| 17 | +from typing import Any, List, Optional |
18 | 18 |
|
19 | 19 | import paddle |
20 | 20 | import paddle.distributed as dist |
|
33 | 33 | from ...nn.norm import mark_as_sequence_parallel_parameter |
34 | 34 | from ...nn.pp_model import GeneralModelForCausalLMPipe, RMSNormPipe, parse_args |
35 | 35 | from ...utils.log import logger |
| 36 | +from ..configuration_utils import PretrainedConfig |
36 | 37 | from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast |
37 | 38 | from ..model_utils import PretrainedModel, register_base_model |
38 | | - |
39 | 39 | from ..qwen2_moe.modeling import Qwen2MoeSparseMoeBlock, load_balancing_loss_func |
40 | | -from ..qwen3_moe.modeling import ( |
41 | | - Qwen3MoeAttention, |
42 | | - Qwen3MoeMLP, |
43 | | -) |
44 | | -from ..configuration_utils import PretrainedConfig |
| 40 | +from ..qwen3_moe.modeling import Qwen3MoeAttention, Qwen3MoeMLP |
45 | 41 | from .configuration import Qwen3NextConfig |
46 | 42 |
|
47 | 43 | __all__ = [ |
@@ -208,9 +204,7 @@ def __init__(self, config: Qwen3NextConfig, device=None): |
208 | 204 | self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
209 | 205 | else: |
210 | 206 | self.rope_type = "default" |
211 | | - assert self.rope_type == "default", ( |
212 | | - f"Currently only supports default rope_type, but got {self.rope_type}" |
213 | | - ) |
| 207 | + assert self.rope_type == "default", f"Currently only supports default rope_type, but got {self.rope_type}" |
214 | 208 | self.max_seq_len_cached = config.max_position_embeddings |
215 | 209 | self.original_max_seq_len = config.max_position_embeddings |
216 | 210 |
|
@@ -298,19 +292,15 @@ def forward( |
298 | 292 | else: |
299 | 293 | bsz, q_len, _ = hidden_states.shape |
300 | 294 |
|
301 | | - query_states, gate = paddle.chunk( |
302 | | - query_states.view(bsz, q_len, -1, self.head_dim * 2), chunks=2, dim=-1 |
303 | | - ) |
| 295 | + query_states, gate = paddle.chunk(query_states.view(bsz, q_len, -1, self.head_dim * 2), chunks=2, dim=-1) |
304 | 296 | gate = gate.reshape(bsz, q_len, -1) |
305 | 297 |
|
306 | 298 | query_states = self.q_norm(query_states.view(bsz, q_len, -1, self.head_dim)) |
307 | 299 | key_states = self.k_norm(key_states.view(bsz, q_len, -1, self.head_dim)) |
308 | 300 | value_states = value_states.reshape(bsz, q_len, -1, self.head_dim) |
309 | 301 |
|
310 | 302 | cos, sin = position_embeddings |
311 | | - query_states, key_states = apply_rotary_pos_emb( |
312 | | - query_states, key_states, cos, sin, unsqueeze_dim=2 |
313 | | - ) |
| 303 | + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2) |
314 | 304 |
|
315 | 305 | if past_key_values is not None: |
316 | 306 | # sin and cos are specific to RoPE models; cache_position needed for the static cache |
@@ -489,7 +479,12 @@ def apply_mask_to_padding_states(hidden_states, attention_mask): |
489 | 479 | """ |
490 | 480 | Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66 |
491 | 481 | """ |
492 | | - if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: |
| 482 | + if ( |
| 483 | + attention_mask is not None |
| 484 | + and attention_mask.dim() == 2 |
| 485 | + and attention_mask.shape[1] > 1 |
| 486 | + and attention_mask.shape[0] > 1 |
| 487 | + ): |
493 | 488 | dtype = hidden_states.dtype |
494 | 489 | hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) |
495 | 490 |
|
@@ -945,27 +940,35 @@ def make_base_actions(): |
945 | 940 | if expert_parallel_degree <= 1: |
946 | 941 | actions.update( |
947 | 942 | { |
948 | | - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=True) |
| 943 | + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( |
| 944 | + fn, is_column=True |
| 945 | + ) |
949 | 946 | for e in range(config.num_experts) |
950 | 947 | for k in EXPERT_LAYER_COLWISE |
951 | 948 | } |
952 | 949 | ) |
953 | 950 | actions.update( |
954 | 951 | { |
955 | | - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=False) |
| 952 | + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( |
| 953 | + fn, is_column=False |
| 954 | + ) |
956 | 955 | for e in range(config.num_experts) |
957 | 956 | for k in EXPERT_LAYER_ROWWISE |
958 | 957 | } |
959 | 958 | ) |
960 | 959 | actions.update( |
961 | 960 | { |
962 | | - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial(fn, is_column=True) |
| 961 | + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial( |
| 962 | + fn, is_column=True |
| 963 | + ) |
963 | 964 | for k in EXPERT_LAYER_COLWISE |
964 | 965 | } |
965 | 966 | ) |
966 | 967 | actions.update( |
967 | 968 | { |
968 | | - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial(fn, is_column=False) |
| 969 | + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial( |
| 970 | + fn, is_column=False |
| 971 | + ) |
969 | 972 | for k in EXPERT_LAYER_ROWWISE |
970 | 973 | } |
971 | 974 | ) |
@@ -1027,9 +1030,7 @@ def forward( |
1027 | 1030 |
|
1028 | 1031 | if cache_position is None: |
1029 | 1032 | past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
1030 | | - cache_position = paddle.arange( |
1031 | | - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1] |
1032 | | - ) |
| 1033 | + cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]) |
1033 | 1034 | if position_ids is None: |
1034 | 1035 | position_ids = cache_position.unsqueeze(0) |
1035 | 1036 |
|
|
0 commit comments