|
| 1 | +// Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +// Licensed under the MIT License. |
| 3 | + |
| 4 | +#include "core/providers/qnn/ort_api.h" |
| 5 | +#include "core/providers/qnn/builder/op_builder_factory.h" |
| 6 | +#include "core/providers/qnn/builder/opbuilder/base_op_builder.h" |
| 7 | +#include "core/providers/qnn/builder/qnn_model_wrapper.h" |
| 8 | +#include "core/providers/qnn/builder/qnn_utils.h" |
| 9 | + |
| 10 | +namespace onnxruntime { |
| 11 | +namespace qnn { |
| 12 | + |
| 13 | +// FusedMatMul operator is decomposed into MatMul with optional transposition and alpha scaling. |
| 14 | +class FusedMatMulOpBuilder : public BaseOpBuilder { |
| 15 | + public: |
| 16 | + FusedMatMulOpBuilder() : BaseOpBuilder("FusedMatMulOpBuilder") {} |
| 17 | + ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(FusedMatMulOpBuilder); |
| 18 | + |
| 19 | + protected: |
| 20 | + Status ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, const logging::Logger& logger, |
| 21 | + std::vector<std::string>& input_names, bool do_op_validation) const override ORT_MUST_USE_RESULT; |
| 22 | + |
| 23 | + Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, |
| 24 | + std::vector<std::string>&& input_names, const logging::Logger& logger, |
| 25 | + bool do_op_validation) const override ORT_MUST_USE_RESULT; |
| 26 | + |
| 27 | + private: |
| 28 | + Status ProcessMatMulInputs(QnnModelWrapper& qnn_model_wrapper, |
| 29 | + const NodeUnit& node_unit, |
| 30 | + const logging::Logger& logger, |
| 31 | + std::vector<std::string>& input_names) const ORT_MUST_USE_RESULT; |
| 32 | + |
| 33 | + Status GetFusedMatMulAttributes(const NodeUnit& node_unit, |
| 34 | + bool& transA, |
| 35 | + bool& transB, |
| 36 | + bool& transBatchA, |
| 37 | + bool& transBatchB, |
| 38 | + float& alpha) const ORT_MUST_USE_RESULT; |
| 39 | + |
| 40 | + Status ProcessPermAttribute(QnnModelWrapper& qnn_model_wrapper, |
| 41 | + const NodeUnit& node_unit, |
| 42 | + const std::vector<uint32_t>& perm, |
| 43 | + std::vector<std::string>& param_tensor_names) const; |
| 44 | + |
| 45 | + void CreateBatchTransposePermVector(const std::vector<uint32_t>& input_shape, std::vector<uint32_t>& perm, bool trans_mat = false) const; |
| 46 | + |
| 47 | + Status HandleBatchTranspose(QnnModelWrapper& qnn_model_wrapper, |
| 48 | + const NodeUnit& node_unit, |
| 49 | + const TensorInfo& input_info, |
| 50 | + const std::string& input_name, |
| 51 | + std::string& transposed_name, |
| 52 | + bool trans_mat, |
| 53 | + bool do_op_validation) const; |
| 54 | +}; |
| 55 | + |
| 56 | +Status FusedMatMulOpBuilder::GetFusedMatMulAttributes(const NodeUnit& node_unit, |
| 57 | + bool& transA, |
| 58 | + bool& transB, |
| 59 | + bool& transBatchA, |
| 60 | + bool& transBatchB, |
| 61 | + float& alpha) const { |
| 62 | + NodeAttrHelper node_helper(node_unit); |
| 63 | + |
| 64 | + transA = node_helper.Get("transA", static_cast<int64_t>(0)) != 0; |
| 65 | + transB = node_helper.Get("transB", static_cast<int64_t>(0)) != 0; |
| 66 | + |
| 67 | + transBatchA = node_helper.Get("transBatchA", static_cast<int64_t>(0)) != 0; |
| 68 | + transBatchB = node_helper.Get("transBatchB", static_cast<int64_t>(0)) != 0; |
| 69 | + |
| 70 | + alpha = node_helper.Get("alpha", 1.0f); |
| 71 | + |
| 72 | + return Status::OK(); |
| 73 | +} |
| 74 | + |
| 75 | +Status FusedMatMulOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, |
| 76 | + const logging::Logger& logger, std::vector<std::string>& input_names, |
| 77 | + bool /*do_op_validation*/) const { |
| 78 | + const auto& inputs = node_unit.Inputs(); |
| 79 | + |
| 80 | + if (inputs.size() != 2) { |
| 81 | + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, |
| 82 | + "FusedMatMul requires exactly 2 inputs, got ", inputs.size()); |
| 83 | + } |
| 84 | + |
| 85 | + TensorInfo input_info_0{}; |
| 86 | + TensorInfo input_info_1{}; |
| 87 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[0], input_info_0)); |
| 88 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(inputs[1], input_info_1)); |
| 89 | + |
| 90 | + ORT_RETURN_IF_ERROR(ProcessMatMulInputs(qnn_model_wrapper, node_unit, logger, input_names)); |
| 91 | + |
| 92 | + return Status::OK(); |
| 93 | +} |
| 94 | + |
| 95 | +Status FusedMatMulOpBuilder::ProcessMatMulInputs(QnnModelWrapper& qnn_model_wrapper, |
| 96 | + const NodeUnit& node_unit, |
| 97 | + const logging::Logger& logger, |
| 98 | + std::vector<std::string>& input_names) const { |
| 99 | + const auto& inputs = node_unit.Inputs(); |
| 100 | + |
| 101 | + // Process input A |
| 102 | + const std::string& input_a_name = inputs[0].node_arg.Name(); |
| 103 | + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_a_name)) { |
| 104 | + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_a_name; |
| 105 | + } else { |
| 106 | + QnnTensorWrapper input_a_tensor; |
| 107 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[0], input_a_tensor)); |
| 108 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_a_tensor)), "Failed to add input A tensor."); |
| 109 | + } |
| 110 | + input_names.emplace_back(input_a_name); |
| 111 | + |
| 112 | + // Process input B |
| 113 | + const std::string& input_b_name = inputs[1].node_arg.Name(); |
| 114 | + if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_b_name)) { |
| 115 | + LOGS(logger, VERBOSE) << "Tensor already added, skip it: " << input_b_name; |
| 116 | + } else { |
| 117 | + QnnTensorWrapper input_b_tensor; |
| 118 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.MakeTensorWrapper(inputs[1], input_b_tensor)); |
| 119 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_b_tensor)), "Failed to add input B tensor."); |
| 120 | + } |
| 121 | + input_names.emplace_back(input_b_name); |
| 122 | + |
| 123 | + return Status::OK(); |
| 124 | +} |
| 125 | + |
| 126 | +Status FusedMatMulOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, |
| 127 | + const NodeUnit& node_unit, |
| 128 | + std::vector<std::string>&& input_names, |
| 129 | + const logging::Logger& /*logger*/, |
| 130 | + bool do_op_validation) const { |
| 131 | + bool transA = false; |
| 132 | + bool transB = false; |
| 133 | + bool transBatchA = false; |
| 134 | + bool transBatchB = false; |
| 135 | + float alpha = 1.0f; |
| 136 | + ORT_RETURN_IF_ERROR(GetFusedMatMulAttributes(node_unit, transA, transB, transBatchA, transBatchB, alpha)); |
| 137 | + |
| 138 | + TensorInfo input_a_info{}; |
| 139 | + TensorInfo input_b_info{}; |
| 140 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[0], input_a_info)); |
| 141 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Inputs()[1], input_b_info)); |
| 142 | + |
| 143 | + std::vector<std::string> matmul_param_tensor_names; |
| 144 | + |
| 145 | + // Set transpose parameters for last two dimensions |
| 146 | + // Skip using transpose_in0 param when both transA and transBatchA are present |
| 147 | + // Only use transpose_in0 when transA is present and transBatchA is not present |
| 148 | + if (!(transA && transBatchA)) { |
| 149 | + Qnn_Scalar_t transpose_a_scalar = QNN_SCALAR_INIT; |
| 150 | + transpose_a_scalar.dataType = QNN_DATATYPE_BOOL_8; |
| 151 | + transpose_a_scalar.bool8Value = transA ? 1 : 0; |
| 152 | + QnnParamWrapper transpose_a_param(node_unit.Index(), node_unit.Name(), |
| 153 | + QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN0, transpose_a_scalar); |
| 154 | + matmul_param_tensor_names.push_back(transpose_a_param.GetParamTensorName()); |
| 155 | + qnn_model_wrapper.AddParamWrapper(std::move(transpose_a_param)); |
| 156 | + } |
| 157 | + |
| 158 | + // Skip using transpose_in1 param when both transB and transBatchB are present |
| 159 | + // Only use transpose_in1 when transB is present and transBatchB is not present |
| 160 | + if (!(transB && transBatchB)) { |
| 161 | + Qnn_Scalar_t transpose_b_scalar = QNN_SCALAR_INIT; |
| 162 | + transpose_b_scalar.dataType = QNN_DATATYPE_BOOL_8; |
| 163 | + transpose_b_scalar.bool8Value = transB ? 1 : 0; |
| 164 | + QnnParamWrapper transpose_b_param(node_unit.Index(), node_unit.Name(), |
| 165 | + QNN_OP_MAT_MUL_PARAM_TRANSPOSE_IN1, transpose_b_scalar); |
| 166 | + matmul_param_tensor_names.push_back(transpose_b_param.GetParamTensorName()); |
| 167 | + qnn_model_wrapper.AddParamWrapper(std::move(transpose_b_param)); |
| 168 | + } |
| 169 | + |
| 170 | + // QNN doesn't directly support batch dimension transposition in MatMul |
| 171 | + // We need to insert additional transpose operations before the MatMul if transBatchA or transBatchB is true |
| 172 | + std::string input_a_for_matmul = input_names[0]; |
| 173 | + std::string input_b_for_matmul = input_names[1]; |
| 174 | + |
| 175 | + if (transBatchA && input_a_info.shape.size() > 2) { |
| 176 | + std::string transposed_a_name; |
| 177 | + ORT_RETURN_IF_ERROR(HandleBatchTranspose(qnn_model_wrapper, node_unit, input_a_info, |
| 178 | + input_a_for_matmul, transposed_a_name, transA, do_op_validation)); |
| 179 | + input_a_for_matmul = transposed_a_name; |
| 180 | + } |
| 181 | + |
| 182 | + if (transBatchB && input_b_info.shape.size() > 2) { |
| 183 | + std::string transposed_b_name; |
| 184 | + ORT_RETURN_IF_ERROR(HandleBatchTranspose(qnn_model_wrapper, node_unit, input_b_info, |
| 185 | + input_b_for_matmul, transposed_b_name, transB, do_op_validation)); |
| 186 | + input_b_for_matmul = transposed_b_name; |
| 187 | + } |
| 188 | + |
| 189 | + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); |
| 190 | + TensorInfo output_info{}; |
| 191 | + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetTensorInfo(node_unit.Outputs()[0], output_info)); |
| 192 | + |
| 193 | + if (alpha == 1.0f) { |
| 194 | + // When alpha is 1.0f, MatMul output is the final output |
| 195 | + Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; |
| 196 | + |
| 197 | + QnnTensorWrapper output_tensor(output_name, |
| 198 | + tensor_type, |
| 199 | + output_info.qnn_data_type, |
| 200 | + output_info.quant_param.Copy(), |
| 201 | + std::vector<uint32_t>(output_info.shape)); |
| 202 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), |
| 203 | + "Failed to add final output tensor."); |
| 204 | + |
| 205 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( |
| 206 | + utils::GetUniqueName(node_unit.Name() + "_matmul"), |
| 207 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 208 | + QNN_OP_MAT_MUL, |
| 209 | + {input_a_for_matmul, input_b_for_matmul}, |
| 210 | + {output_name}, |
| 211 | + std::move(matmul_param_tensor_names), |
| 212 | + do_op_validation), |
| 213 | + "Failed to create MatMul node for FusedMatMul."); |
| 214 | + } else { |
| 215 | + // When alpha is not 1.0f, we need an intermediate tensor for MatMul output |
| 216 | + // and then apply alpha scaling |
| 217 | + std::string matmul_output_name = utils::GetUniqueName(node_unit.Name() + "_matmul_output"); |
| 218 | + |
| 219 | + QnnTensorWrapper matmul_output_tensor(matmul_output_name, |
| 220 | + QNN_TENSOR_TYPE_NATIVE, |
| 221 | + output_info.qnn_data_type, |
| 222 | + QnnQuantParamsWrapper(), |
| 223 | + std::vector<uint32_t>(output_info.shape)); |
| 224 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(matmul_output_tensor)), |
| 225 | + "Failed to add MatMul output tensor."); |
| 226 | + |
| 227 | + Qnn_TensorType_t tensor_type = qnn_model_wrapper.IsGraphOutput(output_name) ? QNN_TENSOR_TYPE_APP_READ : QNN_TENSOR_TYPE_NATIVE; |
| 228 | + |
| 229 | + QnnTensorWrapper output_tensor(output_name, |
| 230 | + tensor_type, |
| 231 | + output_info.qnn_data_type, |
| 232 | + output_info.quant_param.Copy(), |
| 233 | + std::vector<uint32_t>(output_info.shape)); |
| 234 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(output_tensor)), |
| 235 | + "Failed to add output tensor."); |
| 236 | + |
| 237 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( |
| 238 | + utils::GetUniqueName(node_unit.Name() + "_matmul"), |
| 239 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 240 | + QNN_OP_MAT_MUL, |
| 241 | + {input_a_for_matmul, input_b_for_matmul}, |
| 242 | + {matmul_output_name}, |
| 243 | + std::move(matmul_param_tensor_names), |
| 244 | + do_op_validation), |
| 245 | + "Failed to create MatMul node for FusedMatMul."); |
| 246 | + |
| 247 | + std::string alpha_tensor_name = utils::GetUniqueName(node_unit.Name() + "_alpha"); |
| 248 | + std::vector<uint32_t> alpha_shape{1}; |
| 249 | + Qnn_DataType_t alpha_qnn_data_type = output_info.qnn_data_type; |
| 250 | + std::vector<uint8_t> alpha_data; |
| 251 | + |
| 252 | + // The alpha tensor data type should match the MatMul output data type for element-wise multiply |
| 253 | + if (alpha_qnn_data_type == QNN_DATATYPE_FLOAT_16) { |
| 254 | + alpha_data.resize(sizeof(MLFloat16)); |
| 255 | + MLFloat16 alpha_fp16(alpha); |
| 256 | + memcpy(alpha_data.data(), &alpha_fp16.val, sizeof(MLFloat16)); |
| 257 | + } else { |
| 258 | + alpha_data.resize(sizeof(float)); |
| 259 | + memcpy(alpha_data.data(), &alpha, sizeof(float)); |
| 260 | + } |
| 261 | + |
| 262 | + QnnTensorWrapper alpha_tensor_wrapper(alpha_tensor_name, |
| 263 | + QNN_TENSOR_TYPE_STATIC, |
| 264 | + alpha_qnn_data_type, |
| 265 | + QnnQuantParamsWrapper(), |
| 266 | + std::move(alpha_shape), |
| 267 | + std::move(alpha_data)); |
| 268 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(alpha_tensor_wrapper)), |
| 269 | + "Failed to add alpha tensor."); |
| 270 | + |
| 271 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( |
| 272 | + utils::GetUniqueName(node_unit.Name() + "_alpha_scale"), |
| 273 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 274 | + QNN_OP_ELEMENT_WISE_MULTIPLY, |
| 275 | + {matmul_output_name, alpha_tensor_name}, |
| 276 | + {output_name}, |
| 277 | + {}, |
| 278 | + do_op_validation), |
| 279 | + "Failed to create alpha scaling node for FusedMatMul."); |
| 280 | + } |
| 281 | + |
| 282 | + return Status::OK(); |
| 283 | +} |
| 284 | + |
| 285 | +Status FusedMatMulOpBuilder::ProcessPermAttribute(QnnModelWrapper& qnn_model_wrapper, |
| 286 | + const NodeUnit& node_unit, |
| 287 | + const std::vector<uint32_t>& perm, |
| 288 | + std::vector<std::string>& param_tensor_names) const { |
| 289 | + QnnParamWrapper transpose_param(node_unit.Index(), node_unit.Name(), QNN_OP_TRANSPOSE_PARAM_PERM, |
| 290 | + {static_cast<uint32_t>(perm.size())}, std::vector<uint32_t>(perm)); |
| 291 | + param_tensor_names.push_back(transpose_param.GetParamTensorName()); |
| 292 | + qnn_model_wrapper.AddParamWrapper(std::move(transpose_param)); |
| 293 | + |
| 294 | + return Status::OK(); |
| 295 | +} |
| 296 | + |
| 297 | +void FusedMatMulOpBuilder::CreateBatchTransposePermVector(const std::vector<uint32_t>& input_shape, |
| 298 | + std::vector<uint32_t>& perm, |
| 299 | + bool trans_mat) const { |
| 300 | + const size_t shape_size = input_shape.size(); |
| 301 | + |
| 302 | + perm.clear(); |
| 303 | + perm.reserve(shape_size); |
| 304 | + |
| 305 | + // 1. Add batch dimensions (1 to shape_size-2) |
| 306 | + for (size_t i = 1; i < shape_size - 1; ++i) { |
| 307 | + perm.push_back(static_cast<uint32_t>(i)); |
| 308 | + } |
| 309 | + |
| 310 | + // 2. Add the second-to-last dimension based on trans_mat |
| 311 | + perm.push_back(trans_mat ? static_cast<uint32_t>(shape_size - 1) : 0); |
| 312 | + |
| 313 | + // 3. Add the last dimension based on trans_mat |
| 314 | + perm.push_back(trans_mat ? 0 : static_cast<uint32_t>(shape_size - 1)); |
| 315 | +} |
| 316 | + |
| 317 | +Status FusedMatMulOpBuilder::HandleBatchTranspose(QnnModelWrapper& qnn_model_wrapper, |
| 318 | + const NodeUnit& node_unit, |
| 319 | + const TensorInfo& input_info, |
| 320 | + const std::string& input_name, |
| 321 | + std::string& transposed_name, |
| 322 | + bool trans_mat, |
| 323 | + bool do_op_validation) const { |
| 324 | + transposed_name = utils::GetUniqueName(node_unit.Name() + "_transposed_" + input_name.substr(input_name.find_last_of('/') + 1)); |
| 325 | + |
| 326 | + // Create perm vector for batch transpose |
| 327 | + std::vector<uint32_t> perm; |
| 328 | + CreateBatchTransposePermVector(input_info.shape, perm, trans_mat); |
| 329 | + |
| 330 | + std::vector<std::string> transpose_params; |
| 331 | + ORT_RETURN_IF_ERROR(ProcessPermAttribute(qnn_model_wrapper, node_unit, perm, transpose_params)); |
| 332 | + |
| 333 | + // Calculate transposed shape directly using the permutation |
| 334 | + std::vector<uint32_t> transposed_shape(input_info.shape.size()); |
| 335 | + for (size_t i = 0; i < perm.size(); ++i) { |
| 336 | + transposed_shape[i] = input_info.shape[perm[i]]; |
| 337 | + } |
| 338 | + |
| 339 | + QnnTensorWrapper transposed_tensor(transposed_name, |
| 340 | + QNN_TENSOR_TYPE_NATIVE, |
| 341 | + input_info.qnn_data_type, |
| 342 | + input_info.quant_param.Copy(), |
| 343 | + std::move(transposed_shape)); |
| 344 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(transposed_tensor)), |
| 345 | + "Failed to add transposed tensor."); |
| 346 | + |
| 347 | + ORT_RETURN_IF_NOT(qnn_model_wrapper.CreateQnnNode( |
| 348 | + utils::GetUniqueName(node_unit.Name() + "_transpose_" + input_name.substr(input_name.find_last_of('/') + 1)), |
| 349 | + QNN_OP_PACKAGE_NAME_QTI_AISW, |
| 350 | + QNN_OP_TRANSPOSE, |
| 351 | + {input_name}, |
| 352 | + {transposed_name}, |
| 353 | + std::move(transpose_params), |
| 354 | + do_op_validation), |
| 355 | + "Failed to create transpose node."); |
| 356 | + |
| 357 | + return Status::OK(); |
| 358 | +} |
| 359 | + |
| 360 | +void CreateFusedMatMulOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { |
| 361 | + op_registrations.AddOpBuilder(op_type, std::make_unique<FusedMatMulOpBuilder>()); |
| 362 | +} |
| 363 | + |
| 364 | +} // namespace qnn |
| 365 | +} // namespace onnxruntime |
0 commit comments