Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion custom_ops/gpu_ops/moe/prefill_permute_to_masked_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,21 @@ std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
}
}
}
case paddle::DataType::UINT8: {
switch (scale.dtype()) {
case paddle::DataType::FLOAT32: {
switch (topk) {
DISPATCH_TOPK(paddle::DataType::UINT8, paddle::DataType::FLOAT32, 4)
DISPATCH_TOPK(paddle::DataType::UINT8, paddle::DataType::FLOAT32, 6)
DISPATCH_TOPK(paddle::DataType::UINT8, paddle::DataType::FLOAT32, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 6 or 8");
}
}
default:
PD_THROW("Unsupported scale dtype for UINT8 x, must be float32");
}
}
case paddle::DataType::BFLOAT16: {
switch (scale.dtype()) {
case paddle::DataType::FLOAT32: {
Expand All @@ -235,10 +250,20 @@ std::vector<paddle::Tensor> PrefillPermuteToMaskedGemm(
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
case paddle::DataType::UINT8: {
switch (topk) {
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::UINT8, 4)
DISPATCH_TOPK(
paddle::DataType::BFLOAT16, paddle::DataType::UINT8, 8)
default:
PD_THROW("Unsupported topk value, must be 4 or 8");
}
}
}
}
default:
PD_THROW("Unsupported dtype, must be float8_e4m3fn or bfloat16");
PD_THROW("Unsupported dtype, must be uint8, float8_e4m3fn or bfloat16");
}

#undef DISPATCH_TOPK
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,8 @@ def _validate_split_kv_size(value: int) -> int:
# train-infer consistency, used in RL
# Whether to align RoPE and moe gate precision with training
"FD_ENABLE_RL": lambda: int(os.getenv("FD_ENABLE_RL", "0")),
# Whether to enable FP4 communication quantization for DeepEP prefill dispatch
"FD_USE_NVFP4_COMM_QUANT": lambda: bool(int(os.getenv("FD_USE_NVFP4_COMM_QUANT", "0"))),
# Whether to use phi FP8 quantization,if 1,use paddle default.
"FD_USE_PHI_FP8_QUANT": lambda: bool(int(os.getenv("FD_USE_PHI_FP8_QUANT", "1"))),
# Enables the Paddle/phi combined TopK operator only when topk_method == noaux_tc,
Expand Down
2 changes: 2 additions & 0 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class ForwardMeta:
moe_num_chunk: int = 1
max_moe_num_chunk: int = 1

audio_token_num: int = 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 audio_token_num 字段与本 PR 的 FP4 通信量化功能似乎无关,且搜索整个代码库后未发现任何读取或写入此字段的代码。

请确认:这个字段是否应该在单独的 PR 中提交?如果是后续功能的前置准备,建议在注释中说明用途。


# for zero size
is_zero_size: bool = False
# for prefill
Expand Down
80 changes: 60 additions & 20 deletions fastdeploy/model_executor/layers/quantization/nvfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,6 +676,22 @@ def apply_ep_prefill(
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)

use_fp4_comm_quant = envs.FD_USE_NVFP4_COMM_QUANT

if use_fp4_comm_quant:
# FP4 communication quantization: quantize to FP4 before dispatch,
# reducing communication volume by ~2x vs BF16.
x_fp4, x_fp4_scale = fp4_quantize(
x, layer.up_gate_proj_input_scale_quant, sf_vec_size=16, is_sf_swizzled_layout=False
)
x_fp4_scale = x_fp4_scale.view(paddle.float32) # float8_e4m3fn -> float32
dispatch_input = x_fp4
dispatch_scale = x_fp4_scale
else:
# BF16 communication: dispatch BF16 data without pre-quantization.
dispatch_input = x
dispatch_scale = None

event = deep_ep.Buffer.capture()

if self.ep_prefill_runner.num_worst_tokens <= 0:
Expand All @@ -690,11 +706,12 @@ def apply_ep_prefill(
handle,
event,
) = self.ep_prefill_runner.dispatch(
x,
dispatch_input,
topk_idx,
topk_weights,
expert_alignment=128,
previous_event=event,
x_scale_tensor=dispatch_scale,
)

if self.ep_prefill_runner.num_worst_tokens > 0:
Expand Down Expand Up @@ -752,25 +769,48 @@ def apply_ep_prefill(
)
)

max_token_num = layer.ep_size * max_tokens_per_rank
permute_input = permute_input.reshape([layer.num_local_experts, max_token_num, recv_x_value.shape[-1]])

# ffn_out: [num_local_experts, m, hidden_size]
# NVFP4 dispatch returns BF16 (no pre-quantized scale), so permute_scale is empty.
# Use per-expert 1/input_scale (up_gate_proj_input_scale_quant) as input_global_scale,
# consistent with apply_ep_decode which also uses this value directly.
ffn_out = flashinfer_cutedsl_moe_masked(
hidden_states=(permute_input, None),
input_global_scale=layer.up_gate_proj_input_scale_quant.expand([layer.num_local_experts]),
w1=layer.up_gate_proj_weight,
w1_blockscale=layer.up_gate_proj_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.down_proj_weight,
a2_global_scale=layer.down_proj_input_scale_quant.expand([layer.num_local_experts]),
w2_blockscale=layer.down_proj_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=token_nums_per_expert.squeeze(-1),
)
if recv_x_scale is not None:
# FP4 pre-quantized dispatch path:
# permute_input is uint8 [E, M, hidden//2] (FP4 packed)
# permute_scale is float32 [E, M, hidden//64] with custom strides
# from C++ kernel (physical layout [E, S, M], non-contiguous).
# Convert scale to float8_e4m3fn, then apply swizzle for
# grouped_gemm_nt_masked which expects SFA in swizzled layout
# (32, 4, rm, 4, rk, l) logical / (l, rm, rk, 32, 4, 4) physical.
# This is the same _process_scale_interleaved used for weight
# blockscale, converting flat [E, M, K] to swizzled layout.
permute_scale_fp8 = permute_scale.contiguous().view(paddle.float8_e4m3fn)
permute_scale_swizzled = _process_scale_interleaved(permute_scale_fp8)
permute_input_t = permute_input.transpose([1, 2, 0])
permute_scale_swizzled_t = permute_scale_swizzled.transpose([1, 2, 0])

ffn_out = flashinfer_cutedsl_moe_masked(
hidden_states=(permute_input_t, permute_scale_swizzled_t),
input_global_scale=None,
w1=layer.up_gate_proj_weight,
w1_blockscale=layer.up_gate_proj_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.down_proj_weight,
a2_global_scale=layer.down_proj_input_scale_quant.expand([layer.num_local_experts]),
w2_blockscale=layer.down_proj_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=token_nums_per_expert.squeeze(-1),
)
else:
# BF16 dispatch path: permute_input is BF16, quantize to FP4
# inside flashinfer_cutedsl_moe_masked
ffn_out = flashinfer_cutedsl_moe_masked(
hidden_states=(permute_input, None),
input_global_scale=layer.up_gate_proj_input_scale_quant.expand([layer.num_local_experts]),
w1=layer.up_gate_proj_weight,
w1_blockscale=layer.up_gate_proj_blockscale_swizzled,
w1_alpha=layer.g1_alphas,
w2=layer.down_proj_weight,
a2_global_scale=layer.down_proj_input_scale_quant.expand([layer.num_local_experts]),
w2_blockscale=layer.down_proj_blockscale_swizzled,
w2_alpha=layer.g2_alphas,
masked_m=token_nums_per_expert.squeeze(-1),
)

tmp_ffn_out = call_depermute_prefill_combine(
x=ffn_out,
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/model_executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):

def process_weight_transpose(layer, weight_name):
weight = getattr(layer, weight_name)
if not weight._is_initialized():
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 调试日志使用了非正式中文 "权重没初始化啊!",不适合生产代码。

  1. 日志语言应与项目其他日志保持一致(英文),且语气应正式;
  2. 日志级别建议使用 logger.warning 而非 logger.info——权重未初始化导致跳过 transpose 可能掩盖上游问题,warning 级别更有助于排查。

建议修改为:

if not weight._is_initialized():
    logger.warning("Weight '%s' is not initialized, skipping transpose.", weight_name)
    return

logger.info("权重没初始化啊!")
return
if len(weight.shape) == 2:
weight_shape = weight.shape[::-1]
elif len(weight.shape) == 3:
Expand Down
Loading