Is your feature request related to a problem? Please describe.
flex_attention support grouped query attention (GQA), but for whatever reason, we get certain errors when using it. If extend_kv=True, GQA is not used, but this needs key, value to be extended, which takes quite some extra GPU memory (temporary).
Describe the solution you'd like
We'd like to be able to use GQA always, and eliminate extend_kv. The current solution for extend_kv=False is pretty weird, and potentially brittle (see comments).
We'd like to get to the bottom of this. Why does GQA not just always work? Can the issue be reproduced with a simple example?
Is your feature request related to a problem? Please describe.
flex_attentionsupport grouped query attention (GQA), but for whatever reason, we get certain errors when using it. Ifextend_kv=True, GQA is not used, but this needskey, valueto be extended, which takes quite some extra GPU memory (temporary).Describe the solution you'd like
We'd like to be able to use GQA always, and eliminate
extend_kv. The current solution forextend_kv=Falseis pretty weird, and potentially brittle (see comments).We'd like to get to the bottom of this. Why does GQA not just always work? Can the issue be reproduced with a simple example?