Skip to content

Commit 7804b5c

Browse files
[QNN EP] Reshape Transpose Fusion for 6D tensors (#26338)
### Description - Added qnn node group for pattern Reshape -> Transpose -> Reshape with reshape node giving output as rank 6. - Added support for QDQ pattern and added unit test to verify 6D tensor offload on NPU with this pattern ### Motivation and Context - Ensure Reshape Transpose Reshape pattern gets offloaded to the NPU since QNN Does not have support for 6D tensors.
1 parent 510dd14 commit 7804b5c

File tree

4 files changed

+603
-2
lines changed

4 files changed

+603
-2
lines changed

onnxruntime/core/providers/qnn/builder/qnn_node_group/qnn_node_group.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "core/providers/qnn/builder/qnn_node_group/udo_fusion.h"
2222
#include "core/providers/qnn/builder/qnn_node_group/lpbqgemm_fusion.h"
2323
#include "core/providers/qnn/builder/qnn_node_group/lpbqmatmul_fusion.h"
24+
#include "core/providers/qnn/builder/qnn_node_group/reshape_transpose_rank5.h"
2425

2526
#include "core/providers/qnn/builder/qnn_utils.h"
2627
#include "core/providers/qnn/ort_api.h"
@@ -82,6 +83,7 @@ static std::unordered_map<std::string, std::vector<FusionFunc>> fusions = {
8283
{"Gemm", {LowPowerBlockQuantizedGemmFusion::TryFusion, ReshapeGemmFusion::TryFusion}},
8384
{"Mul", {ScaleSoftmaxFusion::TryFusion}},
8485
{"Cast", {CastLoneQFusion::TryFusion}},
86+
{"Reshape", {Rank6ToRank5Fusion::TryFusion}},
8587
{"Transpose", {ChannelShuffleFusion::TryFusion}}};
8688

8789
void registerUDO(const std::string& node_type, const std::string& op_package) {
@@ -117,8 +119,10 @@ static std::unique_ptr<IQnnNodeGroup> TryQnnFusions(
117119
const std::unordered_map<const Node*, const NodeUnit*>& node_to_node_unit,
118120
const std::unordered_map<const NodeUnit*, const IQnnNodeGroup*>& node_unit_to_qnn_node_group,
119121
const logging::Logger& logger) {
120-
// For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings
121-
if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode && starting_node_unit.OpType() != "MatMul") {
122+
// For now, all fusions involve standalone node units (i.e., no wrapping DQ/Q nodes) except MatMul w/ LPBQ encodings and Reshape
123+
if (starting_node_unit.UnitType() != NodeUnit::Type::SingleNode &&
124+
starting_node_unit.OpType() != "MatMul" &&
125+
starting_node_unit.OpType() != "Reshape") {
122126
return nullptr;
123127
}
124128

0 commit comments

Comments
 (0)