-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[trainer] feat: Add Nemo-Automodel as alternative training engine #5407
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
HuiyingLi
wants to merge
19
commits into
verl-project:main
Choose a base branch
from
HuiyingLi:add_automodel_sft_backend
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
07554c9
init version with fsdp2
HuiyingLi 11eb1a9
add mp policy config
HuiyingLi 4641e75
add ep and expose more configs
HuiyingLi 697bf68
fix(dataset): call .tolist() before tokenizer.decode() for tiktoken c…
HuiyingLi 41dd4a8
add test
HuiyingLi c33321b
format
HuiyingLi 9a14478
revert some format changes
HuiyingLi 4d7a193
fix eval ctx
HuiyingLi 6b3f061
fix exp name
HuiyingLi 3208fbd
add expert torch_mm backend to config
HuiyingLi a0b51f8
change copyright
HuiyingLi d2eec66
Merge branch 'main' into add_automodel_sft_backend
HuiyingLi ec3b283
upgrade to automodel r0.3.0 with transformers v5.0.0
HuiyingLi c1e8025
add automodel examples scripts
HuiyingLi 6060737
add docs
HuiyingLi 20cd9dc
update optimizer integration
HuiyingLi 1b9c6aa
update example scripts
HuiyingLi db0d6ca
add dependency req to examples
HuiyingLi 48b7315
format
HuiyingLi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| Automodel Backend | ||
| ================= | ||
|
|
||
| Last updated: 03/07/2026. | ||
|
|
||
| We support the Automodel (nemo_automodel) backend by implementing the | ||
| ``AutomodelEngine`` and ``AutomodelEngineWithLMHead`` engine classes. | ||
| The Automodel backend delegates model building, parallelization, optimizer | ||
| sharding, LR scheduling, gradient clipping, and checkpointing to | ||
| nemo_automodel's infrastructure while using verl's training loop, | ||
| data pipeline, and loss function. | ||
|
|
||
| **Requirements** | ||
|
|
||
| - Automodel r0.3.0 | ||
| - transformers v5.0.0 | ||
|
|
||
| **Pros** | ||
|
|
||
| - Supports FSDP2 and TP distributed strategies out of | ||
| the box. | ||
|
|
||
| - Native support for Mixture-of-Experts (MoE) models with Expert | ||
| Parallelism (EP) via DeepEP. | ||
|
|
||
| - TransformerEngine (TE) integration for optimized attention, linear | ||
| layers, and RMSNorm. | ||
|
|
||
| - Readily supports any HuggingFace model without checkpoint conversion. | ||
|
|
||
| **Cons** | ||
|
|
||
| - Pipeline parallelism is not yet supported. | ||
|
|
||
|
|
||
| SFT Examples | ||
| ------------ | ||
|
|
||
| We provide example SFT training scripts using the Automodel backend in | ||
| `examples/sft/gsm8k/ <https://github.com/volcengine/verl/blob/main/examples/sft/gsm8k/>`_. | ||
|
|
||
| Basic: Qwen2.5-0.5B with FSDP2 | ||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| A minimal example using ``Qwen/Qwen2.5-0.5B-Instruct`` with FSDP2 and | ||
| no parallelism: | ||
|
|
||
| .. code:: shell | ||
|
|
||
| bash examples/sft/gsm8k/run_qwen_05_automodel.sh 4 /tmp/automodel_sft_test | ||
|
|
||
| See `run_qwen_05_automodel.sh <https://github.com/volcengine/verl/blob/main/examples/sft/gsm8k/run_qwen_05_automodel.sh>`_. | ||
|
|
||
| Advanced: Qwen3-30B MoE with Expert Parallelism | ||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
|
||
| A larger-scale example using ``Qwen/Qwen3-30B-A3B-Base`` (MoE model) | ||
| with Expert Parallelism (EP=8), DeepEP, TransformerEngine backend, and | ||
| torch_mm experts backend: | ||
|
|
||
| .. code:: shell | ||
|
|
||
| bash examples/sft/gsm8k/run_qwen3_30b_automodel.sh 8 /tmp/automodel_sft_30b | ||
|
|
||
| See `run_qwen3_30b_automodel.sh <https://github.com/volcengine/verl/blob/main/examples/sft/gsm8k/run_qwen3_30b_automodel.sh>`_. | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| # Requires: Automodel, transformers>=5.3.0, torchao | ||
| # MoE also requires: grouped_gemm (github.com/fanshiqing/grouped_gemm v1.1.4) | ||
|
|
||
| set -x | ||
|
|
||
| if [ "$#" -lt 2 ]; then | ||
| echo "Usage: run_qwen3_30b_automodel.sh <nproc_per_node> <save_path> [other_configs...]" | ||
| exit 1 | ||
| fi | ||
|
|
||
| nproc_per_node=$1 | ||
| save_path=$2 | ||
|
|
||
| # Shift the arguments so $@ refers to the rest | ||
| shift 2 | ||
|
|
||
| torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ | ||
| -m verl.trainer.sft_trainer \ | ||
| data.train_files=$HOME/data/hellaswag_sft/hellaswag_sft.parquet \ | ||
| data.val_files=$HOME/data/hellaswag_sft/hellaswag_sft.parquet \ | ||
| data.train_batch_size=512 \ | ||
| data.max_length=2048 \ | ||
| data.truncation=left \ | ||
| data.use_dynamic_bsz=True \ | ||
| data.max_token_len_per_gpu=8192 \ | ||
| data.messages_key=messages \ | ||
| data.ignore_input_ids_mismatch=True \ | ||
| data.train_max_samples=-1 \ | ||
| data.val_max_samples=1024 \ | ||
| model=hf_model \ | ||
| model.path=Qwen/Qwen3-30B-A3B-Base \ | ||
| model.trust_remote_code=True \ | ||
| model.use_remove_padding=True \ | ||
| engine=automodel \ | ||
| engine.distributed_strategy=fsdp2 \ | ||
| engine.tp_size=1 \ | ||
| engine.pp_size=1 \ | ||
| engine.cp_size=1 \ | ||
| engine.ep_size=8 \ | ||
| engine.backend_config.dispatcher=deepep \ | ||
| engine.backend_config.attn=te \ | ||
| engine.backend_config.linear=te \ | ||
| engine.backend_config.rms_norm=torch_fp32 \ | ||
| engine.backend_config.enable_fsdp_optimizations=True \ | ||
| engine.backend_config.experts=torch_mm \ | ||
| engine.activation_checkpointing=True \ | ||
| engine.model_dtype=bf16 \ | ||
| engine.attn_implementation=te \ | ||
| engine.use_torch_compile=False \ | ||
| optim=automodel \ | ||
| optim.optimizer=FusedAdam \ | ||
| optim.optimizer_impl=transformer_engine.pytorch.optimizers.fused_adam \ | ||
| optim.lr=1e-5 \ | ||
| optim.lr_warmup_steps_ratio=0.1 \ | ||
| optim.weight_decay=0 \ | ||
| optim.betas='[0.9,0.95]' \ | ||
| optim.clip_grad=1.0 \ | ||
| optim.init_lr_ratio=0.1 \ | ||
| optim.min_lr_ratio=0.01 \ | ||
| optim.lr_scheduler_type=cosine \ | ||
| optim.master_weights=true \ | ||
| optim.store_param_remainders=true \ | ||
| optim.exp_avg_dtype=bf16 \ | ||
| optim.exp_avg_sq_dtype=bf16 \ | ||
| trainer.default_local_dir=$save_path \ | ||
| trainer.project_name=hellaswag-sft \ | ||
| trainer.experiment_name=hellaswag-sft-qwen3-30b-automodel \ | ||
| trainer.total_epochs=2 \ | ||
| trainer.total_training_steps=100 \ | ||
| trainer.save_freq=-1 \ | ||
| trainer.test_freq=10 \ | ||
| trainer.logger=console \ | ||
| trainer.seed=1111 \ | ||
| trainer.nnodes=1 \ | ||
| trainer.resume_mode=disable $@ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| # Requires: Automodel, transformers>=5.3.0, torchao | ||
| # MoE also requires: grouped_gemm (github.com/fanshiqing/grouped_gemm v1.1.4) | ||
|
|
||
| set -x | ||
|
|
||
| if [ "$#" -lt 2 ]; then | ||
| echo "Usage: run_qwen_05_automodel.sh <nproc_per_node> <save_path> [other_configs...]" | ||
| exit 1 | ||
| fi | ||
|
|
||
| nproc_per_node=$1 | ||
| save_path=$2 | ||
|
|
||
| # Shift the arguments so $@ refers to the rest | ||
| shift 2 | ||
|
|
||
| torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ | ||
| -m verl.trainer.sft_trainer \ | ||
| data.train_files=$HOME/data/gsm8k_sft/train.parquet \ | ||
| data.val_files=$HOME/data/gsm8k_sft/test.parquet \ | ||
| data.train_batch_size=128 \ | ||
| data.pad_mode=no_padding \ | ||
| data.truncation=error \ | ||
| data.use_dynamic_bsz=True \ | ||
| data.max_token_len_per_gpu=2048 \ | ||
| data.messages_key=messages \ | ||
| data.ignore_input_ids_mismatch=True \ | ||
| model=hf_model \ | ||
| model.path=Qwen/Qwen2.5-0.5B-Instruct \ | ||
| model.use_remove_padding=True \ | ||
| engine=automodel \ | ||
| engine.distributed_strategy=fsdp2 \ | ||
| engine.tp_size=1 \ | ||
| engine.pp_size=1 \ | ||
| engine.cp_size=1 \ | ||
| engine.ep_size=1 \ | ||
| engine.use_torch_compile=False \ | ||
| optim=automodel \ | ||
| optim.lr=1e-5 \ | ||
| optim.lr_warmup_steps_ratio=0.2 \ | ||
| optim.weight_decay=0.1 \ | ||
| optim.betas='[0.9,0.95]' \ | ||
| optim.clip_grad=1.0 \ | ||
| optim.init_lr_ratio=0 \ | ||
| optim.min_lr_ratio=0.1 \ | ||
| optim.lr_scheduler_type=cosine \ | ||
| trainer.default_local_dir=$save_path \ | ||
| trainer.project_name=gsm8k-sft \ | ||
| trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-automodel \ | ||
| trainer.total_epochs=2 \ | ||
| trainer.test_freq=-1 \ | ||
| trainer.save_freq=-1 \ | ||
| trainer.logger=console \ | ||
| trainer.seed=1111 \ | ||
| trainer.resume_mode=disable $@ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| # Target class for this configuration | ||
| _target_: verl.workers.config.AutomodelEngineConfig | ||
|
|
||
| # Backend strategy identifier | ||
| strategy: automodel | ||
|
|
||
| # Distributed training strategy: "fsdp2", "megatron_fsdp", or "ddp" | ||
| distributed_strategy: fsdp2 | ||
|
|
||
| # Parallelism sizes | ||
| tp_size: 1 | ||
| pp_size: 1 | ||
| cp_size: 1 | ||
| ep_size: 1 | ||
| dp_replicate_size: 1 | ||
| sequence_parallel: false | ||
| defer_fsdp_grad_sync: true | ||
|
|
||
| # Whether to offload model parameters to CPU | ||
| param_offload: false | ||
|
|
||
| # Whether to offload optimizer state to CPU | ||
| optimizer_offload: false | ||
|
|
||
| # Whether to enable activation checkpointing | ||
| activation_checkpointing: false | ||
|
|
||
| # Whether to enable FP8 training | ||
| enable_fp8: false | ||
|
|
||
| # Whether to enable torch.compile for the model | ||
| enable_compile: false | ||
|
|
||
| # Model data type for loading weights ("fp32", "bf16", "fp16") | ||
| model_dtype: fp32 | ||
|
|
||
| # Attention implementation ("sdpa", "flash_attention_2", "eager", "te") | ||
| attn_implementation: flash_attention_2 | ||
|
|
||
| # Backend settings | ||
| backend_config: | ||
| attn: sdpa # "te", "sdpa" | ||
| linear: te # "torch", "te" | ||
| rms_norm: torch_fp32 # "torch", "torch_fp32", "te" | ||
| rope_fusion: true | ||
| dispatcher: torch # "torch", "deepep" | ||
| experts: gmm # "gmm", "torch_mm", "torch", "te" | ||
| gate_precision: null | ||
| enable_hf_state_dict_adapter: true | ||
| enable_fsdp_optimizations: false | ||
| fake_balanced_gate: false | ||
| fake_gate_noise: 0.0 | ||
|
|
||
| # MoE settings (MoEParallelizerConfig) | ||
| moe_config: | ||
| ignore_router_for_ac: false | ||
| reshard_after_forward: false | ||
| lm_head_precision: null | ||
| wrap_outer_model: true | ||
|
|
||
| # Mixed precision policy (FSDP2 MixedPrecisionPolicy) | ||
| mp_param_dtype: bf16 | ||
| mp_reduce_dtype: fp32 | ||
| mp_output_dtype: bf16 | ||
|
|
||
| # Random seed for reproducibility | ||
| seed: 42 | ||
|
|
||
| # Whether to enable full determinism for distributed training, only for debugging | ||
| full_determinism: false | ||
|
|
||
| # Whether to use forward only mode | ||
| forward_only: false | ||
|
|
||
| # Whether to use torch compile for entropy computation | ||
| use_torch_compile: false | ||
|
|
||
| # Whether to use chunked entropy computation | ||
| entropy_from_logits_with_chunking: false | ||
|
|
||
| # Whether to use checkpointing for entropy computation | ||
| entropy_checkpointing: false |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe in another PR, we should refactor
docs/start/install.rstto support all model engines and rollout engine install methods and use some displays for better choices.