-
Notifications
You must be signed in to change notification settings - Fork 661
[Perf] Optimize torch.where and vectorize PCP/DCP loops in mla_v1.py #5003
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:
If CI fails, you can run linting and testing checks locally according Contributing and Testing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces significant performance optimizations to the MLA v1 attention mechanism for Ascend NPUs. The changes primarily focus on replacing torch.where with the more efficient in-place masked_fill_ operation and vectorizing the PCP/DCP logic to eliminate Python loops and list manipulations. The refactoring in _npu_attention_update and _process_attn_out_lse correctly uses vectorized tensor operations, which should result in a noticeable performance improvement. The logic appears sound, and the changes are well-aligned with the goal of optimizing performance. Overall, this is a solid improvement.
75b57e3 to
81e1b59
Compare
|
This pull request has conflicts, please resolve those before we can evaluate the pull request. |
vllm_ascend/attention/mla_v1.py
Outdated
| softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse) | ||
| ) -> torch.Tensor: | ||
| out_lse_mask = decode_meta.batch_seq_mask[:, None, None].bool() | ||
| attn_output.masked_fill_(out_lse_mask, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually masked_fill is implemented by torch.where, I perfer to keep using torch.where. plz refer to https://github.com/pytorch/pytorch/blob/main/torch/_refs/__init__.py#L5935
81e1b59 to
d43cabc
Compare
Signed-off-by: F.Liu <[email protected]>
What this PR does / why we need it?
Does this PR introduce any user-facing change?
How was this patch tested?