Skip to content

Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29268

Open
titaiwangms wants to merge 8 commits into
microsoft:mainfrom
titaiwangms:attention-optional-outputs-shapeinf
Open

Fix optional-output guard in DecoderAttention/MultiHeadAttention shape inference#29268
titaiwangms wants to merge 8 commits into
microsoft:mainfrom
titaiwangms:attention-optional-outputs-shapeinf

Conversation

@titaiwangms

@titaiwangms titaiwangms commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Description

The DecoderAttention and MultiHeadAttention shape-inference functions guarded
population of their optional present_key (output 1) and present_value (output 2)
outputs with getNumOutputs() > 1, but then write output index 2. present_key and
present_value are produced as a both-or-neither pair, so this requires all three
outputs (> 2) to be present before populating them — matching the existing
BaseGroupQueryAttention (>= 3) and EmbedLayerNorm guards.

It also adds an output-index range check in InferenceContextImpl::getOutputType so an
output index beyond the declared output count fails inference cleanly instead of
indexing past the end of the outputs container, mirroring the existing
DataPropagationContextImpl::getOutputType and getInputType behavior.

Motivation and Context

A model that declares fewer outputs than the optional present outputs could previously
drive shape inference to access an output index that was not declared. This makes the
guard consistent with the other attention-family contrib ops.

Changes

  • onnxruntime/core/graph/contrib_ops/bert_defs.cc — require all present outputs before
    populating present_key/present_value in DecoderAttention and MultiHeadAttention.
  • onnxruntime/core/graph/graph.cc — add an output-index range check in
    InferenceContextImpl::getOutputType.
  • onnxruntime/test/contrib_ops/attention_optional_outputs_shape_inference_test.cc
    regression tests covering omitted optional present outputs, the 3-output positive
    cases, and the MHA/DMMHA two-output cases.
  • Adds a contrib-op shape-inference output-index safety skill doc plus a one-line
    coding-convention note.

Co-authored-by: Copilot 223556219+Copilot@users.noreply.github.com

titaiwangms and others added 7 commits June 23, 2026 23:49
DecoderAttention and MultiHeadAttention shape-inference functions guarded
population of present_key (output 1) and present_value (output 2) with
getNumOutputs() > 1, but write output index 2. present_key and present_value
are produced as a both-or-neither pair, so require all three outputs (> 2)
before populating them, matching BaseGroupQueryAttention (>= 3) and the
EmbedLayerNorm guard. Also add a bounds check in
InferenceContextImpl::getOutputType so an out-of-range output index fails
inference cleanly instead of indexing past the end, mirroring
DataPropagationContextImpl and getInputType.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…outputs omitted

Cover DecoderAttention, MultiHeadAttention and DecoderMaskedMultiHeadAttention
nodes declared with exactly two outputs (present_key kept, present_value
omitted). Each test builds the node and asserts Graph::Resolve() shape inference
completes cleanly. Tests are execution-provider independent and throw-free, so
they run on the default CPU build and in no-exception (ORT_NO_EXCEPTIONS) builds.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Extend the optional-present-output regression suite with cases that declare all
three outputs for DecoderAttention, MultiHeadAttention and
DecoderMaskedMultiHeadAttention and assert the present_key/present_value branch
still runs and infers their element types. Together with the two-output cases
this pins the output-count guard to exactly three.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Use each op's actual output names (DecoderAttention: new_key_cache /
new_value_cache; MultiHeadAttention: present_key / present_value), align the
three guards to a consistent '// has <names> outputs' phrasing, and note that
the two optional cache outputs are produced as a pair, so they are present only
when the node declares more than two outputs. Comment-only; no logic change.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The MultiHeadAttention and DecoderMaskedMultiHeadAttention two-output cases only
passed a query input, so the present-output branch (which references output index
2) was never entered and the tests could not detect a regression there. Supply
shaped past_key / past_value (and past_sequence_length for MHA, buffer sharing for
DMMHA) so the branch is exercised while only two outputs are declared, matching the
DecoderAttention case which already reached that path.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…ng-convention note

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

@tianleiwu tianleiwu left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, well-scoped fix. The > 1 guards in DecoderAttentionTypeAndShapeInference and MultiHeadAttentionTypeAndShapeInference (also used by DecoderMaskedMultiHeadAttention via the shared function at bert_defs.cc:1085/1190) populated output indices 1 and 2, but > 1 only guarantees indices 0–1, so a schema-valid 2-output node drove an out-of-range write to index 2. Raising to > 2 matches the both-or-neither present_key/present_value (and new_key_cache/new_value_cache) semantics and the existing BaseGroupQueryAttention (>= 3) / EmbedLayerNorm (> 2) exemplars.

Swept the siblings: no other > 1-guarded block writes index 2 — shape_inference_functions.cc:221 and contrib_defs.cc:3770 write only index 1, and PagedAttention (bert_defs.cc:1449) gates with an inner != 3 fail.

The InferenceContextImpl::getOutputType bounds check is solid defense-in-depth — it mirrors getInputType's .at() and the sibling DataPropagationContextImpl checks and turns a latent out-of-range operator[] (UB) into a clean fail_type_inference, consistent under both exceptions and ORT_NO_EXCEPTIONS.

Tests are driven through Model + Graph::Resolve() (hitting the real InferenceContextImpl sink), are non-vacuous (the type-inference block writes index 2 unconditionally inside the guard, so the negative cases would OOB pre-fix), and the positive all-output cases pin the guard against over-restriction.

One behavioral note (non-blocking): for a degenerate 2-output MHA node (out + present_key only), present_key's element type is no longer propagated. This is acceptable since the prior path was UB anyway and kv-cache usage always emits both outputs.

Only a nitpick below; otherwise LGTM.

Comment thread .agents/skills/contrib-op-shape-inference-memory-safety/SKILL.md Outdated
Comment thread .agents/skills/contrib-op-shape-inference-memory-safety/SKILL.md Outdated
@titaiwangms titaiwangms enabled auto-merge (squash) June 26, 2026 00:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants