Commit cb29a62
committed
[Shape Inference] Fix GroupQueryAttention shape inference for present outputs
Issue:
When using pre-allocated KV cache with freeDimensionOverrides, the shape
inference for present_key and present_value outputs failed silently.
This caused downstream graph operations to receive tensors with unknown
dynamic shapes, leading to unexpected fallback in execution providers like WebNN.
(WebNN currently doesn't support dynamic shape)
Root cause:
In BaseGroupQueryAttentionTypeAndShapeInference(), the shape inference
logic for use_max_past_present_buffer == -1 only propagated shapes when
BOTH conditions were met:
1. total_sequence_length_value was a concrete value (> 0)
2. past_dims[2] had a concrete dimension value
When either condition failed (e.g., using freeDimensionOverrides which
results in dynamic past_sequence_length), present output shapes were
left uninitialized.
Additionally, when past_key is not provided (prefill/first-token mode),
no shape inference was performed for present outputs at all.
Fix:
1. For use_max_past_present_buffer == -1:
- Always construct and propagate present_shape
- Compute present_sequence_length = max(past_sequence_length,
total_sequence_length) when both values are concrete
- Fall back to copying past_key's sequence dimension when exact
value cannot be computed
2. Add new else-if branch to handle prefill mode (no past_key input):
- Infer head_size from query shape and num_heads/kv_num_heads attrs
- Handle both separate Q/K/V and packed QKV input formats
- Construct present shape from query dims, kv_num_heads, and
total_sequence_length or kv_sequence_length1 parent a5dc0f9 commit cb29a62
1 file changed
+58
-11
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
322 | 322 | | |
323 | 323 | | |
324 | 324 | | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
325 | 329 | | |
326 | 330 | | |
327 | 331 | | |
328 | 332 | | |
329 | 333 | | |
330 | | - | |
331 | | - | |
332 | | - | |
333 | | - | |
334 | | - | |
335 | | - | |
336 | | - | |
337 | | - | |
338 | | - | |
339 | | - | |
340 | | - | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
341 | 338 | | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
342 | 343 | | |
343 | 344 | | |
344 | 345 | | |
| |||
370 | 371 | | |
371 | 372 | | |
372 | 373 | | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
| 397 | + | |
| 398 | + | |
| 399 | + | |
| 400 | + | |
| 401 | + | |
| 402 | + | |
| 403 | + | |
| 404 | + | |
| 405 | + | |
| 406 | + | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
| 413 | + | |
| 414 | + | |
| 415 | + | |
| 416 | + | |
| 417 | + | |
| 418 | + | |
| 419 | + | |
373 | 420 | | |
374 | 421 | | |
375 | 422 | | |
| |||
0 commit comments