Skip to content

Commit 68f577e

Browse files
authored
feat(zerobubble): zerobubble add lagacy group gemm & te backend (#241)
1 parent c69ee12 commit 68f577e

File tree

9 files changed

+1183
-29
lines changed

9 files changed

+1183
-29
lines changed

primus/backends/megatron/core/extensions/primus_turbo.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
#
44
# See LICENSE for license information.
55
###############################################################################
6+
import functools
67
from contextlib import contextmanager
78
from typing import Callable, List, Optional, Tuple
89

10+
import grouped_gemm
911
import primus_turbo.pytorch as pt
1012
import torch
1113
import transformer_engine as te
@@ -728,12 +730,22 @@ def __init__(
728730
pg_collection,
729731
)
730732
args = get_args()
733+
grouped_gemm_backend = args.grouped_gemm_backend
734+
self.grouped_gemm_backend = grouped_gemm_backend
735+
731736
if args.patch_zero_bubble and args.enable_zero_bubble:
732737
from .zbpp_gemm import grouped_gemm_with_weight_gradient_store
733738

734-
self.grouped_gemm = grouped_gemm_with_weight_gradient_store
739+
self.grouped_gemm = functools.partial(
740+
grouped_gemm_with_weight_gradient_store, gg_backend=grouped_gemm_backend
741+
)
735742
else:
736-
self.grouped_gemm = pt.ops.grouped_gemm
743+
if grouped_gemm_backend == "turbo-gg":
744+
self.grouped_gemm = pt.ops.grouped_gemm
745+
elif grouped_gemm_backend == "lagacy-gg":
746+
self.grouped_gemm = grouped_gemm.ops.gmm
747+
else:
748+
raise NotImplementedError(f"Grouped gemm backend {grouped_gemm_backend} not implemented")
737749

738750
def forward(
739751
self,
@@ -760,6 +772,7 @@ def forward(
760772
if permuted_local_hidden_states.nelement() != 0:
761773
# Reshape the weights for the grouped GEMMs.
762774
if args.patch_zero_bubble and args.enable_zero_bubble:
775+
763776
w1 = self.weight1
764777
w2 = self.weight2
765778

@@ -769,7 +782,8 @@ def forward(
769782
w1 = self.weight1.view(self.num_local_experts, self.config.hidden_size, -1)
770783
w2 = self.weight2.view(self.num_local_experts, -1, self.config.hidden_size)
771784

772-
tokens_per_expert = tokens_per_expert.cuda()
785+
if self.grouped_gemm_backend == "turbo-gg":
786+
tokens_per_expert = tokens_per_expert.cuda()
773787
assert w1.is_contiguous(), "w1 must be contiguous"
774788
assert w2.is_contiguous(), "w2 must be contiguous"
775789
fc1_output = self.grouped_gemm(
@@ -870,8 +884,11 @@ def __init__(
870884
deepep_use_comm_stream=args.turbo_deepep_use_comm_stream,
871885
deepep_num_use_cu=args.turbo_deepep_num_cu,
872886
deepep_num_worst_tokens=num_worst_tokens,
873-
deepep_use_cuda_num_tokens_per_expert=args.use_turbo_grouped_mlp
874-
and args.moe_use_legacy_grouped_gemm,
887+
deepep_use_cuda_num_tokens_per_expert=(
888+
args.use_turbo_grouped_mlp
889+
and args.moe_use_legacy_grouped_gemm
890+
and args.grouped_gemm_backend == "turbo-gg"
891+
),
875892
deepep_async_finish=True,
876893
deepep_allocate_on_comm_stream=True,
877894
)

0 commit comments

Comments
 (0)