@@ -11,10 +11,63 @@ using namespace ONNX_NAMESPACE;
1111using namespace ::onnxruntime::common;
1212namespace onnxruntime {
1313
14+ namespace {
15+
16+ // Attention subgraph has 4 MatMul-Add pairs, that we want to skip here because AttentionFusion will handle it.
17+ // In such case, 3 of MatMul-Add pairs are following LN, the other one produces output which is added with LN's output.
18+ // Use two sets to remember such patterns we already met during the graph iteration so that we can skip them directly
19+ // if we go to other MatMul-Add pairs in the same pattern.
20+ struct AttentionPatternCache {
21+ bool IsAttentionPattern (const Graph& graph, const Node& matmul_node, const Node& add_node) {
22+ const Node* parent_node = graph.GetProducerNode (matmul_node.InputDefs ()[0 ]->Name ());
23+ if (attn_ln_nodes.count (parent_node) > 0 || attn_add_nodes.count (&add_node) > 0 ) {
24+ return true ;
25+ }
26+
27+ if (parent_node && parent_node->OpType () == " LayerNormalization" ) {
28+ unsigned int add_count = 0 ;
29+ unsigned int matmul_count = 0 ;
30+ unsigned int shape_count = 0 ;
31+ const Node* ln_add_node = nullptr ;
32+ for (auto it = parent_node->OutputNodesBegin (); it != parent_node->OutputNodesEnd (); ++it) {
33+ std::string op_type = (*it).OpType ();
34+ if (op_type == " Add" ) {
35+ ln_add_node = &(*it);
36+ add_count++;
37+ } else if (op_type == " MatMul" ) {
38+ matmul_count++;
39+ } else if (op_type == " Shape" ) {
40+ shape_count++;
41+ }
42+ }
43+
44+ if (add_count == 1 && matmul_count == 3 && shape_count == parent_node->GetOutputEdgesCount () - 4 ) {
45+ size_t index = ln_add_node->InputDefs ()[0 ]->Name () == parent_node->OutputDefs ()[0 ]->Name () ? 1 : 0 ;
46+ const Node* attn_add_node = graph.GetProducerNode (ln_add_node->InputDefs ()[index]->Name ());
47+ if (attn_add_node && attn_add_node->OpType () == " Add" ) {
48+ attn_ln_nodes.insert (parent_node);
49+ attn_add_nodes.insert (attn_add_node);
50+ return true ;
51+ }
52+ }
53+ }
54+
55+ return false ;
56+ }
57+
58+ std::unordered_set<const Node*> attn_ln_nodes;
59+ std::unordered_set<const Node*> attn_add_nodes;
60+ };
61+
62+ } // namespace
63+
1464Status MatMulAddFusion::ApplyImpl (Graph& graph, bool & modified, int graph_level, const logging::Logger& logger) const {
1565 GraphViewer graph_viewer (graph);
1666 const auto & node_topology_list = graph_viewer.GetNodesInTopologicalOrder ();
1767
68+ // Cache for skipping Attention subgraph pattern.
69+ AttentionPatternCache attn_pattern_cache;
70+
1871 for (auto node_index : node_topology_list) {
1972 auto * node_ptr = graph.GetNode (node_index);
2073 if (!node_ptr)
@@ -65,58 +118,133 @@ Status MatMulAddFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level,
65118 // Gemm only support Matrix, need to check the shape of MatMul and Add
66119 auto matmul_a_shape = matmul_input_defs[0 ]->Shape ();
67120 auto matmul_b_shape = matmul_input_defs[1 ]->Shape ();
68- if (nullptr == matmul_a_shape || nullptr == matmul_b_shape) {
121+ if (nullptr == matmul_a_shape || nullptr == matmul_b_shape || matmul_b_shape-> dim_size () != 2 ) {
69122 continue ;
70123 }
71124
72- if (2 != matmul_a_shape->dim_size () || 2 != matmul_b_shape->dim_size ()) {
73- // Gemm only support Matrix
74- continue ;
125+ bool need_reshape = matmul_a_shape->dim_size () != 2 ;
126+ const auto & dim_n = matmul_b_shape->dim (1 );
127+ InlinedVector<int64_t > shape_values;
128+ int64_t m = 0 , k = 0 , n = 0 ;
129+ if (need_reshape) {
130+ // Only check and skip Attention pattern here because normally input to Attention is 4D.
131+ if (attn_pattern_cache.IsAttentionPattern (graph, matmul_node, add_node)) {
132+ continue ;
133+ }
134+
135+ // Logically we can use Shape-Concat to produce shape input for Reshape, to keep it simple, we require
136+ // both inputs have concrete shape for now, we can add dynamic shape support in future.
137+ auto a_shape = utils::GetTensorShapeFromTensorShapeProto (*matmul_a_shape);
138+ if (a_shape.Size () == -1 ) {
139+ continue ;
140+ }
141+
142+ const auto & dim_k = matmul_b_shape->dim (0 );
143+ if (!utils::HasDimValue (dim_k) || !utils::HasDimValue (dim_n)) {
144+ continue ;
145+ }
146+
147+ shape_values = a_shape.AsShapeVector ();
148+ // If a_shape is 1D, m is 1 from SizeToDimension() with empty dimension interval.
149+ m = a_shape.SizeToDimension (a_shape.NumDimensions () - 1 );
150+ k = dim_k.dim_value ();
151+ n = dim_n.dim_value ();
75152 }
76153
77154 const auto & matmul_output = *matmul_node.OutputDefs ()[0 ];
78155
79156 auto matmul_output_name = matmul_output.Name ();
80157 auto gemm_input_defs = matmul_input_defs;
81- if (matmul_output_name == add_input_defs[0 ]->Name ()) {
82- // matmul output as Add_A, should use Add_B as input C for gemm
83- gemm_input_defs.push_back (add_input_defs[1 ]);
84- } else {
85- // matmul output as Add_B, should use Add_A as input C for gemm
86- gemm_input_defs.push_back (add_input_defs[0 ]);
87- }
158+ int bias_idx = matmul_output_name == add_input_defs[0 ]->Name () ? 1 : 0 ;
159+ gemm_input_defs.push_back (add_input_defs[bias_idx]);
88160
89161 // valid bias_shapes are (N) or (1, N) or (M, 1) or (M, N) as
90162 // GEMM only supports unidirectional broadcast on the bias input C
91163 if (!gemm_input_defs.back ()->Shape ()) {
92164 continue ;
93165 }
94166 const auto & bias_shape = *gemm_input_defs.back ()->Shape ();
95- const auto & M = matmul_output.Shape ()->dim ()[0 ];
96- const auto & N = matmul_output.Shape ()->dim ()[1 ];
97167 auto dim_has_value_1 = [](const TensorShapeProto_Dimension& dim) {
98168 return dim.has_dim_value () && dim.dim_value () == 1 ;
99169 };
100170
101- bool valid = ((bias_shape.dim_size () == 1 && bias_shape.dim ()[0 ] == N) ||
102- (bias_shape.dim_size () == 2 && dim_has_value_1 (bias_shape.dim ()[0 ]) && bias_shape.dim ()[1 ] == N) ||
103- (bias_shape.dim_size () == 2 && bias_shape.dim ()[0 ] == M &&
104- (dim_has_value_1 (bias_shape.dim ()[1 ]) || bias_shape.dim ()[1 ] == N)));
171+ bool valid = ((bias_shape.dim_size () == 1 && bias_shape.dim (0 ) == dim_n) ||
172+ (!need_reshape && bias_shape.dim_size () == 2 && dim_has_value_1 (bias_shape.dim (0 )) &&
173+ bias_shape.dim (1 ) == dim_n) ||
174+ (!need_reshape && bias_shape.dim_size () == 2 && bias_shape.dim (0 ) == matmul_a_shape->dim (0 ) &&
175+ (dim_has_value_1 (bias_shape.dim (1 )) || bias_shape.dim (1 ) == dim_n)));
105176 if (!valid) {
106177 continue ;
107178 }
108179
109- Node& gemm_node = graph.AddNode (graph.GenerateNodeName (matmul_node.Name () + " /MatMulAddFusion/" ),
110- " Gemm" ,
111- " fused Matmul and Add " + add_node.OpType (),
112- gemm_input_defs,
113- {});
180+ auto gemm_output_defs = add_node.MutableOutputDefs ();
181+ Node* input_node = nullptr ;
182+ Node* output_node = nullptr ;
183+ if (need_reshape) {
184+ auto add_reshape = [&](const InlinedVector<int64_t >& shape, Graph& graph, bool is_input) -> Node* {
185+ const std::string name = is_input ? " gemm_input" : " gemm_output" ;
186+ ONNX_NAMESPACE::TensorProto shape_initializer_proto;
187+ shape_initializer_proto.set_name (graph.GenerateNodeName (name + " _shape" ));
188+ shape_initializer_proto.add_dims (static_cast <int64_t >(shape.size ()));
189+ shape_initializer_proto.set_data_type (ONNX_NAMESPACE::TensorProto_DataType_INT64);
190+ shape_initializer_proto.set_raw_data (shape.data (), shape.size () * sizeof (int64_t ));
191+ NodeArg* shape_arg = &graph_utils::AddInitializer (graph, shape_initializer_proto);
192+ ONNX_NAMESPACE::TypeProto new_arg_type;
193+ const ONNX_NAMESPACE::TensorProto_DataType element_type = static_cast <ONNX_NAMESPACE::TensorProto_DataType>(
194+ gemm_input_defs[0 ]->TypeAsProto ()->tensor_type ().elem_type ());
195+ new_arg_type.mutable_tensor_type ()->set_elem_type (element_type);
196+ new_arg_type.mutable_tensor_type ()->mutable_shape ()->add_dim ()->set_dim_value (m);
197+ new_arg_type.mutable_tensor_type ()->mutable_shape ()->add_dim ()->set_dim_value (is_input ? k : n);
198+ NodeArg* new_arg = &graph.GetOrCreateNodeArg (graph.GenerateNodeArgName (name + " _reshape_arg" ), &new_arg_type);
199+ Node& reshape_node = graph.AddNode (graph.GenerateNodeName (name + " _reshape" ), " Reshape" , " Reshape for " + name,
200+ {is_input ? gemm_input_defs[0 ] : new_arg, shape_arg},
201+ {is_input ? new_arg : gemm_output_defs[0 ]});
202+ reshape_node.SetExecutionProviderType (matmul_node.GetExecutionProviderType ());
203+ return &reshape_node;
204+ };
205+
206+ input_node = add_reshape ({m, k}, graph, true );
207+ gemm_input_defs[0 ] = input_node->MutableOutputDefs ()[0 ];
208+ shape_values.back () = n;
209+ output_node = add_reshape (shape_values, graph, false );
210+ gemm_output_defs[0 ] = output_node->MutableInputDefs ()[0 ];
211+ }
114212
115- // Assign provider to this new node. Provider should be same as the provider for old node.
213+ Node& gemm_node = graph.AddNode (graph.GenerateNodeName (matmul_node.Name () + " /MatMulAddFusion" ), " Gemm" ,
214+ " fused Matmul and Add" , gemm_input_defs, gemm_output_defs);
116215 gemm_node.SetExecutionProviderType (matmul_node.GetExecutionProviderType ());
117216
118- // move output definitions and edges from act_node to gemm_node. delete gemm_node and act_node.
119- graph_utils::FinalizeNodeFusion (graph, {matmul_node, add_node}, gemm_node);
217+ if (need_reshape) {
218+ graph.AddEdge (input_node->Index (), gemm_node.Index (), 0 , 0 );
219+ graph.AddEdge (gemm_node.Index (), output_node->Index (), 0 , 0 );
220+ } else {
221+ input_node = &gemm_node;
222+ output_node = &gemm_node;
223+ }
224+
225+ auto matmul_input_edges = graph_utils::GraphEdge::GetNodeInputEdges (matmul_node);
226+ for (auto cur = matmul_input_edges.cbegin (), end = matmul_input_edges.cend (); cur != end; ++cur) {
227+ if (cur->dst_arg_index == 0 ) {
228+ graph.AddEdge (cur->src_node , input_node->Index (), cur->src_arg_index , 0 );
229+ } else if (cur->dst_arg_index == 1 ) {
230+ graph.AddEdge (cur->src_node , gemm_node.Index (), cur->src_arg_index , 1 );
231+ }
232+ }
233+
234+ graph_utils::GraphEdge::RemoveGraphEdges (graph, matmul_input_edges);
235+ auto add_input_edges = graph_utils::GraphEdge::GetNodeInputEdges (add_node);
236+ for (auto cur = add_input_edges.cbegin (), end = add_input_edges.cend (); cur != end; ++cur) {
237+ if (cur->dst_arg_index == bias_idx) {
238+ graph.AddEdge (cur->src_node , gemm_node.Index (), cur->src_arg_index , 2 );
239+ break ;
240+ }
241+ }
242+
243+ graph_utils::GraphEdge::RemoveGraphEdges (graph, add_input_edges);
244+ graph_utils::RemoveNodeOutputEdges (graph, matmul_node);
245+ graph_utils::ReplaceDownstreamNodeInput (graph, add_node, 0 , *output_node, 0 );
246+ graph.RemoveNode (matmul_node.Index ());
247+ graph.RemoveNode (add_node.Index ());
120248
121249 modified = true ;
122250 }
0 commit comments