Skip to content

Commit ab90de3

Browse files
committed
extract grpo_npu script from patch
1 parent 5ece81d commit ab90de3

File tree

2 files changed

+163
-169
lines changed

2 files changed

+163
-169
lines changed

docker/npu_patch/qwen3_vl_8b_multi_turn_grpo/slime.patch

Lines changed: 0 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -1,172 +1,3 @@
1-
diff --git a/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn_grpo_npu.py b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn_grpo_npu.py
2-
new file mode 100644
3-
index 00000000..8c862444
4-
--- /dev/null
5-
+++ b/examples/geo3k_vlm_multi_turn/run_geo3k_vlm_multi_turn_grpo_npu.py
6-
@@ -0,0 +1,163 @@
7-
+import os
8-
+
9-
+import slime.utils.misc as U
10-
+from slime.utils.external_utils.command_utils import execute_train_npu
11-
+
12-
+MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct")
13-
+assert MODEL_NAME in {
14-
+ "Qwen3-VL-2B-Instruct",
15-
+ "Qwen3-VL-4B-Instruct",
16-
+ "Qwen3-VL-8B-Instruct",
17-
+ "Qwen3-VL-2B-Thinking",
18-
+ "Qwen3-VL-4B-Thinking",
19-
+ "Qwen3-VL-8B-Thinking",
20-
+}
21-
+
22-
+EXTERNAL_RAY = int(os.environ.get("SLIME_SCRIPT_EXTERNAL_RAY", "0"))
23-
+TRAIN_BACKEND = os.environ.get("SLIME_SCRIPT_TRAIN_BACKEND", "fsdp").lower()
24-
+assert TRAIN_BACKEND in {"fsdp", "megatron"}
25-
+
26-
+DATASET_NAME = "VeraIsHere/geo3k_imgurl_processed"
27-
+DATA_ROOT = "/root/datasets/geo3k_imgurl_processed"
28-
+TRAIN_DATA_PATH = os.path.join(DATA_ROOT, "train.parquet")
29-
+
30-
+
31-
+def get_megatron_model_type(model_name: str) -> str:
32-
+ model_type = model_name.replace("-Instruct", "").replace("-Thinking", "")
33-
+ model_type = model_type.replace("Qwen3-VL-", "qwen3-")
34-
+ return model_type.replace("-2B", "-1.7B")
35-
+
36-
+
37-
+def execute():
38-
+ ckpt_args = f"--hf-checkpoint /home/data/{MODEL_NAME} "
39-
+
40-
+ wandb_args = (
41-
+ (
42-
+ "--use-wandb "
43-
+ "--wandb-project slime-dev "
44-
+ "--wandb-group geo3k_vlm_multi_turn "
45-
+ f"--wandb-key '{wandb_api_key}' "
46-
+ )
47-
+ if (wandb_api_key := os.environ.get("WANDB_API_KEY"))
48-
+ else ""
49-
+ )
50-
+
51-
+ rollout_args = (
52-
+ f"--prompt-data {TRAIN_DATA_PATH} "
53-
+ "--input-key problem "
54-
+ "--label-key answer "
55-
+ '--multimodal-keys \'{"image": "images"}\' '
56-
+ "--rm-type math "
57-
+ "--apply-chat-template "
58-
+ "--custom-generate-function-path examples.geo3k_vlm_multi_turn.rollout.generate "
59-
+ "--custom-config-path examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml "
60-
+ "--rollout-shuffle "
61-
+ "--num-rollout 3000 "
62-
+ "--rollout-batch-size 32 "
63-
+ "--n-samples-per-prompt 8 "
64-
+ "--rollout-max-response-len 4096 "
65-
+ "--rollout-temperature 1 "
66-
+ "--global-batch-size 256 "
67-
+ )
68-
+
69-
+ # eval_args = (
70-
+ # "--eval-interval 20 "
71-
+ # f"--eval-prompt-data geo3k_eval {TRAIN_DATA_PATH}@[0:64] "
72-
+ # "--n-samples-per-eval-prompt 1 "
73-
+ # "--eval-max-response-len 4096 "
74-
+ # "--eval-top-k 1 "
75-
+ # )
76-
+
77-
+ grpo_args = (
78-
+ "--advantage-estimator grpo "
79-
+ "--kl-loss-coef 0.00 "
80-
+ "--kl-loss-type low_var_kl "
81-
+ "--kl-coef 0.00 "
82-
+ "--entropy-coef 0.00 "
83-
+ "--eps-clip 0.2 "
84-
+ "--eps-clip-high 0.28 "
85-
+ "--use-kl-loss "
86-
+ )
87-
+
88-
+ optimizer_args = (
89-
+ "--optimizer adam "
90-
+ "--lr 1e-6 "
91-
+ "--lr-decay-style constant "
92-
+ "--weight-decay 0.1 "
93-
+ "--adam-beta1 0.9 "
94-
+ "--adam-beta2 0.98 "
95-
+
96-
+ "--optimizer-cpu-offload "
97-
+ "--overlap-cpu-optimizer-d2h-h2d "
98-
+ "--use-precision-aware-optimizer "
99-
+ )
100-
+
101-
+ sglang_args = (
102-
+ "--rollout-num-gpus-per-engine 1 "
103-
+ "--sglang-mem-fraction-static 0.6 "
104-
+ f"--sglang-cuda-graph-bs {' '.join(map(str, [4, 8] + list(range(16, 257, 8))))} "
105-
+ "--sglang-device npu "
106-
+ "--sglang-disable-radix-cache "
107-
+ "--sglang-chunked-prefill-size 32768 "
108-
+ "--sglang-max-prefill-tokens 4000 "
109-
+ "--sglang-max-total-tokens 327680 "
110-
+ )
111-
+
112-
+ megatron_args = (
113-
+ "--train-backend megatron "
114-
+ f"--load /home/data/{MODEL_NAME} "
115-
+ f"--ref-load /home/data/{MODEL_NAME} "
116-
+ "--tensor-model-parallel-size 4 "
117-
+ "--sequence-parallel "
118-
+ "--pipeline-model-parallel-size 1 "
119-
+ "--context-parallel-size 1 "
120-
+ "--expert-model-parallel-size 1 "
121-
+ "--expert-tensor-parallel-size 1 "
122-
+ "--recompute-granularity full "
123-
+ "--recompute-method uniform "
124-
+ "--recompute-num-layers 1 "
125-
+ "--use-dynamic-batch-size "
126-
+ "--max-tokens-per-gpu 16384 "
127-
+ "--balance-data "
128-
+ "--attention-dropout 0.0 "
129-
+ "--hidden-dropout 0.0 "
130-
+ "--accumulate-allreduce-grads-in-fp32 "
131-
+ "--attention-softmax-in-fp32 "
132-
+ "--attention-backend flash "
133-
+ "--megatron-to-hf-mode bridge "
134-
+ )
135-
+
136-
+ misc_args = (
137-
+ "--actor-num-nodes 1 " f"--actor-num-gpus-per-node 8 " f"--rollout-num-gpus 8 "
138-
+ "--no-gradient-accumulation-fusion "
139-
+ "--use-flash-attn "
140-
+ )
141-
+
142-
+ if TRAIN_BACKEND == "megatron":
143-
+ backend_args = megatron_args
144-
+ megatron_model_type = get_megatron_model_type(MODEL_NAME)
145-
+ os.environ["MODEL_ARGS_ROTARY_BASE"] = "5000000"
146-
+ else:
147-
+ exit()
148-
+
149-
+ train_args = (
150-
+ f"{ckpt_args} "
151-
+ f"{rollout_args} "
152-
+ f"{optimizer_args} "
153-
+ f"{grpo_args} "
154-
+ f"{sglang_args} "
155-
+ f"{backend_args} "
156-
+ f"{misc_args} "
157-
+ f"{wandb_args} "
158-
+ # f"{get_default_wandb_args(__file__)} "
159-
+ )
160-
+
161-
+ execute_train_npu(
162-
+ train_args=train_args,
163-
+ megatron_model_type=megatron_model_type,
164-
+ extra_env_vars=({"WANDB_API_KEY": os.environ["WANDB_API_KEY"]} if os.environ.get("WANDB_API_KEY") else {}),
165-
+ )
166-
+
167-
+
168-
+if __name__ == "__main__":
169-
+ execute()
1701
diff --git a/slime/backends/megatron_utils/__init__.py b/slime/backends/megatron_utils/__init__.py
1712
index a4666fbe..96e7b1b0 100644
1723
--- a/slime/backends/megatron_utils/__init__.py
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
import os
2+
3+
import slime.utils.misc as U
4+
from slime.utils.external_utils.command_utils import execute_train_npu
5+
6+
MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-VL-2B-Instruct")
7+
assert MODEL_NAME in {
8+
"Qwen3-VL-2B-Instruct",
9+
"Qwen3-VL-4B-Instruct",
10+
"Qwen3-VL-8B-Instruct",
11+
"Qwen3-VL-2B-Thinking",
12+
"Qwen3-VL-4B-Thinking",
13+
"Qwen3-VL-8B-Thinking",
14+
}
15+
16+
EXTERNAL_RAY = int(os.environ.get("SLIME_SCRIPT_EXTERNAL_RAY", "0"))
17+
TRAIN_BACKEND = os.environ.get("SLIME_SCRIPT_TRAIN_BACKEND", "fsdp").lower()
18+
assert TRAIN_BACKEND in {"fsdp", "megatron"}
19+
20+
DATASET_NAME = "VeraIsHere/geo3k_imgurl_processed"
21+
DATA_ROOT = "/root/datasets/geo3k_imgurl_processed"
22+
TRAIN_DATA_PATH = os.path.join(DATA_ROOT, "train.parquet")
23+
24+
25+
def get_megatron_model_type(model_name: str) -> str:
26+
model_type = model_name.replace("-Instruct", "").replace("-Thinking", "")
27+
model_type = model_type.replace("Qwen3-VL-", "qwen3-")
28+
return model_type.replace("-2B", "-1.7B")
29+
30+
31+
def execute():
32+
ckpt_args = f"--hf-checkpoint /home/data/{MODEL_NAME} "
33+
34+
wandb_args = (
35+
(
36+
"--use-wandb "
37+
"--wandb-project slime-dev "
38+
"--wandb-group geo3k_vlm_multi_turn "
39+
f"--wandb-key '{wandb_api_key}' "
40+
)
41+
if (wandb_api_key := os.environ.get("WANDB_API_KEY"))
42+
else ""
43+
)
44+
45+
rollout_args = (
46+
f"--prompt-data {TRAIN_DATA_PATH} "
47+
"--input-key problem "
48+
"--label-key answer "
49+
'--multimodal-keys \'{"image": "images"}\' '
50+
"--rm-type math "
51+
"--apply-chat-template "
52+
"--custom-generate-function-path examples.geo3k_vlm_multi_turn.rollout.generate "
53+
"--custom-config-path examples/geo3k_vlm_multi_turn/geo3k_vlm_multi_turn_config.yaml "
54+
"--rollout-shuffle "
55+
"--num-rollout 3000 "
56+
"--rollout-batch-size 32 "
57+
"--n-samples-per-prompt 8 "
58+
"--rollout-max-response-len 4096 "
59+
"--rollout-temperature 1 "
60+
"--global-batch-size 256 "
61+
)
62+
63+
# eval_args = (
64+
# "--eval-interval 20 "
65+
# f"--eval-prompt-data geo3k_eval {TRAIN_DATA_PATH}@[0:64] "
66+
# "--n-samples-per-eval-prompt 1 "
67+
# "--eval-max-response-len 4096 "
68+
# "--eval-top-k 1 "
69+
# )
70+
71+
grpo_args = (
72+
"--advantage-estimator grpo "
73+
"--kl-loss-coef 0.00 "
74+
"--kl-loss-type low_var_kl "
75+
"--kl-coef 0.00 "
76+
"--entropy-coef 0.00 "
77+
"--eps-clip 0.2 "
78+
"--eps-clip-high 0.28 "
79+
"--use-kl-loss "
80+
)
81+
82+
optimizer_args = (
83+
"--optimizer adam "
84+
"--lr 1e-6 "
85+
"--lr-decay-style constant "
86+
"--weight-decay 0.1 "
87+
"--adam-beta1 0.9 "
88+
"--adam-beta2 0.98 "
89+
90+
"--optimizer-cpu-offload "
91+
"--overlap-cpu-optimizer-d2h-h2d "
92+
"--use-precision-aware-optimizer "
93+
)
94+
95+
sglang_args = (
96+
"--rollout-num-gpus-per-engine 1 "
97+
"--sglang-mem-fraction-static 0.6 "
98+
f"--sglang-cuda-graph-bs {' '.join(map(str, [4, 8] + list(range(16, 257, 8))))} "
99+
"--sglang-device npu "
100+
"--sglang-disable-radix-cache "
101+
"--sglang-chunked-prefill-size 32768 "
102+
"--sglang-max-prefill-tokens 4000 "
103+
"--sglang-max-total-tokens 327680 "
104+
)
105+
106+
megatron_args = (
107+
"--train-backend megatron "
108+
f"--load /home/data/{MODEL_NAME} "
109+
f"--ref-load /home/data/{MODEL_NAME} "
110+
"--tensor-model-parallel-size 4 "
111+
"--sequence-parallel "
112+
"--pipeline-model-parallel-size 1 "
113+
"--context-parallel-size 1 "
114+
"--expert-model-parallel-size 1 "
115+
"--expert-tensor-parallel-size 1 "
116+
"--recompute-granularity full "
117+
"--recompute-method uniform "
118+
"--recompute-num-layers 1 "
119+
"--use-dynamic-batch-size "
120+
"--max-tokens-per-gpu 16384 "
121+
"--balance-data "
122+
"--attention-dropout 0.0 "
123+
"--hidden-dropout 0.0 "
124+
"--accumulate-allreduce-grads-in-fp32 "
125+
"--attention-softmax-in-fp32 "
126+
"--attention-backend flash "
127+
"--megatron-to-hf-mode bridge "
128+
)
129+
130+
misc_args = (
131+
"--actor-num-nodes 1 " f"--actor-num-gpus-per-node 8 " f"--rollout-num-gpus 8 "
132+
"--no-gradient-accumulation-fusion "
133+
"--use-flash-attn "
134+
)
135+
136+
if TRAIN_BACKEND == "megatron":
137+
backend_args = megatron_args
138+
megatron_model_type = get_megatron_model_type(MODEL_NAME)
139+
os.environ["MODEL_ARGS_ROTARY_BASE"] = "5000000"
140+
else:
141+
exit()
142+
143+
train_args = (
144+
f"{ckpt_args} "
145+
f"{rollout_args} "
146+
f"{optimizer_args} "
147+
f"{grpo_args} "
148+
f"{sglang_args} "
149+
f"{backend_args} "
150+
f"{misc_args} "
151+
f"{wandb_args} "
152+
# f"{get_default_wandb_args(__file__)} "
153+
)
154+
155+
execute_train_npu(
156+
train_args=train_args,
157+
megatron_model_type=megatron_model_type,
158+
extra_env_vars=({"WANDB_API_KEY": os.environ["WANDB_API_KEY"]} if os.environ.get("WANDB_API_KEY") else {}),
159+
)
160+
161+
162+
if __name__ == "__main__":
163+
execute()

0 commit comments

Comments
 (0)