Skip to content

Commit 0286210

Browse files
zzong2006kakao-charlie-csETOgaosion
authored
[Megatron] Support optimizer offload for moe when ep > 1 (#1638)
### Checklist Before Starting - [x] Search for similar PR(s). ### What does this PR do? This simple PR adds support for [ChainedOptimizer](https://github.com/NVIDIA/Megatron-LM/blob/75b1ca13618bded85c81fb572f58df83ba095dc9/megatron/core/optimizer/optimizer.py#L938) offloading in the Megatron-LM training environment. In Megatron-LM, ChainedOptimizer is used when expert parallelism (expert_parallel > 1, related to #1467 ) is enabled—commonly in Mixture-of-Experts (MoE) models. This has been tested and validated with the Qwen3-235B-22A model configuration. ### High-Level Design > Demonstrate the high-level design if this PR is complex. ### Specific Changes > List the specific changes. ### API > Demonstrate how the API changes if any. ### Usage Example > Provide usage example(s) for easier usage. ```python ... actor_rollout_ref.actor.megatron.optimizer_offload=True \ actor_rollout_ref.actor.megatron.expert_model_parallel_size=16 \ ... ``` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluatuion results, etc. ### Additional Info. - **Issue Number**: Fixes issue # or discussion # if any. - **Training**: [Megatron] - **Inference**: [none] ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting). - [ ] Add `[BREAKING]` to the PR title if it breaks any API. - [ ] Update the documentation about your changes in the [docs](https://github.com/volcengine/verl/tree/main/docs). - [x] Add CI test(s) if necessary. --------- Co-authored-by: charlie.cs <[email protected]> Co-authored-by: ETOgaosion <[email protected]>
1 parent 7225544 commit 0286210

File tree

4 files changed

+80
-34
lines changed

4 files changed

+80
-34
lines changed

.github/workflows/e2e_ppo_trainer_megatron.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with validation and saving
6666
run: |
6767
ray stop --force
68-
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/run_ppo_trainer_megatron.sh
68+
ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/run_ppo_trainer_megatron.sh
6969
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) after resuming
7070
run: |
7171
ray stop --force
@@ -107,7 +107,7 @@ jobs:
107107
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek)
108108
run: |
109109
ray stop --force
110-
SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh
110+
ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh
111111
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek)
112112
run: |
113113
ray stop --force
@@ -149,7 +149,7 @@ jobs:
149149
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving
150150
run: |
151151
ray stop --force
152-
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh
152+
ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh
153153
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming
154154
run: |
155155
ray stop --force
@@ -306,3 +306,4 @@ jobs:
306306
run: |
307307
rm -rf checkpoints
308308
309+

tests/e2e/run_ppo_trainer_megatron.sh

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ RM_VPP=${RM_VPP:-$COMMON_VPP}
5555
RM_CP=${RM_CP:-$COMMON_CP}
5656
RM_TP=${RM_TP:-$TRAIN_TP}
5757

58+
ALL_OFFLOAD=${ALL_OFFLOAD:-False}
59+
COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}
60+
COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}
61+
COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}
62+
63+
ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
64+
ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
65+
ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
66+
REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
67+
CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
68+
CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
69+
CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
70+
RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
71+
5872
CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra']
5973
SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0}
6074
if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then
@@ -81,6 +95,9 @@ python3 -m verl.trainer.main_ppo --config-path=config \
8195
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \
8296
actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \
8397
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \
98+
actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \
99+
actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \
100+
actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \
84101
actor_rollout_ref.actor.use_kl_loss=True \
85102
actor_rollout_ref.actor.kl_loss_coef=0.001 \
86103
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
@@ -95,6 +112,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
95112
actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \
96113
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \
97114
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
115+
actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \
98116
critic.optim.lr=2e-5 \
99117
critic.model.path="${MODEL_PATH}" \
100118
critic.model.enable_gradient_checkpointing=False \
@@ -104,13 +122,17 @@ python3 -m verl.trainer.main_ppo --config-path=config \
104122
critic.megatron.context_parallel_size=$CRITIC_CP \
105123
critic.megatron.tensor_model_parallel_size=$CRITIC_TP \
106124
critic.checkpoint.contents=$CHECKPOINT_CONTENTS \
125+
critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \
126+
critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \
127+
critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \
107128
reward_model.enable=True \
108129
reward_model.model.path="${MODEL_PATH}" \
109130
reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
110131
reward_model.megatron.pipeline_model_parallel_size=$RM_PP \
111132
reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \
112133
reward_model.megatron.context_parallel_size=$RM_CP \
113134
reward_model.megatron.tensor_model_parallel_size=$RM_TP \
135+
reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \
114136
algorithm.use_kl_in_reward=False \
115137
algorithm.kl_penalty=kl \
116138
algorithm.kl_ctrl.kl_coef=0.001 \

verl/trainer/config/ppo_megatron_trainer.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,6 @@ reward_model:
228228
strategy: megatron
229229
megatron:
230230
param_offload: False
231-
grad_offload: False
232-
optimizer_offload: False
233231
tensor_model_parallel_size: 1
234232
expert_model_parallel_size: 1
235233
expert_tensor_parallel_size: null

verl/utils/megatron_utils.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from megatron.core.distributed import DistributedDataParallel as DDP
2626
from megatron.core.distributed import DistributedDataParallelConfig
2727
from megatron.core.enums import ModelType
28-
from megatron.core.optimizer import OptimizerConfig
28+
from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig
2929
from megatron.core.transformer import TransformerConfig
3030
from megatron.core.transformer.module import Float16Module
3131
from megatron.core.utils import get_attr_wrapped_model
@@ -296,12 +296,18 @@ def load_megatron_model_to_gpu(models, load_grad=True):
296296
@torch.no_grad()
297297
def offload_megatron_copy_params(optimizers):
298298
"""
299-
Offload optimizer parameters to CPU
299+
Offload optimizer parameters to CPU. Supports both Megatron optimizers
300+
and `ChainedOptimizer`, which wraps a list of underlying optimizers.
300301
301302
Args:
302-
optimizers: The optimizer containing parameter groups to offload
303+
optimizers: The optimizer or ChainedOptimizer instance.
303304
"""
304305

306+
def _iter_opts(opt):
307+
if isinstance(opt, ChainedOptimizer):
308+
return opt.chained_optimizers
309+
return [opt]
310+
305311
def offload_tensor_to_cpu(tensor):
306312
if tensor is None:
307313
return
@@ -321,21 +327,27 @@ def offload_group_to_cpu(group):
321327
else:
322328
offload_tensor_to_cpu(group)
323329

324-
# Offload all parameter groups to CPU
330+
# Offload all parameter groups to CPU for each underlying optimizer
325331

326-
if hasattr(optimizers, "shard_fp32_from_float16_groups"):
327-
offload_group_to_cpu(optimizers.shard_fp32_from_float16_groups)
332+
for _opt in _iter_opts(optimizers):
333+
if hasattr(_opt, "shard_fp32_from_float16_groups"):
334+
offload_group_to_cpu(_opt.shard_fp32_from_float16_groups)
328335

329336

330337
@torch.no_grad()
331338
def load_megatron_copy_params(optimizers):
332339
"""
333-
Load optimizer parameters back to GPU
340+
Load optimizer parameters back to GPU. Handles ChainedOptimizer.
334341
335342
Args:
336-
optimizers: The optimizer containing parameter groups to load
343+
optimizers: Optimizer or ChainedOptimizer instance.
337344
"""
338345

346+
def _iter_opts(opt):
347+
if isinstance(opt, ChainedOptimizer):
348+
return opt.chained_optimizers
349+
return [opt]
350+
339351
def load_tensor_to_gpu(tensor):
340352
if tensor is None:
341353
return
@@ -356,36 +368,49 @@ def load_group_to_gpu(group):
356368
else:
357369
load_tensor_to_gpu(group)
358370

359-
# Load all parameter groups to GPU
371+
# Load all parameter groups to GPU for each underlying optimizer
360372

361-
if hasattr(optimizers, "shard_fp32_from_float16_groups"):
362-
load_group_to_gpu(optimizers.shard_fp32_from_float16_groups)
373+
for _opt in _iter_opts(optimizers):
374+
if hasattr(_opt, "shard_fp32_from_float16_groups"):
375+
load_group_to_gpu(_opt.shard_fp32_from_float16_groups)
363376

364377

365378
@torch.no_grad()
366379
def offload_megatron_optimizer(optimizers):
367-
offload_megatron_copy_params(optimizers)
368-
opt_state_dict_values = optimizers.optimizer.state.values()
369-
for v in opt_state_dict_values:
370-
if "exp_avg" in v:
371-
v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True)
372-
if "exp_avg_sq" in v:
373-
v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True)
374-
gc.collect()
375-
torch.cuda.empty_cache()
380+
def _iter_opts(opt):
381+
if isinstance(opt, ChainedOptimizer):
382+
return opt.chained_optimizers
383+
return [opt]
384+
385+
for _opt in _iter_opts(optimizers):
386+
offload_megatron_copy_params(_opt)
387+
opt_state_dict_values = _opt.optimizer.state.values()
388+
for v in opt_state_dict_values:
389+
if "exp_avg" in v:
390+
v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True)
391+
if "exp_avg_sq" in v:
392+
v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True)
393+
gc.collect()
394+
torch.cuda.empty_cache()
376395

377396

378397
@torch.no_grad()
379398
def load_megatron_optimizer(optimizers):
380-
load_megatron_copy_params(optimizers)
381-
opt_state_dict_values = optimizers.optimizer.state.values()
382-
for v in opt_state_dict_values:
383-
if "exp_avg" in v:
384-
v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True)
385-
if "exp_avg_sq" in v:
386-
v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True)
387-
gc.collect()
388-
torch.cuda.empty_cache()
399+
def _iter_opts(opt):
400+
if isinstance(opt, ChainedOptimizer):
401+
return opt.chained_optimizers
402+
return [opt]
403+
404+
for _opt in _iter_opts(optimizers):
405+
load_megatron_copy_params(_opt)
406+
opt_state_dict_values = _opt.optimizer.state.values()
407+
for v in opt_state_dict_values:
408+
if "exp_avg" in v:
409+
v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True)
410+
if "exp_avg_sq" in v:
411+
v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True)
412+
gc.collect()
413+
torch.cuda.empty_cache()
389414

390415

391416
def print_rank_0(message):

0 commit comments

Comments
 (0)