Skip to content

Commit 14631a4

Browse files
Vectorize act-and-mul kernels for speedup (vllm-project#207)
Add vectorized memory access to activation-and-mul kernels using aligned_vec loads/stores with dynamic vec_size dispatch (1-16). Switch from 3D to 1D nd_range for simpler indexing. All 4 fused ops (silu_and_mul, mul_and_silu, gelu_and_mul, gelu_tanh_and_mul) now use the vectorized path. The original scalar kernel is retained as VEC_SIZE=1 fallback. Benchmark results (avg GPU time in us, 200 iterations, no per-iter sync): | Model | Tokens | Dtype | d (intermediate_size) | Baseline (us) | Vectorized (us) | Change | |-------|--------|-------|-----------------------|---------------|-----------------|--------| | llama3-70b | 128 | fp16 | 28672 | 24.01 | 8.96 | -62.7% | | llama3-70b | 128 | bf16 | 28672 | 27.25 | 11.25 | -58.7% | | llama3-70b | 512 | fp16 | 28672 | 262.79 | 202.13 | -23.1% | | llama3-70b | 512 | bf16 | 28672 | 261.46 | 202.67 | -22.5% | | llama3-70b | 1024 | fp16 | 28672 | 545.11 | 424.03 | -22.2% | | llama3-70b | 1024 | bf16 | 28672 | 545.13 | 424.82 | -22.1% | | llama3-70b | 2048 | fp16 | 28672 | 1108.82 | 872.10 | -21.3% | | llama3-70b | 2048 | bf16 | 28672 | 1108.13 | 872.70 | -21.2% | | llama3-8b | 128 | fp16 | 14336 | 33.05 | 6.51 | -80.3% | | llama3-8b | 128 | bf16 | 14336 | 26.65 | 6.15 | -76.9% | | llama3-8b | 512 | fp16 | 14336 | 169.74 | 92.10 | -45.7% | | llama3-8b | 512 | bf16 | 14336 | 139.62 | 93.25 | -33.2% | | llama3-8b | 1024 | fp16 | 14336 | 261.68 | 201.64 | -22.9% | | llama3-8b | 1024 | bf16 | 14336 | 260.92 | 201.73 | -22.7% | | llama3-8b | 2048 | fp16 | 14336 | 539.98 | 420.75 | -22.1% | | llama3-8b | 2048 | bf16 | 14336 | 541.28 | 422.87 | -21.9% | | qwen-14b | 512 | fp16 | 13824 | 116.04 | 85.32 | -26.5% | | qwen-14b | 512 | bf16 | 13824 | 114.37 | 85.69 | -25.1% | | qwen-14b | 1024 | fp16 | 13824 | 238.41 | 193.29 | -18.9% | | qwen-14b | 1024 | bf16 | 13824 | 254.00 | 193.76 | -23.7% | | qwen-14b | 2048 | fp16 | 13824 | 527.05 | 407.07 | -22.8% | | qwen-14b | 2048 | bf16 | 13824 | 521.38 | 403.80 | -22.6% | | qwen-32b | 128 | fp16 | 27648 | 20.65 | 6.29 | -69.5% | | qwen-32b | 128 | bf16 | 27648 | 21.35 | 6.89 | -67.7% | | qwen-32b | 512 | fp16 | 27648 | 253.84 | 193.79 | -23.7% | | qwen-32b | 512 | bf16 | 27648 | 253.64 | 193.84 | -23.6% | | qwen-32b | 1024 | fp16 | 27648 | 526.81 | 407.99 | -22.6% | | qwen-32b | 1024 | bf16 | 27648 | 523.08 | 408.52 | -21.9% | | qwen-32b | 2048 | fp16 | 27648 | 1069.97 | 838.01 | -21.7% | | qwen-32b | 2048 | bf16 | 27648 | 1068.91 | 838.34 | -21.6% | Signed-off-by: Ma, Liangliang <liangliang.ma@intel.com>
1 parent 1b4770e commit 14631a4

1 file changed

Lines changed: 91 additions & 4 deletions

File tree

csrc/activation.cpp

Lines changed: 91 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,49 @@ class act_and_mul_kernel {
121121
const int d_;
122122
};
123123

124+
// Vectorized version of act_and_mul_kernel using aligned vector loads/stores.
125+
// Each work-item processes VEC_SIZE elements per iteration, reducing memory
126+
// transactions and improving bandwidth utilization.
127+
template <
128+
typename scalar_t,
129+
scalar_t (*ACT_FN)(const scalar_t&),
130+
bool act_first,
131+
int VEC_SIZE>
132+
class act_and_mul_vec_kernel {
133+
public:
134+
act_and_mul_vec_kernel(
135+
scalar_t* __restrict__ out,
136+
const scalar_t* __restrict__ input,
137+
const int d)
138+
: out_(out), input_(input), d_(d) {}
139+
140+
void operator()(sycl::nd_item<1> item) const {
141+
using vec_t = vllm::xpu::aligned_vec<scalar_t, VEC_SIZE>;
142+
const int64_t token_idx = item.get_group(0);
143+
const int64_t offset = item.get_local_linear_id();
144+
const int64_t step = item.get_local_range(0);
145+
const int64_t bound = d_ / VEC_SIZE;
146+
147+
for (int64_t i = offset; i < bound; i += step) {
148+
auto x_vec =
149+
reinterpret_cast<const vec_t*>(input_)[token_idx * bound * 2 + i];
150+
auto y_vec = reinterpret_cast<const vec_t*>(
151+
input_)[token_idx * bound * 2 + i + bound];
152+
vec_t out_vec;
153+
#pragma unroll
154+
for (int j = 0; j < VEC_SIZE; ++j) {
155+
out_vec[j] = compute<scalar_t, ACT_FN, act_first>(x_vec[j], y_vec[j]);
156+
}
157+
reinterpret_cast<vec_t*>(out_)[token_idx * bound + i] = out_vec;
158+
}
159+
}
160+
161+
private:
162+
scalar_t* __restrict__ out_;
163+
const scalar_t* __restrict__ input_;
164+
const int d_;
165+
};
166+
124167
template <typename T>
125168
[[intel::device_indirectly_callable]] inline __attribute__((always_inline)) T
126169
swigluoai_and_mul(const T& gate, const T& up, float alpha, float limit) {
@@ -201,12 +244,56 @@ class swigluoai_and_mul_kernel {
201244
(sycl_t*)out_ptr, (sycl_t*)input_ptr, d)); \
202245
});
203246

247+
// Vectorized launch: dispatch to vec_size=1,2,4,8,16 based on d and dtype.
248+
#define VEC_LAUNCH_ACT_AND_MUL(KERNEL, ACT_FIRST, N) \
249+
case N: { \
250+
queue.submit([&](sycl::handler& cgh) { \
251+
cgh.parallel_for( \
252+
sycl::nd_range<1>(num_tokens * wg_size, wg_size), \
253+
vllm::act_and_mul_vec_kernel<sycl_t, KERNEL, ACT_FIRST, N>( \
254+
(sycl_t*)out_ptr, (sycl_t*)input_ptr, d)); \
255+
}); \
256+
break; \
257+
}
258+
259+
#define LAUNCH_ACTIVATION_GATE_KERNEL_VEC(KERNEL, ACT_FIRST) \
260+
using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type; \
261+
int d = input.size(-1) / 2; \
262+
int64_t num_tokens = input.numel() / input.size(-1); \
263+
if (num_tokens == 0) { \
264+
return; \
265+
} \
266+
auto out_ptr = out.data_ptr<scalar_t>(); \
267+
auto input_ptr = input.data_ptr<scalar_t>(); \
268+
at::DeviceGuard device_guard(input.device()); \
269+
auto& queue = vllm::xpu::vllmGetQueue(); \
270+
int vec_size = static_cast<int>(sizeof(float) * 4 / sizeof(scalar_t)); \
271+
{ \
272+
int64_t tmp_wg = \
273+
std::min(static_cast<int64_t>(d), static_cast<int64_t>(1024)); \
274+
while (vec_size > 1 && (vec_size >> 1) * tmp_wg >= d) { \
275+
vec_size = vec_size >> 1; \
276+
} \
277+
} \
278+
if (d % vec_size != 0) vec_size = 1; \
279+
int64_t wg_size = std::min( \
280+
static_cast<int64_t>(d / vec_size), static_cast<int64_t>(1024)); \
281+
switch (vec_size) { \
282+
VEC_LAUNCH_ACT_AND_MUL(KERNEL, ACT_FIRST, 1); \
283+
VEC_LAUNCH_ACT_AND_MUL(KERNEL, ACT_FIRST, 2); \
284+
VEC_LAUNCH_ACT_AND_MUL(KERNEL, ACT_FIRST, 4); \
285+
VEC_LAUNCH_ACT_AND_MUL(KERNEL, ACT_FIRST, 8); \
286+
VEC_LAUNCH_ACT_AND_MUL(KERNEL, ACT_FIRST, 16); \
287+
default: \
288+
TORCH_CHECK(false, "Unsupported vector size: ", vec_size); \
289+
}
290+
204291
void silu_and_mul(
205292
torch::Tensor& out, // [..., d]
206293
torch::Tensor& input) // [..., 2 * d]
207294
{
208295
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
209-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, true);
296+
LAUNCH_ACTIVATION_GATE_KERNEL_VEC(vllm::silu_kernel, true);
210297
});
211298
}
212299

@@ -215,7 +302,7 @@ void mul_and_silu(
215302
torch::Tensor& input) // [..., 2 * d]
216303
{
217304
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "mul_and_silu", [&] {
218-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel, false);
305+
LAUNCH_ACTIVATION_GATE_KERNEL_VEC(vllm::silu_kernel, false);
219306
});
220307
}
221308

@@ -224,7 +311,7 @@ void gelu_and_mul(
224311
torch::Tensor& input) // [..., 2 * d]
225312
{
226313
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
227-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel, true);
314+
LAUNCH_ACTIVATION_GATE_KERNEL_VEC(vllm::gelu_kernel, true);
228315
});
229316
}
230317

@@ -233,7 +320,7 @@ void gelu_tanh_and_mul(
233320
torch::Tensor& input) // [..., 2 * d]
234321
{
235322
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
236-
LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel, true);
323+
LAUNCH_ACTIVATION_GATE_KERNEL_VEC(vllm::gelu_tanh_kernel, true);
237324
});
238325
}
239326

0 commit comments

Comments
 (0)