Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ modules:
use_turbo_attention: true
use_turbo_grouped_mlp: true

# MoE overlap
# overlap_moe_expert_parallel_comm: true
# patch_moe_overlap: true
# delay_wgrad_compute: false
# moe_shared_expert_overlap: false

# Cross entropy flags
# cross_entropy_fusion_impl: "te"
# cross_entropy_loss_fusion: true
4 changes: 3 additions & 1 deletion primus/backends/megatron/core/extensions/primus_turbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,9 @@ def __init__(
)
args = get_args()

if args.patch_zero_bubble and args.enable_zero_bubble:
if (args.patch_zero_bubble and args.enable_zero_bubble) or (
args.patch_moe_overlap and args.overlap_moe_expert_parallel_comm
):
from .zbpp_gemm import grouped_gemm_with_weight_gradient_store

self.grouped_gemm = functools.partial(
Expand Down
12 changes: 8 additions & 4 deletions primus/backends/megatron/core/extensions/zbpp_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def forward(
):
if wgrad_gemm_backend_func is None:
wgrad_gemm_backend_func = group_gemm_backend_func
ctx.weight_main_grad = weight.main_grad
ctx.use_main_grad = hasattr(weight, "main_grad") and weight.main_grad is not None
if ctx.use_main_grad:
ctx.weight_main_grad = weight.main_grad
ctx.weight_shape_ori = weight.shape
ctx.group_gemm_backend_func = group_gemm_backend_func
ctx.wgrad_gemm_backend_func = wgrad_gemm_backend_func
Expand All @@ -129,7 +131,8 @@ def forward(
def backward(ctx, grad_output):
input, weight, group_lens, group_offs = ctx.saved_tensors
group_gemm_backend_func = ctx.group_gemm_backend_func
weight.main_grad = ctx.weight_main_grad
if ctx.use_main_grad:
weight.main_grad = ctx.weight_main_grad
grad_a = group_gemm_backend_func(
grad_output,
weight,
Expand All @@ -154,8 +157,9 @@ def process_wgrad(_weight, _weight_shape_ori, _grad_output, _total_input, handle
trans_b=False,
)
_wgrad = _wgrad.view(_weight_shape_ori)
with torch.no_grad():
_weight.main_grad.add_(_wgrad)
if ctx.use_main_grad:
with torch.no_grad():
_weight.main_grad.add_(_wgrad)

WeightGradStore.put(
weight,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
###############################################################################
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################

import torch
from megatron.core.models.common.model_chunk_schedule_plan import (
TransformerLayerSchedulePlan,
)
from megatron.core.models.common.model_chunk_schedule_plan import (
TransformerModelChunkSchedulePlan as TransformerModelChunkSchedulePlanBase,
)
from megatron.core.pipeline_parallel.utils import get_comm_stream

from primus.backends.megatron.core.pipeline_parallel.zerobubble.zbpp_utils import (
WeightGradStore,
)


class TransformerModelChunkSchedulePlan(TransformerModelChunkSchedulePlanBase):

@staticmethod
def run(
f_schedule_plan,
b_schedule_plan,
b_grad=None,
pre_forward=None,
pre_backward=None,
post_forward=None,
post_backward=None,
):
"""Model Chunk level 1f1b fine-grained scheduler.

This function schedules the forward and backward passes for a model chunk,
which interleaves forward and backward function of multiple Transformer layers
within a model chunk, and this is needed to overlap the submodules between the individual
forward and backward functions.

Assume there are 4 layers in the given model chunk:
Phase 0: p2p_comm_sync -> forward_preprocess -> p2p_comm_sync -> backward_postprocess
Phase 1: forward_layer[0] + backward_layer[3], overlapped execution by schedule_layer_1f1b
Phase 2: forward_layer[1] + backward_layer[2], overlapped execution by schedule_layer_1f1b
Phase 3: forward_layer[2] + backward_layer[1], overlapped execution by schedule_layer_1f1b
Phase 4: forward_layer[3] + backward_layer[0], overlapped execution by schedule_layer_1f1b
Phase 5: send_forward_recv_backward -> send_backward_recv_forward
Phase 6: backward_dw of the first layer -> forward_postprocess -> backward_preprocess

Args:
f_schedule_plan (TransformerModelChunkSchedulePlan): The forward schedule plan
b_schedule_plan (TransformerModelChunkSchedulePlan): The backward schedule plan
b_grad (Tensor or None): The gradient of the loss function
pre_forward (callable or None): The function to call before the forward pass
pre_backward (callable or None): The function to call before the backward pass
post_forward (callable or None): The function to call after the forward pass
post_backward (callable or None): The function to call after the backward pass
Returns:
The output of the forward pass.
"""
f_input = None
if f_schedule_plan:
# pp output send/receive sync
if pre_forward is not None:
pre_forward(f_schedule_plan.vp_stage)
f_schedule_plan.record_current_stream()
f_input = f_schedule_plan.pre_process.forward()

if b_schedule_plan:
b_schedule_plan.record_current_stream()
assert b_grad is not None
if pre_backward is not None:
pre_backward(b_schedule_plan.vp_stage)
b_schedule_plan.record_current_stream()

if b_schedule_plan.post_process is not None:
b_grad = b_schedule_plan.post_process.backward(b_grad)

f_num_layers = f_schedule_plan.num_layers() if f_schedule_plan is not None else 0
b_num_layers = b_schedule_plan.num_layers() if b_schedule_plan is not None else 0
overlapped_layers = min(f_num_layers, b_num_layers)

# combined forward and backward pass for overlapped layers
for i in range(overlapped_layers):
f_layer = f_schedule_plan.get_layer(i)
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b")
f_input, b_grad = TransformerLayerSchedulePlan.run(
f_layer,
b_layer,
f_input=f_input,
b_grad=b_grad,
is_last_layer_in_bwd=(i == b_num_layers - 1),
)
torch.cuda.nvtx.range_pop()

# backward pass for the remaining layers
for i in range(overlapped_layers, b_num_layers):
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
_, b_grad = TransformerLayerSchedulePlan.run(
None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1)
)
torch.cuda.nvtx.range_pop()

# forward pass for the remaining layers
for i in range(overlapped_layers, f_num_layers):
f_layer = f_schedule_plan.get_layer(i)
torch.cuda.nvtx.range_push(f"layer_{i}f")
f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input)
torch.cuda.nvtx.range_pop()

if f_schedule_plan is not None and post_forward is not None:
# post_forward()/send_forward_recv_forward() is running in the communication stream,
# so the p2p comm could be overlapped with the attn backward
with torch.cuda.stream(get_comm_stream()):
f_schedule_plan.wait_current_stream()
post_forward(f_input, f_schedule_plan.vp_stage)

# post_backward()/send_backward_recv_backward() is running in the computation stream,
# so the p2p comm could be overlapped with the wgrad of attn backward
if b_schedule_plan is not None and post_backward is not None:
b_schedule_plan.wait_current_stream()
post_backward(b_grad, b_schedule_plan.vp_stage)

# Delay the dw in backward pass for overlapping with the p2p comm
if b_num_layers > 0:
WeightGradStore.flush()
WeightGradStore.pop()

# post process forward
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
f_input = f_schedule_plan.post_process.forward(f_input)
# pre process backward
if b_schedule_plan is not None:
b_schedule_plan.pre_process.backward(b_grad)

if f_schedule_plan:
f_schedule_plan.wait_current_stream()
if b_schedule_plan:
b_schedule_plan.wait_current_stream()

# Release reference as early as possible, this helps avoid memory leak.
if b_schedule_plan is not None:
b_schedule_plan.release_state()

return f_input
3 changes: 3 additions & 0 deletions primus/configs/modules/megatron/primus_megatron_module.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ no_fp8_weight_transpose_cache: false
# parallelism
decoder_pipeline_manual_split_list: null # int list

# MoE comm & comp Overlap
patch_moe_overlap: false

# perf
pp_warmup: false # set to true to decrease iter-1 time when using pp

Expand Down
23 changes: 20 additions & 3 deletions primus/modules/trainer/megatron/pre_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,26 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals
assert (
args.overlap_moe_expert_parallel_comm
), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
schedule_plan = model.build_schedule_plan(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
if args.patch_moe_overlap:
assert (
not args.delay_wgrad_compute
), "Primus MoE overlap handles wgrad separately from the original Megatron implementation"
from primus.backends.megatron.core.pipeline_parallel.zerobubble.zbpp_utils import (
WeightGradStore,
)

WeightGradStore.enable_split_bw()
from primus.backends.megatron.core.models.common.model_chunk_schedule_plan import (
TransformerModelChunkSchedulePlan,
)

schedule_plan = TransformerModelChunkSchedulePlan(
model, tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
else:
schedule_plan = model.build_schedule_plan(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return schedule_plan, partial(self.loss_func, loss_mask)
else:
output_tensor = model(
Expand Down