77namespace vllm {
88
99template <typename scalar_t >
10- void rms_norm_kernel (scalar_t * __restrict__ out, // [..., hidden_size]
11- const scalar_t * __restrict__ input, // [..., hidden_size]
12- const int64_t input_stride,
13- const scalar_t * __restrict__ weight, // [hidden_size]
14- const float epsilon, const int num_tokens,
15- const int hidden_size, const sycl::nd_item<3 >& item_ct1,
16- float * s_variance) {
17- float variance = 0 .0f ;
18-
19- for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
20- idx += item_ct1.get_local_range (2 )) {
21- const float x = (float )input[item_ct1.get_group (2 ) * input_stride + idx];
22- variance += x * x;
10+ class rms_norm_kernel {
11+ public:
12+ rms_norm_kernel (scalar_t * out_, const scalar_t * input_,
13+ const int64_t input_stride_, const scalar_t * weight_,
14+ const float epsilon_, const int num_tokens_,
15+ const int hidden_size_,
16+ sycl::local_accessor<float , 1 > s_variance_)
17+ : out(out_),
18+ input (input_),
19+ input_stride(input_stride_),
20+ weight(weight_),
21+ epsilon(epsilon_),
22+ num_tokens(num_tokens_),
23+ hidden_size(hidden_size_),
24+ s_variance(s_variance_) {}
25+
26+ void operator () [[intel::reqd_sub_group_size(32 )]] (
27+ const sycl::nd_item<3 >& item_ct1) const {
28+ float * s_variance_ptr = s_variance.get_pointer ();
29+ float variance = 0 .0f ;
30+
31+ for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
32+ idx += item_ct1.get_local_range (2 )) {
33+ const float x = (float )input[item_ct1.get_group (2 ) * input_stride + idx];
34+ variance += x * x;
35+ }
36+
37+ variance = sycl::reduce_over_group (
38+ sycl::ext::oneapi::this_work_item::get_work_group<3 >(), variance,
39+ sycl::plus<>());
40+ if (item_ct1.get_local_id (2 ) == 0 ) {
41+ *s_variance_ptr = sycl::rsqrt (variance / hidden_size + epsilon);
42+ }
43+
44+ item_ct1.barrier (sycl::access::fence_space::local_space);
45+
46+ for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
47+ idx += item_ct1.get_local_range (2 )) {
48+ float x = (float )input[item_ct1.get_group (2 ) * hidden_size + idx];
49+ out[item_ct1.get_group (2 ) * input_stride + idx] =
50+ ((scalar_t )(x * (*s_variance_ptr))) * weight[idx];
51+ }
2352 }
2453
25- variance = sycl::reduce_over_group (
26- sycl::ext::oneapi::this_work_item::get_work_group<3 >(), variance,
27- sycl::plus<>());
28- if (item_ct1.get_local_id (2 ) == 0 ) {
29- *s_variance = sycl::rsqrt (variance / hidden_size + epsilon);
30- }
31-
32- item_ct1.barrier (sycl::access::fence_space::local_space);
33-
34- for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
35- idx += item_ct1.get_local_range (2 )) {
36- float x = (float )input[item_ct1.get_group (2 ) * hidden_size + idx];
37- out[item_ct1.get_group (2 ) * input_stride + idx] =
38- ((scalar_t )(x * (*s_variance))) * weight[idx];
39- }
40- }
54+ private:
55+ scalar_t * __restrict__ out; // [..., hidden_size]
56+ const scalar_t * __restrict__ input; // [..., hidden_size]
57+ const int64_t input_stride;
58+ const scalar_t * __restrict__ weight; // [hidden_size]
59+ const float epsilon;
60+ const int num_tokens;
61+ const int hidden_size;
62+ sycl::local_accessor<float , 1 > s_variance;
63+ };
4164
4265template <typename scalar_t >
4366void call_rms_norm_kernel (torch::Tensor& out, torch::Tensor& input,
@@ -54,52 +77,74 @@ void call_rms_norm_kernel(torch::Tensor& out, torch::Tensor& input,
5477 auto & queue = vllm::xpu::vllmGetQueue ();
5578 queue.submit ([&](sycl::handler& cgh) {
5679 sycl::local_accessor<float , 1 > s_variance (sycl::range<1 >(1 ), cgh);
57- cgh.parallel_for (
58- sycl::nd_range<3 >(grid * block, block),
59- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
60- rms_norm_kernel<sycl_t >((sycl_t *)out_ptr, (const sycl_t *)input_ptr,
61- input_stride, (const sycl_t *)weight_ptr,
62- epsilon, num_tokens, hidden_size, item_ct1,
63- s_variance.get_pointer ());
64- });
80+ cgh.parallel_for (sycl::nd_range<3 >(grid * block, block),
81+ vllm::rms_norm_kernel<sycl_t >(
82+ (sycl_t *)out_ptr, (const sycl_t *)input_ptr,
83+ input_stride, (const sycl_t *)weight_ptr, epsilon,
84+ num_tokens, hidden_size, s_variance));
6585 });
6686}
6787
6888template <typename scalar_t >
69- void fused_add_rms_norm_kernel (
70- scalar_t * __restrict__ input, // [..., hidden_size]
71- scalar_t * __restrict__ residual, // [..., hidden_size]
72- const int64_t input_stride,
73- const scalar_t * __restrict__ weight, // [hidden_size]
74- const float epsilon, const int num_tokens, const int hidden_size,
75- const sycl::nd_item<3 >& item_ct1, float * s_variance) {
76- float variance = 0 .0f ;
77-
78- for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
79- idx += item_ct1.get_local_range (2 )) {
80- scalar_t z = (scalar_t )input[item_ct1.get_group (2 ) * input_stride + idx];
81- z += residual[item_ct1.get_group (2 ) * hidden_size + idx];
82- float x = (float )z;
83- variance += x * x;
84- residual[item_ct1.get_group (2 ) * hidden_size + idx] = z;
89+ class fused_add_rms_norm_kernel {
90+ public:
91+ fused_add_rms_norm_kernel (
92+ scalar_t * __restrict__ input_, // [..., hidden_size]
93+ scalar_t * __restrict__ residual_, // [..., hidden_size]
94+ const int64_t input_stride_,
95+ const scalar_t * __restrict__ weight_, // [hidden_size]
96+ const float epsilon_, const int num_tokens_, const int hidden_size_,
97+ sycl::local_accessor<float , 1 > s_variance_)
98+ : input(input_),
99+ residual (residual_),
100+ input_stride(input_stride_),
101+ weight(weight_),
102+ epsilon(epsilon_),
103+ num_tokens(num_tokens_),
104+ hidden_size(hidden_size_),
105+ s_variance(s_variance_) {}
106+
107+ void operator () [[intel::reqd_sub_group_size(32 )]] (
108+ const sycl::nd_item<3 >& item_ct1) const {
109+ float * s_variance_ptr = s_variance.get_pointer ();
110+ float variance = 0 .0f ;
111+
112+ for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
113+ idx += item_ct1.get_local_range (2 )) {
114+ scalar_t z = (scalar_t )input[item_ct1.get_group (2 ) * input_stride + idx];
115+ z += residual[item_ct1.get_group (2 ) * hidden_size + idx];
116+ float x = (float )z;
117+ variance += x * x;
118+ residual[item_ct1.get_group (2 ) * hidden_size + idx] = z;
119+ }
120+
121+ variance = sycl::reduce_over_group (
122+ sycl::ext::oneapi::this_work_item::get_work_group<3 >(), variance,
123+ sycl::plus<>());
124+ if (item_ct1.get_local_id (2 ) == 0 ) {
125+ *s_variance_ptr = sycl::rsqrt (variance / hidden_size + epsilon);
126+ }
127+
128+ item_ct1.barrier (sycl::access::fence_space::local_space);
129+
130+ for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
131+ idx += item_ct1.get_local_range (2 )) {
132+ float x = (float )residual[item_ct1.get_group (2 ) * hidden_size + idx];
133+ input[item_ct1.get_group (2 ) * input_stride + idx] =
134+ ((scalar_t )(x * (*s_variance_ptr))) * weight[idx];
135+ }
85136 }
86137
87- variance = sycl::reduce_over_group (
88- sycl::ext::oneapi::this_work_item::get_work_group<3 >(), variance,
89- sycl::plus<>());
90- if (item_ct1.get_local_id (2 ) == 0 ) {
91- *s_variance = sycl::rsqrt (variance / hidden_size + epsilon);
92- }
93-
94- item_ct1.barrier (sycl::access::fence_space::local_space);
95-
96- for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
97- idx += item_ct1.get_local_range (2 )) {
98- float x = (float )residual[item_ct1.get_group (2 ) * hidden_size + idx];
99- input[item_ct1.get_group (2 ) * input_stride + idx] =
100- ((scalar_t )(x * (*s_variance))) * weight[idx];
101- }
102- }
138+ private:
139+ scalar_t * __restrict__ input; // [..., hidden_size]
140+ scalar_t * __restrict__ residual; // [..., hidden_size]
141+ const int64_t input_stride;
142+ const scalar_t * __restrict__ weight; // [hidden_size]
143+ const float epsilon;
144+ const int num_tokens;
145+ const int hidden_size;
146+ sycl::local_accessor<float , 1 > s_variance; // local memory for variance
147+ };
103148
104149template <typename scalar_t >
105150void call_fused_add_rms_norm_kernel (torch::Tensor& input,
@@ -116,16 +161,12 @@ void call_fused_add_rms_norm_kernel(torch::Tensor& input,
116161 sycl::range<3 > block (1 , 1 , std::min (hidden_size, 1024 ));
117162 auto & queue = vllm::xpu::vllmGetQueue ();
118163 queue.submit ([&](sycl::handler& cgh) {
119- sycl::local_accessor<float , 1 > shared_vals (sycl::range<1 >(32 ), cgh);
120164 sycl::local_accessor<float , 1 > s_variance (sycl::range<1 >(1 ), cgh);
121- cgh.parallel_for (
122- sycl::nd_range<3 >(grid * block, block),
123- [=](sycl::nd_item<3 > item_ct1) [[intel::reqd_sub_group_size (32 )]] {
124- fused_add_rms_norm_kernel<sycl_t >(
125- (sycl_t *)input_ptr, (sycl_t *)residual_ptr, input_stride,
126- (const sycl_t *)weight_ptr, epsilon, num_tokens, hidden_size,
127- item_ct1, s_variance.get_pointer ());
128- });
165+ cgh.parallel_for (sycl::nd_range<3 >(grid * block, block),
166+ fused_add_rms_norm_kernel<sycl_t >(
167+ (sycl_t *)input_ptr, (sycl_t *)residual_ptr,
168+ input_stride, (const sycl_t *)weight_ptr, epsilon,
169+ num_tokens, hidden_size, s_variance));
129170 });
130171}
131172
0 commit comments