diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index fec706b4ae9a4..c2a361fc082b6 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1905,6 +1905,14 @@ struct GroupNode { intermediate_args.insert(arg); } + if (node->GetOutputEdgesCount() == 0) { + for (const NodeArg* arg : node->OutputDefs()) { + output_args.push_back(arg); + } + + continue; + } + for (auto output_edge_it = node->OutputEdgesBegin(); output_edge_it != node->OutputEdgesEnd(); ++output_edge_it) { const Node* output_node = &output_edge_it->GetNode(); @@ -2012,7 +2020,8 @@ void FindBranchGraph( const InlinedVector& branch_graph_input_nodes, const InlinedVector& backward_node_in_degree, InlinedVector& branch_graph, - InlinedVector>& branch_subgraph_consumers) { + InlinedVector>& branch_subgraph_consumers, + InlinedVector& branch_subgraph_outputs) { // Loop through the branch_graph_input_nodes to find the branch subgraphs by its output edges in BFS, // and find the maximum self_contained subgraph taking the branch_graph_input_nodes as input nodes. std::queue to_visit_queue; @@ -2044,11 +2053,20 @@ void FindBranchGraph( // At this point, branch_graph is a big subgraph that contains all the nodes that are purely // triggered by the branch_graph_input_nodes, other graph input/initializers and leaf nodes (for example Constant). for (const Node* n : branch_graph) { + if (n->GetOutputEdgesCount() == 0) { + // In case the node connect to graph outputs or nothings, append all outputs as the branch subgraph outputs. + for (auto output_def : n->OutputDefs()) { + branch_subgraph_outputs.push_back(output_def); + } + continue; + } + for (auto output_it = n->OutputEdgesBegin(); output_it != n->OutputEdgesEnd(); ++output_it) { const Node* output_node = &output_it->GetNode(); const size_t dest_in_port = output_it->GetDstArgIndex(); if (std::find(branch_graph.begin(), branch_graph.end(), output_node) == branch_graph.end()) { branch_subgraph_consumers.push_back({output_node, dest_in_port}); + branch_subgraph_outputs.push_back(n->OutputDefs()[output_it->GetSrcArgIndex()]); } } } @@ -2056,18 +2074,20 @@ void FindBranchGraph( void TagNodeToAssociatedOutputs(const Graph* graph, const InlinedHashSet& nodes_to_execute_before_yieldop, - const InlinedVector>& branch_subgraph_consumers, + const InlinedVector& branch_subgraph_outputs, const InlinedVector& branch_graph, InlinedVector& group_node_collection, InlinedHashMap& output_arg_to_grouped_node) { - // Reverse DFS from branch graph outputs (e.g. branch_subgraph_consumers) to tag each nodes: + // Reverse DFS from branch graph outputs (e.g. branch_subgraph_outputs) to tag each nodes: // If one node N contributes to a graph output A, then we will tag A to N. // If the node N contributes to multiple graph outputs A, B, C, then we will tag the A, B, C to N. InlinedHashMap> node_to_its_associated_outputs; node_to_its_associated_outputs.reserve(branch_graph.size()); - for (const auto& consumer : branch_subgraph_consumers) { - const NodeArg* output_arg = consumer.first->InputDefs()[consumer.second]; + InlinedHashSet handled_branch_subgraph_end_nodes; + for (const auto& output_arg : branch_subgraph_outputs) { const Node* end_node = graph->GetProducerNode(output_arg->Name()); + handled_branch_subgraph_end_nodes.insert(end_node); + InlinedVector end_nodes{end_node}; graph->ReverseDFSFrom( end_nodes, @@ -2097,6 +2117,7 @@ void TagNodeToAssociatedOutputs(const Graph* graph, group_node_collection.reserve(associated_outputs_to_nodes.size()); for (auto& [associated_outputs, nodes] : associated_outputs_to_nodes) { group_node_collection.push_back(nodes); + // Flatten the key into NodeArg* for better search. GroupNode& grouped_node = group_node_collection.back(); for (const auto& output_arg : grouped_node.output_args) { @@ -2184,6 +2205,7 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, InlinedVector branch_graph_input_nodes; branch_graph_input_nodes.reserve(num_of_backward_nodes); + PrepareToFindBranchGraph(this, nodes_to_execute_before_yieldop, branch_graph_input_nodes, @@ -2193,17 +2215,19 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, InlinedVector branch_graph; branch_graph.reserve(num_of_backward_nodes); InlinedVector> branch_subgraph_consumers; + InlinedVector branch_subgraph_outputs; FindBranchGraph(branch_graph_input_nodes, backward_node_in_degree, branch_graph, - branch_subgraph_consumers); + branch_subgraph_consumers, + branch_subgraph_outputs); // Cluster the nodes in the branch_graph based on the associated outputs. InlinedVector group_node_collection; InlinedHashMap output_arg_to_grouped_node; TagNodeToAssociatedOutputs(this, nodes_to_execute_before_yieldop, - branch_subgraph_consumers, + branch_subgraph_outputs, branch_graph, group_node_collection, output_arg_to_grouped_node); @@ -2247,9 +2271,27 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, // For the group nodes that are not outputted, we need to output them. // Hitting this code path means some nodes are consuming outputs of forward nodes, and their outputs // are not used by main branch backward nodes. - for (const auto& [output_arg, grouped_node] : output_arg_to_grouped_node) { + InlinedVector> + left_output_arg_to_grouped_node_vector; // To ensure deterministic order. + left_output_arg_to_grouped_node_vector.reserve(output_arg_to_grouped_node.size()); + for (auto& [output_arg, grouped_node] : output_arg_to_grouped_node) { if (!grouped_node->is_outputted) { - OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order); + left_output_arg_to_grouped_node_vector.push_back({output_arg, grouped_node}); + } + } + + if (!left_output_arg_to_grouped_node_vector.empty()) { + // Sort to ensure deterministic order. + std::sort(left_output_arg_to_grouped_node_vector.begin(), left_output_arg_to_grouped_node_vector.end(), + [](const std::pair& a, const std::pair& b) { + return a.first->Name() < b.first->Name(); + }); + for (const auto& pair : left_output_arg_to_grouped_node_vector) { + const NodeArg* output_arg = pair.first; + GroupNode* grouped_node = pair.second; + if (!grouped_node->is_outputted) { + OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order); + } } } diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 590d18be91bb2..ff10765741bbe 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -2291,6 +2291,85 @@ TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_MultiLayerRec } } +TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_SubgraphGeneratingNodeHavingNoConsumers) { + Model model("graph_1", false, *logger_); + auto& graph = model.MainGraph(); + + /* + | + node_0 (Identity) + / \ \ + node_1 (Identity) \ Identity + | | \_____graph_output_0 + node_4 (Identity) | + | | + YieldOp recompute_node_1 + \ / \ + node_1_grad (Merge) Identity + | | + graph_output_1 + */ + + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32); + auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32); + auto& graph_output_0_identity = graph.GetOrCreateNodeArg("graphoutput_0_identity_out_1", &tensor_int32); + auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32); + auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32); + auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32); + auto& output_arg5 = graph.GetOrCreateNodeArg("node_yield_out_1", &tensor_int32); + auto& output_arg6 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32); + + graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0}); + graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1}); + graph.AddNode("graph_output_0_identity", "Identity_Fake", "graph output 0 identity", {&output_arg0}, {&graph_output_0_identity}); + graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2}); + + auto& graph_output1_identity = graph.GetOrCreateNodeArg("graphoutput_1_identity_out_1", &tensor_int32); + graph.AddNode("graph_output_1_identity", "Identity_Fake", "graph output 1 identity", {&output_arg2}, {&graph_output1_identity}); + + graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4}); + + ONNX_NAMESPACE::AttributeProto full_shape_outputs; + const std::string attribute_name = "full_shape_outputs"; + full_shape_outputs.set_name(attribute_name); + full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS); + full_shape_outputs.add_ints(static_cast(0)); + NodeAttributes attributes({{attribute_name, full_shape_outputs}}); + + graph.AddNode("node_yield", "YieldOp", "node yield", {&output_arg4}, {&output_arg5}, &attributes, kMSDomain); + graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg5, &output_arg2}, {&output_arg6}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + GraphViewer graph_viewer(graph); + + // MEMORY_EFFICIENT order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::MEMORY_EFFICIENT); + const std::vector expected_order = + { + "node_0", + "node_1", + "node_4", + "node_yield", + "recompute_node_1", + "node_1_grad", + "graph_output_0_identity", + "graph_output_1_identity", + }; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_order[i]) + << "MEMORY_EFFICIENT based execution order is wrong. expected node is " << expected_order[i] + << " but got " << node->Name(); + } + } +} + #endif } // namespace test