Skip to content

Commit 29c11ba

Browse files
authored
[megatron] support qwen3.5 models for megatron, bump mbridge + megatron-core to latest (NovaSky-AI#1425)
GPU CI: https://github.com/NovaSky-AI/SkyRL/actions/runs/23869520430 Megatron GPU CI: https://github.com/NovaSky-AI/SkyRL/actions/runs/23869278330 Megatron GPU CI #2: https://github.com/NovaSky-AI/SkyRL/actions/runs/24045414612 megatron gpu CI #3: https://github.com/NovaSky-AI/SkyRL/actions/runs/24054807024 WandB run for Qwen3.5-0.8B: https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k_megatron/runs/5cm9tg0j <img width="555" height="625" alt="image" src="https://github.com/user-attachments/assets/d3867343-6bc7-49a3-9d29-6c62f20381b3" /> <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1425" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
1 parent 3b0a148 commit 29c11ba

3 files changed

Lines changed: 2370 additions & 4945 deletions

File tree

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
set -x
2+
3+
# Colocated GRPO training+generation for Qwen3.5-0.8B on GSM8K with Megatron.
4+
5+
# uv run examples/train/gsm8k/gsm8k_dataset.py --output_dir $HOME/data/gsm8k
6+
# export WANDB_API_KEY=<your_key_here>
7+
# bash examples/train/megatron/run_megatron_qwen3.5.sh
8+
9+
DATA_DIR="$HOME/data/gsm8k"
10+
LOGGER="wandb" # change to "console" to print to stdout
11+
MODEL_NAME="Qwen/Qwen3.5-0.8B"
12+
13+
INFERENCE_BACKEND="vllm" # currently only vllm is supported for megatron
14+
15+
NUM_NODES=1
16+
NUM_GPUS=4
17+
18+
MEGATRON_TP=1
19+
MEGATRON_PP=1
20+
MEGATRON_CP=1
21+
22+
NUM_INFERENCE_ENGINES=1
23+
INFERENCE_ENGINE_TP=4
24+
25+
# Qwen3.5 flags
26+
USE_SAMPLE_PACKING=false # sample packing is not yet supported for GDN layers in megatron - see: https://github.com/NVIDIA/Megatron-LM/pull/2644
27+
28+
uv run --isolated --extra megatron -m skyrl.train.entrypoints.main_base \
29+
data.train_data="['$DATA_DIR/train.parquet']" \
30+
data.val_data="['$DATA_DIR/validation.parquet']" \
31+
trainer.algorithm.advantage_estimator="grpo" \
32+
trainer.policy.model.path=$MODEL_NAME \
33+
trainer.placement.colocate_all=true \
34+
trainer.strategy=megatron \
35+
trainer.placement.policy_num_nodes=$NUM_NODES \
36+
trainer.placement.policy_num_gpus_per_node=$NUM_GPUS \
37+
generator.inference_engine.num_engines=$NUM_INFERENCE_ENGINES \
38+
generator.inference_engine.tensor_parallel_size=$INFERENCE_ENGINE_TP \
39+
trainer.policy.megatron_config.tensor_model_parallel_size=$MEGATRON_TP \
40+
trainer.policy.megatron_config.pipeline_model_parallel_size=$MEGATRON_PP \
41+
trainer.policy.megatron_config.context_parallel_size=$MEGATRON_CP \
42+
trainer.use_sample_packing=$USE_SAMPLE_PACKING \
43+
trainer.epochs=20 \
44+
trainer.eval_batch_size=1024 \
45+
trainer.eval_before_train=false \
46+
trainer.eval_interval=5 \
47+
trainer.update_epochs_per_batch=1 \
48+
trainer.train_batch_size=128 \
49+
trainer.policy_mini_batch_size=64 \
50+
trainer.micro_forward_batch_size_per_gpu=4 \
51+
trainer.micro_train_batch_size_per_gpu=4 \
52+
trainer.ckpt_interval=10 \
53+
trainer.max_prompt_length=512 \
54+
generator.sampling_params.max_generate_length=1024 \
55+
trainer.policy.optimizer_config.lr=1.0e-6 \
56+
trainer.algorithm.use_kl_loss=false \
57+
generator.inference_engine.backend=$INFERENCE_BACKEND \
58+
generator.inference_engine.run_engines_locally=true \
59+
generator.inference_engine.weight_sync_backend=nccl \
60+
generator.inference_engine.async_engine=true \
61+
generator.batched=true \
62+
environment.env_class=gsm8k \
63+
generator.n_samples_per_prompt=5 \
64+
generator.inference_engine.gpu_memory_utilization=0.6 \
65+
trainer.logger="$LOGGER" \
66+
trainer.project_name="gsm8k_megatron" \
67+
trainer.run_name="gsm8k_megatron_tp${MEGATRON_TP}_pp${MEGATRON_PP}_cp${MEGATRON_CP}_qwen3.5-0.8b" \
68+
trainer.resume_mode=null \
69+
trainer.ckpt_path="$HOME/ckpts/gsm8k_megatron_ckpt" \
70+
$@

pyproject.toml

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,10 @@ megatron = [
124124
"torch==2.10.0; sys_platform == 'linux'",
125125
"flashinfer-python==0.6.6; sys_platform == 'linux' and platform_machine == 'x86_64'",
126126
"torchvision; sys_platform == 'linux'",
127-
"megatron-bridge==0.3.1; sys_platform == 'linux'",
128-
"megatron-core==0.16.1; sys_platform == 'linux'",
127+
# megatron-bridge requires Python 3.12+; pin megatron-core to the same
128+
# constraint so both packages are consistently available (or absent).
129+
"megatron-bridge; sys_platform == 'linux' and python_version >= '3.12'",
130+
"megatron-core; sys_platform == 'linux' and python_version >= '3.12'",
129131
"flashinfer-jit-cache==0.6.6; sys_platform == 'linux' and platform_machine == 'x86_64'",
130132
"nvidia-modelopt; sys_platform == 'linux'",
131133
]
@@ -215,8 +217,8 @@ override-dependencies = [
215217
"mamba-ssm; sys_platform == 'never'",
216218
"causal-conv1d; sys_platform == 'never'",
217219
"transformer-engine[pytorch]==2.10.0; sys_platform == 'linux'",
218-
"megatron-core==0.16.1; sys_platform == 'linux'",
219220
"transformers>=5.0.0,<=5.3.0; sys_platform == 'linux'",
221+
"megatron-core>=0.16.0; sys_platform == 'linux'",
220222
"ml_dtypes>=0.5.0; sys_platform == 'linux'",
221223
]
222224

@@ -261,6 +263,9 @@ torchvision = [
261263
{ index = "pytorch-cpu", marker = "sys_platform == 'darwin'" },
262264
]
263265
harbor = { git = "https://github.com/laude-institute/harbor", rev = "8c040e1bb010201fd3c75bee3dede2407b9f57cd" }
266+
megatron-bridge = {git = "https://github.com/NVIDIA-NeMo/Megatron-Bridge", rev = "420a7da37afea5eb4e8d3899d540c830b9c4cda2", marker = "sys_platform == 'linux'"}
267+
# megatron-core dev branch: https://github.com/NVIDIA/Megatron-LM/tree/dev latest as of 4/1/26
268+
megatron-core = {git = "https://github.com/NVIDIA/Megatron-LM", rev = "4ef64ebc468cd3da41a22d46a2db37163694e8e2", marker = "sys_platform == 'linux'"}
264269

265270
[tool.black]
266271
line-length = 120

0 commit comments

Comments
 (0)