[Feature] Add SFA MLA prolog v3 path#10294
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request integrates a new SFA MLA prolog v3 path, specifically designed to enhance performance for models utilizing W8A8 dynamic weights while maintaining non-quantized KV cache and query. The changes involve adding a new configuration option to enable this feature, alongside the necessary backend logic to process weights and inputs according to the specified quantization modes. This update aims to provide a more optimized preprocessing pipeline for a particular set of quantization requirements. Highlights
New Features🧠 You can now enable Memory (public preview) to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize the Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counterproductive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
Suggested PR Title:
[Attention][Feature] Support SFA preprocessing with mla_prolog_v3Suggested PR Summary:
### What this PR does / why we need it?
This PR introduces support for SFA preprocessing using `mla_prolog_v3` in the Ascend backend. It adds the configuration option `enable_sfa_prolog_v3` (controlled via the `VLLM_ASCEND_ENABLE_SFA_PROLOG_V3` environment variable) and implements the corresponding weight processing, input formatting, and preprocessing steps in the SFA attention implementation.
Feedback:
An issue was identified in `vllm_ascend/attention/sfa_v1.py` where the private API `torch_npu._npu_reshape_and_cache` is called directly. It is recommended to use the `DeviceOperator.reshape_and_cache` abstraction instead and slice the inputs to `attn_metadata.num_actual_tokens` to ensure compatibility and correct shape alignment.
### Does this PR introduce _any_ user-facing change?
Yes, it introduces a new environment variable `VLLM_ASCEND_ENABLE_SFA_PROLOG_V3` and a configuration option `enable_sfa_prolog_v3` to enable SFA preprocessing with `mla_prolog_v3`.
### How was this patch tested?
The patch was tested with unit tests in `tests/ut/test_ascend_config.py` verifying the configuration fallback and override behavior.| torch_npu._npu_reshape_and_cache( | ||
| key=k_nope, | ||
| value=k_pe, | ||
| key_cache=kv_cache[0], | ||
| value_cache=kv_cache[1], | ||
| slot_indices=slot_mapping, | ||
| ) |
There was a problem hiding this comment.
Calling the private/internal API torch_npu._npu_reshape_and_cache directly is discouraged as it can lead to compatibility issues across different PyTorch/CANN versions. Additionally, passing un-sliced k_nope, k_pe, and slot_mapping can cause shape mismatches or incorrect cache writing if there is padding or if the gathered tensor size does not align with the un-sliced slot mapping.\n\nPlease use the established DeviceOperator.reshape_and_cache abstraction instead, and slice the inputs to attn_metadata.num_actual_tokens to ensure shape alignment and correctness, matching the pattern used in _all_gather_and_cache_dsa_cp_kv.
DeviceOperator.reshape_and_cache(\n key=k_nope[: attn_metadata.num_actual_tokens],\n value=k_pe[: attn_metadata.num_actual_tokens],\n key_cache=kv_cache[0],\n value_cache=kv_cache[1],\n slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens],\n )bf002e6 to
4913778
Compare
4913778 to
03d36db
Compare
Signed-off-by: ZYang6263 <zy626375@gmail.com>
6228f48 to
25b7e53
Compare
Summary
Adds an SFA-specific
mla_prolog_v3path and extends it to the packed int8 KV-cache case consumed bytorch_npu.npu_kv_quant_sparse_flash_attention.This path is intentionally independent from
enable_fa_quantand the existingmla_v1.pyflow. It is enabled with:Changes
main(ab065ffb) and resolved thesfa_v1.pyconflict.enable_sfa_prolog_v3/VLLM_ASCEND_ENABLE_SFA_PROLOG_V3config plumbing.enable_sfa_kv_quant_sparse_attention/VLLM_ASCEND_ENABLE_SFA_KV_QUANT_SPARSE_ATTENTIONconfig plumbing.mla_prolog_v3weight handles from W8A8 dynamic loaded weights insfa_v1.py.kv_cache_quant_mode=0.kv_cache_quant_mode=3,ckvkr_repo_mode=1,quant_scale_repo_mode=1, andtile_size=128.kv_cache[0]is packed int8 cache,kv_cache[1]is an emptykr_cache, and DSA indexer cache remains in later tuple entries.torch_npu.npu_kv_quant_sparse_flash_attentionwhen the packed int8 cache path is enabled.docs/source/developer_guide/Design_Documents/sfa_mla_prolog_v3_kv_quant.md.Validation
git diff --checkpython,py, oruvon PATH.Notes
The packed cache dimension is computed as:
For the common MLA shape this is
512 + 64 * 2 + 4 * 4 = 656.