Skip to content

Commit cb29a62

Browse files
committed
[Shape Inference] Fix GroupQueryAttention shape inference for present outputs
Issue: When using pre-allocated KV cache with freeDimensionOverrides, the shape inference for present_key and present_value outputs failed silently. This caused downstream graph operations to receive tensors with unknown dynamic shapes, leading to unexpected fallback in execution providers like WebNN. (WebNN currently doesn't support dynamic shape) Root cause: In BaseGroupQueryAttentionTypeAndShapeInference(), the shape inference logic for use_max_past_present_buffer == -1 only propagated shapes when BOTH conditions were met: 1. total_sequence_length_value was a concrete value (> 0) 2. past_dims[2] had a concrete dimension value When either condition failed (e.g., using freeDimensionOverrides which results in dynamic past_sequence_length), present output shapes were left uninitialized. Additionally, when past_key is not provided (prefill/first-token mode), no shape inference was performed for present outputs at all. Fix: 1. For use_max_past_present_buffer == -1: - Always construct and propagate present_shape - Compute present_sequence_length = max(past_sequence_length, total_sequence_length) when both values are concrete - Fall back to copying past_key's sequence dimension when exact value cannot be computed 2. Add new else-if branch to handle prefill mode (no past_key input): - Infer head_size from query shape and num_heads/kv_num_heads attrs - Handle both separate Q/K/V and packed QKV input formats - Construct present shape from query dims, kv_num_heads, and total_sequence_length or kv_sequence_length
1 parent a5dc0f9 commit cb29a62

File tree

1 file changed

+58
-11
lines changed

1 file changed

+58
-11
lines changed

onnxruntime/core/graph/contrib_ops/bert_defs.cc

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -322,23 +322,24 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
322322
updateOutputShape(ctx, 2, present_shape);
323323
}
324324
} else if (use_max_past_present_buffer == -1) {
325+
// shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size)
326+
ONNX_NAMESPACE::TensorShapeProto present_shape;
327+
*present_shape.add_dim() = past_dims[0]; // batch_size
328+
*present_shape.add_dim() = past_dims[1]; // kv_num_heads
325329
if (total_sequence_length_value > 0 && past_dims[2].has_dim_value()) {
326330
// present_sequence_length = max(past_sequence_length, total_sequence_length)
327331
const int64_t present_sequence_length = total_sequence_length_value > past_dims[2].dim_value()
328332
? total_sequence_length_value
329333
: past_dims[2].dim_value();
330-
331-
ONNX_NAMESPACE::TensorShapeProto present_shape;
332-
for (auto& dim : past_dims) {
333-
*present_shape.add_dim() = dim;
334-
}
335-
336-
// shape of present key/value is (batch_size, kv_num_heads, present_sequence_length, head_size)
337-
present_shape.mutable_dim(2)->set_dim_value(present_sequence_length);
338-
339-
updateOutputShape(ctx, 1, present_shape);
340-
updateOutputShape(ctx, 2, present_shape);
334+
present_shape.add_dim()->set_dim_value(present_sequence_length);
335+
} else {
336+
// Cannot compute exact present_sequence_length, copy from past_key (may be dynamic)
337+
*present_shape.add_dim() = past_dims[2];
341338
}
339+
*present_shape.add_dim() = past_dims[3]; // head_size
340+
341+
updateOutputShape(ctx, 1, present_shape);
342+
updateOutputShape(ctx, 2, present_shape);
342343
}
343344

344345
if (output_qk_index >= 0) {
@@ -370,6 +371,52 @@ void BaseGroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceConte
370371
}
371372
}
372373
}
374+
} else if (hasInputShape(ctx, 0)) {
375+
// Handle the case when past_key/past_value is not provided (first token/prefill mode).
376+
// We still need to infer present_key/present_value output shapes from query and attributes.
377+
auto& query_shape = getInputShape(ctx, 0);
378+
auto& query_dims = query_shape.dim();
379+
380+
int64_t num_heads = getAttribute(ctx, "num_heads", 0);
381+
int64_t kv_num_heads = getAttribute(ctx, "kv_num_heads", 0);
382+
383+
if (num_heads > 0 && kv_num_heads > 0 && query_dims.size() == 3 && query_dims[2].has_dim_value()) {
384+
int64_t hidden_size = query_dims[2].dim_value();
385+
int64_t head_size = 0;
386+
387+
if (hasInputShape(ctx, 2)) {
388+
// query shape is (batch_size, sequence_length, num_heads * head_size)
389+
head_size = hidden_size / num_heads;
390+
} else {
391+
// Packed QKV: query shape is (batch_size, sequence_length, (num_heads + 2 * kv_num_heads) * head_size)
392+
head_size = hidden_size / (num_heads + 2 * kv_num_heads);
393+
}
394+
395+
if (head_size > 0) {
396+
// Determine present_sequence_length from total_sequence_length or kv_sequence_length
397+
int64_t present_sequence_length = 0;
398+
if (total_sequence_length_value > 0) {
399+
present_sequence_length = total_sequence_length_value;
400+
} else if (kv_sequence_length > 0) {
401+
present_sequence_length = kv_sequence_length;
402+
}
403+
404+
// present key/value shape is (batch_size, kv_num_heads, present_sequence_length, head_size)
405+
ONNX_NAMESPACE::TensorShapeProto present_shape;
406+
*present_shape.add_dim() = query_dims[0]; // batch_size
407+
present_shape.add_dim()->set_dim_value(kv_num_heads);
408+
if (present_sequence_length > 0) {
409+
present_shape.add_dim()->set_dim_value(present_sequence_length);
410+
} else {
411+
// Fallback: use query sequence_length (dim 1) as present_sequence_length for prefill
412+
*present_shape.add_dim() = query_dims[1];
413+
}
414+
present_shape.add_dim()->set_dim_value(head_size);
415+
416+
updateOutputShape(ctx, 1, present_shape);
417+
updateOutputShape(ctx, 2, present_shape);
418+
}
419+
}
373420
}
374421
}
375422
}

0 commit comments

Comments
 (0)