Skip to content

minimax2.1 megatron sft with lora, tensor parallel have bug #7786

@tic-top

Description

@tic-top

Describe the bug
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)

export MODELSCOPE_CACHE=/tmp/modelscope_cache
export MEGATRON_LM_PATH=/tmp/Megatron-LM

export JOB_NAME=minimax_sft
export MODEL_NAME=MiniMaxAI/MiniMax-M2.1
export SAVE_PATH=/mnt/user/models/mg-exps/${JOB_NAME}
mkdir -p ${SAVE_PATH}

NPROC_PER_NODE=8 \
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
megatron sft \
    --model ${MODEL_NAME} \
    --dataset 'liucong/Chinese-DeepSeek-R1-Distill-data-110k-SFT' \
    --load_safetensors true \
    --save_safetensors true \
    --merge_lora false \
    --load_from_cache_file true \
    --train_type lora \
    --lora_rank 8 \
    --lora_alpha 32 \
    --target_modules all-linear \
    --split_dataset_ratio 0.01 \
    --tensor_model_parallel_size 4 \
    --expert_tensor_parallel_size 1 \
    --expert_model_parallel_size 8 \
    --moe_permute_fusion true \
    --moe_grouped_gemm true \
    --moe_shared_expert_overlap true \
    --moe_aux_loss_coeff 1e-3 \
    --micro_batch_size 1 \
    --global_batch_size 128 \
    --recompute_granularity full \
    --recompute_method uniform \
    --recompute_num_layers 1 \
    --max_epochs 1 \
    --finetune true \
    --cross_entropy_loss_fusion true \
    --lr 1e-4 \
    --lr_warmup_fraction 0.05 \
    --min_lr 1e-5 \
    --save ${SAVE_PATH} \
    --eval_interval 200 \
    --save_interval 200 \
    --max_length 8192 \
    --num_workers 8 \
    --dataset_num_proc 8 \
    --no_save_optim true \
    --no_save_rng true \
    --sequence_parallel true \
    --attention_backend flash \
    --use_hf true \
    --wandb-project test \
    --wandb-exp-name ${JOB_NAME}
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/cli/_megatron/sft.py", line 7, in <module>
[rank3]:     megatron_sft_main()
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/train/sft.py", line 87, in megatron_sft_main
[rank3]:     return MegatronSft(args).main()
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/llm/base.py", line 49, in main
[rank3]:     result = self.run()
[rank3]:              ^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/train/sft.py", line 77, in run
[rank3]:     self.trainer.train(train_dataset, val_dataset, data_collator)
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/trainers/base.py", line 1098, in train
[rank3]:     pretrain(
[rank3]:   File "/tmp/Megatron-LM/megatron/training/training.py", line 737, in pretrain
[rank3]:     iteration, num_floating_point_operations_so_far = train(
[rank3]:                                                       ^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/training/training.py", line 2298, in train
[rank3]:     ) = train_step(
[rank3]:         ^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/trainers/base.py", line 565, in train_step
[rank3]:     return self._origin_train_step(forward_step_func, new_data_iterator, model, optimizer, opt_param_scheduler,
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/training/training.py", line 1268, in train_step
[rank3]:     losses_reduced = forward_backward_func(
[rank3]:                      ^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 595, in forward_backward_no_pipelining
[rank3]:     output_tensor, num_tokens = forward_step(
[rank3]:                                 ^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 402, in forward_step
[rank3]:     output_tensor, loss_func = forward_step_func(data_iterator, model)
[rank3]:                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/trainers/trainer.py", line 150, in forward_step
[rank3]:     output_tensor = model(**data)
[rank3]:                     ^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/distributed/data_parallel_base.py", line 22, in forward
[rank3]:     return self.module(*inputs, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/module.py", line 429, in forward
[rank3]:     outputs = self.module(*inputs, **kwargs)
[rank3]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/model/gpt_model.py", line 299, in forward
[rank3]:     hidden_states = self.decoder(
[rank3]:                     ^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/transformer_block.py", line 553, in __call__
[rank3]:     return super().__call__(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/module.py", line 305, in __call__
[rank3]:     return super().__call__(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/transformer_block.py", line 669, in forward
[rank3]:     hidden_states = self._checkpointed_forward(
[rank3]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/transformer_block.py", line 472, in _checkpointed_forward
[rank3]:     hidden_states, context = checkpoint_handler(
[rank3]:                              ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/transformer_block.py", line 456, in checkpoint_handler
[rank3]:     return tensor_parallel.checkpoint(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/tensor_parallel/random.py", line 480, in checkpoint
[rank3]:     return CheckpointFunction.apply(function, distribute_saved_activations, *args)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/autograd/function.py", line 581, in apply
[rank3]:     return super().apply(*args, **kwargs)  # type: ignore[misc]
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/tensor_parallel/random.py", line 426, in forward
[rank3]:     outputs = run_function(*args)
[rank3]:               ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/transformer_block.py", line 426, in custom_forward
[rank3]:     hidden_states, context = layer(
[rank3]:                              ^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 852, in __call__
[rank3]:     return super().__call__(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/module.py", line 305, in __call__
[rank3]:     return super().__call__(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/init.py", line 545, in forward
[rank3]:     hidden_states, context = self._forward_attention(*_args, **kwargs)
[rank3]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/transformer_layer.py", line 499, in _forward_attention
[rank3]:     attention_output_with_bias = self.self_attention(
[rank3]:                                  ^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/tmp/Megatron-LM/megatron/core/transformer/attention.py", line 728, in forward
[rank3]:     qkv_output = self.get_query_key_value_tensors(
[rank3]:                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/swift/megatron/model/gpt/minimax_m2.py", line 52, in get_query_key_value_tensors
[rank3]:     query = self.q_norm(query.reshape(*query_shape[:-2], -1)).view(query_shape)
[rank3]:             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]:     return self._call_impl(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]:     return forward_call(*args, **kwargs)
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/transformer_engine/pytorch/ops/op.py", line 522, in forward
[rank3]:     return OperationFuser([self])(
[rank3]:            ^^^^^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/transformer_engine/pytorch/ops/fuser.py", line 493, in __call__
[rank3]:     return forward_func(*args)
[rank3]:            ^^^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/transformer_engine/pytorch/ops/fuser.py", line 138, in forward
[rank3]:     x, fused_op_extra_outputs = op.fuser_forward(
[rank3]:                                 ^^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/transformer_engine/pytorch/ops/op.py", line 483, in fuser_forward
[rank3]:     output = self.op_forward(
[rank3]:              ^^^^^^^^^^^^^^^^
[rank3]:   File "/usr/local/lib/python3.11/site-packages/transformer_engine/pytorch/ops/basic/rmsnorm.py", line 174, in op_forward
[rank3]:     raise ValueError(
[rank3]: ValueError: Input tensor (shape=(344, 1, 1536)) and weight tensor (shape=(6144,)) are not compatible
WARNING:megatron.core.utils:Input tensor (shape=(344, 1, 1536)) and weight tensor (shape=(6144,)) are not compatible

Your hardware and system info
8 * A100(80G)
docker modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.8.1-py311-torch2.9.0-vllm0.13.0-modelscope1.33.0-swift3.12.1

Additional context
I guess there's something wrong related to the tensor_model_parallel_size.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions