Skip to content

Commit bcb9807

Browse files
zhaochaoxinghellozmz
authored andcommitted
opt moe_block by dlblas
1 parent f50a1f4 commit bcb9807

File tree

1 file changed

+18
-1
lines changed
  • lmdeploy/pytorch/backends/cuda

1 file changed

+18
-1
lines changed

lmdeploy/pytorch/backends/cuda/moe.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,20 @@ def __init__(self,
588588
self.use_deep_gemm = False
589589
logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')
590590

591+
try:
592+
from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl
593+
self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size,
594+
ep_group=ep_group,
595+
top_k=top_k,
596+
num_experts=num_experts,
597+
hidden_dim=hidden_dim,
598+
renormalize=renormalize,
599+
block_size=block_size,
600+
out_dtype=out_dtype)
601+
except ImportError:
602+
self.dlblas_moe = None
603+
logger.warning('For higher performance, please install dlblas https://github.com/DeepLink-org/dlBlas')
604+
591605
def forward(self,
592606
hidden_states: torch.Tensor,
593607
topk_weights: torch.Tensor,
@@ -598,8 +612,11 @@ def forward(self,
598612
down_scale: torch.Tensor,
599613
expert_list: List[int] = None):
600614
"""forward."""
601-
topk_weights = self.do_renormalize(topk_weights)
602615
step_ctx = get_step_ctx_manager().current_context()
616+
if self.dlblas_moe is not None:
617+
return self.dlblas_moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale,
618+
down_weights, down_scale, step_ctx.is_decoding, expert_list)
619+
topk_weights = self.do_renormalize(topk_weights)
603620
low_latency_mode = step_ctx.is_decoding and self.use_deep_gemm
604621
moe = self.fusedmoe_build(low_latency_mode)
605622
out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights,

0 commit comments

Comments
 (0)