1
1
// Copyright (c) Microsoft Corporation. All rights reserved.
2
2
// Licensed under the MIT License.
3
3
4
+ #include < memory>
4
5
#include < string>
6
+ #include < utility>
5
7
#include < vector>
6
8
7
9
#include " core/providers/qnn/builder/opbuilder/base_op_builder.h"
@@ -29,8 +31,63 @@ class CastOpBuilder : public BaseOpBuilder {
29
31
std::vector<std::string>&& input_names,
30
32
const logging::Logger& logger,
31
33
bool do_op_validation) const override ORT_MUST_USE_RESULT;
34
+
35
+ private:
36
+ // QNN HTP currently does not support casting FP16/FP32 to Bool, and thus such Cast will be replaced by NotEqual with
37
+ // an additional static input 0.f to achieve the idential functional.
38
+ bool IsFpToBoolCast (const NodeUnit& node_unit) const ;
39
+ Status ProcessExtraInputForNotEqual (QnnModelWrapper& qnn_model_wrapper,
40
+ const NodeUnit& node_unit,
41
+ std::vector<std::string>& input_names,
42
+ const logging::Logger& logger) const ;
32
43
};
33
44
45
+ bool CastOpBuilder::IsFpToBoolCast (const NodeUnit& node_unit) const {
46
+ const auto * input_type_proto = node_unit.Inputs ()[0 ].node_arg .TypeAsProto ();
47
+ const auto * output_type_proto = node_unit.Outputs ()[0 ].node_arg .TypeAsProto ();
48
+
49
+ Qnn_DataType_t input_qnn_dtype = QNN_DATATYPE_UNDEFINED;
50
+ Qnn_DataType_t output_qnn_dtype = QNN_DATATYPE_UNDEFINED;
51
+
52
+ if (utils::GetQnnDataType (false , input_type_proto, input_qnn_dtype) != Status::OK () ||
53
+ utils::GetQnnDataType (false , output_type_proto, output_qnn_dtype) != Status::OK ()) {
54
+ return false ;
55
+ }
56
+
57
+ return ((input_qnn_dtype == QNN_DATATYPE_FLOAT_16 || input_qnn_dtype == QNN_DATATYPE_FLOAT_32) &&
58
+ output_qnn_dtype == QNN_DATATYPE_BOOL_8);
59
+ }
60
+
61
+ Status CastOpBuilder::ProcessExtraInputForNotEqual (QnnModelWrapper& qnn_model_wrapper,
62
+ const NodeUnit& node_unit,
63
+ std::vector<std::string>& input_names,
64
+ const logging::Logger& logger) const {
65
+ const auto & input = node_unit.Inputs ()[0 ];
66
+ if (input.quant_param .has_value ()) {
67
+ return Status::OK ();
68
+ }
69
+
70
+ // Build additional static input with value 0.
71
+ const std::string& input_name = utils::GetNodeName (node_unit) + " _notequal_zero" ;
72
+
73
+ Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED;
74
+ const auto * type_proto = input.node_arg .TypeAsProto ();
75
+ ORT_RETURN_IF_ERROR (utils::GetQnnDataType (false , type_proto, qnn_data_type));
76
+
77
+ QnnTensorWrapper input_tensor_wrapper (input_name,
78
+ QNN_TENSOR_TYPE_STATIC,
79
+ qnn_data_type,
80
+ QnnQuantParamsWrapper (),
81
+ std::move (std::vector<uint32_t >{1 }),
82
+ std::move (std::vector<uint8_t >(utils::GetElementSizeByType (qnn_data_type), 0 )));
83
+ ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (input_tensor_wrapper)),
84
+ " Failed to add additional input tensor for QNN Cast node that will be replaced by NotEqual." );
85
+ input_names.push_back (input_name);
86
+
87
+ LOGS (logger, VERBOSE) << " FP-to-Bool Cast node " << utils::GetNodeName (node_unit) << " is replaced by NotEqual." ;
88
+ return Status::OK ();
89
+ }
90
+
34
91
Status CastOpBuilder::ProcessInputs (QnnModelWrapper& qnn_model_wrapper,
35
92
const NodeUnit& node_unit,
36
93
const logging::Logger& logger,
@@ -47,7 +104,9 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
47
104
if (qnn_model_wrapper.IsQnnTensorWrapperExist (input_name)) {
48
105
LOGS (logger, VERBOSE) << " Tensor already added, skip it: " << input_name;
49
106
input_names.push_back (input_name);
50
- return Status::OK ();
107
+ return IsFpToBoolCast (node_unit)
108
+ ? ProcessExtraInputForNotEqual (qnn_model_wrapper, node_unit, input_names, logger)
109
+ : Status::OK ();
51
110
}
52
111
53
112
std::vector<uint8_t > unpacked_tensor;
@@ -75,7 +134,9 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
75
134
" Failed to add input tensor for QNN Cast node." );
76
135
input_names.push_back (input_name);
77
136
78
- return Status::OK ();
137
+ return IsFpToBoolCast (node_unit)
138
+ ? ProcessExtraInputForNotEqual (qnn_model_wrapper, node_unit, input_names, logger)
139
+ : Status::OK ();
79
140
}
80
141
81
142
Status CastOpBuilder::ProcessAttributesAndOutputs (QnnModelWrapper& qnn_model_wrapper,
@@ -110,14 +171,17 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
110
171
ORT_RETURN_IF_NOT (qnn_model_wrapper.AddTensorWrapper (std::move (output_tensorwrapper)),
111
172
" Failed to add output tensor for QNN Cast node." );
112
173
174
+ const std::string qnn_op_type = IsFpToBoolCast (node_unit)
175
+ ? QNN_OP_ELEMENT_WISE_NOT_EQUAL
176
+ : GetQnnOpType (node_unit.OpType ());
113
177
ORT_RETURN_IF_NOT (qnn_model_wrapper.CreateQnnNode (utils::GetNodeName (node_unit),
114
178
QNN_OP_PACKAGE_NAME_QTI_AISW,
115
- GetQnnOpType (node_unit. OpType ()) ,
179
+ qnn_op_type ,
116
180
std::move (input_names),
117
181
{output_name},
118
182
{},
119
183
do_op_validation),
120
- " Failed to create QNN Cast node." );
184
+ " Failed to create " + qnn_op_type + " node." );
121
185
122
186
return Status::OK ();
123
187
}
0 commit comments