77namespace vllm {
88
99template <typename scalar_t >
10+ struct alignas (8 ) vec4_t {
11+ scalar_t val[4 ];
12+ };
13+
14+ // The vector width is fixed at 4 to avoid excessive branching in the kernel,
15+ // which could degrade performance.
16+ template <typename scalar_t , int NUM_DIMS, int VEC_SIZE = 4 >
1017class rms_norm_kernel {
1118 public:
1219 rms_norm_kernel (
1320 scalar_t * out_,
1421 const scalar_t * input_,
15- const int64_t input_stride_,
22+ const int64_t input_stride_d2_, // input.stride(-2)
23+ const int64_t input_stride_d3_, // input.stride(-3)
24+ const int64_t input_stride_d4_, // input.stride(-4)
25+ const int64_t input_shape_d2_, // input.size(-2)
26+ const int64_t input_shape_d3_, // input.size(-3)
1627 const scalar_t * weight_,
1728 const float epsilon_,
1829 const int num_tokens_,
1930 const int hidden_size_,
2031 sycl::local_accessor<float , 1 > s_variance_)
2132 : out(out_),
2233 input (input_),
23- input_stride(input_stride_),
34+ input_stride_d2(input_stride_d2_),
35+ input_stride_d3(input_stride_d3_),
36+ input_stride_d4(input_stride_d4_),
37+ input_shape_d2(input_shape_d2_),
38+ input_shape_d3(input_shape_d3_),
2439 weight(weight_),
2540 epsilon(epsilon_),
2641 num_tokens(num_tokens_),
@@ -33,10 +48,80 @@ class rms_norm_kernel {
3348 s_variance.template get_multi_ptr <sycl::access::decorated::no>().get ();
3449 float variance = 0 .0f ;
3550
36- for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
37- idx += item_ct1.get_local_range (2 )) {
38- const float x = (float )input[item_ct1.get_group (2 ) * input_stride + idx];
51+ const scalar_t * input_row;
52+ if constexpr (NUM_DIMS == 2 ) {
53+ // 2D for layernorm normal case [batch_size, hidden]
54+ input_row = input + item_ct1.get_group (2 ) * input_stride_d2;
55+ } else if constexpr (NUM_DIMS == 3 ) {
56+ // 3D for q/k norm [batch_size, num_heads, head_size]
57+ int batch_idx = item_ct1.get_group (2 ) / input_shape_d2;
58+ int head_idx = item_ct1.get_group (2 ) % input_shape_d2;
59+ input_row =
60+ input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
61+ } else if constexpr (NUM_DIMS == 4 ) {
62+ // 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
63+ int batch_idx = item_ct1.get_group (2 ) / (input_shape_d3 * input_shape_d2);
64+ int remaining = item_ct1.get_group (2 ) % (input_shape_d3 * input_shape_d2);
65+ int seq_idx = remaining / input_shape_d2;
66+ int head_idx = remaining % input_shape_d2;
67+ input_row = input + batch_idx * input_stride_d4 +
68+ seq_idx * input_stride_d3 + head_idx * input_stride_d2;
69+ }
70+
71+ auto vec_op = [&variance](
72+ const vec4_t <scalar_t >& vec, int vec_size = VEC_SIZE) {
73+ for (int i = 0 ; i < vec_size; ++i) {
74+ float x = static_cast <float >(vec.val [i]);
75+ variance += x * x;
76+ }
77+ };
78+ auto scalar_op = [&variance](const scalar_t & val) {
79+ float x = static_cast <float >(val);
3980 variance += x * x;
81+ };
82+
83+ constexpr int WIDTH = VEC_SIZE * sizeof (scalar_t );
84+ uintptr_t addr_in = reinterpret_cast <uintptr_t >(input_row);
85+
86+ // fast path when the whole region is already aligned
87+ bool can_vec =
88+ ((addr_in & (WIDTH - 1 )) == 0 ) && ((hidden_size & (VEC_SIZE - 1 )) == 0 );
89+ if (can_vec) {
90+ int64_t const num_vec_elems = hidden_size / VEC_SIZE;
91+ auto const * vec_in = reinterpret_cast <const vec4_t <scalar_t >*>(input_row);
92+ for (int i = item_ct1.get_local_id (2 ); i < num_vec_elems;
93+ i += item_ct1.get_local_range (2 )) {
94+ vec4_t <scalar_t > tmp = vec_in[i];
95+ vec_op (tmp);
96+ }
97+ } else {
98+ int misalignment_offset = addr_in & (WIDTH - 1 );
99+ int alignment_bytes = WIDTH - misalignment_offset;
100+ int prefix_elems = alignment_bytes & (WIDTH - 1 );
101+ prefix_elems /= sizeof (scalar_t );
102+ prefix_elems = prefix_elems < hidden_size ? prefix_elems : hidden_size;
103+
104+ // 1. handle the possibly unaligned prefix with scalar access.
105+ for (int i = item_ct1.get_local_id (2 ); i < prefix_elems;
106+ i += item_ct1.get_local_range (2 )) {
107+ scalar_op (input_row[i]);
108+ }
109+
110+ int64_t const num_vec_elems = (hidden_size - prefix_elems) / VEC_SIZE;
111+ auto const * vec_in =
112+ reinterpret_cast <const vec4_t <scalar_t >*>(input_row + prefix_elems);
113+ for (int i = item_ct1.get_local_id (2 ); i < num_vec_elems;
114+ i += item_ct1.get_local_range (2 )) {
115+ vec4_t <scalar_t > tmp = vec_in[i];
116+ vec_op (tmp);
117+ }
118+
119+ // 3. handle remaining tail elements.
120+ for (int i = item_ct1.get_local_id (2 ) + num_vec_elems * VEC_SIZE;
121+ i < hidden_size - prefix_elems;
122+ i += item_ct1.get_local_range (2 )) {
123+ scalar_op ((input_row + prefix_elems)[i]);
124+ }
40125 }
41126
42127 variance = sycl::reduce_over_group (
@@ -49,18 +134,47 @@ class rms_norm_kernel {
49134
50135 item_ct1.barrier (sycl::access::fence_space::local_space);
51136
52- for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
53- idx += item_ct1.get_local_range (2 )) {
54- float x = (float )input[item_ct1.get_group (2 ) * input_stride + idx];
55- out[item_ct1.get_group (2 ) * hidden_size + idx] =
56- ((scalar_t )(x * (*s_variance_ptr))) * weight[idx];
137+ scalar_t * out_row = out + item_ct1.get_group (2 ) * hidden_size;
138+ uintptr_t addr_weight = reinterpret_cast <uintptr_t >(weight);
139+ uintptr_t addr_out = reinterpret_cast <uintptr_t >(out_row);
140+ bool can_vec_out = ((addr_in & (WIDTH - 1 )) == 0 ) &&
141+ ((addr_weight & (WIDTH - 1 )) == 0 ) &&
142+ ((addr_out & (WIDTH - 1 )) == 0 ) &&
143+ ((hidden_size & (VEC_SIZE - 1 )) == 0 );
144+ if (can_vec_out) {
145+ auto * v_in = reinterpret_cast <const vec4_t <scalar_t >*>(input_row);
146+ auto * v_w = reinterpret_cast <const vec4_t <scalar_t >*>(weight);
147+ auto * v_out = reinterpret_cast <vec4_t <scalar_t >*>(out_row);
148+ int64_t const out_num_vec_elems = hidden_size / VEC_SIZE;
149+ float s_variance_val = *s_variance_ptr;
150+ for (int idx = item_ct1.get_local_id (2 ); idx < out_num_vec_elems;
151+ idx += item_ct1.get_local_range (2 )) {
152+ vec4_t <scalar_t > dst;
153+ vec4_t <scalar_t > src1 = v_in[idx];
154+ vec4_t <scalar_t > src2 = v_w[idx];
155+ for (int j = 0 ; j < VEC_SIZE; j++) {
156+ float x = static_cast <float >(src1.val [j]);
157+ dst.val [j] = ((scalar_t )(x * s_variance_val)) * src2.val [j];
158+ }
159+ v_out[idx] = dst;
160+ }
161+ } else {
162+ for (int idx = item_ct1.get_local_id (2 ); idx < hidden_size;
163+ idx += item_ct1.get_local_range (2 )) {
164+ float x = (float )input_row[idx];
165+ out_row[idx] = ((scalar_t )(x * (*s_variance_ptr))) * weight[idx];
166+ }
57167 }
58168 }
59169
60170 private:
61171 scalar_t * __restrict__ out; // [..., hidden_size]
62172 const scalar_t * __restrict__ input; // [..., hidden_size]
63- const int64_t input_stride;
173+ const int64_t input_stride_d2;
174+ const int64_t input_stride_d3;
175+ const int64_t input_stride_d4;
176+ const int64_t input_shape_d2;
177+ const int64_t input_shape_d3;
64178 const scalar_t * __restrict__ weight; // [hidden_size]
65179 const float epsilon;
66180 const int num_tokens;
@@ -77,26 +191,39 @@ void call_rms_norm_kernel(
77191 using sycl_t = typename vllm::xpu::SyclTypeTrait<scalar_t >::Type;
78192 int hidden_size = input.size (-1 );
79193 int num_tokens = input.numel () / hidden_size;
80- int64_t input_stride = input.stride (-2 );
194+ int num_dims = input.dim ();
195+ int64_t input_stride_d2 = input.stride (-2 );
196+ int64_t input_stride_d3 = (num_dims >= 3 ) ? input.stride (-3 ) : 0 ;
197+ int64_t input_stride_d4 = (num_dims >= 4 ) ? input.stride (-4 ) : 0 ;
198+ int64_t input_shape_d2 = (num_dims >= 3 ) ? input.size (-2 ) : 0 ;
199+ int64_t input_shape_d3 = (num_dims >= 4 ) ? input.size (-3 ) : 0 ;
200+
81201 auto out_ptr = out.data_ptr <scalar_t >();
82202 auto input_ptr = input.data_ptr <scalar_t >();
83203 auto weight_ptr = weight.data_ptr <scalar_t >();
84204 sycl::range<3 > grid (1 , 1 , num_tokens);
85205 sycl::range<3 > block (1 , 1 , std::min (hidden_size, 1024 ));
86206 auto & queue = vllm::xpu::vllmGetQueue ();
87- queue.submit ([&](sycl::handler& cgh) {
88- sycl::local_accessor<float , 1 > s_variance (sycl::range<1 >(1 ), cgh);
89- cgh.parallel_for (
90- sycl::nd_range<3 >(grid * block, block),
91- vllm::rms_norm_kernel<sycl_t >(
92- (sycl_t *)out_ptr,
93- (const sycl_t *)input_ptr,
94- input_stride,
95- (const sycl_t *)weight_ptr,
96- epsilon,
97- num_tokens,
98- hidden_size,
99- s_variance));
207+
208+ VLLM_DISPATCH_RANK234 (num_dims, [&]() {
209+ queue.submit ([&](sycl::handler& cgh) {
210+ sycl::local_accessor<float , 1 > s_variance (sycl::range<1 >(1 ), cgh);
211+ cgh.parallel_for (
212+ sycl::nd_range<3 >(grid * block, block),
213+ vllm::rms_norm_kernel<sycl_t , tensor_rank>(
214+ (sycl_t *)out_ptr,
215+ (const sycl_t *)input_ptr,
216+ input_stride_d2,
217+ input_stride_d3,
218+ input_stride_d4,
219+ input_shape_d2,
220+ input_shape_d3,
221+ (const sycl_t *)weight_ptr,
222+ epsilon,
223+ num_tokens,
224+ hidden_size,
225+ s_variance));
226+ });
100227 });
101228}
102229
@@ -205,7 +332,10 @@ void rms_norm(
205332 torch::Tensor& weight,
206333 double epsilon) {
207334 TORCH_CHECK (out.is_contiguous ());
208- input = input.contiguous ();
335+ if (input.stride (-1 ) != 1 ) {
336+ input = input.contiguous ();
337+ }
338+ TORCH_CHECK (input.stride (-1 ) == 1 );
209339 TORCH_CHECK (weight.is_contiguous ());
210340 VLLM_DISPATCH_FLOATING_TYPES (
211341 input.scalar_type (), " call_rms_norm_kernel" , [&] {
0 commit comments