Skip to content

[Megatron] Support optimizer offload for moe when ep > 1 #1638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

zzong2006
Copy link
Contributor

@zzong2006 zzong2006 commented May 22, 2025

Checklist Before Starting

  • Search for similar PR(s).

What does this PR do?

This simple PR adds support for ChainedOptimizer offloading in the Megatron-LM training environment.

In Megatron-LM, ChainedOptimizer is used when expert parallelism (expert_parallel > 1, related to #1467 ) is enabled—commonly in Mixture-of-Experts (MoE) models.

This has been tested and validated with the Qwen3-235B-22A model configuration.

High-Level Design

Demonstrate the high-level design if this PR is complex.

Specific Changes

List the specific changes.

API

Demonstrate how the API changes if any.

Usage Example

Provide usage example(s) for easier usage.

...
actor_rollout_ref.actor.megatron.optimizer_offload=True \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=16 \
...

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc.

Additional Info.

  • Issue Number: Fixes issue # or discussion # if any.
  • Training: [Megatron]
  • Inference: [none]

Checklist Before Submitting

  • Read the Contribute Guide.
  • Apply pre-commit checks.
  • Add [BREAKING] to the PR title if it breaks any API.
  • Update the documentation about your changes in the docs.
  • Add CI test(s) if necessary.

… in offload and load functions. This allows offloading and loading of parameters for both standard optimizers and chained optimizers, improving flexibility in optimizer management.
@zzong2006 zzong2006 changed the title Support MoE optimizer offload when ep > 1 [Megatron] Support optimizer offload for moe when ep > 1 May 22, 2025
@CLAassistant
Copy link

CLAassistant commented May 22, 2025

CLA assistant check
All committers have signed the CLA.

@ETOgaosion
Copy link
Collaborator

ETOgaosion commented May 22, 2025

@zzong2006 Thanks for your great contribution! Congratulations, seems that you have successfully run with Qwen3-235B-22A!

For this PR, could you sign the CLA and after a bug fix #1634 we can run the CI workflow for your great work~

Currently, there are some missing CIs for the offloading process of megatron weights/grad/optimizers. If you are in convenient, could you help to add the environment variables and options in tests/e2e/run_ppo_trainer_megatron.sh, and try enabling or disabling them in .github/workflows/e2e_ppo_trainer_megatron.yml? I can work with you and pr CI tests to your branch if inconvenient~

Also, for the whole process of the Qwen3-235B-22A post-train, could you share more info about your configurations, results and experiences?

@ETOgaosion ETOgaosion requested a review from BearBiscuit05 May 22, 2025 11:19
@zzong2006
Copy link
Contributor Author

zzong2006 commented May 23, 2025

@ETOgaosion thanks for the comment!

✅ I’ve signed the CLA.

I’ll give the CI additions you mentioned a try. From a quick look, it seems most of the existing tests are based on smaller models like qwen3-0.6b. Just to confirm—would it be okay to use a similarly small model for the additional CI tests as well? For Qwen3 MoE models, the smallest available option seems to be 30B-3A, which could lead to significantly longer test times.
In any case, I’ll push a follow-up commit to this PR.

Qwen3-235B-22A Post-Training Summary

For the Qwen3-235B-22A setup, I performed GRPO training using 128 H200 GPUs in a Ray cluster environment.

Image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.3

Here’s a detailed breakdown of the configuration and setup:

PP=2    # num layers: 94 
TP=2
CP=1
EP=32
ETP=2
VLLM_TP=8
VLLM_MAX_NUM_SEQS=1
VLLM_GPU_MEMORY_UTILIZATION=0.75
VLLM_MAX_NUM_BATCHED_TOKENS=4096
MAX_ATTEMPTS=60  # Max attempts for retrying operations (e.g., Ray head ready)
SLEEP_INTERVAL=10 # Sleep interval between attempts (seconds)
HF_MODEL_PATH=/data/base_models/Qwen3-235B-A22B
DIST_CKPT_PATH=$BASE_DIR/models/Qwen3-235B-A22B-mcore
CONFIG_FILE_NAME=ppo_megatron_trainer.yaml   # same as verl/trainer/config/ppo_megatron_trainer.yaml
MAX_CONTEXT_LENGTH=8192
MAX_RESPONSE_LENGTH=1024
TRAIN_BATCH_SIZE=128     # num_nodes * num_gpus_per_node = 16 * 8 = 128
NUM_GENERATIONS=8
MICRO_MINI_BATCH_SIZE=1
PROJECT_NAME=verl_grpo_megatron_gsm8k
EXPERIMENT_NAME=qwen3_235b22b_moe_mcore
TIMESTAMP=$(date '+%Y%m%d_%H%M%S')
LOCAL_DIR=$BASE_DIR/models/${PROJECT_NAME}/${EXPERIMENT_NAME}/v${TIMESTAMP}
TENSORBOARD_DIR=$LOCAL_DIR/logs
mkdir -p "$TENSORBOARD_DIR"
export TENSORBOARD_DIR

ROLLOUT_MODE="sync"
gsm8k_train_path=$BASE_DIR/data/gsm8k/train.parquet
gsm8k_test_path=$BASE_DIR/data/gsm8k/test.parquet
TRAIN_FILES="['$gsm8k_train_path']"
TEST_FILES="['$gsm8k_test_path']"

ROLLOUT_NAME=vllm

# Install verl with vllm extra

log_info "Installing project in editable mode with ${ROLLOUT_NAME} extras..."
pip3 install -q -e .[${ROLLOUT_NAME}]
pip3 install -q ray[default]
log_info "Installation complete."

log_info "Starting Ray head node on ${HEAD_NODE_IP}:${HEAD_PORT} with ${SLURM_GPUS_PER_NODE} GPUs..."
ray start --head --node-ip-address="${HEAD_NODE_IP}" --port="${HEAD_PORT}" --num-gpus="${SLURM_GPUS_PER_NODE}" --block 

# skip the script that wait for all worker nodes to join

log_info "All $EXPECTED_TOTAL_NODES nodes are ready in the Ray cluster."
log_info "Launching main PPO training script..."

export PYTHONUNBUFFERED=1

# https://github.com/volcengine/verl/blob/main/docs/examples/config.rst
python3 -m verl.trainer.main_ppo --config-path=${CONFIG_DIR_PATH} \
    --config-name=${CONFIG_FILE_NAME} \
    algorithm.adv_estimator=grpo \
    data.train_files="${TRAIN_FILES}" \
    data.val_files="${TEST_FILES}" \
    data.train_batch_size=${TRAIN_BATCH_SIZE} \
    data.max_prompt_length=${MAX_CONTEXT_LENGTH} \
    data.max_response_length=${MAX_RESPONSE_LENGTH} \
    data.filter_overlong_prompts=True \
    data.truncation='error' \
    actor_rollout_ref.model.path=${HF_MODEL_PATH} \
    actor_rollout_ref.actor.optim.lr=1e-6 \
    actor_rollout_ref.actor.ppo_mini_batch_size=${TRAIN_BATCH_SIZE} \
    actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${MICRO_MINI_BATCH_SIZE} \
    actor_rollout_ref.actor.checkpoint.contents=['model','extra'] \
    actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${PP} \
    actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${TP} \
    actor_rollout_ref.actor.megatron.context_parallel_size=${CP} \
    actor_rollout_ref.actor.megatron.expert_model_parallel_size=${EP} \
    actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ETP} \
    actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
    actor_rollout_ref.actor.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \
    actor_rollout_ref.actor.megatron.param_offload=True \
    actor_rollout_ref.actor.megatron.optimizer_offload=True \
    actor_rollout_ref.actor.megatron.grad_offload=True \
    actor_rollout_ref.actor.use_kl_loss=True \
    actor_rollout_ref.actor.kl_loss_coef=0.001 \
    actor_rollout_ref.actor.kl_loss_type=low_var_kl \
    actor_rollout_ref.actor.entropy_coeff=0 \
    actor_rollout_ref.model.enable_gradient_checkpointing=True \
    actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=${MICRO_MINI_BATCH_SIZE} \
    actor_rollout_ref.rollout.tensor_model_parallel_size=${VLLM_TP} \
    actor_rollout_ref.rollout.name=${ROLLOUT_NAME} \
    actor_rollout_ref.rollout.gpu_memory_utilization=${VLLM_GPU_MEMORY_UTILIZATION} \
    actor_rollout_ref.rollout.enforce_eager=True \
    actor_rollout_ref.rollout.free_cache_engine=True \
    actor_rollout_ref.rollout.max_num_batched_tokens=${VLLM_MAX_NUM_BATCHED_TOKENS} \
    actor_rollout_ref.rollout.max_model_len=${MAX_CONTEXT_LENGTH} \
    actor_rollout_ref.rollout.enable_chunked_prefill=False \
    actor_rollout_ref.rollout.engine_kwargs.swap_space=32 \
    actor_rollout_ref.rollout.n=${NUM_GENERATIONS} \
    actor_rollout_ref.rollout.max_num_seqs=${VLLM_MAX_NUM_SEQS} \
    actor_rollout_ref.rollout.mode=${ROLLOUT_MODE} \
    actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${MICRO_MINI_BATCH_SIZE} \
    actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${PP} \
    actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${TP} \
    actor_rollout_ref.ref.megatron.context_parallel_size=${CP} \
    actor_rollout_ref.ref.megatron.expert_model_parallel_size=${EP} \
    actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${ETP} \
    actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
    actor_rollout_ref.ref.megatron.dist_checkpointing_path=${DIST_CKPT_PATH} \
    actor_rollout_ref.ref.megatron.param_offload=True \
    algorithm.use_kl_in_reward=True \
    algorithm.kl_ctrl.kl_coef=0.001 \
    trainer.critic_warmup=0 \
    trainer.val_before_train=True \
    trainer.default_local_dir=${LOCAL_DIR} \
    trainer.logger=['console','tensorboard'] \
    trainer.project_name=${PROJECT_NAME} \
    trainer.experiment_name=${EXPERIMENT_NAME} \
    trainer.n_gpus_per_node=${SLURM_GPUS_PER_NODE} \
    trainer.nnodes=${EXPECTED_TOTAL_NODES} \
    trainer.save_freq=30 \
    trainer.test_freq=5 \
    trainer.max_actor_ckpt_to_keep=1 \
    trainer.max_critic_ckpt_to_keep=1 \
    trainer.total_epochs=5 "$@" # Pass any additional arguments from sbatch

APP_EXIT_STATUS=$?
log_info "Python application finished with exit status $APP_EXIT_STATUS."

log_info "Attempting to stop Ray head node..."
ray stop

The results may vary depending on the dataset, but with gsm8k as a reference, I was able to train with a context length of 32k and a response length of 4k without encountering any memory issues.

Result from Tensorboard

image

@zzong2006
Copy link
Contributor Author

zzong2006 commented May 23, 2025

Additionally, to share a few more observations:

  1. When running with EP = 64 and TP = 1 (with PP fixed at 2), I encountered OOM (out-of-memory) errors. Training became stable starting from EP = 32 and TP = 2. I also experimented with lowering EP and increasing TP (e.g., TP = 4), which resulted in longer training times (increased from 9 minutes to 10 minutes per step), but I couldn’t observe any noticeable memory savings.

2. When using configurations where EP > 1 and TP > 1, I encountered the following error unless TP and ETP were set to the same value. (wrong information sorry: #1638 (comment))

@ETOgaosion
Copy link
Collaborator

ETOgaosion commented May 23, 2025

Hi @zzong2006 , you did a great job, thanks a lot for sharing your experience.

@ISEEKYAN cc this bug please, seems that tp and etp still need to be same?

@ETOgaosion
Copy link
Collaborator

ETOgaosion commented May 23, 2025

I’ll give the CI additions you mentioned a try. From a quick look, it seems most of the existing tests are based on smaller models like qwen3-0.6b. Just to confirm—would it be okay to use a similarly small model for the additional CI tests as well? For Qwen3 MoE models, the smallest available option seems to be 30B-3A, which could lead to significantly longer test times.

For the CI test, you can just modify the tests/e2e/run_ppo_trainer_megatron.sh to add some options, for quick CI tests, we commonly accept the small models~

It's like:

PARAM_OFFLOAD=${PARAM_OFFLOAD:-False}
GRAD_OFFLOAD=${GRAD_OFFLOAD:-False}
OPTIMIZER_OFFLOAD=${OPTIMIZER_OFFLOAD:-False}

# pass them to actor.megatron.param_offload etc

In .github/workflows/e2e_ppo_trainer_megatron.yml, like ppo_training with reward model tests PARAM_OFFLOAD=True by default to test all models.

Thanks a lot for help~

@ISEEKYAN
Copy link
Contributor

Hi @zzong2006 , you did a great job, thanks a lot for sharing your experience.

@ISEEKYAN cc this bug please, seems that tp and etp still need to be same?

I have tried and verified with tp!=etp, shown in example of #1467

@zzong2006
Copy link
Contributor Author

@ISEEKYAN It seems likely that it was a human error, so I’ll retry with the configuration set to TP = 2 and ETP = 1, and I’ll share the results with you afterward.

@ETOgaosion
Copy link
Collaborator

@zzong2006 Thanks for contribution to CI tests~ I've made a little refactor and fix to your branch, can you merge them? zzong2006#1

We may reuse the existing tests as there are plenty of tests now (qaq)

@zzong2006
Copy link
Contributor Author

@ISEEKYAN After trying again, it seems that training proceeds correctly even when TP and ETP are set to different values. Apologies for the earlier confusion. I'll update my previous comment accordingly.

For clarity, I’ve attached the training logs for your reference.
Please note that due to the large model size, the logs only include a few training steps.

@ETOgaosion I’ve also added the offload-related CI tests you requested.
Would you mind taking a look and reviewing the changes?

…e_than_one

[feat] reuse existing tests and remove the error config of reward models
@zzong2006
Copy link
Contributor Author

@ETOgaosion Sure, I’ve merged it. Thanks!

@vermouth1992
Copy link
Collaborator

@ETOgaosion Shall we merge this?

@ETOgaosion
Copy link
Collaborator

Yes, it's ready now.

@ETOgaosion ETOgaosion merged commit 0286210 into volcengine:main May 24, 2025
37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants