Skip to content

Commit 0a478c0

Browse files
authored
[Shape Inference] Fix GQA shape inference for present outputs (#27250)
### Description When using pre-allocated KV cache with `freeDimensionOverrides`, the shape inference for `present_key` and `present_value` outputs failed silently. This caused downstream graph operations to receive tensors with unknown dynamic shapes, leading to unexpected fallback in execution providers like WebNN. (WebNN currently doesn't support dynamic shape) ### Motivation and Context **Root cause**: In `BaseGroupQueryAttentionTypeAndShapeInference()`, the shape inference logic for `use_max_past_present_buffer == -1` only propagated shapes when BOTH conditions were met: 1. `total_sequence_length_value` was a concrete value (> 0) 2. `past_dims[2]` had a concrete dimension value When either condition failed (e.g., using `freeDimensionOverrides` which results in dynamic `past_sequence_length`), present output shapes were left uninitialized. Additionally, when `past_key/past_value` is not provided (prefill/first-token mode), no shape inference was performed for present outputs at all. **Fix**: 1. For `use_max_past_present_buffer == -1`: - Always construct and propagate `present_shape` - Compute `present_sequence_length = max(past_sequence_length, total_sequence_length)` when both values are concrete - Fall back to copying `past_key`'s sequence dimension when exact value cannot be computed 2. Add new else-if branch to handle prefill mode (no past_key/past_value input): - Infer `head_size` from query shape and `num_heads/kv_num_heads` attrs - Handle both separate Q/K/V and packed QKV input formats - Construct present shape from query dims, `kv_num_heads`, and `total_sequence_length` or `kv_sequence_length`
1 parent b214734 commit 0a478c0

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

onnxruntime/core/graph/contrib_ops/bert_defs.cc

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -322,23 +322,24 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
322322
updateOutputShape(ctx, 2, present_shape);
323323
}
324324
} else if (use_max_past_present_buffer == -1) {
325+
// shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size)
326+
ONNX_NAMESPACE::TensorShapeProto present_shape;
327+
*present_shape.add_dim() = past_dims[0]; // batch_size
328+
*present_shape.add_dim() = past_dims[1]; // kv_num_heads
325329
if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) {
326330
// present_sequence_length = max(past_sequence_length, total_sequence_length)
327331
const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value()
328332
? total_sequence_length_value
329333
: past_dims[2].dim_value();
330-
331-
ONNX_NAMESPACE::TensorShapeProto present_shape;
332-
for (auto& dim : past_dims) {
333-
*present_shape.add_dim() = dim;
334-
}
335-
336-
// shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size)
337-
present_shape.mutable_dim(2)->set_dim_value(present_sequence_length);
338-
339-
updateOutputShape(ctx, 1, present_shape);
340-
updateOutputShape(ctx, 2, present_shape);
334+
present_shape.add_dim()->set_dim_value(present_sequence_length);
335+
} else {
336+
// Cannot compute exact present_sequence_length, copy from past_key (may be dynamic)
337+
*present_shape.add_dim() = past_dims[2];
341338
}
339+
*present_shape.add_dim() = past_dims[3]; // head_size
340+
341+
updateOutputShape(ctx, 1, present_shape);
342+
updateOutputShape(ctx, 2, present_shape);
342343
}
343344

344345
if (output_qk_index >= 0) {
@@ -370,6 +371,52 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
370371
}
371372
}
372373
}
374+
} else if (hasInputShape(ctx, 0)) {
375+
// Handle the case when past_key/past_value is not provided (first token/prefill mode).
376+
// We still need to infer present_key/present_value output shapes from query and attributes.
377+
auto& query_shape = getInputShape(ctx, 0);
378+
auto& query_dims = query_shape.dim();
379+
380+
int64_t num_heads = getAttribute(ctx, "num_heads", 0);
381+
int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0);
382+
383+
if (num_heads > 0 && kv_num_heads > 0 && query_dims.size() == 3 && query_dims[2].has_dim_value()) {
384+
int64_t hidden_size = query_dims[2].dim_value();
385+
int64_t head_size = 0;
386+
387+
if (hasInputShape(ctx, 2)) {
388+
// query shape is (batch_size, sequence_length, num_heads * head_size)
389+
head_size = hidden_size / num_heads;
390+
} else {
391+
// Packed QKV: query shape is (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)
392+
head_size = hidden_size / (num_heads + 2 * kv_num_heads);
393+
}
394+
395+
if (head_size > 0) {
396+
// Determine present_sequence_length from total_sequence_length or kv_sequence_length
397+
int64_t present_sequence_length = 0;
398+
if (total_sequence_length_value > 0) {
399+
present_sequence_length = total_sequence_length_value;
400+
} else if (kv_sequence_length > 0) {
401+
present_sequence_length = kv_sequence_length;
402+
}
403+
404+
// present key/value shape is (batch_size, kv_num_heads, present_sequence_length, head_size)
405+
ONNX_NAMESPACE::TensorShapeProto present_shape;
406+
*present_shape.add_dim() = query_dims[0]; // batch_size
407+
present_shape.add_dim()->set_dim_value(kv_num_heads);
408+
if (present_sequence_length > 0) {
409+
present_shape.add_dim()->set_dim_value(present_sequence_length);
410+
} else {
411+
// Fallback: use query sequence_length (dim 1) as present_sequence_length for prefill
412+
*present_shape.add_dim() = query_dims[1];
413+
}
414+
present_shape.add_dim()->set_dim_value(head_size);
415+
416+
updateOutputShape(ctx, 1, present_shape);
417+
updateOutputShape(ctx, 2, present_shape);
418+
}
419+
}
373420
}
374421
}
375422
}

0 commit comments

Comments
 (0)