Skip to content
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
542ab37
initial try to add Torchtitan Engine
acisseJZhong Jan 23, 2026
17e3f5f
sft running but loss mismatch
acisseJZhong Jan 30, 2026
745cb09
loss become large
acisseJZhong Feb 1, 2026
8f1183c
loss closer but still mismatch
acisseJZhong Feb 1, 2026
ad0f8d6
loss exactly matching with no parallelism
acisseJZhong Feb 3, 2026
9eeb171
non parallelism working
acisseJZhong Feb 6, 2026
200fb15
formatting
acisseJZhong Feb 7, 2026
e303e98
address comments
acisseJZhong Feb 8, 2026
22adbab
address comments
acisseJZhong Feb 8, 2026
1de4d17
address comments
acisseJZhong Feb 10, 2026
f6deb69
address comments
acisseJZhong Feb 10, 2026
82fe47d
address comments
acisseJZhong Feb 10, 2026
0571b51
tp/sp stuck
acisseJZhong Feb 12, 2026
2e6aac0
tp working
acisseJZhong Feb 13, 2026
46cefc9
tp working
acisseJZhong Feb 13, 2026
f73eaad
tp working
acisseJZhong Feb 13, 2026
df16152
tp working
acisseJZhong Feb 13, 2026
f2bd36c
tp working
acisseJZhong Feb 13, 2026
bada868
delete log
acisseJZhong Feb 13, 2026
26da997
address comments
acisseJZhong Feb 13, 2026
9f4510b
address comments
acisseJZhong Feb 13, 2026
902916f
address comments
acisseJZhong Feb 13, 2026
9703d2b
address comments
acisseJZhong Feb 13, 2026
f448b27
remove ci for now
acisseJZhong Feb 13, 2026
95abca1
remove ci for now
acisseJZhong Feb 13, 2026
f55959f
Re-enable FSDP's gradient division
acisseJZhong Feb 13, 2026
71e432b
Re-enable FSDP's gradient division
acisseJZhong Feb 13, 2026
133e69e
trigger ci
acisseJZhong Feb 14, 2026
712b38b
format
acisseJZhong Feb 14, 2026
f61d0ae
remove file
acisseJZhong Feb 14, 2026
ccbece3
move attn_type to engine
acisseJZhong Feb 14, 2026
db55a2e
remove log
acisseJZhong Feb 14, 2026
543b1d4
misc
acisseJZhong Feb 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions tests/special_e2e/sft/run_sft_engine.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
#hf download "${MODEL_ID}" --local-dir "${MODEL_PATH}"

SP_SIZE=${SP_SIZE:-1}
FSDP_SIZE=${FSDP_SIZE:-${NUM_GPUS}}
FSDP_SIZE=${FSDP_SIZE:-1}
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp"}

TP_SIZE=${TP_SIZE:-1}
Expand All @@ -44,6 +44,8 @@ USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}

FSDP_ENGINE_CONFIG="\
engine=${backend} \
model=hf_model \
model.path=$MODEL_PATH \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
Expand All @@ -58,6 +60,8 @@ FSDP_ENGINE_CONFIG="\

VEOMNI_ENGINE_CONFIG="\
engine=${backend} \
model=hf_model \
model.path=$MODEL_PATH \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
Expand All @@ -71,6 +75,8 @@ VEOMNI_ENGINE_CONFIG="\

MEGATRON_ENGINE_CONFIG="\
engine=${backend} \
model=hf_model \
model.path=$MODEL_PATH \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
Expand All @@ -87,6 +93,29 @@ MEGATRON_ENGINE_CONFIG="\
+engine.override_transformer_config.context_parallel_size=${CP_SIZE} \
engine.use_mbridge=True"

TORCHTITAN_ENGINE_CONFIG="\
engine=${backend} \
model=hf_model \
model.torchtitan.name=qwen3 \
model.torchtitan.flavor=0.6B \
model.torchtitan.attn_type=flex \
model.path=${MODEL_PATH} \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
optim.weight_decay=0.1 \
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.min_lr_factor=0.1 \
optim.decay_type=cosine \
optim.total_training_steps=1000 \
engine.tensor_parallel_size=${TP_SIZE} \
engine.pipeline_parallel_size=${PP_SIZE} \
engine.context_parallel_size=${CP_SIZE} \
engine.data_parallel_shard_size=${FSDP_SIZE} \
engine.use_torch_compile=False"


if [ "$backend" = "fsdp" ]; then
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
echo "Using fsdp engine"
Expand All @@ -95,6 +124,10 @@ elif [ "$backend" = "veomni" ]; then
ENGINE_CONFIG="$VEOMNI_ENGINE_CONFIG"
echo "Using veomni engine"
exp_name=gsm8k-${backend}-sp${SP_SIZE}-fsdp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
elif [ "$backend" = "torchtitan" ]; then
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please verify different parallelism in tests/special_e2e/sft/test_sft_engine_all.sh

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good. I will incorporate TP/SP with this PR, for other parallelism, will have separate PRs.

ENGINE_CONFIG="$TORCHTITAN_ENGINE_CONFIG"
echo "Using torchtitan engine"
exp_name=gsm8k-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-cp${CP_SIZE}-dp${FSDP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}-mode-${mode}
else
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
echo "Using megatron engine"
Expand All @@ -112,8 +145,8 @@ $COMMAND \
data.use_dynamic_bsz=True \
data.max_token_len_per_gpu=2048 \
data.messages_key=messages \
model.path=$MODEL_PATH \
model.use_remove_padding=${USE_REMOVE_PADDING} \
data.ignore_input_ids_mismatch=True \
${ENGINE_CONFIG} \
trainer.test_freq=after_each_epoch \
trainer.save_freq=-1 \
Expand All @@ -128,5 +161,5 @@ $COMMAND \
# trainer.total_training_steps=${TOTAL_TRAIN_STEP} \
# trainer.checkpoint.save_contents=[model,optimizer,extra,hf_model] \
# trainer.max_ckpt_to_keep=1 \
rm -rf "${ckpts_home:?}/*"

rm -rf "${ckpts_home:?}/*"
9 changes: 9 additions & 0 deletions tests/special_e2e/sft/test_sft_engine_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 b
echo "run with tp2 pp2 vpp2 cp2 num_gpus8 mode=ray"
BACKEND=megatron TP_SIZE=2 PP_SIZE=2 VPP_SIZE=${VPP_SIZE} CP_SIZE=2 NUM_GPUS=8 mode=ray bash tests/special_e2e/sft/run_sft_engine.sh

# TODO: Will add back torchtitan CI once everything is ready
# # test with torchtitan fsdp=2
# echo "run with tp1 pp1 cp1 fsdp2 num_gpus2"
# BACKEND=torchtitan TP_SIZE=1 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=2 bash tests/special_e2e/sft/run_sft_engine.sh

# # test with torchtitan tp2 fsdp=2
# echo "run with tp2 pp1 cp1 fsdp2 num_gpus4"
# BACKEND=torchtitan TP_SIZE=2 PP_SIZE=1 CP_SIZE=1 FSDP_SIZE=2 NUM_GPUS=4 bash tests/special_e2e/sft/run_sft_engine.sh

python3 tests/special_e2e/sft/compare_sft_engine_results.py

rm -rf ~/verl/test/log
1 change: 1 addition & 0 deletions tests/special_sanity/check_device_api_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"verl/workers/engine/utils.py", # appear in enable_full_determinism
"verl/workers/engine/fsdp/transformer_impl.py", # appear in default device_name
"verl/workers/engine/veomni/transformer_impl.py", # appear in default device_name
"verl/workers/engine/torchtitan/transformer_impl.py", # appear in default device_name
"verl/workers/rollout/vllm_rollout/vllm_async_server.py", # appear in config.cudagraph_capture_sizes
"verl/workers/rollout/sglang_rollout/async_sglang_server.py", # manually set CUDA_VISIBLE_DEVICES
"verl/workers/rollout/trtllm_rollout/trtllm_async_server.py", # appear in config.cudagraph_capture_sizes
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/_generated_ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,11 @@ actor_rollout_ref:
speculative_num_draft_tokens: 4
method: mtp
num_speculative_tokens: 1
torchtitan:
name: null
flavor: null
attn_type: sdpa
attn_mask_type: causal
lora:
type: lora
merge: false
Expand Down
7 changes: 7 additions & 0 deletions verl/trainer/config/_generated_ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ actor_rollout_ref:
min_lr_ratio: 0.0
num_cycles: 0.5
lr_scheduler_type: constant
zero_indexed_step: true
warmup_style: null
override_optimizer_config: null
fsdp_config:
Expand Down Expand Up @@ -340,6 +341,11 @@ actor_rollout_ref:
speculative_num_draft_tokens: 4
method: mtp
num_speculative_tokens: 1
torchtitan:
name: null
flavor: null
attn_type: sdpa
attn_mask_type: causal
hybrid_engine: true
nccl_timeout: 600
data:
Expand Down Expand Up @@ -399,6 +405,7 @@ critic:
min_lr_ratio: 0.0
num_cycles: 0.5
lr_scheduler_type: constant
zero_indexed_step: true
warmup_style: null
override_optimizer_config: null
model:
Expand Down
5 changes: 5 additions & 0 deletions verl/trainer/config/_generated_ppo_veomni_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ actor_rollout_ref:
speculative_num_draft_tokens: 4
method: mtp
num_speculative_tokens: 1
torchtitan:
name: null
flavor: null
attn_type: sdpa
attn_mask_type: causal
hybrid_engine: true
nccl_timeout: 600
data:
Expand Down
65 changes: 65 additions & 0 deletions verl/trainer/config/engine/torchtitan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Target class for this configuration
_target_: verl.workers.config.TorchtitanEngineConfig

# policy for wrapping the model
wrap_policy:
# Minimum number of parameters to trigger wrapping a layer with FSDP
min_num_params: 0

# The policy for applying `reshard_after_forward` within an FSDP setup
# Options: "default", "always", "never"
reshard_after_forward: default

# Prefetch the next forward-pass all-gather before the current forward computation.
forward_prefetch: false

# Whether to use original parameters
use_orig_params: false

# Mixed precision configuration for FSDP
mixed_precision: false

# Whether to use torch compile
use_torch_compile: true

# Whether to use entropy_from_logits_with_chunking
entropy_from_logits_with_chunking: false

# Whether to use entropy checkpointing
entropy_checkpointing: false

# Data parallel size (FSDP group size)
data_parallel_size: 1

# Data parallel replicate size
data_parallel_replicate_size: 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any document explain these parallelism?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


# Data parallel shard size
data_parallel_shard_size: 1

# Tensor parallel size
tensor_parallel_size: 1

# Expert parallel size
expert_parallel_size: 1

# Pipeline parallel size
pipeline_parallel_size: 1

# Context parallel size
context_parallel_size: 1

# Strategy
strategy: torchtitan

# Random seed for reproducibility
seed: 42

# Whether to enable full determinism for distributed training, only for debugging
full_determinism: false

# Whether to use forward only
forward_only: false

# Mixed precision training param dtype
dtype: bfloat16
16 changes: 16 additions & 0 deletions verl/trainer/config/model/hf_model.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,19 @@ mtp:

method: mtp
num_speculative_tokens: 1

# Torchtitan backend configuration
# Only used when engine backend is set to "torchtitan"
torchtitan:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still not desirable. All the models including names and flavors must start from a single huggingface folder. We can introduce a general model_implementation dict so that users can write attn_type and attn_mask_type inside this sub-config

Copy link
Collaborator Author

@acisseJZhong acisseJZhong Feb 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a helper function to derive model name and flavor from hf config, and get rid of attn_mask_type since it's not used. For attn_type, I moved it to TorchtitanEngineConfig since it's more torchtitan specific field(I don't want other training engine to have this field). Please let me know if you have different opinions @vermouth1992


# model name for torchtitan (e.g., "qwen3", "llama3")
name: null

# model flavor/size (e.g., "0.6B", "8B")
flavor: null

# attention type (e.g., "sdpa", "flex", "varlen")
attn_type: sdpa

# attention mask type (e.g., "causal", "block_causal")
attn_mask_type: causal
3 changes: 3 additions & 0 deletions verl/trainer/config/optim/fsdp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ num_cycles: 0.5
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant

# Whether the LR schedule uses 0-indexed steps
zero_indexed_step: true

# deprecated
warmup_style: null

Expand Down
35 changes: 35 additions & 0 deletions verl/trainer/config/optim/torchtitan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Target class for this configuration
_target_: verl.workers.config.TorchtitanOptimizerConfig

# Optimizer name
name: AdamW

# Learning rate
lr: 1e-3

# LR warmup steps ratio
lr_warmup_steps_ratio: 0.0

# Total training steps
total_training_steps: -1

# Weight decay
weight_decay: 0.01

# LR warmup steps
lr_warmup_steps: -1

# Betas for Adam optimizer
betas: [0.9, 0.999]

# Clip gradient
clip_grad: 1.0

# Epsilon for Adam optimizer
eps: 1e-8

# Decay type: "linear", "sqrt", or "cosine"
decay_type: linear

# Minimum LR factor for cosine schedule
min_lr_factor: 0.0
16 changes: 11 additions & 5 deletions verl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,16 +238,22 @@ def _get_batch_seqlens(self, data):
batch_seqlens: torch.Tensor = data["attention_mask"].sum(dim=-1)
batch_seqlens = batch_seqlens.to(self.device_name) # (global_bsz // dp)

dp_group = self.engine.get_data_parallel_group()
dp_size = self.engine.get_data_parallel_size()

if dp_size == 1 or dp_group is None:
return batch_seqlens.tolist()

output_tensor = torch.empty(
(batch_seqlens.shape[0] * self.engine.get_data_parallel_size(),),
(batch_seqlens.shape[0] * dp_size,),
dtype=batch_seqlens.dtype,
device=self.device_name,
) # (global_bsz,)

torch.distributed.all_gather_into_tensor(
output_tensor=output_tensor,
input_tensor=batch_seqlens,
group=self.engine.get_data_parallel_group(),
group=dp_group,
)

batch_seqlens = output_tensor.tolist()
Expand Down Expand Up @@ -372,9 +378,9 @@ def fit(self):
if self.engine.is_mp_src_rank_with_outputs():
val_loss = torch.mean(torch.tensor(val_losses, device=self.device_name))
# average over data parallel group
torch.distributed.all_reduce(
val_loss, op=torch.distributed.ReduceOp.AVG, group=self.engine.get_data_parallel_group()
)
dp_group = self.engine.get_data_parallel_group()
if dp_group is not None:
torch.distributed.all_reduce(val_loss, op=torch.distributed.ReduceOp.AVG, group=dp_group)

if is_logging:
metric = {"val/loss": val_loss.detach().item()}
Expand Down
2 changes: 1 addition & 1 deletion verl/utils/seqlen_balancing.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def rearrange_micro_batches(
if min_num_micro_batch is not None:
# used to support pp
num_micro_batches = max(min_num_micro_batch, num_micro_batches)
if dist.is_initialized() and same_micro_num_in_dp:
if dist.is_initialized() and same_micro_num_in_dp and dp_group is not None:
num_micro_batches = torch.tensor([num_micro_batches], device=get_device_name())
dist.all_reduce(num_micro_batches, op=dist.ReduceOp.MAX, group=dp_group)
num_micro_batches = num_micro_batches.cpu().item()
Expand Down
6 changes: 6 additions & 0 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ def get_cosine_schedule_with_warmup(
num_cycles: float = 0.5,
last_epoch: int = -1,
init_lr_ratio: float = None,
zero_indexed_step: bool = True,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
Expand All @@ -731,6 +732,9 @@ def get_cosine_schedule_with_warmup(
The index of the last epoch when resuming training.
init_lr_ratio (:obj:`float`, `optional`, defaults to None):
The initial lr ratio w.r.t the maximum.
zero_indexed_step (:obj:`bool`, `optional`, defaults to True):
Whether the LR schedule uses 0-indexed steps. If True (default), step counting starts at 0.
If False (used by torchtitan), step counting starts at 1.
Return:
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
Expand All @@ -743,6 +747,8 @@ def get_cosine_schedule_with_warmup(
assert init_lr_ratio >= 0 and init_lr_ratio <= 1.0

def lr_lambda(current_step):
if not zero_indexed_step:
current_step += 1
if current_step < num_warmup_steps:
return init_lr_ratio + (1.0 - init_lr_ratio) * (float(current_step) / float(max(1, num_warmup_steps)))
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
Expand Down
Loading
Loading