Skip to content

Is it possible to relax V shape requirements to have different head dim than q/k? #753

Closed
@Maykeye

Description

Torch's SDPA doesn't require V to have the same dimensions as inputs, it even noted in docs with different dimensions E and Ev as when V is multiplied by, head dimensions is gone and we have only L x L matrix.

In [23]: qk = torch.randn(4, 4, 4, 8).bfloat16().cuda()  

In [24]: v = torch.randn(4, 4, 4, 16).bfloat16().cuda()

In [25]: F.scaled_dot_product_attention(qk, qk, v).shape
Out[25]: torch.Size([4, 4, 4, 16])

same with xfrormers, they use K and Kv in doc.

In [26]: xops.memory_efficient_attention(qk, qk,v).shape
Out[26]: torch.Size([4, 4, 4, 16])

However flash attention 2 [2.4.2] requires head dimensions to match.

In [27]: flash_attn.flash_attn_func(qk,qk,v)....
RuntimeError: v must have shape (batch_size, seqlen_k, num_heads_k, head_size_og)

(as documented it requires all tensors to have headdim per head (error uses different name than documentation))

can it be relaxed to have different head_size for v or implementation depends on head dimensions match?

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions