From dec9cf4436c6002e97df79c08e118c99346cc9b3 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Fri, 24 Oct 2025 14:47:06 +0800 Subject: [PATCH] cp --- examples/true_on_policy/run_simple.py | 152 +++++++++++++++++++++++++ slime/backends/fsdp_utils/actor.py | 12 ++ slime/backends/fsdp_utils/arguments.py | 3 + slime/utils/arguments.py | 6 + 4 files changed, 173 insertions(+) create mode 100644 examples/true_on_policy/run_simple.py diff --git a/examples/true_on_policy/run_simple.py b/examples/true_on_policy/run_simple.py new file mode 100644 index 000000000..f3b459492 --- /dev/null +++ b/examples/true_on_policy/run_simple.py @@ -0,0 +1,152 @@ +import os +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parents[2] / "tests")) + +import command_utils as U + +MODEL_NAME = os.environ.get("SLIME_SCRIPT_MODEL_NAME", "Qwen3-0.6B") +assert MODEL_NAME in {"Qwen3-0.6B", "Qwen3-4B"} + +MODE = os.environ.get("SLIME_SCRIPT_MODE", "normal") +assert MODE in {"normal", "debug_minimal", "debug_one_sample"} + +NUM_GPUS = int(os.environ.get("SLIME_SCRIPT_NUM_GPUS", "1")) + + +def prepare(): + U.exec_command("mkdir -p /root/models /root/datasets") + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") + U.hf_download_dataset("zhuzilin/gsm8k") + + +def execute(): + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME} " + + rollout_args = ( + "--prompt-data /root/datasets/gsm8k/train.parquet " + "--input-key messages " + "--label-key label " + "--apply-chat-template " + "--rollout-shuffle " + "--rm-type math " + f"--num-rollout {2 if MODE == 'debug_one_sample' else 3000} " + f"--rollout-batch-size {1 if MODE == 'debug_one_sample' else 32} " + f"--n-samples-per-prompt {1 if MODE == 'debug_one_sample' else 8} " + f"--rollout-max-response-len {2 if MODE == 'debug_one_sample' else 1024} " + "--rollout-temperature 0.8 " + # temp remove this to make test easier + # "--over-sampling-batch-size 64 " + # "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " + f"--global-batch-size {1 if MODE == 'debug_one_sample' else 256} " + ) + + eval_args = "" + if MODE == "normal": + eval_args = ( + "--eval-interval 20 " + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " + "--n-samples-per-eval-prompt 1 " + "--eval-max-response-len 1024 " + "--eval-top-k 1 " + ) + + grpo_args = ( + "--advantage-estimator grpo " + # "--use-kl-loss " + "--kl-loss-coef 0.00 " + "--kl-loss-type low_var_kl " + "--kl-coef 0.00 " + "--entropy-coef 0.00 " + "--eps-clip 0.2 " + "--eps-clip-high 0.28 " + # mainly to look at its metric + "--use-tis " + ) + + optimizer_args = ( + "--optimizer adam " + "--lr 1e-6 " + "--lr-decay-style constant " + "--weight-decay 0.1 " + "--adam-beta1 0.9 " + "--adam-beta2 0.98 " + ) + + sglang_args = ( + "--rollout-num-gpus-per-engine 1 " + "--sglang-decode-log-interval 1000 " + "--sglang-enable-metrics " + f"--sglang-mem-fraction-static {0.2 if MODEL_NAME == 'Qwen3-4B' else 0.4} " + f"{'--sglang-disable-cuda-graph ' if MODE == 'debug_one_sample' else ''}" + ) + + fsdp_args = ( + # Set to true for FULL_STATE_DICT mode, false for SHARDED_STATE_DICT mode (default) + # "--fsdp-full-params " # Uncomment this line to enable full params mode + # Set the bucket size for weight update + "--update-weight-buffer-size 536870912 " # 512MB + ) + + ci_args = ( + "--ci-test " + "--ci-disable-kl-checker " + "--ci-metric-checker-key eval/gsm8k " + "--ci-metric-checker-threshold 0.71 " # loose threshold at 60 step + ) + + misc_args = "--actor-num-nodes 1 " f"--actor-num-gpus-per-node {NUM_GPUS} " "--colocate " "--train-backend fsdp " + + if MODEL_NAME == "Qwen3-4B": + misc_args += ( + "--use-dynamic-batch-size " + # TODO pick a good value + "--max-tokens-per-gpu 2048 " + ) + + true_on_policy_args = ( + "--sglang-enable-deterministic-inference " + "--sglang-rl-on-policy-target fsdp " + "--sglang-attention-backend fa3 " + "--attn-implementation flash_attention_3 " + "--deterministic-mode " + "--true-on-policy-mode " + ) + true_on_policy_envs = { + # TODO note: "Ring" in original RL PR, "allreduce:tree" in SGLang + # "NCCL_ALGO": "Ring", + "NCCL_ALGO": "allreduce:tree", + "NVTE_ALLOW_NONDETERMINISTIC_ALGO": "0", + "CUBLAS_WORKSPACE_CONFIG": ":4096:8", + } + + train_args = ( + f"{ckpt_args} " + f"{rollout_args} " + f"{optimizer_args} " + f"{grpo_args} " + f"{sglang_args} " + f"{U.get_default_wandb_args(__file__)} " + f"{eval_args} " + f"{fsdp_args} " + f"{ci_args} " + f"{misc_args} " + f"{true_on_policy_args} " + ) + + U.execute_train( + train_args=train_args, + num_gpus=NUM_GPUS, + model_type=None, + extra_env_vars={ + **true_on_policy_envs, + "SGLANG_DUMPER_ENABLE": "1" if MODE == "debug_one_sample" else "0", + "SGLANG_TEMP_UTILS_ENABLE_DEBUG_PRINT": "1" if MODE == "debug_one_sample" else "0", + }, + ) + + +if __name__ == "__main__": + prepare() + execute() diff --git a/slime/backends/fsdp_utils/actor.py b/slime/backends/fsdp_utils/actor.py index 6040d68d3..1c1d6e730 100644 --- a/slime/backends/fsdp_utils/actor.py +++ b/slime/backends/fsdp_utils/actor.py @@ -50,6 +50,12 @@ class FSDPTrainRayActor(TrainRayActor): def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = False) -> int: # type: ignore[override] super().init(args, role, wandb_run_id, with_ref) + if args.true_on_policy_mode: + from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode + + print("FSDPTrainRayActor call enable_batch_invariant_mode for true-on-policy") + enable_batch_invariant_mode() + # Update rank and world_size for wandb secondary initialization (using actual distributed values) args.rank = dist.get_rank() args.world_size = dist.get_world_size() @@ -447,6 +453,11 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: pg_clipfrac = sum_of_sample_mean(pg_clipfrac, response_lengths, loss_masks) ppo_kl = sum_of_sample_mean(ppo_kl.abs(), response_lengths, loss_masks) + train_rollout_logprob_diff = old_log_probs - rollout_log_probs + train_rollout_logprob_diff = sum_of_sample_mean( + train_rollout_logprob_diff, response_lengths, loss_masks + ).detach() + loss = pg_loss if self.args.entropy_coef != 0: @@ -470,6 +481,7 @@ def train(self, rollout_id: int, rollout_data_ref: Box) -> None: "pg_loss": pg_loss.detach(), "pg_clipfrac": pg_clipfrac.detach(), "ppo_kl": ppo_kl.detach(), + "train_rollout_logprob_diff": train_rollout_logprob_diff, } if self.args.use_kl_loss: diff --git a/slime/backends/fsdp_utils/arguments.py b/slime/backends/fsdp_utils/arguments.py index 76d8ddf98..4dd5fb724 100644 --- a/slime/backends/fsdp_utils/arguments.py +++ b/slime/backends/fsdp_utils/arguments.py @@ -32,6 +32,9 @@ class FSDPArgs: # FSDP configuration fsdp_full_params: bool = False # If True, use full_tensor; if False, use shard_tensor + # Others + deterministic_mode: bool = False # This name must be the same as Megatron's + # YAML bookkeeping config: str | None = None diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index b65fcdc15..3632cbd07 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -98,6 +98,12 @@ def add_train_arguments(parser): default="megatron", help="The backend for training.", ) + parser.add_argument( + "--true-on-policy-mode", + action="store_true", + default=False, + help="Whether to enable true-on-policy mode.", + ) return parser