Skip to content

Zmz/prefill without permute by dlblas #3430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,17 @@ def eager_mode(parser):
help='Whether to enable eager mode. '
'If True, cuda graph would be disabled')

@staticmethod
def prefill_without_permute(parser):
"""Add argument prefill_without_permute to parser."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regarding dlblas' option, I recommend using env variables.


return parser.add_argument('--prefill-without-permute',
action='store_true',
default=False,
help='Whether to enable prefill_without_permute. '
'If True, the moe layer would not permute the input, '
'and would not unpermute the output')

@staticmethod
def communicator(parser):
return parser.add_argument('--communicator',
Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class PytorchEngineConfig:
bit, set it to 4 or 8, respectively
distributed_executor_backend (str): backend of distributed backend,
options: ['uni', 'mp', 'ray']
prefill_without_permute(bool): whether to use moe without permute.
Default to False.
"""
dtype: str = 'auto'
tp: int = 1
Expand All @@ -321,6 +323,7 @@ class PytorchEngineConfig:
revision: str = None
quant_policy: Literal[0, 4, 8] = 0
distributed_executor_backend: str = None
prefill_without_permute: bool = False

def __post_init__(self):
"""Check input validation."""
Expand Down
27 changes: 25 additions & 2 deletions lmdeploy/pytorch/backends/cuda/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.distributed as dist

from lmdeploy.pytorch.backends.cuda.token_dispatcher import DeepEPTokenDispatcherLowLatency, TokenDispatcherBuilder
from lmdeploy.pytorch.distributed import prefill_without_permute
from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import fused_moe_blocked_fp8
from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
Expand All @@ -20,6 +21,7 @@
from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl, FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
FusedMoEW8A8Impl)

is_prefill_without_permute = prefill_without_permute()
logger = get_logger('lmdeploy')


Expand Down Expand Up @@ -425,7 +427,8 @@ def forward(self,
up_scale: torch.Tensor,
down_weights: torch.Tensor,
down_scale: torch.Tensor,
expert_list: List[int] = None):
expert_list: List[int] = None,
dlblas_moe_impl: FusedMoEBlockedF8Impl = None):
"""forward."""
recv_hidden_states, recv_topk_ids, recv_topk_weights, tokens_per_expert = self.token_dispatcher.dispatch(
hidden_states,
Expand Down Expand Up @@ -506,6 +509,21 @@ def __init__(self,
self.use_deep_gemm = False
logger.warning('For higher performance, please install DeepGEMM https://github.com/deepseek-ai/DeepGEMM')

try:
# use dlblas moe block
from dlblas.layers.moe.ep_moe import FusedMoEBlockedF8Impl
self.dlblas_moe = FusedMoEBlockedF8Impl(ep_size=ep_size,
ep_group=ep_group,
top_k=top_k,
num_experts=num_experts,
hidden_dim=hidden_dim,
renormalize=renormalize,
block_size=block_size,
out_dtype=out_dtype)
except ImportError:
self.dlblas_moe = None
logger.warning('For higher performance, please install dlblas https://github.com/DeepLink-org/dlBlas')

def forward(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
Expand All @@ -516,8 +534,13 @@ def forward(self,
down_scale: torch.Tensor,
expert_list: List[int] = None):
"""forward."""
topk_weights = _renormalize(topk_weights, self.renormalize)
step_ctx = get_step_ctx_manager().current_context()
# use dlblas moe block
if self.dlblas_moe is not None and is_prefill_without_permute:
return self.dlblas_moe.forward(hidden_states, topk_weights, topk_ids, gate_up_weights, gate_up_scale,
down_weights, down_scale, step_ctx.is_decoding, expert_list)

topk_weights = _renormalize(topk_weights, self.renormalize)
moe = None
if step_ctx.is_decoding is False or self.use_deep_gemm is False:
moe = FusedMoENormal(self.ep_size, self.ep_group, self.num_experts, self.hidden_dim, self.block_size,
Expand Down
1 change: 0 additions & 1 deletion lmdeploy/pytorch/backends/cuda/token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
_buffer_low_latency = None
_buffer_common = None


def get_buffer_common(
group: dist.ProcessGroup,
num_max_dispatch_tokens_per_rank: int,
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class DistConfig:
dp_rank: int = 0
world_size: int = None
attn_config: 'DistConfig' = None
prefill_without_permute: bool = False

def __post_init__(self):
"""post init."""
Expand Down
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class DistContext:
ep_gpu_group: dist.ProcessGroup = None
ep_gpu_groups: List[dist.ProcessGroup] = None
dist_config: DistConfig = None
prefill_without_permute: bool = False

@classmethod
def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str = 'nccl'):
Expand All @@ -44,6 +45,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str =
ep = dist_config.ep
world_size = dist_config.world_size
dp_rank = dist_config.dp_rank
prefill_without_permute = dist_config.prefill_without_permute

if world_size == 1:
return DistContext(dist_config=dist_config)
Expand Down Expand Up @@ -104,6 +106,7 @@ def build(cls, rank: int = 0, dist_config: DistConfig = None, ccl_backend: str =
ep_gpu_group=ep_gpu_group,
ep_gpu_groups=ep_gpu_groups,
dist_config=dist_config,
prefill_without_permute=prefill_without_permute,
)
return context

Expand Down Expand Up @@ -181,6 +184,11 @@ def get_ep_world_rank():
return ctx.ep, ctx.ep_rank


def prefill_without_permute():
ctx = get_dist_manager().current_context()
return ctx.prefill_without_permute


def _check_group_device(device: str):
"""check group device."""
assert (device in ['cpu', 'gpu']), ('Expect process group device in ("cpu", "gpu"), '
Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def _build_dist_config(engine_config: PytorchEngineConfig):
tp=engine_config.tp,
ep=engine_config.ep,
dp_rank=engine_config.dp_rank,
prefill_without_permute=engine_config.prefill_without_permute,
)
return dist_config

Expand Down
Loading