Skip to content

[QNN-EP] Translate FP-to-Bool Cast by NotEqual. #24466

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "core/providers/qnn/builder/opbuilder/base_op_builder.h"
Expand Down Expand Up @@ -29,8 +31,63 @@ class CastOpBuilder : public BaseOpBuilder {
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const override ORT_MUST_USE_RESULT;

private:
// QNN HTP currently does not support casting FP16/FP32 to Bool, and thus such Cast will be replaced by NotEqual with
// an additional static input 0.f to achieve the idential functional.
bool IsFpToBoolCast(const NodeUnit& node_unit) const;
Status ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>& input_names,
const logging::Logger& logger) const;
};

bool CastOpBuilder::IsFpToBoolCast(const NodeUnit& node_unit) const {
const auto* input_type_proto = node_unit.Inputs()[0].node_arg.TypeAsProto();
const auto* output_type_proto = node_unit.Outputs()[0].node_arg.TypeAsProto();

Qnn_DataType_t input_qnn_dtype = QNN_DATATYPE_UNDEFINED;
Qnn_DataType_t output_qnn_dtype = QNN_DATATYPE_UNDEFINED;

if (utils::GetQnnDataType(false, input_type_proto, input_qnn_dtype) != Status::OK() ||
utils::GetQnnDataType(false, output_type_proto, output_qnn_dtype) != Status::OK()) {
return false;
}

return ((input_qnn_dtype == QNN_DATATYPE_FLOAT_16 || input_qnn_dtype == QNN_DATATYPE_FLOAT_32) &&
output_qnn_dtype == QNN_DATATYPE_BOOL_8);
}

Status CastOpBuilder::ProcessExtraInputForNotEqual(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>& input_names,
const logging::Logger& logger) const {
const auto& input = node_unit.Inputs()[0];
if (input.quant_param.has_value()) {
return Status::OK();
}

// Build additional static input with value 0.
const std::string& input_name = utils::GetNodeName(node_unit) + "_notequal_zero";

Qnn_DataType_t qnn_data_type = QNN_DATATYPE_UNDEFINED;
const auto* type_proto = input.node_arg.TypeAsProto();
ORT_RETURN_IF_ERROR(utils::GetQnnDataType(false, type_proto, qnn_data_type));

QnnTensorWrapper input_tensor_wrapper(input_name,
QNN_TENSOR_TYPE_STATIC,
qnn_data_type,
QnnQuantParamsWrapper(),
std::move(std::vector<uint32_t>{1}),
std::move(std::vector<uint8_t>(utils::GetElementSizeByType(qnn_data_type), 0)));
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensor_wrapper)),
"Failed to add additional input tensor for QNN Cast node that will be replaced by NotEqual.");
input_names.push_back(input_name);

LOGS(logger, VERBOSE) << "FP-to-Bool Cast node " << utils::GetNodeName(node_unit) << " is replaced by NotEqual.";
return Status::OK();
}

Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger,
Expand All @@ -47,7 +104,9 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) {
LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_name;
input_names.push_back(input_name);
return Status::OK();
return IsFpToBoolCast(node_unit)
? ProcessExtraInputForNotEqual(qnn_model_wrapper, node_unit, input_names, logger)
: Status::OK();
}

std::vector<uint8_t> unpacked_tensor;
Expand Down Expand Up @@ -75,7 +134,9 @@ Status CastOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
"Failed to add input tensor for QNN Cast node.");
input_names.push_back(input_name);

return Status::OK();
return IsFpToBoolCast(node_unit)
? ProcessExtraInputForNotEqual(qnn_model_wrapper, node_unit, input_names, logger)
: Status::OK();
}

Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
Expand Down Expand Up @@ -110,14 +171,17 @@ Status CastOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra
ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensorwrapper)),
"Failed to add output tensor for QNN Cast node.");

const std::string qnn_op_type = IsFpToBoolCast(node_unit)
? QNN_OP_ELEMENT_WISE_NOT_EQUAL
: GetQnnOpType(node_unit.OpType());
ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode(utils::GetNodeName(node_unit),
QNN_OP_PACKAGE_NAME_QTI_AISW,
GetQnnOpType(node_unit.OpType()),
qnn_op_type,
std::move(input_names),
{output_name},
{},
do_op_validation),
"Failed to create QNN Cast node.");
"Failed to create " + qnn_op_type + " node.");

return Status::OK();
}
Expand Down
49 changes: 45 additions & 4 deletions onnxruntime/test/providers/qnn/cast_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@

#include <string>
#include <unordered_map>
#include <vector>

#include "test/optimizer/qdq_test_utils.h"
#include "test/providers/qnn/qnn_test_utils.h"
#include "gtest/gtest.h"

#include "core/framework/float16.h"
#include "core/graph/onnx_protobuf.h"

#include "gtest/gtest.h"
#include "test/optimizer/qdq_test_utils.h"
#include "test/providers/qnn/qnn_test_utils.h"

namespace onnxruntime {
namespace test {
Expand Down Expand Up @@ -67,6 +68,31 @@ static void RunCastOpTest(const std::vector<int64_t>& shape, ONNX_NAMESPACE::Ten
expected_ep_assignment);
}

#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)
static void RunCastFP16HTPTest(const std::vector<int64_t>& shape,
ONNX_NAMESPACE::TensorProto_DataType dst_type,
ExpectedEPNodeAssignment expected_ep_assignment) {
ProviderOptions provider_options;
#if defined(_WIN32)
provider_options["backend_path"] = "QnnHtp.dll";
#else
provider_options["backend_path"] = "libQnnHtp.so";
#endif

auto testcase = [shape, dst_type](ModelTestBuilder& builder) {
auto input_def_fp = TestInputDef(shape, false, static_cast<float>(0), static_cast<float>(20));
auto input_def = ConvertToFP16InputDef(input_def_fp);
auto input = MakeTestInput<MLFloat16>(builder, input_def);

auto* output = builder.MakeOutput();
Node& cast_node = builder.AddNode("Cast", {input}, {output});
cast_node.AddAttribute("to", static_cast<int64_t>(dst_type));
};

RunQnnModelTest(testcase, provider_options, /* opset */ 13, expected_ep_assignment);
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

//
// CPU tests:
//
Expand Down Expand Up @@ -125,6 +151,21 @@ TEST_F(QnnHTPBackendTests, TestCastInt32ToInt64HTP) {
RunCastOpTest<int32_t>({3, 3}, ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
ExpectedEPNodeAssignment::All, true);
}

// Cast float to bool on HTP.
TEST_F(QnnHTPBackendTests, TestCastFloatToBoolHTP) {
RunCastOpTest<float>({3, 3},
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
ExpectedEPNodeAssignment::All,
true);
}

// Cast float16 to bool on HTP.
TEST_F(QnnHTPBackendTests, TestCastFloat16ToBoolHTP) {
RunCastFP16HTPTest({3, 3},
ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_BOOL,
ExpectedEPNodeAssignment::All);
}
#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__)

} // namespace test
Expand Down
Loading