-
-
Notifications
You must be signed in to change notification settings - Fork 12k
[MM Encoder]: Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention interface
#30684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the MultiHeadAttention class, used for multimodal encoders, into a new MMEncoderAttention class, moving its definition to vllm/attention/layers/mm_encoder_attention.py and removing it from vllm/attention/layer.py. All instances and imports of MultiHeadAttention across various model implementations (e.g., AIMV2, BLIP, CLIP, GLM4V, Idefics2, InternViT, MLlama4, MoLMo, SigLIP, Step3-VL, Whisper) and their respective test files have been updated to use MMEncoderAttention. The MMEncoderAttention class now directly integrates Flash Attention backend selection logic and removes a redundant reshape_qkv_to_3d method. However, a review comment points out a critical issue in the torch_sdpa_wrapper within vllm/attention/ops/vit_attn_wrappers.py, where torch.split is incorrectly applied on the sequence length dimension (dim=1) for batched inputs, assuming packed tensors. This causes a dimension mismatch and will lead to errors, with the reviewer suggesting to split along the batch dimension (dim=0) or use an alternative approach for handling batched inputs with SDPA in variable-length attention.
| q_chunks = torch.split(q, lens, dim=1) | ||
| k_chunks = torch.split(k, lens, dim=1) | ||
| v_chunks = torch.split(v, lens, dim=1) | ||
|
|
||
| batch_size, q_len, _, _ = q.shape | ||
| if cu_seqlens is None: | ||
| cu_seqlens = torch.arange( | ||
| 0, (batch_size + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device | ||
| ) | ||
| for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in torch_sdpa_wrapper for handling batched inputs appears to be incorrect. The torch.split on dim=1 (lines 104-106) assumes that the input tensors q, k, v are packed (i.e., shape (1, total_tokens, ...)), but they are passed as batched tensors of shape (batch_size, seq_len, ...). This will cause torch.split to fail because sum(lens) will not match q.shape[1].
Additionally, the new block at lines 109-112 for when cu_seqlens is None is also flawed. It computes lens based on cu_seqlens which is computed for a uniform batch, but torch.split will still fail for the same reason.
To fix this, you should probably split along the batch dimension (dim=0) or use a different approach to handle batched inputs with SDPA for varlen attention.
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Purpose
MultiHeadAttentionusage to newMMEncoderAttentionTest Plan
Test Result
Test should pass
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.