Skip to content

Commit 699feaf

Browse files
authored
layernorm support uncontiguous (vllm-project#131)
* add ut for static quant fp8 Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * remove useless val Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * layernorm support contiguous and add vec Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * use fixed VEC_SIZE Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * fix typo Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * add ut for uncontiguous input of rms_norm and format Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> * fix conor case Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com> --------- Signed-off-by: Zhu, Zufang <zufang.zhu@intel.com>
1 parent a78e733 commit 699feaf

3 files changed

Lines changed: 221 additions & 29 deletions

File tree

csrc/dispatch_utils.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,24 @@
8787
constexpr bool const_expr = false; \
8888
__VA_ARGS__(); \
8989
}
90+
91+
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
92+
switch (NUM_DIMS) { \
93+
case 2: { \
94+
constexpr int tensor_rank = 2; \
95+
__VA_ARGS__(); \
96+
break; \
97+
} \
98+
case 3: { \
99+
constexpr int tensor_rank = 3; \
100+
__VA_ARGS__(); \
101+
break; \
102+
} \
103+
case 4: { \
104+
constexpr int tensor_rank = 4; \
105+
__VA_ARGS__(); \
106+
break; \
107+
} \
108+
default: \
109+
TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
110+
}

csrc/layernorm.cpp

Lines changed: 156 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,35 @@
77
namespace vllm {
88

99
template <typename scalar_t>
10+
struct alignas(8) vec4_t {
11+
scalar_t val[4];
12+
};
13+
14+
// The vector width is fixed at 4 to avoid excessive branching in the kernel,
15+
// which could degrade performance.
16+
template <typename scalar_t, int NUM_DIMS, int VEC_SIZE = 4>
1017
class rms_norm_kernel {
1118
public:
1219
rms_norm_kernel(
1320
scalar_t* out_,
1421
const scalar_t* input_,
15-
const int64_t input_stride_,
22+
const int64_t input_stride_d2_, // input.stride(-2)
23+
const int64_t input_stride_d3_, // input.stride(-3)
24+
const int64_t input_stride_d4_, // input.stride(-4)
25+
const int64_t input_shape_d2_, // input.size(-2)
26+
const int64_t input_shape_d3_, // input.size(-3)
1627
const scalar_t* weight_,
1728
const float epsilon_,
1829
const int num_tokens_,
1930
const int hidden_size_,
2031
sycl::local_accessor<float, 1> s_variance_)
2132
: out(out_),
2233
input(input_),
23-
input_stride(input_stride_),
34+
input_stride_d2(input_stride_d2_),
35+
input_stride_d3(input_stride_d3_),
36+
input_stride_d4(input_stride_d4_),
37+
input_shape_d2(input_shape_d2_),
38+
input_shape_d3(input_shape_d3_),
2439
weight(weight_),
2540
epsilon(epsilon_),
2641
num_tokens(num_tokens_),
@@ -33,10 +48,80 @@ class rms_norm_kernel {
3348
s_variance.template get_multi_ptr<sycl::access::decorated::no>().get();
3449
float variance = 0.0f;
3550

36-
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
37-
idx += item_ct1.get_local_range(2)) {
38-
const float x = (float)input[item_ct1.get_group(2) * input_stride + idx];
51+
const scalar_t* input_row;
52+
if constexpr (NUM_DIMS == 2) {
53+
// 2D for layernorm normal case [batch_size, hidden]
54+
input_row = input + item_ct1.get_group(2) * input_stride_d2;
55+
} else if constexpr (NUM_DIMS == 3) {
56+
// 3D for q/k norm [batch_size, num_heads, head_size]
57+
int batch_idx = item_ct1.get_group(2) / input_shape_d2;
58+
int head_idx = item_ct1.get_group(2) % input_shape_d2;
59+
input_row =
60+
input + batch_idx * input_stride_d3 + head_idx * input_stride_d2;
61+
} else if constexpr (NUM_DIMS == 4) {
62+
// 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
63+
int batch_idx = item_ct1.get_group(2) / (input_shape_d3 * input_shape_d2);
64+
int remaining = item_ct1.get_group(2) % (input_shape_d3 * input_shape_d2);
65+
int seq_idx = remaining / input_shape_d2;
66+
int head_idx = remaining % input_shape_d2;
67+
input_row = input + batch_idx * input_stride_d4 +
68+
seq_idx * input_stride_d3 + head_idx * input_stride_d2;
69+
}
70+
71+
auto vec_op = [&variance](
72+
const vec4_t<scalar_t>& vec, int vec_size = VEC_SIZE) {
73+
for (int i = 0; i < vec_size; ++i) {
74+
float x = static_cast<float>(vec.val[i]);
75+
variance += x * x;
76+
}
77+
};
78+
auto scalar_op = [&variance](const scalar_t& val) {
79+
float x = static_cast<float>(val);
3980
variance += x * x;
81+
};
82+
83+
constexpr int WIDTH = VEC_SIZE * sizeof(scalar_t);
84+
uintptr_t addr_in = reinterpret_cast<uintptr_t>(input_row);
85+
86+
// fast path when the whole region is already aligned
87+
bool can_vec =
88+
((addr_in & (WIDTH - 1)) == 0) && ((hidden_size & (VEC_SIZE - 1)) == 0);
89+
if (can_vec) {
90+
int64_t const num_vec_elems = hidden_size / VEC_SIZE;
91+
auto const* vec_in = reinterpret_cast<const vec4_t<scalar_t>*>(input_row);
92+
for (int i = item_ct1.get_local_id(2); i < num_vec_elems;
93+
i += item_ct1.get_local_range(2)) {
94+
vec4_t<scalar_t> tmp = vec_in[i];
95+
vec_op(tmp);
96+
}
97+
} else {
98+
int misalignment_offset = addr_in & (WIDTH - 1);
99+
int alignment_bytes = WIDTH - misalignment_offset;
100+
int prefix_elems = alignment_bytes & (WIDTH - 1);
101+
prefix_elems /= sizeof(scalar_t);
102+
prefix_elems = prefix_elems < hidden_size ? prefix_elems : hidden_size;
103+
104+
// 1. handle the possibly unaligned prefix with scalar access.
105+
for (int i = item_ct1.get_local_id(2); i < prefix_elems;
106+
i += item_ct1.get_local_range(2)) {
107+
scalar_op(input_row[i]);
108+
}
109+
110+
int64_t const num_vec_elems = (hidden_size - prefix_elems) / VEC_SIZE;
111+
auto const* vec_in =
112+
reinterpret_cast<const vec4_t<scalar_t>*>(input_row + prefix_elems);
113+
for (int i = item_ct1.get_local_id(2); i < num_vec_elems;
114+
i += item_ct1.get_local_range(2)) {
115+
vec4_t<scalar_t> tmp = vec_in[i];
116+
vec_op(tmp);
117+
}
118+
119+
// 3. handle remaining tail elements.
120+
for (int i = item_ct1.get_local_id(2) + num_vec_elems * VEC_SIZE;
121+
i < hidden_size - prefix_elems;
122+
i += item_ct1.get_local_range(2)) {
123+
scalar_op((input_row + prefix_elems)[i]);
124+
}
40125
}
41126

42127
variance = sycl::reduce_over_group(
@@ -49,18 +134,47 @@ class rms_norm_kernel {
49134

50135
item_ct1.barrier(sycl::access::fence_space::local_space);
51136

52-
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
53-
idx += item_ct1.get_local_range(2)) {
54-
float x = (float)input[item_ct1.get_group(2) * input_stride + idx];
55-
out[item_ct1.get_group(2) * hidden_size + idx] =
56-
((scalar_t)(x * (*s_variance_ptr))) * weight[idx];
137+
scalar_t* out_row = out + item_ct1.get_group(2) * hidden_size;
138+
uintptr_t addr_weight = reinterpret_cast<uintptr_t>(weight);
139+
uintptr_t addr_out = reinterpret_cast<uintptr_t>(out_row);
140+
bool can_vec_out = ((addr_in & (WIDTH - 1)) == 0) &&
141+
((addr_weight & (WIDTH - 1)) == 0) &&
142+
((addr_out & (WIDTH - 1)) == 0) &&
143+
((hidden_size & (VEC_SIZE - 1)) == 0);
144+
if (can_vec_out) {
145+
auto* v_in = reinterpret_cast<const vec4_t<scalar_t>*>(input_row);
146+
auto* v_w = reinterpret_cast<const vec4_t<scalar_t>*>(weight);
147+
auto* v_out = reinterpret_cast<vec4_t<scalar_t>*>(out_row);
148+
int64_t const out_num_vec_elems = hidden_size / VEC_SIZE;
149+
float s_variance_val = *s_variance_ptr;
150+
for (int idx = item_ct1.get_local_id(2); idx < out_num_vec_elems;
151+
idx += item_ct1.get_local_range(2)) {
152+
vec4_t<scalar_t> dst;
153+
vec4_t<scalar_t> src1 = v_in[idx];
154+
vec4_t<scalar_t> src2 = v_w[idx];
155+
for (int j = 0; j < VEC_SIZE; j++) {
156+
float x = static_cast<float>(src1.val[j]);
157+
dst.val[j] = ((scalar_t)(x * s_variance_val)) * src2.val[j];
158+
}
159+
v_out[idx] = dst;
160+
}
161+
} else {
162+
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
163+
idx += item_ct1.get_local_range(2)) {
164+
float x = (float)input_row[idx];
165+
out_row[idx] = ((scalar_t)(x * (*s_variance_ptr))) * weight[idx];
166+
}
57167
}
58168
}
59169

60170
private:
61171
scalar_t* __restrict__ out; // [..., hidden_size]
62172
const scalar_t* __restrict__ input; // [..., hidden_size]
63-
const int64_t input_stride;
173+
const int64_t input_stride_d2;
174+
const int64_t input_stride_d3;
175+
const int64_t input_stride_d4;
176+
const int64_t input_shape_d2;
177+
const int64_t input_shape_d3;
64178
const scalar_t* __restrict__ weight; // [hidden_size]
65179
const float epsilon;
66180
const int num_tokens;
@@ -77,26 +191,39 @@ void call_rms_norm_kernel(
77191
using sycl_t = typename vllm::xpu::SyclTypeTrait<scalar_t>::Type;
78192
int hidden_size = input.size(-1);
79193
int num_tokens = input.numel() / hidden_size;
80-
int64_t input_stride = input.stride(-2);
194+
int num_dims = input.dim();
195+
int64_t input_stride_d2 = input.stride(-2);
196+
int64_t input_stride_d3 = (num_dims >= 3) ? input.stride(-3) : 0;
197+
int64_t input_stride_d4 = (num_dims >= 4) ? input.stride(-4) : 0;
198+
int64_t input_shape_d2 = (num_dims >= 3) ? input.size(-2) : 0;
199+
int64_t input_shape_d3 = (num_dims >= 4) ? input.size(-3) : 0;
200+
81201
auto out_ptr = out.data_ptr<scalar_t>();
82202
auto input_ptr = input.data_ptr<scalar_t>();
83203
auto weight_ptr = weight.data_ptr<scalar_t>();
84204
sycl::range<3> grid(1, 1, num_tokens);
85205
sycl::range<3> block(1, 1, std::min(hidden_size, 1024));
86206
auto& queue = vllm::xpu::vllmGetQueue();
87-
queue.submit([&](sycl::handler& cgh) {
88-
sycl::local_accessor<float, 1> s_variance(sycl::range<1>(1), cgh);
89-
cgh.parallel_for(
90-
sycl::nd_range<3>(grid * block, block),
91-
vllm::rms_norm_kernel<sycl_t>(
92-
(sycl_t*)out_ptr,
93-
(const sycl_t*)input_ptr,
94-
input_stride,
95-
(const sycl_t*)weight_ptr,
96-
epsilon,
97-
num_tokens,
98-
hidden_size,
99-
s_variance));
207+
208+
VLLM_DISPATCH_RANK234(num_dims, [&]() {
209+
queue.submit([&](sycl::handler& cgh) {
210+
sycl::local_accessor<float, 1> s_variance(sycl::range<1>(1), cgh);
211+
cgh.parallel_for(
212+
sycl::nd_range<3>(grid * block, block),
213+
vllm::rms_norm_kernel<sycl_t, tensor_rank>(
214+
(sycl_t*)out_ptr,
215+
(const sycl_t*)input_ptr,
216+
input_stride_d2,
217+
input_stride_d3,
218+
input_stride_d4,
219+
input_shape_d2,
220+
input_shape_d3,
221+
(const sycl_t*)weight_ptr,
222+
epsilon,
223+
num_tokens,
224+
hidden_size,
225+
s_variance));
226+
});
100227
});
101228
}
102229

@@ -205,7 +332,10 @@ void rms_norm(
205332
torch::Tensor& weight,
206333
double epsilon) {
207334
TORCH_CHECK(out.is_contiguous());
208-
input = input.contiguous();
335+
if (input.stride(-1) != 1) {
336+
input = input.contiguous();
337+
}
338+
TORCH_CHECK(input.stride(-1) == 1);
209339
TORCH_CHECK(weight.is_contiguous());
210340
VLLM_DISPATCH_FLOATING_TYPES(
211341
input.scalar_type(), "call_rms_norm_kernel", [&] {

tests/test_layernorm.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,19 @@
99

1010
DTYPES = [torch.half, torch.bfloat16]
1111
NUM_TOKENS = [7, 83, 4096] # Arbitrary values for testing
12-
#TODO: add back 5120, 5124, 5125, 5126, 8192, 8199 after ci env issue fixed
12+
# TODO: add back 5120, 5124, 5125, 5126, 8192, 8199 after ci env issue fixed
1313
HIDDEN_SIZES = [8, 768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192,
1414
8199] # Arbitrary values for testing
15-
15+
HEAD_DIMS = [128, 64]
16+
NUM_Q_HEADS = [32, 40, 64]
17+
NUM_KV_HEADS = [8, 32]
1618
ADD_RESIDUAL = [False, True]
1719
SEEDS = [0]
1820
XPU_DEVICES = [
1921
f"xpu:{i}" for i in range(1 if torch.xpu.device_count() == 1 else 2)
2022
]
2123

22-
#override pytest parameters when enable mini pytest
24+
# override pytest parameters when enable mini pytest
2325
MINI_PYTEST_PARAMS = {
2426
"default": {
2527
"num_tokens": [7],
@@ -78,3 +80,42 @@ def test_rms_norm(
7880
else:
7981
opcheck(torch.ops._C.rms_norm,
8082
(out, x, layer.weight.data, layer.variance_epsilon))
83+
84+
85+
@pytest.mark.parametrize("num_tokens", NUM_TOKENS)
86+
@pytest.mark.parametrize("head_dim", HEAD_DIMS)
87+
@pytest.mark.parametrize("num_q_heads", NUM_Q_HEADS)
88+
@pytest.mark.parametrize("num_kv_heads", NUM_KV_HEADS)
89+
@pytest.mark.parametrize("dtype", DTYPES)
90+
@pytest.mark.parametrize("device", XPU_DEVICES)
91+
@pytest.mark.parametrize("seed", SEEDS)
92+
@torch.inference_mode()
93+
def test_rms_norm_uncontigous(
94+
num_tokens: int,
95+
head_dim: int,
96+
num_q_heads: int,
97+
num_kv_heads: int,
98+
dtype: torch.dtype,
99+
device: str,
100+
seed: int,
101+
) -> None:
102+
torch.manual_seed(seed)
103+
torch.set_default_device("xpu")
104+
torch.xpu.set_device(device)
105+
106+
hidden_size = (num_q_heads + 2 * num_kv_heads) * head_dim
107+
qkv = torch.randn(num_tokens, hidden_size, dtype=dtype)
108+
q_size = num_q_heads * head_dim
109+
kv_size = num_kv_heads * head_dim
110+
q, _, _ = qkv.split([q_size, kv_size, kv_size], dim=-1)
111+
q_by_head = q.view(*q.shape[:-1], q.shape[-1] // head_dim, head_dim)
112+
113+
layer = RMSNorm(head_dim).to(dtype=dtype)
114+
ref_out = layer.forward_native(q_by_head)
115+
out = layer(q_by_head)
116+
torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2)
117+
118+
opcheck(
119+
torch.ops._C.rms_norm,
120+
(out, q_by_head, layer.weight.data, layer.variance_epsilon),
121+
)

0 commit comments

Comments
 (0)