|
| 1 | +import os |
| 2 | +import tempfile |
| 3 | + |
| 4 | +import slime.utils.external_utils.command_utils as U |
| 5 | + |
| 6 | +TIGHT_DEVICE_MEMORY = U.get_bool_env_var("SLIME_TEST_TIGHT_DEVICE_MEMORY", "1") |
| 7 | + |
| 8 | +MODEL_NAME = "Qwen2.5-0.5B-Instruct" |
| 9 | +MODEL_TYPE = "qwen2.5-0.5B" |
| 10 | +NUM_GPUS = 4 |
| 11 | + |
| 12 | +# Inline sglang config: same model, 3 engine groups with different parallelism. |
| 13 | +# Group 1: 2 GPUs, 2 GPUs/engine (tp=2) → 1 engine |
| 14 | +# Group 2: 1 GPU, 1 GPU/engine (tp=1) → 1 engine |
| 15 | +# Group 3: 1 GPU, placeholder → reserves 1 GPU slot, no engine created |
| 16 | +SGLANG_CONFIG_YAML = """\ |
| 17 | +sglang: |
| 18 | + - name: default |
| 19 | + engine_groups: |
| 20 | + - worker_type: regular |
| 21 | + num_gpus: 2 |
| 22 | + num_gpus_per_engine: 2 |
| 23 | + - worker_type: regular |
| 24 | + num_gpus: 1 |
| 25 | + num_gpus_per_engine: 1 |
| 26 | + - worker_type: placeholder |
| 27 | + num_gpus: 1 |
| 28 | +""" |
| 29 | + |
| 30 | + |
| 31 | +def prepare(): |
| 32 | + U.exec_command("mkdir -p /root/models /root/datasets") |
| 33 | + U.exec_command(f"huggingface-cli download Qwen/{MODEL_NAME} --local-dir /root/models/{MODEL_NAME}") |
| 34 | + U.hf_download_dataset("zhuzilin/gsm8k") |
| 35 | + |
| 36 | + |
| 37 | +def execute(): |
| 38 | + # Write inline sglang config to a temp file |
| 39 | + config_file = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", prefix="sglang_config_", delete=False) |
| 40 | + config_file.write(SGLANG_CONFIG_YAML) |
| 41 | + config_file.flush() |
| 42 | + config_path = config_file.name |
| 43 | + |
| 44 | + ckpt_args = f"--hf-checkpoint /root/models/{MODEL_NAME}/ " f"--ref-load /root/models/{MODEL_NAME}/ " |
| 45 | + |
| 46 | + rollout_args = ( |
| 47 | + "--prompt-data /root/datasets/gsm8k/train.parquet " |
| 48 | + "--input-key messages " |
| 49 | + "--label-key label " |
| 50 | + "--apply-chat-template " |
| 51 | + "--rollout-shuffle " |
| 52 | + "--rm-type math " |
| 53 | + "--num-rollout 3 " |
| 54 | + "--rollout-batch-size 8 " |
| 55 | + "--n-samples-per-prompt 4 " |
| 56 | + "--rollout-max-response-len 1024 " |
| 57 | + "--rollout-temperature 0.8 " |
| 58 | + "--over-sampling-batch-size 16 " |
| 59 | + "--dynamic-sampling-filter-path slime.rollout.filter_hub.dynamic_sampling_filters.check_reward_nonzero_std " |
| 60 | + "--global-batch-size 32 " |
| 61 | + ) |
| 62 | + |
| 63 | + eval_args = ( |
| 64 | + "--eval-interval 20 " |
| 65 | + "--eval-prompt-data gsm8k /root/datasets/gsm8k/test.parquet " |
| 66 | + "--n-samples-per-eval-prompt 1 " |
| 67 | + "--eval-max-response-len 1024 " |
| 68 | + "--eval-top-k 1 " |
| 69 | + ) |
| 70 | + |
| 71 | + perf_args = ( |
| 72 | + "--tensor-model-parallel-size 1 " |
| 73 | + "--sequence-parallel " |
| 74 | + "--pipeline-model-parallel-size 1 " |
| 75 | + "--context-parallel-size 1 " |
| 76 | + "--expert-model-parallel-size 1 " |
| 77 | + "--expert-tensor-parallel-size 1 " |
| 78 | + "--use-dynamic-batch-size " |
| 79 | + "--max-tokens-per-gpu 9216 " |
| 80 | + ) |
| 81 | + |
| 82 | + grpo_args = ( |
| 83 | + "--advantage-estimator grpo " |
| 84 | + "--use-kl-loss " |
| 85 | + "--kl-loss-coef 0.00 " |
| 86 | + "--kl-loss-type low_var_kl " |
| 87 | + "--entropy-coef 0.00 " |
| 88 | + "--eps-clip 0.2 " |
| 89 | + "--eps-clip-high 0.28 " |
| 90 | + ) |
| 91 | + |
| 92 | + optimizer_args = ( |
| 93 | + "--optimizer adam " |
| 94 | + "--lr 1e-6 " |
| 95 | + "--lr-decay-style constant " |
| 96 | + "--weight-decay 0.1 " |
| 97 | + "--adam-beta1 0.9 " |
| 98 | + "--adam-beta2 0.98 " |
| 99 | + ) |
| 100 | + |
| 101 | + sglang_args = ( |
| 102 | + "--rollout-num-gpus-per-engine 1 " |
| 103 | + f"--sglang-mem-fraction-static {0.6 if TIGHT_DEVICE_MEMORY else 0.7} " |
| 104 | + "--sglang-enable-metrics " |
| 105 | + f"--sglang-config {config_path} " |
| 106 | + ) |
| 107 | + |
| 108 | + ci_args = "--ci-test " |
| 109 | + |
| 110 | + misc_args = ( |
| 111 | + "--attention-dropout 0.0 " |
| 112 | + "--hidden-dropout 0.0 " |
| 113 | + "--accumulate-allreduce-grads-in-fp32 " |
| 114 | + "--attention-softmax-in-fp32 " |
| 115 | + "--attention-backend flash " |
| 116 | + "--actor-num-nodes 1 " |
| 117 | + "--actor-num-gpus-per-node 4 " |
| 118 | + "--colocate " |
| 119 | + "--megatron-to-hf-mode bridge " |
| 120 | + ) |
| 121 | + |
| 122 | + train_args = ( |
| 123 | + f"{ckpt_args} " |
| 124 | + f"{rollout_args} " |
| 125 | + f"{optimizer_args} " |
| 126 | + f"{grpo_args} " |
| 127 | + f"{U.get_default_wandb_args(__file__)} " |
| 128 | + f"{perf_args} " |
| 129 | + f"{eval_args} " |
| 130 | + f"{sglang_args} " |
| 131 | + f"{ci_args} " |
| 132 | + f"{misc_args} " |
| 133 | + ) |
| 134 | + |
| 135 | + U.execute_train( |
| 136 | + train_args=train_args, |
| 137 | + num_gpus_per_node=NUM_GPUS, |
| 138 | + megatron_model_type=MODEL_TYPE, |
| 139 | + ) |
| 140 | + |
| 141 | + |
| 142 | +if __name__ == "__main__": |
| 143 | + prepare() |
| 144 | + os.environ.pop("http_proxy", None) |
| 145 | + os.environ.pop("https_proxy", None) |
| 146 | + os.environ.pop("HTTP_PROXY", None) |
| 147 | + os.environ.pop("HTTPS_PROXY", None) |
| 148 | + execute() |
0 commit comments