Skip to content

Commit cf33f0e

Browse files
authored
add fp8 training examples (#821)
1 parent 21c9a40 commit cf33f0e

File tree

2 files changed

+208
-0
lines changed

2 files changed

+208
-0
lines changed

examples/low_precision/README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
## FP8 training examples
2+
3+
This is an example of FP8 training and FP8 inference. Under FP8 training and inference, it can achieve more efficient inference throughput and lower training-inference mismatch, resulting in more stable training.
4+
5+
### Files
6+
7+
* `run-qwen3-4b-fp8.sh`: example launch script with Qwen3‑4B in FP8.
8+
9+
### Quick Start
10+
11+
1. [optional] Convert your HuggingFace weights to FP8 format. You can use `tools/convert_hf_to_fp8`, or directly write an FP8 format model config.
12+
13+
2. Start FP8 training
14+
15+
```
16+
cd slime
17+
bash examples/fp8/run-qwen3-4b-fp8.sh
18+
```
19+
Following the above command will launch FP8 training. According to slime's design, if the model under `--hf-checkpoint` is FP8, it will automatically use FP8 quantization in weight updates.
20+
21+
3. Use the saved checkpoint for evaluation
22+
23+
Note that TransformerEngine does not specifically save FP8 quantized weights; the saved torch dist remains in original precision (usually bf16). If you want to evaluate under FP8, you need to convert the checkpoint from `torch_dist` to HuggingFace format, then convert to FP8 HuggingFace format.
24+
25+
26+
### Quick Explanation
27+
28+
Here's a quick explanation of how FP8 training is currently implemented in slime:
29+
30+
1. Initialization: If FP8 recipe is enabled, layers will be built in FP8 context.
31+
32+
2. Training: During training, weights and activations are quantized online to nvfp8 format, and cuBLAS FP8 GEMM is called for various GEMM computations in forward and backward passes.
33+
34+
3. Update weight: In RL weight updates, the training engine will attempt to save model weights. The saved results will be dequantized from FP8 to bf16, but since the config under `--hf-checkpoint` is FP8, slime will quantize this bf16.
35+
36+
4. Save checkpoint: Similar to weight updates, if checkpoints need to be saved from the training engine, they will also be dequantized back to bf16 and saved to `torch_dist` format checkpoints.
37+
38+
39+
### TODO
40+
41+
Currently, FP8 is far from being a complete feature and still has the following bugs, for examples:
42+
43+
- FP8 weights (`--fp8-param-gather`) can provide memory savings benefits, but currently FP8 weights must be used with TransformerEngine's FusedAdam, which conflicts with the commonly used Adam CPU offload technique in Megatron-LM.
44+
45+
The slime team will continue to collaborate with the NVIDIA team to contribute more complete FP8 training infrastructure to the community.
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
13+
set -ex
14+
15+
# will prevent ray from buffering stdout/stderr
16+
export PYTHONBUFFERED=16
17+
18+
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
19+
if [ "$NVLINK_COUNT" -gt 0 ]; then
20+
HAS_NVLINK=1
21+
else
22+
HAS_NVLINK=0
23+
fi
24+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
25+
26+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
27+
source "${SCRIPT_DIR}/../../scripts/models/qwen3-4B.sh"
28+
29+
CKPT_ARGS=(
30+
--hf-checkpoint /root/Qwen3-4B-FP8
31+
#--hf-checkpoint /root/Qwen3-4B-FP8
32+
--ref-load /root/Qwen3-4B_torch_dist
33+
--load /root/qwen3-4b_cp8_fp8
34+
--save /root/rl-model/qwen3-4b_cp8_fp8
35+
--save-interval 20
36+
)
37+
38+
ROLLOUT_ARGS=(
39+
--prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
40+
--input-key prompt
41+
--label-key label
42+
--apply-chat-template
43+
--rollout-shuffle
44+
--rm-type deepscaler
45+
--num-rollout 3000
46+
--rollout-batch-size 32
47+
--n-samples-per-prompt 8
48+
--rollout-max-response-len 8192
49+
--rollout-temperature 0.8
50+
51+
--global-batch-size 256
52+
--balance-data
53+
)
54+
55+
EVAL_ARGS=(
56+
--eval-interval 20
57+
--eval-prompt-data aime /root/data/aime-2024.jsonl
58+
--n-samples-per-eval-prompt 16
59+
--eval-max-response-len 16384
60+
--eval-top-p 0.7
61+
)
62+
63+
PERF_ARGS=(
64+
--tensor-model-parallel-size 2
65+
--sequence-parallel
66+
--pipeline-model-parallel-size 1
67+
--context-parallel-size 1
68+
--expert-model-parallel-size 1
69+
--expert-tensor-parallel-size 1
70+
71+
--recompute-granularity full
72+
--recompute-method uniform
73+
--recompute-num-layers 1
74+
75+
# --micro-batch-size 1
76+
--use-dynamic-batch-size
77+
--max-tokens-per-gpu 9216
78+
)
79+
80+
GRPO_ARGS=(
81+
--advantage-estimator grpo
82+
--use-kl-loss
83+
--kl-loss-coef 0.00
84+
--kl-loss-type low_var_kl
85+
--entropy-coef 0.00
86+
--eps-clip 0.2
87+
--eps-clip-high 0.28
88+
)
89+
90+
OPTIMIZER_ARGS=(
91+
--optimizer adam
92+
--lr 1e-6
93+
--lr-decay-style constant
94+
--weight-decay 0.1
95+
--adam-beta1 0.9
96+
--adam-beta2 0.98
97+
98+
)
99+
100+
WANDB_ARGS=(
101+
# --use-wandb
102+
# --wandb-project slime-dev
103+
# --wandb-group qwen3-4B-test
104+
# --wandb-key ${WANDB_KEY}
105+
)
106+
107+
SGLANG_ARGS=(
108+
--rollout-num-gpus-per-engine 2
109+
--sglang-mem-fraction-static 0.7
110+
)
111+
112+
MISC_ARGS=(
113+
# default dropout in megatron is 0.1
114+
--attention-dropout 0.0
115+
--hidden-dropout 0.0
116+
# should be good for model performance
117+
--accumulate-allreduce-grads-in-fp32
118+
--attention-softmax-in-fp32
119+
# need to comment this when using model with MLA
120+
--attention-backend flash
121+
)
122+
123+
PRECISE_ARGS=(
124+
--transformer-impl transformer_engine
125+
--bf16
126+
--fp8-format e4m3
127+
--fp8-recipe blockwise
128+
--fp8-param-gather
129+
)
130+
131+
132+
# launch the master node of ray in container
133+
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
134+
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 8 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265
135+
136+
# Build the runtime environment JSON with proper variable substitution
137+
# you should enable NVTE_FP8_BLOCK_SCALING_FP32_SCALES to use fp32 scales in fp8 training
138+
RUNTIME_ENV_JSON="{
139+
\"env_vars\": {
140+
\"PYTHONPATH\": \"/root/Megatron-LM/:${SCRIPT_DIR}\",
141+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
142+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
143+
\"NVTE_FP8_BLOCK_SCALING_FP32_SCALES\": \"1\"
144+
}
145+
}"
146+
147+
ray job submit --address="http://127.0.0.1:8265" \
148+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
149+
-- python3 train.py \
150+
--actor-num-nodes 1 \
151+
--actor-num-gpus-per-node 8 \
152+
--colocate \
153+
${MODEL_ARGS[@]} \
154+
${CKPT_ARGS[@]} \
155+
${ROLLOUT_ARGS[@]} \
156+
${OPTIMIZER_ARGS[@]} \
157+
${GRPO_ARGS[@]} \
158+
${WANDB_ARGS[@]} \
159+
${PERF_ARGS[@]} \
160+
${EVAL_ARGS[@]} \
161+
${SGLANG_ARGS[@]} \
162+
${MISC_ARGS[@]} \
163+
${PRECISE_ARGS[@]}

0 commit comments

Comments
 (0)