Skip to content

Commit 26b5133

Browse files
committed
format
Signed-off-by: Kunshang Ji <[email protected]>
1 parent 47fd0fd commit 26b5133

File tree

5 files changed

+92
-122
lines changed

5 files changed

+92
-122
lines changed

csrc/attention/merge_attn_states.cpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,18 @@ void merge_attn_states_kernel(scalar_t* output, float* output_lse,
6666
pack_128b_t o_out_pack;
6767

6868
#pragma unroll
69-
for (uint i = 0; i < pack_size; ++i) {
69+
for (uint i = 0; i < pack_size; ++i) {
7070
// Always use float for FMA to keep high precision.
7171
// half(uint16_t), bfloat16, float -> float.
72-
const float p_out_f =
73-
vllm::xpu::to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
74-
const float s_out_f =
75-
vllm::xpu::to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
72+
const float p_out_f = vllm::xpu::to_float(
73+
reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
74+
const float s_out_f = vllm::xpu::to_float(
75+
reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
7676
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
7777
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
7878
// float -> half(uint16_t), bfloat16, float.
79-
vllm::xpu::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
79+
vllm::xpu::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
80+
o_out_f);
8081
}
8182

8283
// Pack 128b storage
@@ -100,7 +101,7 @@ void merge_attn_states_kernel(scalar_t* output, float* output_lse,
100101
if (scalar_dtype == at::ScalarType::Float) { \
101102
fn(float); \
102103
} else if (scalar_dtype == at::ScalarType::Half) { \
103-
fn(sycl::half); \
104+
fn(sycl::half); \
104105
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
105106
fn(sycl::ext::oneapi::bfloat16); \
106107
} else { \
@@ -110,7 +111,7 @@ void merge_attn_states_kernel(scalar_t* output, float* output_lse,
110111

111112
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
112113
{ \
113-
((sycl::queue)(queue)).submit([&](sycl::handler& cgh) { \
114+
((sycl::queue)(queue)).submit([&](sycl::handler& cgh) { \
114115
auto output_data_ptr_ct0 = \
115116
reinterpret_cast<scalar_t*>(output.data_ptr()); \
116117
auto output_lse_ptr_ct1 = output_lse_ptr; \
@@ -181,8 +182,8 @@ void merge_attn_states_launcher(torch::Tensor& output,
181182
const uint threads_per_head = head_size / pack_size;
182183
const uint total_threads = num_tokens * num_heads * threads_per_head;
183184

184-
sycl::range<3> block(1,1,NUM_THREADS);
185-
sycl::range<3> grid(1,1,(total_threads + NUM_THREADS - 1) / NUM_THREADS);
185+
sycl::range<3> block(1, 1, NUM_THREADS);
186+
sycl::range<3> grid(1, 1, (total_threads + NUM_THREADS - 1) / NUM_THREADS);
186187

187188
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
188189
at::DeviceGuard device_guard(curDevice);

csrc/torch_bindings.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
9292
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
9393
// can be used to combine partial attention results (in the split-KV case)
9494
ops.def(
95-
"merge_attn_states("
96-
" Tensor! output,"
97-
" Tensor!? output_lse,"
98-
" Tensor prefix_output,"
99-
" Tensor prefix_lse,"
100-
" Tensor suffix_output,"
101-
" Tensor suffix_lse) -> ()");
95+
"merge_attn_states("
96+
" Tensor! output,"
97+
" Tensor!? output_lse,"
98+
" Tensor prefix_output,"
99+
" Tensor prefix_lse,"
100+
" Tensor suffix_output,"
101+
" Tensor suffix_lse) -> ()");
102102
ops.impl("merge_attn_states", torch::kXPU, &merge_attn_states);
103103
}
104104

csrc/utils.h

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,23 +69,18 @@ struct alignas(sizeof(scalar_t) * vec_size) aligned_vec {
6969
// From float to float.
7070
inline void from_float(float& dst, float src) { dst = src; }
7171
// From float32 to float16.
72-
inline void from_float(sycl::half& dst, float src) {
73-
dst = sycl::half(src);
74-
}
72+
inline void from_float(sycl::half& dst, float src) { dst = sycl::half(src); }
7573
// From float32 to bfloat16.
7674
inline void from_float(sycl::ext::oneapi::bfloat16& dst, float src) {
7775
dst = sycl::ext::oneapi::bfloat16(src);
7876
}
7977

8078
// From float to float.
81-
inline float to_float(float u) { return u; }
79+
inline float to_float(float u) { return u; }
8280
// From float16 to float32.
8381
inline float to_float(sycl::half u) { return float(u); }
8482
// From bfloat16 to float32.
85-
inline float to_float(sycl::ext::oneapi::bfloat16 u) {
86-
return float(u);
87-
}
88-
83+
inline float to_float(sycl::ext::oneapi::bfloat16 u) { return float(u); }
8984

9085
} // namespace xpu
9186

tests/register_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ def merge_attn_states(
8585
suffix_lse: torch.Tensor,
8686
output_lse: torch.Tensor | None = None,
8787
) -> None:
88-
torch.ops._C.merge_attn_states(
89-
output, output_lse, prefix_output, prefix_lse, suffix_output, suffix_lse
90-
)
88+
torch.ops._C.merge_attn_states(output, output_lse, prefix_output,
89+
prefix_lse, suffix_output, suffix_lse)
90+
9191

9292
def reshape_and_cache(
9393
key: torch.Tensor,

tests/test_merge_attn_states.py

Lines changed: 68 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,29 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
43
"""Tests for merge_attn_states function.
54
65
Run `pytest tests/test_merge_attn_states.py`.
76
"""
87

8+
import logging
9+
910
import pytest
1011
import torch
11-
import logging
1212

1313
from tests.register_ops import merge_attn_states as merge_attn_states_xpu
1414

1515
logger = logging.getLogger("vllm_xpu_kernel")
1616

1717

18-
1918
# Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
2019
# can be used to combine partial attention results (in the split-KV case)
2120
def merge_attn_states_torch(
22-
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
23-
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
24-
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
25-
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
26-
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
27-
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
21+
output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
22+
prefix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
23+
prefix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
24+
suffix_output: torch.Tensor, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
25+
suffix_lse: torch.Tensor, # [NUM_HEADS, NUM_TOKENS]
26+
output_lse: torch.Tensor | None = None, # [NUM_HEADS, NUM_TOKENS]
2827
):
2928
p_lse = prefix_lse
3029
s_lse = suffix_lse
@@ -42,8 +41,10 @@ def merge_attn_states_torch(
4241
output_lse = torch.log(out_se) + max_lse
4342
p_scale = p_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
4443
s_scale = s_lse_exp / out_se # [NUM_HEADS, NUM_TOKENS]
45-
p_scale = torch.transpose(p_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
46-
s_scale = torch.transpose(s_scale, 0, 1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
44+
p_scale = torch.transpose(p_scale, 0,
45+
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
46+
s_scale = torch.transpose(s_scale, 0,
47+
1).unsqueeze(2) # [NUM_TOKENS, NUM_HEADS, 1]
4748
output = prefix_output * p_scale + suffix_output * s_scale
4849
return output, output_lse
4950

@@ -66,13 +67,10 @@ def merge_attn_states_torch(
6667
}
6768

6869

69-
7070
def generate_markdown_table():
7171
global all_case_info
72-
table_header = (
73-
"| tokens | heads | headsize | dtype "
74-
"| device | torch | cuda | speedup |"
75-
)
72+
table_header = ("| tokens | heads | headsize | dtype "
73+
"| device | torch | cuda | speedup |")
7674
table_separator = "| --- | --- | --- | --- | --- | --- | --- | --- |"
7775

7876
def shortly_dtype(dtype: torch.dtype) -> str:
@@ -96,36 +94,33 @@ def shortly_device(device: str) -> str:
9694
) = info
9795
dtype = shortly_dtype(dtype)
9896
device = shortly_device(device)
99-
print(
100-
f"| {num_tokens} | {num_heads} | {head_size} "
101-
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
102-
f"| {avg_time_xpu_kernel:.5f}ms "
103-
f"| {performance_improved:.4f}x |"
104-
)
97+
print(f"| {num_tokens} | {num_heads} | {head_size} "
98+
f"| {dtype} | {device} | {avg_time_torch_kernel:.5f}ms "
99+
f"| {avg_time_xpu_kernel:.5f}ms "
100+
f"| {performance_improved:.4f}x |")
105101

106102

107103
@pytest.mark.parametrize("num_tokens", NUM_BATCH_TOKENS)
108104
@pytest.mark.parametrize("num_query_heads", NUM_QUERY_HEADS)
109105
@pytest.mark.parametrize("head_size", HEAD_SIZES)
110106
@pytest.mark.parametrize("output_dtype", DTYPES)
111107
@torch.inference_mode()
112-
def test_merge_attn_states(
113-
num_tokens: int, num_query_heads: int, head_size: int, output_dtype: torch.dtype
114-
):
108+
def test_merge_attn_states(num_tokens: int, num_query_heads: int,
109+
head_size: int, output_dtype: torch.dtype):
115110

116111
NUM_TOKENS = num_tokens
117112
NUM_HEADS = num_query_heads
118113
HEAD_SIZE = head_size
119114

120-
logger.debug(
121-
f"\nNUM_TOKENS:{NUM_TOKENS}, NUM_HEADS:{NUM_HEADS}, "
122-
f"HEAD_SIZE:{HEAD_SIZE}, DTYPE: {output_dtype}, "
123-
f"Device: xpu."
124-
)
125-
126115
# prefix_lse and suffix_lse contain inf and normal values
127-
prefix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="xpu")
128-
suffix_lse = torch.randn(NUM_HEADS, NUM_TOKENS, dtype=torch.float32, device="xpu")
116+
prefix_lse = torch.randn(NUM_HEADS,
117+
NUM_TOKENS,
118+
dtype=torch.float32,
119+
device="xpu")
120+
suffix_lse = torch.randn(NUM_HEADS,
121+
NUM_TOKENS,
122+
dtype=torch.float32,
123+
device="xpu")
129124

130125
# Generate boolean masks
131126
mask_prefix = torch.rand(NUM_HEADS, NUM_TOKENS) < 0.1
@@ -140,18 +135,18 @@ def test_merge_attn_states(
140135

141136
# Other input tensors (need to be initialized but
142137
# no actual calculation needed)
143-
output = torch.zeros(
144-
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="xpu"
145-
)
146-
output_lse = torch.zeros(
147-
(NUM_HEADS, NUM_TOKENS), dtype=torch.float32, device="xpu"
148-
)
149-
prefix_output = torch.randn(
150-
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="xpu"
151-
)
152-
suffix_output = torch.randn(
153-
(NUM_TOKENS, NUM_HEADS, HEAD_SIZE), dtype=output_dtype, device="xpu"
154-
)
138+
output = torch.zeros((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
139+
dtype=output_dtype,
140+
device="xpu")
141+
output_lse = torch.zeros((NUM_HEADS, NUM_TOKENS),
142+
dtype=torch.float32,
143+
device="xpu")
144+
prefix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
145+
dtype=output_dtype,
146+
device="xpu")
147+
suffix_output = torch.randn((NUM_TOKENS, NUM_HEADS, HEAD_SIZE),
148+
dtype=output_dtype,
149+
device="xpu")
155150

156151
warmup_times = 2
157152
repeat_times = 20
@@ -226,60 +221,39 @@ def test_merge_attn_states(
226221

227222
# 2. Performance compare
228223
performance_improved = avg_time_torch_kernel / avg_time_xpu_kernel
229-
logger.debug(f" Torch time: {avg_time_torch_kernel:.6f}ms")
230-
logger.debug(
231-
f" XPU time: {avg_time_xpu_kernel:.6f}ms, "
232-
f"Performance: {performance_improved:.5f}x"
233-
)
234-
logger.debug("-" * 100)
235-
236-
# 4. Correctness compare
224+
# print(f" Torch time: {avg_time_torch_kernel:.6f}ms")
225+
# print(f" XPU time: {avg_time_xpu_kernel:.6f}ms, "
226+
# f"Performance: {performance_improved:.5f}x")
227+
# print("-" * 100)
228+
229+
# 3. Correctness compare
237230
# Liger Kernel: Efficient Triton Kernels for LLM Training
238231
# https://arxiv.org/pdf/2410.10989, 3.3 Correctness
239232
# use rtol = 1e-2 for bfloat16.
240233
rtol = 1e-2 if output_dtype == torch.bfloat16 else 1e-3
241234

242-
def diff(a: torch.Tensor, b: torch.Tensor):
243-
max_diff = torch.max(torch.abs(a.float() - b.float()))
244-
return max_diff
245-
246-
# Use Triton output as reference because we want to replace
247-
# the Triton kernel with custom XPU kernel for merge attn
248-
# states operation.
249-
torch.testing.assert_close(
250-
output_xpu.float(), output_torch.float(), atol=1e-3, rtol=rtol
251-
)
252-
logger.debug("Output all match, max abs diff:")
253-
logger.debug(f" (XPU vs Torch) : {diff(output_torch, output_xpu)}")
254-
logger.debug("-" * 100)
255-
256-
torch.testing.assert_close(
257-
output_lse_xpu.float(), output_lse_torch.float(), atol=1e-3, rtol=rtol
258-
)
259-
logger.debug("Output LSE all match, max abs diff:")
260-
logger.debug(f" (XPU vs Torch) : {diff(output_lse_torch, output_lse_xpu)}")
261-
logger.debug("-" * 100)
262-
263-
logger.debug(
264-
"All output values test passed! All inf values "
265-
"are correctly replaced with -inf."
266-
)
267-
logger.debug("-" * 100)
235+
# Use torch output as reference
236+
torch.testing.assert_close(output_xpu.float(),
237+
output_torch.float(),
238+
atol=1e-3,
239+
rtol=rtol)
240+
241+
torch.testing.assert_close(output_lse_xpu.float(),
242+
output_lse_torch.float(),
243+
atol=1e-3,
244+
rtol=rtol)
268245

269246
device = "xpu"
270-
all_case_info.append(
271-
(
272-
NUM_TOKENS,
273-
NUM_HEADS,
274-
HEAD_SIZE,
275-
output_dtype,
276-
device,
277-
avg_time_torch_kernel,
278-
avg_time_xpu_kernel,
279-
performance_improved,
280-
)
281-
)
282-
if len(all_case_info) == (
283-
len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) * len(NUM_QUERY_HEADS) * len(DTYPES)
284-
):
247+
all_case_info.append((
248+
NUM_TOKENS,
249+
NUM_HEADS,
250+
HEAD_SIZE,
251+
output_dtype,
252+
device,
253+
avg_time_torch_kernel,
254+
avg_time_xpu_kernel,
255+
performance_improved,
256+
))
257+
if len(all_case_info) == (len(NUM_BATCH_TOKENS) * len(HEAD_SIZES) *
258+
len(NUM_QUERY_HEADS) * len(DTYPES)):
285259
generate_markdown_table()

0 commit comments

Comments
 (0)