Skip to content

Commit 657d8ee

Browse files
fzyzcjyzhuzilin
andauthored
Support true on policy (#566)
Co-authored-by: Zilin Zhu <zhuzilinallen@gmail.com>
1 parent a133f93 commit 657d8ee

File tree

4 files changed

+171
-0
lines changed

4 files changed

+171
-0
lines changed
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import os
2+
import sys
3+
from pathlib import Path
4+
5+
sys.path.append(str(Path(__file__).resolve().parents[2] / "tests"))
6+
7+
import command_utils as U
8+
9+
MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-0.6B")
10+
assert MODEL_NAME in {"Qwen3-0.6B", "Qwen3-4B"}
11+
12+
MODE = os.environ.get("SLIME_SCRIPT_MODE", "normal")
13+
assert MODE in {"normal", "debug_minimal", "debug_one_sample"}
14+
15+
NUM_GPUS = int(os.environ.get("SLIME_SCRIPT_NUM_GPUS", "1"))
16+
17+
18+
def prepare():
19+
U.exec_command("mkdir -p /root/models /root/datasets")
20+
U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}")
21+
U.hf_download_dataset("zhuzilin/gsm8k")
22+
23+
24+
def execute():
25+
ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} "
26+
27+
rollout_args = (
28+
"--prompt-data /root/datasets/gsm8k/train.parquet "
29+
"--input-key messages "
30+
"--label-key label "
31+
"--apply-chat-template "
32+
"--rollout-shuffle "
33+
"--rm-type math "
34+
f"--num-rollout {2 if MODE == 'debug_one_sample' else 3000} "
35+
f"--rollout-batch-size {1 if MODE == 'debug_one_sample' else 32} "
36+
f"--n-samples-per-prompt {1 if MODE == 'debug_one_sample' else 8} "
37+
f"--rollout-max-response-len {2 if MODE == 'debug_one_sample' else 1024} "
38+
"--rollout-temperature 0.8 "
39+
# temp remove this to make test easier
40+
# "--over-sampling-batch-size 64 "
41+
# "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std "
42+
f"--global-batch-size {1 if MODE == 'debug_one_sample' else 256} "
43+
)
44+
45+
eval_args = ""
46+
if MODE == "normal":
47+
eval_args = (
48+
"--eval-interval 20 "
49+
"--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet "
50+
"--n-samples-per-eval-prompt 1 "
51+
"--eval-max-response-len 1024 "
52+
"--eval-top-k 1 "
53+
)
54+
55+
grpo_args = (
56+
"--advantage-estimator grpo "
57+
# "--use-kl-loss "
58+
"--kl-loss-coef 0.00 "
59+
"--kl-loss-type low_var_kl "
60+
"--kl-coef 0.00 "
61+
"--entropy-coef 0.00 "
62+
"--eps-clip 0.2 "
63+
"--eps-clip-high 0.28 "
64+
# mainly to look at its metric
65+
"--use-tis "
66+
)
67+
68+
optimizer_args = (
69+
"--optimizer adam "
70+
"--lr 1e-6 "
71+
"--lr-decay-style constant "
72+
"--weight-decay 0.1 "
73+
"--adam-beta1 0.9 "
74+
"--adam-beta2 0.98 "
75+
)
76+
77+
sglang_args = (
78+
"--rollout-num-gpus-per-engine 1 "
79+
"--sglang-decode-log-interval 1000 "
80+
"--sglang-enable-metrics "
81+
f"--sglang-mem-fraction-static {0.2 if MODEL_NAME == 'Qwen3-4B' else 0.4} "
82+
f"{'--sglang-disable-cuda-graph ' if MODE == 'debug_one_sample' else ''}"
83+
)
84+
85+
fsdp_args = (
86+
# Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default)
87+
# "--fsdp-full-params " # Uncomment this line to enable full params mode
88+
# Set the bucket size for weight update
89+
"--update-weight-buffer-size 536870912 " # 512MB
90+
)
91+
92+
ci_args = (
93+
"--ci-test "
94+
"--ci-disable-kl-checker "
95+
"--ci-metric-checker-key eval/gsm8k "
96+
"--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step
97+
)
98+
99+
misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " "--train-backend fsdp "
100+
101+
if MODEL_NAME == "Qwen3-4B":
102+
misc_args += (
103+
"--use-dynamic-batch-size "
104+
# TODO pick a good value
105+
"--max-tokens-per-gpu 2048 "
106+
)
107+
108+
true_on_policy_args = (
109+
"--sglang-enable-deterministic-inference "
110+
"--sglang-rl-on-policy-target fsdp "
111+
"--sglang-attention-backend fa3 "
112+
"--attn-implementation flash_attention_3 "
113+
"--deterministic-mode "
114+
"--true-on-policy-mode "
115+
)
116+
true_on_policy_envs = {
117+
# TODO note: "Ring" in original RL PR, "allreduce:tree" in SGLang
118+
# "NCCL_ALGO": "Ring",
119+
"NCCL_ALGO": "allreduce:tree",
120+
"NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0",
121+
"CUBLAS_WORKSPACE_CONFIG": ":4096:8",
122+
}
123+
124+
train_args = (
125+
f"{ckpt_args} "
126+
f"{rollout_args} "
127+
f"{optimizer_args} "
128+
f"{grpo_args} "
129+
f"{sglang_args} "
130+
f"{U.get_default_wandb_args(__file__)} "
131+
f"{eval_args} "
132+
f"{fsdp_args} "
133+
f"{ci_args} "
134+
f"{misc_args} "
135+
f"{true_on_policy_args} "
136+
)
137+
138+
U.execute_train(
139+
train_args=train_args,
140+
num_gpus=NUM_GPUS,
141+
model_type=None,
142+
extra_env_vars={
143+
**true_on_policy_envs,
144+
"SGLANG_DUMPER_ENABLE": "1" if MODE == "debug_one_sample" else "0",
145+
"SGLANG_TEMP_UTILS_ENABLE_DEBUG_PRINT": "1" if MODE == "debug_one_sample" else "0",
146+
},
147+
)
148+
149+
150+
if __name__ == "__main__":
151+
prepare()
152+
execute()

slime/backends/fsdp_utils/actor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,12 @@ class FSDPTrainRayActor(TrainRayActor):
5151
def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = False) -> int: # type: ignore[override]
5252
super().init(args, role, wandb_run_id, with_ref)
5353

54+
if args.true_on_policy_mode:
55+
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
56+
57+
print("FSDPTrainRayActor call enable_batch_invariant_mode for true-on-policy")
58+
enable_batch_invariant_mode()
59+
5460
# Update rank and world_size for wandb secondary initialization (using actual distributed values)
5561
args.rank = dist.get_rank()
5662
args.world_size = dist.get_world_size()
@@ -454,6 +460,11 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
454460
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
455461
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)
456462

463+
train_rollout_logprob_diff = old_log_probs - rollout_log_probs
464+
train_rollout_logprob_diff = sum_of_sample_mean(
465+
train_rollout_logprob_diff, response_lengths, loss_masks
466+
).detach()
467+
457468
loss = pg_loss
458469

459470
if self.args.entropy_coef != 0:
@@ -477,6 +488,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
477488
"pg_loss": pg_loss.detach(),
478489
"pg_clipfrac": pg_clipfrac.detach(),
479490
"ppo_kl": ppo_kl.detach(),
491+
"train_rollout_logprob_diff": train_rollout_logprob_diff,
480492
}
481493

482494
if self.args.use_kl_loss:

slime/backends/fsdp_utils/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class FSDPArgs:
3232
# FSDP configuration
3333
fsdp_full_params: bool = False # If True, use full_tensor; if False, use shard_tensor
3434

35+
deterministic_mode: bool = False # This name must be the same as Megatron's
3536
# Profile
3637
record_memory_history: bool = False
3738
memory_snapshot_path: str = "snapshot.pickle"

slime/utils/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,12 @@ def add_train_arguments(parser):
9898
default="megatron",
9999
help="The backend for training.",
100100
)
101+
parser.add_argument(
102+
"--true-on-policy-mode",
103+
action="store_true",
104+
default=False,
105+
help="Whether to enable true-on-policy mode.",
106+
)
101107

102108
return parser
103109

0 commit comments

Comments
 (0)