Skip to content

Commit 5e29f58

Browse files
authored
Merge branch 'main' into featute/cli/layered_env
2 parents 25684a9 + 5a03df4 commit 5e29f58

File tree

7 files changed

+186
-8
lines changed

7 files changed

+186
-8
lines changed

examples/megatron/configs/MI300X/deepseek_v3-pretrain.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ modules:
8686
use_turbo_attention: true
8787
use_turbo_grouped_mlp: true
8888

89+
# MoE overlap
90+
# overlap_moe_expert_parallel_comm: true
91+
# patch_moe_overlap: true
92+
# delay_wgrad_compute: false
93+
# moe_shared_expert_overlap: false
94+
8995
# Cross entropy flags
9096
# cross_entropy_fusion_impl: "te"
9197
# cross_entropy_loss_fusion: true

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,9 @@ def __init__(
731731
)
732732
args = get_args()
733733

734-
if args.patch_zero_bubble and args.enable_zero_bubble:
734+
if (args.patch_zero_bubble and args.enable_zero_bubble) or (
735+
args.patch_moe_overlap and args.overlap_moe_expert_parallel_comm
736+
):
735737
from .zbpp_gemm import grouped_gemm_with_weight_gradient_store
736738

737739
self.grouped_gemm = functools.partial(

primus/backends/megatron/core/extensions/zbpp_gemm.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def forward(
103103
):
104104
if wgrad_gemm_backend_func is None:
105105
wgrad_gemm_backend_func = group_gemm_backend_func
106-
ctx.weight_main_grad = weight.main_grad
106+
ctx.use_main_grad = hasattr(weight, "main_grad") and weight.main_grad is not None
107+
if ctx.use_main_grad:
108+
ctx.weight_main_grad = weight.main_grad
107109
ctx.weight_shape_ori = weight.shape
108110
ctx.group_gemm_backend_func = group_gemm_backend_func
109111
ctx.wgrad_gemm_backend_func = wgrad_gemm_backend_func
@@ -129,7 +131,8 @@ def forward(
129131
def backward(ctx, grad_output):
130132
input, weight, group_lens, group_offs = ctx.saved_tensors
131133
group_gemm_backend_func = ctx.group_gemm_backend_func
132-
weight.main_grad = ctx.weight_main_grad
134+
if ctx.use_main_grad:
135+
weight.main_grad = ctx.weight_main_grad
133136
grad_a = group_gemm_backend_func(
134137
grad_output,
135138
weight,
@@ -154,8 +157,9 @@ def process_wgrad(_weight, _weight_shape_ori, _grad_output, _total_input, handle
154157
trans_b=False,
155158
)
156159
_wgrad = _wgrad.view(_weight_shape_ori)
157-
with torch.no_grad():
158-
_weight.main_grad.add_(_wgrad)
160+
if ctx.use_main_grad:
161+
with torch.no_grad():
162+
_weight.main_grad.add_(_wgrad)
159163

160164
WeightGradStore.put(
161165
weight,

primus/backends/megatron/core/models/common/__init__.py

Whitespace-only changes.
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
import torch
8+
from megatron.core.models.common.model_chunk_schedule_plan import (
9+
TransformerLayerSchedulePlan,
10+
)
11+
from megatron.core.models.common.model_chunk_schedule_plan import (
12+
TransformerModelChunkSchedulePlan as TransformerModelChunkSchedulePlanBase,
13+
)
14+
from megatron.core.pipeline_parallel.utils import get_comm_stream
15+
16+
from primus.backends.megatron.core.pipeline_parallel.zerobubble.zbpp_utils import (
17+
WeightGradStore,
18+
)
19+
20+
21+
class TransformerModelChunkSchedulePlan(TransformerModelChunkSchedulePlanBase):
22+
23+
@staticmethod
24+
def run(
25+
f_schedule_plan,
26+
b_schedule_plan,
27+
b_grad=None,
28+
pre_forward=None,
29+
pre_backward=None,
30+
post_forward=None,
31+
post_backward=None,
32+
):
33+
"""Model Chunk level 1f1b fine-grained scheduler.
34+
35+
This function schedules the forward and backward passes for a model chunk,
36+
which interleaves forward and backward function of multiple Transformer layers
37+
within a model chunk, and this is needed to overlap the submodules between the individual
38+
forward and backward functions.
39+
40+
Assume there are 4 layers in the given model chunk:
41+
Phase 0: p2p_comm_sync -> forward_preprocess -> p2p_comm_sync -> backward_postprocess
42+
Phase 1: forward_layer[0] + backward_layer[3], overlapped execution by schedule_layer_1f1b
43+
Phase 2: forward_layer[1] + backward_layer[2], overlapped execution by schedule_layer_1f1b
44+
Phase 3: forward_layer[2] + backward_layer[1], overlapped execution by schedule_layer_1f1b
45+
Phase 4: forward_layer[3] + backward_layer[0], overlapped execution by schedule_layer_1f1b
46+
Phase 5: send_forward_recv_backward -> send_backward_recv_forward
47+
Phase 6: backward_dw of the first layer -> forward_postprocess -> backward_preprocess
48+
49+
Args:
50+
f_schedule_plan (TransformerModelChunkSchedulePlan): The forward schedule plan
51+
b_schedule_plan (TransformerModelChunkSchedulePlan): The backward schedule plan
52+
b_grad (Tensor or None): The gradient of the loss function
53+
pre_forward (callable or None): The function to call before the forward pass
54+
pre_backward (callable or None): The function to call before the backward pass
55+
post_forward (callable or None): The function to call after the forward pass
56+
post_backward (callable or None): The function to call after the backward pass
57+
Returns:
58+
The output of the forward pass.
59+
"""
60+
f_input = None
61+
if f_schedule_plan:
62+
# pp output send/receive sync
63+
if pre_forward is not None:
64+
pre_forward(f_schedule_plan.vp_stage)
65+
f_schedule_plan.record_current_stream()
66+
f_input = f_schedule_plan.pre_process.forward()
67+
68+
if b_schedule_plan:
69+
b_schedule_plan.record_current_stream()
70+
assert b_grad is not None
71+
if pre_backward is not None:
72+
pre_backward(b_schedule_plan.vp_stage)
73+
b_schedule_plan.record_current_stream()
74+
75+
if b_schedule_plan.post_process is not None:
76+
b_grad = b_schedule_plan.post_process.backward(b_grad)
77+
78+
f_num_layers = f_schedule_plan.num_layers() if f_schedule_plan is not None else 0
79+
b_num_layers = b_schedule_plan.num_layers() if b_schedule_plan is not None else 0
80+
overlapped_layers = min(f_num_layers, b_num_layers)
81+
82+
# combined forward and backward pass for overlapped layers
83+
for i in range(overlapped_layers):
84+
f_layer = f_schedule_plan.get_layer(i)
85+
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
86+
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b")
87+
f_input, b_grad = TransformerLayerSchedulePlan.run(
88+
f_layer,
89+
b_layer,
90+
f_input=f_input,
91+
b_grad=b_grad,
92+
is_last_layer_in_bwd=(i == b_num_layers - 1),
93+
)
94+
torch.cuda.nvtx.range_pop()
95+
96+
# backward pass for the remaining layers
97+
for i in range(overlapped_layers, b_num_layers):
98+
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
99+
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
100+
_, b_grad = TransformerLayerSchedulePlan.run(
101+
None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1)
102+
)
103+
torch.cuda.nvtx.range_pop()
104+
105+
# forward pass for the remaining layers
106+
for i in range(overlapped_layers, f_num_layers):
107+
f_layer = f_schedule_plan.get_layer(i)
108+
torch.cuda.nvtx.range_push(f"layer_{i}f")
109+
f_input, _ = TransformerLayerSchedulePlan.run(f_layer, None, f_input=f_input)
110+
torch.cuda.nvtx.range_pop()
111+
112+
if f_schedule_plan is not None and post_forward is not None:
113+
# post_forward()/send_forward_recv_forward() is running in the communication stream,
114+
# so the p2p comm could be overlapped with the attn backward
115+
with torch.cuda.stream(get_comm_stream()):
116+
f_schedule_plan.wait_current_stream()
117+
post_forward(f_input, f_schedule_plan.vp_stage)
118+
119+
# post_backward()/send_backward_recv_backward() is running in the computation stream,
120+
# so the p2p comm could be overlapped with the wgrad of attn backward
121+
if b_schedule_plan is not None and post_backward is not None:
122+
b_schedule_plan.wait_current_stream()
123+
post_backward(b_grad, b_schedule_plan.vp_stage)
124+
125+
# Delay the dw in backward pass for overlapping with the p2p comm
126+
if b_num_layers > 0:
127+
WeightGradStore.flush()
128+
WeightGradStore.pop()
129+
130+
# post process forward
131+
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
132+
f_input = f_schedule_plan.post_process.forward(f_input)
133+
# pre process backward
134+
if b_schedule_plan is not None:
135+
b_schedule_plan.pre_process.backward(b_grad)
136+
137+
if f_schedule_plan:
138+
f_schedule_plan.wait_current_stream()
139+
if b_schedule_plan:
140+
b_schedule_plan.wait_current_stream()
141+
142+
# Release reference as early as possible, this helps avoid memory leak.
143+
if b_schedule_plan is not None:
144+
b_schedule_plan.release_state()
145+
146+
return f_input

primus/configs/modules/megatron/primus_megatron_module.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ no_fp8_weight_transpose_cache: false
3030
# parallelism
3131
decoder_pipeline_manual_split_list: null # int list
3232

33+
# MoE comm & comp Overlap
34+
patch_moe_overlap: false
35+
3336
# perf
3437
pp_warmup: false # set to true to decrease iter-1 time when using pp
3538

primus/modules/trainer/megatron/pre_trainer.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,26 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals
233233
assert (
234234
args.overlap_moe_expert_parallel_comm
235235
), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
236-
schedule_plan = model.build_schedule_plan(
237-
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
238-
)
236+
if args.patch_moe_overlap:
237+
assert (
238+
not args.delay_wgrad_compute
239+
), "Primus MoE overlap handles wgrad separately from the original Megatron implementation"
240+
from primus.backends.megatron.core.pipeline_parallel.zerobubble.zbpp_utils import (
241+
WeightGradStore,
242+
)
243+
244+
WeightGradStore.enable_split_bw()
245+
from primus.backends.megatron.core.models.common.model_chunk_schedule_plan import (
246+
TransformerModelChunkSchedulePlan,
247+
)
248+
249+
schedule_plan = TransformerModelChunkSchedulePlan(
250+
model, tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
251+
)
252+
else:
253+
schedule_plan = model.build_schedule_plan(
254+
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
255+
)
239256
return schedule_plan, partial(self.loss_func, loss_mask)
240257
else:
241258
output_tensor = model(

0 commit comments

Comments
 (0)