Skip to content

Commit 14413cf

Browse files
authored
Add support for GLM5 (#1599)
1 parent 834cf80 commit 14413cf

File tree

21 files changed

+2120
-34
lines changed

21 files changed

+2120
-34
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1. **High-Performance Training**: Supports efficient training in various modes by connecting Megatron with SGLang;
1111
2. **Flexible Data Generation**: Enables arbitrary training data generation workflows through custom data generation interfaces and server-based engines.
1212

13-
slime is the RL-framework behind [GLM-4.7](https://z.ai/blog/glm-4.7), [GLM-4.6](https://z.ai/blog/glm-4.6), [GLM-4.5](https://z.ai/blog/glm-4.5) and apart from models from Z.ai, we also supports the following models:
13+
slime is the RL-framework behind [GLM-5](https://z.ai/blog/glm-5), [GLM-4.7](https://z.ai/blog/glm-4.7), [GLM-4.6](https://z.ai/blog/glm-4.6), [GLM-4.5](https://z.ai/blog/glm-4.5) and apart from models from Z.ai, we also supports the following models:
1414
- Qwen3 series (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 series;
1515
- DeepSeek V3 series (DeepSeek V3, V3.1, DeepSeek R1);
1616
- Llama 3.

README_zh.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
1. **高性能训练**:通过连接 Megatron 与 SGLang,支持各种模式的高效训练;
1111
2. **灵活的数据生成**:通过自定义数据生成接口以及 server based engine,实现任意的数据训练数据生成流程。
1212

13-
slime 是 [GLM-4.7](https://z.ai/blog/glm-4.7)[GLM-4.6](https://z.ai/blog/glm-4.6)[GLM-4.5](https://z.ai/blog/glm-4.5) 背后的 RL 训练框架,除此之外,slime 还支持:
13+
slime 是 [GLM-5](https://z.ai/blog/glm-5)[GLM-4.7](https://z.ai/blog/glm-4.7)[GLM-4.6](https://z.ai/blog/glm-4.6)[GLM-4.5](https://z.ai/blog/glm-4.5) 背后的 RL 训练框架,除此之外,slime 还支持:
1414
- Qwen3 系列 (Qwen3Next, Qwen3MoE, Qwen3), Qwen2.5 系列;
1515
- DeepSeek V3 系列 (DeepSeek V3, V3.1, DeepSeek R1);
1616
- Llama 3。

scripts/models/glm5-744B-A40B.sh

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
MOE_ROUTED_EXPERTS=256
2+
MOE_ACTIVE_ROUTED_EXPERTS=8
3+
MOE_SHARED_EXPERTS=1
4+
5+
NHIDDEN=6144
6+
MOE_FFN_HIDDEN=2048
7+
MOE_SHARED_EXPERT_INTERMEDIATE_SIZE=$(($MOE_FFN_HIDDEN * $MOE_SHARED_EXPERTS))
8+
FFN_HIDDEN=12288
9+
N_DENSE_LAYERS=3
10+
N_MOE_LAYERS=75
11+
NHEADS=64
12+
13+
MODEL_ARGS=(
14+
--spec "slime_plugins.models.glm5.glm5" "get_glm5_spec"
15+
--moe-layer-freq [0]*$N_DENSE_LAYERS+[1]*$N_MOE_LAYERS
16+
--num-experts $MOE_ROUTED_EXPERTS
17+
--moe-shared-expert-intermediate-size $MOE_SHARED_EXPERT_INTERMEDIATE_SIZE
18+
--moe-router-topk $MOE_ACTIVE_ROUTED_EXPERTS
19+
--moe-grouped-gemm
20+
--moe-permute-fusion
21+
--moe-ffn-hidden-size $MOE_FFN_HIDDEN
22+
--moe-router-score-function sigmoid
23+
--moe-router-pre-softmax
24+
--moe-router-enable-expert-bias
25+
--moe-router-bias-update-rate 0
26+
--moe-router-load-balancing-type seq_aux_loss
27+
--moe-router-topk-scaling-factor 2.5
28+
--moe-aux-loss-coeff 0
29+
--moe-router-dtype fp32
30+
--make-vocab-size-divisible-by 16
31+
--num-layers $((N_DENSE_LAYERS + N_MOE_LAYERS))
32+
--hidden-size $NHIDDEN
33+
--ffn-hidden-size $FFN_HIDDEN
34+
--num-attention-heads $NHEADS
35+
--disable-bias-linear
36+
--swiglu
37+
--untie-embeddings-and-output-weights
38+
--position-embedding-type rope
39+
--no-position-embedding
40+
--normalization RMSNorm
41+
--qk-layernorm
42+
--multi-latent-attention
43+
--q-lora-rank 2048
44+
--kv-lora-rank 512
45+
--qk-head-dim 192
46+
--v-head-dim 256
47+
--kv-channels 192
48+
--qk-pos-emb-head-dim 64
49+
--vocab-size 154880
50+
--rotary-base 1000000
51+
--enable-experimental
52+
53+
# slime specific args
54+
--allgather-cp
55+
)

scripts/run-glm5-744B-A40B.sh

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
13+
set -ex
14+
15+
# will prevent ray from buffering stdout/stderr
16+
export PYTHONBUFFERED=16
17+
18+
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
19+
if [ "$NVLINK_COUNT" -gt 0 ]; then
20+
HAS_NVLINK=1
21+
else
22+
HAS_NVLINK=0
23+
fi
24+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
25+
26+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
27+
source "${SCRIPT_DIR}/models/glm5-744B-A40B.sh"
28+
29+
CKPT_ARGS=(
30+
--hf-checkpoint $BASE_DIR/GLM-5
31+
--ref-load $BASE_DIR/GLM-5_torch_dist/
32+
--load $BASE_DIR/GLM-5_slime/
33+
--save $BASE_DIR/GLM-5_slime/
34+
--save-interval 20
35+
)
36+
37+
ROLLOUT_ARGS=(
38+
--prompt-data $BASE_DIR/dapo-math-17k/dapo-math-17k.jsonl
39+
--input-key prompt
40+
--label-key label
41+
--apply-chat-template
42+
--rollout-shuffle
43+
44+
--rm-type deepscaler
45+
46+
--num-rollout 3000
47+
--rollout-batch-size 8
48+
--n-samples-per-prompt 8
49+
--rollout-max-response-len 32768
50+
--rollout-temperature 1
51+
52+
--global-batch-size 64
53+
)
54+
55+
PERF_ARGS=(
56+
--tensor-model-parallel-size 4
57+
--sequence-parallel
58+
--pipeline-model-parallel-size 4
59+
--decoder-last-pipeline-num-layers 18
60+
--expert-model-parallel-size 32
61+
--expert-tensor-parallel-size 1
62+
--context-parallel-size 2
63+
64+
--recompute-granularity full
65+
--recompute-method uniform
66+
--recompute-num-layers 1
67+
68+
--use-dynamic-batch-size
69+
--max-tokens-per-gpu 16384
70+
--data-pad-size-multiplier 4096
71+
--log-probs-chunk-size 1024
72+
)
73+
74+
GRPO_ARGS=(
75+
--advantage-estimator grpo
76+
#--use-kl-loss
77+
--kl-loss-coef 0.00
78+
--kl-loss-type low_var_kl
79+
--kl-coef 0.00
80+
--entropy-coef 0.00
81+
--eps-clip 0.2
82+
--eps-clip-high 0.28
83+
)
84+
85+
OPTIMIZER_ARGS=(
86+
--optimizer adam
87+
--lr 1e-6
88+
89+
--lr-decay-style constant
90+
--weight-decay 0.1
91+
--adam-beta1 0.9
92+
--adam-beta2 0.98
93+
94+
--optimizer-cpu-offload
95+
--overlap-cpu-optimizer-d2h-h2d
96+
--use-precision-aware-optimizer
97+
)
98+
99+
WANDB_ARGS=(
100+
# --use-wandb
101+
# --wandb-project slime-dev
102+
# --wandb-group glm5-test
103+
# --wandb-key ${WANDB_KEY}
104+
)
105+
106+
SGLANG_ARGS=(
107+
--rollout-num-gpus-per-engine 64
108+
--sglang-mem-fraction-static 0.70
109+
--sglang-enable-dp-attention
110+
--sglang-ep-size 64
111+
--sglang-dp-size 64
112+
--sglang-moe-dense-tp-size 1
113+
--sglang-enable-dp-lm-head
114+
115+
--sglang-moe-a2a-backend deepep
116+
--sglang-deepep-mode auto
117+
118+
--prefill-num-servers 1
119+
120+
# mtp
121+
--sglang-speculative-algorithm EAGLE
122+
--sglang-speculative-num-steps 3
123+
--sglang-speculative-eagle-topk 1
124+
--sglang-speculative-num-draft-tokens 4
125+
126+
# dsa
127+
--sglang-page-size 64
128+
--sglang-nsa-decode-backend flashmla_sparse
129+
--sglang-nsa-prefill-backend flashmla_sparse
130+
--sglang-attention-backend nsa
131+
--sglang-cuda-graph-max-bs 8
132+
133+
--sglang-max-running-requests 512
134+
--sglang-chunked-prefill-size 131072
135+
136+
--sglang-watchdog-timeout 3600
137+
)
138+
139+
MISC_ARGS=(
140+
# default dropout in megatron is 0.1
141+
--attention-dropout 0.0
142+
--hidden-dropout 0.0
143+
# should be good for model performance
144+
--accumulate-allreduce-grads-in-fp32
145+
--attention-softmax-in-fp32
146+
# need to comment this when using model with MLA
147+
--attention-backend flash
148+
149+
# use deepep for megatron
150+
--moe-enable-deepep
151+
--moe-token-dispatcher-type flex
152+
)
153+
154+
# Build the runtime environment JSON with proper variable substitution
155+
RUNTIME_ENV_JSON="{
156+
\"env_vars\": {
157+
\"PYTHONPATH\": \"/root/Megatron-LM/\",
158+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
159+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
160+
\"no_proxy\": \"${no_proxy}\",
161+
\"MASTER_ADDR\": \"${MASTER_ADDR}\",
162+
\"INDEXER_ROPE_NEOX_STYLE\": \"0\",
163+
\"SGLANG_DEEPEP_NUM_MAX_DISPATCH_TOKENS_PER_RANK\": \"32\",
164+
\"NVSHMEM_DISABLE_NCCL\": \"1\"
165+
}
166+
}"
167+
168+
ray job submit --address="http://127.0.0.1:8265" \
169+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
170+
-- python3 train.py \
171+
--actor-num-nodes 32 \
172+
--actor-num-gpus-per-node 8 \
173+
--colocate \
174+
--update-weight-buffer-size $(( 1024 * 1024 * 1024 * 2 )) \
175+
${MODEL_ARGS[@]} \
176+
${CKPT_ARGS[@]} \
177+
${ROLLOUT_ARGS[@]} \
178+
${OPTIMIZER_ARGS[@]} \
179+
${GRPO_ARGS[@]} \
180+
${WANDB_ARGS[@]} \
181+
${PERF_ARGS[@]} \
182+
${EVAL_ARGS[@]} \
183+
${SGLANG_ARGS[@]} \
184+
${MISC_ARGS[@]}

slime/backends/megatron_utils/actor.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -223,14 +223,12 @@ def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
223223
continue
224224
rollout_data[key] = [
225225
torch.tensor(
226-
(
227-
slice_log_prob_with_cp(
228-
log_prob,
229-
total_length,
230-
response_length,
231-
self.args.qkv_format,
232-
rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None,
233-
)
226+
slice_log_prob_with_cp(
227+
log_prob,
228+
total_length,
229+
response_length,
230+
self.args.qkv_format,
231+
rollout_data["max_seq_lens"][i] if self.args.qkv_format == "bshd" else None,
234232
),
235233
device=torch.cuda.current_device(),
236234
dtype=torch.float32,

slime/backends/megatron_utils/data.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_batch(
2727
keys: Sequence[str],
2828
pad_multiplier: int = 128,
2929
qkv_format: str = "thd",
30+
allgather_cp: bool = False,
3031
) -> dict[str, torch.Tensor | PackedSeqParams | list[torch.Tensor] | None]:
3132
"""
3233
Generate a CP-ready micro-batch with packed sequence parameters.
@@ -64,31 +65,53 @@ def get_batch(
6465
batch["unconcat_tokens"] = tokens
6566

6667
cp_size = mpu.get_context_parallel_world_size()
68+
cp_rank = mpu.get_context_parallel_rank()
6769

6870
if qkv_format == "bshd":
6971
max_seqlen = batch["max_seq_lens"][0]
7072
assert max([t.size(0) for t in tokens]) <= max_seqlen
7173
tokens = [slice_with_cp(t, pad_token_id, qkv_format, max_seqlen) for t in tokens]
7274
tokens = torch.stack(tokens)
75+
7376
elif qkv_format == "thd":
74-
tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens]
77+
if allgather_cp:
78+
# DSA mode: concatenate all sequences first, then slice once with CP.
79+
# We also pad the *global* concatenated stream to make per-rank chunks equal.
80+
cu_seqlens_list: list[int] = [0]
81+
for t in tokens:
82+
cu_seqlens_list.append(cu_seqlens_list[-1] + t.size(0))
83+
84+
tokens = torch.cat(tokens, dim=0)
85+
86+
# Pad global stream so (1) divisible by cp_size (equal chunks),
87+
# (2) divisible by pad_size (reduce fragmentation).
88+
global_pad_size = cp_size * pad_size
89+
pad = (global_pad_size - tokens.size(0) % global_pad_size) % global_pad_size
90+
if pad != 0:
91+
tokens = F.pad(tokens, (0, pad), value=pad_token_id)
92+
cu_seqlens_list.append(cu_seqlens_list[-1] + pad)
93+
94+
cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int, device=torch.cuda.current_device())
95+
tokens = tokens.chunk(cp_size, dim=0)[cp_rank]
96+
else:
97+
tokens = [slice_with_cp(t, pad_token_id, qkv_format) for t in tokens]
7598

76-
cu_seqlens = [0]
77-
for t in tokens:
78-
cu_seqlens.append(cu_seqlens[-1] + t.size(0))
99+
cu_seqlens = [0]
100+
for t in tokens:
101+
cu_seqlens.append(cu_seqlens[-1] + t.size(0))
79102

80-
tokens = torch.cat(tokens)
103+
tokens = torch.cat(tokens)
81104

82-
# Always pad to reduce memory fragmentation and maybe make the computation faster
83-
pad = (pad_size - tokens.size(0) % pad_size) % pad_size
84-
if pad != 0:
85-
tokens = F.pad(tokens, (0, pad), value=pad_token_id)
86-
cu_seqlens.append(cu_seqlens[-1] + pad)
105+
# Always pad to reduce memory fragmentation and maybe make the computation faster
106+
pad = (pad_size - tokens.size(0) % pad_size) % pad_size
107+
if pad != 0:
108+
tokens = F.pad(tokens, (0, pad), value=pad_token_id)
109+
cu_seqlens.append(cu_seqlens[-1] + pad)
87110

88-
# thd requires the cu_seqlens to be of the origin length
89-
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size
90-
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
111+
# thd requires the cu_seqlens to be of the origin length
112+
cu_seqlens = torch.tensor(cu_seqlens, dtype=torch.int).cuda() * cp_size
91113

114+
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
92115
packed_seq_params = PackedSeqParams(
93116
cu_seqlens_q=cu_seqlens,
94117
cu_seqlens_kv=cu_seqlens,
@@ -115,11 +138,20 @@ def get_batch(
115138
prompt_length = total_length - response_length
116139
# Align mask to token stream positions (prompt_length-1 left pad, 1 right pad)
117140
loss_mask = F.pad(loss_mask, (prompt_length - 1, 1), value=0)
141+
if allgather_cp:
142+
loss_masks.append(loss_mask)
143+
continue
118144
loss_mask = slice_with_cp(loss_mask, 0, qkv_format, max_seqlen)
119145
loss_masks.append(loss_mask)
120146

121147
if qkv_format == "bshd":
122148
loss_masks = torch.stack(loss_masks)
149+
elif qkv_format == "thd" and allgather_cp:
150+
# DSA: concatenate first (same as tokens), pad globally (same pad as above), then slice once.
151+
loss_masks = torch.cat(loss_masks, dim=0)
152+
if pad != 0:
153+
loss_masks = F.pad(loss_masks, (0, pad), value=0)
154+
loss_masks = loss_masks.chunk(cp_size, dim=0)[cp_rank].unsqueeze(0)
123155
elif qkv_format == "thd":
124156
loss_masks = torch.cat(loss_masks)
125157
loss_masks = F.pad(loss_masks, (0, pad), value=0).unsqueeze(0)

0 commit comments

Comments
 (0)