Skip to content

Commit 1683b76

Browse files
authored
[MLA] add merge_attn_states sycl kernel (#64)
* add merge_attn_states Signed-off-by: Kunshang Ji <[email protected]> fix Signed-off-by: Kunshang Ji <[email protected]> format Signed-off-by: Kunshang Ji <[email protected]> * fix format Signed-off-by: Kunshang Ji <[email protected]> * add blank line Signed-off-by: Kunshang Ji <[email protected]> * update Signed-off-by: Kunshang Ji <[email protected]> * fix comments Signed-off-by: Kunshang Ji <[email protected]> --------- Signed-off-by: Kunshang Ji <[email protected]>
1 parent 391b8ba commit 1683b76

File tree

7 files changed

+562
-0
lines changed

7 files changed

+562
-0
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ if(BASIC_KERNELS_ENABLED)
410410
"csrc/quantization/fp8/fp8_quant.cpp"
411411
"csrc/quantization/fp4/mxfp4_quant.cpp"
412412
"csrc/xpu_view.cpp"
413+
"csrc/attention/merge_attn_states.cpp"
413414
"csrc/tensor_utils.cpp"
414415
"csrc/utils/mem_cpy.cpp"
415416
"csrc/topk_per_row.cpp")
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/DeviceGuard.h>
3+
4+
#include <sycl/sycl.hpp>
5+
#include <optional>
6+
#include <torch/all.h>
7+
#include <algorithm>
8+
#include "utils.h"
9+
10+
namespace vllm {
11+
12+
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
13+
// can be used to combine partial attention results (in the split-KV case)
14+
template <typename scalar_t, const uint NUM_THREADS>
15+
void merge_attn_states_kernel(
16+
scalar_t* output,
17+
float* output_lse,
18+
const scalar_t* prefix_output,
19+
const float* prefix_lse,
20+
const scalar_t* suffix_output,
21+
const float* suffix_lse,
22+
const uint num_tokens,
23+
const uint num_heads,
24+
const uint head_size,
25+
const uint prefix_head_stride,
26+
const uint output_head_stride,
27+
const sycl::nd_item<3>& item_ct1) {
28+
using pack_128b_t = sycl::uint4;
29+
const uint pack_size = 16 / sizeof(scalar_t);
30+
const uint threads_per_head = head_size / pack_size;
31+
32+
const uint global_idx =
33+
item_ct1.get_group(2) * NUM_THREADS + item_ct1.get_local_id(2);
34+
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
35+
36+
if (global_idx >= token_head_threads) return;
37+
38+
// global_idx -> token_idx + head_idx + pack_idx
39+
const uint token_head_idx = global_idx / threads_per_head;
40+
const uint pack_idx = global_idx % threads_per_head;
41+
42+
const uint token_idx = token_head_idx / num_heads;
43+
const uint head_idx = token_head_idx % num_heads;
44+
45+
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
46+
const uint src_head_offset = token_idx * num_heads * prefix_head_stride +
47+
head_idx * prefix_head_stride;
48+
const uint dst_head_offset = token_idx * num_heads * output_head_stride +
49+
head_idx * output_head_stride;
50+
const scalar_t* prefix_head_ptr = prefix_output + src_head_offset;
51+
const scalar_t* suffix_head_ptr = suffix_output + src_head_offset;
52+
scalar_t* output_head_ptr = output + dst_head_offset;
53+
54+
float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
55+
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
56+
p_lse = sycl::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
57+
s_lse = sycl::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
58+
59+
const float max_lse = sycl::fmax(p_lse, s_lse);
60+
61+
/* In certain edge cases, MLA can produce p_lse = s_lse = -inf;
62+
continuing the pipeline then yields NaN. Root cause: with chunked prefill
63+
a batch may be split into two chunks; if a request in that batch has no
64+
prefix hit, every LSE entry for that request's position is -inf, and at
65+
this moment we merge cross-attention at first. For now we simply emit
66+
prefix_output (expected to be all zeros) and prefix_lse (-inf) to fix
67+
this problem.
68+
*/
69+
if (sycl::isinf(max_lse)) {
70+
if (pack_offset < head_size) {
71+
// Pack 128b load
72+
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
73+
prefix_head_ptr)[pack_offset / pack_size];
74+
75+
// Pack 128b storage
76+
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
77+
p_out_pack;
78+
}
79+
// We only need to write to output_lse once per head.
80+
if (output_lse != nullptr && pack_idx == 0) {
81+
output_lse[head_idx * num_tokens + token_idx] = max_lse;
82+
}
83+
return;
84+
}
85+
86+
p_lse = p_lse - max_lse;
87+
s_lse = s_lse - max_lse;
88+
const float p_se = sycl::native::exp(p_lse);
89+
const float s_se = sycl::native::exp(s_lse);
90+
const float out_se = p_se + s_se;
91+
const float p_scale = p_se / out_se;
92+
const float s_scale = s_se / out_se;
93+
94+
if (pack_offset < head_size) {
95+
// Pack 128b load
96+
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
97+
prefix_head_ptr)[pack_offset / pack_size];
98+
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
99+
suffix_head_ptr)[pack_offset / pack_size];
100+
pack_128b_t o_out_pack;
101+
102+
#pragma unroll
103+
for (uint i = 0; i < pack_size; ++i) {
104+
// Always use float for FMA to keep high precision.
105+
// half(uint16_t), bfloat16, float -> float.
106+
const float p_out_f = vllm::xpu::to_float(
107+
reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
108+
const float s_out_f = vllm::xpu::to_float(
109+
reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
110+
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
111+
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
112+
// float -> half(uint16_t), bfloat16, float.
113+
vllm::xpu::from_float(
114+
reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
115+
}
116+
117+
// Pack 128b storage
118+
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
119+
o_out_pack;
120+
}
121+
// We only need to write to output_lse once per head.
122+
if (output_lse != nullptr && pack_idx == 0) {
123+
float out_lse = sycl::log((float)out_se) + max_lse;
124+
output_lse[head_idx * num_tokens + token_idx] = out_lse;
125+
}
126+
}
127+
128+
} // namespace vllm
129+
130+
// The following macro is used to dispatch the conversion function based on
131+
// the output data type. The FN is a macro that calls a function with
132+
// template<typename scalar_t>.
133+
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
134+
{ \
135+
if (scalar_dtype == at::ScalarType::Float) { \
136+
fn(float); \
137+
} else if (scalar_dtype == at::ScalarType::Half) { \
138+
fn(sycl::half); \
139+
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
140+
fn(sycl::ext::oneapi::bfloat16); \
141+
} else { \
142+
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
143+
} \
144+
}
145+
146+
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
147+
{ \
148+
((sycl::queue)(queue)).submit([&](sycl::handler& cgh) { \
149+
auto output_data_ptr_ct0 = \
150+
reinterpret_cast<scalar_t*>(output.data_ptr()); \
151+
auto output_lse_ptr_ct1 = output_lse_ptr; \
152+
auto prefix_output_data_ptr_ct2 = \
153+
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()); \
154+
auto prefix_lse_data_ptr_ct3 = \
155+
reinterpret_cast<float*>(prefix_lse.data_ptr()); \
156+
auto suffix_output_data_ptr_ct4 = \
157+
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()); \
158+
auto suffix_lse_data_ptr_ct5 = \
159+
reinterpret_cast<float*>(suffix_lse.data_ptr()); \
160+
auto num_tokens_ct6 = num_tokens; \
161+
auto num_heads_ct7 = num_heads; \
162+
auto head_size_ct8 = head_size; \
163+
auto prefix_head_stride_ct9 = prefix_head_stride; \
164+
auto output_head_stride_ct10 = output_head_stride; \
165+
\
166+
cgh.parallel_for( \
167+
sycl::nd_range<3>(grid * block, block), \
168+
[=](sycl::nd_item<3> item_ct1) { \
169+
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS>( \
170+
output_data_ptr_ct0, \
171+
output_lse_ptr_ct1, \
172+
prefix_output_data_ptr_ct2, \
173+
prefix_lse_data_ptr_ct3, \
174+
suffix_output_data_ptr_ct4, \
175+
suffix_lse_data_ptr_ct5, \
176+
num_tokens_ct6, \
177+
num_heads_ct7, \
178+
head_size_ct8, \
179+
prefix_head_stride_ct9, \
180+
output_head_stride_ct10, \
181+
item_ct1); \
182+
}); \
183+
}); \
184+
}
185+
186+
/*@brief Merges the attention states from prefix and suffix
187+
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
188+
*
189+
* @param output [n,h,d] The output tensor to store the merged attention states.
190+
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
191+
* @param prefix_output [n,h,d] The prefix attention states.
192+
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
193+
* states.
194+
* @param suffix_output [n,h,d] The suffix attention states.
195+
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
196+
* states.
197+
*/
198+
template <typename scalar_t>
199+
void merge_attn_states_launcher(
200+
torch::Tensor& output,
201+
std::optional<torch::Tensor> output_lse,
202+
const torch::Tensor& prefix_output,
203+
const torch::Tensor& prefix_lse,
204+
const torch::Tensor& suffix_output,
205+
const torch::Tensor& suffix_lse) {
206+
constexpr uint NUM_THREADS = 128;
207+
const uint num_tokens = output.size(0);
208+
const uint num_heads = output.size(1);
209+
const uint head_size = output.size(2);
210+
const uint prefix_head_stride = prefix_output.stride(1);
211+
const uint output_head_stride = output.stride(1);
212+
const uint pack_size = 16 / sizeof(scalar_t);
213+
TORCH_CHECK(
214+
head_size % pack_size == 0,
215+
"headsize must be multiple of pack_size:",
216+
pack_size);
217+
float* output_lse_ptr = nullptr;
218+
if (output_lse.has_value()) {
219+
output_lse_ptr = output_lse.value().data_ptr<float>();
220+
}
221+
// Process one pack elements per thread. for float, the
222+
// pack_size is 4 for half/bf16, the pack_size is 8.
223+
const uint threads_per_head = head_size / pack_size;
224+
const uint total_threads = num_tokens * num_heads * threads_per_head;
225+
226+
sycl::range<3> block(1, 1, NUM_THREADS);
227+
sycl::range<3> grid(1, 1, (total_threads + NUM_THREADS - 1) / NUM_THREADS);
228+
229+
at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
230+
at::DeviceGuard device_guard(curDevice);
231+
auto& queue = vllm::xpu::vllmGetQueue();
232+
233+
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
234+
}
235+
236+
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
237+
{ \
238+
merge_attn_states_launcher<scalar_t>( \
239+
output, \
240+
output_lse, \
241+
prefix_output, \
242+
prefix_lse, \
243+
suffix_output, \
244+
suffix_lse); \
245+
}
246+
247+
void merge_attn_states(
248+
torch::Tensor& output,
249+
std::optional<torch::Tensor> output_lse,
250+
const torch::Tensor& prefix_output,
251+
const torch::Tensor& prefix_lse,
252+
const torch::Tensor& suffix_output,
253+
const torch::Tensor& suffix_lse) {
254+
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
255+
}

csrc/ops.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,11 @@ void xpu_memcpy_sync(
179179
int64_t n_bytes,
180180
int64_t kind,
181181
int64_t device = -1);
182+
183+
void merge_attn_states(
184+
torch::Tensor& output,
185+
std::optional<torch::Tensor> output_lse,
186+
const torch::Tensor& prefix_output,
187+
const torch::Tensor& prefix_lse,
188+
const torch::Tensor& suffix_output,
189+
const torch::Tensor& suffix_lse);

csrc/torch_bindings.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
133133
"xpu_memcpy_sync(int dst_ptr, int src_ptr, int n_bytes, int kind, "
134134
"int device=-1) -> ()");
135135
ops.impl("xpu_memcpy_sync", &xpu_memcpy_sync);
136+
137+
// Merge attn states
138+
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
139+
// can be used to combine partial attention results (in the split-KV case)
140+
ops.def(
141+
"merge_attn_states("
142+
" Tensor! output,"
143+
" Tensor!? output_lse,"
144+
" Tensor prefix_output,"
145+
" Tensor prefix_lse,"
146+
" Tensor suffix_output,"
147+
" Tensor suffix_lse) -> ()");
148+
ops.impl("merge_attn_states", torch::kXPU, &merge_attn_states);
136149
}
137150

138151
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {

csrc/utils.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,22 @@ struct alignas(sizeof(scalar_t) * vec_size) aligned_vec {
126126
scalar_t const& operator[](int index) const { return val[index]; }
127127
};
128128

129+
// From float to float.
130+
inline void from_float(float& dst, float src) { dst = src; }
131+
// From float32 to float16.
132+
inline void from_float(sycl::half& dst, float src) { dst = sycl::half(src); }
133+
// From float32 to bfloat16.
134+
inline void from_float(sycl::ext::oneapi::bfloat16& dst, float src) {
135+
dst = sycl::ext::oneapi::bfloat16(src);
136+
}
137+
138+
// From float to float.
139+
inline float to_float(float u) { return u; }
140+
// From float16 to float32.
141+
inline float to_float(sycl::half u) { return float(u); }
142+
// From bfloat16 to float32.
143+
inline float to_float(sycl::ext::oneapi::bfloat16 u) { return float(u); }
144+
129145
} // namespace xpu
130146

131147
} // namespace vllm

tests/register_ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,19 @@ def deepseek_scaling_rope(
7676
rotary_dim, is_neox_style)
7777

7878

79+
# merge attn states ops
80+
def merge_attn_states(
81+
output: torch.Tensor,
82+
prefix_output: torch.Tensor,
83+
prefix_lse: torch.Tensor,
84+
suffix_output: torch.Tensor,
85+
suffix_lse: torch.Tensor,
86+
output_lse: torch.Tensor | None = None,
87+
) -> None:
88+
torch.ops._C.merge_attn_states(output, output_lse, prefix_output,
89+
prefix_lse, suffix_output, suffix_lse)
90+
91+
7992
def reshape_and_cache(
8093
key: torch.Tensor,
8194
value: torch.Tensor,

0 commit comments

Comments
 (0)