Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 152 additions & 0 deletions examples/true_on_policy/run_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parents[2] / "tests"))

import command_utils as U

MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-0.6B")
assert MODEL_NAME in {"Qwen3-0.6B", "Qwen3-4B"}

MODE = os.environ.get("SLIME_SCRIPT_MODE", "normal")
assert MODE in {"normal", "debug_minimal", "debug_one_sample"}

NUM_GPUS = int(os.environ.get("SLIME_SCRIPT_NUM_GPUS", "1"))


def prepare():
U.exec_command("mkdir -p /root/models /root/datasets")
U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}")
U.hf_download_dataset("zhuzilin/gsm8k")


def execute():
ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} "

rollout_args = (
"--prompt-data /root/datasets/gsm8k/train.parquet "
"--input-key messages "
"--label-key label "
"--apply-chat-template "
"--rollout-shuffle "
"--rm-type math "
f"--num-rollout {2 if MODE == 'debug_one_sample' else 3000} "
f"--rollout-batch-size {1 if MODE == 'debug_one_sample' else 32} "
f"--n-samples-per-prompt {1 if MODE == 'debug_one_sample' else 8} "
f"--rollout-max-response-len {2 if MODE == 'debug_one_sample' else 1024} "
"--rollout-temperature 0.8 "
# temp remove this to make test easier
# "--over-sampling-batch-size 64 "
# "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std "
f"--global-batch-size {1 if MODE == 'debug_one_sample' else 256} "
)

eval_args = ""
if MODE == "normal":
eval_args = (
"--eval-interval 20 "
"--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet "
"--n-samples-per-eval-prompt 1 "
"--eval-max-response-len 1024 "
"--eval-top-k 1 "
)

grpo_args = (
"--advantage-estimator grpo "
# "--use-kl-loss "
"--kl-loss-coef 0.00 "
"--kl-loss-type low_var_kl "
"--kl-coef 0.00 "
"--entropy-coef 0.00 "
"--eps-clip 0.2 "
"--eps-clip-high 0.28 "
# mainly to look at its metric
"--use-tis "
)

optimizer_args = (
"--optimizer adam "
"--lr 1e-6 "
"--lr-decay-style constant "
"--weight-decay 0.1 "
"--adam-beta1 0.9 "
"--adam-beta2 0.98 "
)

sglang_args = (
"--rollout-num-gpus-per-engine 1 "
"--sglang-decode-log-interval 1000 "
"--sglang-enable-metrics "
f"--sglang-mem-fraction-static {0.2 if MODEL_NAME == 'Qwen3-4B' else 0.4} "
f"{'--sglang-disable-cuda-graph ' if MODE == 'debug_one_sample' else ''}"
)

fsdp_args = (
# Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default)
# "--fsdp-full-params " # Uncomment this line to enable full params mode
# Set the bucket size for weight update
"--update-weight-buffer-size 536870912 " # 512MB
)

ci_args = (
"--ci-test "
"--ci-disable-kl-checker "
"--ci-metric-checker-key eval/gsm8k "
"--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step
)

misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " "--train-backend fsdp "

if MODEL_NAME == "Qwen3-4B":
misc_args += (
"--use-dynamic-batch-size "
# TODO pick a good value
"--max-tokens-per-gpu 2048 "
)

true_on_policy_args = (
"--sglang-enable-deterministic-inference "
"--sglang-rl-on-policy-target fsdp "
"--sglang-attention-backend fa3 "
"--attn-implementation flash_attention_3 "
"--deterministic-mode "
"--true-on-policy-mode "
)
true_on_policy_envs = {
# TODO note: "Ring" in original RL PR, "allreduce:tree" in SGLang
# "NCCL_ALGO": "Ring",
"NCCL_ALGO": "allreduce:tree",
"NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0",
"CUBLAS_WORKSPACE_CONFIG": ":4096:8",
}

train_args = (
f"{ckpt_args} "
f"{rollout_args} "
f"{optimizer_args} "
f"{grpo_args} "
f"{sglang_args} "
f"{U.get_default_wandb_args(__file__)} "
f"{eval_args} "
f"{fsdp_args} "
f"{ci_args} "
f"{misc_args} "
f"{true_on_policy_args} "
)

U.execute_train(
train_args=train_args,
num_gpus=NUM_GPUS,
model_type=None,
extra_env_vars={
**true_on_policy_envs,
"SGLANG_DUMPER_ENABLE": "1" if MODE == "debug_one_sample" else "0",
"SGLANG_TEMP_UTILS_ENABLE_DEBUG_PRINT": "1" if MODE == "debug_one_sample" else "0",
},
)


if __name__ == "__main__":
prepare()
execute()
12 changes: 12 additions & 0 deletions slime/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class FSDPTrainRayActor(TrainRayActor):
def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = False) -> int: # type: ignore[override]
super().init(args, role, wandb_run_id, with_ref)

if args.true_on_policy_mode:
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode

print("FSDPTrainRayActor call enable_batch_invariant_mode for true-on-policy")
enable_batch_invariant_mode()

# Update rank and world_size for wandb secondary initialization (using actual distributed values)
args.rank = dist.get_rank()
args.world_size = dist.get_world_size()
Expand Down Expand Up @@ -454,6 +460,11 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks)
ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks)

train_rollout_logprob_diff = old_log_probs - rollout_log_probs
train_rollout_logprob_diff = sum_of_sample_mean(
train_rollout_logprob_diff, response_lengths, loss_masks
).detach()

loss = pg_loss

if self.args.entropy_coef != 0:
Expand All @@ -477,6 +488,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None:
"pg_loss": pg_loss.detach(),
"pg_clipfrac": pg_clipfrac.detach(),
"ppo_kl": ppo_kl.detach(),
"train_rollout_logprob_diff": train_rollout_logprob_diff,
}

if self.args.use_kl_loss:
Expand Down
1 change: 1 addition & 0 deletions slime/backends/fsdp_utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class FSDPArgs:
# FSDP configuration
fsdp_full_params: bool = False # If True, use full_tensor; if False, use shard_tensor

deterministic_mode: bool = False # This name must be the same as Megatron's
# Profile
record_memory_history: bool = False
memory_snapshot_path: str = "snapshot.pickle"
Expand Down
6 changes: 6 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ def add_train_arguments(parser):
default="megatron",
help="The backend for training.",
)
parser.add_argument(
"--true-on-policy-mode",
action="store_true",
default=False,
help="Whether to enable true-on-policy mode.",
)

return parser

Expand Down
Loading