Skip to content

Clean up flex_attention integration, extend_kv=False case #34

@mseeger

Description

@mseeger

Is your feature request related to a problem? Please describe.

flex_attention support grouped query attention (GQA), but for whatever reason, we get certain errors when using it. If extend_kv=True, GQA is not used, but this needs key, value to be extended, which takes quite some extra GPU memory (temporary).

Describe the solution you'd like
We'd like to be able to use GQA always, and eliminate extend_kv. The current solution for extend_kv=False is pretty weird, and potentially brittle (see comments).

We'd like to get to the bottom of this. Why does GQA not just always work? Can the issue be reproduced with a simple example?

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No fields configured for Task.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions