@@ -222,6 +222,187 @@ Status QnnModelWrapper::ValidateQnnNode(const std::string& node_name,
222222 return Status::OK ();
223223}
224224
225+ bool QnnModelWrapper::CreateBF16CastTensor (const std::string& tensor_name,
226+ const std::vector<uint32_t >& shape,
227+ Qnn_TensorType_t tensor_type) {
228+ QnnTensorWrapper bf16_tensor (tensor_name, tensor_type, QNN_DATATYPE_BFLOAT_16,
229+ QnnQuantParamsWrapper (), std::vector<uint32_t >(shape));
230+ if (!AddTensorWrapper (std::move (bf16_tensor))) {
231+ LOGS (logger_, ERROR) << " BF16: Failed to add tensor: " << tensor_name;
232+ return false ;
233+ }
234+ return true ;
235+ }
236+
237+ bool QnnModelWrapper::ProcessBF16InputConversion (const std::string& qnn_node_name,
238+ const std::vector<std::string>& input_names,
239+ std::vector<std::string>& converted_input_names,
240+ std::vector<QnnOpProperty>& cast_ops_to_add) {
241+ ORT_UNUSED_PARAMETER (qnn_node_name);
242+
243+ for (size_t i = 0 ; i < input_names.size (); ++i) {
244+ const auto & input_name = input_names[i];
245+
246+ auto it = model_tensors_map_.find (input_name);
247+ if (it == model_tensors_map_.end ()) {
248+ LOGS (logger_, ERROR) << " BF16: Input tensor not found: " << input_name;
249+ return false ;
250+ }
251+
252+ auto & tensor_wrapper = it->second ;
253+ Qnn_DataType_t tensor_dtype = tensor_wrapper.GetTensorDataType ();
254+ Qnn_TensorType_t tensor_type = tensor_wrapper.GetTensorType ();
255+ bool is_graph_input_or_init = IsGraphInput (input_name) || IsConstantInput (input_name) || IsGraphOutput (input_name);
256+
257+ if (is_graph_input_or_init && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
258+ // Insert Cast node for FP32 graph inputs/initializers: FP32 -> BF16
259+ std::string cast_output_name = input_name + " _bf16_intermediate" ;
260+
261+ if (!IsQnnTensorWrapperExist (cast_output_name)) {
262+ std::vector<uint32_t > shape = tensor_wrapper.GetTensorDims ();
263+
264+ if (!CreateBF16CastTensor (cast_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) {
265+ return false ;
266+ }
267+
268+ LOGS (logger_, VERBOSE) << " BF16: Adding Cast op " << input_name << " -> " << cast_output_name;
269+
270+ QnnOpProperty cast_op (cast_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST,
271+ std::vector<std::string>{input_name},
272+ std::vector<std::string>{cast_output_name},
273+ std::vector<std::string>{});
274+ cast_ops_to_add.push_back (std::move (cast_op));
275+ }
276+ converted_input_names.push_back (cast_output_name);
277+ } else if (tensor_type == QNN_TENSOR_TYPE_NATIVE && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
278+ // Convert intermediate FP32 tensors to BF16 directly
279+ SetQnnTensorDataType (tensor_wrapper.GetQnnTensor (), QNN_DATATYPE_BFLOAT_16);
280+ converted_input_names.push_back (input_name);
281+ } else if (tensor_type == QNN_TENSOR_TYPE_STATIC && !IsConstantInput (input_name) && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
282+ // Initializers that are created in QNN and are not present in ONNX
283+ std::string cast_output_name = input_name + " _bf16_intermediate" ;
284+ if (!IsQnnTensorWrapperExist (cast_output_name)) {
285+ std::vector<uint32_t > shape = tensor_wrapper.GetTensorDims ();
286+ if (!CreateBF16CastTensor (cast_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) {
287+ return false ;
288+ }
289+ LOGS (logger_, VERBOSE) << " BF16: Adding Cast op for static tensor " << input_name << " -> " << cast_output_name;
290+ QnnOpProperty cast_op (cast_output_name, QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST,
291+ std::vector<std::string>{input_name},
292+ std::vector<std::string>{cast_output_name},
293+ std::vector<std::string>{});
294+ cast_ops_to_add.push_back (std::move (cast_op));
295+ }
296+ converted_input_names.push_back (cast_output_name);
297+ } else {
298+ converted_input_names.push_back (input_name);
299+ }
300+ }
301+
302+ return true ;
303+ }
304+
305+ bool QnnModelWrapper::ProcessBF16OutputConversion (const std::string& qnn_node_name,
306+ const std::vector<std::string>& output_names,
307+ std::vector<std::string>& converted_output_names,
308+ std::vector<std::pair<std::string, std::string>>& graph_output_cast_ops) {
309+ ORT_UNUSED_PARAMETER (qnn_node_name);
310+
311+ for (size_t i = 0 ; i < output_names.size (); ++i) {
312+ const auto & output_name = output_names[i];
313+
314+ auto it = model_tensors_map_.find (output_name);
315+ if (it == model_tensors_map_.end ()) {
316+ continue ;
317+ }
318+ auto & tensor_wrapper = it->second ;
319+ Qnn_DataType_t tensor_dtype = tensor_wrapper.GetTensorDataType ();
320+ Qnn_TensorType_t tensor_type = tensor_wrapper.GetTensorType ();
321+
322+ if (IsGraphOutput (output_name) &&
323+ (tensor_dtype == QNN_DATATYPE_FLOAT_32 || tensor_dtype == QNN_DATATYPE_BFLOAT_16)) {
324+ // For FP32 graph outputs, insert Cast node to convert BF16 back to FP32
325+ std::string bf16_output_name = utils::GetUniqueName (output_name, " _bf16_intermediate" );
326+
327+ if (!IsQnnTensorWrapperExist (bf16_output_name)) {
328+ std::vector<uint32_t > shape = tensor_wrapper.GetTensorDims ();
329+
330+ if (!CreateBF16CastTensor (bf16_output_name, shape, QNN_TENSOR_TYPE_NATIVE)) {
331+ return false ;
332+ }
333+ LOGS (logger_, VERBOSE) << " BF16: Adding Cast op " << bf16_output_name << " -> " << output_name;
334+ graph_output_cast_ops.push_back ({bf16_output_name, output_name});
335+ }
336+ converted_output_names.push_back (bf16_output_name);
337+ } else if (tensor_type == QNN_TENSOR_TYPE_NATIVE && tensor_dtype == QNN_DATATYPE_FLOAT_32) {
338+ // Convert intermediate FP32 tensors to BF16 directly
339+ SetQnnTensorDataType (tensor_wrapper.GetQnnTensor (), QNN_DATATYPE_BFLOAT_16);
340+ converted_output_names.push_back (output_name);
341+ } else {
342+ converted_output_names.push_back (output_name);
343+ }
344+ }
345+
346+ return true ;
347+ }
348+
349+ bool QnnModelWrapper::ApplyBF16ConversionForValidation (const std::vector<std::string>& input_names,
350+ const std::vector<std::string>& output_names,
351+ std::vector<std::string>& validation_input_names,
352+ std::vector<std::string>& validation_output_names) {
353+ // Temporarily convert FP32 tensors to BF16 for validation
354+ for (const auto & input_name : input_names) {
355+ auto it = model_tensors_map_.find (input_name);
356+ if (it == model_tensors_map_.end ()) {
357+ LOGS (logger_, ERROR) << " BF16: Validation failed - input tensor not found: " << input_name;
358+ return false ;
359+ }
360+
361+ auto & tensor_wrapper = it->second ;
362+ if (tensor_wrapper.GetTensorDataType () == QNN_DATATYPE_FLOAT_32) {
363+ SetQnnTensorDataType (tensor_wrapper.GetQnnTensor (), QNN_DATATYPE_BFLOAT_16);
364+ }
365+ validation_input_names.push_back (input_name);
366+ }
367+
368+ for (const auto & output_name : output_names) {
369+ auto it = model_tensors_map_.find (output_name);
370+ if (it != model_tensors_map_.end ()) {
371+ auto & tensor_wrapper = it->second ;
372+ if (tensor_wrapper.GetTensorDataType () == QNN_DATATYPE_FLOAT_32) {
373+ SetQnnTensorDataType (tensor_wrapper.GetQnnTensor (), QNN_DATATYPE_BFLOAT_16);
374+ }
375+ }
376+ validation_output_names.push_back (output_name);
377+ }
378+
379+ return true ;
380+ }
381+
382+ void QnnModelWrapper::RestoreFP32AfterValidation (const std::vector<std::string>& input_names,
383+ const std::vector<std::string>& output_names) {
384+ // Restore FP32 data types after validation
385+ for (const auto & input_name : input_names) {
386+ auto it = model_tensors_map_.find (input_name);
387+ if (it != model_tensors_map_.end ()) {
388+ auto & tensor_wrapper = it->second ;
389+ if (tensor_wrapper.GetTensorDataType () == QNN_DATATYPE_BFLOAT_16) {
390+ SetQnnTensorDataType (tensor_wrapper.GetQnnTensor (), QNN_DATATYPE_FLOAT_32);
391+ }
392+ }
393+ }
394+
395+ for (const auto & output_name : output_names) {
396+ auto it = model_tensors_map_.find (output_name);
397+ if (it != model_tensors_map_.end ()) {
398+ auto & tensor_wrapper = it->second ;
399+ if (tensor_wrapper.GetTensorDataType () == QNN_DATATYPE_BFLOAT_16) {
400+ SetQnnTensorDataType (tensor_wrapper.GetQnnTensor (), QNN_DATATYPE_FLOAT_32);
401+ }
402+ }
403+ }
404+ }
405+
225406bool QnnModelWrapper::CreateQnnNode (const std::string& qnn_node_name,
226407 const std::string& package_name,
227408 const std::string& qnn_node_type,
@@ -233,15 +414,31 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name,
233414 std::vector<Qnn_Tensor_t> input_tensors;
234415 std::vector<Qnn_Tensor_t> output_tensors;
235416 std::vector<Qnn_Param_t> params;
236- if (!CreateQnnInputOutputTensors (qnn_node_name, input_names, input_tensors, do_op_validation)) {
237- return false ;
238- }
239417
240- if (!CreateQnnInputOutputTensors (qnn_node_name, output_names, output_tensors, do_op_validation)) {
241- return false ;
418+ // Apply BF16 conversion for validation if enabled
419+ std::vector<std::string> validation_input_names;
420+ std::vector<std::string> validation_output_names;
421+
422+ // Use RAII guard for BF16 conversion to ensure cleanup
423+ std::unique_ptr<BF16ConversionGuard> bf16_guard;
424+
425+ if (IsBF16ConversionEnabled ()) {
426+ LOGS (logger_, VERBOSE) << " [BF16] Validation with BF16 conversion enabled" ;
427+ if (!ApplyBF16ConversionForValidation (input_names, output_names, validation_input_names, validation_output_names)) {
428+ LOGS (logger_, ERROR) << " [BF16] ApplyBF16ConversionForValidation failed for node: " << qnn_node_name;
429+ return false ;
430+ }
431+ // Create the guard after successful conversion
432+ bf16_guard = std::make_unique<BF16ConversionGuard>(this , input_names, output_names);
433+ } else {
434+ validation_input_names = input_names;
435+ validation_output_names = output_names;
242436 }
243437
244- if (!CreateQnnParamTensors (qnn_node_name, param_tensor_names, params, do_op_validation)) {
438+ // Create tensors for validation
439+ if (!CreateQnnInputOutputTensors (qnn_node_name, validation_input_names, input_tensors, do_op_validation) ||
440+ !CreateQnnInputOutputTensors (qnn_node_name, validation_output_names, output_tensors, do_op_validation) ||
441+ !CreateQnnParamTensors (qnn_node_name, param_tensor_names, params, do_op_validation)) {
245442 return false ;
246443 }
247444
@@ -257,28 +454,106 @@ bool QnnModelWrapper::CreateQnnNode(const std::string& qnn_node_name,
257454
258455 std::string error_msg;
259456 bool rt = op_config_wrapper.QnnGraphOpValidation (qnn_interface_, backend_handle_, error_msg);
457+
260458 if (!rt) {
261459 // TODO(adrianlizarraga): Return a Status with the error message so that aggregated logs show a more
262460 // specific validation error (instead of "failed to add node").
263461 LOGS (logger_, WARNING) << error_msg;
264462 }
265463 return rt;
266464 } else {
465+ // Standard execution - just add the node to the op list
267466 QnnOpProperty qnn_op (qnn_node_name, package_name, qnn_node_type,
268467 std::move (input_names), std::move (output_names), std::move (param_tensor_names));
269468 qnn_op_property_list_.push_back (std::move (qnn_op));
270469 return true ;
271470 }
272471}
273472
473+ bool QnnModelWrapper::ProcessBF16Conversions (std::vector<QnnOpProperty>& final_ops) {
474+ std::vector<QnnOpProperty> processed_ops;
475+ std::vector<QnnOpProperty> input_cast_ops;
476+
477+ for (const auto & op_property : qnn_op_property_list_) {
478+ // Make copies of the strings to avoid reference invalidation
479+ std::string qnn_node_name = op_property.GetNodeName ();
480+ std::string package_name = op_property.GetPackageName ();
481+ std::string qnn_node_type = op_property.GetNodeType ();
482+ std::vector<std::string> input_names = op_property.GetInputNames ();
483+ std::vector<std::string> output_names = op_property.GetOutputNames ();
484+ std::vector<std::string> param_tensor_names = op_property.GetParamTensorNames ();
485+
486+ LOGS (logger_, VERBOSE) << " [BF16] Processing node for BF16 conversion: " << qnn_node_name;
487+
488+ std::vector<std::string> converted_input_names;
489+ std::vector<std::string> converted_output_names;
490+ std::vector<std::pair<std::string, std::string>> graph_output_cast_ops;
491+
492+ if (!ProcessBF16InputConversion (qnn_node_name, input_names, converted_input_names, input_cast_ops)) {
493+ LOGS (logger_, ERROR) << " [BF16] ProcessBF16InputConversion failed for node: " << qnn_node_name;
494+ return false ;
495+ }
496+
497+ if (!ProcessBF16OutputConversion (qnn_node_name, output_names, converted_output_names, graph_output_cast_ops)) {
498+ LOGS (logger_, ERROR) << " [BF16] ProcessBF16OutputConversion failed for node: " << qnn_node_name;
499+ return false ;
500+ }
501+
502+ // Add the main node with BF16-converted tensor names
503+ LOGS (logger_, VERBOSE) << " [BF16] Adding main node with converted tensors: " << qnn_node_name;
504+ processed_ops.emplace_back (std::move (qnn_node_name), std::move (package_name), std::move (qnn_node_type),
505+ std::move (converted_input_names), std::move (converted_output_names),
506+ std::move (param_tensor_names));
507+
508+ // Add Cast operations for graph outputs to convert BF16 back to FP32
509+ LOGS (logger_, VERBOSE) << " [BF16] Adding " << graph_output_cast_ops.size () << " output cast operations" ;
510+ for (size_t i = 0 ; i < graph_output_cast_ops.size (); ++i) {
511+ const auto & [bf16_name, fp32_name] = graph_output_cast_ops[i];
512+ std::string cast_node_name = bf16_name;
513+ LOGS (logger_, VERBOSE) << " [BF16] Adding output Cast op[" << i << " ]: " << cast_node_name
514+ << " (" << bf16_name << " -> " << fp32_name << " )" ;
515+
516+ processed_ops.emplace_back (std::move (cast_node_name), QNN_OP_PACKAGE_NAME_QTI_AISW, QNN_OP_CAST,
517+ std::vector<std::string>{bf16_name},
518+ std::vector<std::string>{fp32_name},
519+ std::vector<std::string>{});
520+ }
521+ }
522+
523+ // Prepend input cast ops to the beginning of processed_ops
524+ final_ops.reserve (input_cast_ops.size () + processed_ops.size ());
525+
526+ for (auto & cast_op : input_cast_ops) {
527+ final_ops.push_back (std::move (cast_op));
528+ }
529+
530+ for (auto & op : processed_ops) {
531+ final_ops.push_back (std::move (op));
532+ }
533+
534+ return true ;
535+ }
536+
274537bool QnnModelWrapper::ComposeQnnGraph (bool build_json_qnn_graph) {
275538 LOGS (logger_, VERBOSE) << " Compose Qnn Graph." ;
276539 // ORT_RETURN_IF(qnn_op_property_list_.empty(), "Empty Qnn op list, no graph to compose.");
277540 if (qnn_op_property_list_.empty ()) {
278541 return false ;
279542 }
280543
281- for (const auto & op_property : qnn_op_property_list_) {
544+ // Determine which ops to process
545+ const std::vector<QnnOpProperty>* ops_to_process = &qnn_op_property_list_;
546+ std::vector<QnnOpProperty> bf16_processed_ops;
547+
548+ if (IsBF16ConversionEnabled ()) {
549+ if (!ProcessBF16Conversions (bf16_processed_ops)) {
550+ return false ;
551+ }
552+ ops_to_process = &bf16_processed_ops;
553+ }
554+
555+ // Create QNN graph ops from the op properties
556+ for (const auto & op_property : *ops_to_process) {
282557 std::vector<Qnn_Tensor_t> input_tensors;
283558 std::vector<Qnn_Tensor_t> output_tensors;
284559 std::vector<Qnn_Param_t> params;
0 commit comments