Skip to content

Commit ef6649a

Browse files
authored
[Optimize] Optimize tensorwise fp8 performance (PaddlePaddle#2729)
* [Optimize] Optimize tensorwise fp8 performance
1 parent 1b54a28 commit ef6649a

File tree

6 files changed

+318
-88
lines changed

6 files changed

+318
-88
lines changed

custom_ops/gpu_ops/cpp_extensions.cc

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,28 @@ std::vector<paddle::Tensor> NoauxTc(
468468
int topk,
469469
float routed_scaling_factor);
470470

471+
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
472+
const paddle::Tensor& x,
473+
const paddle::Tensor& y,
474+
const paddle::optional<paddle::Tensor>& bias,
475+
bool trans_x,
476+
bool trans_y,
477+
float scale, // only support per-tensor quantization
478+
std::string output_dtype,
479+
std::string activation_type);
480+
481+
paddle::Tensor MoeFusedHadamardQuantFp8Func(
482+
const paddle::Tensor &input,
483+
const paddle::Tensor &scale,
484+
const paddle::Tensor &topk_ids,
485+
const int top_k,
486+
const int intermediate_size,
487+
const bool tiled);
488+
489+
paddle::Tensor FusedHadamardQuantFp8Func(
490+
const paddle::Tensor &input,
491+
const float scale);
492+
471493
PYBIND11_MODULE(fastdeploy_ops, m) {
472494

473495
m.def("get_expert_token_num", &GetExpertTokenNum, py::arg("topk_ids"),
@@ -697,38 +719,21 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
697719
"text_image_gather_scatter function");
698720

699721
m.def("count_tokens_per_expert_func", &count_tokens_per_expert_func);
722+
700723
m.def("tritonmoe_preprocess_func", &tritonmoe_preprocess_kernel);
701724

702725
m.def("MoeWna16MarlinGemmApi", &MoeWna16MarlinGemmApi,
703-
py::arg("a"),
704-
py::arg("c_or_none"),
705-
py::arg("b_q_weight"),
706-
py::arg("b_scales"),
707-
py::arg("global_scale_or_none"),
708-
py::arg("b_zeros_or_none"),
709-
py::arg("g_idx_or_none"),
710-
py::arg("perm_or_none"),
711-
py::arg("workspace"),
712-
py::arg("sorted_token_ids"),
713-
py::arg("expert_ids"),
714-
py::arg("num_tokens_post_padded"),
715-
py::arg("topk_weights"),
716-
py::arg("moe_block_size"),
717-
py::arg("top_k"),
718-
py::arg("mul_topk_weights"),
719-
py::arg("is_ep"),
720-
py::arg("b_q_type_str"),
721-
py::arg("size_m"),
722-
py::arg("size_n"),
723-
py::arg("size_k"),
724-
py::arg("is_k_full"),
725-
py::arg("use_atomic_add"),
726-
py::arg("use_fp32_reduce"),
727-
py::arg("is_zp_float"));
726+
py::arg("a"), py::arg("c_or_none"), py::arg("b_q_weight"),
727+
py::arg("b_scales"), py::arg("global_scale_or_none"), py::arg("b_zeros_or_none"),
728+
py::arg("g_idx_or_none"), py::arg("perm_or_none"), py::arg("workspace"), py::arg("sorted_token_ids"),
729+
py::arg("expert_ids"), py::arg("num_tokens_post_padded"), py::arg("topk_weights"), py::arg("moe_block_size"),
730+
py::arg("top_k"), py::arg("mul_topk_weights"), py::arg("is_ep"), py::arg("b_q_type_str"),
731+
py::arg("size_m"), py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"), py::arg("use_atomic_add"),
732+
py::arg("use_fp32_reduce"), py::arg("is_zp_float"));
733+
728734
m.def("get_position_ids_and_mask_encoder_batch", &GetPositionIdsAndMaskEncoderBatch,
729735
"get_position_ids_and_mask_encoder_batch function");
730736

731-
732737
/**
733738
* cutlass_scaled_mm.cu
734739
* cutlass_scaled_mm
@@ -753,6 +758,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
753758
m.def("dynamic_per_token_scaled_fp8_quant", &DynamicPerTokenScaledFp8Quant,
754759
"dynamic_per_token_scaled_fp8_quant function",
755760
py::arg("out"), py::arg("input"), py::arg("scales"), py::arg("scale_ub"));
761+
756762
m.def("decode_mla_write_cache", &DecodeMLAWriteCacheKernel, "decode_mla_write_cache function");
757763

758764
m.def("prefill_mla_write_cache", &PrefillMLAWriteCacheKernel, "prefill_mla_write_cache function");
@@ -762,4 +768,16 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
762768
m.def("multi_head_latent_attention", &MultiHeadLatentAttention, "multi_head_latent_attention function");
763769

764770
m.def("noaux_tc",&NoauxTc, "noaux_tc for Deepseekv3 MoE compute");
771+
772+
m.def("cutlass_fp8_fp8_half_gemm_fused", &cutlass_fp8_fp8_half_gemm_func,
773+
py::arg("x"), py::arg("y"), py::arg("bias"), py::arg("transpose_x"),
774+
py::arg("transpose_y"), py::arg("scale"), py::arg("output_dtype"),
775+
py::arg("activation_type"), "cutlass_fp8_fp8_half_gemm_fused function");
776+
777+
m.def("moe_fused_hadamard_quant_fp8", &MoeFusedHadamardQuantFp8Func,
778+
py::arg("input"), py::arg("scale"), py::arg("topk_ids"),
779+
py::arg("top_k"), py::arg("intermediate_size"), py::arg("tiled"), "moe_fused_hadamard_quant_fp8 function");
780+
781+
m.def("fused_hadamard_quant_fp8", &FusedHadamardQuantFp8Func,
782+
py::arg("input"), py::arg("scale"), "fused_hadamard_quant_fp8 function");
765783
}

custom_ops/gpu_ops/fp8_gemm_with_cutlass/fp8_fp8_half_gemm.cu

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "fp8_fp8_half_cuda_core_gemm.h"
2020

2121

22-
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
22+
paddle::Tensor cutlass_fp8_fp8_half_gemm_func(
2323
const paddle::Tensor& x,
2424
const paddle::Tensor& y,
2525
const paddle::optional<paddle::Tensor>& bias,
@@ -142,7 +142,7 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
142142
{
143143
if(output_dtype == "bfloat16") {
144144
cuda_core_gemm_launcher<__nv_fp8_e4m3, __nv_bfloat16>(params);
145-
145+
146146
} else {
147147
cuda_core_gemm_launcher<__nv_fp8_e4m3, half>(params);
148148
}
@@ -174,7 +174,21 @@ std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
174174
fuse_gemm_config};
175175
fp8_fp8_gemm_scale_bias_act(params);
176176
}
177-
return {out};
177+
return out;
178+
}
179+
180+
std::vector<paddle::Tensor> cutlass_fp8_fp8_half_gemm(
181+
const paddle::Tensor& x,
182+
const paddle::Tensor& y,
183+
const paddle::optional<paddle::Tensor>& bias,
184+
bool trans_x,
185+
bool trans_y,
186+
float scale, // only support per-tensor quantization
187+
std::string output_dtype,
188+
std::string activation_type) {
189+
return {cutlass_fp8_fp8_half_gemm_func(
190+
x, y, bias, trans_x, trans_y, scale,
191+
output_dtype, activation_type)};
178192
}
179193

180194
std::vector<std::vector<int64_t>> CutlassFp8Fp8HalfGemmFusedInferShape(
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <fcntl.h>
16+
#include <stdio.h>
17+
#include <stdlib.h>
18+
#include <string.h>
19+
#include <sys/mman.h>
20+
#include <sys/stat.h>
21+
#include <sys/types.h>
22+
#include <unistd.h>
23+
#include <algorithm>
24+
#include "helper.h"
25+
26+
__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) {
27+
int lane_id = threadIdx.x % 32;
28+
#pragma unroll
29+
for (int step = 0; step < 5; ++step) {
30+
const int lane_mask = 1 << step;
31+
const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f;
32+
__nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask);
33+
x = sign * x + x_val_other;
34+
}
35+
}
36+
37+
__global__ void MoeFusedHadamardQuantFp8Kernel(
38+
const __nv_bfloat16* __restrict__ input,
39+
const float* __restrict__ scale,
40+
const int64_t* __restrict__ topk_ids,
41+
__nv_fp8_e4m3* out,
42+
const int top_k,
43+
const int intermediate_size,
44+
const int64_t numel
45+
) {
46+
int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x;
47+
if (out_idx >= numel) return;
48+
49+
int64_t token_idx = out_idx / (top_k * intermediate_size);
50+
int64_t topk_idx = (out_idx / intermediate_size) % top_k;
51+
int64_t inter_idx = out_idx % intermediate_size;
52+
53+
int64_t input_idx = token_idx * intermediate_size + inter_idx;
54+
if (input_idx >= numel / top_k) return;
55+
56+
int64_t expert_id = topk_ids[token_idx * top_k + topk_idx];
57+
float scale_value = scale[expert_id];
58+
59+
__nv_bfloat16 x = input[input_idx];
60+
hadamard32_warp(x);
61+
62+
float x_fp32 = __bfloat162float(x);
63+
float quantized = x_fp32 / scale_value;
64+
out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized);
65+
}
66+
67+
__global__ void MoeFusedHadamardQuantFp8TiledKernel(
68+
const __nv_bfloat16* __restrict__ input,
69+
const float* __restrict__ scale,
70+
const int64_t* __restrict__ topk_ids,
71+
__nv_fp8_e4m3* out,
72+
const int top_k,
73+
const int intermediate_size,
74+
const int64_t numel
75+
) {
76+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
77+
if (idx >= numel) return;
78+
79+
int64_t token_idx = idx / intermediate_size;
80+
int64_t expert_id = topk_ids[token_idx];
81+
float scale_value = scale[expert_id];
82+
83+
__nv_bfloat16 x = input[idx];
84+
hadamard32_warp(x);
85+
86+
float x_fp32 = __bfloat162float(x);
87+
float quantized = x_fp32 / scale_value;
88+
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
89+
}
90+
91+
std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8(
92+
const paddle::Tensor &input,
93+
const paddle::Tensor &scale,
94+
const paddle::Tensor &topk_ids,
95+
const int top_k,
96+
const int intermediate_size,
97+
const bool tiled) {
98+
int64_t numel = input.numel();
99+
if (!tiled) numel *= top_k;
100+
paddle::Tensor out = GetEmptyTensor(
101+
{numel / intermediate_size, intermediate_size},
102+
paddle::DataType::FLOAT8_E4M3FN,
103+
input.place());
104+
constexpr int64_t thread_per_block = 256;
105+
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
106+
auto stream = input.stream();
107+
if (tiled) {
108+
MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>(
109+
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
110+
scale.data<float>(),
111+
topk_ids.data<int64_t>(),
112+
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
113+
top_k,
114+
intermediate_size,
115+
numel
116+
);
117+
} else {
118+
MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
119+
reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()),
120+
scale.data<float>(),
121+
topk_ids.data<int64_t>(),
122+
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
123+
top_k,
124+
intermediate_size,
125+
numel
126+
);
127+
}
128+
return {out};
129+
}
130+
131+
PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8)
132+
.Inputs({"input", "scale", "topk_ids"})
133+
.Outputs({"output"})
134+
.Attrs({"top_k: int",
135+
"intermediate_size: int",
136+
"tiled: bool"})
137+
.SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8));
138+
139+
140+
paddle::Tensor MoeFusedHadamardQuantFp8Func(
141+
const paddle::Tensor &input,
142+
const paddle::Tensor &scale,
143+
const paddle::Tensor &topk_ids,
144+
const int top_k,
145+
const int intermediate_size,
146+
const bool tiled) {
147+
return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0];
148+
}
149+
150+
151+
__global__ void FusedHadamardQuantFp8Kernel(
152+
const __nv_bfloat16* __restrict__ input,
153+
__nv_fp8_e4m3* out,
154+
const float scale,
155+
const int64_t numel) {
156+
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
157+
if (idx >= numel) return;
158+
159+
__nv_bfloat16 x = input[idx];
160+
hadamard32_warp(x);
161+
162+
float x_fp32 = __bfloat162float(x);
163+
float quantized = x_fp32 / scale;
164+
out[idx] = static_cast<__nv_fp8_e4m3>(quantized);
165+
}
166+
167+
std::vector<paddle::Tensor> FusedHadamardQuantFp8(
168+
const paddle::Tensor &input,
169+
const float scale) {
170+
int64_t numel = input.numel();
171+
paddle::Tensor out = GetEmptyTensor(
172+
input.dims(),
173+
paddle::DataType::FLOAT8_E4M3FN,
174+
input.place());
175+
constexpr int64_t thread_per_block = 256;
176+
int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block;
177+
auto stream = input.stream();
178+
FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>(
179+
reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()),
180+
reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()),
181+
scale,
182+
numel
183+
);
184+
return {out};
185+
}
186+
187+
PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8)
188+
.Inputs({"input"})
189+
.Outputs({"output"})
190+
.Attrs({"scale: float"})
191+
.SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8));
192+
193+
194+
paddle::Tensor FusedHadamardQuantFp8Func(
195+
const paddle::Tensor &input,
196+
const float scale) {
197+
return FusedHadamardQuantFp8(input, scale)[0];
198+
}

custom_ops/setup_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ def find_end_files(directory, end_str):
442442
"gpu_ops/scaled_gemm_f8_i4_f16_weight_quantize.cu",
443443
"gpu_ops/cutlass_kernels/cutlass_heuristic.cu",
444444
"gpu_ops/cutlass_kernels/cutlass_preprocessors.cu",
445+
"gpu_ops/fused_hadamard_quant_fp8.cu"
445446
]
446447

447448
sources += find_end_files(fp8_auto_gen_directory, ".cu")

0 commit comments

Comments
 (0)