Skip to content

Commit add6952

Browse files
author
minfhong-quic
committed
[QNN-EP] Translate FP-to-Bool Cast by NotEqual.
HTP currently does not support FP-to-Bool Cast due to some limitations. To unblock CLIP models, replace such Cast with NotEqual to achieve the same functionality. Test: Add UT testcase for FP/FP16 to Bool.
1 parent cda0d14 commit add6952

File tree

2 files changed

+113
-8
lines changed

2 files changed

+113
-8
lines changed

Diff for: onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc

+68-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include <memory>
45
#include <string>
6+
#include <utility>
57
#include <vector>
68

79
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
@@ -29,8 +31,63 @@ class CastOpBuilder : public BaseOpBuilder {
2931
std::vector<std::string>&& input_names,
3032
const logging::Logger& logger,
3133
bool do_op_validation) const override ORT_MUST_USE_RESULT;
34+
35+
private:
36+
// QNN HTP currently does not support casting FP16/FP32 to Bool, and thus such Cast will be replaced by NotEqual with
37+
// an additional static input 0.f to achieve the idential functional.
38+
bool IsFpToBoolCast(const NodeUnit& node_unit) const;
39+
Status ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wrapper,
40+
const NodeUnit& node_unit,
41+
std::vector<std::string>& input_names,
42+
const logging::Logger& logger) const;
3243
};
3344

45+
bool CastOpBuilder::IsFpToBoolCast(const NodeUnit& node_unit) const {
46+
const auto* input_type_proto = node_unit.Inputs()[0].node_arg.TypeAsProto();
47+
const auto* output_type_proto = node_unit.Outputs()[0].node_arg.TypeAsProto();
48+
49+
Qnn_DataType_t input_qnn_dtype = QNN_DATATYPE_UNDEFINED;
50+
Qnn_DataType_t output_qnn_dtype = QNN_DATATYPE_UNDEFINED;
51+
52+
if (utils::GetQnnDataType(false, input_type_proto, input_qnn_dtype) != Status::OK() ||
53+
utils::GetQnnDataType(false, output_type_proto, output_qnn_dtype) != Status::OK()) {
54+
return false;
55+
}
56+
57+
return ((input_qnn_dtype == QNN_DATATYPE_FLOAT_16 || input_qnn_dtype == QNN_DATATYPE_FLOAT_32) &&
58+
output_qnn_dtype == QNN_DATATYPE_BOOL_8);
59+
}
60+
61+
Status CastOpBuilder::ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wrapper,
62+
const NodeUnit& node_unit,
63+
std::vector<std::string>& input_names,
64+
const logging::Logger& logger) const {
65+
const auto& input = node_unit.Inputs()[0];
66+
if (input.quant_param.has_value()) {
67+
return Status::OK();
68+
}
69+
70+
// Build additional static input with value 0.
71+
const std::string& input_name = utils::GetNodeName(node_unit) + "_notequal_zero";
72+
73+
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED;
74+
const auto* type_proto = input.node_arg.TypeAsProto();
75+
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type));
76+
77+
QnnTensorWrapper input_tensor_wrapper(input_name,
78+
QNN_TENSOR_TYPE_STATIC,
79+
qnn_data_type,
80+
QnnQuantParamsWrapper(),
81+
std::move(std::vector<uint32_t>{1}),
82+
std::move(std::vector<uint8_t>(utils::GetElementSizeByType(qnn_data_type), 0)));
83+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor_wrapper)),
84+
"Failed to add additional input tensor for QNN Cast node that will be replaced by NotEqual.");
85+
input_names.push_back(input_name);
86+
87+
LOGS(logger, VERBOSE) << "FP-to-Bool Cast node " << utils::GetNodeName(node_unit) << " is replaced by NotEqual.";
88+
return Status::OK();
89+
}
90+
3491
Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
3592
const NodeUnit& node_unit,
3693
const logging::Logger& logger,
@@ -47,7 +104,9 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
47104
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
48105
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
49106
input_names.push_back(input_name);
50-
return Status::OK();
107+
return IsFpToBoolCast(node_unit)
108+
? ProcessExtraInputForNotEqual(qnn_model_wrapper, node_unit, input_names, logger)
109+
: Status::OK();
51110
}
52111

53112
std::vector<uint8_t> unpacked_tensor;
@@ -75,7 +134,9 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
75134
"Failed to add input tensor for QNN Cast node.");
76135
input_names.push_back(input_name);
77136

78-
return Status::OK();
137+
return IsFpToBoolCast(node_unit)
138+
? ProcessExtraInputForNotEqual(qnn_model_wrapper, node_unit, input_names, logger)
139+
: Status::OK();
79140
}
80141

81142
Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
@@ -110,14 +171,17 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
110171
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)),
111172
"Failed to add output tensor for QNN Cast node.");
112173

174+
const std::string qnn_op_type = IsFpToBoolCast(node_unit)
175+
? QNN_OP_ELEMENT_WISE_NOT_EQUAL
176+
: GetQnnOpType(node_unit.OpType());
113177
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
114178
QNN_OP_PACKAGE_NAME_QTI_AISW,
115-
GetQnnOpType(node_unit.OpType()),
179+
qnn_op_type,
116180
std::move(input_names),
117181
{output_name},
118182
{},
119183
do_op_validation),
120-
"Failed to create QNN Cast node.");
184+
"Failed to create " + qnn_op_type + " node.");
121185

122186
return Status::OK();
123187
}

Diff for: onnxruntime/test/providers/qnn/cast_test.cc

+45-4
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
#include <string>
77
#include <unordered_map>
8+
#include <vector>
89

9-
#include "test/optimizer/qdq_test_utils.h"
10-
#include "test/providers/qnn/qnn_test_utils.h"
10+
#include "gtest/gtest.h"
1111

12+
#include "core/framework/float16.h"
1213
#include "core/graph/onnx_protobuf.h"
13-
14-
#include "gtest/gtest.h"
14+
#include "test/optimizer/qdq_test_utils.h"
15+
#include "test/providers/qnn/qnn_test_utils.h"
1516

1617
namespace onnxruntime {
1718
namespace test {
@@ -67,6 +68,31 @@ static void RunCastOpTest(const std::vector<int64_t>& shape, ONNX_NAMESPACE::Ten
6768
expected_ep_assignment);
6869
}
6970

71+
#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
72+
static void RunCastFP16HTPTest(const std::vector<int64_t>& shape,
73+
ONNX_NAMESPACE::TensorProto_DataType dst_type,
74+
ExpectedEPNodeAssignment expected_ep_assignment) {
75+
ProviderOptions provider_options;
76+
#if defined(_WIN32)
77+
provider_options["backend_path"] = "QnnHtp.dll";
78+
#else
79+
provider_options["backend_path"] = "libQnnHtp.so";
80+
#endif
81+
82+
auto testcase = [shape, dst_type](ModelTestBuilder& builder) {
83+
auto input_def_fp = TestInputDef(shape, false, static_cast<float>(0), static_cast<float>(20));
84+
auto input_def = ConvertToFP16InputDef(input_def_fp);
85+
auto input = MakeTestInput<MLFloat16>(builder, input_def);
86+
87+
auto* output = builder.MakeOutput();
88+
Node& cast_node = builder.AddNode("Cast", {input}, {output});
89+
cast_node.AddAttribute("to", static_cast<int64_t>(dst_type));
90+
};
91+
92+
RunQnnModelTest(testcase, provider_options, /* opset */ 13, expected_ep_assignment);
93+
}
94+
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
95+
7096
//
7197
// CPU tests:
7298
//
@@ -125,6 +151,21 @@ TEST_F(QnnHTPBackendTests, TestCastInt32ToInt64HTP) {
125151
RunCastOpTest<int32_t>({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
126152
ExpectedEPNodeAssignment::All, true);
127153
}
154+
155+
// Cast float to bool on HTP.
156+
TEST_F(QnnHTPBackendTests, TestCastFloatToBoolHTP) {
157+
RunCastOpTest<float>({3, 3},
158+
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
159+
ExpectedEPNodeAssignment::All,
160+
true);
161+
}
162+
163+
// Cast float16 to bool on HTP.
164+
TEST_F(QnnHTPBackendTests, TestCastFloat16ToBoolHTP) {
165+
RunCastFP16HTPTest({3, 3},
166+
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
167+
ExpectedEPNodeAssignment::All);
168+
}
128169
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
129170

130171
} // namespace test

0 commit comments

Comments
 (0)