@@ -588,6 +588,20 @@ def __init__(self,
588
588
self .use_deep_gemm = False
589
589
logger .warning ('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM' )
590
590
591
+ try :
592
+ from dlblas .layers .moe .ep_moe import FusedMoEBlockedF8Impl
593
+ self .dlblas_moe = FusedMoEBlockedF8Impl (ep_size = ep_size ,
594
+ ep_group = ep_group ,
595
+ top_k = top_k ,
596
+ num_experts = num_experts ,
597
+ hidden_dim = hidden_dim ,
598
+ renormalize = renormalize ,
599
+ block_size = block_size ,
600
+ out_dtype = out_dtype )
601
+ except ImportError :
602
+ self .dlblas_moe = None
603
+ logger .warning ('For higher performance, please install dlblas https://github.com/DeepLink-org/dlBlas' )
604
+
591
605
def forward (self ,
592
606
hidden_states : torch .Tensor ,
593
607
topk_weights : torch .Tensor ,
@@ -598,8 +612,11 @@ def forward(self,
598
612
down_scale : torch .Tensor ,
599
613
expert_list : List [int ] = None ):
600
614
"""forward."""
601
- topk_weights = self .do_renormalize (topk_weights )
602
615
step_ctx = get_step_ctx_manager ().current_context ()
616
+ if self .dlblas_moe is not None :
617
+ return self .dlblas_moe .forward (hidden_states , topk_weights , topk_ids , gate_up_weights , gate_up_scale ,
618
+ down_weights , down_scale , step_ctx .is_decoding , expert_list )
619
+ topk_weights = self .do_renormalize (topk_weights )
603
620
low_latency_mode = step_ctx .is_decoding and self .use_deep_gemm
604
621
moe = self .fusedmoe_build (low_latency_mode )
605
622
out_states = moe .forward (hidden_states , topk_weights , topk_ids , gate_up_weights , gate_up_scale , down_weights ,
0 commit comments