Skip to content

Commit 21de283

Browse files
authored
use functor for rms_norm kernels (#18)
* use functor instead of lambda for rms norm kernels Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> * update Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> --------- Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
1 parent 56106ab commit 21de283

File tree

1 file changed

+119
-78
lines changed

1 file changed

+119
-78
lines changed

csrc/layernorm.cpp

Lines changed: 119 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,37 +7,60 @@
77
namespace vllm {
88

99
template <typename scalar_t>
10-
void rms_norm_kernel(scalar_t* __restrict__ out, // [..., hidden_size]
11-
const scalar_t* __restrict__ input, // [..., hidden_size]
12-
const int64_t input_stride,
13-
const scalar_t* __restrict__ weight, // [hidden_size]
14-
const float epsilon, const int num_tokens,
15-
const int hidden_size, const sycl::nd_item<3>& item_ct1,
16-
float* s_variance) {
17-
float variance = 0.0f;
18-
19-
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
20-
idx += item_ct1.get_local_range(2)) {
21-
const float x = (float)input[item_ct1.get_group(2) * input_stride + idx];
22-
variance += x * x;
10+
class rms_norm_kernel {
11+
public:
12+
rms_norm_kernel(scalar_t* out_, const scalar_t* input_,
13+
const int64_t input_stride_, const scalar_t* weight_,
14+
const float epsilon_, const int num_tokens_,
15+
const int hidden_size_,
16+
sycl::local_accessor<float, 1> s_variance_)
17+
: out(out_),
18+
input(input_),
19+
input_stride(input_stride_),
20+
weight(weight_),
21+
epsilon(epsilon_),
22+
num_tokens(num_tokens_),
23+
hidden_size(hidden_size_),
24+
s_variance(s_variance_) {}
25+
26+
void operator() [[intel::reqd_sub_group_size(32)]] (
27+
const sycl::nd_item<3>& item_ct1) const {
28+
float* s_variance_ptr = s_variance.get_pointer();
29+
float variance = 0.0f;
30+
31+
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
32+
idx += item_ct1.get_local_range(2)) {
33+
const float x = (float)input[item_ct1.get_group(2) * input_stride + idx];
34+
variance += x * x;
35+
}
36+
37+
variance = sycl::reduce_over_group(
38+
sycl::ext::oneapi::this_work_item::get_work_group<3>(), variance,
39+
sycl::plus<>());
40+
if (item_ct1.get_local_id(2) == 0) {
41+
*s_variance_ptr = sycl::rsqrt(variance / hidden_size + epsilon);
42+
}
43+
44+
item_ct1.barrier(sycl::access::fence_space::local_space);
45+
46+
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
47+
idx += item_ct1.get_local_range(2)) {
48+
float x = (float)input[item_ct1.get_group(2) * hidden_size + idx];
49+
out[item_ct1.get_group(2) * input_stride + idx] =
50+
((scalar_t)(x * (*s_variance_ptr))) * weight[idx];
51+
}
2352
}
2453

25-
variance = sycl::reduce_over_group(
26-
sycl::ext::oneapi::this_work_item::get_work_group<3>(), variance,
27-
sycl::plus<>());
28-
if (item_ct1.get_local_id(2) == 0) {
29-
*s_variance = sycl::rsqrt(variance / hidden_size + epsilon);
30-
}
31-
32-
item_ct1.barrier(sycl::access::fence_space::local_space);
33-
34-
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
35-
idx += item_ct1.get_local_range(2)) {
36-
float x = (float)input[item_ct1.get_group(2) * hidden_size + idx];
37-
out[item_ct1.get_group(2) * input_stride + idx] =
38-
((scalar_t)(x * (*s_variance))) * weight[idx];
39-
}
40-
}
54+
private:
55+
scalar_t* __restrict__ out; // [..., hidden_size]
56+
const scalar_t* __restrict__ input; // [..., hidden_size]
57+
const int64_t input_stride;
58+
const scalar_t* __restrict__ weight; // [hidden_size]
59+
const float epsilon;
60+
const int num_tokens;
61+
const int hidden_size;
62+
sycl::local_accessor<float, 1> s_variance;
63+
};
4164

4265
template <typename scalar_t>
4366
void call_rms_norm_kernel(torch::Tensor& out, torch::Tensor& input,
@@ -54,52 +77,74 @@ void call_rms_norm_kernel(torch::Tensor& out, torch::Tensor& input,
5477
auto& queue = vllm::xpu::vllmGetQueue();
5578
queue.submit([&](sycl::handler& cgh) {
5679
sycl::local_accessor<float, 1> s_variance(sycl::range<1>(1), cgh);
57-
cgh.parallel_for(
58-
sycl::nd_range<3>(grid * block, block),
59-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
60-
rms_norm_kernel<sycl_t>((sycl_t*)out_ptr, (const sycl_t*)input_ptr,
61-
input_stride, (const sycl_t*)weight_ptr,
62-
epsilon, num_tokens, hidden_size, item_ct1,
63-
s_variance.get_pointer());
64-
});
80+
cgh.parallel_for(sycl::nd_range<3>(grid * block, block),
81+
vllm::rms_norm_kernel<sycl_t>(
82+
(sycl_t*)out_ptr, (const sycl_t*)input_ptr,
83+
input_stride, (const sycl_t*)weight_ptr, epsilon,
84+
num_tokens, hidden_size, s_variance));
6585
});
6686
}
6787

6888
template <typename scalar_t>
69-
void fused_add_rms_norm_kernel(
70-
scalar_t* __restrict__ input, // [..., hidden_size]
71-
scalar_t* __restrict__ residual, // [..., hidden_size]
72-
const int64_t input_stride,
73-
const scalar_t* __restrict__ weight, // [hidden_size]
74-
const float epsilon, const int num_tokens, const int hidden_size,
75-
const sycl::nd_item<3>& item_ct1, float* s_variance) {
76-
float variance = 0.0f;
77-
78-
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
79-
idx += item_ct1.get_local_range(2)) {
80-
scalar_t z = (scalar_t)input[item_ct1.get_group(2) * input_stride + idx];
81-
z += residual[item_ct1.get_group(2) * hidden_size + idx];
82-
float x = (float)z;
83-
variance += x * x;
84-
residual[item_ct1.get_group(2) * hidden_size + idx] = z;
89+
class fused_add_rms_norm_kernel {
90+
public:
91+
fused_add_rms_norm_kernel(
92+
scalar_t* __restrict__ input_, // [..., hidden_size]
93+
scalar_t* __restrict__ residual_, // [..., hidden_size]
94+
const int64_t input_stride_,
95+
const scalar_t* __restrict__ weight_, // [hidden_size]
96+
const float epsilon_, const int num_tokens_, const int hidden_size_,
97+
sycl::local_accessor<float, 1> s_variance_)
98+
: input(input_),
99+
residual(residual_),
100+
input_stride(input_stride_),
101+
weight(weight_),
102+
epsilon(epsilon_),
103+
num_tokens(num_tokens_),
104+
hidden_size(hidden_size_),
105+
s_variance(s_variance_) {}
106+
107+
void operator() [[intel::reqd_sub_group_size(32)]] (
108+
const sycl::nd_item<3>& item_ct1) const {
109+
float* s_variance_ptr = s_variance.get_pointer();
110+
float variance = 0.0f;
111+
112+
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
113+
idx += item_ct1.get_local_range(2)) {
114+
scalar_t z = (scalar_t)input[item_ct1.get_group(2) * input_stride + idx];
115+
z += residual[item_ct1.get_group(2) * hidden_size + idx];
116+
float x = (float)z;
117+
variance += x * x;
118+
residual[item_ct1.get_group(2) * hidden_size + idx] = z;
119+
}
120+
121+
variance = sycl::reduce_over_group(
122+
sycl::ext::oneapi::this_work_item::get_work_group<3>(), variance,
123+
sycl::plus<>());
124+
if (item_ct1.get_local_id(2) == 0) {
125+
*s_variance_ptr = sycl::rsqrt(variance / hidden_size + epsilon);
126+
}
127+
128+
item_ct1.barrier(sycl::access::fence_space::local_space);
129+
130+
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
131+
idx += item_ct1.get_local_range(2)) {
132+
float x = (float)residual[item_ct1.get_group(2) * hidden_size + idx];
133+
input[item_ct1.get_group(2) * input_stride + idx] =
134+
((scalar_t)(x * (*s_variance_ptr))) * weight[idx];
135+
}
85136
}
86137

87-
variance = sycl::reduce_over_group(
88-
sycl::ext::oneapi::this_work_item::get_work_group<3>(), variance,
89-
sycl::plus<>());
90-
if (item_ct1.get_local_id(2) == 0) {
91-
*s_variance = sycl::rsqrt(variance / hidden_size + epsilon);
92-
}
93-
94-
item_ct1.barrier(sycl::access::fence_space::local_space);
95-
96-
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
97-
idx += item_ct1.get_local_range(2)) {
98-
float x = (float)residual[item_ct1.get_group(2) * hidden_size + idx];
99-
input[item_ct1.get_group(2) * input_stride + idx] =
100-
((scalar_t)(x * (*s_variance))) * weight[idx];
101-
}
102-
}
138+
private:
139+
scalar_t* __restrict__ input; // [..., hidden_size]
140+
scalar_t* __restrict__ residual; // [..., hidden_size]
141+
const int64_t input_stride;
142+
const scalar_t* __restrict__ weight; // [hidden_size]
143+
const float epsilon;
144+
const int num_tokens;
145+
const int hidden_size;
146+
sycl::local_accessor<float, 1> s_variance; // local memory for variance
147+
};
103148

104149
template <typename scalar_t>
105150
void call_fused_add_rms_norm_kernel(torch::Tensor& input,
@@ -116,16 +161,12 @@ void call_fused_add_rms_norm_kernel(torch::Tensor& input,
116161
sycl::range<3> block(1, 1, std::min(hidden_size, 1024));
117162
auto& queue = vllm::xpu::vllmGetQueue();
118163
queue.submit([&](sycl::handler& cgh) {
119-
sycl::local_accessor<float, 1> shared_vals(sycl::range<1>(32), cgh);
120164
sycl::local_accessor<float, 1> s_variance(sycl::range<1>(1), cgh);
121-
cgh.parallel_for(
122-
sycl::nd_range<3>(grid * block, block),
123-
[=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
124-
fused_add_rms_norm_kernel<sycl_t>(
125-
(sycl_t*)input_ptr, (sycl_t*)residual_ptr, input_stride,
126-
(const sycl_t*)weight_ptr, epsilon, num_tokens, hidden_size,
127-
item_ct1, s_variance.get_pointer());
128-
});
165+
cgh.parallel_for(sycl::nd_range<3>(grid * block, block),
166+
fused_add_rms_norm_kernel<sycl_t>(
167+
(sycl_t*)input_ptr, (sycl_t*)residual_ptr,
168+
input_stride, (const sycl_t*)weight_ptr, epsilon,
169+
num_tokens, hidden_size, s_variance));
129170
});
130171
}
131172

0 commit comments

Comments
 (0)