Skip to content

Commit 925b90b

Browse files
[QNN-EP] Update gather op input tensor cast logic. (#26835)
### Description <!-- Describe your changes. --> Gather op was referring to onnx graph when deciding whether to insert `Cast->int32` on indices. But input tensor is created by QNN and it could already casted into int32. Which cause mismatch and resulting adding redundant Cast. This PR changes Gather Op builder to refer to QNN tenser before adding int64->int32 cast. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> It solve QNN-Gather op not to insert redundant Cast->int32. --------- Signed-off-by: Mu-Chein Hsu <quic_muchhsu@quicinc.com> Co-authored-by: Edward Chen <18449977+edgchen1@users.noreply.github.com>
1 parent 9ce61be commit 925b90b

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

onnxruntime/core/providers/qnn/builder/opbuilder/gather_op_builder.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,11 @@ static Status ProcessIndicesInput(QnnModelWrapper& qnn_model_wrapper,
160160
}
161161

162162
// Insert QNN Cast op to convert dynamic indices from int64 to int32.
163+
const auto& input_tensorwrapper = qnn_model_wrapper.GetQnnTensorWrapper(indices_tensor_name);
164+
163165
std::string indices_casted_name{indices_tensor_name};
164-
if (indices_info.qnn_data_type == QNN_DATATYPE_INT_64) {
166+
// Check QNN Tensor data type.
167+
if (input_tensorwrapper.GetTensorDataType() == QNN_DATATYPE_INT_64) {
165168
assert(!indices_info.is_initializer);
166169
indices_casted_name += "_int32";
167170
if (qnn_model_wrapper.IsQnnTensorWrapperExist(indices_casted_name)) {

0 commit comments

Comments
 (0)