Skip to content

Commit 530a1fb

Browse files
[QNN EP] Add BFloat16 dtype support in QNN EP (#26987)
### Description - QNN NPU backend supports BFloat16 dtype for many operators - QNN EP adds a new session option "htp_bf16_enable" to enable Users to signal processing the Float32 graph in BFloat16 precision - When User specifies "htp_bf16_enable", the QNN EP lowers incoming Float32 Ort graph into BFloat16 QNN graph. - The ORT CPU fallback still receives Float32 partitions. - The lowered QNN graph still accepts float32 inputs, outputs and constant initializers. The QNN EP inserts Cast operators to do the necessary precision switch. ### Motivation and Context - This enables computing accuracy sensitive float32 models in bfloat16 precision on Qualcomm NPU accelerator to improve inference time w.r.t computing in float32 precision. --------- Co-authored-by: Ashwath Shankarnarayan <ashwshan@qti.qualcomm.com>
1 parent 744e7fe commit 530a1fb

File tree

7 files changed

+745
-9
lines changed

7 files changed

+745
-9
lines changed

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,8 @@ class QnnBackendManager : public std::enable_shared_from_this<QnnBackendManager>
212212
void SetQnnBackendType(uint32_t backend_id);
213213
QnnBackendType GetQnnBackendType() { return qnn_backend_type_; }
214214

215+
uint32_t GetSocModel() const { return soc_model_; }
216+
215217
const std::string& GetSdkVersion() { return sdk_build_version_; }
216218

217219
Status DestroyHTPPowerConfigID(uint32_t htp_power_config_id);

onnxruntime/core/providers/qnn/builder/qnn_model_wrapper.cc

Lines changed: 282 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,187 @@ Status QnnModelWrapper::ValidateQnnNode(const std::string& node_name,
222222
return Status::OK();
223223
}
224224

225+
bool QnnModelWrapper::CreateBF16CastTensor(const std::string& tensor_name,
226+
const std::vector<uint32_t>& shape,
227+
Qnn_TensorType_t tensor_type) {
228+
QnnTensorWrapper bf16_tensor(tensor_name, tensor_type, QNN_DATATYPE_BFLOAT_16,
229+
QnnQuantParamsWrapper(), std::vector<uint32_t>(shape));
230+
if (!AddTensorWrapper(std::move(bf16_tensor))) {
231+
LOGS(logger_, ERROR) << "BF16: Failed to add tensor: " << tensor_name;
232+
return false;
233+
}
234+
return true;
235+
}
236+
237+
bool QnnModelWrapper::ProcessBF16InputConversion(const std::string& qnn_node_name,
238+
const std::vector<std::string>& input_names,
239+
std::vector<std::string>& converted_input_names,
240+
std::vector<QnnOpProperty>& cast_ops_to_add) {
241+
ORT_UNUSED_PARAMETER(qnn_node_name);
242+
243+
for (size_t i = 0; i < input_names.size(); ++i) {
244+
const auto& input_name = input_names[i];
245+
246+
auto it = model_tensors_map_.find(input_name);
247+
if (it == model_tensors_map_.end()) {
248+
LOGS(logger_, ERROR) << "BF16: Input tensor not found: " << input_name;
249+
return false;
250+
}
251+
252+
auto& tensor_wrapper = it->second;
253+
Qnn_DataType_t tensor_dtype = tensor_wrapper.GetTensorDataType();
254+
Qnn_TensorType_t tensor_type = tensor_wrapper.GetTensorType();
255+
bool is_graph_input_or_init = IsGraphInput(input_name) || IsConstantInput(input_name) || IsGraphOutput(input_name);
256+
257+
if (is_graph_input_or_init && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
258+
// Insert Cast node for FP32 graph inputs/initializers: FP32 -> BF16
259+
std::string cast_output_name = input_name + "_bf16_intermediate";
260+
261+
if (!IsQnnTensorWrapperExist(cast_output_name)) {
262+
std::vector<uint32_t> shape = tensor_wrapper.GetTensorDims();
263+
264+
if (!CreateBF16CastTensor(cast_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) {
265+
return false;
266+
}
267+
268+
LOGS(logger_, VERBOSE) << "BF16: Adding Cast op " << input_name << " -> " << cast_output_name;
269+
270+
QnnOpProperty cast_op(cast_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST,
271+
std::vector<std::string>{input_name},
272+
std::vector<std::string>{cast_output_name},
273+
std::vector<std::string>{});
274+
cast_ops_to_add.push_back(std::move(cast_op));
275+
}
276+
converted_input_names.push_back(cast_output_name);
277+
} else if (tensor_type == QNN_TENSOR_TYPE_NATIVE && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
278+
// Convert intermediate FP32 tensors to BF16 directly
279+
SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16);
280+
converted_input_names.push_back(input_name);
281+
} else if (tensor_type == QNN_TENSOR_TYPE_STATIC && !IsConstantInput(input_name) && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
282+
// Initializers that are created in QNN and are not present in ONNX
283+
std::string cast_output_name = input_name + "_bf16_intermediate";
284+
if (!IsQnnTensorWrapperExist(cast_output_name)) {
285+
std::vector<uint32_t> shape = tensor_wrapper.GetTensorDims();
286+
if (!CreateBF16CastTensor(cast_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) {
287+
return false;
288+
}
289+
LOGS(logger_, VERBOSE) << "BF16: Adding Cast op for static tensor " << input_name << " -> " << cast_output_name;
290+
QnnOpProperty cast_op(cast_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST,
291+
std::vector<std::string>{input_name},
292+
std::vector<std::string>{cast_output_name},
293+
std::vector<std::string>{});
294+
cast_ops_to_add.push_back(std::move(cast_op));
295+
}
296+
converted_input_names.push_back(cast_output_name);
297+
} else {
298+
converted_input_names.push_back(input_name);
299+
}
300+
}
301+
302+
return true;
303+
}
304+
305+
bool QnnModelWrapper::ProcessBF16OutputConversion(const std::string& qnn_node_name,
306+
const std::vector<std::string>& output_names,
307+
std::vector<std::string>& converted_output_names,
308+
std::vector<std::pair<std::string, std::string>>& graph_output_cast_ops) {
309+
ORT_UNUSED_PARAMETER(qnn_node_name);
310+
311+
for (size_t i = 0; i < output_names.size(); ++i) {
312+
const auto& output_name = output_names[i];
313+
314+
auto it = model_tensors_map_.find(output_name);
315+
if (it == model_tensors_map_.end()) {
316+
continue;
317+
}
318+
auto& tensor_wrapper = it->second;
319+
Qnn_DataType_t tensor_dtype = tensor_wrapper.GetTensorDataType();
320+
Qnn_TensorType_t tensor_type = tensor_wrapper.GetTensorType();
321+
322+
if (IsGraphOutput(output_name) &&
323+
(tensor_dtype == QNN_DATATYPE_FLOAT_32 || tensor_dtype == QNN_DATATYPE_BFLOAT_16)) {
324+
// For FP32 graph outputs, insert Cast node to convert BF16 back to FP32
325+
std::string bf16_output_name = utils::GetUniqueName(output_name, "_bf16_intermediate");
326+
327+
if (!IsQnnTensorWrapperExist(bf16_output_name)) {
328+
std::vector<uint32_t> shape = tensor_wrapper.GetTensorDims();
329+
330+
if (!CreateBF16CastTensor(bf16_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) {
331+
return false;
332+
}
333+
LOGS(logger_, VERBOSE) << "BF16: Adding Cast op " << bf16_output_name << " -> " << output_name;
334+
graph_output_cast_ops.push_back({bf16_output_name, output_name});
335+
}
336+
converted_output_names.push_back(bf16_output_name);
337+
} else if (tensor_type == QNN_TENSOR_TYPE_NATIVE && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
338+
// Convert intermediate FP32 tensors to BF16 directly
339+
SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16);
340+
converted_output_names.push_back(output_name);
341+
} else {
342+
converted_output_names.push_back(output_name);
343+
}
344+
}
345+
346+
return true;
347+
}
348+
349+
bool QnnModelWrapper::ApplyBF16ConversionForValidation(const std::vector<std::string>& input_names,
350+
const std::vector<std::string>& output_names,
351+
std::vector<std::string>& validation_input_names,
352+
std::vector<std::string>& validation_output_names) {
353+
// Temporarily convert FP32 tensors to BF16 for validation
354+
for (const auto& input_name : input_names) {
355+
auto it = model_tensors_map_.find(input_name);
356+
if (it == model_tensors_map_.end()) {
357+
LOGS(logger_, ERROR) << "BF16: Validation failed - input tensor not found: " << input_name;
358+
return false;
359+
}
360+
361+
auto& tensor_wrapper = it->second;
362+
if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_FLOAT_32) {
363+
SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16);
364+
}
365+
validation_input_names.push_back(input_name);
366+
}
367+
368+
for (const auto& output_name : output_names) {
369+
auto it = model_tensors_map_.find(output_name);
370+
if (it != model_tensors_map_.end()) {
371+
auto& tensor_wrapper = it->second;
372+
if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_FLOAT_32) {
373+
SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_BFLOAT_16);
374+
}
375+
}
376+
validation_output_names.push_back(output_name);
377+
}
378+
379+
return true;
380+
}
381+
382+
void QnnModelWrapper::RestoreFP32AfterValidation(const std::vector<std::string>& input_names,
383+
const std::vector<std::string>& output_names) {
384+
// Restore FP32 data types after validation
385+
for (const auto& input_name : input_names) {
386+
auto it = model_tensors_map_.find(input_name);
387+
if (it != model_tensors_map_.end()) {
388+
auto& tensor_wrapper = it->second;
389+
if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_BFLOAT_16) {
390+
SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_FLOAT_32);
391+
}
392+
}
393+
}
394+
395+
for (const auto& output_name : output_names) {
396+
auto it = model_tensors_map_.find(output_name);
397+
if (it != model_tensors_map_.end()) {
398+
auto& tensor_wrapper = it->second;
399+
if (tensor_wrapper.GetTensorDataType() == QNN_DATATYPE_BFLOAT_16) {
400+
SetQnnTensorDataType(tensor_wrapper.GetQnnTensor(), QNN_DATATYPE_FLOAT_32);
401+
}
402+
}
403+
}
404+
}
405+
225406
bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name,
226407
const std::string& package_name,
227408
const std::string& qnn_node_type,
@@ -233,15 +414,31 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name,
233414
std::vector<Qnn_Tensor_t> input_tensors;
234415
std::vector<Qnn_Tensor_t> output_tensors;
235416
std::vector<Qnn_Param_t> params;
236-
if (!CreateQnnInputOutputTensors(qnn_node_name, input_names, input_tensors, do_op_validation)) {
237-
return false;
238-
}
239417

240-
if (!CreateQnnInputOutputTensors(qnn_node_name, output_names, output_tensors, do_op_validation)) {
241-
return false;
418+
// Apply BF16 conversion for validation if enabled
419+
std::vector<std::string> validation_input_names;
420+
std::vector<std::string> validation_output_names;
421+
422+
// Use RAII guard for BF16 conversion to ensure cleanup
423+
std::unique_ptr<BF16ConversionGuard> bf16_guard;
424+
425+
if (IsBF16ConversionEnabled()) {
426+
LOGS(logger_, VERBOSE) << "[BF16] Validation with BF16 conversion enabled";
427+
if (!ApplyBF16ConversionForValidation(input_names, output_names, validation_input_names, validation_output_names)) {
428+
LOGS(logger_, ERROR) << "[BF16] ApplyBF16ConversionForValidation failed for node: " << qnn_node_name;
429+
return false;
430+
}
431+
// Create the guard after successful conversion
432+
bf16_guard = std::make_unique<BF16ConversionGuard>(this, input_names, output_names);
433+
} else {
434+
validation_input_names = input_names;
435+
validation_output_names = output_names;
242436
}
243437

244-
if (!CreateQnnParamTensors(qnn_node_name, param_tensor_names, params, do_op_validation)) {
438+
// Create tensors for validation
439+
if (!CreateQnnInputOutputTensors(qnn_node_name, validation_input_names, input_tensors, do_op_validation) ||
440+
!CreateQnnInputOutputTensors(qnn_node_name, validation_output_names, output_tensors, do_op_validation) ||
441+
!CreateQnnParamTensors(qnn_node_name, param_tensor_names, params, do_op_validation)) {
245442
return false;
246443
}
247444

@@ -257,28 +454,106 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name,
257454

258455
std::string error_msg;
259456
bool rt = op_config_wrapper.QnnGraphOpValidation(qnn_interface_, backend_handle_, error_msg);
457+
260458
if (!rt) {
261459
// TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more
262460
// specific validation error (instead of "failed to add node").
263461
LOGS(logger_, WARNING) << error_msg;
264462
}
265463
return rt;
266464
} else {
465+
// Standard execution - just add the node to the op list
267466
QnnOpProperty qnn_op(qnn_node_name, package_name, qnn_node_type,
268467
std::move(input_names), std::move(output_names), std::move(param_tensor_names));
269468
qnn_op_property_list_.push_back(std::move(qnn_op));
270469
return true;
271470
}
272471
}
273472

473+
bool QnnModelWrapper::ProcessBF16Conversions(std::vector<QnnOpProperty>& final_ops) {
474+
std::vector<QnnOpProperty> processed_ops;
475+
std::vector<QnnOpProperty> input_cast_ops;
476+
477+
for (const auto& op_property : qnn_op_property_list_) {
478+
// Make copies of the strings to avoid reference invalidation
479+
std::string qnn_node_name = op_property.GetNodeName();
480+
std::string package_name = op_property.GetPackageName();
481+
std::string qnn_node_type = op_property.GetNodeType();
482+
std::vector<std::string> input_names = op_property.GetInputNames();
483+
std::vector<std::string> output_names = op_property.GetOutputNames();
484+
std::vector<std::string> param_tensor_names = op_property.GetParamTensorNames();
485+
486+
LOGS(logger_, VERBOSE) << "[BF16] Processing node for BF16 conversion: " << qnn_node_name;
487+
488+
std::vector<std::string> converted_input_names;
489+
std::vector<std::string> converted_output_names;
490+
std::vector<std::pair<std::string, std::string>> graph_output_cast_ops;
491+
492+
if (!ProcessBF16InputConversion(qnn_node_name, input_names, converted_input_names, input_cast_ops)) {
493+
LOGS(logger_, ERROR) << "[BF16] ProcessBF16InputConversion failed for node: " << qnn_node_name;
494+
return false;
495+
}
496+
497+
if (!ProcessBF16OutputConversion(qnn_node_name, output_names, converted_output_names, graph_output_cast_ops)) {
498+
LOGS(logger_, ERROR) << "[BF16] ProcessBF16OutputConversion failed for node: " << qnn_node_name;
499+
return false;
500+
}
501+
502+
// Add the main node with BF16-converted tensor names
503+
LOGS(logger_, VERBOSE) << "[BF16] Adding main node with converted tensors: " << qnn_node_name;
504+
processed_ops.emplace_back(std::move(qnn_node_name), std::move(package_name), std::move(qnn_node_type),
505+
std::move(converted_input_names), std::move(converted_output_names),
506+
std::move(param_tensor_names));
507+
508+
// Add Cast operations for graph outputs to convert BF16 back to FP32
509+
LOGS(logger_, VERBOSE) << "[BF16] Adding " << graph_output_cast_ops.size() << " output cast operations";
510+
for (size_t i = 0; i < graph_output_cast_ops.size(); ++i) {
511+
const auto& [bf16_name, fp32_name] = graph_output_cast_ops[i];
512+
std::string cast_node_name = bf16_name;
513+
LOGS(logger_, VERBOSE) << "[BF16] Adding output Cast op[" << i << "]: " << cast_node_name
514+
<< " (" << bf16_name << " -> " << fp32_name << ")";
515+
516+
processed_ops.emplace_back(std::move(cast_node_name), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST,
517+
std::vector<std::string>{bf16_name},
518+
std::vector<std::string>{fp32_name},
519+
std::vector<std::string>{});
520+
}
521+
}
522+
523+
// Prepend input cast ops to the beginning of processed_ops
524+
final_ops.reserve(input_cast_ops.size() + processed_ops.size());
525+
526+
for (auto& cast_op : input_cast_ops) {
527+
final_ops.push_back(std::move(cast_op));
528+
}
529+
530+
for (auto& op : processed_ops) {
531+
final_ops.push_back(std::move(op));
532+
}
533+
534+
return true;
535+
}
536+
274537
bool QnnModelWrapper::ComposeQnnGraph(bool build_json_qnn_graph) {
275538
LOGS(logger_, VERBOSE) << "Compose Qnn Graph.";
276539
// ORT_RETURN_IF(qnn_op_property_list_.empty(), "Empty Qnn op list, no graph to compose.");
277540
if (qnn_op_property_list_.empty()) {
278541
return false;
279542
}
280543

281-
for (const auto& op_property : qnn_op_property_list_) {
544+
// Determine which ops to process
545+
const std::vector<QnnOpProperty>* ops_to_process = &qnn_op_property_list_;
546+
std::vector<QnnOpProperty> bf16_processed_ops;
547+
548+
if (IsBF16ConversionEnabled()) {
549+
if (!ProcessBF16Conversions(bf16_processed_ops)) {
550+
return false;
551+
}
552+
ops_to_process = &bf16_processed_ops;
553+
}
554+
555+
// Create QNN graph ops from the op properties
556+
for (const auto& op_property : *ops_to_process) {
282557
std::vector<Qnn_Tensor_t> input_tensors;
283558
std::vector<Qnn_Tensor_t> output_tensors;
284559
std::vector<Qnn_Param_t> params;

0 commit comments

Comments
 (0)