Skip to content

Commit 61aff5c

Browse files
Copilotjustinchuby
andcommitted
Fix outer scope initializer type checking by using IsOuterScopeValue
Co-authored-by: justinchuby <11205048+justinchuby@users.noreply.github.com>
1 parent d8e68d1 commit 61aff5c

File tree

2 files changed

+76
-1
lines changed

2 files changed

+76
-1
lines changed

onnxruntime/core/graph/graph.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2822,7 +2822,7 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
28222822
// if we're building a graph we permit outer scope node args to have no type
28232823
// as the 'real' Resolve at runtime will have type inferencing
28242824
auto is_outer_scope_nodearg = [this](const std::string& name) {
2825-
return outer_scope_node_arg_names_.find(name) != outer_scope_node_arg_names_.cend();
2825+
return resolve_context_.IsOuterScopeValue(name);
28262826
};
28272827

28282828
// <k> index used to navigate node->InputDefs().

onnxruntime/test/ir/graph_test.cc

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,6 +2238,81 @@ TEST(GraphGetOrtValueInitializerTest, ReturnsOrtValueFromOuterScope) {
22382238
EXPECT_EQ(t.Shape().Size(), kTensorSize);
22392239
}
22402240

2241+
TEST(GraphTest, OuterScopeInitializerTypeInference) {
2242+
// This test verifies that initializers from outer scope are properly recognized
2243+
// during type inference, even without explicit value_info in the subgraph
2244+
2245+
// Create parent graph with initializer
2246+
Model parent_model("ParentModel", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), {}, {},
2247+
DefaultLoggingManager().DefaultLogger());
2248+
Graph& parent_graph = parent_model.MainGraph();
2249+
2250+
const std::string outer_init_name = "outer_init";
2251+
// Create a simple TensorProto initializer in parent graph
2252+
TensorProto tensor_proto;
2253+
tensor_proto.set_name(outer_init_name);
2254+
tensor_proto.set_data_type(TensorProto_DataType_FLOAT);
2255+
tensor_proto.add_dims(1);
2256+
tensor_proto.add_float_data(42.0f);
2257+
2258+
parent_graph.AddInitializedTensor(tensor_proto);
2259+
2260+
// Create a node in parent graph that will be the parent node for the subgraph
2261+
TypeProto tensor_type;
2262+
tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_BOOL);
2263+
auto& condition_arg = parent_graph.GetOrCreateNodeArg("condition", &tensor_type);
2264+
2265+
tensor_type.mutable_tensor_type()->set_elem_type(TensorProto_DataType_FLOAT);
2266+
auto& output_arg = parent_graph.GetOrCreateNodeArg("output", &tensor_type);
2267+
2268+
NodeArg* inputs[] = {&condition_arg};
2269+
NodeArg* outputs[] = {&output_arg};
2270+
2271+
// Create an "If" node that will contain the subgraph
2272+
auto& parent_node = parent_graph.AddNode("if_node", "If", "test if node", inputs, outputs);
2273+
2274+
// Create subgraph
2275+
GraphProto subgraph_proto;
2276+
subgraph_proto.set_name("TestSubgraph");
2277+
Graph subgraph(parent_model, &subgraph_proto, parent_graph.DomainToVersionMap(),
2278+
parent_model.IrVersion(), nullptr, &parent_graph, &parent_node,
2279+
DefaultLoggingManager().DefaultLogger(), false);
2280+
2281+
// Add a node in the subgraph that uses the outer scope initializer
2282+
// Deliberately do NOT add a value_info for the outer_initializer in the subgraph
2283+
auto& outer_init_nodearg = subgraph.GetOrCreateNodeArg(outer_init_name, nullptr); // No type info
2284+
auto& subgraph_output = subgraph.GetOrCreateNodeArg("subgraph_out", &tensor_type);
2285+
2286+
NodeArg* sub_inputs[] = {&outer_init_nodearg};
2287+
NodeArg* sub_outputs[] = {&subgraph_output};
2288+
2289+
// Create an Identity node that uses the outer scope initializer
2290+
auto& sub_node = subgraph.AddNode("identity_node", "Identity", "uses outer scope init",
2291+
sub_inputs, sub_outputs);
2292+
2293+
// Set the subgraph as an attribute of the parent node
2294+
subgraph_proto.set_name("then_branch");
2295+
*(subgraph_proto.add_node()) = sub_node.ToProto();
2296+
subgraph_proto.add_output()->set_name("subgraph_out");
2297+
2298+
parent_node.AddAttribute("then_branch", subgraph_proto);
2299+
parent_node.AddAttribute("else_branch", subgraph_proto); // Same for simplicity
2300+
2301+
// Add the initializer name to the parent node's implicit input defs
2302+
NodeArg* outer_init_nodearg_parent = parent_graph.GetNodeArg(outer_init_name);
2303+
ASSERT_NE(outer_init_nodearg_parent, nullptr);
2304+
{
2305+
// Test hack to tweak an internal structure.
2306+
auto& node_wrapper = static_cast<NodeWrapper&>(parent_node);
2307+
node_wrapper.MutableDefinitions().implicit_input_defs.push_back(outer_init_nodearg_parent);
2308+
}
2309+
2310+
// With the fix, this should NOT fail with "does not have type information set by parent node"
2311+
// The outer scope initializer should be properly recognized
2312+
Status result = parent_graph.Resolve();
2313+
EXPECT_TRUE(result.IsOK()) << "Graph resolution failed: " << result.ErrorMessage();
2314+
}
2315+
22412316
TEST_F(GraphTest, AddInitializedOrtValueWithExternalData) {
22422317
Model model("TestAddInitializedOrtValue", false, *logger_);
22432318
Graph& graph = model.MainGraph();

0 commit comments

Comments
 (0)