@@ -59,18 +59,21 @@ inline T gelu_tanh_kernel(const T& x) {
5959 return (T)(0 .5f * f * (1 .0f + sycl::tanh (inner)));
6060}
6161
62- template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
63- bool act_first>
62+ template <
63+ typename scalar_t ,
64+ scalar_t (*ACT_FN)(const scalar_t &),
65+ bool act_first>
6466inline scalar_t compute (const scalar_t & x, const scalar_t & y) {
6567 return act_first ? ACT_FN (x) * y : x * ACT_FN (y);
6668}
6769
6870template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &)>
6971class act_kernel {
7072 public:
71- act_kernel (scalar_t * __restrict__ out, // [..., d]
72- const scalar_t * __restrict__ input, // [..., d]
73- const int d)
73+ act_kernel (
74+ scalar_t * __restrict__ out, // [..., d]
75+ const scalar_t * __restrict__ input, // [..., d]
76+ const int d)
7477 : out_(out), input_(input), d_(d) {}
7578
7679 void operator () [[sycl::reqd_sub_group_size(32 )]] (
@@ -89,13 +92,16 @@ class act_kernel {
8992 const int d_;
9093};
9194
92- template <typename scalar_t , scalar_t (*ACT_FN)(const scalar_t &),
93- bool act_first>
95+ template <
96+ typename scalar_t ,
97+ scalar_t (*ACT_FN)(const scalar_t &),
98+ bool act_first>
9499class act_and_mul_kernel {
95100 public:
96- act_and_mul_kernel (scalar_t * __restrict__ out, // [..., d]
97- const scalar_t * __restrict__ input, // [..., 2, d]
98- const int d)
101+ act_and_mul_kernel (
102+ scalar_t * __restrict__ out, // [..., d]
103+ const scalar_t * __restrict__ input, // [..., 2, d]
104+ const int d)
99105 : out_(out), input_(input), d_(d) {}
100106
101107 void operator () [[sycl::reqd_sub_group_size(32 )]] (
@@ -135,14 +141,18 @@ swigluoai_and_mul(const T& gate, const T& up, float alpha, float limit) {
135141 return (T)((clamped_up + 1 .0f ) * glu);
136142}
137143
138- template <typename scalar_t ,
139- scalar_t (*ACT_FN)(const scalar_t &, const scalar_t &, const float ,
140- const float )>
144+ template <
145+ typename scalar_t ,
146+ scalar_t (*ACT_FN)(
147+ const scalar_t &, const scalar_t &, const float , const float )>
141148class swigluoai_and_mul_kernel {
142149 public:
143- swigluoai_and_mul_kernel (scalar_t * __restrict__ out, // [..., d]
144- const scalar_t * __restrict__ input, // [..., 2, d]
145- const int d, const float alpha, const float limit)
150+ swigluoai_and_mul_kernel (
151+ scalar_t * __restrict__ out, // [..., d]
152+ const scalar_t * __restrict__ input, // [..., 2, d]
153+ const int d,
154+ const float alpha,
155+ const float limit)
146156 : out(out), input(input), d(d), alpha(alpha), limit(limit) {}
147157
148158 void operator ()(sycl::nd_item<1 > item) const {
@@ -171,121 +181,135 @@ class swigluoai_and_mul_kernel {
171181// Launch activation and gating kernel.
172182// Use ACT_FIRST (bool) indicating whether to apply the activation function
173183// first.
174- #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL, ACT_FIRST ) \
175- using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t >::Type; \
176- int d = input.size(-1 ) / 2 ; \
177- int64_t num_tokens = input.numel() / input.size(-1 ); \
178- sycl::range<3 > grid (1 , 1 , num_tokens); \
179- sycl::range<3 > block (1 , 1 , std::min(d, 1024 )); \
180- if (num_tokens == 0 ) { \
181- return ; \
182- } \
183- auto out_ptr = out.data_ptr<scalar_t >(); \
184- auto input_ptr = input.data_ptr<scalar_t >(); \
185- at::DeviceGuard device_guard (input.device()); \
186- auto & queue = vllm::xpu::vllmGetQueue(); \
187- queue.submit([&](sycl::handler& cgh) { \
188- cgh.parallel_for (sycl::nd_range<3 >(grid * block, block), \
189- vllm::act_and_mul_kernel<sycl_t , KERNEL, ACT_FIRST>( \
190- (sycl_t *)out_ptr, (sycl_t *)input_ptr, d)); \
184+ #define LAUNCH_ACTIVATION_GATE_KERNEL (KERNEL, ACT_FIRST ) \
185+ using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t >::Type; \
186+ int d = input.size(-1 ) / 2 ; \
187+ int64_t num_tokens = input.numel() / input.size(-1 ); \
188+ sycl::range<3 > grid (1 , 1 , num_tokens); \
189+ sycl::range<3 > block (1 , 1 , std::min(d, 1024 )); \
190+ if (num_tokens == 0 ) { \
191+ return ; \
192+ } \
193+ auto out_ptr = out.data_ptr<scalar_t >(); \
194+ auto input_ptr = input.data_ptr<scalar_t >(); \
195+ at::DeviceGuard device_guard (input.device()); \
196+ auto & queue = vllm::xpu::vllmGetQueue(); \
197+ queue.submit([&](sycl::handler& cgh) { \
198+ cgh.parallel_for ( \
199+ sycl::nd_range<3 >(grid * block, block), \
200+ vllm::act_and_mul_kernel<sycl_t , KERNEL, ACT_FIRST>( \
201+ (sycl_t *)out_ptr, (sycl_t *)input_ptr, d)); \
191202 });
192203
193- void silu_and_mul (torch::Tensor& out, // [..., d]
194- torch::Tensor& input) // [..., 2 * d]
204+ void silu_and_mul (
205+ torch::Tensor& out, // [..., d]
206+ torch::Tensor& input) // [..., 2 * d]
195207{
196208 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " silu_and_mul" , [&] {
197209 LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel, true );
198210 });
199211}
200212
201- void mul_and_silu (torch::Tensor& out, // [..., d]
202- torch::Tensor& input) // [..., 2 * d]
213+ void mul_and_silu (
214+ torch::Tensor& out, // [..., d]
215+ torch::Tensor& input) // [..., 2 * d]
203216{
204217 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " mul_and_silu" , [&] {
205218 LAUNCH_ACTIVATION_GATE_KERNEL (vllm::silu_kernel, false );
206219 });
207220}
208221
209- void gelu_and_mul (torch::Tensor& out, // [..., d]
210- torch::Tensor& input) // [..., 2 * d]
222+ void gelu_and_mul (
223+ torch::Tensor& out, // [..., d]
224+ torch::Tensor& input) // [..., 2 * d]
211225{
212226 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " gelu_and_mul" , [&] {
213227 LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_kernel, true );
214228 });
215229}
216230
217- void gelu_tanh_and_mul (torch::Tensor& out, // [..., d]
218- torch::Tensor& input) // [..., 2 * d]
231+ void gelu_tanh_and_mul (
232+ torch::Tensor& out, // [..., d]
233+ torch::Tensor& input) // [..., 2 * d]
219234{
220235 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " gelu_tanh_and_mul" , [&] {
221236 LAUNCH_ACTIVATION_GATE_KERNEL (vllm::gelu_tanh_kernel, true );
222237 });
223238}
224239
225240// Launch element-wise activation kernel.
226- #define LAUNCH_ACTIVATION_KERNEL (KERNEL ) \
227- using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t >::Type; \
228- int d = input.size(-1 ); \
229- int64_t num_tokens = input.numel() / input.size(-1 ); \
230- sycl::range<3 > grid (1 , 1 , num_tokens); \
231- sycl::range<3 > block (1 , 1 , std::min(d, 1024 )); \
232- if (num_tokens == 0 ) { \
233- return ; \
234- } \
235- auto out_ptr = out.data_ptr<scalar_t >(); \
236- auto input_ptr = input.data_ptr<scalar_t >(); \
237- at::DeviceGuard device_guard (input.device()); \
238- auto & queue = vllm::xpu::vllmGetQueue(); \
239- queue.submit([&](sycl::handler& cgh) { \
240- cgh.parallel_for (sycl::nd_range<3 >(grid * block, block), \
241- vllm::act_kernel<sycl_t , KERNEL>((sycl_t *)out_ptr, \
242- (sycl_t *)input_ptr, d)); \
241+ #define LAUNCH_ACTIVATION_KERNEL (KERNEL ) \
242+ using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t >::Type; \
243+ int d = input.size(-1 ); \
244+ int64_t num_tokens = input.numel() / input.size(-1 ); \
245+ sycl::range<3 > grid (1 , 1 , num_tokens); \
246+ sycl::range<3 > block (1 , 1 , std::min(d, 1024 )); \
247+ if (num_tokens == 0 ) { \
248+ return ; \
249+ } \
250+ auto out_ptr = out.data_ptr<scalar_t >(); \
251+ auto input_ptr = input.data_ptr<scalar_t >(); \
252+ at::DeviceGuard device_guard (input.device()); \
253+ auto & queue = vllm::xpu::vllmGetQueue(); \
254+ queue.submit([&](sycl::handler& cgh) { \
255+ cgh.parallel_for ( \
256+ sycl::nd_range<3 >(grid * block, block), \
257+ vllm::act_kernel<sycl_t , KERNEL>( \
258+ (sycl_t *)out_ptr, (sycl_t *)input_ptr, d)); \
243259 });
244260
245- #define LAUNCH_SWIGLUOAI_AND_MUL (KERNEL, ALPHA, LIMIT ) \
246- int d = input.size(-1 ) / 2 ; \
247- int64_t num_tokens = input.numel() / input.size(-1 ); \
248- sycl::range<1 > grid (num_tokens); \
249- sycl::range<1 > block (std::min(d, 1024 )); \
250- at::DeviceGuard device_guard (input.device()); \
251- auto & queue = vllm::xpu::vllmGetQueue(); \
252- VLLM_DISPATCH_FLOATING_TYPES ( \
253- input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
254- queue.submit ([&](sycl::handler& cgh) { \
255- cgh.parallel_for ( \
256- sycl::nd_range<1 >(grid * block, block), \
257- vllm::swigluoai_and_mul_kernel<scalar_t , KERNEL<scalar_t >>( \
258- out.data_ptr <scalar_t >(), input.data_ptr <scalar_t >(), d, \
259- ALPHA, LIMIT)); \
260- }); \
261+ #define LAUNCH_SWIGLUOAI_AND_MUL (KERNEL, ALPHA, LIMIT ) \
262+ int d = input.size(-1 ) / 2 ; \
263+ int64_t num_tokens = input.numel() / input.size(-1 ); \
264+ sycl::range<1 > grid (num_tokens); \
265+ sycl::range<1 > block (std::min(d, 1024 )); \
266+ at::DeviceGuard device_guard (input.device()); \
267+ auto & queue = vllm::xpu::vllmGetQueue(); \
268+ VLLM_DISPATCH_FLOATING_TYPES ( \
269+ input.scalar_type(), "clamp_swiglu_kernel_with_params", [&] { \
270+ queue.submit ([&](sycl::handler& cgh) { \
271+ cgh.parallel_for ( \
272+ sycl::nd_range<1 >(grid * block, block), \
273+ vllm::swigluoai_and_mul_kernel<scalar_t , KERNEL<scalar_t >>( \
274+ out.data_ptr <scalar_t >(), \
275+ input.data_ptr <scalar_t >(), \
276+ d, \
277+ ALPHA, \
278+ LIMIT)); \
279+ }); \
261280 });
262281
263- void gelu_new (torch::Tensor& out, // [..., d]
264- torch::Tensor& input) // [..., d]
282+ void gelu_new (
283+ torch::Tensor& out, // [..., d]
284+ torch::Tensor& input) // [..., d]
265285{
266286 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " gelu_new" , [&] {
267287 LAUNCH_ACTIVATION_KERNEL (vllm::gelu_new_kernel);
268288 });
269289}
270290
271- void gelu_fast (torch::Tensor& out, // [..., d]
272- torch::Tensor& input) // [..., d]
291+ void gelu_fast (
292+ torch::Tensor& out, // [..., d]
293+ torch::Tensor& input) // [..., d]
273294{
274295 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " gelu_fast" , [&] {
275296 LAUNCH_ACTIVATION_KERNEL (vllm::gelu_fast_kernel);
276297 });
277298}
278299
279- void gelu_quick (torch::Tensor& out, // [..., d]
280- torch::Tensor& input) // [..., d]
300+ void gelu_quick (
301+ torch::Tensor& out, // [..., d]
302+ torch::Tensor& input) // [..., d]
281303{
282304 VLLM_DISPATCH_FLOATING_TYPES (input.scalar_type (), " gelu_quick" , [&] {
283305 LAUNCH_ACTIVATION_KERNEL (vllm::gelu_quick_kernel);
284306 });
285307}
286308
287- void swigluoai_and_mul (torch::Tensor& out, // [..., d]
288- torch::Tensor& input, // [..., 2 * d]
289- double alpha, double limit) {
309+ void swigluoai_and_mul (
310+ torch::Tensor& out, // [..., d]
311+ torch::Tensor& input, // [..., 2 * d]
312+ double alpha,
313+ double limit) {
290314 LAUNCH_SWIGLUOAI_AND_MUL (vllm::swigluoai_and_mul, alpha, limit);
291315}
0 commit comments