Commit 0a478c0
authored
[Shape Inference] Fix GQA shape inference for present outputs (#27250)
### Description
When using pre-allocated KV cache with `freeDimensionOverrides`, the
shape inference for `present_key` and `present_value` outputs failed
silently. This caused downstream graph operations to receive tensors
with unknown dynamic shapes, leading to unexpected fallback in execution
providers like WebNN. (WebNN currently doesn't support dynamic shape)
### Motivation and Context
**Root cause**:
In `BaseGroupQueryAttentionTypeAndShapeInference()`, the shape inference
logic for `use_max_past_present_buffer == -1` only propagated shapes
when BOTH conditions were met:
1. `total_sequence_length_value` was a concrete value (> 0)
2. `past_dims[2]` had a concrete dimension value
When either condition failed (e.g., using `freeDimensionOverrides` which
results in dynamic `past_sequence_length`), present output shapes were
left uninitialized.
Additionally, when `past_key/past_value` is not provided
(prefill/first-token mode), no shape inference was performed for present
outputs at all.
**Fix**:
1. For `use_max_past_present_buffer == -1`:
- Always construct and propagate `present_shape`
- Compute `present_sequence_length = max(past_sequence_length,
total_sequence_length)` when both values are concrete
- Fall back to copying `past_key`'s sequence dimension when exact value
cannot be computed
2. Add new else-if branch to handle prefill mode (no past_key/past_value
input):
- Infer `head_size` from query shape and `num_heads/kv_num_heads` attrs
- Handle both separate Q/K/V and packed QKV input formats
- Construct present shape from query dims, `kv_num_heads`, and
`total_sequence_length` or `kv_sequence_length`1 parent b214734 commit 0a478c0
1 file changed
+58
-11
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
322 | 322 | | |
323 | 323 | | |
324 | 324 | | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
325 | 329 | | |
326 | 330 | | |
327 | 331 | | |
328 | 332 | | |
329 | 333 | | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
341 | 338 | | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
342 | 343 | | |
343 | 344 | | |
344 | 345 | | |
| |||
370 | 371 | | |
371 | 372 | | |
372 | 373 | | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
373 | 420 | | |
374 | 421 | | |
375 | 422 | | |
| |||
0 commit comments