-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[trainer] feat: Add Torchtitan as alternative training engine #5051
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
542ab37
17e3f5f
745cb09
8f1183c
ad0f8d6
9eeb171
200fb15
e303e98
22adbab
1de4d17
f6deb69
82fe47d
0571b51
2e6aac0
46cefc9
f73eaad
df16152
f2bd36c
bada868
26da997
9f4510b
902916f
9703d2b
f448b27
95abca1
f55959f
71e432b
133e69e
712b38b
f61d0ae
ccbece3
db55a2e
543b1d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,14 @@ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 b | |
| echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray" | ||
| BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh | ||
|
|
||
| # test with torchtitan fsdp=1 | ||
| echo "run with tp1 pp1 cp1 fsdp2 num_gpus2" | ||
| BACKEND=torchtitan TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=2 bash tests/special_e2e/sft/run_sft_engine.sh | ||
|
||
|
|
||
| # test with torchtitan tp2 fsdp=2 | ||
| echo "run with tp2 pp1 cp1 fsdp2 num_gpus4" | ||
| BACKEND=torchtitan TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh | ||
|
|
||
| python3 tests/special_e2e/sft/compare_sft_engine_results.py | ||
|
|
||
| rm -rf ~/verl/test/log | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| # Target class for this configuration | ||
| _target_: verl.workers.config.TorchtitanEngineConfig | ||
|
|
||
| # policy for wrapping the model | ||
| wrap_policy: | ||
| # Minimum number of parameters to trigger wrapping a layer with FSDP | ||
| min_num_params: 0 | ||
|
|
||
| # The policy for applying `reshard_after_forward` within an FSDP setup | ||
| # Options: "default", "always", "never" | ||
| reshard_after_forward: default | ||
|
|
||
| # Prefetch the next forward-pass all-gather before the current forward computation. | ||
| forward_prefetch: false | ||
|
|
||
| # Whether to use original parameters | ||
| use_orig_params: false | ||
|
|
||
| # Mixed precision configuration for FSDP | ||
| mixed_precision: false | ||
|
|
||
| # Whether to use torch compile | ||
| use_torch_compile: true | ||
|
|
||
| # Whether to use entropy_from_logits_with_chunking | ||
| entropy_from_logits_with_chunking: false | ||
|
|
||
| # Whether to use entropy checkpointing | ||
| entropy_checkpointing: false | ||
|
|
||
| # Data parallel size (FSDP group size) | ||
| data_parallel_size: 1 | ||
|
|
||
| # Data parallel replicate size | ||
| data_parallel_replicate_size: 1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any document explain these parallelism?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| # Data parallel shard size | ||
| data_parallel_shard_size: 1 | ||
|
|
||
| # Tensor parallel size | ||
| tensor_parallel_size: 1 | ||
|
|
||
| # Expert parallel size | ||
| expert_parallel_size: 1 | ||
|
|
||
| # Pipeline parallel size | ||
| pipeline_parallel_size: 1 | ||
|
|
||
| # Context parallel size | ||
| context_parallel_size: 1 | ||
|
|
||
| # Strategy | ||
| strategy: torchtitan | ||
|
|
||
| # Random seed for reproducibility | ||
| seed: 42 | ||
|
|
||
| # Whether to enable full determinism for distributed training, only for debugging | ||
| full_determinism: false | ||
|
|
||
| # Whether to use forward only | ||
| forward_only: false | ||
|
|
||
| # Mixed precision training param dtype | ||
| dtype: bfloat16 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Format checks enforced on CI: | ||
|
||
| # 1. Comments must appear above each field. | ||
| # 2. There must be a blank line between each field. | ||
| # 3. Inline comments (after a field on the same line) are not allowed. | ||
| # 4. Indentation level is respected for nested fields. | ||
|
|
||
| _target_: verl.workers.config.TorchtitanModelConfig | ||
|
||
|
|
||
| # Model name (e.g., "qwen3", "llama3") | ||
| name: qwen3 | ||
|
|
||
| # Model flavor/size (e.g., "0.6B", "1.5B", "7B") | ||
| flavor: "0.6B" | ||
|
|
||
| # Path to HuggingFace model (tokenizer, config, weights, etc.) | ||
| path: ./assets/hf/Qwen3-0.6B | ||
|
|
||
| # whether to use remove padding. Only valid when we use hf model definition | ||
| use_remove_padding: True | ||
|
|
||
| # Attention type for the model (e.g., "sdpa", "flex", "varlen") | ||
| attn_type: sdpa | ||
|
|
||
| # Attention mask type for the model (e.g., "causal", "document_mask", "block_causal") | ||
| attn_mask_type: causal | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,35 @@ | ||
| # Target class for this configuration | ||
| _target_: verl.workers.config.TorchtitanOptimizerConfig | ||
wuxibin89 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| # Optimizer name | ||
| name: AdamW | ||
|
|
||
| # Learning rate | ||
| lr: 1e-3 | ||
|
|
||
| # LR warmup steps ratio | ||
| lr_warmup_steps_ratio: 0.0 | ||
|
|
||
| # Total training steps | ||
| total_training_steps: -1 | ||
|
|
||
| # Weight decay | ||
| weight_decay: 0.01 | ||
|
|
||
| # LR warmup steps | ||
| lr_warmup_steps: -1 | ||
|
|
||
| # Betas for Adam optimizer | ||
| betas: [0.9, 0.999] | ||
|
|
||
| # Clip gradient | ||
| clip_grad: 1.0 | ||
|
|
||
| # Epsilon for Adam optimizer | ||
| eps: 1e-8 | ||
|
|
||
| # Decay type: "linear", "sqrt", or "cosine" | ||
| decay_type: linear | ||
|
|
||
| # Minimum LR factor for cosine schedule | ||
| min_lr_factor: 0.0 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -743,6 +743,8 @@ def get_cosine_schedule_with_warmup( | |
| assert init_lr_ratio >= 0 and init_lr_ratio <= 1.0 | ||
|
|
||
| def lr_lambda(current_step): | ||
| # # 0-indexed step, hence + 1 adjustments | ||
| current_step += 1 | ||
|
||
| if current_step < num_warmup_steps: | ||
| return init_lr_ratio + (1.0 - init_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps))) | ||
| progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -27,6 +27,7 @@ | |||||||||||||||||||||||||||||
| "FSDPEngineConfig", | ||||||||||||||||||||||||||||||
| "McoreEngineConfig", | ||||||||||||||||||||||||||||||
| "TrainingWorkerConfig", | ||||||||||||||||||||||||||||||
| "TorchtitanEngineConfig", | ||||||||||||||||||||||||||||||
| "VeOmniEngineConfig", | ||||||||||||||||||||||||||||||
| "EngineConfig", | ||||||||||||||||||||||||||||||
| "EngineRouterReplayConfig", | ||||||||||||||||||||||||||||||
|
|
@@ -309,6 +310,62 @@ def __post_init__(self): | |||||||||||||||||||||||||||||
| assert self.strategy in ["veomni"], f"strategy {self.strategy} not supported" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||
| class TorchtitanEngineConfig(EngineConfig): | ||||||||||||||||||||||||||||||
| """Configuration for Torchtitan. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy. | ||||||||||||||||||||||||||||||
| reshard_after_forward (Literal["default", "always", "never"]): The policy for applying | ||||||||||||||||||||||||||||||
| `reshard_after_forward` within an FSDP setup, default "default" | ||||||||||||||||||||||||||||||
| forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False | ||||||||||||||||||||||||||||||
| use_orig_params (bool): Whether to use original parameters when initialize FSDP, default False | ||||||||||||||||||||||||||||||
| mixed_precision (bool): Mixed precision configuration for FSDP, default False | ||||||||||||||||||||||||||||||
| offload_policy (bool): Whether to offload policy model parameters, default False | ||||||||||||||||||||||||||||||
| data_parallel_size (int): Data parallel group size, default 1 | ||||||||||||||||||||||||||||||
| data_parallel_replicate_size (int): Data parallel replicate size, default 1 | ||||||||||||||||||||||||||||||
| data_parallel_shard_size (int): Data parallel shard degree, default 1 | ||||||||||||||||||||||||||||||
| tensor_parallel_size (int): Tensor parallel size, default 1 | ||||||||||||||||||||||||||||||
| expert_parallel_size (int): Expert parallel size, default 1 | ||||||||||||||||||||||||||||||
| expert_tensor_parallel_size (int): Expert tensor parallel size, default 1 | ||||||||||||||||||||||||||||||
| pipeline_parallel_size (int): Pipeline parallel size, default 1 | ||||||||||||||||||||||||||||||
| context_parallel_size (int): Context parallel size, default 1 | ||||||||||||||||||||||||||||||
| strategy (str): Strategy to use for distributed training, default "torchtitan" | ||||||||||||||||||||||||||||||
|
Comment on lines
+328
to
+337
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The descriptions for
Suggested change
|
||||||||||||||||||||||||||||||
| seed (int): Random seed for reproducibility. | ||||||||||||||||||||||||||||||
| full_determinism (bool): If true, enable_full_determinism is called to ensure reproducible results | ||||||||||||||||||||||||||||||
| in distributed training. Important: this will negatively impact performance, so only use it for | ||||||||||||||||||||||||||||||
| debugging. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| wrap_policy: dict[str, Any] = field(default_factory=dict) | ||||||||||||||||||||||||||||||
| reshard_after_forward: Literal["default", "always", "never"] = "default" | ||||||||||||||||||||||||||||||
| forward_prefetch: bool = False | ||||||||||||||||||||||||||||||
| use_orig_params: bool = False | ||||||||||||||||||||||||||||||
| mixed_precision: bool = False | ||||||||||||||||||||||||||||||
| offload_policy: bool = False | ||||||||||||||||||||||||||||||
| use_torch_compile: bool = True | ||||||||||||||||||||||||||||||
| entropy_from_logits_with_chunking: bool = False | ||||||||||||||||||||||||||||||
| entropy_checkpointing: bool = False | ||||||||||||||||||||||||||||||
| data_parallel_size: int = 1 | ||||||||||||||||||||||||||||||
| data_parallel_replicate_size: int = 1 | ||||||||||||||||||||||||||||||
| data_parallel_shard_size: int = 1 | ||||||||||||||||||||||||||||||
| tensor_parallel_size: int = 1 | ||||||||||||||||||||||||||||||
| expert_parallel_size: int = 1 | ||||||||||||||||||||||||||||||
| expert_tensor_parallel_size: int = 1 | ||||||||||||||||||||||||||||||
| pipeline_parallel_size: int = 1 | ||||||||||||||||||||||||||||||
| context_parallel_size: int = 1 | ||||||||||||||||||||||||||||||
| strategy: str = "torchtitan" | ||||||||||||||||||||||||||||||
| seed: int = 42 | ||||||||||||||||||||||||||||||
| full_determinism: bool = False | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def __post_init__(self): | ||||||||||||||||||||||||||||||
| super().__post_init__() | ||||||||||||||||||||||||||||||
| assert self.strategy in ["torchtitan"], f"strategy {self.strategy} not supported" | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @dataclass | ||||||||||||||||||||||||||||||
| class TrainingWorkerConfig(BaseConfig): | ||||||||||||||||||||||||||||||
| model_type: str = None # model type (language_model/value_model) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please verify different parallelism in
tests/special_e2e/sft/test_sft_engine_all.shThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good. I will incorporate TP/SP with this PR, for other parallelism, will have separate PRs.