Skip to content

model: use F.scaled_dot_product_attention for ViT-family attention#548

Open
rebel-junamsong wants to merge 1 commit intomainfrom
model/vit-family-use-sdpa
Open

model: use F.scaled_dot_product_attention for ViT-family attention#548
rebel-junamsong wants to merge 1 commit intomainfrom
model/vit-family-use-sdpa

Conversation

@rebel-junamsong
Copy link
Copy Markdown
Contributor

@rebel-junamsong rebel-junamsong commented Apr 17, 2026

Summary

Replace manual matmul → softmax → matmul attention in ViT-family vision modules with F.scaled_dot_product_attention. This routes the op through rbln-compiler's dedicated SDPA custom-op path (mapped to __paged_normal_*_sdpa_{5d,6d}_ kernels) instead of the generic relay fallback, improving numerical stability under compDtype=bfloat.

Affected models

  • qwen2_vlVisionAttention
  • qwen2_5_vlQwen2_5_VLVisionFullAttention, Qwen2_5_VLVisionWindowAttention
  • pixtralPixtralAttention

Qwen3VLVisionAttention already uses SDPA. GroundingDINO variants use non-standard attention (dual-softmax / grid_sample) and are not applicable.

4D vs 5D tensor rank

Each model keeps its original Q/K/V rank — no forced unification. rbln-compiler dispatches to different kernels based on rank:

Input rank Internal layout Kernel
4D [B, H, L, D] NHCW64c __paged_normal_*_sdpa_5d_
5D [B, H, 1, L, D] NDHCW64c __paged_normal_*_sdpa_6d_

The SDPA converter path in rebel.core.custom_converter._pt_attention has no rank assertion, so both 4D and 5D work.

Model Rank Kernel
Qwen2-VL 4D _5d_
Qwen2.5-VL Full / Window 4D _5d_
Pixtral 5D _6d_
Qwen3-VL (unchanged) 4D _5d_

Test plan

  • Qwen2-VL / ColQwen2 retrieval — compile + accuracy
  • Qwen2.5-VL image/video — compile + correlation vs native
  • Pixtral vision encoder — compile + accuracy
  • Qwen3-VL regression (shares compiler path)

Replace manual matmul+softmax+matmul in ViT-family vision attention
with F.scaled_dot_product_attention so the rbln-compiler lowers it
through the dedicated SDPA custom op path (mapped to the
__paged_normal_*_sdpa_{5d,6d}_ kernels) instead of the generic
relay fallback.

Affected models: Qwen2-VL, Qwen2.5-VL (full and window), Pixtral.

Co-Authored-By: Claude Opus 4.6 (1M context) <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant