Skip to content

Commit 20f2472

Browse files
authored
[Fusion][Torch.compiler] Add fuse_norm_quant, fuse_act_quant and fused_qk_norm_rope kernel (vllm-project#267)
* Add fuse_norm_quant, fuse_act_quant and fused_qk_norm_rope kernel Signed-off-by: Lai, Yejing <yejing.lai@intel.com> * fix format Signed-off-by: Lai, Yejing <yejing.lai@intel.com> * fix format Signed-off-by: Lai, Yejing <yejing.lai@intel.com> * add fused_qk_norm_rope head_dim=512 case and update vec_size Signed-off-by: Lai, Yejing <yejing.lai@intel.com> --------- Signed-off-by: Lai, Yejing <yejing.lai@intel.com> Signed-off-by: Yejing Lai <yejing.lai@intel.com>
1 parent 3fb2f1e commit 20f2472

13 files changed

Lines changed: 2306 additions & 18 deletions

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,8 +439,10 @@ if(BASIC_KERNELS_ENABLED)
439439
set(VLLM_EXT_SRC
440440
"csrc/cache.cpp"
441441
"csrc/layernorm.cpp"
442+
"csrc/layernorm_quant.cpp"
442443
"csrc/activation.cpp"
443444
"csrc/pos_encoding_kernels.cpp"
445+
"csrc/fused_qknorm_rope.cpp"
444446
"csrc/torch_bindings.cpp"
445447
"csrc/quantization/fp8/fp8_quant.cpp"
446448
"csrc/quantization/fp4/mxfp4_quant.cpp"

csrc/activation.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
#include <sycl/sycl.hpp>
22
#include <cmath>
33
#include <algorithm>
4+
#include <numeric>
45
#include "utils.h"
56
#include "dispatch_utils.h"
67

8+
#include <c10/util/Float8_e4m3fn.h>
9+
#include <c10/util/Float8_e5m2.h>
10+
#include "quantization/fp8/quant_utils.h"
11+
712
#define VLLM_LDG(arg) *(arg)
813

914
namespace vllm {
@@ -172,6 +177,56 @@ class act_and_mul_vec_kernel {
172177
const int d_;
173178
};
174179

180+
template <
181+
typename scalar_t,
182+
scalar_t (*ACT_FN)(const scalar_t&),
183+
typename fp8_type,
184+
int VEC_SIZE>
185+
class act_and_mul_quant_vec_kernel {
186+
public:
187+
act_and_mul_quant_vec_kernel(
188+
fp8_type* __restrict__ out, // [..., d]
189+
const scalar_t* __restrict__ input, // [..., 2 * d]
190+
const float* __restrict__ scale, // [1]
191+
const int d)
192+
: out_(out), input_(input), scale_(scale), d_(d) {}
193+
194+
void operator()(sycl::nd_item<1> item) const {
195+
using vec_t = vllm::xpu::aligned_vec<scalar_t, VEC_SIZE>;
196+
197+
const int64_t token_idx = item.get_group(0);
198+
const int64_t offset = item.get_local_linear_id();
199+
const int64_t step = item.get_local_range(0);
200+
const int64_t bound = d_ / VEC_SIZE;
201+
202+
const float inv_scale = 1.0f / (*scale_);
203+
const float fp8_max = static_cast<float>(fp8::quant_type_max_v<fp8_type>);
204+
205+
// x and y halves are laid out contiguously: [x0..xd-1, y0..yd-1]
206+
const auto* v_x =
207+
reinterpret_cast<const vec_t*>(input_) + token_idx * bound * 2;
208+
const auto* v_y = v_x + bound;
209+
210+
for (int64_t i = offset; i < bound; i += step) {
211+
vec_t xv = v_x[i];
212+
vec_t yv = v_y[i];
213+
#pragma unroll
214+
for (int j = 0; j < VEC_SIZE; j++) {
215+
float val = static_cast<float>(ACT_FN(xv[j]) * yv[j]) * inv_scale;
216+
float clamped = sycl::fmax(-fp8_max, sycl::fmin(val, fp8_max));
217+
out_[token_idx * d_ + i * VEC_SIZE + j] =
218+
static_cast<fp8_type>(clamped);
219+
}
220+
}
221+
}
222+
223+
private:
224+
fp8_type* __restrict__ out_; // [..., d]
225+
const scalar_t* __restrict__ input_; // [..., 2 * d]
226+
const float* __restrict__ scale_; // [1]
227+
const int d_;
228+
};
229+
175230
template <typename T>
176231
[[intel::device_indirectly_callable]] inline __attribute__((always_inline)) T
177232
swigluoai_and_mul(const T& gate, const T& up, float alpha, float limit) {
@@ -352,6 +407,73 @@ void silu_and_mul(
352407
});
353408
}
354409

410+
// Fused SiLU + Mul + FP8 Quantization
411+
// Input: [..., 2*d] in FP16/BF16, Output: [..., d] in FP8
412+
// Dispatches to the vectorized kernel (VEC_SIZE=1..8) based on alignment.
413+
#define LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, N) \
414+
case N: { \
415+
int64_t wg_size = \
416+
std::min(static_cast<int64_t>(d / N), static_cast<int64_t>(1024)); \
417+
VLLM_DISPATCH_FP8_TYPES( \
418+
out.scalar_type(), "act_and_mul_quant_vec_kernel_fp8", [&] { \
419+
auto out_ptr = out.data_ptr<fp8_t>(); \
420+
queue.submit([&](sycl::handler& cgh) { \
421+
cgh.parallel_for( \
422+
sycl::nd_range<1>(num_tokens * wg_size, wg_size), \
423+
vllm::act_and_mul_quant_vec_kernel<sycl_t, KERNEL, fp8_t, N>( \
424+
out_ptr, (sycl_t*)input_ptr, scale_ptr, d)); \
425+
}); \
426+
}); \
427+
break; \
428+
}
429+
430+
#define LAUNCH_ACTIVATION_GATE_QUANT_KERNEL(KERNEL) \
431+
using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type; \
432+
int d = input.size(-1) / 2; \
433+
int64_t num_tokens = input.numel() / input.size(-1); \
434+
if (num_tokens == 0) { \
435+
return; \
436+
} \
437+
auto input_ptr = input.data_ptr<scalar_t>(); \
438+
auto scale_ptr = scale.data_ptr<float>(); \
439+
at::DeviceGuard device_guard(input.device()); \
440+
auto& queue = vllm::xpu::vllmGetQueue(); \
441+
/* Compute vec_size like non-quant path: gcd(4*sizeof(float)/sizeof, d) */ \
442+
int vec_size = static_cast<int>(sizeof(float) * 4 / sizeof(sycl_t)); \
443+
{ \
444+
int64_t tmp_wg = \
445+
std::min(static_cast<int64_t>(d), static_cast<int64_t>(1024)); \
446+
while (vec_size > 1 && (vec_size >> 1) * tmp_wg >= d) { \
447+
vec_size = vec_size >> 1; \
448+
} \
449+
} \
450+
if (d % vec_size != 0) vec_size = 1; \
451+
switch (vec_size) { \
452+
LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 1); \
453+
LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 2); \
454+
LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 4); \
455+
LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 8); \
456+
LAUNCH_ACT_AND_MUL_QUANT_VEC(KERNEL, 16); \
457+
default: \
458+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size); \
459+
}
460+
461+
void silu_and_mul_quant(
462+
torch::Tensor& out, // [..., d] FP8
463+
torch::Tensor& input, // [..., 2 * d] FP16/BF16
464+
torch::Tensor& scale) // [1] FP32
465+
{
466+
TORCH_CHECK(
467+
out.dtype() == torch::kFloat8_e4m3fn ||
468+
out.dtype() == torch::kFloat8_e5m2);
469+
TORCH_CHECK(
470+
input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
471+
TORCH_CHECK(input.size(-1) % 2 == 0);
472+
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul_quant", [&] {
473+
LAUNCH_ACTIVATION_GATE_QUANT_KERNEL(vllm::silu_kernel);
474+
});
475+
}
476+
355477
void mul_and_silu(
356478
torch::Tensor& out, // [..., d]
357479
torch::Tensor& input) // [..., 2 * d]

0 commit comments

Comments
 (0)