Skip to content

[Feature Request] Shape inference for GroupQueryAttention Op #23189

@peishenyan

Description

@peishenyan

Describe the feature request

For WebNN EP, the graph builder does not accept input and output with dynamic shape. So after FreeDimensionOverride it is expected that all shape / dims are static.
There was already a shape inference function for GroupQueryAttention Op in BaseGroupQueryAttentionTypeAndShapeInference() of onnxruntime/core/graph/contrib_ops/bert_defs.cc. However, the use_max_past_present_buffer parameter is set to -1 for each case, as in the following code:

void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) {
// TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not
constexpr int use_max_past_present_buffer = -1;
BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer);
}

So I was wondering if it is possible to pass an argument/flag to give it a chance to perform shape inference, at least when the shared buffer is used by some EPs.

Describe scenario use case

When some EPs use shared buffer for key / value cache, they pass the flag/argument to set the use_max_past_present_buffer to 1, which will enable the shape inference for GroupQueryAttention Ops.

Metadata

Metadata

Assignees

No one assigned

    Labels

    ep:WebNNWebNN execution providerfeature requestrequest for unsupported feature or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions