77from contextlib import contextmanager
88from typing import Callable , List , Optional , Tuple
99
10- import grouped_gemm
1110import primus_turbo .pytorch as pt
1211import torch
1312import torch .nn .functional as F
@@ -731,22 +730,15 @@ def __init__(
731730 pg_collection ,
732731 )
733732 args = get_args ()
734- grouped_gemm_backend = args .grouped_gemm_backend
735- self .grouped_gemm_backend = grouped_gemm_backend
736733
737734 if args .patch_zero_bubble and args .enable_zero_bubble :
738735 from .zbpp_gemm import grouped_gemm_with_weight_gradient_store
739736
740737 self .grouped_gemm = functools .partial (
741- grouped_gemm_with_weight_gradient_store , gg_backend = grouped_gemm_backend
738+ grouped_gemm_with_weight_gradient_store , gg_backend = "turbo-gg"
742739 )
743740 else :
744- if grouped_gemm_backend == "turbo-gg" :
745- self .grouped_gemm = pt .ops .grouped_gemm
746- elif grouped_gemm_backend == "lagacy-gg" :
747- self .grouped_gemm = grouped_gemm .ops .gmm
748- else :
749- raise NotImplementedError (f"Grouped gemm backend { grouped_gemm_backend } not implemented" )
741+ self .grouped_gemm = pt .ops .grouped_gemm
750742
751743 if args .use_turbo_fused_act_with_probs :
752744 assert self .config .gated_linear_unit , "turbo_fused_act_with_probs only support with GLU."
@@ -802,13 +794,18 @@ def forward(
802794 w1 = self .weight1 .view (self .num_local_experts , self .config .hidden_size , - 1 )
803795 w2 = self .weight2 .view (self .num_local_experts , - 1 , self .config .hidden_size )
804796
805- if self .grouped_gemm_backend == "turbo-gg" :
806- tokens_per_expert = tokens_per_expert .cuda ()
797+ tokens_per_expert = tokens_per_expert .to (w1 .device )
807798 assert w1 .is_contiguous (), "w1 must be contiguous"
808799 assert w2 .is_contiguous (), "w2 must be contiguous"
809- fc1_output = self .grouped_gemm (
810- permuted_local_hidden_states , w1 , tokens_per_expert , trans_b = False , ** (gemm_kargs [0 ])
811- )
800+ if PrimusTurboFP8GlobalStateManager .is_turbo_fp8_enabled ():
801+ quant_config = PrimusTurboFP8GlobalStateManager .get_turbo_fp8_quant_config ()
802+ fc1_output = pt .ops .grouped_gemm_fp8 (
803+ permuted_local_hidden_states , w1 , tokens_per_expert , trans_b = False , config = quant_config
804+ )
805+ else :
806+ fc1_output = self .grouped_gemm (
807+ permuted_local_hidden_states , w1 , tokens_per_expert , trans_b = False , ** (gemm_kargs [0 ])
808+ )
812809 if self .activation_recompute :
813810 if args .use_turbo_fused_act_with_probs :
814811 intermediate_parallel = self .activation_checkpoint .checkpoint (
@@ -821,9 +818,15 @@ def forward(
821818 intermediate_parallel = self .activation_checkpoint .checkpoint (
822819 self .activation_func_with_probs , fc1_output , permuted_probs .unsqueeze (- 1 )
823820 )
824- fc2_output = self .grouped_gemm (
825- intermediate_parallel , w2 , tokens_per_expert , trans_b = False , ** (gemm_kargs [1 ])
826- )
821+ if PrimusTurboFP8GlobalStateManager .is_turbo_fp8_enabled ():
822+ quant_config = PrimusTurboFP8GlobalStateManager .get_turbo_fp8_quant_config ()
823+ fc2_output = pt .ops .grouped_gemm_fp8 (
824+ intermediate_parallel , w2 , tokens_per_expert , trans_b = False , config = quant_config
825+ )
826+ else :
827+ fc2_output = self .grouped_gemm (
828+ intermediate_parallel , w2 , tokens_per_expert , trans_b = False , ** (gemm_kargs [1 ])
829+ )
827830 self .activation_checkpoint .discard_output_and_register_recompute (fc2_output )
828831 else :
829832 if args .use_turbo_fused_act_with_probs :
@@ -834,9 +837,15 @@ def forward(
834837 intermediate_parallel = self .activation_func_with_probs (
835838 fc1_output , permuted_probs .unsqueeze (- 1 )
836839 )
837- fc2_output = self .grouped_gemm (
838- intermediate_parallel , w2 , tokens_per_expert , trans_b = False , ** (gemm_kargs [1 ])
839- )
840+ if PrimusTurboFP8GlobalStateManager .is_turbo_fp8_enabled ():
841+ quant_config = PrimusTurboFP8GlobalStateManager .get_turbo_fp8_quant_config ()
842+ fc2_output = pt .ops .grouped_gemm_fp8 (
843+ intermediate_parallel , w2 , tokens_per_expert , trans_b = False , config = quant_config
844+ )
845+ else :
846+ fc2_output = self .grouped_gemm (
847+ intermediate_parallel , w2 , tokens_per_expert , trans_b = False , ** (gemm_kargs [1 ])
848+ )
840849 else :
841850 # No token is allocated for local experts.
842851 assert torch .count_nonzero (tokens_per_expert ) == 0
@@ -925,9 +934,7 @@ def __init__(
925934 deepep_num_use_cu = args .turbo_deepep_num_cu ,
926935 deepep_num_worst_tokens = num_worst_tokens ,
927936 deepep_use_cuda_num_tokens_per_expert = (
928- args .use_turbo_grouped_mlp
929- and args .moe_use_legacy_grouped_gemm
930- and args .grouped_gemm_backend == "turbo-gg"
937+ args .use_turbo_grouped_mlp and args .moe_use_legacy_grouped_gemm
931938 ),
932939 deepep_async_finish = True ,
933940 deepep_allocate_on_comm_stream = True ,
0 commit comments