Skip to content

Commit 2957b05

Browse files
lint fixing
1 parent 95d94ba commit 2957b05

File tree

4 files changed

+13
-8
lines changed

4 files changed

+13
-8
lines changed

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,9 @@ def __init__(
731731
)
732732
args = get_args()
733733

734-
if (args.patch_zero_bubble and args.enable_zero_bubble) or (args.patch_moe_overlap and args.overlap_moe_expert_parallel_comm):
734+
if (args.patch_zero_bubble and args.enable_zero_bubble) or (
735+
args.patch_moe_overlap and args.overlap_moe_expert_parallel_comm
736+
):
735737
from .zbpp_gemm import grouped_gemm_with_weight_gradient_store
736738

737739
self.grouped_gemm = functools.partial(

primus/backends/megatron/core/extensions/zbpp_gemm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(
103103
):
104104
if wgrad_gemm_backend_func is None:
105105
wgrad_gemm_backend_func = group_gemm_backend_func
106-
ctx.use_main_grad = hasattr(weight, 'main_grad') and weight.main_grad is not None
106+
ctx.use_main_grad = hasattr(weight, "main_grad") and weight.main_grad is not None
107107
if ctx.use_main_grad:
108108
ctx.weight_main_grad = weight.main_grad
109109
ctx.weight_shape_ori = weight.shape

primus/backends/megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,14 @@
55
###############################################################################
66

77
import torch
8-
9-
from megatron.core.pipeline_parallel.utils import (
10-
get_comm_stream,
11-
)
12-
138
from megatron.core.models.common.model_chunk_schedule_plan import (
149
TransformerLayerSchedulePlan,
10+
)
11+
from megatron.core.models.common.model_chunk_schedule_plan import (
1512
TransformerModelChunkSchedulePlan as TransformerModelChunkSchedulePlanBase,
1613
)
14+
from megatron.core.pipeline_parallel.utils import get_comm_stream
15+
1716
from primus.backends.megatron.core.pipeline_parallel.zerobubble.zbpp_utils import (
1817
WeightGradStore,
1918
)

primus/modules/trainer/megatron/pre_trainer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,14 +234,18 @@ def forward_step(self, data_iterator, model: GPTModel, return_schedule_plan=Fals
234234
args.overlap_moe_expert_parallel_comm
235235
), "overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
236236
if args.patch_moe_overlap:
237-
assert not args.delay_wgrad_compute, "Primus MoE overlap handles wgrad separately from the original Megatron implementation"
237+
assert (
238+
not args.delay_wgrad_compute
239+
), "Primus MoE overlap handles wgrad separately from the original Megatron implementation"
238240
from primus.backends.megatron.core.pipeline_parallel.zerobubble.zbpp_utils import (
239241
WeightGradStore,
240242
)
243+
241244
WeightGradStore.enable_split_bw()
242245
from primus.backends.megatron.core.models.common.model_chunk_schedule_plan import (
243246
TransformerModelChunkSchedulePlan,
244247
)
248+
245249
schedule_plan = TransformerModelChunkSchedulePlan(
246250
model, tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
247251
)

0 commit comments

Comments
 (0)