Skip to content

Commit 00cd7c6

Browse files
committed
[QNN EP] Add FusedMatMul operator support
Add support for the FusedMatMul operator in the QNN execution provider. FusedMatMul is a contrib operator in the Microsoft domain that performs a fused matrix multiplication with optional bias addition and activation. Implementation details: - Added FusedMatMulOpBuilder class that decomposes FusedMatMul into: 1. MatMul operation 2. Optional bias addition 3. Optional activation (Relu, Sigmoid, Tanh, Gelu) - Handles various attributes: transA, transB, alpha, and activation - Supports higher rank tensors and different data types Added comprehensive tests: - Basic functionality tests with various configurations - Tests for both CPU and HTP backends - QDQ (Quantize-Dequantize) tests for 8-bit and 16-bit precision
1 parent 3874516 commit 00cd7c6

File tree

6 files changed

+735
-20
lines changed

6 files changed

+735
-20
lines changed

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
227227
{
228228
CreateInverseOpBuilder("Inverse", *this);
229229
}
230+
231+
{
232+
CreateFusedMatMulOpBuilder("FusedMatMul", *this);
233+
}
230234
}
231235

232236
const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type) {

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ void CreateThresholdedReluOpBuilder(const std::string& op_type, OpBuilderRegistr
126126
void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
127127

128128
void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
129+
void CreateFusedMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
129130

130131
} // namespace qnn
131132
} // namespace onnxruntime
Lines changed: 365 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,365 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/ort_api.h"
5+
#include "core/providers/qnn/builder/op_builder_factory.h"
6+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
7+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
8+
#include "core/providers/qnn/builder/qnn_utils.h"
9+
10+
namespace onnxruntime {
11+
namespace qnn {
12+
13+
// FusedMatMul operator is decomposed into MatMul with optional transposition and alpha scaling.
14+
class FusedMatMulOpBuilder : public BaseOpBuilder {
15+
public:
16+
FusedMatMulOpBuilder() : BaseOpBuilder("FusedMatMulOpBuilder") {}
17+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedMatMulOpBuilder);
18+
19+
protected:
20+
Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger,
21+
std::vector<std::string>& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT;
22+
23+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
24+
std::vector<std::string>&& input_names, const logging::Logger& logger,
25+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
26+
27+
private:
28+
Status ProcessMatMulInputs(QnnModelWrapper& qnn_model_wrapper,
29+
const NodeUnit& node_unit,
30+
const logging::Logger& logger,
31+
std::vector<std::string>& input_names) const ORT_MUST_USE_RESULT;
32+
33+
Status GetFusedMatMulAttributes(const NodeUnit& node_unit,
34+
bool& transA,
35+
bool& transB,
36+
bool& transBatchA,
37+
bool& transBatchB,
38+
float& alpha) const ORT_MUST_USE_RESULT;
39+
40+
Status ProcessPermAttribute(QnnModelWrapper& qnn_model_wrapper,
41+
const NodeUnit& node_unit,
42+
const std::vector<uint32_t>& perm,
43+
std::vector<std::string>& param_tensor_names) const;
44+
45+
void CreateBatchTransposePermVector(const std::vector<uint32_t>& input_shape, std::vector<uint32_t>& perm, bool trans_mat = false) const;
46+
47+
Status HandleBatchTranspose(QnnModelWrapper& qnn_model_wrapper,
48+
const NodeUnit& node_unit,
49+
const TensorInfo& input_info,
50+
const std::string& input_name,
51+
std::string& transposed_name,
52+
bool trans_mat,
53+
bool do_op_validation) const;
54+
};
55+
56+
Status FusedMatMulOpBuilder::GetFusedMatMulAttributes(const NodeUnit& node_unit,
57+
bool& transA,
58+
bool& transB,
59+
bool& transBatchA,
60+
bool& transBatchB,
61+
float& alpha) const {
62+
NodeAttrHelper node_helper(node_unit);
63+
64+
transA = node_helper.Get("transA", static_cast<int64_t>(0)) != 0;
65+
transB = node_helper.Get("transB", static_cast<int64_t>(0)) != 0;
66+
67+
transBatchA = node_helper.Get("transBatchA", static_cast<int64_t>(0)) != 0;
68+
transBatchB = node_helper.Get("transBatchB", static_cast<int64_t>(0)) != 0;
69+
70+
alpha = node_helper.Get("alpha", 1.0f);
71+
72+
return Status::OK();
73+
}
74+
75+
Status FusedMatMulOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
76+
const logging::Logger& logger, std::vector<std::string>& input_names,
77+
bool /*do_op_validation*/) const {
78+
const auto& inputs = node_unit.Inputs();
79+
80+
if (inputs.size() != 2) {
81+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
82+
"FusedMatMul requires exactly 2 inputs, got ", inputs.size());
83+
}
84+
85+
TensorInfo input_info_0{};
86+
TensorInfo input_info_1{};
87+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info_0));
88+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input_info_1));
89+
90+
ORT_RETURN_IF_ERROR(ProcessMatMulInputs(qnn_model_wrapper, node_unit, logger, input_names));
91+
92+
return Status::OK();
93+
}
94+
95+
Status FusedMatMulOpBuilder::ProcessMatMulInputs(QnnModelWrapper& qnn_model_wrapper,
96+
const NodeUnit& node_unit,
97+
const logging::Logger& logger,
98+
std::vector<std::string>& input_names) const {
99+
const auto& inputs = node_unit.Inputs();
100+
101+
// Process input A
102+
const std::string& input_a_name = inputs[0].node_arg.Name();
103+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_a_name)) {
104+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_a_name;
105+
} else {
106+
QnnTensorWrapper input_a_tensor;
107+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[0], input_a_tensor));
108+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_a_tensor)), "Failed to add input A tensor.");
109+
}
110+
input_names.emplace_back(input_a_name);
111+
112+
// Process input B
113+
const std::string& input_b_name = inputs[1].node_arg.Name();
114+
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_b_name)) {
115+
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_b_name;
116+
} else {
117+
QnnTensorWrapper input_b_tensor;
118+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], input_b_tensor));
119+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_b_tensor)), "Failed to add input B tensor.");
120+
}
121+
input_names.emplace_back(input_b_name);
122+
123+
return Status::OK();
124+
}
125+
126+
Status FusedMatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
127+
const NodeUnit& node_unit,
128+
std::vector<std::string>&& input_names,
129+
const logging::Logger& /*logger*/,
130+
bool do_op_validation) const {
131+
bool transA = false;
132+
bool transB = false;
133+
bool transBatchA = false;
134+
bool transBatchB = false;
135+
float alpha = 1.0f;
136+
ORT_RETURN_IF_ERROR(GetFusedMatMulAttributes(node_unit, transA, transB, transBatchA, transBatchB, alpha));
137+
138+
TensorInfo input_a_info{};
139+
TensorInfo input_b_info{};
140+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_a_info));
141+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[1], input_b_info));
142+
143+
std::vector<std::string> matmul_param_tensor_names;
144+
145+
// Set transpose parameters for last two dimensions
146+
// Skip using transpose_in0 param when both transA and transBatchA are present
147+
// Only use transpose_in0 when transA is present and transBatchA is not present
148+
if (!(transA && transBatchA)) {
149+
Qnn_Scalar_t transpose_a_scalar = QNN_SCALAR_INIT;
150+
transpose_a_scalar.dataType = QNN_DATATYPE_BOOL_8;
151+
transpose_a_scalar.bool8Value = transA ? 1 : 0;
152+
QnnParamWrapper transpose_a_param(node_unit.Index(), node_unit.Name(),
153+
QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, transpose_a_scalar);
154+
matmul_param_tensor_names.push_back(transpose_a_param.GetParamTensorName());
155+
qnn_model_wrapper.AddParamWrapper(std::move(transpose_a_param));
156+
}
157+
158+
// Skip using transpose_in1 param when both transB and transBatchB are present
159+
// Only use transpose_in1 when transB is present and transBatchB is not present
160+
if (!(transB && transBatchB)) {
161+
Qnn_Scalar_t transpose_b_scalar = QNN_SCALAR_INIT;
162+
transpose_b_scalar.dataType = QNN_DATATYPE_BOOL_8;
163+
transpose_b_scalar.bool8Value = transB ? 1 : 0;
164+
QnnParamWrapper transpose_b_param(node_unit.Index(), node_unit.Name(),
165+
QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, transpose_b_scalar);
166+
matmul_param_tensor_names.push_back(transpose_b_param.GetParamTensorName());
167+
qnn_model_wrapper.AddParamWrapper(std::move(transpose_b_param));
168+
}
169+
170+
// QNN doesn't directly support batch dimension transposition in MatMul
171+
// We need to insert additional transpose operations before the MatMul if transBatchA or transBatchB is true
172+
std::string input_a_for_matmul = input_names[0];
173+
std::string input_b_for_matmul = input_names[1];
174+
175+
if (transBatchA && input_a_info.shape.size() > 2) {
176+
std::string transposed_a_name;
177+
ORT_RETURN_IF_ERROR(HandleBatchTranspose(qnn_model_wrapper, node_unit, input_a_info,
178+
input_a_for_matmul, transposed_a_name, transA, do_op_validation));
179+
input_a_for_matmul = transposed_a_name;
180+
}
181+
182+
if (transBatchB && input_b_info.shape.size() > 2) {
183+
std::string transposed_b_name;
184+
ORT_RETURN_IF_ERROR(HandleBatchTranspose(qnn_model_wrapper, node_unit, input_b_info,
185+
input_b_for_matmul, transposed_b_name, transB, do_op_validation));
186+
input_b_for_matmul = transposed_b_name;
187+
}
188+
189+
const std::string& output_name = node_unit.Outputs()[0].node_arg.Name();
190+
TensorInfo output_info{};
191+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info));
192+
193+
if (alpha == 1.0f) {
194+
// When alpha is 1.0f, MatMul output is the final output
195+
Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
196+
197+
QnnTensorWrapper output_tensor(output_name,
198+
tensor_type,
199+
output_info.qnn_data_type,
200+
output_info.quant_param.Copy(),
201+
std::vector<uint32_t>(output_info.shape));
202+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)),
203+
"Failed to add final output tensor.");
204+
205+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
206+
utils::GetUniqueName(node_unit.Name() + "_matmul"),
207+
QNN_OP_PACKAGE_NAME_QTI_AISW,
208+
QNN_OP_MAT_MUL,
209+
{input_a_for_matmul, input_b_for_matmul},
210+
{output_name},
211+
std::move(matmul_param_tensor_names),
212+
do_op_validation),
213+
"Failed to create MatMul node for FusedMatMul.");
214+
} else {
215+
// When alpha is not 1.0f, we need an intermediate tensor for MatMul output
216+
// and then apply alpha scaling
217+
std::string matmul_output_name = utils::GetUniqueName(node_unit.Name() + "_matmul_output");
218+
219+
QnnTensorWrapper matmul_output_tensor(matmul_output_name,
220+
QNN_TENSOR_TYPE_NATIVE,
221+
output_info.qnn_data_type,
222+
QnnQuantParamsWrapper(),
223+
std::vector<uint32_t>(output_info.shape));
224+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(matmul_output_tensor)),
225+
"Failed to add MatMul output tensor.");
226+
227+
Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
228+
229+
QnnTensorWrapper output_tensor(output_name,
230+
tensor_type,
231+
output_info.qnn_data_type,
232+
output_info.quant_param.Copy(),
233+
std::vector<uint32_t>(output_info.shape));
234+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)),
235+
"Failed to add output tensor.");
236+
237+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
238+
utils::GetUniqueName(node_unit.Name() + "_matmul"),
239+
QNN_OP_PACKAGE_NAME_QTI_AISW,
240+
QNN_OP_MAT_MUL,
241+
{input_a_for_matmul, input_b_for_matmul},
242+
{matmul_output_name},
243+
std::move(matmul_param_tensor_names),
244+
do_op_validation),
245+
"Failed to create MatMul node for FusedMatMul.");
246+
247+
std::string alpha_tensor_name = utils::GetUniqueName(node_unit.Name() + "_alpha");
248+
std::vector<uint32_t> alpha_shape{1};
249+
Qnn_DataType_t alpha_qnn_data_type = output_info.qnn_data_type;
250+
std::vector<uint8_t> alpha_data;
251+
252+
// The alpha tensor data type should match the MatMul output data type for element-wise multiply
253+
if (alpha_qnn_data_type == QNN_DATATYPE_FLOAT_16) {
254+
alpha_data.resize(sizeof(MLFloat16));
255+
MLFloat16 alpha_fp16(alpha);
256+
memcpy(alpha_data.data(), &alpha_fp16.val, sizeof(MLFloat16));
257+
} else {
258+
alpha_data.resize(sizeof(float));
259+
memcpy(alpha_data.data(), &alpha, sizeof(float));
260+
}
261+
262+
QnnTensorWrapper alpha_tensor_wrapper(alpha_tensor_name,
263+
QNN_TENSOR_TYPE_STATIC,
264+
alpha_qnn_data_type,
265+
QnnQuantParamsWrapper(),
266+
std::move(alpha_shape),
267+
std::move(alpha_data));
268+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_tensor_wrapper)),
269+
"Failed to add alpha tensor.");
270+
271+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
272+
utils::GetUniqueName(node_unit.Name() + "_alpha_scale"),
273+
QNN_OP_PACKAGE_NAME_QTI_AISW,
274+
QNN_OP_ELEMENT_WISE_MULTIPLY,
275+
{matmul_output_name, alpha_tensor_name},
276+
{output_name},
277+
{},
278+
do_op_validation),
279+
"Failed to create alpha scaling node for FusedMatMul.");
280+
}
281+
282+
return Status::OK();
283+
}
284+
285+
Status FusedMatMulOpBuilder::ProcessPermAttribute(QnnModelWrapper& qnn_model_wrapper,
286+
const NodeUnit& node_unit,
287+
const std::vector<uint32_t>& perm,
288+
std::vector<std::string>& param_tensor_names) const {
289+
QnnParamWrapper transpose_param(node_unit.Index(), node_unit.Name(), QNN_OP_TRANSPOSE_PARAM_PERM,
290+
{static_cast<uint32_t>(perm.size())}, std::vector<uint32_t>(perm));
291+
param_tensor_names.push_back(transpose_param.GetParamTensorName());
292+
qnn_model_wrapper.AddParamWrapper(std::move(transpose_param));
293+
294+
return Status::OK();
295+
}
296+
297+
void FusedMatMulOpBuilder::CreateBatchTransposePermVector(const std::vector<uint32_t>& input_shape,
298+
std::vector<uint32_t>& perm,
299+
bool trans_mat) const {
300+
const size_t shape_size = input_shape.size();
301+
302+
perm.clear();
303+
perm.reserve(shape_size);
304+
305+
// 1. Add batch dimensions (1 to shape_size-2)
306+
for (size_t i = 1; i < shape_size - 1; ++i) {
307+
perm.push_back(static_cast<uint32_t>(i));
308+
}
309+
310+
// 2. Add the second-to-last dimension based on trans_mat
311+
perm.push_back(trans_mat ? static_cast<uint32_t>(shape_size - 1) : 0);
312+
313+
// 3. Add the last dimension based on trans_mat
314+
perm.push_back(trans_mat ? 0 : static_cast<uint32_t>(shape_size - 1));
315+
}
316+
317+
Status FusedMatMulOpBuilder::HandleBatchTranspose(QnnModelWrapper& qnn_model_wrapper,
318+
const NodeUnit& node_unit,
319+
const TensorInfo& input_info,
320+
const std::string& input_name,
321+
std::string& transposed_name,
322+
bool trans_mat,
323+
bool do_op_validation) const {
324+
transposed_name = utils::GetUniqueName(node_unit.Name() + "_transposed_" + input_name.substr(input_name.find_last_of('/') + 1));
325+
326+
// Create perm vector for batch transpose
327+
std::vector<uint32_t> perm;
328+
CreateBatchTransposePermVector(input_info.shape, perm, trans_mat);
329+
330+
std::vector<std::string> transpose_params;
331+
ORT_RETURN_IF_ERROR(ProcessPermAttribute(qnn_model_wrapper, node_unit, perm, transpose_params));
332+
333+
// Calculate transposed shape directly using the permutation
334+
std::vector<uint32_t> transposed_shape(input_info.shape.size());
335+
for (size_t i = 0; i < perm.size(); ++i) {
336+
transposed_shape[i] = input_info.shape[perm[i]];
337+
}
338+
339+
QnnTensorWrapper transposed_tensor(transposed_name,
340+
QNN_TENSOR_TYPE_NATIVE,
341+
input_info.qnn_data_type,
342+
input_info.quant_param.Copy(),
343+
std::move(transposed_shape));
344+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(transposed_tensor)),
345+
"Failed to add transposed tensor.");
346+
347+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(
348+
utils::GetUniqueName(node_unit.Name() + "_transpose_" + input_name.substr(input_name.find_last_of('/') + 1)),
349+
QNN_OP_PACKAGE_NAME_QTI_AISW,
350+
QNN_OP_TRANSPOSE,
351+
{input_name},
352+
{transposed_name},
353+
std::move(transpose_params),
354+
do_op_validation),
355+
"Failed to create transpose node.");
356+
357+
return Status::OK();
358+
}
359+
360+
void CreateFusedMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
361+
op_registrations.AddOpBuilder(op_type, std::make_unique<FusedMatMulOpBuilder>());
362+
}
363+
364+
} // namespace qnn
365+
} // namespace onnxruntime

onnxruntime/test/contrib_ops/fused_matmul_op_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ void RunFusedMatMulTest(const char* op_name, int32_t opset_version = 7, bool tra
213213
test.AddOutput<T>("Y", t.expected_dims, t.expected_vals);
214214

215215
// Disable OpenVINO, TensorRT because of unsupported data type
216-
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider});
216+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider, kQnnExecutionProvider});
217217
}
218218
}
219219

0 commit comments

Comments
 (0)