-
Notifications
You must be signed in to change notification settings - Fork 16
[MLA] add merge_attn_states sycl kernel
#64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
Signed-off-by: Kunshang Ji <[email protected]>
26b5133 to
4edf4ed
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds a SYCL kernel implementation for the merge_attn_states operation, which is used to combine partial attention results during the MLA (Multi-Head Latent Attention) chunked prefill stage. The implementation follows section 2.2 of the referenced paper (https://www.arxiv.org/pdf/2501.01005).
Key Changes:
- Implements
merge_attn_statesSYCL kernel with FP32, FP16, and BF16 support - Adds comprehensive test coverage with performance benchmarking
- Integrates the kernel into the build system and torch bindings
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| csrc/attention/merge_attn_states.cpp | Core SYCL kernel implementation for merging attention states with 128-bit packed operations |
| csrc/ops.h | Function declaration for merge_attn_states |
| csrc/torch_bindings.cpp | PyTorch C++ extension binding registration |
| csrc/utils.h | Helper functions for type conversion between float and half/bfloat16 |
| tests/test_merge_attn_states.py | Comprehensive test suite with PyTorch reference implementation and performance comparison |
| tests/register_ops.py | Python wrapper for the SYCL kernel operation |
| CMakeLists.txt | Adds new source file to build configuration |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| * into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d | ||
| * | ||
| * @param output [n,h,d] The output tensor to store the merged attention states. | ||
| * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. |
Copilot
AI
Nov 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dimension description for output_lse is incorrect. According to the code (line 90) and test file (line 141), output_lse should be [h,n] not [h,d].
| * @param output_lse [h,d] Optional tensor to store the log-sum-exp values. | |
| * @param output_lse [h,n] Optional tensor to store the log-sum-exp values. |
| logger = logging.getLogger("vllm_xpu_kernel") | ||
|
|
||
|
|
||
| # Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 |
Copilot
AI
Nov 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'Implements' to 'Implementation of'.
| # Naive PyTorch Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005 | |
| # Naive PyTorch Implementation of section 2.2 of https://www.arxiv.org/pdf/2501.01005 |
|
|
||
| all_case_info: list[tuple] = [] | ||
|
|
||
| #override pytest parameters when enable mini pytest |
Copilot
AI
Nov 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing article in comment. Should be 'when enabling mini pytest' or 'when mini pytest is enabled'.
| #override pytest parameters when enable mini pytest | |
| # override pytest parameters when enabling mini pytest |
| if output_lse is not None: | ||
| output_lse = torch.log(out_se) + max_lse |
Copilot
AI
Nov 10, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output_lse parameter is reassigned but never returned or used. This local assignment has no effect on the caller. The corrected value should be stored before returning it at line 49.
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.PLEASE FILL IN THE PR DESCRIPTION HERE ENSURING ALL CHECKLIST ITEMS ABOVE HAVE BEEN CONSIDERED.
Purpose
This PR add
merge_attn_stateskernel , which will be used in MLA chunked prefill stage.Test Plan
UT&CI
Test Result
pass
(Optional) Documentation Update
BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing (anything written below this line will be removed by GitHub Actions)