Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion onnxruntime/core/optimizer/layer_norm_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -752,11 +752,37 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr
// Assign provider to this new node. Provider should be same as the provider for old node.
layer_norm_node.SetExecutionProviderType(reduce_mean_node.GetExecutionProviderType());

// move input edges to add (first in list) across to the layer_norm_node.
// FinalizeNodeFusion moves every input edge of the first node by NodeArg name. Disconnect inputs
// that the replacement does not use, such as a Pow exponent produced by a mixed-precision Cast.
// Keep track of their producers so they can be removed if this fusion makes them dead.
InlinedVector<NodeIndex> unused_input_node_indices;
const auto first_node_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(nodes_to_remove.front().get());
for (const auto& input_edge : first_node_input_edges) {
const bool is_replacement_input =
std::any_of(layer_norm_input_defs.cbegin(), layer_norm_input_defs.cend(),
[&input_edge](const NodeArg* input) { return input->Name() == input_edge.arg_name; });
if (!is_replacement_input) {
unused_input_node_indices.push_back(input_edge.src_node);
graph.RemoveEdge(input_edge.src_node, input_edge.dst_node,
input_edge.src_arg_index, input_edge.dst_arg_index);
}
}

// move input edges from the first node in nodes_to_remove to layer_norm_node.
// move output definitions and output edges from mul_node (last in list) to layer_norm_node.
// remove all the other nodes.
graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node);

// Remove unused input producers and any newly dead upstream nodes only after their final consumer is
// fused. A producer can be shared by multiple matched subgraphs, so it must remain while it still has users.
for (const NodeIndex unused_input_node_index : unused_input_node_indices) {
Comment thread
the0cp marked this conversation as resolved.
Node* unused_input_node = graph.GetNode(unused_input_node_index);
if (unused_input_node != nullptr && unused_input_node->GetOutputEdgesCount() == 0 &&
!graph.NodeProducesGraphOutput(*unused_input_node)) {
graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *unused_input_node);
}
}

#ifdef ENABLE_TRAINING_CORE
// add one extra output def, so we have 2 output defs that match what gradient builder expected
layer_norm_node.MutableOutputDefs().push_back(
Expand Down
61 changes: 61 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,67 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) {
}
}

TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) {
auto build_test_case = [](ModelTestBuilder& builder) {
auto* pow_exponent_fp16 =
builder.MakeInitializer<MLFloat16>({}, {MLFloat16(2.0f)});
auto* pow_exponent = builder.MakeIntermediate();
builder.AddNode("Cast", {pow_exponent_fp16}, {pow_exponent})
.AddAttribute("to", static_cast<int64_t>(ONNX_NAMESPACE::TensorProto_DataType_FLOAT));

auto* epsilon = builder.MakeInitializer<float>({}, {1e-5f});
auto* scale = builder.MakeInitializer<float>({4}, {1.0f, 1.0f, 1.0f, 1.0f});

auto add_simplified_layer_norm = [&](NodeArg* input) {
auto* pow_out = builder.MakeIntermediate();
auto* reduce_mean_out = builder.MakeIntermediate();
auto* add_out = builder.MakeIntermediate();
auto* sqrt_out = builder.MakeIntermediate();
auto* div_out = builder.MakeIntermediate();
auto* output = builder.MakeOutput();

builder.AddNode("Pow", {input, pow_exponent}, {pow_out});
builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out})
.AddAttribute("axes", std::vector<int64_t>{-1});
builder.AddNode("Add", {reduce_mean_out, epsilon}, {add_out});
builder.AddNode("Sqrt", {add_out}, {sqrt_out});
builder.AddNode("Div", {input, sqrt_out}, {div_out});
builder.AddNode("Mul", {div_out, scale}, {output});
};

add_simplified_layer_norm(builder.MakeInput<float>({{2, 4}}));
add_simplified_layer_norm(builder.MakeInput<float>({{2, 4}}));
};

auto pre_graph_checker = [](Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
const auto cast_it = op_to_count.find("Cast");
TEST_RETURN_IF_NOT(cast_it != op_to_count.end() && cast_it->second == 1);
const auto pow_it = op_to_count.find("Pow");
TEST_RETURN_IF_NOT(pow_it != op_to_count.end() && pow_it->second == 2);
return Status::OK();
};

auto post_graph_checker = [](Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
const auto simplified_layer_norm_it = op_to_count.find("SimplifiedLayerNormalization");
TEST_RETURN_IF_NOT(simplified_layer_norm_it != op_to_count.end() &&
simplified_layer_norm_it->second == 2);
TEST_RETURN_IF_NOT(op_to_count.find("Cast") == op_to_count.end());
TEST_RETURN_IF_NOT(op_to_count.find("Pow") == op_to_count.end());
TEST_RETURN_IF_NOT(op_to_count.find("ReduceMean") == op_to_count.end());
TEST_RETURN_IF_NOT(op_to_count.find("Add") == op_to_count.end());
TEST_RETURN_IF_NOT(op_to_count.find("Sqrt") == op_to_count.end());
TEST_RETURN_IF_NOT(op_to_count.find("Div") == op_to_count.end());
TEST_RETURN_IF_NOT(op_to_count.find("Mul") == op_to_count.end());
return Status::OK();
};

ASSERT_STATUS_OK(TestGraphTransformer(
build_test_case, 17, *logger_, std::make_unique<SimplifiedLayerNormFusion>(),
TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker));
}

// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph
// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly
TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) {
Expand Down
Loading