Skip to content

Commit de7a8b5

Browse files
committed
add back run-kimi-k2-Instruct.sh for #1344
1 parent ed7780e commit de7a8b5

File tree

1 file changed

+181
-0
lines changed

1 file changed

+181
-0
lines changed

scripts/run-kimi-k2-Instruct.sh

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
#!/bin/bash
2+
3+
# for rerun the task
4+
pkill -9 sglang
5+
sleep 3
6+
ray stop --force
7+
pkill -9 ray
8+
pkill -9 python
9+
sleep 3
10+
pkill -9 ray
11+
pkill -9 python
12+
13+
set -ex
14+
15+
# will prevent ray from buffering stdout/stderr
16+
export PYTHONBUFFERED=16
17+
18+
NVLINK_COUNT=$(nvidia-smi topo -m 2>/dev/null | grep -o 'NV[0-9][0-9]*' | wc -l)
19+
if [ "$NVLINK_COUNT" -gt 0 ]; then
20+
HAS_NVLINK=1
21+
else
22+
HAS_NVLINK=0
23+
fi
24+
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"
25+
26+
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
27+
source "${SCRIPT_DIR}/models/kimi-k2.sh"
28+
29+
CKPT_ARGS=(
30+
--hf-checkpoint $BASE_DIR/Kimi-K2-Instruct/
31+
# --hf-checkpoint $BASE_DIR/Kimi-K2-bf16/
32+
--ref-load $BASE_DIR/Kimi-K2_torch_dist/
33+
--load $BASE_DIR/Kimi-K2_slime/
34+
--save $BASE_DIR/Kimi-K2_slime/
35+
--save-interval 20
36+
)
37+
38+
ROLLOUT_ARGS=(
39+
--prompt-data $BASE_DIR/dapo-math-17k/dapo-math-17k.jsonl
40+
--input-key prompt
41+
--label-key label
42+
--apply-chat-template
43+
--rollout-shuffle
44+
45+
--rm-type math
46+
47+
--num-rollout 100
48+
--rollout-batch-size 128
49+
--n-samples-per-prompt 8
50+
--rollout-max-response-len 32768
51+
--rollout-temperature 1
52+
53+
# --global-batch-size 1024
54+
55+
--over-sampling-batch-size 256
56+
--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std
57+
58+
--num-steps-per-rollout 4
59+
--balance-data
60+
)
61+
62+
EVAL_ARGS=(
63+
--eval-interval 20
64+
--eval-prompt-data aime $BASE_DIR/rl_data/aime-2024.jsonl
65+
--n-samples-per-eval-prompt 8
66+
--eval-max-response-len 32768
67+
--eval-top-p 1
68+
)
69+
70+
PERF_ARGS=(
71+
--tensor-model-parallel-size 8
72+
--sequence-parallel
73+
--pipeline-model-parallel-size 8
74+
--context-parallel-size 4
75+
--expert-model-parallel-size 32
76+
--expert-tensor-parallel-size 1
77+
--decoder-last-pipeline-num-layers 5
78+
79+
--recompute-granularity full
80+
--recompute-method uniform
81+
--recompute-num-layers 1
82+
83+
--use-dynamic-batch-size
84+
--max-tokens-per-gpu 16384
85+
)
86+
87+
GRPO_ARGS=(
88+
--advantage-estimator grpo
89+
--use-kl-loss
90+
--kl-loss-coef 0.00
91+
--kl-loss-type low_var_kl
92+
--entropy-coef 0.00
93+
--eps-clip 0.2
94+
--eps-clip-high 0.28
95+
)
96+
97+
OPTIMIZER_ARGS=(
98+
--optimizer adam
99+
--lr 1e-6
100+
101+
--lr-decay-style constant
102+
--weight-decay 0.1
103+
--adam-beta1 0.9
104+
--adam-beta2 0.98
105+
106+
--optimizer-cpu-offload
107+
--overlap-cpu-optimizer-d2h-h2d
108+
--use-precision-aware-optimizer
109+
)
110+
111+
WANDB_ARGS=(
112+
# --use-wandb
113+
# --wandb-project slime-dev
114+
# --wandb-group kimi-k2-test
115+
# --wandb-key ${WANDB_KEY}
116+
)
117+
118+
SGLANG_ARGS=(
119+
--rollout-num-gpus-per-engine 16
120+
--sglang-mem-fraction-static 0.7
121+
122+
# dp attention
123+
--sglang-enable-dp-attention
124+
--sglang-dp-size 8
125+
--sglang-moe-dense-tp-size 1
126+
--sglang-enable-dp-lm-head
127+
128+
--sglang-ep-size 16
129+
130+
# enable deepep for sglang
131+
# --sglang-moe-a2a-backend deepep
132+
# --sglang-deepep-mode auto
133+
134+
# make every dp rank has 128 concurrency
135+
--sglang-server-concurrency 1024
136+
)
137+
138+
139+
MISC_ARGS=(
140+
# default dropout in megatron is 0.1
141+
--attention-dropout 0.0
142+
--hidden-dropout 0.0
143+
# should be good for model performance
144+
--accumulate-allreduce-grads-in-fp32
145+
--attention-softmax-in-fp32
146+
# need to comment this when using model with MLA
147+
--attention-backend flash
148+
149+
# use deepep for megatron
150+
--moe-enable-deepep
151+
--moe-token-dispatcher-type flex
152+
)
153+
154+
# Build the runtime environment JSON with proper variable substitution
155+
RUNTIME_ENV_JSON="{
156+
\"env_vars\": {
157+
\"PYTHONPATH\": \"/root/Megatron-LM/\",
158+
\"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
159+
\"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\",
160+
\"no_proxy\": \"${no_proxy}\",
161+
\"MASTER_ADDR\": \"${MASTER_ADDR}\"
162+
}
163+
}"
164+
165+
ray job submit --address="http://127.0.0.1:8265" \
166+
--runtime-env-json="${RUNTIME_ENV_JSON}" \
167+
-- python3 train.py \
168+
--actor-num-nodes 32 \
169+
--actor-num-gpus-per-node 8 \
170+
--colocate \
171+
--update-weight-buffer-size $(( 4 * 512 * 1024 * 1024))
172+
${MODEL_ARGS[@]} \
173+
${CKPT_ARGS[@]} \
174+
${ROLLOUT_ARGS[@]} \
175+
${OPTIMIZER_ARGS[@]} \
176+
${GRPO_ARGS[@]} \
177+
${WANDB_ARGS[@]} \
178+
${PERF_ARGS[@]} \
179+
${EVAL_ARGS[@]} \
180+
${SGLANG_ARGS[@]} \
181+
${MISC_ARGS[@]}

0 commit comments

Comments
 (0)