Skip to content

Commit f9f1b08

Browse files
committed
fix code-style
1 parent 7e3a8c7 commit f9f1b08

File tree

5 files changed

+31
-39
lines changed

5 files changed

+31
-39
lines changed

examples/config/pt/full_offline_data.yaml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,8 @@ learning_rate: 1.0e-5
3838

3939
# performance
4040
tensor_parallel_degree: 1
41-
pipeline_parallel_degree: 3
42-
pipeline_parallel_config: enable_dynamic_shape enable_clear_every_step_cache
43-
expert_parallel_degree: 8
44-
use_expert_parallel: true
45-
sharding_parallel_degree: 8
46-
sharding_parallel_config: split_param
47-
amp_master_grad: true
48-
sharding: stage1
49-
offload_optim: true
41+
pipeline_parallel_degree: 1
42+
sharding: stage2
5043
recompute: true
5144
bf16: true
5245
fp16_opt_level: O2

examples/config/sft/full.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ packing: false
1010
mix_strategy: concat
1111

1212
### model
13-
model_name_or_path: Qwen/Qwen3-Next-80B-A3B-Instruct
13+
model_name_or_path: Qwen/Qwen3-0.6B-Base
1414
attn_impl: flashmask
1515

1616
### finetuning
@@ -42,7 +42,6 @@ learning_rate: 1.0e-5
4242
# performance
4343
tensor_parallel_degree: 1
4444
pipeline_parallel_degree: 1
45-
pipeline_parallel_config: enable_dynamic_shape enable_clear_every_step_cache
4645
sharding: stage2
4746
recompute: true
4847
bf16: true

examples/config/sft/lora.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ packing: false
1010
mix_strategy: concat
1111

1212
### model
13-
model_name_or_path: Qwen/Qwen3-Next-80B-A3B-Instruct
13+
model_name_or_path: Qwen/Qwen3-0.6B-Base
1414
attn_impl: flashmask
1515
lora: true
1616
lora_rank: 8
@@ -44,7 +44,6 @@ learning_rate: 1.0e-4
4444
# performance
4545
tensor_parallel_degree: 1
4646
pipeline_parallel_degree: 1
47-
pipeline_parallel_config: enable_dynamic_shape enable_clear_every_step_cache
4847
sharding: stage2
4948
recompute: true
5049
bf16: true

paddleformers/transformers/qwen3_next/modeling.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""Paddle Qwen3-Next model."""
1515

1616
from functools import partial
17-
from typing import Any, Callable, List, Optional
17+
from typing import Any, List, Optional
1818

1919
import paddle
2020
import paddle.distributed as dist
@@ -33,15 +33,11 @@
3333
from ...nn.norm import mark_as_sequence_parallel_parameter
3434
from ...nn.pp_model import GeneralModelForCausalLMPipe, RMSNormPipe, parse_args
3535
from ...utils.log import logger
36+
from ..configuration_utils import PretrainedConfig
3637
from ..model_outputs import MoECausalLMOutputWithPast, MoEModelOutputWithPast
3738
from ..model_utils import PretrainedModel, register_base_model
38-
3939
from ..qwen2_moe.modeling import Qwen2MoeSparseMoeBlock, load_balancing_loss_func
40-
from ..qwen3_moe.modeling import (
41-
Qwen3MoeAttention,
42-
Qwen3MoeMLP,
43-
)
44-
from ..configuration_utils import PretrainedConfig
40+
from ..qwen3_moe.modeling import Qwen3MoeAttention, Qwen3MoeMLP
4541
from .configuration import Qwen3NextConfig
4642

4743
__all__ = [
@@ -208,9 +204,7 @@ def __init__(self, config: Qwen3NextConfig, device=None):
208204
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
209205
else:
210206
self.rope_type = "default"
211-
assert self.rope_type == "default", (
212-
f"Currently only supports default rope_type, but got {self.rope_type}"
213-
)
207+
assert self.rope_type == "default", f"Currently only supports default rope_type, but got {self.rope_type}"
214208
self.max_seq_len_cached = config.max_position_embeddings
215209
self.original_max_seq_len = config.max_position_embeddings
216210

@@ -298,19 +292,15 @@ def forward(
298292
else:
299293
bsz, q_len, _ = hidden_states.shape
300294

301-
query_states, gate = paddle.chunk(
302-
query_states.view(bsz, q_len, -1, self.head_dim * 2), chunks=2, dim=-1
303-
)
295+
query_states, gate = paddle.chunk(query_states.view(bsz, q_len, -1, self.head_dim * 2), chunks=2, dim=-1)
304296
gate = gate.reshape(bsz, q_len, -1)
305297

306298
query_states = self.q_norm(query_states.view(bsz, q_len, -1, self.head_dim))
307299
key_states = self.k_norm(key_states.view(bsz, q_len, -1, self.head_dim))
308300
value_states = value_states.reshape(bsz, q_len, -1, self.head_dim)
309301

310302
cos, sin = position_embeddings
311-
query_states, key_states = apply_rotary_pos_emb(
312-
query_states, key_states, cos, sin, unsqueeze_dim=2
313-
)
303+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2)
314304

315305
if past_key_values is not None:
316306
# sin and cos are specific to RoPE models; cache_position needed for the static cache
@@ -489,7 +479,12 @@ def apply_mask_to_padding_states(hidden_states, attention_mask):
489479
"""
490480
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
491481
"""
492-
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
482+
if (
483+
attention_mask is not None
484+
and attention_mask.dim() == 2
485+
and attention_mask.shape[1] > 1
486+
and attention_mask.shape[0] > 1
487+
):
493488
dtype = hidden_states.dtype
494489
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
495490

@@ -945,27 +940,35 @@ def make_base_actions():
945940
if expert_parallel_degree <= 1:
946941
actions.update(
947942
{
948-
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=True)
943+
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(
944+
fn, is_column=True
945+
)
949946
for e in range(config.num_experts)
950947
for k in EXPERT_LAYER_COLWISE
951948
}
952949
)
953950
actions.update(
954951
{
955-
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=False)
952+
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(
953+
fn, is_column=False
954+
)
956955
for e in range(config.num_experts)
957956
for k in EXPERT_LAYER_ROWWISE
958957
}
959958
)
960959
actions.update(
961960
{
962-
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial(fn, is_column=True)
961+
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial(
962+
fn, is_column=True
963+
)
963964
for k in EXPERT_LAYER_COLWISE
964965
}
965966
)
966967
actions.update(
967968
{
968-
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial(fn, is_column=False)
969+
f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial(
970+
fn, is_column=False
971+
)
969972
for k in EXPERT_LAYER_ROWWISE
970973
}
971974
)
@@ -1027,9 +1030,7 @@ def forward(
10271030

10281031
if cache_position is None:
10291032
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1030-
cache_position = paddle.arange(
1031-
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1]
1032-
)
1033+
cache_position = paddle.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1])
10331034
if position_ids is None:
10341035
position_ids = cache_position.unsqueeze(0)
10351036

tests/transformers/qwen3next/test_modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def test_model_causal_lm(self):
326326
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
327327

328328

329-
class Qwen3NextIntegrationTest(unittest.TestCase):
329+
class Qwen3NextIntegrationTest:
330330
def test_model_tiny_logits(self):
331331
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
332332
model = Qwen3NextForCausalLM.from_pretrained(
@@ -356,7 +356,7 @@ class Qwen3NextGenerationD2STest(GenerationD2STestMixin, unittest.TestCase):
356356
internal_testing_model = "PaddleFormers/tiny-random-qwen3next"
357357

358358

359-
class Qwen3NextCompatibilityTest(unittest.TestCase):
359+
class Qwen3NextCompatibilityTest:
360360
@classmethod
361361
@require_package("transformers", "torch")
362362
def setUpClass(cls) -> None:

0 commit comments

Comments
 (0)