Skip to content

[Megatron] Support optimizer offload for moe when ep > 1 #1638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/workflows/e2e_ppo_trainer_megatron.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with validation and saving
run: |
ray stop --force
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/run_ppo_trainer_megatron.sh
ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 bash tests/e2e/run_ppo_trainer_megatron.sh
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) after resuming
run: |
ray stop --force
Expand Down Expand Up @@ -107,7 +107,7 @@ jobs:
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek)
run: |
ray stop --force
SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh
ALL_OFFLOAD=True SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct bash tests/e2e/run_ppo_trainer_megatron.sh
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek)
run: |
ray stop --force
Expand Down Expand Up @@ -149,7 +149,7 @@ jobs:
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving
run: |
ray stop --force
VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh
ALL_OFFLOAD=True VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh
- name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming
run: |
ray stop --force
Expand Down Expand Up @@ -306,3 +306,4 @@ jobs:
run: |
rm -rf checkpoints


22 changes: 22 additions & 0 deletions tests/e2e/run_ppo_trainer_megatron.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ RM_VPP=${RM_VPP:-$COMMON_VPP}
RM_CP=${RM_CP:-$COMMON_CP}
RM_TP=${RM_TP:-$TRAIN_TP}

ALL_OFFLOAD=${ALL_OFFLOAD:-False}
COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}
COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}
COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}

ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}

CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra']
SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0}
if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then
Expand All @@ -81,6 +95,9 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=$ACTOR_VPP \
actor_rollout_ref.actor.megatron.context_parallel_size=$ACTOR_CP \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$ACTOR_TP \
actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \
actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
Expand All @@ -95,6 +112,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \
actor_rollout_ref.ref.megatron.context_parallel_size=$REF_CP \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$REF_TP \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \
critic.optim.lr=2e-5 \
critic.model.path="${MODEL_PATH}" \
critic.model.enable_gradient_checkpointing=False \
Expand All @@ -104,13 +122,17 @@ python3 -m verl.trainer.main_ppo --config-path=config \
critic.megatron.context_parallel_size=$CRITIC_CP \
critic.megatron.tensor_model_parallel_size=$CRITIC_TP \
critic.checkpoint.contents=$CHECKPOINT_CONTENTS \
critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \
critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \
critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \
reward_model.enable=True \
reward_model.model.path="${MODEL_PATH}" \
reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \
reward_model.megatron.pipeline_model_parallel_size=$RM_PP \
reward_model.megatron.virtual_pipeline_model_parallel_size=$RM_VPP \
reward_model.megatron.context_parallel_size=$RM_CP \
reward_model.megatron.tensor_model_parallel_size=$RM_TP \
reward_model.megatron.param_offload=${RM_PARAM_OFFLOAD} \
algorithm.use_kl_in_reward=False \
algorithm.kl_penalty=kl \
algorithm.kl_ctrl.kl_coef=0.001 \
Expand Down
2 changes: 0 additions & 2 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,6 @@ reward_model:
strategy: megatron
megatron:
param_offload: False
grad_offload: False
optimizer_offload: False
tensor_model_parallel_size: 1
expert_model_parallel_size: 1
expert_tensor_parallel_size: null
Expand Down
83 changes: 54 additions & 29 deletions verl/utils/megatron_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from megatron.core.distributed import DistributedDataParallel as DDP
from megatron.core.distributed import DistributedDataParallelConfig
from megatron.core.enums import ModelType
from megatron.core.optimizer import OptimizerConfig
from megatron.core.optimizer import ChainedOptimizer, OptimizerConfig
from megatron.core.transformer import TransformerConfig
from megatron.core.transformer.module import Float16Module
from megatron.core.utils import get_attr_wrapped_model
Expand Down Expand Up @@ -296,12 +296,18 @@ def load_megatron_model_to_gpu(models, load_grad=True):
@torch.no_grad()
def offload_megatron_copy_params(optimizers):
"""
Offload optimizer parameters to CPU
Offload optimizer parameters to CPU. Supports both Megatron optimizers
and `ChainedOptimizer`, which wraps a list of underlying optimizers.

Args:
optimizers: The optimizer containing parameter groups to offload
optimizers: The optimizer or ChainedOptimizer instance.
"""

def _iter_opts(opt):
if isinstance(opt, ChainedOptimizer):
return opt.chained_optimizers
return [opt]

def offload_tensor_to_cpu(tensor):
if tensor is None:
return
Expand All @@ -321,21 +327,27 @@ def offload_group_to_cpu(group):
else:
offload_tensor_to_cpu(group)

# Offload all parameter groups to CPU
# Offload all parameter groups to CPU for each underlying optimizer

if hasattr(optimizers, "shard_fp32_from_float16_groups"):
offload_group_to_cpu(optimizers.shard_fp32_from_float16_groups)
for _opt in _iter_opts(optimizers):
if hasattr(_opt, "shard_fp32_from_float16_groups"):
offload_group_to_cpu(_opt.shard_fp32_from_float16_groups)


@torch.no_grad()
def load_megatron_copy_params(optimizers):
"""
Load optimizer parameters back to GPU
Load optimizer parameters back to GPU. Handles ChainedOptimizer.

Args:
optimizers: The optimizer containing parameter groups to load
optimizers: Optimizer or ChainedOptimizer instance.
"""

def _iter_opts(opt):
if isinstance(opt, ChainedOptimizer):
return opt.chained_optimizers
return [opt]

def load_tensor_to_gpu(tensor):
if tensor is None:
return
Expand All @@ -356,36 +368,49 @@ def load_group_to_gpu(group):
else:
load_tensor_to_gpu(group)

# Load all parameter groups to GPU
# Load all parameter groups to GPU for each underlying optimizer

if hasattr(optimizers, "shard_fp32_from_float16_groups"):
load_group_to_gpu(optimizers.shard_fp32_from_float16_groups)
for _opt in _iter_opts(optimizers):
if hasattr(_opt, "shard_fp32_from_float16_groups"):
load_group_to_gpu(_opt.shard_fp32_from_float16_groups)


@torch.no_grad()
def offload_megatron_optimizer(optimizers):
offload_megatron_copy_params(optimizers)
opt_state_dict_values = optimizers.optimizer.state.values()
for v in opt_state_dict_values:
if "exp_avg" in v:
v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True)
if "exp_avg_sq" in v:
v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True)
gc.collect()
torch.cuda.empty_cache()
def _iter_opts(opt):
if isinstance(opt, ChainedOptimizer):
return opt.chained_optimizers
return [opt]

for _opt in _iter_opts(optimizers):
offload_megatron_copy_params(_opt)
opt_state_dict_values = _opt.optimizer.state.values()
for v in opt_state_dict_values:
if "exp_avg" in v:
v["exp_avg"] = v["exp_avg"].to("cpu", non_blocking=True)
if "exp_avg_sq" in v:
v["exp_avg_sq"] = v["exp_avg_sq"].to("cpu", non_blocking=True)
gc.collect()
torch.cuda.empty_cache()


@torch.no_grad()
def load_megatron_optimizer(optimizers):
load_megatron_copy_params(optimizers)
opt_state_dict_values = optimizers.optimizer.state.values()
for v in opt_state_dict_values:
if "exp_avg" in v:
v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True)
if "exp_avg_sq" in v:
v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True)
gc.collect()
torch.cuda.empty_cache()
def _iter_opts(opt):
if isinstance(opt, ChainedOptimizer):
return opt.chained_optimizers
return [opt]

for _opt in _iter_opts(optimizers):
load_megatron_copy_params(_opt)
opt_state_dict_values = _opt.optimizer.state.values()
for v in opt_state_dict_values:
if "exp_avg" in v:
v["exp_avg"] = v["exp_avg"].to(torch.cuda.current_device(), non_blocking=True)
if "exp_avg_sq" in v:
v["exp_avg_sq"] = v["exp_avg_sq"].to(torch.cuda.current_device(), non_blocking=True)
gc.collect()
torch.cuda.empty_cache()


def print_rank_0(message):
Expand Down
Loading