From 7e7d32a122f585769730ccd38221cbe4335d8e41 Mon Sep 17 00:00:00 2001 From: zhaochaoxing Date: Fri, 11 Apr 2025 15:05:44 +0800 Subject: [PATCH 1/5] for dlblas --- lmdeploy/pytorch/backends/cuda/moe.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index af206ed60..580faf23d 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -506,6 +506,20 @@ def __init__(self, self.use_deep_gemm = False logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') + try: + from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl + self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size, + ep_group=ep_group, + top_k=top_k, + num_experts=num_experts, + hidden_dim=hidden_dim, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype) + except ImportError: + self.dlblas_moe = None + logger.warning('For higher performance, please install dlblas https://github.com/DeepLink-org/dlBlas') + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -516,8 +530,11 @@ def forward(self, down_scale: torch.Tensor, expert_list: List[int] = None): """forward.""" - topk_weights = _renormalize(topk_weights, self.renormalize) step_ctx = get_step_ctx_manager().current_context() + if self.dlblas_moe is not None: + return self.dlblas_moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, + down_weights, down_scale, step_ctx.is_decoding, expert_list) + topk_weights = _renormalize(topk_weights, self.renormalize) moe = None if step_ctx.is_decoding is False or self.use_deep_gemm is False: moe = FusedMoENormal(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, From 47d9980ff9a5105cfc9168d5da95df09ba6c052d Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Mon, 14 Apr 2025 11:02:56 +0800 Subject: [PATCH 2/5] add args prefill_without_permute --- lmdeploy/cli/utils.py | 11 +++++++++++ lmdeploy/messages.py | 3 +++ lmdeploy/pytorch/backends/cuda/moe.py | 2 ++ lmdeploy/pytorch/config.py | 1 + lmdeploy/pytorch/distributed.py | 8 ++++++++ lmdeploy/pytorch/engine/engine.py | 1 + 6 files changed, 26 insertions(+) diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 0ea38a915..0448dacea 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -512,6 +512,17 @@ def eager_mode(parser): help='Whether to enable eager mode. ' 'If True, cuda graph would be disabled') + @staticmethod + def prefill_without_permute(parser): + """Add argument prefill_without_permute to parser.""" + + return parser.add_argument('--prefill-without-permute', + action='store_true', + default=False, + help='Whether to enable prefill_without_permute. ' + 'If True, the moe layer would not permute the input, ' + 'and would not unpermute the output') + @staticmethod def communicator(parser): return parser.add_argument('--communicator', diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index d3b0efccd..915ee0fa4 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -297,6 +297,8 @@ class PytorchEngineConfig: bit, set it to 4 or 8, respectively distributed_executor_backend (str): backend of distributed backend, options: ['uni', 'mp', 'ray'] + prefill_without_permute(bool): whether to use moe without permute. + Default to False. """ dtype: str = 'auto' tp: int = 1 @@ -321,6 +323,7 @@ class PytorchEngineConfig: revision: str = None quant_policy: Literal[0, 4, 8] = 0 distributed_executor_backend: str = None + prefill_without_permute: bool = False def __post_init__(self): """Check input validation.""" diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index 580faf23d..83c17bfee 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -6,6 +6,7 @@ import torch.distributed as dist from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPTokenDispatcherLowLatency, TokenDispatcherBuilder +from lmdeploy.pytorch.distributed import prefill_without_permute from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8 from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 @@ -20,6 +21,7 @@ from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder, FusedMoEW8A8Impl) +is_prefill_without_permute = prefill_without_permute() logger = get_logger('lmdeploy') diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 9c9840439..87a0ca55f 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -97,6 +97,7 @@ class DistConfig: dp_rank: int = 0 world_size: int = None attn_config: 'DistConfig' = None + prefill_without_permute: bool = False def __post_init__(self): """post init.""" diff --git a/lmdeploy/pytorch/distributed.py b/lmdeploy/pytorch/distributed.py index 34347222d..c379954bd 100644 --- a/lmdeploy/pytorch/distributed.py +++ b/lmdeploy/pytorch/distributed.py @@ -29,6 +29,7 @@ class DistContext: ep_gpu_group: dist.ProcessGroup = None ep_gpu_groups: List[dist.ProcessGroup] = None dist_config: DistConfig = None + prefill_without_permute: bool = False @classmethod def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'): @@ -44,6 +45,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = ep = dist_config.ep world_size = dist_config.world_size dp_rank = dist_config.dp_rank + prefill_without_permute = dist_config.prefill_without_permute if world_size == 1: return DistContext(dist_config=dist_config) @@ -104,6 +106,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = ep_gpu_group=ep_gpu_group, ep_gpu_groups=ep_gpu_groups, dist_config=dist_config, + prefill_without_permute=prefill_without_permute, ) return context @@ -181,6 +184,11 @@ def get_ep_world_rank(): return ctx.ep, ctx.ep_rank +def prefill_without_permute(): + ctx = get_dist_manager().current_context() + return ctx.prefill_without_permute + + def _check_group_device(device: str): """check group device.""" assert (device in ['cpu', 'gpu']), ('Expect process group device in ("cpu", "gpu"), ' diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 80df39282..49f10b792 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -89,6 +89,7 @@ def _build_dist_config(engine_config: PytorchEngineConfig): tp=engine_config.tp, ep=engine_config.ep, dp_rank=engine_config.dp_rank, + prefill_without_permute=engine_config.prefill_without_permute, ) return dist_config From cb6a364bc0f043343aaa53cab4a0b8c2b26cf255 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Mon, 14 Apr 2025 14:47:24 +0800 Subject: [PATCH 3/5] use dlblas moe kernel --- lmdeploy/pytorch/backends/cuda/moe.py | 51 ++++++++++++------- .../pytorch/backends/cuda/token_dispatcher.py | 18 +++++-- 2 files changed, 46 insertions(+), 23 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index 83c17bfee..9dbd3f95e 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -427,7 +427,8 @@ def forward(self, up_scale: torch.Tensor, down_weights: torch.Tensor, down_scale: torch.Tensor, - expert_list: List[int] = None): + expert_list: List[int] = None, + dlblas_moe_impl: FusedMoEBlockedF8Impl = None): """forward.""" recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = self.token_dispatcher.dispatch( hidden_states, @@ -435,8 +436,13 @@ def forward(self, topk_weights, expert_list, ) - out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, up_weights, up_scale, down_weights, - down_scale) + if is_prefill_without_permute and dlblas_moe_impl != None: + logger.error(f"dlblas_moe_impl is try to run.") + out_states = dlblas_moe_impl.forward(recv_hidden_states, recv_topk_weights, recv_topk_ids, up_weights, up_scale, + down_weights, down_scale) + else: + out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, up_weights, up_scale, down_weights, + down_scale) out_states = self.token_dispatcher.combine(out_states) return out_states @@ -509,15 +515,22 @@ def __init__(self, logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') try: - from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl - self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size, - ep_group=ep_group, - top_k=top_k, - num_experts=num_experts, - hidden_dim=hidden_dim, - renormalize=renormalize, - block_size=block_size, - out_dtype=out_dtype) + # from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl + # self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size, + # ep_group=ep_group, + # top_k=top_k, + # num_experts=num_experts, + # hidden_dim=hidden_dim, + # renormalize=renormalize, + # block_size=block_size, + # out_dtype=out_dtype) + from dlblas.layers.moe.ep_moe import DlblasTritonFusedMoEBlockedF8Impl + self.dlblas_moe = DlblasTritonFusedMoEBlockedF8Impl(top_k=top_k, + num_experts=num_experts, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype, + ep_size=ep_size) except ImportError: self.dlblas_moe = None logger.warning('For higher performance, please install dlblas https://github.com/DeepLink-org/dlBlas') @@ -533,19 +546,21 @@ def forward(self, expert_list: List[int] = None): """forward.""" step_ctx = get_step_ctx_manager().current_context() - if self.dlblas_moe is not None: - return self.dlblas_moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, - down_weights, down_scale, step_ctx.is_decoding, expert_list) - topk_weights = _renormalize(topk_weights, self.renormalize) + if is_prefill_without_permute: + pass + else: + topk_weights = _renormalize(topk_weights, self.renormalize) moe = None if step_ctx.is_decoding is False or self.use_deep_gemm is False: moe = FusedMoENormal(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, self.out_dtype) + out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, + down_scale, expert_list, self.dlblas_moe) else: moe = FusedMoELowLatency(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, self.out_dtype) - out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, - down_scale, expert_list) + out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, + down_scale, expert_list) return out_states diff --git a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py index d3585736d..d430eb7f8 100644 --- a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py +++ b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py @@ -13,10 +13,12 @@ from ..default.token_dispatcher import AlltoAllTokenDispatcher from ..token_dispatcher import TokenDispatcherImpl +from lmdeploy.pytorch.distributed import prefill_without_permute + _buffer_normal = None _buffer_low_latency = None _buffer_common = None - +is_prefill_without_permute = prefill_without_permute() def get_buffer_common( group: dist.ProcessGroup, @@ -154,8 +156,11 @@ def dispatch( self.handle = handle self.topk_idx = topk_idx self.topk_weights = topk_weights - if hidden_states.shape[0] > 0: - hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states) + if is_prefill_without_permute: + pass + else: + if hidden_states.shape[0] > 0: + hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states) return hidden_states, topk_idx, topk_weights, tokens_per_expert def dispatch_normal( @@ -210,8 +215,11 @@ def dispatch_normal( ) def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - if hidden_states.shape[0] > 0: - hidden_states = self.get_restored_hidden_states_by_experts(hidden_states) + if is_prefill_without_permute: + pass + else: + if hidden_states.shape[0] > 0: + hidden_states = self.get_restored_hidden_states_by_experts(hidden_states) hidden_states, event = self.combine_normal(hidden_states, self.handle) self.handle = None return hidden_states.view(self.hidden_shape) From 20875f4c8a6efd310da9c883b3eddcf6a930caf2 Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Mon, 14 Apr 2025 15:14:53 +0800 Subject: [PATCH 4/5] rm debug logs --- lmdeploy/pytorch/backends/cuda/moe.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index 9dbd3f95e..2aa6e6b41 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -437,7 +437,6 @@ def forward(self, expert_list, ) if is_prefill_without_permute and dlblas_moe_impl != None: - logger.error(f"dlblas_moe_impl is try to run.") out_states = dlblas_moe_impl.forward(recv_hidden_states, recv_topk_weights, recv_topk_ids, up_weights, up_scale, down_weights, down_scale) else: @@ -515,15 +514,6 @@ def __init__(self, logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') try: - # from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl - # self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size, - # ep_group=ep_group, - # top_k=top_k, - # num_experts=num_experts, - # hidden_dim=hidden_dim, - # renormalize=renormalize, - # block_size=block_size, - # out_dtype=out_dtype) from dlblas.layers.moe.ep_moe import DlblasTritonFusedMoEBlockedF8Impl self.dlblas_moe = DlblasTritonFusedMoEBlockedF8Impl(top_k=top_k, num_experts=num_experts, @@ -547,6 +537,7 @@ def forward(self, """forward.""" step_ctx = get_step_ctx_manager().current_context() if is_prefill_without_permute: + # dlblas_moe support prefill without permute, when use dlblas moe, we will renormalize topk_weights in dlblas moe. pass else: topk_weights = _renormalize(topk_weights, self.renormalize) From a793d6eab01d7f006bd1c2cf13a1eca854c63c9f Mon Sep 17 00:00:00 2001 From: hellozmz <407190054@qq.com> Date: Tue, 15 Apr 2025 11:42:24 +0800 Subject: [PATCH 5/5] opt code --- lmdeploy/pytorch/backends/cuda/moe.py | 42 +++++++++---------- .../pytorch/backends/cuda/token_dispatcher.py | 17 ++------ 2 files changed, 24 insertions(+), 35 deletions(-) diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index 2aa6e6b41..9ed5561b7 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -436,12 +436,8 @@ def forward(self, topk_weights, expert_list, ) - if is_prefill_without_permute and dlblas_moe_impl != None: - out_states = dlblas_moe_impl.forward(recv_hidden_states, recv_topk_weights, recv_topk_ids, up_weights, up_scale, - down_weights, down_scale) - else: - out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, up_weights, up_scale, down_weights, - down_scale) + out_states = self.experts.forward(recv_hidden_states, tokens_per_expert, up_weights, up_scale, down_weights, + down_scale) out_states = self.token_dispatcher.combine(out_states) return out_states @@ -514,13 +510,16 @@ def __init__(self, logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') try: - from dlblas.layers.moe.ep_moe import DlblasTritonFusedMoEBlockedF8Impl - self.dlblas_moe = DlblasTritonFusedMoEBlockedF8Impl(top_k=top_k, - num_experts=num_experts, - renormalize=renormalize, - block_size=block_size, - out_dtype=out_dtype, - ep_size=ep_size) + # use dlblas moe block + from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl + self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size, + ep_group=ep_group, + top_k=top_k, + num_experts=num_experts, + hidden_dim=hidden_dim, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype) except ImportError: self.dlblas_moe = None logger.warning('For higher performance, please install dlblas https://github.com/DeepLink-org/dlBlas') @@ -536,22 +535,21 @@ def forward(self, expert_list: List[int] = None): """forward.""" step_ctx = get_step_ctx_manager().current_context() - if is_prefill_without_permute: - # dlblas_moe support prefill without permute, when use dlblas moe, we will renormalize topk_weights in dlblas moe. - pass - else: - topk_weights = _renormalize(topk_weights, self.renormalize) + # use dlblas moe block + if self.dlblas_moe is not None and is_prefill_without_permute: + return self.dlblas_moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, + down_weights, down_scale, step_ctx.is_decoding, expert_list) + + topk_weights = _renormalize(topk_weights, self.renormalize) moe = None if step_ctx.is_decoding is False or self.use_deep_gemm is False: moe = FusedMoENormal(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, self.out_dtype) - out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, - down_scale, expert_list, self.dlblas_moe) else: moe = FusedMoELowLatency(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, self.out_dtype) - out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, - down_scale, expert_list) + out_states = moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, down_weights, + down_scale, expert_list) return out_states diff --git a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py index d430eb7f8..84fd5b1b2 100644 --- a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py +++ b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py @@ -13,12 +13,9 @@ from ..default.token_dispatcher import AlltoAllTokenDispatcher from ..token_dispatcher import TokenDispatcherImpl -from lmdeploy.pytorch.distributed import prefill_without_permute - _buffer_normal = None _buffer_low_latency = None _buffer_common = None -is_prefill_without_permute = prefill_without_permute() def get_buffer_common( group: dist.ProcessGroup, @@ -156,11 +153,8 @@ def dispatch( self.handle = handle self.topk_idx = topk_idx self.topk_weights = topk_weights - if is_prefill_without_permute: - pass - else: - if hidden_states.shape[0] > 0: - hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states) + if hidden_states.shape[0] > 0: + hidden_states = self.get_permuted_hidden_states_by_experts(hidden_states) return hidden_states, topk_idx, topk_weights, tokens_per_expert def dispatch_normal( @@ -215,11 +209,8 @@ def dispatch_normal( ) def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - if is_prefill_without_permute: - pass - else: - if hidden_states.shape[0] > 0: - hidden_states = self.get_restored_hidden_states_by_experts(hidden_states) + if hidden_states.shape[0] > 0: + hidden_states = self.get_restored_hidden_states_by_experts(hidden_states) hidden_states, event = self.combine_normal(hidden_states, self.handle) self.handle = None return hidden_states.view(self.hidden_shape)