Skip to content

Conversation

@ader47
Copy link

@ader47 ader47 commented Dec 14, 2025

What this PR does / why we need it?

  • Replace torch.where() with masked_fill_()
  • Replace nested PCP/DCP Python loops with fully vectorized tensor operations

Does this PR introduce any user-facing change?

How was this patch tested?

@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM Ascend project. The following points will speed up your PR merge:‌‌

  • A PR should do only one thing, smaller PRs enable faster reviews.
  • Every PR should include unit tests and end-to-end tests ‌to ensure it works and is not broken by other future PRs.
  • Write the commit message by fulfilling the PR description to help reviewer and future developers understand.

If CI fails, you can run linting and testing checks locally according Contributing and Testing.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant performance optimizations to the MLA v1 attention mechanism for Ascend NPUs. The changes primarily focus on replacing torch.where with the more efficient in-place masked_fill_ operation and vectorizing the PCP/DCP logic to eliminate Python loops and list manipulations. The refactoring in _npu_attention_update and _process_attn_out_lse correctly uses vectorized tensor operations, which should result in a noticeable performance improvement. The logic appears sound, and the changes are well-aligned with the goal of optimizing performance. Overall, this is a solid improvement.

@ader47 ader47 force-pushed the optimize-mla-cp branch 4 times, most recently from 75b57e3 to 81e1b59 Compare December 15, 2025 01:20
@github-actions
Copy link

This pull request has conflicts, please resolve those before we can evaluate the pull request.

softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
) -> torch.Tensor:
out_lse_mask = decode_meta.batch_seq_mask[:, None, None].bool()
attn_output.masked_fill_(out_lse_mask, 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually masked_fill is implemented by torch.where, I perfer to keep using torch.where. plz refer to https://github.com/pytorch/pytorch/blob/main/torch/_refs/__init__.py#L5935

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants