Skip to content

Commit c45fb8a

Browse files
author
minfhong-quic
committed
[QNN-EP] Translate FP-to-Bool Cast by NotEqual.
HTP currently does not support FP-to-Bool Cast due to some limitations. To unblock CLIP models, replace such Cast with NotEqual to achieve the same functionality. Test: Add UT testcase for FP/FP16 to Bool.
1 parent cda0d14 commit c45fb8a

File tree

2 files changed

+108
-8
lines changed

2 files changed

+108
-8
lines changed

Diff for: onnxruntime/core/providers/qnn/builder/opbuilder/cast_op_builder.cc

+65-4
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include <memory>
45
#include <string>
6+
#include <utility>
57
#include <vector>
68

79
#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
@@ -29,8 +31,63 @@ class CastOpBuilder : public BaseOpBuilder {
2931
std::vector<std::string>&& input_names,
3032
const logging::Logger& logger,
3133
bool do_op_validation) const override ORT_MUST_USE_RESULT;
34+
35+
private:
36+
// QNN HTP currently does not support casting FP16/FP32 to Bool, and thus such Cast will be replaced by NotEqual with
37+
// an additional static input 0.f to achieve the idential functional.
38+
bool IsFpToBoolCast(const NodeUnit& node_unit) const;
39+
Status ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wrapper,
40+
const NodeUnit& node_unit,
41+
std::vector<std::string>& input_names,
42+
const logging::Logger& logger) const;
3243
};
3344

45+
bool CastOpBuilder::IsFpToBoolCast(const NodeUnit& node_unit) const {
46+
const auto* input_type_proto = node_unit.Inputs()[0].node_arg.TypeAsProto();
47+
const auto* output_type_proto = node_unit.Outputs()[0].node_arg.TypeAsProto();
48+
49+
Qnn_DataType_t input_qnn_dtype = QNN_DATATYPE_UNDEFINED;
50+
Qnn_DataType_t output_qnn_dtype = QNN_DATATYPE_UNDEFINED;
51+
52+
if (utils::GetQnnDataType(false, input_type_proto, input_qnn_dtype) != Status::OK() ||
53+
utils::GetQnnDataType(false, output_type_proto, output_qnn_dtype) != Status::OK()) {
54+
return false;
55+
}
56+
57+
return ((input_qnn_dtype == QNN_DATATYPE_FLOAT_16 || input_qnn_dtype == QNN_DATATYPE_FLOAT_32) &&
58+
output_qnn_dtype == QNN_DATATYPE_BOOL_8);
59+
}
60+
61+
Status CastOpBuilder::ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wrapper,
62+
const NodeUnit& node_unit,
63+
std::vector<std::string>& input_names,
64+
const logging::Logger& logger) const {
65+
const auto& input = node_unit.Inputs()[0];
66+
if (input.quant_param.has_value()) {
67+
return Status::OK();
68+
}
69+
70+
// Build additional static input with value 0.
71+
const std::string& input_name = utils::GetNodeName(node_unit) + "_notequal_zero";
72+
73+
Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED;
74+
const auto* type_proto = input.node_arg.TypeAsProto();
75+
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type));
76+
77+
QnnTensorWrapper input_tensor_wrapper(input_name,
78+
QNN_TENSOR_TYPE_STATIC,
79+
qnn_data_type,
80+
QnnQuantParamsWrapper(),
81+
std::move(std::vector<uint32_t>{1}),
82+
std::move(std::vector<uint8_t>(utils::GetElementSizeByType(qnn_data_type), 0)));
83+
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor_wrapper)),
84+
"Failed to add additional input tensor for QNN Cast node that will be replaced by NotEqual.");
85+
input_names.push_back(input_name);
86+
87+
LOGS(logger, VERBOSE) << "FP-to-Bool Cast node " << utils::GetNodeName(node_unit) << " is replaced by NotEqual.";
88+
return Status::OK();
89+
}
90+
3491
Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
3592
const NodeUnit& node_unit,
3693
const logging::Logger& logger,
@@ -47,7 +104,8 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
47104
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
48105
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
49106
input_names.push_back(input_name);
50-
return Status::OK();
107+
return IsFpToBoolCast(node_unit) ? ProcessExtraInputForNotEqual(qnn_model_wrapper, node_unit, input_names, logger) :
108+
Status::OK();
51109
}
52110

53111
std::vector<uint8_t> unpacked_tensor;
@@ -75,7 +133,8 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
75133
"Failed to add input tensor for QNN Cast node.");
76134
input_names.push_back(input_name);
77135

78-
return Status::OK();
136+
return IsFpToBoolCast(node_unit) ? ProcessExtraInputForNotEqual(qnn_model_wrapper, node_unit, input_names, logger) :
137+
Status::OK();
79138
}
80139

81140
Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
@@ -110,14 +169,16 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
110169
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)),
111170
"Failed to add output tensor for QNN Cast node.");
112171

172+
const std::string qnn_op_type = IsFpToBoolCast(node_unit) ? QNN_OP_ELEMENT_WISE_NOT_EQUAL :
173+
GetQnnOpType(node_unit.OpType());
113174
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
114175
QNN_OP_PACKAGE_NAME_QTI_AISW,
115-
GetQnnOpType(node_unit.OpType()),
176+
qnn_op_type,
116177
std::move(input_names),
117178
{output_name},
118179
{},
119180
do_op_validation),
120-
"Failed to create QNN Cast node.");
181+
"Failed to create " + qnn_op_type + " node.");
121182

122183
return Status::OK();
123184
}

Diff for: onnxruntime/test/providers/qnn/cast_test.cc

+43-4
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55

66
#include <string>
77
#include <unordered_map>
8+
#include <vector>
89

9-
#include "test/optimizer/qdq_test_utils.h"
10-
#include "test/providers/qnn/qnn_test_utils.h"
10+
#include "gtest/gtest.h"
1111

12+
#include "core/framework/float16.h"
1213
#include "core/graph/onnx_protobuf.h"
13-
14-
#include "gtest/gtest.h"
14+
#include "test/optimizer/qdq_test_utils.h"
15+
#include "test/providers/qnn/qnn_test_utils.h"
1516

1617
namespace onnxruntime {
1718
namespace test {
@@ -67,6 +68,29 @@ static void RunCastOpTest(const std::vector<int64_t>& shape, ONNX_NAMESPACE::Ten
6768
expected_ep_assignment);
6869
}
6970

71+
static void RunCastFP16HTPTest(const std::vector<int64_t>& shape,
72+
ONNX_NAMESPACE::TensorProto_DataType dst_type,
73+
ExpectedEPNodeAssignment expected_ep_assignment) {
74+
ProviderOptions provider_options;
75+
#if defined(_WIN32)
76+
provider_options["backend_path"] = "QnnHtp.dll";
77+
#else
78+
provider_options["backend_path"] = "libQnnHtp.so";
79+
#endif
80+
81+
auto testcase = [shape, dst_type](ModelTestBuilder& builder) {
82+
auto input_def_fp = TestInputDef(shape, false, static_cast<float>(0), static_cast<float>(20));
83+
auto input_def = ConvertToFP16InputDef(input_def_fp);
84+
auto input = MakeTestInput<MLFloat16>(builder, input_def);
85+
86+
auto* output = builder.MakeOutput();
87+
Node& cast_node = builder.AddNode("Cast", {input}, {output});
88+
cast_node.AddAttribute("to", static_cast<int64_t>(dst_type));
89+
};
90+
91+
RunQnnModelTest(testcase, provider_options, /* opset */13, expected_ep_assignment);
92+
}
93+
7094
//
7195
// CPU tests:
7296
//
@@ -125,6 +149,21 @@ TEST_F(QnnHTPBackendTests, TestCastInt32ToInt64HTP) {
125149
RunCastOpTest<int32_t>({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
126150
ExpectedEPNodeAssignment::All, true);
127151
}
152+
153+
// Cast float to bool on HTP.
154+
TEST_F(QnnHTPBackendTests, TestCastFloatToBoolHTP) {
155+
RunCastOpTest<float>({3, 3},
156+
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
157+
ExpectedEPNodeAssignment::All,
158+
true);
159+
}
160+
161+
// Cast float16 to bool on HTP.
162+
TEST_F(QnnHTPBackendTests, TestCastFloat16ToBoolHTP) {
163+
RunCastFP16HTPTest({3, 3},
164+
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
165+
ExpectedEPNodeAssignment::All);
166+
}
128167
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
129168

130169
} // namespace test

0 commit comments

Comments
 (0)