From addcc4c4b296148e617435ac7ac917d278fa7a79 Mon Sep 17 00:00:00 2001 From: pengwa Date: Mon, 6 May 2024 17:25:23 +0800 Subject: [PATCH] Fix missing node during mem efficient topo sort (#20497) ### Fix missing node during mem efficient topo sort Some nodes are not cusumed by the backward path, they are also not generating graph outputs. We missed those nodes, so this PR fix that and add related tests. A side note: we should remove those nodes that are not used for computing any graph outputs in a graph transformer. (TODO) ### Motivation and Context --- onnxruntime/core/graph/graph.cc | 60 +++++++++++++++++++---- onnxruntime/test/ir/graph_test.cc | 79 +++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+), 9 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index fec706b4ae9a4..c2a361fc082b6 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -1905,6 +1905,14 @@ struct GroupNode { intermediate_args.insert(arg); } + if (node->GetOutputEdgesCount() == 0) { + for (const NodeArg* arg : node->OutputDefs()) { + output_args.push_back(arg); + } + + continue; + } + for (auto output_edge_it = node->OutputEdgesBegin(); output_edge_it != node->OutputEdgesEnd(); ++output_edge_it) { const Node* output_node = &output_edge_it->GetNode(); @@ -2012,7 +2020,8 @@ void FindBranchGraph( const InlinedVector& branch_graph_input_nodes, const InlinedVector& backward_node_in_degree, InlinedVector& branch_graph, - InlinedVector>& branch_subgraph_consumers) { + InlinedVector>& branch_subgraph_consumers, + InlinedVector& branch_subgraph_outputs) { // Loop through the branch_graph_input_nodes to find the branch subgraphs by its output edges in BFS, // and find the maximum self_contained subgraph taking the branch_graph_input_nodes as input nodes. std::queue to_visit_queue; @@ -2044,11 +2053,20 @@ void FindBranchGraph( // At this point, branch_graph is a big subgraph that contains all the nodes that are purely // triggered by the branch_graph_input_nodes, other graph input/initializers and leaf nodes (for example Constant). for (const Node* n : branch_graph) { + if (n->GetOutputEdgesCount() == 0) { + // In case the node connect to graph outputs or nothings, append all outputs as the branch subgraph outputs. + for (auto output_def : n->OutputDefs()) { + branch_subgraph_outputs.push_back(output_def); + } + continue; + } + for (auto output_it = n->OutputEdgesBegin(); output_it != n->OutputEdgesEnd(); ++output_it) { const Node* output_node = &output_it->GetNode(); const size_t dest_in_port = output_it->GetDstArgIndex(); if (std::find(branch_graph.begin(), branch_graph.end(), output_node) == branch_graph.end()) { branch_subgraph_consumers.push_back({output_node, dest_in_port}); + branch_subgraph_outputs.push_back(n->OutputDefs()[output_it->GetSrcArgIndex()]); } } } @@ -2056,18 +2074,20 @@ void FindBranchGraph( void TagNodeToAssociatedOutputs(const Graph* graph, const InlinedHashSet& nodes_to_execute_before_yieldop, - const InlinedVector>& branch_subgraph_consumers, + const InlinedVector& branch_subgraph_outputs, const InlinedVector& branch_graph, InlinedVector& group_node_collection, InlinedHashMap& output_arg_to_grouped_node) { - // Reverse DFS from branch graph outputs (e.g. branch_subgraph_consumers) to tag each nodes: + // Reverse DFS from branch graph outputs (e.g. branch_subgraph_outputs) to tag each nodes: // If one node N contributes to a graph output A, then we will tag A to N. // If the node N contributes to multiple graph outputs A, B, C, then we will tag the A, B, C to N. InlinedHashMap> node_to_its_associated_outputs; node_to_its_associated_outputs.reserve(branch_graph.size()); - for (const auto& consumer : branch_subgraph_consumers) { - const NodeArg* output_arg = consumer.first->InputDefs()[consumer.second]; + InlinedHashSet handled_branch_subgraph_end_nodes; + for (const auto& output_arg : branch_subgraph_outputs) { const Node* end_node = graph->GetProducerNode(output_arg->Name()); + handled_branch_subgraph_end_nodes.insert(end_node); + InlinedVector end_nodes{end_node}; graph->ReverseDFSFrom( end_nodes, @@ -2097,6 +2117,7 @@ void TagNodeToAssociatedOutputs(const Graph* graph, group_node_collection.reserve(associated_outputs_to_nodes.size()); for (auto& [associated_outputs, nodes] : associated_outputs_to_nodes) { group_node_collection.push_back(nodes); + // Flatten the key into NodeArg* for better search. GroupNode& grouped_node = group_node_collection.back(); for (const auto& output_arg : grouped_node.output_args) { @@ -2184,6 +2205,7 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, InlinedVector branch_graph_input_nodes; branch_graph_input_nodes.reserve(num_of_backward_nodes); + PrepareToFindBranchGraph(this, nodes_to_execute_before_yieldop, branch_graph_input_nodes, @@ -2193,17 +2215,19 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, InlinedVector branch_graph; branch_graph.reserve(num_of_backward_nodes); InlinedVector> branch_subgraph_consumers; + InlinedVector branch_subgraph_outputs; FindBranchGraph(branch_graph_input_nodes, backward_node_in_degree, branch_graph, - branch_subgraph_consumers); + branch_subgraph_consumers, + branch_subgraph_outputs); // Cluster the nodes in the branch_graph based on the associated outputs. InlinedVector group_node_collection; InlinedHashMap output_arg_to_grouped_node; TagNodeToAssociatedOutputs(this, nodes_to_execute_before_yieldop, - branch_subgraph_consumers, + branch_subgraph_outputs, branch_graph, group_node_collection, output_arg_to_grouped_node); @@ -2247,9 +2271,27 @@ void Graph::MemoryEfficientTopologicalSort(const Node* yield_op, // For the group nodes that are not outputted, we need to output them. // Hitting this code path means some nodes are consuming outputs of forward nodes, and their outputs // are not used by main branch backward nodes. - for (const auto& [output_arg, grouped_node] : output_arg_to_grouped_node) { + InlinedVector> + left_output_arg_to_grouped_node_vector; // To ensure deterministic order. + left_output_arg_to_grouped_node_vector.reserve(output_arg_to_grouped_node.size()); + for (auto& [output_arg, grouped_node] : output_arg_to_grouped_node) { if (!grouped_node->is_outputted) { - OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order); + left_output_arg_to_grouped_node_vector.push_back({output_arg, grouped_node}); + } + } + + if (!left_output_arg_to_grouped_node_vector.empty()) { + // Sort to ensure deterministic order. + std::sort(left_output_arg_to_grouped_node_vector.begin(), left_output_arg_to_grouped_node_vector.end(), + [](const std::pair& a, const std::pair& b) { + return a.first->Name() < b.first->Name(); + }); + for (const auto& pair : left_output_arg_to_grouped_node_vector) { + const NodeArg* output_arg = pair.first; + GroupNode* grouped_node = pair.second; + if (!grouped_node->is_outputted) { + OutputGroupedNodes(this, output_arg, output_arg_to_grouped_node, node_orders, topo_order); + } } } diff --git a/onnxruntime/test/ir/graph_test.cc b/onnxruntime/test/ir/graph_test.cc index 590d18be91bb2..ff10765741bbe 100644 --- a/onnxruntime/test/ir/graph_test.cc +++ b/onnxruntime/test/ir/graph_test.cc @@ -2291,6 +2291,85 @@ TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_MultiLayerRec } } +TEST_F(GraphTest, GraphConstruction_MemoryEfficientTopologicalSort_SubgraphGeneratingNodeHavingNoConsumers) { + Model model("graph_1", false, *logger_); + auto& graph = model.MainGraph(); + + /* + | + node_0 (Identity) + / \ \ + node_1 (Identity) \ Identity + | | \_____graph_output_0 + node_4 (Identity) | + | | + YieldOp recompute_node_1 + \ / \ + node_1_grad (Merge) Identity + | | + graph_output_1 + */ + + TypeProto tensor_int32; + tensor_int32.mutable_tensor_type()->set_elem_type(TensorProto_DataType_INT32); + tensor_int32.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(1); + + auto& input_arg0 = graph.GetOrCreateNodeArg("node_0_in_1", &tensor_int32); + auto& output_arg0 = graph.GetOrCreateNodeArg("node_0_out_1", &tensor_int32); + auto& graph_output_0_identity = graph.GetOrCreateNodeArg("graphoutput_0_identity_out_1", &tensor_int32); + auto& output_arg1 = graph.GetOrCreateNodeArg("node_1_out_1", &tensor_int32); + auto& output_arg2 = graph.GetOrCreateNodeArg("node_2_out_1", &tensor_int32); + auto& output_arg4 = graph.GetOrCreateNodeArg("node_4_out_1", &tensor_int32); + auto& output_arg5 = graph.GetOrCreateNodeArg("node_yield_out_1", &tensor_int32); + auto& output_arg6 = graph.GetOrCreateNodeArg("node_5_out_1", &tensor_int32); + + graph.AddNode("node_0", "Identity_Fake", "node 0", {&input_arg0}, {&output_arg0}); + graph.AddNode("node_1", "Identity_Fake", "node 1", {&output_arg0}, {&output_arg1}); + graph.AddNode("graph_output_0_identity", "Identity_Fake", "graph output 0 identity", {&output_arg0}, {&graph_output_0_identity}); + graph.AddNode("recompute_node_1", "Identity_Fake", "recompute node 1", {&output_arg0}, {&output_arg2}); + + auto& graph_output1_identity = graph.GetOrCreateNodeArg("graphoutput_1_identity_out_1", &tensor_int32); + graph.AddNode("graph_output_1_identity", "Identity_Fake", "graph output 1 identity", {&output_arg2}, {&graph_output1_identity}); + + graph.AddNode("node_4", "Identity_Fake", "node 4", {&output_arg1}, {&output_arg4}); + + ONNX_NAMESPACE::AttributeProto full_shape_outputs; + const std::string attribute_name = "full_shape_outputs"; + full_shape_outputs.set_name(attribute_name); + full_shape_outputs.set_type(ONNX_NAMESPACE::AttributeProto::INTS); + full_shape_outputs.add_ints(static_cast(0)); + NodeAttributes attributes({{attribute_name, full_shape_outputs}}); + + graph.AddNode("node_yield", "YieldOp", "node yield", {&output_arg4}, {&output_arg5}, &attributes, kMSDomain); + graph.AddNode("node_1_grad", "Merge_Fake", "node_1 gradient", {&output_arg5, &output_arg2}, {&output_arg6}); + + auto status = graph.Resolve(); + EXPECT_TRUE(status.IsOK()) << status.ErrorMessage(); + GraphViewer graph_viewer(graph); + + // MEMORY_EFFICIENT order + { + auto& order = graph_viewer.GetNodesInTopologicalOrder(ExecutionOrder::MEMORY_EFFICIENT); + const std::vector expected_order = + { + "node_0", + "node_1", + "node_4", + "node_yield", + "recompute_node_1", + "node_1_grad", + "graph_output_0_identity", + "graph_output_1_identity", + }; + for (size_t i = 0; i < order.size(); ++i) { + auto node = graph.GetNode(order[i]); + EXPECT_TRUE(node->Name() == expected_order[i]) + << "MEMORY_EFFICIENT based execution order is wrong. expected node is " << expected_order[i] + << " but got " << node->Name(); + } + } +} + #endif } // namespace test