diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 4dced8f2207..e93768e3db4 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -228,6 +228,7 @@ def _validate_split_kv_size(value: int) -> int: "FD_DETERMINISTIC_LOG_MODE": lambda: bool(int(os.getenv("FD_DETERMINISTIC_LOG_MODE", "0"))), # Whether to use PD REORDER, can set 0 or 1 "FD_PD_REORDER": lambda: int(os.getenv("FD_PD_REORDER", "0")), + "FD_CUTEDSL_MOE_SCALAR_INPUT_SCALE": lambda: int(os.getenv("FD_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "1")), # Whether to enable KV cache lock, enforcing mutual exclusion between # PrefixCacheManager and Worker when accessing GPU KV cache. # Under certain DP+EP configurations, concurrent access (even read-only) diff --git a/fastdeploy/model_executor/layers/linear.py b/fastdeploy/model_executor/layers/linear.py index 2bee885ff43..c83d4680021 100644 --- a/fastdeploy/model_executor/layers/linear.py +++ b/fastdeploy/model_executor/layers/linear.py @@ -142,6 +142,7 @@ def __init__( self.is_quantized = fd_config.model_config.is_quantized and not ( fd_config.quant_config.name() == "mix_quant" and fd_config.quant_config.dense_quant_type is None ) + # key if weight_key: self.weight_key = f"{prefix}.{weight_key}" diff --git a/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py new file mode 100644 index 00000000000..638ad371c1c --- /dev/null +++ b/fastdeploy/model_executor/layers/moe/flashinfer_cutedsl_moe.py @@ -0,0 +1,206 @@ +""" +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +from typing import Any, Optional + +import paddle +from flashinfer import ( + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, +) +from flashinfer.cute_dsl.blockscaled_gemm import grouped_gemm_nt_masked + + +def _dtype_str(dtype) -> str: + """Normalize dtype to string, handling both paddle and torch proxy dtypes.""" + return str(dtype).split(".")[-1] + + +def _is_dtype(tensor, *dtype_names: str) -> bool: + """Check tensor dtype by name, compatible with both paddle and torch proxy tensors.""" + return _dtype_str(tensor.dtype) in dtype_names + + +def _perm(tensor, *dims): + """Permute tensor dims, compatible with both paddle (transpose) and torch proxy (permute).""" + try: + return tensor.transpose(list(dims)) + except TypeError: + return tensor.permute(*dims) + + +def get_cute_dtype(input) -> str: + s = _dtype_str(input.dtype) + if s == "bfloat16": + return "bfloat16" + elif s == "float16": + return "float16" + elif s == "float32": + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + +def flashinfer_cutedsl_moe_masked( + hidden_states: tuple, + input_global_scale: paddle.Tensor, + w1: paddle.Tensor, + w1_blockscale: paddle.Tensor, + w1_alpha: paddle.Tensor, + w2: paddle.Tensor, + a2_global_scale: paddle.Tensor, + w2_blockscale: paddle.Tensor, + w2_alpha: paddle.Tensor, + masked_m: paddle.Tensor, + down_sm_count: Optional[int] = None, + down_signals: Optional[paddle.Tensor] = None, + down_start_event: Optional[Any] = None, +): + """ + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL kernels. + + Args: + hidden_states: Either of the following: + * (paddle.Tensor, None): [num_experts, m, k] bf16 — not pre-quantized + * (paddle.Tensor, paddle.Tensor): [m, k//2, num_experts] uint8, + [m, k//16, num_experts] float8_e4m3fn — pre-quantized FP4 from dispatch + input_global_scale: (l,) float32, value is 1/input_scale per expert + w1: [l, 2*n, k//2] uint8, FP4-packed gate+up projection weights + w1_blockscale: float8_e4m3fn blockscale for w1 + w1_alpha: (l,) float32, = input_scale * w1_weight_scale_2 + w2: [l, k, n//2] uint8, FP4-packed down projection weights + a2_global_scale: (l,) float32, 1/input_scale for GEMM2 + w2_blockscale: float8_e4m3fn blockscale for w2 + w2_alpha: (l,) float32, = input_scale * w2_weight_scale_2 + masked_m: (l,) int32, valid token count per expert; max(masked_m) == m + + Returns: + paddle.Tensor: [num_experts, m, k] bf16 + """ + + # === Dtype assertions === + # Use string-based dtype check to be compatible with both paddle and torch proxy tensors + assert _is_dtype(w1, "uint8"), f"w1 must be uint8 (fp4 packed), got {w1.dtype}" + assert _is_dtype(w1_blockscale, "float8_e4m3fn"), f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" + assert _is_dtype(w1_alpha, "float32"), f"w1_alpha must be float32, got {w1_alpha.dtype}" + assert _is_dtype(w2, "uint8"), f"w2 must be uint8 (fp4 packed), got {w2.dtype}" + assert _is_dtype(a2_global_scale, "float32"), f"a2_global_scale must be float32, got {a2_global_scale.dtype}" + assert _is_dtype(w2_blockscale, "float8_e4m3fn"), f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" + assert _is_dtype(w2_alpha, "float32"), f"w2_alpha must be float32, got {w2_alpha.dtype}" + assert len(hidden_states) == 2, f"hidden_states must be a tuple of length 2, got {len(hidden_states)}" + + # intermediate_size derived from w2 last dimension + n = w2.shape[-1] * 2 + + if hidden_states[1] is not None: + # Pre-quantized path: tokens already FP4-packed by dispatch + # a_q: [m, k//2, num_experts] uint8 + # a_q_sf:[m, k//16, num_experts] float8_e4m3fn + a_q = hidden_states[0].view(paddle.uint8) + a_q_sf = hidden_states[1].view(paddle.float8_e4m3fn) + m, k_by_2, num_experts = a_q.shape + k = k_by_2 * 2 + else: + # Standard path: bf16 [num_experts, m, k], quantize to FP4 here + num_experts, m, k = hidden_states[0].shape + + assert _is_dtype( + input_global_scale, "float32" + ), f"input_global_scale must be float32, got {input_global_scale.dtype}" + assert list(input_global_scale.shape) == [ + num_experts + ], f"input_global_scale must be (l,), got {input_global_scale.shape}" + + a_q, a_q_sf = scaled_fp4_grouped_quantize( + hidden_states[0], + masked_m, + input_global_scale, + ) + + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n={2*n}, got {w1.shape[-2]}" + assert w1.shape[-1] * 2 == k, f"w1 last dim * 2 must equal k={k}, got {w1.shape[-1] * 2}" + assert ( + w2.shape[-2] == k and w2.shape[-1] == n // 2 + ), f"w2 shape mismatch, got {list(w2.shape[-2:])}, expected [{k}, {n // 2}]" + assert list(w1_alpha.shape) == [num_experts], f"w1_alpha must be (l,), got {w1_alpha.shape}" + assert list(a2_global_scale.shape) == [num_experts], f"a2_global_scale must be (l,), got {a2_global_scale.shape}" + assert list(w2_alpha.shape) == [num_experts], f"w2_alpha must be (l,), got {w2_alpha.shape}" + + assert _is_dtype(a_q, "uint8") + assert _is_dtype(a_q_sf, "float8_e4m3fn") + + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + c_dtype = "bfloat16" + sf_vec_size = 16 + + # === GEMM1: gate+up projection === + # grouped_gemm_nt_masked requires output in [m, 2*n, l] layout + gateup_output = paddle.empty([num_experts, m, n * 2], dtype=paddle.bfloat16) + gateup_output = gateup_output.transpose([1, 2, 0]) # [m, 2*n, num_experts] + + grouped_gemm_nt_masked( + (a_q, a_q_sf), + (_perm(w1, 1, 2, 0), w1_blockscale), + gateup_output, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w1_alpha.reshape([1, 1, num_experts]), + alpha_dtype=get_cute_dtype(w1_alpha), + ) # fills gateup_output in logical [m, 2*n, l] + + # === SiLU + mul + quantize intermediate activations to FP4 === + # Input expected as [num_experts, m, 2*n] + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( + gateup_output.transpose([2, 0, 1]), # [num_experts, m, 2*n] + masked_m, + a2_global_scale, + ) + + if down_start_event is not None: + down_start_event.record() + + # === GEMM2: down projection === + # grouped_gemm_nt_masked requires output in [m, k, l] layout + out = paddle.empty([num_experts, m, k], dtype=paddle.bfloat16) + out = out.transpose([1, 2, 0]) # [m, k, num_experts] + + grouped_gemm_nt_masked( + (diq, diq_sf), + (_perm(w2, 1, 2, 0), w2_blockscale), + out, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w2_alpha.reshape([1, 1, num_experts]), + alpha_dtype=get_cute_dtype(w2_alpha), + **( + dict( + sm_count=down_sm_count, + dst_signals=down_signals, + ) + if down_sm_count is not None or down_signals is not None + else {} + ), + ) # fills out in logical [m, k, l] + + # Return [num_experts, m, k] + return out.transpose([2, 0, 1]) diff --git a/fastdeploy/model_executor/layers/quantization/__init__.py b/fastdeploy/model_executor/layers/quantization/__init__.py index 430724c7e63..edd079784d2 100644 --- a/fastdeploy/model_executor/layers/quantization/__init__.py +++ b/fastdeploy/model_executor/layers/quantization/__init__.py @@ -118,31 +118,57 @@ def parse_quant_config(args, model_config, is_ernie, is_v1_loader): if quant_config_name is None: quant_config = None else: - if not quantization_config.get("is_quantized"): - quantization_config["is_quantized"] = model_config.is_quantized + # Handle both dict and QuantizationConfig object + if hasattr(quantization_config, "to_dict"): + quantization_config_dict = quantization_config.to_dict() + else: + quantization_config_dict = quantization_config if isinstance(quantization_config, dict) else {} + + if not quantization_config_dict.get("is_quantized"): + quantization_config_dict["is_quantized"] = model_config.is_quantized if args.dynamic_load_weight and quantization_config is not None: - quantization_config["is_quantized"] = True + quantization_config_dict["is_quantized"] = True quant_cls = get_quantization_config(quant_config_name) - quant_config = quant_cls.from_config(quantization_config) + quant_config = quant_cls.from_config(quantization_config_dict) return quant_config def _get_offline_quant_config_name(quantization_config, is_torch_weight, is_v1_loader): if is_torch_weight: # only support block_wise_fp8 now - quant_method = quantization_config.get("quant_method") - has_block_size = "weight_block_size" in quantization_config + # Handle both dict and QuantizationConfig object + if hasattr(quantization_config, "quant_method"): + quant_method = quantization_config.quant_method + else: + quant_method = quantization_config.get("quant_method") + + has_block_size = ( + "weight_block_size" in quantization_config + if isinstance(quantization_config, dict) + else hasattr(quantization_config, "weight_block_size") + and quantization_config.weight_block_size is not None + ) + if quant_method == "fp8" and has_block_size: quant_config_name = "block_wise_fp8" elif quant_method == "modelopt": - if quantization_config.get("quant_algo", "") == "NVFP4": + # Try to get quant_algo from dict or from to_dict() method + quant_algo = None + if isinstance(quantization_config, dict): + quant_algo = quantization_config.get("quant_algo", "") + elif hasattr(quantization_config, "to_dict"): + quant_algo = quantization_config.to_dict().get("quant_algo", "") + + if quant_algo == "NVFP4": quant_config_name = "modelopt_fp4" else: - raise ValueError("modelopt only supports NVFP4 quantization.") + raise ValueError(f"modelopt only supports NVFP4 quantization, got quant_algo={quant_algo}") elif quant_method == "mxfp4": quant_config_name = "mxfp4" else: - raise ValueError("Torch weight offline quantization only supports block-wise FP8.") + raise ValueError( + f"Torch weight offline quantization only supports block-wise FP8, modelopt NVFP4, or mxfp4. Got quant_method={quant_method}" + ) else: quant_config_name = quantization_config["quantization"] return quant_config_name diff --git a/fastdeploy/model_executor/layers/quantization/nvfp4.py b/fastdeploy/model_executor/layers/quantization/nvfp4.py index 4db2c919fce..c5e51bc1d26 100644 --- a/fastdeploy/model_executor/layers/quantization/nvfp4.py +++ b/fastdeploy/model_executor/layers/quantization/nvfp4.py @@ -22,6 +22,7 @@ import fastdeploy from fastdeploy import envs from fastdeploy.model_executor.layers.moe import FusedMoE +from fastdeploy.model_executor.layers.moe.fused_moe_backend_base import MoEMethodBase from fastdeploy.model_executor.utils import ( create_parameter_and_copy, free_tensor, @@ -32,6 +33,8 @@ paddle.compat.enable_torch_proxy(scope={"flashinfer"}) +CUTEDSL_MOE_SCALAR_INPUT_SCALE = bool(envs.FD_CUTEDSL_MOE_SCALAR_INPUT_SCALE) + def next_power_of_2(n: int): return 1 << (n - 1).bit_length() if n > 0 else 1 @@ -97,12 +100,26 @@ def name(self) -> str: @classmethod def from_config(cls, config: dict) -> "ModelOptNvFp4Config": quant_config = config - quant_method = quant_config.get("quant_algo", "") + + # Handle both dict and QuantizationConfig object + if hasattr(quant_config, "to_dict"): + quant_config_dict = quant_config.to_dict() + else: + quant_config_dict = quant_config if isinstance(quant_config, dict) else {} + + # Try to get quant_algo from config or from nested structure + quant_method = quant_config_dict.get("quant_algo", "") + if not quant_method: + # Try from nested quantization key + if "quantization" in quant_config_dict: + quant_method = quant_config_dict["quantization"].get("quant_algo", "") if not quant_method: raise ValueError("Missing 'quant_algo' in quantization config") # Handle kv_cache_quant_algo with proper type validation - kv_cache_quant_algo_raw = quant_config.get("kv_cache_quant_algo") + kv_cache_quant_algo_raw = quant_config_dict.get("kv_cache_quant_algo") + if kv_cache_quant_algo_raw is None and "quantization" in quant_config_dict: + kv_cache_quant_algo_raw = quant_config_dict["quantization"].get("kv_cache_quant_algo") if kv_cache_quant_algo_raw is None: # No KV cache quantization by default kv_cache_quant_algo = None @@ -112,7 +129,9 @@ def from_config(cls, config: dict) -> "ModelOptNvFp4Config": raise ValueError(f"kv_cache_quant_algo must be a string, got " f"{type(kv_cache_quant_algo_raw)}") # Handle group_size with proper type validation - group_size_raw = quant_config.get("group_size") + group_size_raw = quant_config_dict.get("group_size") + if group_size_raw is None and "quantization" in quant_config_dict: + group_size_raw = quant_config_dict["quantization"].get("group_size") if group_size_raw is None: group_size = 16 # Default value elif isinstance(group_size_raw, int): @@ -124,7 +143,9 @@ def from_config(cls, config: dict) -> "ModelOptNvFp4Config": raise ValueError(f"group_size must be an integer, got {type(group_size_raw)}") from None # "exclude_modules" is the key in the legacy hf_quant_config.json - exclude_modules = quant_config.get("exclude_modules", []) + exclude_modules = quant_config_dict.get("exclude_modules", []) + if not exclude_modules and "quantization" in quant_config_dict: + exclude_modules = quant_config_dict["quantization"].get("exclude_modules", []) if not isinstance(exclude_modules, list): raise ValueError(f"exclude_modules must be a list, got {type(exclude_modules)}") @@ -152,7 +173,10 @@ def get_quant_method(self, layer) -> Optional[QuantMethodBase]: """ Get quantization method. """ - if isinstance(layer, FusedMoE): + + if envs.FD_MOE_BACKEND == "flashinfer-cutedsl": + return ModelOptNvFp4FusedMoECuteDSL(self) + elif isinstance(layer, FusedMoE): return ModelOptNvFp4FusedMoE(self) else: return ModelOptNvFp4LinearMethod(self) @@ -608,3 +632,299 @@ def apply( # flashinfer-trtllm return output + + +class ModelOptNvFp4FusedMoECuteDSL(MoEMethodBase): + def __init__(self, quant_config: ModelOptNvFp4Config): + super().__init__(quant_config) + self.added_weight_attrs = ["up_gate_proj_weight", "down_proj_weight"] + self.added_scale_attrs = [ + "up_gate_proj_weight_scale", + "down_proj_weight_scale", + ] + self.backend = "flashinfer-cutedsl" + + logger.info(f"Using {self.backend} for NVFP4 FusedMoE") + + def create_weights(self, layer, **extra_weight_attrs): + if not self.quant_config.is_checkpoint_nvfp4_serialized: + raise ValueError("NVFP4 quantization was selected, " " dynamic quantization is not supported.") + + self.up_gate_proj_weight_shape = [ + layer.num_local_experts, + layer.moe_intermediate_size * 2, + layer.hidden_size // 2, + ] + self.down_proj_weight_shape = [ + layer.num_local_experts, + layer.hidden_size, + layer.moe_intermediate_size // 2, + ] + self.up_gate_proj_scale_shape = self.up_gate_proj_weight_shape[0:2] + [ + layer.hidden_size // self.quant_config.group_size + ] + self.down_proj_scale_shape = self.down_proj_weight_shape[0:2] + [ + layer.moe_intermediate_size // self.quant_config.group_size + ] + + self.weight_scale_dtype = paddle.float8_e4m3fn + self.weight_dtype = paddle.uint8 + up_gate_proj_weight_name = self.added_weight_attrs[0] + down_proj_weight_name = self.added_weight_attrs[1] + up_gate_proj_scale_name = self.added_scale_attrs[0] + down_proj_scale_name = self.added_scale_attrs[1] + setattr( + layer, + up_gate_proj_weight_name, + layer.create_parameter( + shape=self.up_gate_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_weight_name, + layer.create_parameter( + shape=self.down_proj_weight_shape, + dtype=self.weight_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale + setattr( + layer, + up_gate_proj_scale_name, + layer.create_parameter( + shape=self.up_gate_proj_scale_shape, + dtype=self.weight_scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + setattr( + layer, + down_proj_scale_name, + layer.create_parameter( + shape=self.down_proj_scale_shape, + dtype=self.weight_scale_dtype, + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + # weight_scale_2 + layer.up_gate_proj_weight_scale_2 = layer.create_parameter( + shape=[layer.num_local_experts, 2], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.down_proj_weight_scale_2 = layer.create_parameter( + shape=[layer.num_local_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + # input_scale + layer.up_gate_proj_input_scale = layer.create_parameter( + shape=[layer.num_experts, 2], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + layer.down_proj_input_scale = layer.create_parameter( + shape=[layer.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ) + + set_weight_attrs( + getattr(layer, up_gate_proj_weight_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + set_weight_attrs( + getattr(layer, up_gate_proj_scale_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + + set_weight_attrs( + getattr(layer, down_proj_weight_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + set_weight_attrs( + getattr(layer, down_proj_scale_name), + {**extra_weight_attrs, "SHARD_ID_TO_SHARDED_DIM": {"gate": 0, "down": 1, "up": 0}}, + ) + + set_weight_attrs(layer.up_gate_proj_weight_scale_2, {**extra_weight_attrs, "weight_type": "weight_scale_2"}) + set_weight_attrs(layer.down_proj_weight_scale_2, {**extra_weight_attrs, "weight_type": "weight_scale_2"}) + set_weight_attrs(layer.up_gate_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"}) + set_weight_attrs(layer.down_proj_input_scale, {**extra_weight_attrs, "weight_type": "input_scale"}) + + def process_weights_after_loading(self, layer): + if layer.up_gate_proj_weight_scale_2.ndim == 1: + up_gate_proj_weight_scale_2 = layer.up_gate_proj_weight_scale_2 + else: + if layer.up_gate_proj_weight_scale_2.shape[1] >= 2 and not paddle.allclose( + layer.up_gate_proj_weight_scale_2[:, 0], + layer.up_gate_proj_weight_scale_2[:, 1], + ): + logger.warning_once( + "up_proj_weight_scale_2 must match gate_proj_weight_scale2. " "Accuracy may be affected" + ) + up_gate_proj_weight_scale_2 = layer.up_gate_proj_weight_scale_2[:, 0] + + free_tensor(layer.up_gate_proj_weight_scale_2) + create_parameter_and_copy(layer, name="up_gate_proj_weight_scale_2", weight=up_gate_proj_weight_scale_2) + + if CUTEDSL_MOE_SCALAR_INPUT_SCALE: + up_gate_proj_input_scale = ( + layer.up_gate_proj_input_scale.max().cast("float32").expand([layer.up_gate_proj_input_scale.shape[0]]) + ) + else: + up_gate_proj_input_scale = layer.up_gate_proj_input_scale.max(axis=1).values.cast("float32") + + down_proj_input_scale = layer.down_proj_input_scale.cast("float32") + + def _slice_scale(w): + assert ( + w.shape[0] == layer.num_experts + ), f"Expected scale shape[0] == num_experts ({layer.num_experts}), got {w.shape[0]}" + assert layer.ep_size * layer.num_local_experts == layer.num_experts + return w[layer.ep_rank * layer.num_local_experts : (layer.ep_rank + 1) * layer.num_local_experts] + + up_gate_proj_input_scale = _slice_scale(up_gate_proj_input_scale) + down_proj_input_scale = _slice_scale(down_proj_input_scale) + + # Step4: 计算并注册 g1_alphas / g2_alphas(= input_scale * weight_scale_2) + create_parameter_and_copy( + layer, + "g1_alphas", + (up_gate_proj_input_scale * up_gate_proj_weight_scale_2).cast("float32"), + ) + create_parameter_and_copy( + layer, + "g2_alphas", + (down_proj_input_scale * layer.down_proj_weight_scale_2).cast("float32"), + ) + + create_parameter_and_copy( + layer, + "up_gate_proj_input_scale_quant", + (1.0 / up_gate_proj_input_scale).cast("float32"), + ) + create_parameter_and_copy( + layer, + "down_proj_input_scale_quant", + (1.0 / down_proj_input_scale).cast("float32"), + ) + + assert_dim = 2 + # Step6: 处理 blockscale swizzle(cutedsl 使用与 cutlass 相同的 swizzle 布局) + for name, weight_scale in [ + ("up_gate", layer.up_gate_proj_weight_scale), + ("down", layer.down_proj_weight_scale), + ]: + if weight_scale.shape[assert_dim] % 4 != 0: + logger.warning( + "NVFP4 %s_weight_scale K' not multiple of 4: shape=%s, groop_size=%s", + name, + tuple(weight_scale.shape), + getattr(self.quant_config, "group_size", None), + ) + assert ( + weight_scale.dtype == paddle.float8_e4m3fn + ), f"{name} Weight Blockscale must be represented as FP8-E4M3" + + up_gate_proj_blockscale_swizzled = _process_scale_interleaved(layer.up_gate_proj_weight_scale) + free_tensor(layer.up_gate_proj_weight_scale) + layer.up_gate_proj_weight_scale = None + create_parameter_and_copy( + layer, name="up_gate_proj_blockscale_swizzled", weight=up_gate_proj_blockscale_swizzled + ) + + down_proj_blockscale_swizzled = _process_scale_interleaved(layer.down_proj_weight_scale) + free_tensor(layer.down_proj_weight_scale) + layer.down_proj_weight_scale = None + create_parameter_and_copy(layer, name="down_proj_blockscale_swizzled", weight=down_proj_blockscale_swizzled) + + def apply( + self, + layer, + x, + gate, + topk_ids_hookfunc=None, + ): + """ + NVFP4 CuteDSL MoE forward (no-dispatch version for testing). + + Flow: + 1. gate + topk selection + 2. Manual token permutation → [num_local_experts, max_m, hidden] + masked_m + 3. flashinfer_cutedsl_moe_masked (FP4 quantize → GEMM1 → SiLU → GEMM2) + 4. Manual combine → scatter back with routing weights + + Note: Production path should replace steps 2&4 with DeepEP low-latency dispatch/combine. + """ + from fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe import ( + flashinfer_cutedsl_moe_masked, + ) + + # Step 1: gate routing + gate_out = gate(x.cast("float32")) + topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select( + gate_out, + layer.gate_correction_bias, + layer.top_k, + True, # apply_norm_weight + False, + ) + + if topk_ids_hookfunc is not None: + topk_ids_hookfunc(topk_ids) + + num_local_experts = layer.num_local_experts + hidden_dim = x.shape[1] + flat_topk_ids = topk_ids.reshape([-1]) # [bs * top_k] + + # Step 2: compute masked_m (token count per expert) and permute tokens + masked_m = paddle.zeros([num_local_experts], dtype=paddle.int32) + for i in range(num_local_experts): + masked_m[i] = int((flat_topk_ids == i).sum()) + + max_m = int(masked_m.max()) + hidden_states_3d = paddle.zeros([num_local_experts, max_m, hidden_dim], dtype=x.dtype) + for i in range(num_local_experts): + count = int(masked_m[i]) + if count > 0: + mask = flat_topk_ids == i + hidden_states_3d[i, :count, :] = x[mask] + + # Step 3: CuteDSL masked GEMM + # hidden_states[1] is None → scaled_fp4_grouped_quantize is called inside the kernel + ffn_out = flashinfer_cutedsl_moe_masked( + hidden_states=(hidden_states_3d, None), + input_global_scale=layer.up_gate_proj_input_scale_quant, + w1=layer.up_gate_proj_weight, + w1_blockscale=layer.up_gate_proj_blockscale_swizzled, + w1_alpha=layer.g1_alphas, + w2=layer.down_proj_weight, + a2_global_scale=layer.down_proj_input_scale_quant, + w2_blockscale=layer.down_proj_blockscale_swizzled, + w2_alpha=layer.g2_alphas, + masked_m=masked_m, + ) + # ffn_out: [num_local_experts, max_m, hidden_dim] + + # Step 4: scatter output back and apply routing weights + # Use a loop to handle duplicate batch_indices correctly (same sample → multiple experts) + bs = x.shape[0] + output = paddle.zeros([bs, hidden_dim], dtype=ffn_out.dtype) + flat_weights = topk_weights.reshape([-1]) # [bs * top_k] + for i in range(num_local_experts): + count = int(masked_m[i]) + if count == 0: + continue + token_indices = paddle.nonzero(flat_topk_ids == i).squeeze(-1) # [count] + batch_indices = (token_indices // layer.top_k).tolist() + weights_list = flat_weights[token_indices].cast(ffn_out.dtype) # [count] + expert_out = ffn_out[i, :count, :] # [count, hidden_dim] + for j, b in enumerate(batch_indices): + output[b] += expert_out[j] * weights_list[j] + + return output diff --git a/fastdeploy/model_executor/layers/utils.py b/fastdeploy/model_executor/layers/utils.py index aa5fd39de0f..c7c0eab2f05 100644 --- a/fastdeploy/model_executor/layers/utils.py +++ b/fastdeploy/model_executor/layers/utils.py @@ -561,15 +561,33 @@ def get_sm_version(): def modules_to_convert(prefix: str, fd_config: FDConfig): import fnmatch + # Check exclude patterns from multiple sources + exclude_patterns = [] + + # 1. Check quantization_config["modules_to_not_convert"] if ( hasattr(fd_config.model_config, "quantization_config") and fd_config.model_config.quantization_config is not None ): if "modules_to_not_convert" in fd_config.model_config.quantization_config: - patterns = fd_config.model_config.quantization_config["modules_to_not_convert"] - for p in patterns: - if fnmatch.fnmatch(prefix, p) or fnmatch.fnmatch(prefix, p + ".*"): - return False - return True - else: - return True + exclude_patterns.extend(fd_config.model_config.quantization_config["modules_to_not_convert"]) + # 2. Check quantization_config["ignore"] (used by some models like NVFP4) + if "ignore" in fd_config.model_config.quantization_config: + exclude_patterns.extend(fd_config.model_config.quantization_config["ignore"]) + + # Get the model's actual prefix_name (e.g., "ernie" or "model") + prefix_name = "model" # default + if hasattr(fd_config, "model_config") and hasattr(fd_config.model_config, "pretrained_config"): + prefix_name = getattr(fd_config.model_config.pretrained_config, "prefix_name", "model") + + # Check if prefix matches any exclude pattern + for p in exclude_patterns: + # Direct match + if fnmatch.fnmatch(prefix, p) or fnmatch.fnmatch(prefix, p + ".*"): + return False + # Handle case where pattern uses "model" but actual prefix is "ernie" (or vice versa) + if p.startswith("model."): + adapted_pattern = prefix_name + "." + p[6:] + if fnmatch.fnmatch(prefix, adapted_pattern) or fnmatch.fnmatch(prefix, adapted_pattern + ".*"): + return False + return True diff --git a/fastdeploy/model_executor/utils.py b/fastdeploy/model_executor/utils.py index 6bb9245d9a3..c7fdeeb966b 100644 --- a/fastdeploy/model_executor/utils.py +++ b/fastdeploy/model_executor/utils.py @@ -147,9 +147,9 @@ def process_weight_transpose(layer, weight_name): return if len(weight.shape) == 2: - weight_transpose = weight.transpose([1, 0]) + weight_transpose = weight.transpose([1, 0]).contiguous() elif len(weight.shape) == 3: - weight_transpose = weight.transpose([0, 2, 1]) + weight_transpose = weight.transpose([0, 2, 1]).contiguous() weight_tmp.copy_(weight_transpose, False) free_tensor(weight) setattr(layer, weight_name, weight_tmp) @@ -183,7 +183,7 @@ def fn(model_sublayer_name: str, param=None): else: unquant_moe_cls = type(unquant_moe_layer) if type(quant_method) is UnquantizedLinearMethod or type(quant_method) is unquant_moe_cls: - # skip unquantized linear + # skip unquantized moe return if not hasattr(quant_method, "process_weights_after_loading"): return @@ -539,11 +539,12 @@ def fn(loaded_weight_name, is_moe): if fd_config.quant_config is None or fd_config.quant_config.is_checkpoint_bf16: return loaded_weight_name # Can be extended to other offline quantization suffixes if needed. + current_fd_suffix_map = {} # Default empty map if (is_moe and moe_quant_type == "block_wise_fp8") or (not is_moe and dense_quant_type == "block_wise_fp8"): - fd_suffix_map = fp8_suffix_map + current_fd_suffix_map = fp8_suffix_map if (is_moe and moe_quant_type == "tensor_wise_fp8") or (not is_moe and dense_quant_type == "tensor_wise_fp8"): - fd_suffix_map = tensor_wise_fp8_suffix_map - for ckpt_suffix, fd_suffix in fd_suffix_map.items(): + current_fd_suffix_map = tensor_wise_fp8_suffix_map + for ckpt_suffix, fd_suffix in current_fd_suffix_map.items(): if re.search(rf"{ckpt_suffix}$", loaded_weight_name): loaded_weight_name = loaded_weight_name.replace(ckpt_suffix, fd_suffix) return loaded_weight_name diff --git a/fuwu.sh b/fuwu.sh new file mode 100644 index 00000000000..86a3a79ab5d --- /dev/null +++ b/fuwu.sh @@ -0,0 +1,30 @@ +# online_inference.sh +for name in `env | grep -E 'PADDLE|ENDPOINT' | awk -F'=' '{print $1}'`; do + unset ${name} +done + +rm -rf log_eb +export FD_LOG_DIR=log_eb + +# model_path="/root/paddlejob/tmpspace/models/paddle/benchmark/checkpoint-320-safetensors-PT" +# /root/paddlejob/tmpspace/models/torch/Qwen3-30B-A3B +# model_path="/root/paddlejob/tmpspace/models/paddle/ERNIE-4.5-21B-A3B-Paddle" +model_path="/raid0/ERNIE-4.5-21B-A3B-FP4" + +export PYTHONPATH=/root/paddlejob/workspace/env_run/output/lizexu/FastDeploy:$PYTHONPATH + + +export FD_SAMPLING_CLASS=rejection +export INFERENCE_MSG_QUEUE_ID=8908 + +export FD_MOE_BACKEND="flashinfer-cutlass" + +python -m fastdeploy.entrypoints.openai.api_server \ + --model $model_path \ + --port 8183 \ + --tensor-parallel-size 1 \ + --max-model-len 32768 \ + --enable-overlap-schedule \ + --num-gpu-blocks-override 1024 \ + --max-num-seqs 128 \ + --graph-optimization-config '{"use_cudagraph":false}' diff --git a/tests/quantization/test_cutedsl_moe.py b/tests/quantization/test_cutedsl_moe.py new file mode 100644 index 00000000000..ba4ea743e83 --- /dev/null +++ b/tests/quantization/test_cutedsl_moe.py @@ -0,0 +1,134 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import sys +import types +import unittest + +import paddle + + +def _install_fake_flashinfer_for_cutedsl( + scaled_fp4_grouped_quantize=None, + silu_and_mul_scaled_nvfp4_experts_quantize=None, + grouped_gemm_nt_masked=None, +): + """ + Install a fake flashinfer module (and cute_dsl.blockscaled_gemm submodule) + so that importing flashinfer_cutedsl_moe does not require the real + flashinfer (and thus does not import torch). + """ + prev_flashinfer = sys.modules.get("flashinfer") + prev_cute_dsl = sys.modules.get("flashinfer.cute_dsl") + prev_blockscaled = sys.modules.get("flashinfer.cute_dsl.blockscaled_gemm") + + fake_flashinfer = types.ModuleType("flashinfer") + if scaled_fp4_grouped_quantize is not None: + fake_flashinfer.scaled_fp4_grouped_quantize = scaled_fp4_grouped_quantize + if silu_and_mul_scaled_nvfp4_experts_quantize is not None: + fake_flashinfer.silu_and_mul_scaled_nvfp4_experts_quantize = silu_and_mul_scaled_nvfp4_experts_quantize + + fake_blockscaled = types.ModuleType("flashinfer.cute_dsl.blockscaled_gemm") + if grouped_gemm_nt_masked is not None: + fake_blockscaled.grouped_gemm_nt_masked = grouped_gemm_nt_masked + fake_cute_dsl = types.ModuleType("flashinfer.cute_dsl") + fake_cute_dsl.blockscaled_gemm = fake_blockscaled + + sys.modules["flashinfer"] = fake_flashinfer + sys.modules["flashinfer.cute_dsl"] = fake_cute_dsl + sys.modules["flashinfer.cute_dsl.blockscaled_gemm"] = fake_blockscaled + + return prev_flashinfer, prev_cute_dsl, prev_blockscaled + + +class TestFlashinferCuteDslMoeMasked(unittest.TestCase): + """Unit tests for flashinfer_cutedsl_moe_masked.""" + + def test_flashinfer_cutedsl_moe_masked_runs_with_bf16_inputs(self): + """ + Verify that flashinfer_cutedsl_moe_masked can run end-to-end with + standard (bf16) inputs when FlashInfer kernels are mocked. + This directly exercises the path where hidden_states[1] is None. + """ + + num_experts = 2 + m = 3 + k = 32 + n = 8 + + # Standard (non-prequantized) path: bf16 [num_experts, m, k], hidden_states[1] is None. + a_bf16 = paddle.zeros([num_experts, m, k], dtype=paddle.bfloat16) + hidden_states = (a_bf16, None) + + input_global_scale = paddle.ones([num_experts], dtype=paddle.float32) + masked_m = paddle.full([num_experts], m, dtype=paddle.int32) + + w1 = paddle.zeros([num_experts, 2 * n, k // 2], dtype=paddle.uint8) + # blockscale tensors must use float8_e4m3fn to satisfy runtime dtype checks + w1_blockscale = paddle.zeros([1], dtype=paddle.float8_e4m3fn) + w1_alpha = paddle.ones([num_experts], dtype=paddle.float32) + + w2 = paddle.zeros([num_experts, k, n // 2], dtype=paddle.uint8) + a2_global_scale = paddle.ones([num_experts], dtype=paddle.float32) + w2_blockscale = paddle.zeros([1], dtype=paddle.float8_e4m3fn) + w2_alpha = paddle.ones([num_experts], dtype=paddle.float32) + + def fake_scaled_fp4_grouped_quantize(x, masked_m, input_global_scale): + # x: [num_experts, m, k] -> produce pre-quantized tensors with valid shapes. + num_experts, m, k = x.shape + a_q = paddle.zeros([m, k // 2, num_experts], dtype=paddle.uint8) + a_q_sf = paddle.zeros([m, k // 16, num_experts], dtype=paddle.float8_e4m3fn) + return a_q, a_q_sf + + def fake_grouped_gemm_nt_masked(a, b, out, masked_m, **kwargs): + # Simply zero out the output tensor while preserving shape and dtype. + out.set_value(paddle.zeros_like(out)) + + def fake_silu_and_mul_scaled_nvfp4_experts_quantize(x, masked_m, a2_global_scale): + # Return dummy FP4-packed activations; grouped_gemm_nt_masked ignores the contents. + num_experts, m, k2 = x.shape # k2 = 2 * n + n = k2 // 2 + diq = paddle.zeros([m, n // 2, num_experts], dtype=paddle.uint8) + diq_sf = paddle.zeros([m, n // 8, num_experts], dtype=paddle.float8_e4m3fn) + return diq, diq_sf + + # Install fake flashinfer BEFORE importing the target module, + # so that no real flashinfer/torch is loaded. + prev_flashinfer, prev_cute_dsl, prev_blockscaled = _install_fake_flashinfer_for_cutedsl( + scaled_fp4_grouped_quantize=fake_scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize=fake_silu_and_mul_scaled_nvfp4_experts_quantize, + grouped_gemm_nt_masked=fake_grouped_gemm_nt_masked, + ) + + cutedsl_moe_module = importlib.import_module("fastdeploy.model_executor.layers.moe.flashinfer_cutedsl_moe") + out = cutedsl_moe_module.flashinfer_cutedsl_moe_masked( + hidden_states=hidden_states, + input_global_scale=input_global_scale, + w1=w1, + w1_blockscale=w1_blockscale, + w1_alpha=w1_alpha, + w2=w2, + a2_global_scale=a2_global_scale, + w2_blockscale=w2_blockscale, + w2_alpha=w2_alpha, + masked_m=masked_m, + ) + + self.assertEqual(list(out.shape), [num_experts, m, k]) + self.assertEqual(out.dtype, paddle.bfloat16) + + +if __name__ == "__main__": + unittest.main()