Skip to content

[Refactor] Merge rollout controller into rollout manager#304

Merged
zhuzilin merged 18 commits intoTHUDM:mainfrom
PopSoda2002:refactor/rollout
Sep 16, 2025
Merged

[Refactor] Merge rollout controller into rollout manager#304
zhuzilin merged 18 commits intoTHUDM:mainfrom
PopSoda2002:refactor/rollout

Conversation

@PopSoda2002
Copy link
Collaborator

@PopSoda2002 PopSoda2002 commented Sep 7, 2025

Motivation:

  • Follow the software principle of High Cohesion, Low Coupling to merge the two function-similar class into one

Coauthor by @Williamren97 @MortalHappiness

Solution:

  • Merge common function into RolloutManager like generate eval
  • Adjust the files and functions position, move public function to the front, private func to back and remove buffer.py
  • Change corresponding rollout manager start code
  • Refactor some function like generate, divides into several functions
  • Pass ray actor variable like RolloutEngine from the ray actor like RolloutManager using ray.get() func

Result

We use qwen3-4b to test the refactor change will not change the behavior of baseline

  • colocate
    script:
#!/bin/bash

# for rerun the task
pkill -9 sglang
sleep 3
ray stop --force
pkill -9 ray
pkill -9 python
sleep 3
pkill -9 ray
pkill -9 python

set -ex

# will prevent ray from buffering stdout/stderr
export PYTHONBUFFERED=16

NVLINK_COUNT=$(nvidia-smi | grep -o "NVLink" | wc -l)
if [ "$NVLINK_COUNT" -gt 0 ]; then
    HAS_NVLINK=1
else
    HAS_NVLINK=0
fi
echo "HAS_NVLINK: $HAS_NVLINK (detected $NVLINK_COUNT NVLink references)"

SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
source "${SCRIPT_DIR}/models/qwen3-4B.sh"

CKPT_ARGS=(
   --hf-checkpoint /root/Qwen3-4B
   #--hf-checkpoint /root/Qwen3-4B-FP8
   --ref-load /root/Qwen3-4B_torch_dist
   --load /root/Qwen3-4B_slime/
   --save /root/Qwen3-4B_slime/
   --save-interval 20
)

ROLLOUT_ARGS=(
   --prompt-data /root/dapo-math-17k/dapo-math-17k.jsonl
   --input-key prompt
   --label-key label
   --apply-chat-template
   --rollout-shuffle
   --rm-type deepscaler
   --num-rollout 3000
   --rollout-batch-size 32
   --n-samples-per-prompt 4
   --rollout-max-response-len 8192
   --rollout-temperature 0.8

   --global-batch-size 128
   --balance-data
)

EVAL_ARGS=(
   --eval-interval 20
   --eval-prompt-data aime /root/aime-2024/aime-2024.jsonl
   --n-samples-per-eval-prompt 4
   --eval-max-response-len 16384
   --eval-top-p 0.7
)

PERF_ARGS=(
   --tensor-model-parallel-size 2
   --sequence-parallel
   --pipeline-model-parallel-size 1
   --context-parallel-size 1
   --expert-model-parallel-size 1
   --expert-tensor-parallel-size 1

   --recompute-granularity full
   --recompute-method uniform
   --recompute-num-layers 1

   # --micro-batch-size 1
   --use-dynamic-batch-size
   --max-tokens-per-gpu 9216
)

GRPO_ARGS=(
   --advantage-estimator grpo
   --use-kl-loss
   --kl-loss-coef 0.00
   --kl-loss-type low_var_kl
   --entropy-coef 0.00
   --eps-clip 0.2
   --eps-clip-high 0.28
)

OPTIMIZER_ARGS=(
   --optimizer adam
   --lr 1e-6
   --lr-decay-style constant
   --weight-decay 0.1
   --adam-beta1 0.9
   --adam-beta2 0.98
)

export WANDB_API_KEY="a37f4796e6205800c4212556a38e1319b5f144b7"
export CUDA_VISIBLE_DEVICES=5,6
WANDB_ARGS=(
   --use-wandb
   --wandb-project slime-dev
   --wandb-group qwen3-4B-test-huapeng
   --wandb-key ${WANDB_API_KEY}
)

SGLANG_ARGS=(
   --rollout-num-gpus-per-engine 1
   --sglang-mem-fraction-static 0.7
)

MISC_ARGS=(
   # default dropout in megatron is 0.1
   --attention-dropout 0.0
   --hidden-dropout 0.0
   # should be good for model performance
   --accumulate-allreduce-grads-in-fp32
   --attention-softmax-in-fp32
   # need to comment this when using model with MLA
   --attention-backend flash
)

# launch the master node of ray in container
export MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
ray start --head --node-ip-address ${MASTER_ADDR} --num-gpus 2 --disable-usage-stats --dashboard-host=0.0.0.0 --dashboard-port=8265

# Build the runtime environment JSON with proper variable substitution
RUNTIME_ENV_JSON="{
  \"env_vars\": {
    \"PYTHONPATH\": \"/root/Megatron-LM/\",
    \"CUDA_DEVICE_MAX_CONNECTIONS\": \"1\",
    \"NCCL_NVLS_ENABLE\": \"${HAS_NVLINK}\"
  }
}"

ray job submit --address="http://127.0.0.1:8265" \
   --runtime-env-json="${RUNTIME_ENV_JSON}" \
   -- python3 train.py \
   --actor-num-nodes 1 \
   --actor-num-gpus-per-node 2 \
   --colocate \
   ${MODEL_ARGS[@]} \
   ${CKPT_ARGS[@]} \
   ${ROLLOUT_ARGS[@]} \
   ${OPTIMIZER_ARGS[@]} \
   ${GRPO_ARGS[@]} \
   ${DISTRIBUTED_ARGS[@]} \
   ${WANDB_ARGS[@]} \
   ${PERF_ARGS[@]} \
   ${EVAL_ARGS[@]} \
   ${SGLANG_ARGS[@]} \
   ${MISC_ARGS[@]}

baseline(main branch) vs change(refactor branch):
image
image
Within the jitter range, they remain basically consistent, and the spike timings are all aligned.wandb link

  • disaggreagate

path = Path(path_template.format(rollout_id=self.rollout_id))
print(f"Save debug rollout data to {path}")
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(data, path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not

            torch.save(
                dict(
                    rollout_id=self.rollout_id,
                    samples=[sample.to_dict() for sample in data],
                ),
                path,
            )

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just copy and past for this part, I think we can do the replacing in the future

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.save(
dict(
rollout_id=self.rollout_id,
samples=[sample.to_dict() for sample in data],
),
path,
)

Did you copy from here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, I think I used the older version of this part, I synced to the newest now. Thanks for your help!

@PopSoda2002 PopSoda2002 marked this pull request as ready for review September 14, 2025 17:46
Copy link
Contributor

@Williamren97 Williamren97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :)

PopSoda2002 and others added 5 commits September 15, 2025 17:31
Co-authored-by: William Ren <williamren97@gmail.com>
Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>"
Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
Co-authored-by: Chengxing Xie <91449279+yitianlian@users.noreply.github.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Haoran Wang <70007833+UbeCc@users.noreply.github.com>
@zhaochenyang20
Copy link
Collaborator

@PopSoda2002 fix the lint locally.

@zhuzilin zhuzilin merged commit 20d0679 into THUDM:main Sep 16, 2025
3 of 4 checks passed
@zhuzilin zhuzilin mentioned this pull request Sep 17, 2025
llltttwww pushed a commit to llltttwww/slime that referenced this pull request Nov 30, 2025
* Refactor rollout manager

* Polish

* experiment

* rebase

* Clean code

* add engine lock

* Clean code

* Clean

* Move position of router and engine

* remove logging

* clean code

* log

* pre commit

* resolve conflicts

* precommit

* "Refactor rollout manager

Co-authored-by: William Ren <williamren97@gmail.com>
Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>"

* Refactor rollout manager

Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>

* add coauthor

Co-authored-by: Chengxing Xie <91449279+yitianlian@users.noreply.github.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Haoran Wang <70007833+UbeCc@users.noreply.github.com>

---------

Co-authored-by: William Ren <williamren97@gmail.com>
Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
Co-authored-by: Chengxing Xie <91449279+yitianlian@users.noreply.github.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Haoran Wang <70007833+UbeCc@users.noreply.github.com>
yueming-yuan pushed a commit to yueming-yuan/slime that referenced this pull request Dec 29, 2025
Yangruipis pushed a commit to rednote-ai/slime that referenced this pull request Feb 28, 2026
* Refactor rollout manager

* Polish

* experiment

* rebase

* Clean code

* add engine lock

* Clean code

* Clean

* Move position of router and engine

* remove logging

* clean code

* log

* pre commit

* resolve conflicts

* precommit

* "Refactor rollout manager

Co-authored-by: William Ren <williamren97@gmail.com>
Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>"

* Refactor rollout manager

Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>

* add coauthor

Co-authored-by: Chengxing Xie <91449279+yitianlian@users.noreply.github.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Haoran Wang <70007833+UbeCc@users.noreply.github.com>

---------

Co-authored-by: William Ren <williamren97@gmail.com>
Co-authored-by: Chi-Sheng Liu <chishengliu@chishengliu.com>
Co-authored-by: Chengxing Xie <91449279+yitianlian@users.noreply.github.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: Haoran Wang <70007833+UbeCc@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants