Skip to content

Conversation

@Isotr0py
Copy link
Member

@Isotr0py Isotr0py commented Dec 15, 2025

Purpose

Test Plan

pytest - s-v tests/kernels/attention/test_attention.py
pytest -s -v tests/kernels/attention/test_mha_attn.py

Test Result

Test should pass


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added llama Related to Llama models v1 tpu Related to Google TPUs labels Dec 15, 2025
Signed-off-by: Isotr0py <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the MultiHeadAttention class, used for multimodal encoders, into a new MMEncoderAttention class, moving its definition to vllm/attention/layers/mm_encoder_attention.py and removing it from vllm/attention/layer.py. All instances and imports of MultiHeadAttention across various model implementations (e.g., AIMV2, BLIP, CLIP, GLM4V, Idefics2, InternViT, MLlama4, MoLMo, SigLIP, Step3-VL, Whisper) and their respective test files have been updated to use MMEncoderAttention. The MMEncoderAttention class now directly integrates Flash Attention backend selection logic and removes a redundant reshape_qkv_to_3d method. However, a review comment points out a critical issue in the torch_sdpa_wrapper within vllm/attention/ops/vit_attn_wrappers.py, where torch.split is incorrectly applied on the sequence length dimension (dim=1) for batched inputs, assuming packed tensors. This causes a dimension mismatch and will lead to errors, with the reviewer suggesting to split along the batch dimension (dim=0) or use an alternative approach for handling batched inputs with SDPA in variable-length attention.

Comment on lines 103 to 112
q_chunks = torch.split(q, lens, dim=1)
k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1)

batch_size, q_len, _, _ = q.shape
if cu_seqlens is None:
cu_seqlens = torch.arange(
0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device
)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic in torch_sdpa_wrapper for handling batched inputs appears to be incorrect. The torch.split on dim=1 (lines 104-106) assumes that the input tensors q, k, v are packed (i.e., shape (1, total_tokens, ...)), but they are passed as batched tensors of shape (batch_size, seq_len, ...). This will cause torch.split to fail because sum(lens) will not match q.shape[1].

Additionally, the new block at lines 109-112 for when cu_seqlens is None is also flawed. It computes lens based on cu_seqlens which is computed for a uniform batch, but torch.split will still fail for the same reason.

To fix this, you should probably split along the batch dimension (dim=0) or use a different approach to handle batched inputs with SDPA for varlen attention.

Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

llama Related to Llama models tpu Related to Google TPUs v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant