diff --git a/docs/index.rst b/docs/index.rst index 3b27486e062..1d3bcf239ff 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -90,6 +90,7 @@ verl is fast with: workers/ray_trainer workers/fsdp_workers workers/megatron_workers + workers/automodel_workers workers/sglang_worker workers/trtllm_worker workers/model_engine diff --git a/docs/workers/automodel_workers.rst b/docs/workers/automodel_workers.rst new file mode 100644 index 00000000000..55864db4360 --- /dev/null +++ b/docs/workers/automodel_workers.rst @@ -0,0 +1,65 @@ +Automodel Backend +================= + +Last updated: 03/07/2026. + +We support the Automodel (nemo_automodel) backend by implementing the +``AutomodelEngine`` and ``AutomodelEngineWithLMHead`` engine classes. +The Automodel backend delegates model building, parallelization, optimizer +sharding, LR scheduling, gradient clipping, and checkpointing to +nemo_automodel's infrastructure while using verl's training loop, +data pipeline, and loss function. + +**Requirements** + +- Automodel r0.3.0 +- transformers v5.0.0 + +**Pros** + +- Supports FSDP2 and TP distributed strategies out of + the box. + +- Native support for Mixture-of-Experts (MoE) models with Expert + Parallelism (EP) via DeepEP. + +- TransformerEngine (TE) integration for optimized attention, linear + layers, and RMSNorm. + +- Readily supports any HuggingFace model without checkpoint conversion. + +**Cons** + +- Pipeline parallelism is not yet supported. + + +SFT Examples +------------ + +We provide example SFT training scripts using the Automodel backend in +`examples/sft/gsm8k/ `_. + +Basic: Qwen2.5-0.5B with FSDP2 +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A minimal example using ``Qwen/Qwen2.5-0.5B-Instruct`` with FSDP2 and +no parallelism: + +.. code:: shell + + bash examples/sft/gsm8k/run_qwen_05_automodel.sh 4 /tmp/automodel_sft_test + +See `run_qwen_05_automodel.sh `_. + +Advanced: Qwen3-30B MoE with Expert Parallelism +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A larger-scale example using ``Qwen/Qwen3-30B-A3B-Base`` (MoE model) +with Expert Parallelism (EP=8), DeepEP, TransformerEngine backend, and +torch_mm experts backend: + +.. code:: shell + + bash examples/sft/gsm8k/run_qwen3_30b_automodel.sh 8 /tmp/automodel_sft_30b + +See `run_qwen3_30b_automodel.sh `_. diff --git a/examples/sft/gsm8k/run_qwen3_30b_automodel.sh b/examples/sft/gsm8k/run_qwen3_30b_automodel.sh new file mode 100644 index 00000000000..95d699d218a --- /dev/null +++ b/examples/sft/gsm8k/run_qwen3_30b_automodel.sh @@ -0,0 +1,75 @@ +# Requires: Automodel, transformers>=5.3.0, torchao +# MoE also requires: grouped_gemm (github.com/fanshiqing/grouped_gemm v1.1.4) + +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen3_30b_automodel.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.sft_trainer \ + data.train_files=$HOME/data/hellaswag_sft/hellaswag_sft.parquet \ + data.val_files=$HOME/data/hellaswag_sft/hellaswag_sft.parquet \ + data.train_batch_size=512 \ + data.max_length=2048 \ + data.truncation=left \ + data.use_dynamic_bsz=True \ + data.max_token_len_per_gpu=8192 \ + data.messages_key=messages \ + data.ignore_input_ids_mismatch=True \ + data.train_max_samples=-1 \ + data.val_max_samples=1024 \ + model=hf_model \ + model.path=Qwen/Qwen3-30B-A3B-Base \ + model.trust_remote_code=True \ + model.use_remove_padding=True \ + engine=automodel \ + engine.distributed_strategy=fsdp2 \ + engine.tp_size=1 \ + engine.pp_size=1 \ + engine.cp_size=1 \ + engine.ep_size=8 \ + engine.backend_config.dispatcher=deepep \ + engine.backend_config.attn=te \ + engine.backend_config.linear=te \ + engine.backend_config.rms_norm=torch_fp32 \ + engine.backend_config.enable_fsdp_optimizations=True \ + engine.backend_config.experts=torch_mm \ + engine.activation_checkpointing=True \ + engine.model_dtype=bf16 \ + engine.attn_implementation=te \ + engine.use_torch_compile=False \ + optim=automodel \ + optim.optimizer=FusedAdam \ + optim.optimizer_impl=transformer_engine.pytorch.optimizers.fused_adam \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.1 \ + optim.weight_decay=0 \ + optim.betas='[0.9,0.95]' \ + optim.clip_grad=1.0 \ + optim.init_lr_ratio=0.1 \ + optim.min_lr_ratio=0.01 \ + optim.lr_scheduler_type=cosine \ + optim.master_weights=true \ + optim.store_param_remainders=true \ + optim.exp_avg_dtype=bf16 \ + optim.exp_avg_sq_dtype=bf16 \ + trainer.default_local_dir=$save_path \ + trainer.project_name=hellaswag-sft \ + trainer.experiment_name=hellaswag-sft-qwen3-30b-automodel \ + trainer.total_epochs=2 \ + trainer.total_training_steps=100 \ + trainer.save_freq=-1 \ + trainer.test_freq=10 \ + trainer.logger=console \ + trainer.seed=1111 \ + trainer.nnodes=1 \ + trainer.resume_mode=disable $@ diff --git a/examples/sft/gsm8k/run_qwen_05_automodel.sh b/examples/sft/gsm8k/run_qwen_05_automodel.sh new file mode 100644 index 00000000000..d3c7dd8b01c --- /dev/null +++ b/examples/sft/gsm8k/run_qwen_05_automodel.sh @@ -0,0 +1,55 @@ +# Requires: Automodel, transformers>=5.3.0, torchao +# MoE also requires: grouped_gemm (github.com/fanshiqing/grouped_gemm v1.1.4) + +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_automodel.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.sft_trainer \ + data.train_files=$HOME/data/gsm8k_sft/train.parquet \ + data.val_files=$HOME/data/gsm8k_sft/test.parquet \ + data.train_batch_size=128 \ + data.pad_mode=no_padding \ + data.truncation=error \ + data.use_dynamic_bsz=True \ + data.max_token_len_per_gpu=2048 \ + data.messages_key=messages \ + data.ignore_input_ids_mismatch=True \ + model=hf_model \ + model.path=Qwen/Qwen2.5-0.5B-Instruct \ + model.use_remove_padding=True \ + engine=automodel \ + engine.distributed_strategy=fsdp2 \ + engine.tp_size=1 \ + engine.pp_size=1 \ + engine.cp_size=1 \ + engine.ep_size=1 \ + engine.use_torch_compile=False \ + optim=automodel \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.2 \ + optim.weight_decay=0.1 \ + optim.betas='[0.9,0.95]' \ + optim.clip_grad=1.0 \ + optim.init_lr_ratio=0 \ + optim.min_lr_ratio=0.1 \ + optim.lr_scheduler_type=cosine \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-automodel \ + trainer.total_epochs=2 \ + trainer.test_freq=-1 \ + trainer.save_freq=-1 \ + trainer.logger=console \ + trainer.seed=1111 \ + trainer.resume_mode=disable $@ diff --git a/tests/special_e2e/sft/run_sft_engine.sh b/tests/special_e2e/sft/run_sft_engine.sh index 9fe80afae13..e7350ee99cf 100644 --- a/tests/special_e2e/sft/run_sft_engine.sh +++ b/tests/special_e2e/sft/run_sft_engine.sh @@ -112,6 +112,22 @@ TORCHTITAN_ENGINE_CONFIG="\ engine.data_parallel_shard_size=${FSDP_SIZE} \ engine.use_torch_compile=False" +AUTOMODEL_ENGINE_CONFIG="\ + engine=${backend} \ + model=hf_model \ + model.path=${MODEL_PATH} \ + optim=${backend} \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.2 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.min_lr_ratio=0.1 \ + optim.lr_scheduler_type=cosine \ + engine.tp_size=${TP_SIZE} \ + engine.cp_size=${CP_SIZE} \ + engine.use_torch_compile=False" + if [ "$backend" = "fsdp" ]; then ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" @@ -125,6 +141,10 @@ elif [ "$backend" = "torchtitan" ]; then ENGINE_CONFIG="$TORCHTITAN_ENGINE_CONFIG" echo "Using torchtitan engine" exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-dp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} +elif [ "$backend" = "automodel" ]; then + ENGINE_CONFIG="$AUTOMODEL_ENGINE_CONFIG" + echo "Using automodel engine" + exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} else ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" echo "Using megatron engine" diff --git a/tests/special_e2e/sft/test_sft_engine_all.sh b/tests/special_e2e/sft/test_sft_engine_all.sh index 21524ce1d09..5bf2927eb46 100644 --- a/tests/special_e2e/sft/test_sft_engine_all.sh +++ b/tests/special_e2e/sft/test_sft_engine_all.sh @@ -46,6 +46,14 @@ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 m # echo "run with tp2 pp1 cp1 fsdp2 num_gpus4" # BACKEND=torchtitan TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh +# # test with automodel dp=2 +# echo "run with automodel tp1 pp1 cp1 dp2 num_gpus2" +# BACKEND=automodel TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=2 bash tests/special_e2e/sft/run_sft_engine.sh + +# # test with automodel tp2 dp=2 +# echo "run with automodel tp2 pp1 cp1 dp2 num_gpus4" +# BACKEND=automodel TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh + python3 tests/special_e2e/sft/compare_sft_engine_results.py rm -rf ~/verl/test/log diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index 46461590e94..dda18d5278d 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -44,6 +44,7 @@ "verl/workers/engine/veomni/transformer_impl.py", # appear in default device_name "verl/workers/engine/torchtitan/transformer_impl.py", # appear in default device_name "verl/workers/engine/torchtitan/utils.py", # appear in torch.cuda.empty_cache() + "verl/workers/engine/automodel/transformer_impl.py", # appear in default device_name "verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes "verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES "verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes diff --git a/verl/trainer/config/engine/automodel.yaml b/verl/trainer/config/engine/automodel.yaml new file mode 100644 index 00000000000..ea731aec88c --- /dev/null +++ b/verl/trainer/config/engine/automodel.yaml @@ -0,0 +1,82 @@ +# Target class for this configuration +_target_: verl.workers.config.AutomodelEngineConfig + +# Backend strategy identifier +strategy: automodel + +# Distributed training strategy: "fsdp2", "megatron_fsdp", or "ddp" +distributed_strategy: fsdp2 + +# Parallelism sizes +tp_size: 1 +pp_size: 1 +cp_size: 1 +ep_size: 1 +dp_replicate_size: 1 +sequence_parallel: false +defer_fsdp_grad_sync: true + +# Whether to offload model parameters to CPU +param_offload: false + +# Whether to offload optimizer state to CPU +optimizer_offload: false + +# Whether to enable activation checkpointing +activation_checkpointing: false + +# Whether to enable FP8 training +enable_fp8: false + +# Whether to enable torch.compile for the model +enable_compile: false + +# Model data type for loading weights ("fp32", "bf16", "fp16") +model_dtype: fp32 + +# Attention implementation ("sdpa", "flash_attention_2", "eager", "te") +attn_implementation: flash_attention_2 + +# Backend settings +backend_config: + attn: sdpa # "te", "sdpa" + linear: te # "torch", "te" + rms_norm: torch_fp32 # "torch", "torch_fp32", "te" + rope_fusion: true + dispatcher: torch # "torch", "deepep" + experts: gmm # "gmm", "torch_mm", "torch", "te" + gate_precision: null + enable_hf_state_dict_adapter: true + enable_fsdp_optimizations: false + fake_balanced_gate: false + fake_gate_noise: 0.0 + +# MoE settings (MoEParallelizerConfig) +moe_config: + ignore_router_for_ac: false + reshard_after_forward: false + lm_head_precision: null + wrap_outer_model: true + +# Mixed precision policy (FSDP2 MixedPrecisionPolicy) +mp_param_dtype: bf16 +mp_reduce_dtype: fp32 +mp_output_dtype: bf16 + +# Random seed for reproducibility +seed: 42 + +# Whether to enable full determinism for distributed training, only for debugging +full_determinism: false + +# Whether to use forward only mode +forward_only: false + +# Whether to use torch compile for entropy computation +use_torch_compile: false + +# Whether to use chunked entropy computation +entropy_from_logits_with_chunking: false + +# Whether to use checkpointing for entropy computation +entropy_checkpointing: false diff --git a/verl/trainer/config/optim/automodel.yaml b/verl/trainer/config/optim/automodel.yaml new file mode 100644 index 00000000000..9e06ffc6ce0 --- /dev/null +++ b/verl/trainer/config/optim/automodel.yaml @@ -0,0 +1,56 @@ +# Target class for this configuration +_target_: verl.workers.config.AutomodelOptimizerConfig + +optimizer: AdamW + +# Module path to import optimizer from +optimizer_impl: torch.optim + +# Learning rate (maps to max_lr in Automodel's OptimizerParamScheduler) +lr: 1e-5 + +# LR warmup steps ratio (used when lr_warmup_steps <= 0) +lr_warmup_steps_ratio: 0.0 + +# Total training steps (injected at runtime) +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps (set > 0 to override lr_warmup_steps_ratio) +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient norm +clip_grad: 1.0 + +# Initial LR ratio for warmup start (init_lr = lr * init_lr_ratio) +init_lr_ratio: 0.1 + +# Minimum LR ratio after decay (min_lr = lr * min_lr_ratio) +min_lr_ratio: 0.01 + +# LR scheduler type (Automodel OptimizerParamScheduler decay style) +# Options: "constant", "cosine", "linear", "inverse-square-root" +lr_scheduler_type: cosine + +# Weight decay increment style: "constant", "linear", or "cosine" +wd_incr_style: constant + +# Kept for backward compatibility (unused by Automodel scheduler) +num_cycles: 0.5 +zero_indexed_step: true + +# Common optimizer kwargs +eps: 1e-8 +master_weights: false +store_param_remainders: false +exp_avg_dtype: null # "fp32", "bf16" +exp_avg_sq_dtype: null # "fp32", "bf16" +master_weight_dtype: null # "fp32", "bf16" + +# Additional optimizer kwargs (passed directly to constructor) +override_optimizer_config: {} diff --git a/verl/utils/dataset/multiturn_sft_dataset.py b/verl/utils/dataset/multiturn_sft_dataset.py index 081d1dcfafa..5e950842298 100644 --- a/verl/utils/dataset/multiturn_sft_dataset.py +++ b/verl/utils/dataset/multiturn_sft_dataset.py @@ -64,7 +64,8 @@ def print_assembled_message(tokenizer, message_list, input_ids, loss_mask, attn_ sep = "\n\n" str = f"tokenized entire message:\n{tokenized}" str += sep - str += f"tokenized seperately :\n{tokenizer.decode(input_ids)}" + decoded_ids = input_ids.tolist() if hasattr(input_ids, "tolist") else input_ids + str += f"tokenized seperately :\n{tokenizer.decode(decoded_ids)}" logger.debug(str) diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index 41fc8181c2a..b193c6f0669 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -29,6 +29,7 @@ "TrainingWorkerConfig", "TorchtitanEngineConfig", "VeOmniEngineConfig", + "AutomodelEngineConfig", "EngineConfig", "EngineRouterReplayConfig", "QATEngineConfig", @@ -396,6 +397,127 @@ def __post_init__(self): assert self.strategy in ["torchtitan"], f"strategy {self.strategy} not supported" +@dataclass +class AutomodelEngineConfig(EngineConfig): + """Configuration for Automodel (nemo_automodel) backend. + + The Automodel backend uses NeMoAutoModelForCausalLM for model loading and + supports FSDP2, MegatronFSDP, and DDP distributed strategies with optional + TP, CP, and EP parallelism. + + Args: + strategy (str): Backend strategy identifier, must be "automodel". + distributed_strategy (str): Distributed training strategy: "fsdp2", "megatron_fsdp", or "ddp". + tp_size (int): Tensor parallel size. + pp_size (int): Pipeline parallel size (only pp_size=1 supported initially). + cp_size (int): Context parallel size. + ep_size (int): Expert parallel size for MoE models. + dp_replicate_size (int): Data-parallel replicate size for HSDP. 1 = pure sharding. + sequence_parallel (bool): Enable sequence parallelism in the TP plan. + defer_fsdp_grad_sync (bool): Defer FSDP gradient sync to the final micro-batch. + activation_checkpointing (bool): Whether to enable activation checkpointing. + enable_fp8 (bool): Whether to enable FP8 training. + enable_compile (bool): Whether to enable torch.compile for the model. + model_dtype (str): Model data type for loading weights. "fp32" loads in float32 + (matching FSDP golden), "auto" uses the dtype from the model config. + attn_implementation (str): Attention implementation to use ("sdpa", "flash_attention_2", "eager", "te"). + + Backend settings (nemo_automodel BackendConfig): + backend_config (dict): Dict of kwargs passed directly to + nemo_automodel.components.models.common.BackendConfig(**backend_config). + Controls how model layers are implemented (TE vs PyTorch) and MoE dispatch. + See automodel.yaml for all predefined keys with defaults. + Key fields: + attn (str): Attention backend. "te" = TransformerEngine fused attention, + "sdpa" = PyTorch scaled dot-product attention. Default: "sdpa". + linear (str): Linear layer backend. "te" = TE fused linear (with FP8 support), + "torch" = standard PyTorch linear. Default: "te". + rms_norm (str): RMSNorm backend. "te" = TE fused RMSNorm, "torch" = PyTorch, + "torch_fp32" = PyTorch in FP32 (better numerical stability for MoE). + Default: "torch_fp32". + rope_fusion (bool): Enable fused RoPE kernel (requires CP=1). Default: true. + experts (str): MoE expert computation backend. + "gmm" = grouped_gemm (requires pip install grouped_gemm), + "torch_mm" = torch._grouped_mm (no external dependency), + "te" = TE GroupedLinear. Default: "gmm". + dispatcher (str): MoE token dispatch strategy. + "torch" = standard all-gather + local compute, + "deepep" = DeepEP optimized all-to-all (higher throughput). + Default: "torch". + Note: "deepep" with experts="gmm" matches the legacy enable_deepep=True behavior. + enable_fsdp_optimizations (bool): Enable FSDP-specific optimizations in Automodel. + Default: false. + enable_hf_state_dict_adapter (bool): Enable HuggingFace state dict adapter for + checkpoint compatibility. Default: true. + fake_balanced_gate (bool): Use fake balanced gating for debugging. Default: false. + fake_gate_noise (float): Noise added to fake balanced gate. Default: 0.0. + gate_precision: Gate computation precision. Default: null (auto). + Full reference: nemo_automodel/components/models/common/backend_config.py + + MoE / Expert Parallelism settings: + moe_config (dict): Dict of kwargs passed directly to + nemo_automodel.components.moe.parallelizer.MoEParallelizerConfig(**moe_config). + Controls MoE parallelization behavior within FSDP2. + See automodel.yaml for all predefined keys with defaults. + Key fields: + ignore_router_for_ac (bool): Exclude router from activation checkpointing. + Default: false. + reshard_after_forward (bool): Reshard expert params after forward pass + (trades compute for memory). Default: false. + lm_head_precision: Precision for the LM head. Default: null (auto). + wrap_outer_model (bool): Whether to FSDP-wrap the outermost model module. + Default: true. + Full reference: nemo_automodel/components/moe/parallelizer.py + + Mixed precision policy (FSDP2): + mp_param_dtype (str): Parameter dtype for FSDP2 mixed precision policy. + mp_reduce_dtype (str): Reduce dtype for FSDP2 mixed precision policy. + mp_output_dtype (str): Output dtype for FSDP2 mixed precision policy. + + Entropy computation: + entropy_from_logits_with_chunking (bool): Whether to use chunked entropy computation. + use_torch_compile (bool): Whether to use torch.compile for entropy computation. + entropy_checkpointing (bool): Whether to use checkpointing for entropy computation. + """ + + strategy: str = "automodel" + distributed_strategy: str = "fsdp2" + # Parallelism sizes + tp_size: int = 1 + pp_size: int = 1 + cp_size: int = 1 + ep_size: int = 1 + dp_replicate_size: int = 1 + sequence_parallel: bool = False + defer_fsdp_grad_sync: bool = True + # Model settings + activation_checkpointing: bool = False + enable_fp8: bool = False + enable_compile: bool = False + model_dtype: str = "fp32" + attn_implementation: str = "flash_attention_2" + # Backend settings + backend_config: dict = field(default_factory=dict) + # MoE settings + moe_config: dict = field(default_factory=dict) + # Mixed precision policy + mp_param_dtype: str = "bf16" + mp_reduce_dtype: str = "fp32" + mp_output_dtype: str = "bf16" + # Entropy computation + entropy_from_logits_with_chunking: bool = False + use_torch_compile: bool = True + entropy_checkpointing: bool = False + + def __post_init__(self): + super().__post_init__() + assert self.strategy == "automodel", f"strategy must be 'automodel', got {self.strategy}" + assert self.distributed_strategy in ["fsdp2", "megatron_fsdp", "ddp"], ( + f"distributed_strategy {self.distributed_strategy} not supported" + ) + assert self.pp_size == 1, "Pipeline parallelism (pp_size > 1) is not yet supported for automodel backend" + + @dataclass class TrainingWorkerConfig(BaseConfig): model_type: str = None # model type (language_model/value_model) diff --git a/verl/workers/config/optimizer.py b/verl/workers/config/optimizer.py index b7f05bef518..47afdd3bf2e 100644 --- a/verl/workers/config/optimizer.py +++ b/verl/workers/config/optimizer.py @@ -26,6 +26,7 @@ "build_optimizer", "VeOmniOptimizerConfig", "TorchtitanOptimizerConfig", + "AutomodelOptimizerConfig", ] @@ -170,6 +171,50 @@ class TorchtitanOptimizerConfig(OptimizerConfig): min_lr_factor: float = 0.0 +@dataclass +class AutomodelOptimizerConfig(OptimizerConfig): + """Automodel optimizer configuration extending base OptimizerConfig. + + Uses the same optimizer building mechanism as FSDP (dynamic import from optimizer_impl). + LR scheduling is handled by Automodel's OptimizerParamScheduler. + + Args: + optimizer (str): Optimizer class name (e.g., "AdamW"). + optimizer_impl (str): Module path to import optimizer from (e.g., "torch.optim"). + lr (float): Learning rate (maps to max_lr in OptimizerParamScheduler). + init_lr_ratio (Optional[float]): Initial LR ratio for warmup start (init_lr = lr * init_lr_ratio). + min_lr_ratio (Optional[float]): Minimum LR ratio after decay (min_lr = lr * min_lr_ratio). + lr_scheduler_type (str): LR decay style: "constant", "cosine", "linear", or "inverse-square-root". + wd_incr_style (str): Weight decay increment style: "constant", "linear", or "cosine". + num_cycles (float): Kept for backward compatibility (unused by Automodel scheduler). + zero_indexed_step (bool): Kept for backward compatibility (unused by Automodel scheduler). + """ + + _mutable_fields = OptimizerConfig._mutable_fields.copy() + _mutable_fields.add("lr_scheduler_type") + + optimizer: str = "AdamW" + optimizer_impl: str = "torch.optim" + init_lr_ratio: Optional[float] = 0.1 + min_lr_ratio: Optional[float] = 0.01 + lr_scheduler_type: str = "cosine" + wd_incr_style: str = "constant" + num_cycles: float = 0.5 + zero_indexed_step: bool = True + # Common optimizer kwargs + eps: float = 1e-8 + master_weights: bool = False + store_param_remainders: bool = False + exp_avg_dtype: Optional[str] = None # "fp32", "bf16", "fp16", or "torch.float32" etc. + exp_avg_sq_dtype: Optional[str] = None # "fp32", "bf16", "fp16", or "torch.float32" etc. + master_weight_dtype: Optional[str] = None # "fp32", "bf16", "fp16", or "torch.float32" etc. + override_optimizer_config: Optional[dict] = None + + def __post_init__(self): + assert self.lr_scheduler_type in ["constant", "cosine", "linear", "inverse-square-root"] + return super().__post_init__() + + def build_optimizer(parameters, config: FSDPOptimizerConfig): """Build an optimizer based on the configuration. diff --git a/verl/workers/engine/__init__.py b/verl/workers/engine/__init__.py index 8f01080fdcb..009f0a8fc8b 100644 --- a/verl/workers/engine/__init__.py +++ b/verl/workers/engine/__init__.py @@ -37,6 +37,14 @@ VeOmniEngine = None VeOmniEngineWithLMHead = None +try: + from .automodel import AutomodelEngine, AutomodelEngineWithLMHead + + __all__ += ["AutomodelEngine", "AutomodelEngineWithLMHead"] +except ImportError: + AutomodelEngine = None + AutomodelEngineWithLMHead = None + # Mindspeed must be imported before Megatron to ensure the related monkey patches take effect as expected try: from .mindspeed import MindspeedEngineWithLMHead diff --git a/verl/workers/engine/automodel/__init__.py b/verl/workers/engine/automodel/__init__.py new file mode 100644 index 00000000000..a839342706b --- /dev/null +++ b/verl/workers/engine/automodel/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .transformer_impl import AutomodelEngine, AutomodelEngineWithLMHead + +__all__ = [ + "AutomodelEngine", + "AutomodelEngineWithLMHead", +] diff --git a/verl/workers/engine/automodel/transformer_impl.py b/verl/workers/engine/automodel/transformer_impl.py new file mode 100644 index 00000000000..fc71384a323 --- /dev/null +++ b/verl/workers/engine/automodel/transformer_impl.py @@ -0,0 +1,713 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Automodel (nemo_automodel) engine for verl SFT training. + +This engine delegates model building, parallelization, optimizer sharding, +LR scheduling, gradient clipping, and checkpointing to Automodel's +infrastructure while using verl's training loop, data pipeline, and loss function. +""" + +import gc +import logging +import os +from contextlib import nullcontext +from typing import Any, Callable, Optional + +import torch +import torch.distributed +from huggingface_hub.constants import HF_HUB_CACHE +from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig +from nemo_automodel.components.optim.scheduler import OptimizerParamScheduler +from nemo_automodel.components.training.utils import ( + prepare_for_final_backward, + prepare_for_grad_accumulation, + scale_grads_and_clip_grad_norm, +) +from tensordict import TensorDict +from torch.distributed.tensor import DTensor + +import verl.utils.torch_functional as verl_F +from verl.trainer.config import CheckpointConfig +from verl.utils import tensordict_utils as tu +from verl.utils.dataset.dataset_utils import DatasetPadMode +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_device_id, get_device_name +from verl.utils.model import convert_weight_keys, extract_multi_modal_inputs +from verl.utils.torch_functional import logprobs_from_logits +from verl.workers.config import AutomodelEngineConfig, AutomodelOptimizerConfig, HFModelConfig + +from ..base import BaseEngine, BaseEngineCtx, EngineRegistry +from ..utils import enable_full_determinism, postprocess_batch_func, prepare_micro_batches +from .utils import ( + build_automodel_model, + build_distributed_config_from_engine_config, + get_dp_group_size, + get_dp_rank, + get_pp_rank, + get_tp_rank, + load_automodel_model_to_gpu, + load_automodel_optimizer, + maybe_fully_shard_optimizer, + offload_automodel_model_to_cpu, + offload_automodel_optimizer, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class AutomodelEngine(BaseEngine): + """Engine implementation using Automodel for distributed training.""" + + def __init__( + self, + model_config: HFModelConfig, + engine_config: AutomodelEngineConfig, + optimizer_config: AutomodelOptimizerConfig, + checkpoint_config: CheckpointConfig, + **kwargs, + ): + super().__init__() + + self.model_config = model_config + self.engine_config = engine_config + self.optimizer_config = optimizer_config + self.checkpoint_config = checkpoint_config + + self.mode = None + self.rank = torch.distributed.get_rank() + + # Apply compatibility patches early in the process + from nemo_automodel._transformers.utils import apply_cache_compatibility_patches + from nemo_automodel.shared.te_patches import apply_te_patches + + apply_cache_compatibility_patches() + apply_te_patches() + + world_size = torch.distributed.get_world_size() + self.distributed_config, self.device_mesh, self.moe_mesh = build_distributed_config_from_engine_config( + self.engine_config, world_size + ) + + if self.engine_config.full_determinism: + enable_full_determinism(seed=self.engine_config.seed) + + self._is_offload_param = self.engine_config.param_offload + self._is_offload_optimizer = self.engine_config.optimizer_offload + + if self.engine_config.entropy_from_logits_with_chunking: + entropy_from_logits = verl_F.entropy_from_logits_with_chunking + else: + entropy_from_logits = verl_F.entropy_from_logits + + self.compute_entropy_from_logits = ( + torch.compile(entropy_from_logits, dynamic=True) + if self.engine_config.use_torch_compile + else entropy_from_logits + ) + + @property + def is_param_offload_enabled(self) -> bool: + return self._is_offload_param + + @property + def is_optimizer_offload_enabled(self) -> bool: + return self._is_offload_optimizer + + def initialize(self): + """Build the model, optimizer, LR scheduler, and checkpointer using Automodel infrastructure.""" + self.module = build_automodel_model( + self.model_config, self.engine_config, self.distributed_config, self.device_mesh, self.moe_mesh + ) + log_gpu_memory_usage("After Automodel model build", logger=logger) + + if not self.engine_config.forward_only: + self.optimizer = self._build_optimizer(self.module) + # maybe shard optimizer for MegatronFSDP + maybe_fully_shard_optimizer(self.module, self.optimizer, self.distributed_config) + self.lr_scheduler = self._build_lr_scheduler(self.optimizer) + else: + self.optimizer = None + self.lr_scheduler = None + self._build_checkpointer() + + self.to( + device="cpu", + model=self._is_offload_param, + optimizer=self._is_offload_optimizer, + grad=self._is_offload_param, + ) + + log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger) + torch.cuda.empty_cache() + + def _build_optimizer(self, module): + """Build optimizer via Automodel's build_optimizer.""" + from nemo_automodel.components.config.loader import ConfigNode + from nemo_automodel.recipes.llm.train_ft import build_optimizer as automodel_build_optimizer + + config = self.optimizer_config + + opt_dict = { + "_target_": f"{config.optimizer_impl}.{config.optimizer}", + "lr": config.lr, + "weight_decay": config.weight_decay, + "eps": config.eps, + "betas": list(config.betas), + } + + if config.master_weights: + opt_dict["master_weights"] = config.master_weights + if config.store_param_remainders: + opt_dict["store_param_remainders"] = config.store_param_remainders + + _short_to_torch = {"bf16": "torch.bfloat16", "fp32": "torch.float32", "fp16": "torch.float16"} + for attr in ("exp_avg_dtype", "exp_avg_sq_dtype", "master_weight_dtype"): + val = getattr(config, attr, None) + if val is not None: + opt_dict[attr] = _short_to_torch.get(val, val) + + if config.override_optimizer_config: + opt_dict.update(config.override_optimizer_config) + + cfg_opt = ConfigNode(opt_dict) + optimizers = automodel_build_optimizer(module, cfg_opt, self.distributed_config, self.device_mesh) + assert len(optimizers) == 1, f"Expected 1 optimizer, got {len(optimizers)}" + return optimizers[0] + + def _build_lr_scheduler(self, optimizer): + cfg = self.optimizer_config + total_steps = cfg.total_training_steps + num_warmup_steps = cfg.lr_warmup_steps + + if num_warmup_steps <= 0: + num_warmup_steps = int(cfg.lr_warmup_steps_ratio * total_steps) + + base_lr = cfg.lr + init_lr_ratio = cfg.init_lr_ratio if cfg.init_lr_ratio is not None else 0.1 + min_lr_ratio = cfg.min_lr_ratio if cfg.min_lr_ratio is not None else 0.01 + + if self.rank == 0: + print( + f"Automodel LR Scheduler: total_steps={total_steps}, warmup={num_warmup_steps}, " + f"decay_style={cfg.lr_scheduler_type}, init_lr={base_lr * init_lr_ratio:.2e}, " + f"max_lr={base_lr:.2e}, min_lr={base_lr * min_lr_ratio:.2e}" + ) + + scheduler = OptimizerParamScheduler( + optimizer=optimizer, + init_lr=base_lr * init_lr_ratio, + max_lr=base_lr, + min_lr=base_lr * min_lr_ratio, + lr_warmup_steps=num_warmup_steps, + lr_decay_steps=total_steps, + lr_decay_style=cfg.lr_scheduler_type, + start_wd=cfg.weight_decay, + end_wd=cfg.weight_decay, + wd_incr_steps=total_steps, + wd_incr_style=getattr(cfg, "wd_incr_style", "constant"), + ) + return scheduler + + def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False) -> Any: + batch_num_tokens = data["loss_mask"].sum().to(get_device_id()) + torch.distributed.all_reduce( + batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=self.get_data_parallel_group() + ) + tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item()) + tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size()) + + micro_batches, indices = prepare_micro_batches( + data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True + ) + + output_lst = [] + ctx = torch.no_grad() if forward_only else nullcontext() + + if not forward_only: + prepare_for_grad_accumulation([self.module]) + + # Set MoE aux loss backward scale to counteract FSDP's gradient allreduce. + if self.engine_config.ep_size > 1: + from nemo_automodel.components.moe.megatron.moe_utils import MoEAuxLossAutoScaler + + MoEAuxLossAutoScaler.main_loss_backward_scale = torch.tensor( + float(get_dp_group_size(self.device_mesh, include_cp=True)) + ) + + num_micro_batches = len(micro_batches) + for i, micro_batch in enumerate(micro_batches): + # Signal final backward for MoE + if not forward_only and i == num_micro_batches - 1: + prepare_for_final_backward([self.module]) + + with ctx: + loss, meta_info = self.forward_step(micro_batch, loss_function=loss_function, forward_only=forward_only) + if not forward_only: + loss.backward() + output_lst.append(meta_info) + + return postprocess_batch_func(output_lst=output_lst, indices=indices, data=data) + + def forward_step(self, micro_batch: TensorDict, loss_function, forward_only): + raise NotImplementedError("forward_step must be implemented in subclass") + + def optimizer_zero_grad(self): + self.optimizer.zero_grad() + + def optimizer_step(self): + grad_norm = scale_grads_and_clip_grad_norm( + max_grad_norm=self.optimizer_config.clip_grad, + model_parts=[self.module], + norm_type=2.0, + pp_enabled=False, + device_mesh=self.device_mesh, + moe_mesh=self.moe_mesh, + ep_axis_name="ep" if self.moe_mesh is not None and "ep" in self.moe_mesh.mesh_dim_names else None, + pp_axis_name=None, + foreach=True, + num_label_tokens=None, + dp_group_size=get_dp_group_size(self.device_mesh, include_cp=True), + ) + + if isinstance(grad_norm, torch.Tensor): + grad_norm_val = grad_norm.item() + else: + grad_norm_val = float(grad_norm) + + # If grad_norm is not finite, skip the update + if not torch.isfinite(torch.tensor(grad_norm_val)): + print(f"WARN: grad_norm is not finite: {grad_norm_val}") + self.optimizer.zero_grad() + else: + self.optimizer.step() + if hasattr(self.module, "update_moe_gate_bias"): + self.module.update_moe_gate_bias() + + return grad_norm_val + + def lr_scheduler_step(self): + """Step Automodel's OptimizerParamScheduler and return current LR.""" + self.lr_scheduler.step(increment=1) + lr = self.optimizer.param_groups[0]["lr"] + return lr + + def get_data_parallel_rank(self): + if self.device_mesh is not None: + return self.device_mesh.get_local_rank("dp") + return torch.distributed.get_rank() + + def get_data_parallel_size(self): + if self.device_mesh is not None: + return self.device_mesh["dp"].size() + return torch.distributed.get_world_size() + + def get_data_parallel_group(self): + if self.device_mesh is not None: + return self.device_mesh.get_group(mesh_dim="dp") + return torch.distributed.group.WORLD + + def is_mp_src_rank_with_outputs(self): + if self.device_mesh is not None and "tp" in self.device_mesh.mesh_dim_names: + if self.device_mesh["tp"].size() > 1: + return self.device_mesh.get_local_rank("tp") == 0 + return True + + def train_mode(self, **kwargs): + return AutomodelTrainModeCtx(self, **kwargs) + + def eval_mode(self, **kwargs): + return AutomodelEvalModeCtx(self, **kwargs) + + def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): + super().to(device=device, model=model, optimizer=optimizer, grad=grad) + + if self.engine_config.forward_only: + return + + device_name = get_device_name() + assert device in (device_name, "cpu") + + if device == device_name: + if model: + load_automodel_model_to_gpu(self.module) + if optimizer and self.optimizer is not None: + load_automodel_optimizer(self.optimizer, get_device_id()) + gc.collect() + elif device == "cpu": + if model: + offload_automodel_model_to_cpu(self.module) + if optimizer and self.optimizer is not None: + offload_automodel_optimizer(self.optimizer) + else: + raise ValueError(f"Invalid device type: {device}") + + def _build_checkpointer(self): + ckpt_config = CheckpointingConfig( + enabled=True, + checkpoint_dir="checkpoints/", + model_save_format="safetensors", + model_cache_dir=HF_HUB_CACHE, + model_repo_id=self.model_config.path, + save_consolidated=True, + is_peft=False, + ) + self.checkpointer = Checkpointer( + config=ckpt_config, + dp_rank=get_dp_rank(self.device_mesh, include_cp=True), + tp_rank=get_tp_rank(self.device_mesh), + pp_rank=get_pp_rank(self.device_mesh), + moe_mesh=self.moe_mesh, + ) + + def save_checkpoint( + self, + local_path: str, + hdfs_path: Optional[str] = None, + global_step: int = 0, + max_ckpt_to_keep: Optional[int] = None, + **kwargs, + ) -> None: + """Save model, optimizer, and LR scheduler using Automodel's Checkpointer.""" + origin_module_device = next(self.module.parameters()).device.type + if self._is_offload_param or origin_module_device == "cpu": + load_automodel_model_to_gpu(self.module) + + # Save model weights + self.checkpointer.save_model(self.module, local_path) + + # Save optimizer and LR scheduler state + if self.optimizer is not None: + scheduler_list = [self.lr_scheduler] if self.lr_scheduler is not None else None + self.checkpointer.save_optimizer(self.optimizer, self.module, local_path, scheduler=scheduler_list) + + torch.distributed.barrier() + if self._is_offload_param: + offload_automodel_model_to_cpu(self.module) + + def load_checkpoint( + self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs + ) -> None: + """Load model, optimizer, and LR scheduler using Automodel's Checkpointer.""" + if self._is_offload_param: + load_automodel_model_to_gpu(self.module) + + model_path = os.path.join(local_path, "model") + if not os.path.isdir(model_path): + model_path = local_path + self.checkpointer.load_model(self.module, model_path) + + if self.optimizer is not None: + scheduler_list = [self.lr_scheduler] if self.lr_scheduler is not None else None + self.checkpointer.load_optimizer(self.optimizer, self.module, local_path, scheduler=scheduler_list) + + torch.distributed.barrier() + if self._is_offload_param: + offload_automodel_model_to_cpu(self.module) + + if self._is_offload_optimizer and self.optimizer is not None: + offload_automodel_optimizer(self.optimizer) + + def get_per_tensor_param(self, **kwargs): + load_automodel_model_to_gpu(self.module) + + params = self.module.state_dict() + params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) + + if self._is_offload_param: + offload_automodel_model_to_cpu(self.module) + + def param_generator(): + for name, param in params.items(): + unsharded_tensor = param.full_tensor() if isinstance(param, DTensor) else param + yield name, unsharded_tensor + + return param_generator(), None + + +class AutomodelEvalModeCtx(BaseEngineCtx): + def __init__(self, engine: AutomodelEngine, **kwargs): + super().__init__(engine=engine, mode="eval", **kwargs) + + def __enter__(self): + assert isinstance(self.engine, AutomodelEngine) + super().__enter__() + self.engine.module.eval() + + def __exit__(self, exc_type, exc_value, traceback): + assert isinstance(self.engine, AutomodelEngine) + # Reshard the root FSDP module + if hasattr(self.engine.module, "reshard"): + self.engine.module.reshard() + super().__exit__(exc_type, exc_value, traceback) + + +class AutomodelTrainModeCtx(BaseEngineCtx): + def __init__(self, engine: AutomodelEngine, **kwargs): + super().__init__(engine=engine, mode="train", **kwargs) + + def __enter__(self): + assert isinstance(self.engine, AutomodelEngine) + super().__enter__() + self.engine.module.train() + + def __exit__(self, exc_type, exc_value, traceback): + assert isinstance(self.engine, AutomodelEngine) + self.engine.optimizer_zero_grad() + super().__exit__(exc_type, exc_value, traceback) + + +@EngineRegistry.register(model_type="language_model", backend=["automodel"], device=["cuda"]) +class AutomodelEngineWithLMHead(AutomodelEngine): + """Automodel engine for language model with LM head training.""" + + def prepare_model_inputs(self, micro_batch: TensorDict): + use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) + use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False) + temperature = micro_batch["temperature"] + temperature_item = temperature + if use_fused_kernels: + assert not isinstance(temperature, torch.Tensor), ( + "use_fused_kernels does not support per sample temperature yet" + ) + assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported" + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch.get("multi_modal_inputs", [])) + input_ids = micro_batch["input_ids"] + position_ids = micro_batch["position_ids"] + + if not isinstance(temperature, torch.Tensor): + temperature = torch.tensor([temperature] * input_ids.shape[0], device=input_ids.device) + + temperature = temperature.to(torch.float32) + assert temperature.shape[0] == input_ids.shape[0] + + output_args = {} + + if use_remove_padding: + temperature_rmpad = verl_F.expand_as_nested(temperature, input_ids).values() + temperature_rmpad = temperature_rmpad.unsqueeze(0) + + if pad_mode == DatasetPadMode.NO_PADDING: + input_ids_rmpad = input_ids.values().unsqueeze(0) + if position_ids.dim() == 3: + position_ids_rmpad = position_ids.values().unsqueeze(1) + else: + position_ids_rmpad = position_ids.values().unsqueeze(0) + else: + raise NotImplementedError(f"pad_mode {pad_mode} not implemented") + + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) + + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) + temperature_rmpad = temperature_rmpad.squeeze(0) + output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled + output_args["temperature_rmpad"] = temperature_rmpad + + model_inputs = { + "input_ids": input_ids_rmpad, + "attention_mask": None, + "position_ids": position_ids_rmpad, + } + + # For TE attention backend, pass cu_seqlens + if self.engine_config.attn_implementation == "te": + cu_seqlens = input_ids.offsets().to(torch.int32) + max_seqlen = cu_seqlens.diff().max().item() + model_inputs["qkv_format"] = "thd" + model_inputs["cu_seqlens"] = cu_seqlens.unsqueeze(0) + model_inputs["max_seqlen"] = max_seqlen + + else: + if pad_mode == DatasetPadMode.NO_PADDING: + input_ids = micro_batch["input_ids"] + position_ids = micro_batch["position_ids"] + loss_mask = micro_batch["loss_mask"] + + pad_token_id = tu.get_non_tensor_data(data=micro_batch, key="pad_token_id", default=0) + batch_size = micro_batch.batch_size[0] + seq_len_effective = input_ids.offsets().diff() + max_seq_len = max(seq_len_effective) + + input_ids_rmpad_rolled = torch.roll(input_ids.values(), shifts=-1, dims=0) + output_args["input_ids_rmpad_rolled"] = input_ids_rmpad_rolled + output_args["temperature"] = temperature + + input_ids = torch.nested.to_padded_tensor( + input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len) + ) + + if position_ids.dim() == 3: + position_ids = torch.nested.to_padded_tensor( + position_ids, padding=0, output_size=(batch_size, 4, max_seq_len) + ).transpose(0, 1) + else: + position_ids = torch.nested.to_padded_tensor( + position_ids, padding=0, output_size=(batch_size, max_seq_len) + ) + + attention_mask_list = [torch.ones_like(t, dtype=torch.int32) for t in loss_mask] + attention_mask = torch.nested.as_nested_tensor(attention_mask_list, layout=torch.jagged) + attention_mask = torch.nested.to_padded_tensor( + attention_mask, padding=0, output_size=(batch_size, max_seq_len) + ) + + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + } + + else: + raise NotImplementedError(f"pad_mode {pad_mode} not implemented") + + extra_args = {} + if use_fused_kernels: + extra_args["temperature"] = temperature_item + extra_args["return_dict"] = True + + model_inputs.update(multi_modal_inputs) + model_inputs.update(extra_args) + + return model_inputs, output_args + + def prepare_model_outputs(self, output, output_args, micro_batch: TensorDict): + use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) + use_fused_kernels = tu.get_non_tensor_data(data=micro_batch, key="use_fused_kernels", default=False) + calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False) + + if isinstance(output, torch.Tensor): + from types import SimpleNamespace + + output = SimpleNamespace(logits=output) + + model_output = {} + input_ids = micro_batch["input_ids"] + + if use_remove_padding: + input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"] + temperature_rmpad = output_args["temperature_rmpad"] + + if use_fused_kernels: + log_probs = output.log_probs.squeeze(0) + entropy_rmpad = output.entropy.squeeze(0) + else: + logits_rmpad = output.logits.squeeze(0) + # With TP, logits are DTensors sharded on vocab dim; gather for log_softmax. + if isinstance(logits_rmpad, DTensor): + logits_rmpad = logits_rmpad.full_tensor() + logits_rmpad = logits_rmpad / temperature_rmpad.clamp(min=1e-8).unsqueeze(-1).to(logits_rmpad.dtype) + + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=input_ids_rmpad_rolled, + inplace_backward=inplace_backward, + ) + + if calculate_entropy: + if not self.engine_config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint( + self.compute_entropy_from_logits, logits_rmpad + ) + + if pad_mode == DatasetPadMode.NO_PADDING: + cu_seqlens = input_ids.offsets() + log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens) + if calculate_entropy: + entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) + else: + raise NotImplementedError(f"pad_mode {pad_mode} not implemented") + + else: + response_length = tu.get_non_tensor_data(data=micro_batch, key="max_response_length", default=1024) + if use_fused_kernels: + log_probs = output.log_probs[:, -response_length - 1 : -1] + entropy = output.entropy[:, -response_length - 1 : -1] + else: + logits = output.logits + # With TP, logits are DTensors sharded on vocab dim; gather for log_softmax. + if isinstance(logits, DTensor): + logits = logits.full_tensor() + temperature = output_args["temperature"] + temperature = temperature.unsqueeze(-1).unsqueeze(-1) + logits = logits / temperature.clamp(min=1e-8).to(logits.dtype) + + if calculate_entropy: + if not self.engine_config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) + else: + entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) + + if pad_mode == DatasetPadMode.NO_PADDING: + cu_seqlens = input_ids.offsets() + seq_lengths = cu_seqlens.diff() + starts = torch.zeros_like(seq_lengths, dtype=torch.int64) + logits = torch.nested.narrow(logits, 1, starts, seq_lengths, layout=torch.jagged) + logits_rmpad = torch.cat([t for t in logits.unbind()]) + input_ids_rmpad_rolled = output_args["input_ids_rmpad_rolled"] + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) + log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens) + if calculate_entropy: + entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged) + entropy_rmpad = torch.cat([t for t in entropy.unbind()]) + entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) + else: + raise NotImplementedError(f"pad_mode {pad_mode} not implemented") + + model_output["log_probs"] = log_probs + if calculate_entropy: + model_output["entropy"] = entropy + + return model_output + + def forward_step(self, micro_batch: TensorDict, loss_function, forward_only): + """Run forward pass, compute loss, and return outputs.""" + device_name = get_device_name() + micro_batch = micro_batch.to(get_device_id()) + model_inputs, output_args = self.prepare_model_inputs(micro_batch=micro_batch) + + with torch.autocast(device_type=device_name, dtype=torch.bfloat16): + raw_output = self.module( + **model_inputs, + use_cache=False, + ) + + model_output = self.prepare_model_outputs( + output=raw_output, output_args=output_args, micro_batch=micro_batch + ) + + if loss_function is not None: + loss, metrics = loss_function( + model_output=model_output, data=micro_batch, dp_group=self.get_data_parallel_group() + ) + else: + assert forward_only, "forward_only must be True when loss_function is None" + loss = torch.tensor(1.0, device=device_name) + metrics = {} + + output = { + "model_output": model_output, + "loss": loss.detach().item(), + "metrics": metrics, + } + + return loss, output diff --git a/verl/workers/engine/automodel/utils.py b/verl/workers/engine/automodel/utils.py new file mode 100644 index 00000000000..c10cf9a2db2 --- /dev/null +++ b/verl/workers/engine/automodel/utils.py @@ -0,0 +1,250 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions for the Automodel engine integration.""" + +import torch +import torch.distributed + +from verl.utils.device import get_device_id, get_torch_device + + +def get_dp_rank(device_mesh, include_cp=False): + """Get data-parallel rank from device mesh.""" + if device_mesh is None: + return 0 + if include_cp and "cp" in device_mesh.mesh_dim_names and device_mesh["cp"].size() > 1: + return device_mesh.get_local_rank("dp_cp") + return device_mesh.get_local_rank("dp") + + +def get_tp_rank(device_mesh): + """Get tensor-parallel rank from device mesh.""" + if device_mesh is None or "tp" not in device_mesh.mesh_dim_names or device_mesh["tp"].size() == 1: + return 0 + return device_mesh.get_local_rank("tp") + + +def get_pp_rank(device_mesh): + """Get pipeline-parallel rank from device mesh.""" + if device_mesh is None or "pp" not in device_mesh.mesh_dim_names or device_mesh["pp"].size() == 1: + return 0 + return device_mesh.get_local_rank("pp") + + +def get_dp_group_size(device_mesh, include_cp=False): + """Get data-parallel group size from device mesh.""" + if device_mesh is None: + return torch.distributed.get_world_size() + if include_cp and "cp" in device_mesh.mesh_dim_names and device_mesh["cp"].size() > 1: + return device_mesh["dp_cp"].size() + if "dp" in device_mesh.mesh_dim_names: + return device_mesh["dp"].size() + return torch.distributed.get_world_size() + + +def maybe_fully_shard_optimizer(model, optimizer, distributed_config): + """Call fully_shard_optimizer for MegatronFSDP strategy.""" + from nemo_automodel.components.distributed.config import MegatronFSDPConfig + + if isinstance(distributed_config, MegatronFSDPConfig) and torch.distributed.get_world_size() > 1: + from megatron_fsdp.fully_shard import fully_shard_optimizer + + fully_shard_optimizer(model, optimizer) + + +def build_distributed_config_from_engine_config(engine_config, world_size): + """Build v5 distributed config, device_mesh, and moe_mesh from engine config. + + Args: + engine_config: AutomodelEngineConfig instance. + world_size: Total number of processes in the job. + + Returns: + Tuple of (distributed_config, device_mesh, moe_mesh). + """ + from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config, MegatronFSDPConfig + from nemo_automodel.components.distributed.mesh_utils import create_device_mesh + + strategy = engine_config.distributed_strategy + + if strategy == "fsdp2": + from torch.distributed.fsdp import MixedPrecisionPolicy + + from verl.utils.torch_dtypes import PrecisionType + + mp_policy = MixedPrecisionPolicy( + param_dtype=PrecisionType.to_dtype(engine_config.mp_param_dtype), + reduce_dtype=PrecisionType.to_dtype(engine_config.mp_reduce_dtype), + output_dtype=PrecisionType.to_dtype(engine_config.mp_output_dtype), + cast_forward_inputs=True, + ) + + distributed_config = FSDP2Config( + sequence_parallel=engine_config.sequence_parallel, + mp_policy=mp_policy, + activation_checkpointing=engine_config.activation_checkpointing, + defer_fsdp_grad_sync=engine_config.defer_fsdp_grad_sync, + ) + + elif strategy == "megatron_fsdp": + distributed_config = MegatronFSDPConfig( + activation_checkpointing=engine_config.activation_checkpointing, + ) + + elif strategy == "ddp": + distributed_config = DDPConfig( + activation_checkpointing=engine_config.activation_checkpointing, + ) + + else: + raise ValueError(f"Unsupported distributed_strategy: {strategy}") + + device_mesh, moe_mesh = create_device_mesh( + distributed_config, + tp_size=engine_config.tp_size, + pp_size=engine_config.pp_size, + cp_size=engine_config.cp_size, + ep_size=engine_config.ep_size, + dp_replicate_size=engine_config.dp_replicate_size, + world_size=world_size, + ) + + return distributed_config, device_mesh, moe_mesh + + +def build_automodel_model(model_config, engine_config, distributed_config, device_mesh, moe_mesh): + """Build a model using NeMoAutoModelForCausalLM.from_pretrained(). + + Args: + model_config: HFModelConfig with model path and settings. + engine_config: AutomodelEngineConfig with distributed settings. + distributed_config: FSDP2Config, MegatronFSDPConfig, or DDPConfig instance. + device_mesh: Pre-created device mesh (or None for DDP). + moe_mesh: Pre-created MoE mesh (or None). + + Returns: + A HuggingFace model with Automodel's distributed infrastructure applied. + """ + from nemo_automodel._transformers.auto_model import NeMoAutoModelForCausalLM + + kwargs = {} + + if engine_config.enable_fp8: + from nemo_automodel.components.quantization.fp8 import FP8Config + + kwargs["fp8_config"] = FP8Config() + + if engine_config.enable_compile: + from nemo_automodel.components.utils.compile_utils import CompileConfig + + kwargs["compile_config"] = CompileConfig() + + # Qwen/Llama with ep_size<=1: use HF implementation. + from transformers import AutoConfig + + _cfg = AutoConfig.from_pretrained(model_config.path, trust_remote_code=model_config.trust_remote_code) + _arch = (getattr(_cfg, "architectures", None) or [""])[0].lower() + if engine_config.ep_size <= 1 and ("qwen" in _arch or "llama" in _arch): + kwargs["force_hf"] = True + + if engine_config.backend_config and not kwargs.get("force_hf", False): + from nemo_automodel.components.models.common.utils import BackendConfig + + backend_kwargs = dict(engine_config.backend_config) + kwargs["backend"] = BackendConfig(**backend_kwargs) + + # MoE config for MoEParallelizerConfig + if engine_config.ep_size > 1: + from nemo_automodel.components.moe.config import MoEParallelizerConfig + + moe_kwargs = dict(engine_config.moe_config) if engine_config.moe_config else {} + if hasattr(distributed_config, "mp_policy"): + moe_kwargs.setdefault("mp_policy", distributed_config.mp_policy) + + kwargs["moe_config"] = MoEParallelizerConfig(**moe_kwargs) + + kwargs["attn_implementation"] = engine_config.attn_implementation + + from verl.utils.torch_dtypes import PrecisionType + + kwargs["torch_dtype"] = PrecisionType.to_dtype(engine_config.model_dtype) + + model = NeMoAutoModelForCausalLM.from_pretrained( + pretrained_model_name_or_path=model_config.path, + device_mesh=device_mesh, + moe_mesh=moe_mesh, + distributed_config=distributed_config, + activation_checkpointing=engine_config.activation_checkpointing, + trust_remote_code=model_config.trust_remote_code, + **kwargs, + ) + + return model + + +@torch.no_grad() +def offload_automodel_model_to_cpu(model, empty_cache=True): + """Offload an FSDP2-wrapped model to CPU (reshard, move to CPU, optional cache clear).""" + from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState + from torch.distributed.fsdp._fully_shard._fsdp_state import _get_module_fsdp_state + + for module in model.modules(): + state = _get_module_fsdp_state(module) + if state is None: + continue + fsdp_param_group = state._fsdp_param_group + + if fsdp_param_group is None: + continue + + fsdp_param_group._training_state = TrainingState.IDLE + + model.reshard() + model.cpu() + if empty_cache: + get_torch_device().empty_cache() + + +@torch.no_grad() +def load_automodel_model_to_gpu(model): + """Load model back to GPU.""" + device = get_device_id() + model.to(device, non_blocking=True) + + +@torch.no_grad() +def offload_automodel_optimizer(optimizer): + """Offload optimizer state to CPU.""" + if not optimizer.state: + return + for param_group in optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + state[key] = value.to("cpu", non_blocking=True) + + +@torch.no_grad() +def load_automodel_optimizer(optimizer, device_id): + """Load optimizer state back to GPU.""" + if not optimizer.state: + return + for param_group in optimizer.param_groups: + for param in param_group["params"]: + state = optimizer.state[param] + for key, value in state.items(): + if isinstance(value, torch.Tensor): + state[key] = value.to(device_id, non_blocking=True)