Skip to content

Commit 03c6c2e

Browse files
[QNN] MatMulAddFusion and Reshape Related Fusion (#22494)
QNN EP relies on Gemm Op to use FullyConnected QNN Op to run the model, which is much faster than MatMul+Add. This PR fuses MatMul+Add when MatMul's 2nd input is 2D initializer, no matter the rank of the 1st input. If the 1st input is not 2D tensor, Reshape nodes will be added. On QNN EP, the memory allocation is for each activation tensor, so Reshape/Squeeze/Unsqueeze is not no-op. This PR also add some fusion trying to remove redundant reshape nodes. For some QNN AI Hub models on specific device, without removing the Reshape nodes, it cannot finalize the graph when execution, but works well after removing. Run below models with and without the change: swin_tiny: Average inference time cost: 12.8077 ms | Average inference time cost: 23.956 ms swin_base: Average inference time cost: 27.0639 ms | Average inference time cost: 57.6608 ms convnext_tiny: Average inference time cost: 3.42956 ms | Average inference time cost: 16.1848 ms openai_clip_CLIPTextEncoder: Average inference time cost: 5.96104 ms | Average inference time cost: 220.406 ms openai_clip_CLIPImageEncoder: Average inference time cost: 41.8206 ms | Average inference time cost: 919.712 ms NOTE that current change skips the Attention pattern because it not it will cause AttentionFusion to work. Ideally we need to adjust the AttentionFusion to support the Gemm pattern, but it requires big changes. Maybe we can do this in the future, say, when we want to run transformer models on QNN, since we don't have Attention QNN, we still want to fuse MatMul+Add in the Attention pattern to use FullyConnected in QNN side. --------- Co-authored-by: adrianlizarraga <adlizarraga@microsoft.com>
1 parent 60d25b2 commit 03c6c2e

33 files changed

+943
-393
lines changed

onnxruntime/core/optimizer/matmul_add_fusion.cc

Lines changed: 153 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,63 @@ using namespace ONNX_NAMESPACE;
1111
using namespace ::onnxruntime::common;
1212
namespace onnxruntime {
1313

14+
namespace {
15+
16+
// Attention subgraph has 4 MatMul-Add pairs, that we want to skip here because AttentionFusion will handle it.
17+
// In such case, 3 of MatMul-Add pairs are following LN, the other one produces output which is added with LN's output.
18+
// Use two sets to remember such patterns we already met during the graph iteration so that we can skip them directly
19+
// if we go to other MatMul-Add pairs in the same pattern.
20+
struct AttentionPatternCache {
21+
bool IsAttentionPattern(const Graph& graph, const Node& matmul_node, const Node& add_node) {
22+
const Node* parent_node = graph.GetProducerNode(matmul_node.InputDefs()[0]->Name());
23+
if (attn_ln_nodes.count(parent_node) > 0 || attn_add_nodes.count(&add_node) > 0) {
24+
return true;
25+
}
26+
27+
if (parent_node && parent_node->OpType() == "LayerNormalization") {
28+
unsigned int add_count = 0;
29+
unsigned int matmul_count = 0;
30+
unsigned int shape_count = 0;
31+
const Node* ln_add_node = nullptr;
32+
for (auto it = parent_node->OutputNodesBegin(); it != parent_node->OutputNodesEnd(); ++it) {
33+
std::string op_type = (*it).OpType();
34+
if (op_type == "Add") {
35+
ln_add_node = &(*it);
36+
add_count++;
37+
} else if (op_type == "MatMul") {
38+
matmul_count++;
39+
} else if (op_type == "Shape") {
40+
shape_count++;
41+
}
42+
}
43+
44+
if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount() - 4) {
45+
size_t index = ln_add_node->InputDefs()[0]->Name() == parent_node->OutputDefs()[0]->Name() ? 1 : 0;
46+
const Node* attn_add_node = graph.GetProducerNode(ln_add_node->InputDefs()[index]->Name());
47+
if (attn_add_node && attn_add_node->OpType() == "Add") {
48+
attn_ln_nodes.insert(parent_node);
49+
attn_add_nodes.insert(attn_add_node);
50+
return true;
51+
}
52+
}
53+
}
54+
55+
return false;
56+
}
57+
58+
std::unordered_set<const Node*> attn_ln_nodes;
59+
std::unordered_set<const Node*> attn_add_nodes;
60+
};
61+
62+
} // namespace
63+
1464
Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
1565
GraphViewer graph_viewer(graph);
1666
const auto& node_topology_list = graph_viewer.GetNodesInTopologicalOrder();
1767

68+
// Cache for skipping Attention subgraph pattern.
69+
AttentionPatternCache attn_pattern_cache;
70+
1871
for (auto node_index : node_topology_list) {
1972
auto* node_ptr = graph.GetNode(node_index);
2073
if (!node_ptr)
@@ -65,58 +118,133 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
65118
// Gemm only support Matrix, need to check the shape of MatMul and Add
66119
auto matmul_a_shape = matmul_input_defs[0]->Shape();
67120
auto matmul_b_shape = matmul_input_defs[1]->Shape();
68-
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape) {
121+
if (nullptr == matmul_a_shape || nullptr == matmul_b_shape || matmul_b_shape->dim_size() != 2) {
69122
continue;
70123
}
71124

72-
if (2 != matmul_a_shape->dim_size() || 2 != matmul_b_shape->dim_size()) {
73-
// Gemm only support Matrix
74-
continue;
125+
bool need_reshape = matmul_a_shape->dim_size() != 2;
126+
const auto& dim_n = matmul_b_shape->dim(1);
127+
InlinedVector<int64_t> shape_values;
128+
int64_t m = 0, k = 0, n = 0;
129+
if (need_reshape) {
130+
// Only check and skip Attention pattern here because normally input to Attention is 4D.
131+
if (attn_pattern_cache.IsAttentionPattern(graph, matmul_node, add_node)) {
132+
continue;
133+
}
134+
135+
// Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
136+
// both inputs have concrete shape for now, we can add dynamic shape support in future.
137+
auto a_shape = utils::GetTensorShapeFromTensorShapeProto(*matmul_a_shape);
138+
if (a_shape.Size() == -1) {
139+
continue;
140+
}
141+
142+
const auto& dim_k = matmul_b_shape->dim(0);
143+
if (!utils::HasDimValue(dim_k) || !utils::HasDimValue(dim_n)) {
144+
continue;
145+
}
146+
147+
shape_values = a_shape.AsShapeVector();
148+
// If a_shape is 1D, m is 1 from SizeToDimension() with empty dimension interval.
149+
m = a_shape.SizeToDimension(a_shape.NumDimensions() - 1);
150+
k = dim_k.dim_value();
151+
n = dim_n.dim_value();
75152
}
76153

77154
const auto& matmul_output = *matmul_node.OutputDefs()[0];
78155

79156
auto matmul_output_name = matmul_output.Name();
80157
auto gemm_input_defs = matmul_input_defs;
81-
if (matmul_output_name == add_input_defs[0]->Name()) {
82-
// matmul output as Add_A, should use Add_B as input C for gemm
83-
gemm_input_defs.push_back(add_input_defs[1]);
84-
} else {
85-
// matmul output as Add_B, should use Add_A as input C for gemm
86-
gemm_input_defs.push_back(add_input_defs[0]);
87-
}
158+
int bias_idx = matmul_output_name == add_input_defs[0]->Name() ? 1 : 0;
159+
gemm_input_defs.push_back(add_input_defs[bias_idx]);
88160

89161
// valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
90162
// GEMM only supports unidirectional broadcast on the bias input C
91163
if (!gemm_input_defs.back()->Shape()) {
92164
continue;
93165
}
94166
const auto& bias_shape = *gemm_input_defs.back()->Shape();
95-
const auto& M = matmul_output.Shape()->dim()[0];
96-
const auto& N = matmul_output.Shape()->dim()[1];
97167
auto dim_has_value_1 = [](const TensorShapeProto_Dimension& dim) {
98168
return dim.has_dim_value() && dim.dim_value() == 1;
99169
};
100170

101-
bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim()[0] == N) ||
102-
(bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim()[0]) && bias_shape.dim()[1] == N) ||
103-
(bias_shape.dim_size() == 2 && bias_shape.dim()[0] == M &&
104-
(dim_has_value_1(bias_shape.dim()[1]) || bias_shape.dim()[1] == N)));
171+
bool valid = ((bias_shape.dim_size() == 1 && bias_shape.dim(0) == dim_n) ||
172+
(!need_reshape && bias_shape.dim_size() == 2 && dim_has_value_1(bias_shape.dim(0)) &&
173+
bias_shape.dim(1) == dim_n) ||
174+
(!need_reshape && bias_shape.dim_size() == 2 && bias_shape.dim(0) == matmul_a_shape->dim(0) &&
175+
(dim_has_value_1(bias_shape.dim(1)) || bias_shape.dim(1) == dim_n)));
105176
if (!valid) {
106177
continue;
107178
}
108179

109-
Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion/"),
110-
"Gemm",
111-
"fused Matmul and Add " + add_node.OpType(),
112-
gemm_input_defs,
113-
{});
180+
auto gemm_output_defs = add_node.MutableOutputDefs();
181+
Node* input_node = nullptr;
182+
Node* output_node = nullptr;
183+
if (need_reshape) {
184+
auto add_reshape = [&](const InlinedVector<int64_t>& shape, Graph& graph, bool is_input) -> Node* {
185+
const std::string name = is_input ? "gemm_input" : "gemm_output";
186+
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
187+
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_shape"));
188+
shape_initializer_proto.add_dims(static_cast<int64_t>(shape.size()));
189+
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
190+
shape_initializer_proto.set_raw_data(shape.data(), shape.size() * sizeof(int64_t));
191+
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
192+
ONNX_NAMESPACE::TypeProto new_arg_type;
193+
const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast<ONNX_NAMESPACE::TensorProto_DataType>(
194+
gemm_input_defs[0]->TypeAsProto()->tensor_type().elem_type());
195+
new_arg_type.mutable_tensor_type()->set_elem_type(element_type);
196+
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(m);
197+
new_arg_type.mutable_tensor_type()->mutable_shape()->add_dim()->set_dim_value(is_input ? k : n);
198+
NodeArg* new_arg = &graph.GetOrCreateNodeArg(graph.GenerateNodeArgName(name + "_reshape_arg"), &new_arg_type);
199+
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_reshape"), "Reshape", "Reshape for " + name,
200+
{is_input ? gemm_input_defs[0] : new_arg, shape_arg},
201+
{is_input ? new_arg : gemm_output_defs[0]});
202+
reshape_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
203+
return &reshape_node;
204+
};
205+
206+
input_node = add_reshape({m, k}, graph, true);
207+
gemm_input_defs[0] = input_node->MutableOutputDefs()[0];
208+
shape_values.back() = n;
209+
output_node = add_reshape(shape_values, graph, false);
210+
gemm_output_defs[0] = output_node->MutableInputDefs()[0];
211+
}
114212

115-
// Assign provider to this new node. Provider should be same as the provider for old node.
213+
Node& gemm_node = graph.AddNode(graph.GenerateNodeName(matmul_node.Name() + "/MatMulAddFusion"), "Gemm",
214+
"fused Matmul and Add", gemm_input_defs, gemm_output_defs);
116215
gemm_node.SetExecutionProviderType(matmul_node.GetExecutionProviderType());
117216

118-
// move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node.
119-
graph_utils::FinalizeNodeFusion(graph, {matmul_node, add_node}, gemm_node);
217+
if (need_reshape) {
218+
graph.AddEdge(input_node->Index(), gemm_node.Index(), 0, 0);
219+
graph.AddEdge(gemm_node.Index(), output_node->Index(), 0, 0);
220+
} else {
221+
input_node = &gemm_node;
222+
output_node = &gemm_node;
223+
}
224+
225+
auto matmul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(matmul_node);
226+
for (auto cur = matmul_input_edges.cbegin(), end = matmul_input_edges.cend(); cur != end; ++cur) {
227+
if (cur->dst_arg_index == 0) {
228+
graph.AddEdge(cur->src_node, input_node->Index(), cur->src_arg_index, 0);
229+
} else if (cur->dst_arg_index == 1) {
230+
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 1);
231+
}
232+
}
233+
234+
graph_utils::GraphEdge::RemoveGraphEdges(graph, matmul_input_edges);
235+
auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(add_node);
236+
for (auto cur = add_input_edges.cbegin(), end = add_input_edges.cend(); cur != end; ++cur) {
237+
if (cur->dst_arg_index == bias_idx) {
238+
graph.AddEdge(cur->src_node, gemm_node.Index(), cur->src_arg_index, 2);
239+
break;
240+
}
241+
}
242+
243+
graph_utils::GraphEdge::RemoveGraphEdges(graph, add_input_edges);
244+
graph_utils::RemoveNodeOutputEdges(graph, matmul_node);
245+
graph_utils::ReplaceDownstreamNodeInput(graph, add_node, 0, *output_node, 0);
246+
graph.RemoveNode(matmul_node.Index());
247+
graph.RemoveNode(add_node.Index());
120248

121249
modified = true;
122250
}

onnxruntime/core/optimizer/reshape_fusion.cc

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ Status ReshapeFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, c
4848
fused_count++;
4949
LOGS(logger, INFO) << "Fused reshape node: " << reshape.OutputDefs()[0]->Name();
5050
modified = true;
51+
} else if (ReshapeFusion::FuseContiguousReshapes(reshape, graph)) {
52+
modified = true;
5153
}
5254
}
5355

@@ -452,4 +454,53 @@ bool ReshapeFusion::Fuse_Subgraph(Node& reshape, Graph& graph, const logging::Lo
452454
return true;
453455
}
454456

457+
bool ReshapeFusion::FuseContiguousReshapes(Node& reshape, Graph& graph) {
458+
InlinedVector<std::reference_wrapper<Node>> contiguous_reshapes{reshape};
459+
InlinedVector<int64_t> shape_value;
460+
while (true) {
461+
Node& curr_node = contiguous_reshapes.back();
462+
if (graph.NodeProducesGraphOutput(curr_node) || curr_node.GetOutputEdgesCount() != 1) {
463+
break;
464+
}
465+
466+
Node* next_node = graph.GetNode(curr_node.OutputNodesBegin()->Index());
467+
if (next_node->OpType() != "Reshape" && next_node->OpType() != "Squeeze" && next_node->OpType() != "Unsqueeze") {
468+
break;
469+
}
470+
471+
auto shape = next_node->OutputDefs()[0]->Shape();
472+
if (!shape) {
473+
break;
474+
}
475+
476+
auto tensor_shape = utils::GetTensorShapeFromTensorShapeProto(*shape);
477+
if (tensor_shape.Size() == -1) {
478+
break;
479+
}
480+
481+
shape_value = tensor_shape.AsShapeVector();
482+
contiguous_reshapes.emplace_back(*next_node);
483+
}
484+
485+
if (contiguous_reshapes.size() < 2) {
486+
return false;
487+
}
488+
489+
const std::string& name = contiguous_reshapes[0].get().Name();
490+
ONNX_NAMESPACE::TensorProto shape_initializer_proto;
491+
shape_initializer_proto.set_name(graph.GenerateNodeName(name + "_new_shape"));
492+
shape_initializer_proto.add_dims(static_cast<int64_t>(shape_value.size()));
493+
shape_initializer_proto.set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
494+
shape_initializer_proto.set_raw_data(shape_value.data(), shape_value.size() * sizeof(int64_t));
495+
NodeArg* shape_arg = &graph_utils::AddInitializer(graph, shape_initializer_proto);
496+
Node& reshape_node = graph.AddNode(graph.GenerateNodeName(name + "_new_reshape"), "Reshape", "Reshape for " + name,
497+
{contiguous_reshapes[0].get().MutableInputDefs()[0], shape_arg},
498+
{contiguous_reshapes.back().get().MutableOutputDefs()[0]});
499+
reshape_node.SetExecutionProviderType(contiguous_reshapes[0].get().GetExecutionProviderType());
500+
501+
graph_utils::FinalizeNodeFusion(graph, contiguous_reshapes, reshape_node);
502+
503+
return true;
504+
}
505+
455506
} // namespace onnxruntime

onnxruntime/core/optimizer/reshape_fusion.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@ class ReshapeFusion : public GraphTransformer {
2727
static bool Is_One_Element_Input(const Node& cur_node, int index);
2828
static bool Is_One_Element_Output_Subgraph(Graph& graph, const NodeArg& root_input, const Node& concat,
2929
int index, gsl::span<const int64_t> shape_value, const logging::Logger& logger);
30+
31+
// Remove contiguous Reshape/Squeeze/Unsqueeze if the shape info is concrete.
32+
// For some EP, such reshape Ops are not no-op, such as QNN EP, memory is allocated for each output,
33+
// so this fusion can help to reduce memory usage on such devices.
34+
static bool FuseContiguousReshapes(Node& reshape, Graph& graph);
3035
};
3136

3237
} // namespace onnxruntime

0 commit comments

Comments
 (0)