Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29268
Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29268titaiwangms wants to merge 8 commits into
Conversation
DecoderAttention and MultiHeadAttention shape-inference functions guarded population of present_key (output 1) and present_value (output 2) with getNumOutputs() > 1, but write output index 2. present_key and present_value are produced as a both-or-neither pair, so require all three outputs (> 2) before populating them, matching BaseGroupQueryAttention (>= 3) and the EmbedLayerNorm guard. Also add a bounds check in InferenceContextImpl::getOutputType so an out-of-range output index fails inference cleanly instead of indexing past the end, mirroring DataPropagationContextImpl and getInputType. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…outputs omitted Cover DecoderAttention, MultiHeadAttention and DecoderMaskedMultiHeadAttention nodes declared with exactly two outputs (present_key kept, present_value omitted). Each test builds the node and asserts Graph::Resolve() shape inference completes cleanly. Tests are execution-provider independent and throw-free, so they run on the default CPU build and in no-exception (ORT_NO_EXCEPTIONS) builds. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Extend the optional-present-output regression suite with cases that declare all three outputs for DecoderAttention, MultiHeadAttention and DecoderMaskedMultiHeadAttention and assert the present_key/present_value branch still runs and infers their element types. Together with the two-output cases this pins the output-count guard to exactly three. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Use each op's actual output names (DecoderAttention: new_key_cache / new_value_cache; MultiHeadAttention: present_key / present_value), align the three guards to a consistent '// has <names> outputs' phrasing, and note that the two optional cache outputs are produced as a pair, so they are present only when the node declares more than two outputs. Comment-only; no logic change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The MultiHeadAttention and DecoderMaskedMultiHeadAttention two-output cases only passed a query input, so the present-output branch (which references output index 2) was never entered and the tests could not detect a regression there. Supply shaped past_key / past_value (and past_sequence_length for MHA, buffer sharing for DMMHA) so the branch is exercised while only two outputs are declared, matching the DecoderAttention case which already reached that path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…ng-convention note Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
tianleiwu
left a comment
There was a problem hiding this comment.
Correct, well-scoped fix. The > 1 guards in DecoderAttentionTypeAndShapeInference and MultiHeadAttentionTypeAndShapeInference (also used by DecoderMaskedMultiHeadAttention via the shared function at bert_defs.cc:1085/1190) populated output indices 1 and 2, but > 1 only guarantees indices 0–1, so a schema-valid 2-output node drove an out-of-range write to index 2. Raising to > 2 matches the both-or-neither present_key/present_value (and new_key_cache/new_value_cache) semantics and the existing BaseGroupQueryAttention (>= 3) / EmbedLayerNorm (> 2) exemplars.
Swept the siblings: no other > 1-guarded block writes index 2 — shape_inference_functions.cc:221 and contrib_defs.cc:3770 write only index 1, and PagedAttention (bert_defs.cc:1449) gates with an inner != 3 fail.
The InferenceContextImpl::getOutputType bounds check is solid defense-in-depth — it mirrors getInputType's .at() and the sibling DataPropagationContextImpl checks and turns a latent out-of-range operator[] (UB) into a clean fail_type_inference, consistent under both exceptions and ORT_NO_EXCEPTIONS.
Tests are driven through Model + Graph::Resolve() (hitting the real InferenceContextImpl sink), are non-vacuous (the type-inference block writes index 2 unconditionally inside the guard, so the negative cases would OOB pre-fix), and the positive all-output cases pin the guard against over-restriction.
One behavioral note (non-blocking): for a degenerate 2-output MHA node (out + present_key only), present_key's element type is no longer propagated. This is acceptable since the prior path was UB anyway and kv-cache usage always emits both outputs.
Only a nitpick below; otherwise LGTM.
Description
The
DecoderAttentionandMultiHeadAttentionshape-inference functions guardedpopulation of their optional
present_key(output 1) andpresent_value(output 2)outputs with
getNumOutputs() > 1, but then write output index 2.present_keyandpresent_valueare produced as a both-or-neither pair, so this requires all threeoutputs (
> 2) to be present before populating them — matching the existingBaseGroupQueryAttention(>= 3) andEmbedLayerNormguards.It also adds an output-index range check in
InferenceContextImpl::getOutputTypeso anoutput index beyond the declared output count fails inference cleanly instead of
indexing past the end of the outputs container, mirroring the existing
DataPropagationContextImpl::getOutputTypeandgetInputTypebehavior.Motivation and Context
A model that declares fewer outputs than the optional present outputs could previously
drive shape inference to access an output index that was not declared. This makes the
guard consistent with the other attention-family contrib ops.
Changes
onnxruntime/core/graph/contrib_ops/bert_defs.cc— require all present outputs beforepopulating
present_key/present_valueinDecoderAttentionandMultiHeadAttention.onnxruntime/core/graph/graph.cc— add an output-index range check inInferenceContextImpl::getOutputType.onnxruntime/test/contrib_ops/attention_optional_outputs_shape_inference_test.cc—regression tests covering omitted optional present outputs, the 3-output positive
cases, and the MHA/DMMHA two-output cases.
coding-convention note.
Co-authored-by: Copilot 223556219+Copilot@users.noreply.github.com