diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index 0ea38a915..0448dacea 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -512,6 +512,17 @@ def eager_mode(parser): help='Whether to enable eager mode. ' 'If True, cuda graph would be disabled') + @staticmethod + def prefill_without_permute(parser): + """Add argument prefill_without_permute to parser.""" + + return parser.add_argument('--prefill-without-permute', + action='store_true', + default=False, + help='Whether to enable prefill_without_permute. ' + 'If True, the moe layer would not permute the input, ' + 'and would not unpermute the output') + @staticmethod def communicator(parser): return parser.add_argument('--communicator', diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py index d3b0efccd..915ee0fa4 100644 --- a/lmdeploy/messages.py +++ b/lmdeploy/messages.py @@ -297,6 +297,8 @@ class PytorchEngineConfig: bit, set it to 4 or 8, respectively distributed_executor_backend (str): backend of distributed backend, options: ['uni', 'mp', 'ray'] + prefill_without_permute(bool): whether to use moe without permute. + Default to False. """ dtype: str = 'auto' tp: int = 1 @@ -321,6 +323,7 @@ class PytorchEngineConfig: revision: str = None quant_policy: Literal[0, 4, 8] = 0 distributed_executor_backend: str = None + prefill_without_permute: bool = False def __post_init__(self): """Check input validation.""" diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py index af206ed60..9ed5561b7 100644 --- a/lmdeploy/pytorch/backends/cuda/moe.py +++ b/lmdeploy/pytorch/backends/cuda/moe.py @@ -6,6 +6,7 @@ import torch.distributed as dist from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPTokenDispatcherLowLatency, TokenDispatcherBuilder +from lmdeploy.pytorch.distributed import prefill_without_permute from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8 from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8 from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8 @@ -20,6 +21,7 @@ from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder, FusedMoEW8A8Impl) +is_prefill_without_permute = prefill_without_permute() logger = get_logger('lmdeploy') @@ -425,7 +427,8 @@ def forward(self, up_scale: torch.Tensor, down_weights: torch.Tensor, down_scale: torch.Tensor, - expert_list: List[int] = None): + expert_list: List[int] = None, + dlblas_moe_impl: FusedMoEBlockedF8Impl = None): """forward.""" recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = self.token_dispatcher.dispatch( hidden_states, @@ -506,6 +509,21 @@ def __init__(self, self.use_deep_gemm = False logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM') + try: + # use dlblas moe block + from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl + self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size, + ep_group=ep_group, + top_k=top_k, + num_experts=num_experts, + hidden_dim=hidden_dim, + renormalize=renormalize, + block_size=block_size, + out_dtype=out_dtype) + except ImportError: + self.dlblas_moe = None + logger.warning('For higher performance, please install dlblas https://github.com/DeepLink-org/dlBlas') + def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, @@ -516,8 +534,13 @@ def forward(self, down_scale: torch.Tensor, expert_list: List[int] = None): """forward.""" - topk_weights = _renormalize(topk_weights, self.renormalize) step_ctx = get_step_ctx_manager().current_context() + # use dlblas moe block + if self.dlblas_moe is not None and is_prefill_without_permute: + return self.dlblas_moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale, + down_weights, down_scale, step_ctx.is_decoding, expert_list) + + topk_weights = _renormalize(topk_weights, self.renormalize) moe = None if step_ctx.is_decoding is False or self.use_deep_gemm is False: moe = FusedMoENormal(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size, diff --git a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py index d3585736d..84fd5b1b2 100644 --- a/lmdeploy/pytorch/backends/cuda/token_dispatcher.py +++ b/lmdeploy/pytorch/backends/cuda/token_dispatcher.py @@ -17,7 +17,6 @@ _buffer_low_latency = None _buffer_common = None - def get_buffer_common( group: dist.ProcessGroup, num_max_dispatch_tokens_per_rank: int, diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 9c9840439..87a0ca55f 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -97,6 +97,7 @@ class DistConfig: dp_rank: int = 0 world_size: int = None attn_config: 'DistConfig' = None + prefill_without_permute: bool = False def __post_init__(self): """post init.""" diff --git a/lmdeploy/pytorch/distributed.py b/lmdeploy/pytorch/distributed.py index 34347222d..c379954bd 100644 --- a/lmdeploy/pytorch/distributed.py +++ b/lmdeploy/pytorch/distributed.py @@ -29,6 +29,7 @@ class DistContext: ep_gpu_group: dist.ProcessGroup = None ep_gpu_groups: List[dist.ProcessGroup] = None dist_config: DistConfig = None + prefill_without_permute: bool = False @classmethod def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'): @@ -44,6 +45,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = ep = dist_config.ep world_size = dist_config.world_size dp_rank = dist_config.dp_rank + prefill_without_permute = dist_config.prefill_without_permute if world_size == 1: return DistContext(dist_config=dist_config) @@ -104,6 +106,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = ep_gpu_group=ep_gpu_group, ep_gpu_groups=ep_gpu_groups, dist_config=dist_config, + prefill_without_permute=prefill_without_permute, ) return context @@ -181,6 +184,11 @@ def get_ep_world_rank(): return ctx.ep, ctx.ep_rank +def prefill_without_permute(): + ctx = get_dist_manager().current_context() + return ctx.prefill_without_permute + + def _check_group_device(device: str): """check group device.""" assert (device in ['cpu', 'gpu']), ('Expect process group device in ("cpu", "gpu"), ' diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py index 80df39282..49f10b792 100644 --- a/lmdeploy/pytorch/engine/engine.py +++ b/lmdeploy/pytorch/engine/engine.py @@ -89,6 +89,7 @@ def _build_dist_config(engine_config: PytorchEngineConfig): tp=engine_config.tp, ep=engine_config.ep, dp_rank=engine_config.dp_rank, + prefill_without_permute=engine_config.prefill_without_permute, ) return dist_config