33#
44# See LICENSE for license information.
55###############################################################################
6+ import functools
67from contextlib import contextmanager
78from typing import Callable , List , Optional , Tuple
89
10+ import grouped_gemm
911import primus_turbo .pytorch as pt
1012import torch
1113import transformer_engine as te
@@ -728,12 +730,22 @@ def __init__(
728730 pg_collection ,
729731 )
730732 args = get_args ()
733+ grouped_gemm_backend = args .grouped_gemm_backend
734+ self .grouped_gemm_backend = grouped_gemm_backend
735+
731736 if args .patch_zero_bubble and args .enable_zero_bubble :
732737 from .zbpp_gemm import grouped_gemm_with_weight_gradient_store
733738
734- self .grouped_gemm = grouped_gemm_with_weight_gradient_store
739+ self .grouped_gemm = functools .partial (
740+ grouped_gemm_with_weight_gradient_store , gg_backend = grouped_gemm_backend
741+ )
735742 else :
736- self .grouped_gemm = pt .ops .grouped_gemm
743+ if grouped_gemm_backend == "turbo-gg" :
744+ self .grouped_gemm = pt .ops .grouped_gemm
745+ elif grouped_gemm_backend == "lagacy-gg" :
746+ self .grouped_gemm = grouped_gemm .ops .gmm
747+ else :
748+ raise NotImplementedError (f"Grouped gemm backend { grouped_gemm_backend } not implemented" )
737749
738750 def forward (
739751 self ,
@@ -760,6 +772,7 @@ def forward(
760772 if permuted_local_hidden_states .nelement () != 0 :
761773 # Reshape the weights for the grouped GEMMs.
762774 if args .patch_zero_bubble and args .enable_zero_bubble :
775+
763776 w1 = self .weight1
764777 w2 = self .weight2
765778
@@ -769,7 +782,8 @@ def forward(
769782 w1 = self .weight1 .view (self .num_local_experts , self .config .hidden_size , - 1 )
770783 w2 = self .weight2 .view (self .num_local_experts , - 1 , self .config .hidden_size )
771784
772- tokens_per_expert = tokens_per_expert .cuda ()
785+ if self .grouped_gemm_backend == "turbo-gg" :
786+ tokens_per_expert = tokens_per_expert .cuda ()
773787 assert w1 .is_contiguous (), "w1 must be contiguous"
774788 assert w2 .is_contiguous (), "w2 must be contiguous"
775789 fc1_output = self .grouped_gemm (
@@ -870,8 +884,11 @@ def __init__(
870884 deepep_use_comm_stream = args .turbo_deepep_use_comm_stream ,
871885 deepep_num_use_cu = args .turbo_deepep_num_cu ,
872886 deepep_num_worst_tokens = num_worst_tokens ,
873- deepep_use_cuda_num_tokens_per_expert = args .use_turbo_grouped_mlp
874- and args .moe_use_legacy_grouped_gemm ,
887+ deepep_use_cuda_num_tokens_per_expert = (
888+ args .use_turbo_grouped_mlp
889+ and args .moe_use_legacy_grouped_gemm
890+ and args .grouped_gemm_backend == "turbo-gg"
891+ ),
875892 deepep_async_finish = True ,
876893 deepep_allocate_on_comm_stream = True ,
877894 )
0 commit comments