Closed
Description
Torch's SDPA doesn't require V to have the same dimensions as inputs, it even noted in docs with different dimensions E and Ev as when V is multiplied by, head dimensions is gone and we have only L x L matrix.
In [23]: qk = torch.randn(4, 4, 4, 8).bfloat16().cuda()
In [24]: v = torch.randn(4, 4, 4, 16).bfloat16().cuda()
In [25]: F.scaled_dot_product_attention(qk, qk, v).shape
Out[25]: torch.Size([4, 4, 4, 16])
same with xfrormers, they use K
and Kv
in doc.
In [26]: xops.memory_efficient_attention(qk, qk,v).shape
Out[26]: torch.Size([4, 4, 4, 16])
However flash attention 2 [2.4.2] requires head dimensions to match.
In [27]: flash_attn.flash_attn_func(qk,qk,v)....
RuntimeError: v must have shape (batch_size, seqlen_k, num_heads_k, head_size_og)
(as documented it requires all tensors to have headdim
per head (error uses different name than documentation))
can it be relaxed to have different head_size for v or implementation depends on head dimensions match?
Metadata
Assignees
Labels
No labels