@@ -30,7 +30,7 @@ class ClipOpBuilder : public BaseOpBuilder {
3030 bool do_op_validation) const override ORT_MUST_USE_RESULT;
3131
3232 private:
33- Status ExplictOpCheck (QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ;
33+ Status ExplicitOpCheck (QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ;
3434};
3535
3636static Status ProcessClipMinMax (QnnModelWrapper& qnn_model_wrapper,
@@ -41,56 +41,112 @@ static Status ProcessClipMinMax(QnnModelWrapper& qnn_model_wrapper,
4141 ORT_RETURN_IF_ERROR (qnn_model_wrapper.GetTensorInfo (input, input_info));
4242 assert (input_info.is_initializer ); // Checked by ExplicitOpCheck().
4343 ORT_RETURN_IF_ERROR (qnn_model_wrapper.UnpackInitializerData (*input_info.initializer_tensor , val_bytes));
44- switch (input_info.qnn_data_type ) {
45- case QNN_DATATYPE_INT_8: {
46- float_value = static_cast <float >(*reinterpret_cast <int8_t *>(val_bytes.data ()));
47- break ;
48- }
49- case QNN_DATATYPE_INT_16: {
50- float_value = static_cast <float >(*reinterpret_cast <int16_t *>(val_bytes.data ()));
51- break ;
52- }
53- case QNN_DATATYPE_INT_32: {
54- float_value = static_cast <float >(*reinterpret_cast <int32_t *>(val_bytes.data ()));
55- break ;
56- }
57- case QNN_DATATYPE_INT_64: {
58- float_value = static_cast <float >(*reinterpret_cast <int64_t *>(val_bytes.data ()));
59- break ;
60- }
61- case QNN_DATATYPE_UINT_8: {
62- float_value = static_cast <float >(*val_bytes.data ());
63- break ;
64- }
65- case QNN_DATATYPE_UINT_16: {
66- float_value = static_cast <float >(*reinterpret_cast <uint16_t *>(val_bytes.data ()));
67- break ;
68- }
69- case QNN_DATATYPE_UINT_32: {
70- float_value = static_cast <float >(*reinterpret_cast <uint32_t *>(val_bytes.data ()));
71- break ;
72- }
73- case QNN_DATATYPE_UINT_64: {
74- float_value = static_cast <float >(*reinterpret_cast <uint64_t *>(val_bytes.data ()));
75- break ;
76- }
77- case QNN_DATATYPE_FLOAT_16: {
78- MLFloat16 fp16_value = *reinterpret_cast <const MLFloat16*>(val_bytes.data ());
79- float_value = fp16_value.ToFloat ();
80- break ;
44+
45+ // If the input is quantized, we need to dequantize it
46+ if (input.quant_param .has_value ()) {
47+ ORT_RETURN_IF_NOT (input_info.quant_param .IsPerTensor (),
48+ " Clip's min/max must use per-tensor quantization" );
49+ const Qnn_QuantizeParams_t& quant_param = input_info.quant_param .Get ();
50+
51+ switch (input_info.qnn_data_type ) {
52+ case QNN_DATATYPE_SFIXED_POINT_8: {
53+ int8_t quantized_value = *reinterpret_cast <int8_t *>(val_bytes.data ());
54+ float_value = static_cast <float >(utils::Dequantize (quant_param.scaleOffsetEncoding .offset ,
55+ quant_param.scaleOffsetEncoding .scale ,
56+ static_cast <double >(quantized_value)));
57+ break ;
58+ }
59+ case QNN_DATATYPE_SFIXED_POINT_16: {
60+ int16_t quantized_value = *reinterpret_cast <int16_t *>(val_bytes.data ());
61+ float_value = static_cast <float >(utils::Dequantize (quant_param.scaleOffsetEncoding .offset ,
62+ quant_param.scaleOffsetEncoding .scale ,
63+ static_cast <double >(quantized_value)));
64+ break ;
65+ }
66+ case QNN_DATATYPE_SFIXED_POINT_32: {
67+ int32_t quantized_value = *reinterpret_cast <int32_t *>(val_bytes.data ());
68+ float_value = static_cast <float >(utils::Dequantize (quant_param.scaleOffsetEncoding .offset ,
69+ quant_param.scaleOffsetEncoding .scale ,
70+ static_cast <double >(quantized_value)));
71+ break ;
72+ }
73+ case QNN_DATATYPE_UFIXED_POINT_8: {
74+ uint8_t quantized_value = *val_bytes.data ();
75+ float_value = static_cast <float >(utils::Dequantize (quant_param.scaleOffsetEncoding .offset ,
76+ quant_param.scaleOffsetEncoding .scale ,
77+ static_cast <double >(quantized_value)));
78+ break ;
79+ }
80+ case QNN_DATATYPE_UFIXED_POINT_16: {
81+ uint16_t quantized_value = *reinterpret_cast <uint16_t *>(val_bytes.data ());
82+ float_value = static_cast <float >(utils::Dequantize (quant_param.scaleOffsetEncoding .offset ,
83+ quant_param.scaleOffsetEncoding .scale ,
84+ static_cast <double >(quantized_value)));
85+ break ;
86+ }
87+ case QNN_DATATYPE_UFIXED_POINT_32: {
88+ uint32_t quantized_value = *reinterpret_cast <uint32_t *>(val_bytes.data ());
89+ float_value = static_cast <float >(utils::Dequantize (quant_param.scaleOffsetEncoding .offset ,
90+ quant_param.scaleOffsetEncoding .scale ,
91+ static_cast <double >(quantized_value)));
92+ break ;
93+ }
94+ default :
95+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Quantized min/max input data type not supported." );
8196 }
82- case QNN_DATATYPE_FLOAT_32: {
83- float_value = *reinterpret_cast <const float *>(val_bytes.data ());
84- break ;
97+ } else {
98+ // Non-quantized input, just cast to float
99+ switch (input_info.qnn_data_type ) {
100+ case QNN_DATATYPE_INT_8: {
101+ float_value = static_cast <float >(*reinterpret_cast <int8_t *>(val_bytes.data ()));
102+ break ;
103+ }
104+ case QNN_DATATYPE_INT_16: {
105+ float_value = static_cast <float >(*reinterpret_cast <int16_t *>(val_bytes.data ()));
106+ break ;
107+ }
108+ case QNN_DATATYPE_INT_32: {
109+ float_value = static_cast <float >(*reinterpret_cast <int32_t *>(val_bytes.data ()));
110+ break ;
111+ }
112+ case QNN_DATATYPE_INT_64: {
113+ float_value = static_cast <float >(*reinterpret_cast <int64_t *>(val_bytes.data ()));
114+ break ;
115+ }
116+ case QNN_DATATYPE_UINT_8: {
117+ float_value = static_cast <float >(*val_bytes.data ());
118+ break ;
119+ }
120+ case QNN_DATATYPE_UINT_16: {
121+ float_value = static_cast <float >(*reinterpret_cast <uint16_t *>(val_bytes.data ()));
122+ break ;
123+ }
124+ case QNN_DATATYPE_UINT_32: {
125+ float_value = static_cast <float >(*reinterpret_cast <uint32_t *>(val_bytes.data ()));
126+ break ;
127+ }
128+ case QNN_DATATYPE_UINT_64: {
129+ float_value = static_cast <float >(*reinterpret_cast <uint64_t *>(val_bytes.data ()));
130+ break ;
131+ }
132+ case QNN_DATATYPE_FLOAT_16: {
133+ MLFloat16 fp16_value = *reinterpret_cast <const MLFloat16*>(val_bytes.data ());
134+ float_value = fp16_value.ToFloat ();
135+ break ;
136+ }
137+ case QNN_DATATYPE_FLOAT_32: {
138+ float_value = *reinterpret_cast <const float *>(val_bytes.data ());
139+ break ;
140+ }
141+ default :
142+ return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " Non-quantized min/max input data type not supported." );
85143 }
86- default :
87- return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " min/max input data type not supported." );
88144 }
89145
90146 return Status::OK ();
91147}
92148
93- Status ClipOpBuilder::ExplictOpCheck (QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
149+ Status ClipOpBuilder::ExplicitOpCheck (QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
94150 if (node_unit.Inputs ().size () > 1 ) {
95151 const auto & min_input_name = node_unit.Inputs ()[1 ].node_arg .Name ();
96152 if (!min_input_name.empty () && !qnn_model_wrapper.IsConstantInput (min_input_name)) {
@@ -112,7 +168,7 @@ Status ClipOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper,
112168 std::vector<std::string>& input_names,
113169 bool do_op_validation) const {
114170 if (do_op_validation) {
115- ORT_RETURN_IF_ERROR (ExplictOpCheck (qnn_model_wrapper, node_unit));
171+ ORT_RETURN_IF_ERROR (ExplicitOpCheck (qnn_model_wrapper, node_unit));
116172 }
117173
118174 return ProcessInput (qnn_model_wrapper, node_unit.Inputs ()[0 ], logger, input_names);
0 commit comments