Skip to content

Commit 5b706fb

Browse files
yefei12chenyefei.cyfGeLee-QGao016yzlnew
authored
add FP8 training and inference script for Qwen3-30B-A3B model (#845)
Co-authored-by: chenyefei.cyf <chenyefei.cyf@U-9V5T77LW-2356.local> Co-authored-by: GeLee-Q <865038696@qq.com> Co-authored-by: Gao016 <yngao016@163.com> Co-authored-by: yzlnew <yzlnew@gmail.com>
1 parent cf33f0e commit 5b706fb

File tree

2 files changed

+189
-1
lines changed

2 files changed

+189
-1
lines changed

examples/low_precision/README.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ This is an example of FP8 training and FP8 inference. Under FP8 training and inf
66

77
* `run-qwen3-4b-fp8.sh`: example launch script with Qwen3‑4B in FP8.
88

9+
* `run-qwen3-30b-a3b-fp8-two-nodes.sh`: example launch script for running Qwen3‑30B‑A3B in FP8 across two nodes.
10+
911
### Quick Start
1012

1113
1. [optional] Convert your HuggingFace weights to FP8 format. You can use `tools/convert_hf_to_fp8`, or directly write an FP8 format model config.
@@ -14,8 +16,14 @@ This is an example of FP8 training and FP8 inference. Under FP8 training and inf
1416

1517
```
1618
cd slime
17-
bash examples/fp8/run-qwen3-4b-fp8.sh
19+
20+
# Qwen3‑4B FP8 training (single node)
21+
bash examples/low_precision/run-qwen3-4b-fp8.sh
22+
23+
# Qwen3‑30B‑A3B FP8 training (two nodes)
24+
bash examples/low_precision/run-qwen3-30b-a3b-fp8-two-nodes.sh
1825
```
26+
1927
Following the above command will launch FP8 training. According to slime's design, if the model under `--hf-checkpoint` is FP8, it will automatically use FP8 quantization in weight updates.
2028

2129
3. Use the saved checkpoint for evaluation
Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
pkill -9 redis
13+
14+
set -ex
15+
16+
# will prevent ray from buffering stdout/stderr
17+
export PYTHONBUFFERED=16
18+
19+
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
20+
if [ "$NVLINK_COUNT" -gt 0 ]; then
21+
HAS_NVLINK=1
22+
else
23+
HAS_NVLINK=0
24+
fi
25+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
26+
27+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
28+
source "${SCRIPT_DIR}/../../scripts/models/qwen3-30B-A3B.sh"
29+
30+
# Base directory for checkpoints and related files (adjust if necessary)
31+
BASE_DIR="/root"
32+
33+
CKPT_ARGS=(
34+
--hf-checkpoint “${BASE_DIR}/Qwen3-30B-A3B-FP8/”
35+
--ref-load “${BASE_DIR}/Qwen3-30B-A3B_torch_dist/”
36+
--load “${BASE_DIR}/Qwen3-30B-A3B_slime/”
37+
--save “${BASE_DIR}/Qwen3-30B-A3B_slime/”
38+
--save-interval 20
39+
)
40+
41+
ROLLOUT_ARGS=(
42+
--prompt-data “${BASE_DIR}/dapo-math-17k.jsonl”
43+
--input-key prompt
44+
--label-key label
45+
--apply-chat-template
46+
--rollout-shuffle
47+
--rm-type math
48+
--num-rollout 200
49+
--rollout-batch-size 16
50+
--n-samples-per-prompt 8
51+
--rollout-max-response-len 8192
52+
--rollout-temperature 0.8
53+
54+
--global-batch-size 128
55+
--balance-data
56+
)
57+
58+
EVAL_ARGS=(
59+
--eval-interval 20
60+
--eval-prompt-data aime “${BASE_DIR}/aime-2024.jsonl”
61+
--n-samples-per-eval-prompt 16
62+
--eval-max-response-len 16384
63+
--eval-top-p 0.7
64+
)
65+
66+
PERF_ARGS=(
67+
--tensor-model-parallel-size 1
68+
--sequence-parallel
69+
--pipeline-model-parallel-size 4
70+
--context-parallel-size 1
71+
--expert-model-parallel-size 4
72+
--expert-tensor-parallel-size 1
73+
74+
--recompute-granularity full
75+
--recompute-method uniform
76+
--recompute-num-layers 1
77+
78+
# --micro-batch-size 1
79+
--use-dynamic-batch-size
80+
--max-tokens-per-gpu 20480
81+
82+
# use deepep for megatron
83+
--moe-enable-deepep
84+
--moe-token-dispatcher-type flex
85+
86+
# fp8
87+
--transformer-impl transformer_engine
88+
--bf16
89+
--fp8-format e4m3
90+
--fp8-recipe blockwise
91+
# --fp8-param-gather
92+
)
93+
94+
GRPO_ARGS=(
95+
--advantage-estimator grpo
96+
--use-kl-loss
97+
--kl-loss-coef 0.00
98+
--kl-loss-type low_var_kl
99+
--entropy-coef 0.00
100+
--eps-clip 0.2
101+
--eps-clip-high 0.28
102+
--use-tis
103+
)
104+
105+
OPTIMIZER_ARGS=(
106+
--optimizer adam
107+
--lr 1e-6
108+
--lr-decay-style constant
109+
--weight-decay 0.1
110+
--adam-beta1 0.9
111+
--adam-beta2 0.98
112+
113+
--optimizer-cpu-offload
114+
--overlap-cpu-optimizer-d2h-h2d
115+
--use-precision-aware-optimizer
116+
)
117+
118+
WANDB_ARGS=(
119+
#--use-wandb
120+
# --wandb-project slime-dev
121+
# --wandb-group qwen3-30B-A3B-test
122+
# --wandb-key ${WANDB_KEY}
123+
)
124+
125+
SGLANG_ARGS=(
126+
--rollout-num-gpus-per-engine 8
127+
--sglang-mem-fraction-static 0.6
128+
--sglang-cuda-graph-bs 1 2 4 8 $(seq 16 8 256)
129+
--sglang-expert-parallel-size 8
130+
--use-slime-router
131+
# --use-rollout-routing-replay
132+
)
133+
134+
MISC_ARGS=(
135+
# default dropout in megatron is 0.1
136+
--attention-dropout 0.0
137+
--hidden-dropout 0.0
138+
# should be good for model performance
139+
--accumulate-allreduce-grads-in-fp32
140+
--attention-softmax-in-fp32
141+
# need to comment this when using model with MLA
142+
--attention-backend flash
143+
)
144+
145+
# Get Ray Head node info automatically
146+
ip=$(ps aux | grep dashboard | grep -oP '(?<=--node-ip-address=)[0-9\.]+' | head -1)
147+
port=$(ps aux | grep dashboard | grep -oP '(?<=dashboard-port=)\d+' | head -1)
148+
export HEAD_NODE_ADDRESS="$ip"
149+
export DASHBOARD_PORT="$port"
150+
echo "Detected Ray Head IP: $HEAD_NODE_ADDRESS, Port: $DASHBOARD_PORT"
151+
152+
export RAY_ADDRESS="http://${HEAD_NODE_ADDRESS}:${DASHBOARD_PORT}"
153+
154+
# You should enable NVTE_FP8_BLOCK_SCALING_FP32_SCALES to use fp32 scales in fp8 training
155+
RUNTIME_ENV_JSON="{
156+
\"env_vars\": {
157+
\"PYTHONPATH\": \"/root/Megatron-LM/\",
158+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
159+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
160+
\"NVTE_FP8_BLOCK_SCALING_FP32_SCALES\": \"1\",
161+
\"NCCL_TIMEOUT_MS\":\"36000000\"
162+
}
163+
}"
164+
165+
ray job submit --address="${RAY_ADDRESS}" \
166+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
167+
-- python3 train.py \
168+
--actor-num-nodes 2 \
169+
--actor-num-gpus-per-node 8 \
170+
--colocate \
171+
${MODEL_ARGS[@]} \
172+
${CKPT_ARGS[@]} \
173+
${ROLLOUT_ARGS[@]} \
174+
${OPTIMIZER_ARGS[@]} \
175+
${GRPO_ARGS[@]} \
176+
${WANDB_ARGS[@]} \
177+
${PERF_ARGS[@]} \
178+
${EVAL_ARGS[@]} \
179+
${SGLANG_ARGS[@]} \
180+
${MISC_ARGS[@]}

0 commit comments

Comments
 (0)