diff --git a/tests/special_e2e/sft/run_sft_engine.sh b/tests/special_e2e/sft/run_sft_engine.sh index b334281908f..9fe80afae13 100644 --- a/tests/special_e2e/sft/run_sft_engine.sh +++ b/tests/special_e2e/sft/run_sft_engine.sh @@ -30,7 +30,7 @@ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} #hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}" SP_SIZE=${SP_SIZE:-1} -FSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}} +FSDP_SIZE=${FSDP_SIZE:-1} FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp"} TP_SIZE=${TP_SIZE:-1} @@ -44,6 +44,8 @@ USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True} FSDP_ENGINE_CONFIG="\ engine=${backend} \ + model=hf_model \ + model.path=$MODEL_PATH \ optim=${backend} \ optim.lr=1e-5 \ optim.lr_warmup_steps_ratio=0.2 \ @@ -58,6 +60,8 @@ FSDP_ENGINE_CONFIG="\ VEOMNI_ENGINE_CONFIG="\ engine=${backend} \ + model=hf_model \ + model.path=$MODEL_PATH \ optim=${backend} \ optim.lr=1e-5 \ optim.lr_warmup_steps_ratio=0.2 \ @@ -71,6 +75,8 @@ VEOMNI_ENGINE_CONFIG="\ MEGATRON_ENGINE_CONFIG="\ engine=${backend} \ + model=hf_model \ + model.path=$MODEL_PATH \ optim=${backend} \ optim.lr=1e-5 \ optim.lr_warmup_steps_ratio=0.2 \ @@ -87,6 +93,26 @@ MEGATRON_ENGINE_CONFIG="\ +engine.override_transformer_config.context_parallel_size=${CP_SIZE} \ engine.use_mbridge=True" +TORCHTITAN_ENGINE_CONFIG="\ + engine=${backend} \ + model=hf_model \ + model.path=${MODEL_PATH} \ + optim=${backend} \ + optim.lr=1e-5 \ + optim.lr_warmup_steps_ratio=0.2 \ + optim.weight_decay=0.1 \ + optim.betas="[0.9,0.95]" \ + optim.clip_grad=1.0 \ + optim.min_lr_factor=0.1 \ + optim.decay_type=cosine \ + optim.total_training_steps=1000 \ + engine.tensor_parallel_size=${TP_SIZE} \ + engine.pipeline_parallel_size=${PP_SIZE} \ + engine.context_parallel_size=${CP_SIZE} \ + engine.data_parallel_shard_size=${FSDP_SIZE} \ + engine.use_torch_compile=False" + + if [ "$backend" = "fsdp" ]; then ENGINE_CONFIG="$FSDP_ENGINE_CONFIG" echo "Using fsdp engine" @@ -95,6 +121,10 @@ elif [ "$backend" = "veomni" ]; then ENGINE_CONFIG="$VEOMNI_ENGINE_CONFIG" echo "Using veomni engine" exp_name=gsm8k-${backend}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} +elif [ "$backend" = "torchtitan" ]; then + ENGINE_CONFIG="$TORCHTITAN_ENGINE_CONFIG" + echo "Using torchtitan engine" + exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-dp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode} else ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG" echo "Using megatron engine" @@ -112,8 +142,8 @@ $COMMAND \ data.use_dynamic_bsz=True \ data.max_token_len_per_gpu=2048 \ data.messages_key=messages \ - model.path=$MODEL_PATH \ model.use_remove_padding=${USE_REMOVE_PADDING} \ + data.ignore_input_ids_mismatch=True \ ${ENGINE_CONFIG} \ trainer.test_freq=after_each_epoch \ trainer.save_freq=-1 \ @@ -128,5 +158,5 @@ $COMMAND \ # trainer.total_training_steps=${TOTAL_TRAIN_STEP} \ # trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \ # trainer.max_ckpt_to_keep=1 \ - -rm -rf "${ckpts_home:?}/*" \ No newline at end of file + +rm -rf "${ckpts_home:?}/*" diff --git a/tests/special_e2e/sft/test_sft_engine_all.sh b/tests/special_e2e/sft/test_sft_engine_all.sh index 96f5f195692..21524ce1d09 100644 --- a/tests/special_e2e/sft/test_sft_engine_all.sh +++ b/tests/special_e2e/sft/test_sft_engine_all.sh @@ -37,6 +37,15 @@ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 b echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray" BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh +# TODO: Will add back torchtitan CI once everything is ready +# # test with torchtitan fsdp=2 +# echo "run with tp1 pp1 cp1 fsdp2 num_gpus2" +# BACKEND=torchtitan TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=2 bash tests/special_e2e/sft/run_sft_engine.sh + +# # test with torchtitan tp2 fsdp=2 +# echo "run with tp2 pp1 cp1 fsdp2 num_gpus4" +# BACKEND=torchtitan TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh + python3 tests/special_e2e/sft/compare_sft_engine_results.py rm -rf ~/verl/test/log diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index b743913b77c..8310583c631 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -42,6 +42,7 @@ "verl/workers/engine/utils.py", # appear in enable_full_determinism "verl/workers/engine/fsdp/transformer_impl.py", # appear in default device_name "verl/workers/engine/veomni/transformer_impl.py", # appear in default device_name + "verl/workers/engine/torchtitan/transformer_impl.py", # appear in default device_name "verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes "verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES "verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes diff --git a/verl/trainer/config/_generated_ppo_trainer.yaml b/verl/trainer/config/_generated_ppo_trainer.yaml index a4002ced76d..a0606078ae9 100644 --- a/verl/trainer/config/_generated_ppo_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_trainer.yaml @@ -21,6 +21,7 @@ actor_rollout_ref: min_lr_ratio: 0.0 num_cycles: 0.5 lr_scheduler_type: constant + zero_indexed_step: true warmup_style: null override_optimizer_config: null fsdp_config: @@ -399,6 +400,7 @@ critic: min_lr_ratio: 0.0 num_cycles: 0.5 lr_scheduler_type: constant + zero_indexed_step: true warmup_style: null override_optimizer_config: null model: diff --git a/verl/trainer/config/engine/torchtitan.yaml b/verl/trainer/config/engine/torchtitan.yaml new file mode 100644 index 00000000000..d28f4cab273 --- /dev/null +++ b/verl/trainer/config/engine/torchtitan.yaml @@ -0,0 +1,68 @@ +# Target class for this configuration +_target_: verl.workers.config.TorchtitanEngineConfig + +# policy for wrapping the model +wrap_policy: + # Minimum number of parameters to trigger wrapping a layer with FSDP + min_num_params: 0 + +# The policy for applying `reshard_after_forward` within an FSDP setup +# Options: "default", "always", "never" +reshard_after_forward: default + +# Prefetch the next forward-pass all-gather before the current forward computation. +forward_prefetch: false + +# Whether to use original parameters +use_orig_params: false + +# Mixed precision configuration for FSDP +mixed_precision: false + +# Whether to use torch compile +use_torch_compile: true + +# Whether to use entropy_from_logits_with_chunking +entropy_from_logits_with_chunking: false + +# Whether to use entropy checkpointing +entropy_checkpointing: false + +# Data parallel size (FSDP group size) +data_parallel_size: 1 + +# Data parallel replicate size +data_parallel_replicate_size: 1 + +# Data parallel shard size +data_parallel_shard_size: 1 + +# Tensor parallel size +tensor_parallel_size: 1 + +# Expert parallel size +expert_parallel_size: 1 + +# Pipeline parallel size +pipeline_parallel_size: 1 + +# Context parallel size +context_parallel_size: 1 + +# Attention type for torchtitan's model (e.g., "sdpa", "flex", "varlen") +attn_type: flex + +# Strategy +strategy: torchtitan + +# Random seed for reproducibility +seed: 42 + +# Whether to enable full determinism for distributed training, only for debugging +full_determinism: false + +# Whether to use forward only +forward_only: false + +# Mixed precision training param dtype +dtype: bfloat16 diff --git a/verl/trainer/config/optim/fsdp.yaml b/verl/trainer/config/optim/fsdp.yaml index a7dd99b1ee2..ce6ced773b6 100644 --- a/verl/trainer/config/optim/fsdp.yaml +++ b/verl/trainer/config/optim/fsdp.yaml @@ -38,6 +38,9 @@ num_cycles: 0.5 # LR scheduler type: "constant" or "cosine" lr_scheduler_type: constant +# Whether the LR schedule uses 0-indexed steps +zero_indexed_step: true + # deprecated warmup_style: null diff --git a/verl/trainer/config/optim/torchtitan.yaml b/verl/trainer/config/optim/torchtitan.yaml new file mode 100644 index 00000000000..baea31ee527 --- /dev/null +++ b/verl/trainer/config/optim/torchtitan.yaml @@ -0,0 +1,35 @@ +# Target class for this configuration +_target_: verl.workers.config.TorchtitanOptimizerConfig + +# Optimizer name +name: AdamW + +# Learning rate +lr: 1e-3 + +# LR warmup steps ratio +lr_warmup_steps_ratio: 0.0 + +# Total training steps +total_training_steps: -1 + +# Weight decay +weight_decay: 0.01 + +# LR warmup steps +lr_warmup_steps: -1 + +# Betas for Adam optimizer +betas: [0.9, 0.999] + +# Clip gradient +clip_grad: 1.0 + +# Epsilon for Adam optimizer +eps: 1e-8 + +# Decay type: "linear", "sqrt", or "cosine" +decay_type: linear + +# Minimum LR factor for cosine schedule +min_lr_factor: 0.0 diff --git a/verl/trainer/sft_trainer.py b/verl/trainer/sft_trainer.py index 979d92b04a1..d23ebc5fa90 100644 --- a/verl/trainer/sft_trainer.py +++ b/verl/trainer/sft_trainer.py @@ -238,8 +238,14 @@ def _get_batch_seqlens(self, data): batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1) batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp) + dp_group = self.engine.get_data_parallel_group() + dp_size = self.engine.get_data_parallel_size() + + if dp_size == 1 or dp_group is None: + return batch_seqlens.tolist() + output_tensor = torch.empty( - (batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),), + (batch_seqlens.shape[0] * dp_size,), dtype=batch_seqlens.dtype, device=self.device_name, ) # (global_bsz,) @@ -247,7 +253,7 @@ def _get_batch_seqlens(self, data): torch.distributed.all_gather_into_tensor( output_tensor=output_tensor, input_tensor=batch_seqlens, - group=self.engine.get_data_parallel_group(), + group=dp_group, ) batch_seqlens = output_tensor.tolist() @@ -372,9 +378,9 @@ def fit(self): if self.engine.is_mp_src_rank_with_outputs(): val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name)) # average over data parallel group - torch.distributed.all_reduce( - val_loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() - ) + dp_group = self.engine.get_data_parallel_group() + if dp_group is not None: + torch.distributed.all_reduce(val_loss, op=torch.distributed.ReduceOp.AVG, group=dp_group) if is_logging: metric = {"val/loss": val_loss.detach().item()} diff --git a/verl/utils/seqlen_balancing.py b/verl/utils/seqlen_balancing.py index 46f82240448..51097f50a51 100644 --- a/verl/utils/seqlen_balancing.py +++ b/verl/utils/seqlen_balancing.py @@ -388,7 +388,7 @@ def rearrange_micro_batches( if min_num_micro_batch is not None: # used to support pp num_micro_batches = max(min_num_micro_batch, num_micro_batches) - if dist.is_initialized() and same_micro_num_in_dp: + if dist.is_initialized() and same_micro_num_in_dp and dp_group is not None: num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name()) dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group) num_micro_batches = num_micro_batches.cpu().item() diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index 2802e3642f1..8666bec2d16 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -710,6 +710,7 @@ def get_cosine_schedule_with_warmup( num_cycles: float = 0.5, last_epoch: int = -1, init_lr_ratio: float = None, + zero_indexed_step: bool = True, ): """ Create a schedule with a learning rate that decreases following the values of the cosine function between the @@ -731,6 +732,9 @@ def get_cosine_schedule_with_warmup( The index of the last epoch when resuming training. init_lr_ratio (:obj:`float`, `optional`, defaults to None): The initial lr ratio w.r.t the maximum. + zero_indexed_step (:obj:`bool`, `optional`, defaults to True): + Whether the LR schedule uses 0-indexed steps. If True (default), step counting starts at 0. + If False (used by torchtitan), step counting starts at 1. Return: :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule. """ @@ -743,6 +747,8 @@ def get_cosine_schedule_with_warmup( assert init_lr_ratio >= 0 and init_lr_ratio <= 1.0 def lr_lambda(current_step): + if not zero_indexed_step: + current_step += 1 if current_step < num_warmup_steps: return init_lr_ratio + (1.0 - init_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps))) progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index 7f8e573ad79..08d7fa293aa 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -27,6 +27,7 @@ "FSDPEngineConfig", "McoreEngineConfig", "TrainingWorkerConfig", + "TorchtitanEngineConfig", "VeOmniEngineConfig", "EngineConfig", "EngineRouterReplayConfig", @@ -309,6 +310,65 @@ def __post_init__(self): assert self.strategy in ["veomni"], f"strategy {self.strategy} not supported" +@dataclass +class TorchtitanEngineConfig(EngineConfig): + """Configuration for Torchtitan. + + The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. + + Args: + wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy. + reshard_after_forward (Literal["default", "always", "never"]): The policy for applying + `reshard_after_forward` within an FSDP setup, default "default" + forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False + use_orig_params (bool): Whether to use original parameters when initialize FSDP, default False + mixed_precision (bool): Mixed precision configuration for FSDP, default False + offload_policy (bool): Whether to offload policy model parameters, default False + data_parallel_size (int): Data parallel group size, default 1 + data_parallel_replicate_size (int): Data parallel replicate size, default 1 + data_parallel_shard_size (int): Data parallel shard degree, default 1 + tensor_parallel_size (int): Tensor parallel size, default 1 + expert_parallel_size (int): Expert parallel size, default 1 + expert_tensor_parallel_size (int): Expert tensor parallel size, default 1 + pipeline_parallel_size (int): Pipeline parallel size, default 1 + context_parallel_size (int): Context parallel size, default 1 + attn_type (str): Attention type for torchtitan's model (e.g., "sdpa", "flex", "varlen"), + default "flex" + strategy (str): Strategy to use for distributed training, default "torchtitan" + seed (int): Random seed for reproducibility. + full_determinism (bool): If true, enable_full_determinism is called to ensure reproducible results + in distributed training. Important: this will negatively impact performance, so only use it for + debugging. + + """ + + wrap_policy: dict[str, Any] = field(default_factory=dict) + reshard_after_forward: Literal["default", "always", "never"] = "default" + forward_prefetch: bool = False + use_orig_params: bool = False + mixed_precision: bool = False + offload_policy: bool = False + use_torch_compile: bool = True + entropy_from_logits_with_chunking: bool = False + entropy_checkpointing: bool = False + data_parallel_size: int = 1 + data_parallel_replicate_size: int = 1 + data_parallel_shard_size: int = 1 + tensor_parallel_size: int = 1 + expert_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + pipeline_parallel_size: int = 1 + context_parallel_size: int = 1 + attn_type: str = "flex" + strategy: str = "torchtitan" + seed: int = 42 + full_determinism: bool = False + + def __post_init__(self): + super().__post_init__() + assert self.strategy in ["torchtitan"], f"strategy {self.strategy} not supported" + + @dataclass class TrainingWorkerConfig(BaseConfig): model_type: str = None # model type (language_model/value_model) diff --git a/verl/workers/config/optimizer.py b/verl/workers/config/optimizer.py index bdb87667c25..b7f05bef518 100644 --- a/verl/workers/config/optimizer.py +++ b/verl/workers/config/optimizer.py @@ -19,7 +19,14 @@ from verl.base_config import BaseConfig -__all__ = ["OptimizerConfig", "FSDPOptimizerConfig", "McoreOptimizerConfig", "build_optimizer", "VeOmniOptimizerConfig"] +__all__ = [ + "OptimizerConfig", + "FSDPOptimizerConfig", + "McoreOptimizerConfig", + "build_optimizer", + "VeOmniOptimizerConfig", + "TorchtitanOptimizerConfig", +] @dataclass @@ -88,6 +95,8 @@ class FSDPOptimizerConfig(OptimizerConfig): min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule. lr_scheduler_type (str): LR scheduler type: "constant" or "cosine". num_cycles (float): Number of cosine cycles in LR schedule. + zero_indexed_step (bool): Whether the LR schedule uses 0-indexed steps. If True (default), + step counting starts at 0. If False, step counting starts at 1. """ _mutable_fields = OptimizerConfig._mutable_fields.copy() @@ -101,6 +110,7 @@ class FSDPOptimizerConfig(OptimizerConfig): lr_scheduler_type: str = "constant" num_cycles: float = 0.5 override_optimizer_config: Optional[dict] = None + zero_indexed_step: bool = True def __post_init__(self): if self.warmup_style is not None: @@ -143,6 +153,23 @@ class McoreOptimizerConfig(OptimizerConfig): override_optimizer_config: Optional[dict] = None +@dataclass +class TorchtitanOptimizerConfig(OptimizerConfig): + """Torchtitan optimizer configuration extending base OptimizerConfig. + + Args: + name (str): Optimizer name; default is "AdamW". + eps (float): Epsilon value for AdamW optimizer, default 1e-8. + decay_type (str): Weight decay type: "linear", "sqrt", or "cosine". + min_lr_factor (float): Minimum learning rate factor. + """ + + name: str = "AdamW" + eps: float = 1e-8 + decay_type: str = "linear" + min_lr_factor: float = 0.0 + + def build_optimizer(parameters, config: FSDPOptimizerConfig): """Build an optimizer based on the configuration. diff --git a/verl/workers/engine/__init__.py b/verl/workers/engine/__init__.py index 7b8be1002c0..8f01080fdcb 100644 --- a/verl/workers/engine/__init__.py +++ b/verl/workers/engine/__init__.py @@ -21,6 +21,14 @@ "FSDPEngineWithLMHead", ] +try: + from .torchtitan import TorchTitanEngine, TorchTitanEngineWithLMHead + + __all__ += ["TorchTitanEngine", "TorchTitanEngineWithLMHead"] +except ImportError: + TorchTitanEngine = None + TorchTitanEngineWithLMHead = None + try: from .veomni import VeOmniEngine, VeOmniEngineWithLMHead diff --git a/verl/workers/engine/fsdp/transformer_impl.py b/verl/workers/engine/fsdp/transformer_impl.py index 6820afbaa61..dbe9eb2f4e1 100644 --- a/verl/workers/engine/fsdp/transformer_impl.py +++ b/verl/workers/engine/fsdp/transformer_impl.py @@ -412,6 +412,7 @@ def _build_lr_scheduler(self, optimizer): lr_scheduler_type = optim_config.lr_scheduler_type min_lr_ratio = optim_config.min_lr_ratio num_cycles = optim_config.num_cycles + zero_indexed_step = optim_config.zero_indexed_step if num_warmup_steps <= 0: num_warmup_steps_ratio = optim_config.lr_warmup_steps_ratio num_warmup_steps = int(num_warmup_steps_ratio * total_steps) @@ -428,6 +429,7 @@ def _build_lr_scheduler(self, optimizer): num_training_steps=total_steps, min_lr_ratio=min_lr_ratio, num_cycles=num_cycles, + zero_indexed_step=zero_indexed_step, ) else: raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported") diff --git a/verl/workers/engine/torchtitan/__init__.py b/verl/workers/engine/torchtitan/__init__.py new file mode 100644 index 00000000000..345757277af --- /dev/null +++ b/verl/workers/engine/torchtitan/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .transformer_impl import TorchTitanEngine, TorchTitanEngineWithLMHead + +__all__ = ["TorchTitanEngine", "TorchTitanEngineWithLMHead"] diff --git a/verl/workers/engine/torchtitan/transformer_impl.py b/verl/workers/engine/torchtitan/transformer_impl.py new file mode 100644 index 00000000000..56108ce6dcb --- /dev/null +++ b/verl/workers/engine/torchtitan/transformer_impl.py @@ -0,0 +1,684 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +The concrete Engine implementation using PyTorch TorchTitan parallelism (FSDP2 + TP + PP) +""" + +import gc +import logging +import os +import re +from contextlib import nullcontext +from typing import Any, Callable, Optional + +import torch +import torch.distributed +from tensordict import TensorDict +from torch.distributed.tensor import DTensor +from torchtitan.config.job_config import ( + Checkpoint, + Compile, + JobConfig, + LRScheduler, + Model, + Optimizer, + Parallelism, + Training, +) +from torchtitan.distributed import utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input +from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.train import Trainer + +import verl.utils.torch_functional as verl_F +from verl.trainer.config import CheckpointConfig +from verl.utils import tensordict_utils as tu +from verl.utils.dataset.dataset_utils import DatasetPadMode +from verl.utils.debug import log_gpu_memory_usage +from verl.utils.device import get_device_id, get_device_name +from verl.utils.fsdp_utils import ( + load_fsdp_model_to_gpu, + load_fsdp_optimizer, + offload_fsdp_model_to_cpu, + offload_fsdp_optimizer, +) +from verl.utils.model import extract_multi_modal_inputs +from verl.utils.torch_functional import logprobs_from_logits +from verl.workers.config import HFModelConfig, TorchtitanEngineConfig, TorchtitanOptimizerConfig +from verl.workers.engine.torchtitan.utils import ( + derive_torchtitan_name_and_flavor, + enable_fsdp_gradient_division, + get_attention_masks, +) + +from ..base import BaseEngine, BaseEngineCtx, EngineRegistry +from ..utils import enable_full_determinism, postprocess_batch_func, prepare_micro_batches + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + + +class TorchTitanEngine(BaseEngine): + """ + Concrete Engine implementation using PyTorch TorchTitan parallelism. + + Supports model sharding with FSDP2, tensor parallelism, activation/optimizer offloading, + LoRA, and sequence parallelism following the TorchTitan design. + """ + + def __init__( + self, + model_config: HFModelConfig, + engine_config: TorchtitanEngineConfig, + optimizer_config: TorchtitanOptimizerConfig, + checkpoint_config: CheckpointConfig, + ): + """ + Initialize the TorchTitanEngine. + + Sets up distributed device meshes for tensor and data parallelism, LoRA, and offload policies. + + Args: + model_config: Configuration for HuggingFace model. + engine_config: Configuration for FSDP/TorchTitan engine (uses FSDP2). + optimizer_config: Configuration for optimizer. + checkpoint_config: Configuration for checkpointing. + """ + super().__init__() + + self.model_config = model_config + self.engine_config = engine_config + self.optimizer_config = optimizer_config + self.checkpoint_config = checkpoint_config + + # Disable torchtitan's dataloader since verl has its own data loading + # Ideally torchtitan trainer init should not initialize dataloader + import torchtitan.protocols.train_spec as train_spec_module + + original_get_train_spec = train_spec_module.get_train_spec + + def _get_train_spec_without_dataloader(model_name): + train_spec = original_get_train_spec(model_name) + train_spec.build_dataloader_fn = None + return train_spec + + train_spec_module.get_train_spec = _get_train_spec_without_dataloader + + # Derive torchtitan model name and flavor from HF config + torchtitan_name, torchtitan_flavor = derive_torchtitan_name_and_flavor(self.model_config.hf_config) + + # Get train_spec and directly override model_args before Trainer init + train_spec = train_spec_module.get_train_spec(torchtitan_name) + model_args = train_spec.model_args.get(torchtitan_flavor) + if model_args is not None: + if hasattr(model_args, "attn_type"): + model_args.attn_type = self.engine_config.attn_type + + model = Model( + name=torchtitan_name, + flavor=torchtitan_flavor, + hf_assets_path=self.model_config.path, + ) + optimizer = Optimizer( + name=self.optimizer_config.name, + lr=self.optimizer_config.lr, + eps=self.optimizer_config.eps, + beta1=self.optimizer_config.betas[0], + beta2=self.optimizer_config.betas[1], + weight_decay=self.optimizer_config.weight_decay, + ) + + total_steps = self.optimizer_config.total_training_steps + lr_warmup_steps = self.optimizer_config.lr_warmup_steps + if lr_warmup_steps is None or lr_warmup_steps <= 0: + lr_warmup_steps = int(self.optimizer_config.lr_warmup_steps_ratio * total_steps) + + lr_scheduler = LRScheduler( + warmup_steps=lr_warmup_steps, + decay_type=self.optimizer_config.decay_type, + min_lr_factor=self.optimizer_config.min_lr_factor, + ) + parallelism = Parallelism( + data_parallel_replicate_degree=self.engine_config.data_parallel_replicate_size, + data_parallel_shard_degree=self.engine_config.data_parallel_shard_size, + fsdp_reshard_after_forward=self.engine_config.reshard_after_forward, + tensor_parallel_degree=self.engine_config.tensor_parallel_size, + pipeline_parallel_degree=self.engine_config.pipeline_parallel_size, + context_parallel_degree=self.engine_config.context_parallel_size, + expert_parallel_degree=self.engine_config.expert_parallel_size, + expert_tensor_parallel_degree=self.engine_config.expert_tensor_parallel_size, + ) + checkpoint = Checkpoint( + enable=True, + initial_load_in_hf=True, + initial_load_model_only=True, + initial_load_path=model_config.path, + ) + compile = Compile(enable=self.engine_config.use_torch_compile) + if self.engine_config.offload_policy or self.engine_config.forward_only: + training = Training(enable_cpu_offload=True) + else: + training = Training() + + # Construct Torchtitan's JobConfig + self.config = JobConfig( + model=model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + parallelism=parallelism, + checkpoint=checkpoint, + compile=compile, + training=training, + ) + self.trainer = Trainer(self.config) + + self._init_device_mesh() + + # Re-enable FSDP's gradient division for verl's loss scaling. + # TorchTitan disables gradient division by default (for global token normalization), + # but verl's loss function multiplies by dp_size to compensate for gradient averaging. + if self.engine_config.data_parallel_shard_size > 1: + dp_size = self.get_data_parallel_size() + for model_part in self.trainer.model_parts: + enable_fsdp_gradient_division(model_part, dp_size) + + if self.engine_config.full_determinism: + enable_full_determinism(seed=self.engine_config.seed) + + # set FSDP offload params + self._is_offload_param = self.engine_config.param_offload + self._is_offload_optimizer = self.engine_config.optimizer_offload + + if self.engine_config.entropy_from_logits_with_chunking: + entropy_from_logits = verl_F.entropy_from_logits_with_chunking + else: + entropy_from_logits = verl_F.entropy_from_logits + + self.compute_entropy_from_logits = ( + torch.compile(entropy_from_logits, dynamic=True) + if self.engine_config.use_torch_compile + else entropy_from_logits + ) + + @property + def is_param_offload_enabled(self) -> bool: + return self._is_offload_param + + @property + def is_optimizer_offload_enabled(self) -> bool: + return self._is_offload_optimizer + + def is_mp_src_rank_with_outputs(self): + """ + Whether the current rank is the first rank in model parallel group that contains model outputs + """ + is_collect = True + # TP: outputs are on TP rank 0 + if self.parallel_dims.tp > 1: + tp_mesh = self.parallel_dims.get_optional_mesh("tp") + is_collect = is_collect and (tp_mesh.get_local_rank() == 0) + # PP: outputs are on the last PP rank + if self.parallel_dims.pp > 1: + pp_mesh = self.parallel_dims.get_optional_mesh("pp") + is_collect = is_collect and (pp_mesh.get_local_rank() == self.parallel_dims.pp - 1) + # CP: outputs are on CP rank 0 + if self.parallel_dims.cp > 1: + cp_mesh = self.parallel_dims.get_optional_mesh("cp") + is_collect = is_collect and (cp_mesh.get_local_rank() == 0) + return is_collect + + def initialize(self): + """ + Build the model, optimizer, and learning rate scheduler with TorchTitan parallelism. + + Applies device, dtype, and precision configurations, including mixed precision. + Sets up checkpoint manager. + """ + self.module = self.trainer.model_parts + self.checkpointer = self.trainer.checkpointer + # load initial HF weights + self.checkpointer.load() + + if not self.engine_config.forward_only: + self.optimizer = self.trainer.optimizers + self.lr_scheduler = self.trainer.lr_schedulers + else: + self.optimizer = None + self.lr_scheduler = None + + self.to( + device="cpu", + model=self._is_offload_param, + optimizer=self._is_offload_optimizer, + grad=self._is_offload_param, + ) + + log_gpu_memory_usage("After offload model/optimizer/grad during init", logger=logger) + + def _init_device_mesh(self): + """Initialize the device mesh for TorchTitan style parallelism.""" + world_size = torch.distributed.get_world_size() + self.parallel_dims = ParallelDims( + dp_shard=self.engine_config.data_parallel_shard_size, + dp_replicate=self.engine_config.data_parallel_replicate_size, + cp=self.engine_config.context_parallel_size, + tp=self.engine_config.tensor_parallel_size, + pp=self.engine_config.pipeline_parallel_size, + ep=self.engine_config.expert_parallel_size, + etp=self.engine_config.expert_tensor_parallel_size, + world_size=world_size, + ) + self.device_mesh = self.parallel_dims.build_mesh() + + def train_mode(self, **kwargs): + """Return a context manager for training mode.""" + return EngineTrainModeCtx(self, **kwargs) + + def eval_mode(self, **kwargs): + """Return a context manager for evaluation mode.""" + return EngineEvalModeCtx(self, **kwargs) + + def get_data_parallel_rank(self): + mesh = self._get_data_parallel_mesh() + return 0 if mesh is None else mesh.get_local_rank() + + def get_data_parallel_size(self): + return self.engine_config.data_parallel_shard_size * self.engine_config.data_parallel_replicate_size + + def get_data_parallel_group(self): + mesh = self._get_data_parallel_mesh() + return mesh.get_group() if mesh is not None else None + + def _get_data_parallel_mesh(self): + """Get the data parallel mesh, handling hybrid/fully/replicate shard modes.""" + mesh = self.parallel_dims.get_optional_mesh(["dp_replicate", "fsdp"]) + if mesh is None: + mesh = self.parallel_dims.get_optional_mesh("fsdp") + if mesh is None: + mesh = self.parallel_dims.get_optional_mesh("dp_replicate") + return mesh + + def forward_backward_batch(self, data: TensorDict, loss_function: Callable, forward_only=False): + """Perform forward and optionally backward pass on a batch.""" + tu.assign_non_tensor(data, sp_size=self.engine_config.tensor_parallel_size) + + # Compute num_tokens in global batch for loss normalization + batch_num_tokens = data["loss_mask"].sum().to(get_device_id()) + dp_group = self.get_data_parallel_group() + if dp_group is not None: + torch.distributed.all_reduce(batch_num_tokens, op=torch.distributed.ReduceOp.SUM, group=dp_group) + tu.assign_non_tensor(data, batch_num_tokens=batch_num_tokens.item()) + tu.assign_non_tensor(data, dp_size=self.get_data_parallel_size()) + + micro_batches, indices = prepare_micro_batches( + data=data, dp_group=self.get_data_parallel_group(), same_micro_num_in_dp=True + ) + + output_lst = [] + + ctx = torch.no_grad() if forward_only else nullcontext() + + for micro_batch in micro_batches: + with ctx: + loss, output = self.forward_step(micro_batch, loss_function=loss_function, forward_only=forward_only) + if not forward_only: + loss.backward() + output_lst.append(output) + + return postprocess_batch_func(output_lst=output_lst, indices=indices, data=data) + + def model_forward_step( + self, + *, + inputs: torch.Tensor, + extra_inputs: dict[str, torch.Tensor] | None = None, + extra_kwargs: dict[str, torch.Tensor] | None = None, + ) -> torch.Tensor: + """ + Perform a forward pass through the trainer model without backward. + """ + model_parts = self.module + parallel_dims = self.parallel_dims + + if parallel_dims.pp_enabled: + raise NotImplementedError( + "Pipeline parallelism is not yet supported in model_forward_step. " + "This will be implemented in a follow-up PR." + ) + else: + # Non-PP forward + assert len(model_parts) == 1 + with self.trainer.train_context(): + with self.trainer.maybe_enable_amp: + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) + + if isinstance(pred, DTensor): + pred = pred.full_tensor() + return pred + + def forward_step(self, micro_batch: TensorDict, loss_function, forward_only): + raise NotImplementedError("forward_step must be implemented in subclass") + + def optimizer_zero_grad(self): + """Zero gradients.""" + self.optimizer.zero_grad() + + def optimizer_step(self): + """Perform optimizer step with gradient clipping.""" + grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.module for p in m.parameters()], + self.config.training.max_norm, + foreach=True, + pp_mesh=self.parallel_dims.get_optional_mesh("pp"), + ep_enabled=self.parallel_dims.ep_enabled, + ) + + # if grad_norm is not finite, skip the update + if not torch.isfinite(grad_norm): + logger.warning(f"grad_norm is not finite: {grad_norm}") + self.optimizer.zero_grad() + else: + self.optimizer.step() + return grad_norm.item() + + def lr_scheduler_step(self): + """Advance learning rate scheduler.""" + self.lr_scheduler.step() + lr = self.lr_scheduler.schedulers[0].get_last_lr()[0] + return lr + + def to(self, device: str, model: bool = True, optimizer: bool = True, grad: bool = True): + """Move model and/or optimizer to CPU or GPU.""" + super().to(device=device, model=model, optimizer=optimizer, grad=grad) + + if self.engine_config.forward_only: + return + + device_name = get_device_name() + assert device in (device_name, "cpu") + if device == device_name: + if model: + for module in self.module: + load_fsdp_model_to_gpu(module) + if optimizer and self.optimizer is not None: + load_fsdp_optimizer(self.optimizer, device) + gc.collect() + elif device == "cpu": + if model: + for module in self.module: + offload_fsdp_model_to_cpu(module) + if optimizer and self.optimizer is not None: + offload_fsdp_optimizer(self.optimizer) + else: + raise ValueError(f"Invalid device type: {device}") + + def save_checkpoint( + self, + local_path: str, + hdfs_path: Optional[str] = None, + global_step: int = 0, + max_ckpt_to_keep: Optional[int] = None, + **kwargs, + ) -> None: + """Save checkpoint.""" + if self._is_offload_param: + for module in self.module: + load_fsdp_model_to_gpu(module) + + # Override TorchTitan's folder to use verl's path + parent_dir = os.path.dirname(local_path) + self.checkpointer.folder = parent_dir + + if max_ckpt_to_keep is not None: + self.checkpointer.keep_latest_k = max_ckpt_to_keep + + self.checkpointer.save(curr_step=global_step) + + torch.distributed.barrier() + if self._is_offload_param: + for module in self.module: + offload_fsdp_model_to_cpu(module) + + def load_checkpoint( + self, local_path: str, hdfs_path: Optional[str] = None, del_local_after_load: int = True, **kwargs + ) -> None: + """Load checkpoint.""" + if self._is_offload_param: + for module in self.module: + load_fsdp_model_to_gpu(module) + + # Override TorchTitan's folder to use verl's path + parent_dir = os.path.dirname(local_path) + self.checkpointer.folder = parent_dir + + # Extract step number from path (verl uses global_step_N format) + match = re.search(r"global_step_(\d+)", local_path) + if match: + step = int(match.group(1)) + self.checkpointer.load(step=step) + else: + # Fallback to latest + self.checkpointer.load(step=-1) + + torch.distributed.barrier() + if self._is_offload_param: + for module in self.module: + offload_fsdp_model_to_cpu(module) + + if self._is_offload_optimizer: + offload_fsdp_optimizer(self.optimizer) + + +class EngineEvalModeCtx(BaseEngineCtx): + def __init__(self, engine: TorchTitanEngine, **kwargs): + super().__init__(engine=engine, mode="eval", **kwargs) + + def __enter__(self): + assert isinstance(self.engine, TorchTitanEngine) + super().__enter__() + for module in self.engine.module: + module.eval() + + def __exit__(self, exc_type, exc_value, traceback): + assert isinstance(self.engine, TorchTitanEngine) + + # Reshard the root FSDP module + if self.engine.engine_config.data_parallel_shard_size > 1: + for module in self.engine.module: + module.reshard() + + super().__exit__(exc_type, exc_value, traceback) + + +class EngineTrainModeCtx(BaseEngineCtx): + def __init__(self, engine: TorchTitanEngine, **kwargs): + super().__init__(engine=engine, mode="train", **kwargs) + + def __enter__(self): + assert isinstance(self.engine, TorchTitanEngine) + super().__enter__() + for module in self.engine.module: + module.train() + + def __exit__(self, exc_type, exc_value, traceback): + assert isinstance(self.engine, TorchTitanEngine) + self.engine.optimizer_zero_grad() + super().__exit__(exc_type, exc_value, traceback) + + +@EngineRegistry.register(model_type="language_model", backend=["torchtitan"], device=["cuda", "npu"]) +class TorchTitanEngineWithLMHead(TorchTitanEngine): + """TorchTitan engine implementation for language models with LM head.""" + + def prepare_model_inputs(self, micro_batch: TensorDict): + use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) + assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported" + + multi_modal_inputs = extract_multi_modal_inputs(micro_batch.get("multi_modal_inputs", [])) + input_ids = micro_batch["input_ids"] + position_ids = micro_batch["position_ids"] + output_args = {} + + if use_remove_padding: + input_ids = input_ids.values().unsqueeze(0) + if position_ids.dim() == 3: + position_ids = position_ids.values().unsqueeze(1) + else: + position_ids = position_ids.values().unsqueeze(0) + + labels = torch.roll(input_ids, shifts=-1, dims=1) + attn_type = self.trainer.model_args.attn_type + attention_mask = get_attention_masks( + input_batch=input_ids, + positions=position_ids, + attn_type=attn_type, + ) + else: + loss_mask = micro_batch["loss_mask"] + pad_token_id = tu.get_non_tensor_data(data=micro_batch, key="pad_token_id", default=0) + batch_size = micro_batch.batch_size[0] + max_seq_len = max(input_ids.offsets().diff()) + + labels = torch.roll(input_ids.values(), shifts=-1, dims=0) + input_ids = torch.nested.to_padded_tensor( + input_ids, padding=pad_token_id, output_size=(batch_size, max_seq_len) + ) + + if position_ids.dim() == 3: + position_ids = torch.nested.to_padded_tensor( + position_ids, padding=0, output_size=(batch_size, 4, max_seq_len) + ).transpose(0, 1) + else: + position_ids = torch.nested.to_padded_tensor( + position_ids, padding=0, output_size=(batch_size, max_seq_len) + ) + + attention_mask_list = [torch.ones_like(t, dtype=torch.int32) for t in loss_mask] + attention_mask = torch.nested.as_nested_tensor(attention_mask_list, layout=torch.jagged) + attention_mask = torch.nested.to_padded_tensor( + attention_mask, padding=0, output_size=(batch_size, max_seq_len) + ) + + extra_inputs = { + "positions": position_ids, + } + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {"attention_masks": attention_mask} + if self.parallel_dims.cp_enabled: + input_ids, labels, extra_kwargs = prepare_context_parallel_input( + input_ids, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + self.trainer.device, + self.trainer.job_config.parallelism.context_parallel_load_balancer, + ) + + # TODO(jessicazhong): multimodal is not yet supported for Torchtitan engine + extra_inputs.update(multi_modal_inputs) + output_args["labels"] = labels + return input_ids, extra_inputs, extra_kwargs, output_args + + def prepare_model_outputs(self, logits, output_args, micro_batch: TensorDict): + use_remove_padding = tu.get_non_tensor_data(data=micro_batch, key="use_remove_padding", default=True) + pad_mode = tu.get_non_tensor_data(data=micro_batch, key="pad_mode", default=DatasetPadMode.NO_PADDING) + assert pad_mode == DatasetPadMode.NO_PADDING, f"pad_mode {pad_mode} not supported" + + temperature = micro_batch["temperature"] + calculate_entropy = tu.get_non_tensor_data(data=micro_batch, key="calculate_entropy", default=False) + labels = output_args["labels"] + model_output = {} + + input_ids = micro_batch["input_ids"] + cu_seqlens = input_ids.offsets() + if use_remove_padding: + labels = labels.squeeze(0) + logits_rmpad = logits.squeeze(0) + # PyTorch's autograd doesn't allow in-place modification of views when gradients need to flow back + logits_rmpad = logits_rmpad / temperature + + inplace_backward = True + if calculate_entropy: + inplace_backward = False + log_probs = logprobs_from_logits( + logits=logits_rmpad, + labels=labels, + inplace_backward=inplace_backward, + ) + + if calculate_entropy: + if not self.engine_config.entropy_checkpointing: + entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) + else: + entropy_rmpad = torch.utils.checkpoint.checkpoint(self.compute_entropy_from_logits, logits_rmpad) + + log_probs = torch.nested.nested_tensor_from_jagged(log_probs.squeeze(0), cu_seqlens) + if calculate_entropy: + entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) + else: + logits.div_(temperature) + if calculate_entropy: + if not self.engine_config.entropy_checkpointing: + entropy = verl_F.entropy_from_logits(logits) + else: + entropy = torch.utils.checkpoint.checkpoint(verl_F.entropy_from_logits, logits) + + seq_lengths = cu_seqlens.diff() + starts = torch.zeros_like(seq_lengths, dtype=torch.int64) + logits = torch.nested.narrow(logits, 1, starts, seq_lengths, layout=torch.jagged) + logits_rmpad = torch.cat([t for t in logits.unbind()]) + log_probs = logprobs_from_logits(logits=logits_rmpad, labels=output_args["labels"]) + log_probs = torch.nested.nested_tensor_from_jagged(log_probs, cu_seqlens) + if calculate_entropy: + entropy = torch.nested.narrow(entropy, 1, starts, seq_lengths, layout=torch.jagged) + entropy_rmpad = torch.cat([t for t in entropy.unbind()]) + entropy = torch.nested.nested_tensor_from_jagged(entropy_rmpad, cu_seqlens) + + model_output["log_probs"] = log_probs + if calculate_entropy: + model_output["entropy"] = entropy + + return model_output + + def forward_step(self, micro_batch: TensorDict, loss_function, forward_only): + device_name = get_device_name() + micro_batch = micro_batch.to(get_device_id()) + input_ids, extra_inputs, extra_kwargs, output_args = self.prepare_model_inputs(micro_batch=micro_batch) + + with torch.autocast(device_type=device_name, dtype=torch.bfloat16): + logits = self.model_forward_step(inputs=input_ids, extra_inputs=extra_inputs, extra_kwargs=extra_kwargs) + + model_output = self.prepare_model_outputs(logits=logits, output_args=output_args, micro_batch=micro_batch) + + if loss_function is not None: + loss, metrics = loss_function( + model_output=model_output, data=micro_batch, dp_group=self.get_data_parallel_group() + ) + else: + assert forward_only, "forward_only must be True when loss_function is None" + loss = torch.tensor(1.0, device=device_name) + metrics = {} + + output = { + "model_output": model_output, + "loss": loss.detach().item(), + "metrics": metrics, + } + + return loss, output diff --git a/verl/workers/engine/torchtitan/utils.py b/verl/workers/engine/torchtitan/utils.py new file mode 100644 index 00000000000..686fb94e6b2 --- /dev/null +++ b/verl/workers/engine/torchtitan/utils.py @@ -0,0 +1,213 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import torch +import torch.nn as nn +from torch.distributed._composable.fsdp import FSDPModule +from torch.nn.attention.flex_attention import _mask_mod_signature, and_masks +from torchtitan.models.attention import VarlenMetadata, create_attention_mask, get_causal_mask_mod +from torchtitan.protocols.model import AttentionMasksType + +logger = logging.getLogger(__name__) + +# Mapping from HuggingFace model_type to torchtitan model name. +# Torchtitan models not mapped here: +# - flux: diffusion model, not applicable to verl's RL/SFT workflows +# - llama3_ft: fault-tolerant variant of llama3, same HF models (mapped via "llama") +_HF_MODEL_TYPE_TO_TORCHTITAN_NAME = { + "qwen2": "qwen3", + "qwen3": "qwen3", + "qwen2_moe": "qwen3", + "qwen3_moe": "qwen3", + "llama": "llama3", + "llama4": "llama4", + "deepseek_v3": "deepseek_v3", + "gpt_oss": "gpt_oss", +} + + +def derive_torchtitan_name_and_flavor(hf_config) -> tuple[str, str]: + """Derive torchtitan model name and flavor from a HuggingFace config. + + The name is mapped from ``hf_config.model_type``. The flavor is found by + matching architecture parameters (dim, n_layers, vocab_size) against the + known flavors registered in the torchtitan TrainSpec. + + Args: + hf_config: A HuggingFace AutoConfig object. + + Returns: + A ``(name, flavor)`` tuple. + + Raises: + ValueError: If model_type is unsupported or no matching flavor is found. + """ + import torchtitan.protocols.train_spec as train_spec_module + + model_type = getattr(hf_config, "model_type", None) + if model_type is None: + raise ValueError("HuggingFace config does not have 'model_type' field") + + name = _HF_MODEL_TYPE_TO_TORCHTITAN_NAME.get(model_type) + if name is None: + raise ValueError( + f"Cannot derive torchtitan model name from HF model_type '{model_type}'. " + f"Supported types: {list(_HF_MODEL_TYPE_TO_TORCHTITAN_NAME.keys())}." + ) + + train_spec = train_spec_module.get_train_spec(name) + + hidden_size = hf_config.hidden_size + num_layers = hf_config.num_hidden_layers + vocab_size = hf_config.vocab_size + + for flavor_name, model_args in train_spec.model_args.items(): + if ( + getattr(model_args, "dim", None) == hidden_size + and getattr(model_args, "n_layers", None) == num_layers + and getattr(model_args, "vocab_size", None) == vocab_size + ): + logger.info( + f"Auto-derived torchtitan name='{name}', flavor='{flavor_name}' from HF model_type='{model_type}'" + ) + return name, flavor_name + + raise ValueError( + f"No matching torchtitan flavor found for model_type='{model_type}' " + f"(hidden_size={hidden_size}, num_hidden_layers={num_layers}, " + f"vocab_size={vocab_size}). " + f"Available flavors for '{name}': {list(train_spec.model_args.keys())}." + ) + + +def enable_fsdp_gradient_division(model: nn.Module, dp_size: int) -> None: + """ + Re-enable FSDP's automatic gradient division. + + TorchTitan calls disable_fsdp_gradient_division() which sets gradient_divide_factor=1.0. + This re-enables it by setting the factor to the specified dp_size, so gradients are + averaged across FSDP ranks. This is needed for verl's loss scaling (loss * dp_size) + to work correctly. + + Args: + model: The model (or model part) to enable gradient division on. + dp_size: The data parallel size to use as the gradient divide factor. + """ + + for module in model.modules(): + if isinstance(module, FSDPModule): + module.set_gradient_divide_factor(float(dp_size)) + + +def get_attention_masks( + input_batch: torch.Tensor, + positions: torch.Tensor, + attn_type: str, +) -> AttentionMasksType: + match attn_type: + case "flex": + return _get_flex_attention_masks( + input_batch, + positions, + ) + case "varlen": + return _create_varlen_metadata_for_document( + input_batch, + positions, + ) + case _: + raise TypeError("Only varlen and flex attn masks are supported") + + +def _get_document_mask_mod(positions: torch.Tensor) -> _mask_mod_signature: + # Detect boundaries from position resets + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + sequence_indices = (position_diff != 1).cumsum(-1) # [batch, seq] + + def document_mask(b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor) -> torch.Tensor: + return sequence_indices[b, q_idx] == sequence_indices[b, kv_idx] + + return document_mask + + +def _get_flex_attention_masks( + input_batch: torch.Tensor, + positions: torch.Tensor, +) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + B = input_batch.shape[0] + mask_mods.append(_get_document_mask_mod(positions=positions)) + return create_attention_mask(and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]) + + +def _create_varlen_metadata_for_document(input_batch: torch.Tensor, positions: torch.Tensor) -> VarlenMetadata: + """ + Creates cumulative sequence length indices needed for variable length attention + + Args: + input_batch: Input token IDs with shape [batch, seq]. + positions: Position IDs with shape [batch, seq]. Boundaries detected where + position diff != 1 (i.e., position resets). + + Returns: + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + """ + batch_size, seq_len = input_batch.shape + device = input_batch.device + + # Detect boundaries from position resets (where diff != 1) + first_dummy_value = positions[:, :1] - 1 + position_diff = torch.diff(positions, prepend=first_dummy_value, dim=-1) + # boundary_mask[b, i] is True if position i starts a new document + boundary_mask = position_diff != 1 # [batch, seq] + boundary_mask[:, 0] = True + + cu_seqlens_list, all_seq_lengths = [], [] + offset = 0 + + for b in range(batch_size): + # Find positions where new documents start + boundary_positions = boundary_mask[b].nonzero(as_tuple=True)[0].to(torch.int32) + sample_cu_seqlens = torch.cat( + [ + boundary_positions, + torch.tensor([seq_len], dtype=torch.int32, device=device), + ] + ) + sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens) + + seq_lengths = torch.diff(sample_cu_seqlens) + all_seq_lengths.append(seq_lengths) + + cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset + cu_seqlens_list.append(cu_seqlens_adjusted) + + offset += seq_len + + packed_cu_seqlens = torch.cat(cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)]) + + max_seqlen = 0 + if len(all_seq_lengths) > 0: + all_seq_lengths = torch.cat(all_seq_lengths) + # device to host sync but only done once per model forward + max_seqlen = all_seq_lengths.max().item() + + return VarlenMetadata( + cu_seq_q=packed_cu_seqlens, + cu_seq_k=packed_cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + ) diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 6f8029600ea..60f2e8d811d 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -179,17 +179,20 @@ def _postprocess_output(self, output, *, global_token_num, delta_time, forward_o # Here each metric in metrics can be a list (micro-batch metrics) or a singleton # we should always sum the loss of each micro-batch as we scale by global_bsz/global_token loss = torch.sum(torch.tensor(output.pop("loss"), device=self.device_name)) - torch.distributed.all_reduce( - loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group() - ) + dp_group = self.engine.get_data_parallel_group() + if dp_group is not None: + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG, group=dp_group) loss = loss.item() # For grad_norm, we do not perform all reduce because it is already been done when clipping grad grad_norm = metrics.pop("grad_norm", None) lr = metrics.pop("lr", None) - # For other metrics, we perform all gather in dp group - final_metrics = allgather_dict_into_dict(data=metrics, group=self.engine.get_data_parallel_group()) + # For other metrics, we perform all gather in dp group (only if DP > 1) + if dp_group is not None: + final_metrics = allgather_dict_into_dict(data=metrics, group=dp_group) + else: + final_metrics = metrics final_metrics["loss"] = loss if grad_norm is not None: final_metrics["grad_norm"] = grad_norm