|
1 | 1 | #include <sycl/sycl.hpp> |
2 | 2 | #include <cmath> |
3 | 3 | #include <algorithm> |
| 4 | +#include <numeric> |
4 | 5 | #include "utils.h" |
5 | 6 | #include "dispatch_utils.h" |
6 | 7 |
|
| 8 | +#include <c10/util/Float8_e4m3fn.h> |
| 9 | +#include <c10/util/Float8_e5m2.h> |
| 10 | +#include "quantization/fp8/quant_utils.h" |
| 11 | + |
7 | 12 | #define VLLM_LDG(arg) *(arg) |
8 | 13 |
|
9 | 14 | namespace vllm { |
@@ -172,6 +177,56 @@ class act_and_mul_vec_kernel { |
172 | 177 | const int d_; |
173 | 178 | }; |
174 | 179 |
|
| 180 | +template < |
| 181 | + typename scalar_t, |
| 182 | + scalar_t (*ACT_FN)(const scalar_t&), |
| 183 | + typename fp8_type, |
| 184 | + int VEC_SIZE> |
| 185 | +class act_and_mul_quant_vec_kernel { |
| 186 | + public: |
| 187 | + act_and_mul_quant_vec_kernel( |
| 188 | + fp8_type* __restrict__ out, // [..., d] |
| 189 | + const scalar_t* __restrict__ input, // [..., 2 * d] |
| 190 | + const float* __restrict__ scale, // [1] |
| 191 | + const int d) |
| 192 | + : out_(out), input_(input), scale_(scale), d_(d) {} |
| 193 | + |
| 194 | + void operator()(sycl::nd_item<1> item) const { |
| 195 | + using vec_t = vllm::xpu::aligned_vec<scalar_t, VEC_SIZE>; |
| 196 | + |
| 197 | + const int64_t token_idx = item.get_group(0); |
| 198 | + const int64_t offset = item.get_local_linear_id(); |
| 199 | + const int64_t step = item.get_local_range(0); |
| 200 | + const int64_t bound = d_ / VEC_SIZE; |
| 201 | + |
| 202 | + const float inv_scale = 1.0f / (*scale_); |
| 203 | + const float fp8_max = static_cast<float>(fp8::quant_type_max_v<fp8_type>); |
| 204 | + |
| 205 | + // x and y halves are laid out contiguously: [x0..xd-1, y0..yd-1] |
| 206 | + const auto* v_x = |
| 207 | + reinterpret_cast<const vec_t*>(input_) + token_idx * bound * 2; |
| 208 | + const auto* v_y = v_x + bound; |
| 209 | + |
| 210 | + for (int64_t i = offset; i < bound; i += step) { |
| 211 | + vec_t xv = v_x[i]; |
| 212 | + vec_t yv = v_y[i]; |
| 213 | +#pragma unroll |
| 214 | + for (int j = 0; j < VEC_SIZE; j++) { |
| 215 | + float val = static_cast<float>(ACT_FN(xv[j]) * yv[j]) * inv_scale; |
| 216 | + float clamped = sycl::fmax(-fp8_max, sycl::fmin(val, fp8_max)); |
| 217 | + out_[token_idx * d_ + i * VEC_SIZE + j] = |
| 218 | + static_cast<fp8_type>(clamped); |
| 219 | + } |
| 220 | + } |
| 221 | + } |
| 222 | + |
| 223 | + private: |
| 224 | + fp8_type* __restrict__ out_; // [..., d] |
| 225 | + const scalar_t* __restrict__ input_; // [..., 2 * d] |
| 226 | + const float* __restrict__ scale_; // [1] |
| 227 | + const int d_; |
| 228 | +}; |
| 229 | + |
175 | 230 | template <typename T> |
176 | 231 | [[intel::device_indirectly_callable]] inline __attribute__((always_inline)) T |
177 | 232 | swigluoai_and_mul(const T& gate, const T& up, float alpha, float limit) { |
@@ -352,6 +407,73 @@ void silu_and_mul( |
352 | 407 | }); |
353 | 408 | } |
354 | 409 |
|
| 410 | +// Fused SiLU + Mul + FP8 Quantization |
| 411 | +// Input: [..., 2*d] in FP16/BF16, Output: [..., d] in FP8 |
| 412 | +// Dispatches to the vectorized kernel (VEC_SIZE=1..8) based on alignment. |
| 413 | +#define LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, N) \ |
| 414 | + case N: { \ |
| 415 | + int64_t wg_size = \ |
| 416 | + std::min(static_cast<int64_t>(d / N), static_cast<int64_t>(1024)); \ |
| 417 | + VLLM_DISPATCH_FP8_TYPES( \ |
| 418 | + out.scalar_type(), "act_and_mul_quant_vec_kernel_fp8", [&] { \ |
| 419 | + auto out_ptr = out.data_ptr<fp8_t>(); \ |
| 420 | + queue.submit([&](sycl::handler& cgh) { \ |
| 421 | + cgh.parallel_for( \ |
| 422 | + sycl::nd_range<1>(num_tokens * wg_size, wg_size), \ |
| 423 | + vllm::act_and_mul_quant_vec_kernel<sycl_t, KERNEL, fp8_t, N>( \ |
| 424 | + out_ptr, (sycl_t*)input_ptr, scale_ptr, d)); \ |
| 425 | + }); \ |
| 426 | + }); \ |
| 427 | + break; \ |
| 428 | + } |
| 429 | + |
| 430 | +#define LAUNCH_ACTIVATION_GATE_QUANT_KERNEL(KERNEL) \ |
| 431 | + using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type; \ |
| 432 | + int d = input.size(-1) / 2; \ |
| 433 | + int64_t num_tokens = input.numel() / input.size(-1); \ |
| 434 | + if (num_tokens == 0) { \ |
| 435 | + return; \ |
| 436 | + } \ |
| 437 | + auto input_ptr = input.data_ptr<scalar_t>(); \ |
| 438 | + auto scale_ptr = scale.data_ptr<float>(); \ |
| 439 | + at::DeviceGuard device_guard(input.device()); \ |
| 440 | + auto& queue = vllm::xpu::vllmGetQueue(); \ |
| 441 | + /* Compute vec_size like non-quant path: gcd(4*sizeof(float)/sizeof, d) */ \ |
| 442 | + int vec_size = static_cast<int>(sizeof(float) * 4 / sizeof(sycl_t)); \ |
| 443 | + { \ |
| 444 | + int64_t tmp_wg = \ |
| 445 | + std::min(static_cast<int64_t>(d), static_cast<int64_t>(1024)); \ |
| 446 | + while (vec_size > 1 && (vec_size >> 1) * tmp_wg >= d) { \ |
| 447 | + vec_size = vec_size >> 1; \ |
| 448 | + } \ |
| 449 | + } \ |
| 450 | + if (d % vec_size != 0) vec_size = 1; \ |
| 451 | + switch (vec_size) { \ |
| 452 | + LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 1); \ |
| 453 | + LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 2); \ |
| 454 | + LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 4); \ |
| 455 | + LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 8); \ |
| 456 | + LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 16); \ |
| 457 | + default: \ |
| 458 | + TORCH_CHECK(false, "Unsupported vector size: ", vec_size); \ |
| 459 | + } |
| 460 | + |
| 461 | +void silu_and_mul_quant( |
| 462 | + torch::Tensor& out, // [..., d] FP8 |
| 463 | + torch::Tensor& input, // [..., 2 * d] FP16/BF16 |
| 464 | + torch::Tensor& scale) // [1] FP32 |
| 465 | +{ |
| 466 | + TORCH_CHECK( |
| 467 | + out.dtype() == torch::kFloat8_e4m3fn || |
| 468 | + out.dtype() == torch::kFloat8_e5m2); |
| 469 | + TORCH_CHECK( |
| 470 | + input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16); |
| 471 | + TORCH_CHECK(input.size(-1) % 2 == 0); |
| 472 | + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_quant", [&] { |
| 473 | + LAUNCH_ACTIVATION_GATE_QUANT_KERNEL(vllm::silu_kernel); |
| 474 | + }); |
| 475 | +} |
| 476 | + |
355 | 477 | void mul_and_silu( |
356 | 478 | torch::Tensor& out, // [..., d] |
357 | 479 | torch::Tensor& input) // [..., 2 * d] |
|
0 commit comments