Skip to content

Commit 9d108d0

Browse files
[QNN EP] Add QuickGELU operator support for QNN provider (#27034)
### Description Add support for the QuickGELU operator in the QNN provider: - Implement QuickGeluOpBuilder to handle QuickGELU operations - Add registration for QuickGELU in op_builder_factory - Add comprehensive tests for CPU and HTP backends - Support both float and quantized (QDQ) versions ### Motivation and Context - QNN doesn't have a direct operator to map QuickGelu so decompose it as x * sigmoid(alpha * x) for computing the whole model on HTP to improve inference time. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 29d9b2f commit 9d108d0

File tree

4 files changed

+349
-0
lines changed

4 files changed

+349
-0
lines changed

onnxruntime/core/providers/qnn/builder/op_builder_factory.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
212212
CreateGatherNDOpBuilder("GatherND", *this);
213213
}
214214

215+
{
216+
CreateQuickGeluOpBuilder("QuickGelu", *this);
217+
}
218+
215219
{
216220
CreateModOpBuilder("Mod", *this);
217221
}

onnxruntime/core/providers/qnn/builder/op_builder_factory.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ void CreateThresholdedReluOpBuilder(const std::string& op_type, OpBuilderRegistr
126126
void CreateSTFTOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
127127

128128
void CreateInverseOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
129+
void CreateQuickGeluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
129130

130131
void CreateMatMulNBitsOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
131132
void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
5+
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
6+
#include "core/providers/qnn/builder/op_builder_factory.h"
7+
#include "core/providers/qnn/builder/qnn_utils.h"
8+
9+
namespace onnxruntime {
10+
namespace qnn {
11+
12+
class QuickGeluOpBuilder : public BaseOpBuilder {
13+
public:
14+
QuickGeluOpBuilder() : BaseOpBuilder("QuickGeluOpBuilder") {}
15+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QuickGeluOpBuilder);
16+
17+
protected:
18+
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
19+
const NodeUnit& node_unit,
20+
std::vector<std::string>&& input_names,
21+
const logging::Logger& logger,
22+
bool do_op_validation) const override ORT_MUST_USE_RESULT;
23+
};
24+
25+
Status QuickGeluOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
26+
const NodeUnit& node_unit,
27+
std::vector<std::string>&& input_names,
28+
const logging::Logger& logger,
29+
bool do_op_validation) const {
30+
LOGS(logger, VERBOSE) << "Processing QuickGelu operator: " << node_unit.Name();
31+
32+
const std::string& input_name = input_names[0];
33+
const auto& outputs = node_unit.Outputs();
34+
const std::string& output_name = outputs[0].node_arg.Name();
35+
36+
NodeAttrHelper node_helper(node_unit);
37+
float alpha = node_helper.Get("alpha", 1.702f);
38+
39+
TensorInfo input_info = {};
40+
ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_info));
41+
42+
// Skip alpha multiplication when alpha is 1.0 to reduce accumulated error
43+
constexpr float alpha_epsilon = 1e-6f;
44+
const bool skip_alpha_mul = std::abs(alpha - 1.0f) < alpha_epsilon;
45+
46+
std::string sigmoid_input_name;
47+
std::string sigmoid_output_name = utils::GetUniqueName(node_unit.Name() + "_sigmoid");
48+
49+
if (skip_alpha_mul) {
50+
sigmoid_input_name = input_name;
51+
} else {
52+
const std::string alpha_mul_output_name = utils::GetUniqueName(node_unit.Name() + "_alpha_mul");
53+
sigmoid_input_name = alpha_mul_output_name;
54+
55+
// The alpha tensor data type should match the input data type for element-wise multiply
56+
std::string alpha_tensor_name = utils::GetUniqueName(node_unit.Name() + "_alpha");
57+
std::vector<uint32_t> alpha_shape{1};
58+
Qnn_DataType_t alpha_qnn_data_type = input_info.qnn_data_type;
59+
std::vector<uint8_t> alpha_data;
60+
61+
if (alpha_qnn_data_type == QNN_DATATYPE_FLOAT_16) {
62+
alpha_data.resize(sizeof(MLFloat16));
63+
MLFloat16 alpha_fp16(alpha);
64+
memcpy(alpha_data.data(), &alpha_fp16.val, sizeof(MLFloat16));
65+
} else {
66+
alpha_data.resize(sizeof(float));
67+
memcpy(alpha_data.data(), &alpha, sizeof(float));
68+
}
69+
70+
QnnTensorWrapper alpha_tensor_wrapper(alpha_tensor_name,
71+
QNN_TENSOR_TYPE_STATIC,
72+
alpha_qnn_data_type,
73+
QnnQuantParamsWrapper(),
74+
std::move(alpha_shape),
75+
std::move(alpha_data));
76+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_tensor_wrapper)), "Failed to add alpha tensor.");
77+
78+
QnnTensorWrapper alpha_mul_output_tensor_wrapper(alpha_mul_output_name,
79+
QNN_TENSOR_TYPE_NATIVE,
80+
input_info.qnn_data_type,
81+
QnnQuantParamsWrapper(),
82+
std::vector<uint32_t>(input_info.shape));
83+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_mul_output_tensor_wrapper)),
84+
"Failed to add alpha_mul_output tensor.");
85+
86+
// Step 1: Create Mul node for alpha * x
87+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_alpha_mul"),
88+
QNN_OP_PACKAGE_NAME_QTI_AISW,
89+
QNN_OP_ELEMENT_WISE_MULTIPLY,
90+
{alpha_tensor_name, input_name},
91+
{alpha_mul_output_name},
92+
{},
93+
do_op_validation),
94+
"Failed to create alpha_mul node.");
95+
}
96+
97+
QnnTensorWrapper sigmoid_output_tensor_wrapper(sigmoid_output_name,
98+
QNN_TENSOR_TYPE_NATIVE,
99+
input_info.qnn_data_type,
100+
QnnQuantParamsWrapper(),
101+
std::vector<uint32_t>(input_info.shape));
102+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(sigmoid_output_tensor_wrapper)),
103+
"Failed to add sigmoid_output tensor.");
104+
105+
Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE;
106+
QnnTensorWrapper output_tensor_wrapper(output_name,
107+
tensor_type,
108+
input_info.qnn_data_type,
109+
input_info.quant_param.Copy(),
110+
std::vector<uint32_t>(input_info.shape));
111+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor_wrapper)),
112+
"Failed to add output tensor.");
113+
114+
// Step 2: Create Sigmoid node for sigmoid(alpha * x) or sigmoid(x)
115+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_sigmoid"),
116+
QNN_OP_PACKAGE_NAME_QTI_AISW,
117+
QNN_OP_SIGMOID,
118+
{sigmoid_input_name},
119+
{sigmoid_output_name},
120+
{},
121+
do_op_validation),
122+
"Failed to create sigmoid node.");
123+
124+
// Step 3: Create Mul node for x * sigmoid(alpha * x) or x * sigmoid(x)
125+
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetUniqueName(node_unit.Name() + "_final_mul"),
126+
QNN_OP_PACKAGE_NAME_QTI_AISW,
127+
QNN_OP_ELEMENT_WISE_MULTIPLY,
128+
{input_name, sigmoid_output_name},
129+
{output_name},
130+
{},
131+
do_op_validation),
132+
"Failed to create final_mul node.");
133+
134+
return Status::OK();
135+
}
136+
137+
void CreateQuickGeluOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
138+
op_registrations.AddOpBuilder(op_type, std::make_unique<QuickGeluOpBuilder>());
139+
}
140+
141+
} // namespace qnn
142+
} // namespace onnxruntime
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#if !defined(ORT_MINIMAL_BUILD)
5+
6+
#include <string>
7+
#include "core/graph/constants.h"
8+
#include "test/providers/qnn/qnn_test_utils.h"
9+
10+
#include "gtest/gtest.h"
11+
12+
namespace onnxruntime {
13+
namespace test {
14+
15+
// Runs a model with a QuickGelu operator on the QNN CPU backend. Checks the graph node assignment
16+
// and that inference outputs for QNN EP and CPU EP match.
17+
template <typename DataType>
18+
static void RunQuickGeluTest(const TestInputDef<DataType>& input_def,
19+
float alpha,
20+
ExpectedEPNodeAssignment expected_ep_assignment,
21+
const std::string& backend_name = "cpu",
22+
float fp32_abs_err = 5e-3f) {
23+
ProviderOptions provider_options;
24+
provider_options["backend_type"] = backend_name;
25+
26+
if (backend_name == "htp") {
27+
provider_options["enable_htp_fp16_precision"] = "1";
28+
}
29+
30+
auto model_builder = [input_def, alpha](ModelTestBuilder& builder) {
31+
NodeArg* input = MakeTestInput<DataType>(builder, input_def);
32+
auto* output = builder.MakeOutput();
33+
34+
Node& node = builder.AddNode("QuickGelu", {input}, {output}, kMSDomain);
35+
node.AddAttribute("alpha", alpha);
36+
};
37+
38+
RunQnnModelTest(model_builder,
39+
provider_options,
40+
13, // opset version for contrib ops
41+
expected_ep_assignment,
42+
fp32_abs_err);
43+
}
44+
45+
// Tests the accuracy of a QDQ QuickGelu model on QNN EP by comparing to CPU EP.
46+
template <typename QType>
47+
static void RunQDQQuickGeluTest(const TestInputDef<float>& input_def,
48+
float alpha,
49+
ExpectedEPNodeAssignment expected_ep_assignment,
50+
const std::string& backend_name = "htp",
51+
bool use_contrib_qdq = false) {
52+
ProviderOptions provider_options;
53+
provider_options["backend_type"] = backend_name;
54+
provider_options["offload_graph_io_quantization"] = "0";
55+
56+
GetTestModelFn model_builder_fn = [input_def, alpha](ModelTestBuilder& builder) {
57+
NodeArg* input = MakeTestInput<float>(builder, input_def);
58+
auto* output = builder.MakeOutput();
59+
60+
Node& node = builder.AddNode("QuickGelu", {input}, {output}, kMSDomain);
61+
node.AddAttribute("alpha", alpha);
62+
};
63+
64+
GetTestQDQModelFn<QType> qdq_model_builder_fn = [input_def, alpha, use_contrib_qdq](ModelTestBuilder& builder, std::vector<QuantParams<QType>>& output_qparams) {
65+
NodeArg* input = MakeTestInput<float>(builder, input_def);
66+
QuantParams<QType> input_qparams = GetTestInputQuantParams<QType>(input_def);
67+
NodeArg* input_after_qdq = AddQDQNodePair<QType>(builder, input, input_qparams.scale,
68+
input_qparams.zero_point, use_contrib_qdq);
69+
70+
// QuickGelu -> op_output
71+
auto* op_output = builder.MakeIntermediate();
72+
Node& node = builder.AddNode("QuickGelu", {input_after_qdq}, {op_output}, kMSDomain);
73+
node.AddAttribute("alpha", alpha);
74+
75+
// op_output -> Q -> DQ -> output
76+
AddQDQNodePairWithOutputAsGraphOutput<QType>(builder, op_output, output_qparams[0].scale,
77+
output_qparams[0].zero_point, use_contrib_qdq);
78+
};
79+
80+
TestQDQModelAccuracy(model_builder_fn,
81+
qdq_model_builder_fn,
82+
provider_options,
83+
13, // opset version for contrib ops
84+
expected_ep_assignment,
85+
QDQTolerance(5e-3f));
86+
}
87+
88+
//
89+
// CPU tests:
90+
//
91+
92+
// Test QuickGelu with default alpha value (1.0)
93+
TEST_F(QnnCPUBackendTests, QuickGelu_Default_Alpha) {
94+
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
95+
1.0f, // alpha
96+
ExpectedEPNodeAssignment::All);
97+
}
98+
99+
// Test QuickGelu with custom alpha value
100+
TEST_F(QnnCPUBackendTests, QuickGelu_Custom_Alpha) {
101+
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
102+
1.702f, // alpha
103+
ExpectedEPNodeAssignment::All);
104+
}
105+
106+
// Test QuickGelu with negative alpha value
107+
TEST_F(QnnCPUBackendTests, QuickGelu_Negative_Alpha) {
108+
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
109+
-1.702f, // alpha
110+
ExpectedEPNodeAssignment::All);
111+
}
112+
113+
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
114+
//
115+
// HTP tests:
116+
//
117+
118+
TEST_F(QnnHTPBackendTests, QuickGelu_Default_Alpha) {
119+
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
120+
1.0f,
121+
ExpectedEPNodeAssignment::All,
122+
"htp",
123+
0.01f);
124+
}
125+
126+
// Test QuickGelu with custom alpha value on HTP
127+
TEST_F(QnnHTPBackendTests, QuickGelu_Custom_Alpha) {
128+
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
129+
1.702f, // alpha
130+
ExpectedEPNodeAssignment::All,
131+
"htp");
132+
}
133+
134+
// Test QuickGelu with negative alpha value on HTP
135+
TEST_F(QnnHTPBackendTests, QuickGelu_Negative_Alpha) {
136+
RunQuickGeluTest<float>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
137+
-1.702f, // alpha
138+
ExpectedEPNodeAssignment::All,
139+
"htp");
140+
}
141+
142+
TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Default_Alpha) {
143+
RunQuickGeluTest<MLFloat16>(ConvertToFP16InputDef(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))),
144+
1.0f,
145+
ExpectedEPNodeAssignment::All,
146+
"htp",
147+
0.01f);
148+
}
149+
150+
// Test QuickGelu with float16 inputs and custom alpha on HTP
151+
TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Custom_Alpha) {
152+
RunQuickGeluTest<MLFloat16>(ConvertToFP16InputDef(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))),
153+
1.702f, // alpha
154+
ExpectedEPNodeAssignment::All,
155+
"htp");
156+
}
157+
158+
// Test QuickGelu with float16 inputs and negative alpha on HTP
159+
TEST_F(QnnHTPBackendTests, QuickGelu_Float16_Negative_Alpha) {
160+
RunQuickGeluTest<MLFloat16>(ConvertToFP16InputDef(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48))),
161+
-1.702f, // alpha
162+
ExpectedEPNodeAssignment::All,
163+
"htp");
164+
}
165+
166+
// Test 8-bit QDQ QuickGelu with default alpha value on HTP
167+
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U8_Default_Alpha) {
168+
RunQDQQuickGeluTest<uint8_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
169+
1.0f, // alpha
170+
ExpectedEPNodeAssignment::All);
171+
}
172+
173+
// Test 8-bit QDQ QuickGelu with custom alpha value on HTP
174+
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U8_Custom_Alpha) {
175+
RunQDQQuickGeluTest<uint8_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
176+
1.702f, // alpha
177+
ExpectedEPNodeAssignment::All);
178+
}
179+
180+
// Test 16-bit QDQ QuickGelu with default alpha value on HTP
181+
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U16_Default_Alpha) {
182+
RunQDQQuickGeluTest<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
183+
1.0f, // alpha
184+
ExpectedEPNodeAssignment::All,
185+
"htp",
186+
true); // Use com.microsoft Q/DQ ops
187+
}
188+
189+
// Test 16-bit QDQ QuickGelu with custom alpha value on HTP
190+
TEST_F(QnnHTPBackendTests, QuickGelu_QDQ_U16_Custom_Alpha) {
191+
RunQDQQuickGeluTest<uint16_t>(TestInputDef<float>({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)),
192+
1.702f, // alpha
193+
ExpectedEPNodeAssignment::All,
194+
"htp",
195+
true); // Use com.microsoft Q/DQ ops
196+
}
197+
198+
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
199+
200+
} // namespace test
201+
} // namespace onnxruntime
202+
#endif // !defined(ORT_MINIMAL_BUILD)

0 commit comments

Comments
 (0)