Skip to content

Commit e86ff2e

Browse files
committed
add check for qnn tensor dynamic shape
1 parent f373035 commit e86ff2e

File tree

5 files changed

+32
-0
lines changed

5 files changed

+32
-0
lines changed

onnxruntime/core/providers/qnn/builder/qnn_def.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,20 @@ const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor)
394394
ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version);
395395
}
396396

397+
uint8_t* GetQnnTensorIsDynamicDimensions(const Qnn_Tensor_t& qnn_tensor) {
398+
if (QNN_TENSOR_VERSION_1 == qnn_tensor.version) {
399+
return nullptr; // not present in v1
400+
}
401+
402+
#ifdef QNN_TENSOR_V2_INIT
403+
if (QNN_TENSOR_VERSION_2 == qnn_tensor.version) {
404+
return qnn_tensor.v2.isDynamicDimensions;
405+
}
406+
#endif // QNN_TENSOR_V2_INIT
407+
408+
ORT_THROW("QNN tensor version not supported, QNN tensor version: ", qnn_tensor.version);
409+
}
410+
397411
Status CompareQnnQuantParams(const Qnn_QuantizeParams_t& qparam0, const Qnn_QuantizeParams_t& qparam1,
398412
float& scale_diff, int32_t& offset_diff) {
399413
scale_diff = 0.0f;

onnxruntime/core/providers/qnn/builder/qnn_def.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ uint32_t* GetQnnTensorDims(const Qnn_Tensor_t& qnn_tensor);
126126
const Qnn_ClientBuffer_t& GetQnnTensorClientBuf(const Qnn_Tensor_t& qnn_tensor);
127127
Qnn_MemHandle_t GetQnnTensorMemHandle(const Qnn_Tensor_t& qnn_tensor);
128128
const Qnn_QuantizeParams_t& GetQnnTensorQParams(const Qnn_Tensor_t& qnn_tensor);
129+
uint8_t* GetQnnTensorIsDynamicDimensions(const Qnn_Tensor_t& qnn_tensor);
129130

130131
/**
131132
* Compares two sets of quantization parameters. Sets the parameters `scale_diff` and `offset_diff`

onnxruntime/core/providers/qnn/builder/qnn_model.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,9 @@ Status QnnModel::SetupTensors(std::vector<QnnTensorInfo>& qnn_tensor_infos,
322322
qnn_tensor_infos.resize(tensor_count);
323323

324324
for (auto& tensor_wrapper : tensor_wrappers) {
325+
ORT_RETURN_IF(utils::QnnTensorHasDynamicShape(tensor_wrapper.GetQnnTensor()),
326+
"QNN tensor (", tensor_wrapper.GetName(), ") has dynamic shape. This is not supported yet.");
327+
325328
const size_t length = utils::GetQnnTensorDataSizeInBytes(tensor_wrapper.GetTensorDims(),
326329
tensor_wrapper.GetTensorDataType());
327330
const auto& tensor_name = tensor_wrapper.GetName();

onnxruntime/core/providers/qnn/builder/qnn_utils.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#include "core/providers/qnn/builder/qnn_utils.h"
55

6+
#include <algorithm>
67
#include <functional>
78
#include <map>
89
#include <numeric>
@@ -71,6 +72,17 @@ size_t GetQnnTensorDataSizeInBytes(gsl::span<const uint32_t> shape, Qnn_DataType
7172
return std::accumulate(shape.begin(), shape.end(), data_length, std::multiplies<>{});
7273
}
7374

75+
bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor) {
76+
const uint8_t* is_dynamic_dimensions = GetQnnTensorIsDynamicDimensions(tensor);
77+
if (is_dynamic_dimensions == nullptr) {
78+
return false;
79+
}
80+
81+
const auto rank = GetQnnTensorRank(tensor);
82+
return std::any_of(is_dynamic_dimensions, is_dynamic_dimensions + rank,
83+
[](uint8_t is_dynamic_dimension) { return is_dynamic_dimension != 0; });
84+
}
85+
7486
std::ostream& operator<<(std::ostream& out, const Qnn_Scalar_t& scalar) {
7587
switch (scalar.dataType) {
7688
case QNN_DATATYPE_INT_8:

onnxruntime/core/providers/qnn/builder/qnn_utils.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ size_t GetElementSizeByType(ONNXTensorElementDataType elem_type);
2929

3030
size_t GetQnnTensorDataSizeInBytes(gsl::span<const uint32_t> shape, Qnn_DataType_t element_data_type);
3131

32+
bool QnnTensorHasDynamicShape(const Qnn_Tensor_t& tensor);
33+
3234
// TODO: make these work with Wrappers?
3335
std::ostream& operator<<(std::ostream& out, const Qnn_Param_t& qnn_param);
3436
std::ostream& operator<<(std::ostream& out, const Qnn_Tensor_t& tensor);

0 commit comments

Comments
 (0)