From d23c1f2fca96ae7f3df1dd2dcc64a6382fccd4c2 Mon Sep 17 00:00:00 2001 From: Matthew Sinclair Date: Mon, 9 Feb 2026 11:35:32 -0800 Subject: [PATCH 1/2] [Fix] Illegal memory access in GetOutputIndex with optional outputs * When a non-trailing output is optional, its ValueInfo may be nullptr. The current implementation attempts to dereference the nullptr due to a missing check. * This bug is exposed if an EP calls `ValueInfo_GetValueProducer` on a value info whose producer has non-trailing optional outputs. * This commit adds the required check and a unit test. --- onnxruntime/core/graph/ep_api_types.cc | 4 ++ onnxruntime/test/ep_graph/test_ep_graph.cc | 7 ++ .../skip_simplified_layer_normalization.onnx | Bin 0 -> 3935 bytes .../skip_simplified_layer_normalization.py | 65 ++++++++++++++++++ 4 files changed, 76 insertions(+) create mode 100644 onnxruntime/test/testdata/skip_simplified_layer_normalization.onnx create mode 100644 onnxruntime/test/testdata/skip_simplified_layer_normalization.py diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index d30c7cd74a76a..5abaca4389c49 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -351,6 +351,10 @@ static Status GetOutputIndex(const EpNode& producer_node, gsl::span outputs = producer_node.GetOutputsSpan(); for (size_t i = 0; i < outputs.size(); i++) { + if (outputs[i] == nullptr) { // outputs == nullptr means the output is optional + continue; + } + if (outputs[i]->GetName() == value_info_name) { index = i; found = true; diff --git a/onnxruntime/test/ep_graph/test_ep_graph.cc b/onnxruntime/test/ep_graph/test_ep_graph.cc index 055b2551328d9..a15f36014e232 100644 --- a/onnxruntime/test/ep_graph/test_ep_graph.cc +++ b/onnxruntime/test/ep_graph/test_ep_graph.cc @@ -180,6 +180,13 @@ TEST(EpGraphTest, CheckModelExternalInitializers) { CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); } +TEST(EpGraphTest, CheckModelOptionalIntermediateNodeOutputs) { + auto test_graph = TestGraph::Load(ORT_TSTR("testdata/skip_simplified_layer_normalization.onnx")); + ASSERT_NE(test_graph, nullptr) << "Failed to load test model"; + + CheckGraphCApi(test_graph->GetGraphViewer(), test_graph->GetOrtGraph()); +} + static void RunConvQDQExtIni(const ORTCHAR_T* model_path, std::vector& output_data) { auto memory_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); Ort::SessionOptions sess_options; diff --git a/onnxruntime/test/testdata/skip_simplified_layer_normalization.onnx b/onnxruntime/test/testdata/skip_simplified_layer_normalization.onnx new file mode 100644 index 0000000000000000000000000000000000000000..a9adf07ab4a6969a803b1d59bce0172ac9d2789f GIT binary patch literal 3935 zcmdX@^+~KuE%M7R%2iSTaf35+3vx2kGE-CF5{WsPRf#2;`FUD`TW+oTq7w4yy2+80u!@Vf6AVaHPhNFR9fKjPIhJm5MeiV;} z!Dt#7O#`E8U^ESkrh(BkFq#HN)4*sN7)=ACX<+!L0VhZwHA)8Ba}^Ti669bM;^kuE yU<6_ZE@og)F-jWJp~WhhB+CWs;9}F8q{W4;&x_S4Cl)RS0YN80F4WGf03!gX?fKRK literal 0 HcmV?d00001 diff --git a/onnxruntime/test/testdata/skip_simplified_layer_normalization.py b/onnxruntime/test/testdata/skip_simplified_layer_normalization.py new file mode 100644 index 0000000000000..5d040231a7c36 --- /dev/null +++ b/onnxruntime/test/testdata/skip_simplified_layer_normalization.py @@ -0,0 +1,65 @@ +from onnx import TensorProto, checker, helper, save, shape_inference + +batch_size = 1 +seq_len = 64 +hidden_size = 896 + +input_vi = helper.make_tensor_value_info( + name="input", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +skip_vi = helper.make_tensor_value_info( + name="skip", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +output_vi = helper.make_tensor_value_info( + name="output", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +input_skip_bias_sum_vi = helper.make_tensor_value_info( + name="input_skip_bias_sum", + elem_type=TensorProto.FLOAT, + shape=[batch_size, seq_len, hidden_size], +) + +gamma_init = helper.make_tensor( + name="gamma", + data_type=TensorProto.FLOAT, + dims=[hidden_size], + vals=[1] * hidden_size +) + +node = helper.make_node( + op_type="SkipSimplifiedLayerNormalization", + inputs=["input", "skip", "gamma"], + outputs=["output", "", "", "input_skip_bias_sum"], + domain="com.microsoft", + epsilon=1e-6, + name="SkipLayerNorm", +) + +graph = helper.make_graph( + nodes=[node], + name="SkipSimplifiedLayerNormGraph", + inputs=[input_vi, skip_vi], + outputs=[output_vi, input_skip_bias_sum_vi], + initializer=[gamma_init], +) + +model = helper.make_model( + graph, + opset_imports=[ + helper.make_operatorsetid("", 17), + helper.make_operatorsetid("com.microsoft", 1), + ], +) + +model = shape_inference.infer_shapes(model) +checker.check_model(model, True) +save(model, "skip_simplified_layer_normalization.onnx") From fcf63f75d61d55c131c89d2bba59d02828b5545c Mon Sep 17 00:00:00 2001 From: Matthew Sinclair Date: Fri, 13 Feb 2026 15:55:52 -0800 Subject: [PATCH 2/2] Address review comment --- onnxruntime/core/graph/ep_api_types.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/graph/ep_api_types.cc b/onnxruntime/core/graph/ep_api_types.cc index 5abaca4389c49..c469873ba5c06 100644 --- a/onnxruntime/core/graph/ep_api_types.cc +++ b/onnxruntime/core/graph/ep_api_types.cc @@ -327,7 +327,8 @@ static Status GetInputIndices(const EpNode& consumer_node, [&found, &value_info_name, &indices](gsl::span input_value_infos, bool is_implicit) -> void { for (size_t i = 0; i < input_value_infos.size(); i++) { - if (input_value_infos[i] == nullptr) { // input_value_info == nullptr means the input is optional + if (input_value_infos[i] == nullptr) { + // This is a missing optional input. Skip it. continue; } if (input_value_infos[i]->GetName() == value_info_name) { @@ -351,7 +352,8 @@ static Status GetOutputIndex(const EpNode& producer_node, gsl::span outputs = producer_node.GetOutputsSpan(); for (size_t i = 0; i < outputs.size(); i++) { - if (outputs[i] == nullptr) { // outputs == nullptr means the output is optional + if (outputs[i] == nullptr) { + // This is a missing optional output. Skip it. continue; }