|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include <fcntl.h> |
| 16 | +#include <stdio.h> |
| 17 | +#include <stdlib.h> |
| 18 | +#include <string.h> |
| 19 | +#include <sys/mman.h> |
| 20 | +#include <sys/stat.h> |
| 21 | +#include <sys/types.h> |
| 22 | +#include <unistd.h> |
| 23 | +#include <algorithm> |
| 24 | +#include "helper.h" |
| 25 | + |
| 26 | +__device__ __forceinline__ void hadamard32_warp(__nv_bfloat16& x) { |
| 27 | + int lane_id = threadIdx.x % 32; |
| 28 | +#pragma unroll |
| 29 | + for (int step = 0; step < 5; ++step) { |
| 30 | + const int lane_mask = 1 << step; |
| 31 | + const __nv_bfloat16 sign = (lane_id & lane_mask) ? -1.f : 1.f; |
| 32 | + __nv_bfloat16 x_val_other = __shfl_xor_sync(0xffffffff, x, lane_mask); |
| 33 | + x = sign * x + x_val_other; |
| 34 | + } |
| 35 | +} |
| 36 | + |
| 37 | +__global__ void MoeFusedHadamardQuantFp8Kernel( |
| 38 | + const __nv_bfloat16* __restrict__ input, |
| 39 | + const float* __restrict__ scale, |
| 40 | + const int64_t* __restrict__ topk_ids, |
| 41 | + __nv_fp8_e4m3* out, |
| 42 | + const int top_k, |
| 43 | + const int intermediate_size, |
| 44 | + const int64_t numel |
| 45 | +) { |
| 46 | + int64_t out_idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 47 | + if (out_idx >= numel) return; |
| 48 | + |
| 49 | + int64_t token_idx = out_idx / (top_k * intermediate_size); |
| 50 | + int64_t topk_idx = (out_idx / intermediate_size) % top_k; |
| 51 | + int64_t inter_idx = out_idx % intermediate_size; |
| 52 | + |
| 53 | + int64_t input_idx = token_idx * intermediate_size + inter_idx; |
| 54 | + if (input_idx >= numel / top_k) return; |
| 55 | + |
| 56 | + int64_t expert_id = topk_ids[token_idx * top_k + topk_idx]; |
| 57 | + float scale_value = scale[expert_id]; |
| 58 | + |
| 59 | + __nv_bfloat16 x = input[input_idx]; |
| 60 | + hadamard32_warp(x); |
| 61 | + |
| 62 | + float x_fp32 = __bfloat162float(x); |
| 63 | + float quantized = x_fp32 / scale_value; |
| 64 | + out[out_idx] = static_cast<__nv_fp8_e4m3>(quantized); |
| 65 | +} |
| 66 | + |
| 67 | +__global__ void MoeFusedHadamardQuantFp8TiledKernel( |
| 68 | + const __nv_bfloat16* __restrict__ input, |
| 69 | + const float* __restrict__ scale, |
| 70 | + const int64_t* __restrict__ topk_ids, |
| 71 | + __nv_fp8_e4m3* out, |
| 72 | + const int top_k, |
| 73 | + const int intermediate_size, |
| 74 | + const int64_t numel |
| 75 | +) { |
| 76 | + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 77 | + if (idx >= numel) return; |
| 78 | + |
| 79 | + int64_t token_idx = idx / intermediate_size; |
| 80 | + int64_t expert_id = topk_ids[token_idx]; |
| 81 | + float scale_value = scale[expert_id]; |
| 82 | + |
| 83 | + __nv_bfloat16 x = input[idx]; |
| 84 | + hadamard32_warp(x); |
| 85 | + |
| 86 | + float x_fp32 = __bfloat162float(x); |
| 87 | + float quantized = x_fp32 / scale_value; |
| 88 | + out[idx] = static_cast<__nv_fp8_e4m3>(quantized); |
| 89 | +} |
| 90 | + |
| 91 | +std::vector<paddle::Tensor> MoeFusedHadamardQuantFp8( |
| 92 | + const paddle::Tensor &input, |
| 93 | + const paddle::Tensor &scale, |
| 94 | + const paddle::Tensor &topk_ids, |
| 95 | + const int top_k, |
| 96 | + const int intermediate_size, |
| 97 | + const bool tiled) { |
| 98 | + int64_t numel = input.numel(); |
| 99 | + if (!tiled) numel *= top_k; |
| 100 | + paddle::Tensor out = GetEmptyTensor( |
| 101 | + {numel / intermediate_size, intermediate_size}, |
| 102 | + paddle::DataType::FLOAT8_E4M3FN, |
| 103 | + input.place()); |
| 104 | + constexpr int64_t thread_per_block = 256; |
| 105 | + int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block; |
| 106 | + auto stream = input.stream(); |
| 107 | + if (tiled) { |
| 108 | + MoeFusedHadamardQuantFp8TiledKernel<<<block_per_grid, thread_per_block, 0, stream>>>( |
| 109 | + reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()), |
| 110 | + scale.data<float>(), |
| 111 | + topk_ids.data<int64_t>(), |
| 112 | + reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()), |
| 113 | + top_k, |
| 114 | + intermediate_size, |
| 115 | + numel |
| 116 | + ); |
| 117 | + } else { |
| 118 | + MoeFusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>( |
| 119 | + reinterpret_cast<const __nv_bfloat16*>(input.data<phi::dtype::bfloat16>()), |
| 120 | + scale.data<float>(), |
| 121 | + topk_ids.data<int64_t>(), |
| 122 | + reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()), |
| 123 | + top_k, |
| 124 | + intermediate_size, |
| 125 | + numel |
| 126 | + ); |
| 127 | + } |
| 128 | + return {out}; |
| 129 | +} |
| 130 | + |
| 131 | +PD_BUILD_STATIC_OP(moe_fused_hadamard_quant_fp8) |
| 132 | + .Inputs({"input", "scale", "topk_ids"}) |
| 133 | + .Outputs({"output"}) |
| 134 | + .Attrs({"top_k: int", |
| 135 | + "intermediate_size: int", |
| 136 | + "tiled: bool"}) |
| 137 | + .SetKernelFn(PD_KERNEL(MoeFusedHadamardQuantFp8)); |
| 138 | + |
| 139 | + |
| 140 | +paddle::Tensor MoeFusedHadamardQuantFp8Func( |
| 141 | + const paddle::Tensor &input, |
| 142 | + const paddle::Tensor &scale, |
| 143 | + const paddle::Tensor &topk_ids, |
| 144 | + const int top_k, |
| 145 | + const int intermediate_size, |
| 146 | + const bool tiled) { |
| 147 | + return MoeFusedHadamardQuantFp8(input, scale, topk_ids, top_k, intermediate_size, tiled)[0]; |
| 148 | +} |
| 149 | + |
| 150 | + |
| 151 | +__global__ void FusedHadamardQuantFp8Kernel( |
| 152 | + const __nv_bfloat16* __restrict__ input, |
| 153 | + __nv_fp8_e4m3* out, |
| 154 | + const float scale, |
| 155 | + const int64_t numel) { |
| 156 | + int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; |
| 157 | + if (idx >= numel) return; |
| 158 | + |
| 159 | + __nv_bfloat16 x = input[idx]; |
| 160 | + hadamard32_warp(x); |
| 161 | + |
| 162 | + float x_fp32 = __bfloat162float(x); |
| 163 | + float quantized = x_fp32 / scale; |
| 164 | + out[idx] = static_cast<__nv_fp8_e4m3>(quantized); |
| 165 | +} |
| 166 | + |
| 167 | +std::vector<paddle::Tensor> FusedHadamardQuantFp8( |
| 168 | + const paddle::Tensor &input, |
| 169 | + const float scale) { |
| 170 | + int64_t numel = input.numel(); |
| 171 | + paddle::Tensor out = GetEmptyTensor( |
| 172 | + input.dims(), |
| 173 | + paddle::DataType::FLOAT8_E4M3FN, |
| 174 | + input.place()); |
| 175 | + constexpr int64_t thread_per_block = 256; |
| 176 | + int64_t block_per_grid = (numel + thread_per_block - 1) / thread_per_block; |
| 177 | + auto stream = input.stream(); |
| 178 | + FusedHadamardQuantFp8Kernel<<<block_per_grid, thread_per_block, 0, stream>>>( |
| 179 | + reinterpret_cast<const __nv_bfloat16*>(input.data<paddle::bfloat16>()), |
| 180 | + reinterpret_cast<__nv_fp8_e4m3*>(out.mutable_data<phi::dtype::float8_e4m3fn>()), |
| 181 | + scale, |
| 182 | + numel |
| 183 | + ); |
| 184 | + return {out}; |
| 185 | +} |
| 186 | + |
| 187 | +PD_BUILD_STATIC_OP(fused_hadamard_quant_fp8) |
| 188 | + .Inputs({"input"}) |
| 189 | + .Outputs({"output"}) |
| 190 | + .Attrs({"scale: float"}) |
| 191 | + .SetKernelFn(PD_KERNEL(FusedHadamardQuantFp8)); |
| 192 | + |
| 193 | + |
| 194 | +paddle::Tensor FusedHadamardQuantFp8Func( |
| 195 | + const paddle::Tensor &input, |
| 196 | + const float scale) { |
| 197 | + return FusedHadamardQuantFp8(input, scale)[0]; |
| 198 | +} |
0 commit comments