Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ if(VLLM_GPU_LANG STREQUAL "SYCL")
"csrc/pos_encoding_kernels.cpp"
"csrc/torch_bindings.cpp"
"csrc/quantization/fp8/fp8_quant.cpp"
"csrc/attention/merge_attn_states.cpp"
)
include_directories("/usr/include")
list(APPEND VLLM_INCLUDE_DIR ${CMPLR_ROOT}/include/)
Expand Down
209 changes: 209 additions & 0 deletions csrc/attention/merge_attn_states.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
#include <ATen/ATen.h>
#include <ATen/DeviceGuard.h>

#include <sycl/sycl.hpp>
#include <optional>
#include <torch/all.h>
#include <algorithm>
#include "utils.h"

namespace vllm {

// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
template <typename scalar_t, const uint NUM_THREADS>
void merge_attn_states_kernel(scalar_t* output, float* output_lse,
const scalar_t* prefix_output,
const float* prefix_lse,
const scalar_t* suffix_output,
const float* suffix_lse, const uint num_tokens,
const uint num_heads, const uint head_size,
const sycl::nd_item<3>& item_ct1) {
using pack_128b_t = sycl::uint4;
const uint pack_size = 16 / sizeof(scalar_t);
const uint threads_per_head = head_size / pack_size;

const uint global_idx =
item_ct1.get_group(2) * NUM_THREADS + item_ct1.get_local_id(2);
const uint token_head_threads = num_tokens * num_heads * threads_per_head;

if (global_idx >= token_head_threads) return;

// global_idx -> token_idx + head_idx + pack_idx
const uint token_head_idx = global_idx / threads_per_head;
const uint pack_idx = global_idx % threads_per_head;

const uint token_idx = token_head_idx / num_heads;
const uint head_idx = token_head_idx % num_heads;

const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
const uint head_offset =
token_idx * num_heads * head_size + head_idx * head_size;
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
scalar_t* output_head_ptr = output + head_offset;

float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
p_lse = sycl::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
s_lse = sycl::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;

const float max_lse = sycl::fmax(p_lse, s_lse);
p_lse = p_lse - max_lse;
s_lse = s_lse - max_lse;
const float p_se = sycl::native::exp(p_lse);
const float s_se = sycl::native::exp(s_lse);
const float out_se = p_se + s_se;
const float p_scale = p_se / out_se;
const float s_scale = s_se / out_se;

if (pack_offset < head_size) {
// Pack 128b load
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(
prefix_head_ptr)[pack_offset / pack_size];
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(
suffix_head_ptr)[pack_offset / pack_size];
pack_128b_t o_out_pack;

#pragma unroll
for (uint i = 0; i < pack_size; ++i) {
// Always use float for FMA to keep high precision.
// half(uint16_t), bfloat16, float -> float.
const float p_out_f = vllm::xpu::to_float(
reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
const float s_out_f = vllm::xpu::to_float(
reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
// float -> half(uint16_t), bfloat16, float.
vllm::xpu::from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i],
o_out_f);
}

// Pack 128b storage
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] =
o_out_pack;
}
// We only need to write to output_lse once per head.
if (output_lse != nullptr && pack_idx == 0) {
float out_lse = sycl::log((float)out_se) + max_lse;
output_lse[head_idx * num_tokens + token_idx] = out_lse;
}
}

} // namespace vllm

// The following macro is used to dispatch the conversion function based on
// the output data type. The FN is a macro that calls a function with
// template<typename scalar_t>.
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
{ \
if (scalar_dtype == at::ScalarType::Float) { \
fn(float); \
} else if (scalar_dtype == at::ScalarType::Half) { \
fn(sycl::half); \
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
fn(sycl::ext::oneapi::bfloat16); \
} else { \
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
} \
}

#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
{ \
((sycl::queue)(queue)).submit([&](sycl::handler& cgh) { \
auto output_data_ptr_ct0 = \
reinterpret_cast<scalar_t*>(output.data_ptr()); \
auto output_lse_ptr_ct1 = output_lse_ptr; \
auto prefix_output_data_ptr_ct2 = \
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()); \
auto prefix_lse_data_ptr_ct3 = \
reinterpret_cast<float*>(prefix_lse.data_ptr()); \
auto suffix_output_data_ptr_ct4 = \
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()); \
auto suffix_lse_data_ptr_ct5 = \
reinterpret_cast<float*>(suffix_lse.data_ptr()); \
auto num_tokens_ct6 = num_tokens; \
auto num_heads_ct7 = num_heads; \
auto head_size_ct8 = head_size; \
\
cgh.parallel_for( \
sycl::nd_range<3>(grid * block, block), \
[=](sycl::nd_item<3> item_ct1) { \
vllm::merge_attn_states_kernel<scalar_t, NUM_THREADS>( \
output_data_ptr_ct0, output_lse_ptr_ct1, \
prefix_output_data_ptr_ct2, prefix_lse_data_ptr_ct3, \
suffix_output_data_ptr_ct4, suffix_lse_data_ptr_ct5, \
num_tokens_ct6, num_heads_ct7, head_size_ct8, item_ct1); \
}); \
}); \
}

/*@brief Merges the attention states from prefix and suffix
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
*
* @param output [n,h,d] The output tensor to store the merged attention states.
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
Copy link

Copilot AI Nov 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dimension description for output_lse is incorrect. According to the code (line 90) and test file (line 141), output_lse should be [h,n] not [h,d].

Suggested change
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
* @param output_lse [h,n] Optional tensor to store the log-sum-exp values.

Copilot uses AI. Check for mistakes.
* @param prefix_output [n,h,d] The prefix attention states.
* @param prefix_lse [h,n] The log-sum-exp values for the prefix attention
* states.
* @param suffix_output [n,h,d] The suffix attention states.
* @param suffix_lse [h,n] The log-sum-exp values for the suffix attention
* states.
*/
template <typename scalar_t>
void merge_attn_states_launcher(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
constexpr uint NUM_THREADS = 128;
const uint num_tokens = output.size(0);
const uint num_heads = output.size(1);
const uint head_size = output.size(2);
const uint pack_size = 16 / sizeof(scalar_t);
TORCH_CHECK(head_size % pack_size == 0,
"headsize must be multiple of pack_size:", pack_size);
TORCH_CHECK(output.stride(-2) == head_size && output.stride(-1) == 1,
"output heads must be contiguous in memory");
TORCH_CHECK(
prefix_output.stride(-2) == head_size && prefix_output.stride(-1) == 1,
"prefix_output heads must be contiguous in memory");
TORCH_CHECK(
suffix_output.stride(-2) == head_size && suffix_output.stride(-1) == 1,
"suffix_output heads must be contiguous in memory");
float* output_lse_ptr = nullptr;
if (output_lse.has_value()) {
output_lse_ptr = output_lse.value().data_ptr<float>();
}
// Process one pack elements per thread. for float, the
// pack_size is 4 for half/bf16, the pack_size is 8.
const uint threads_per_head = head_size / pack_size;
const uint total_threads = num_tokens * num_heads * threads_per_head;

sycl::range<3> block(1, 1, NUM_THREADS);
sycl::range<3> grid(1, 1, (total_threads + NUM_THREADS - 1) / NUM_THREADS);

at::Device curDevice = at::Device(at::kXPU, at::xpu::current_device());
at::DeviceGuard device_guard(curDevice);
auto& queue = vllm::xpu::vllmGetQueue();

LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
}

#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
{ \
merge_attn_states_launcher<scalar_t>(output, output_lse, prefix_output, \
prefix_lse, suffix_output, \
suffix_lse); \
}

void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse) {
DISPATCH_BY_SCALAR_DTYPE(output.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
}
7 changes: 7 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,10 @@ void dynamic_per_token_scaled_fp8_quant(

void swigluoai_and_mul(torch::Tensor& out, torch::Tensor& input,
double alpha = 1.702, double limit = 7.0);

void merge_attn_states(torch::Tensor& output,
std::optional<torch::Tensor> output_lse,
const torch::Tensor& prefix_output,
const torch::Tensor& prefix_lse,
const torch::Tensor& suffix_output,
const torch::Tensor& suffix_lse);
13 changes: 13 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"limit=7.0) "
"-> ()");
ops.impl("swigluoai_and_mul", torch::kXPU, &swigluoai_and_mul);

// Merge attn states
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
// can be used to combine partial attention results (in the split-KV case)
ops.def(
"merge_attn_states("
" Tensor! output,"
" Tensor!? output_lse,"
" Tensor prefix_output,"
" Tensor prefix_lse,"
" Tensor suffix_output,"
" Tensor suffix_lse) -> ()");
ops.impl("merge_attn_states", torch::kXPU, &merge_attn_states);
}

TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
Expand Down
16 changes: 16 additions & 0 deletions csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ struct alignas(sizeof(scalar_t) * vec_size) aligned_vec {
scalar_t const& operator[](int index) const { return val[index]; }
};

// From float to float.
inline void from_float(float& dst, float src) { dst = src; }
// From float32 to float16.
inline void from_float(sycl::half& dst, float src) { dst = sycl::half(src); }
// From float32 to bfloat16.
inline void from_float(sycl::ext::oneapi::bfloat16& dst, float src) {
dst = sycl::ext::oneapi::bfloat16(src);
}

// From float to float.
inline float to_float(float u) { return u; }
// From float16 to float32.
inline float to_float(sycl::half u) { return float(u); }
// From bfloat16 to float32.
inline float to_float(sycl::ext::oneapi::bfloat16 u) { return float(u); }

} // namespace xpu

} // namespace vllm
13 changes: 13 additions & 0 deletions tests/register_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,19 @@ def deepseek_scaling_rope(
rotary_dim, is_neox_style)


# merge attn states ops
def merge_attn_states(
output: torch.Tensor,
prefix_output: torch.Tensor,
prefix_lse: torch.Tensor,
suffix_output: torch.Tensor,
suffix_lse: torch.Tensor,
output_lse: torch.Tensor | None = None,
) -> None:
torch.ops._C.merge_attn_states(output, output_lse, prefix_output,
prefix_lse, suffix_output, suffix_lse)


def reshape_and_cache(
key: torch.Tensor,
value: torch.Tensor,
Expand Down
Loading