-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
Describe the feature request
For WebNN EP, the graph builder does not accept input and output with dynamic shape. So after FreeDimensionOverride it is expected that all shape / dims are static.
There was already a shape inference function for GroupQueryAttention Op in BaseGroupQueryAttentionTypeAndShapeInference() of onnxruntime/core/graph/contrib_ops/bert_defs.cc. However, the use_max_past_present_buffer parameter is set to -1 for each case, as in the following code:
onnxruntime/onnxruntime/core/graph/contrib_ops/bert_defs.cc
Lines 319 to 323 in 81cd6ea
| void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int past_key_index) { | |
| // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not | |
| constexpr int use_max_past_present_buffer = -1; | |
| BaseGroupQueryAttentionTypeAndShapeInference(ctx, past_key_index, use_max_past_present_buffer); | |
| } |
So I was wondering if it is possible to pass an argument/flag to give it a chance to perform shape inference, at least when the shared buffer is used by some EPs.
Describe scenario use case
When some EPs use shared buffer for key / value cache, they pass the flag/argument to set the use_max_past_present_buffer to 1, which will enable the shape inference for GroupQueryAttention Ops.