diff --git a/onnxruntime/core/optimizer/common_subexpression_elimination.cc b/onnxruntime/core/optimizer/common_subexpression_elimination.cc index 8f78f9c3b6cc7..3c373296f442b 100644 --- a/onnxruntime/core/optimizer/common_subexpression_elimination.cc +++ b/onnxruntime/core/optimizer/common_subexpression_elimination.cc @@ -172,29 +172,18 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) { } // Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op. -// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto. bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) { if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() && - (lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT || - lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 || - lhs_t.data_type() == onnx::TensorProto_DataType_INT64) && - lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 && - utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) { + lhs_t.data_type() != onnx::TensorProto_DataType_STRING && + lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1)) { return false; } - const void* lhs_value = lhs_t.raw_data().data(); - const void* rhs_value = rhs_t.raw_data().data(); - switch (lhs_t.data_type()) { - case onnx::TensorProto_DataType_FLOAT: - return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); - case onnx::TensorProto_DataType_FLOAT16: - return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); - case onnx::TensorProto_DataType_INT64: - return *reinterpret_cast(lhs_value) == *reinterpret_cast(rhs_value); - default: - break; + std::vector unpacked_lhs_tensor, unpacked_rhs_tensor; + if (!utils::UnpackInitializerData(lhs_t, unpacked_lhs_tensor).IsOK() || + !utils::UnpackInitializerData(rhs_t, unpacked_rhs_tensor).IsOK()) { + return false; } - return false; + return unpacked_lhs_tensor == unpacked_rhs_tensor; } bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) { @@ -235,26 +224,16 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A return false; } -// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto. +// Support scalar tensor attribute only for now. std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) { std::size_t hash = 0; - if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) { + if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1) { int data_type = attr_t.data_type(); - switch (data_type) { - case onnx::TensorProto_DataType_FLOAT: - UpdateHash(data_type, hash); - UpdateHash(*reinterpret_cast(attr_t.raw_data().data()), hash); - break; - case onnx::TensorProto_DataType_FLOAT16: - UpdateHash(data_type, hash); - UpdateHash(static_cast(*reinterpret_cast(attr_t.raw_data().data())), hash); - break; - case onnx::TensorProto_DataType_INT64: - UpdateHash(data_type, hash); - UpdateHash(*reinterpret_cast(attr_t.raw_data().data()), hash); - break; - default: - break; + ORT_ENFORCE(data_type != onnx::TensorProto_DataType_STRING, "Unexpected tensor string type"); + std::vector unpacked_tensor; + if (utils::UnpackInitializerData(attr_t, unpacked_tensor).IsOK()) { + UpdateHash(data_type, hash); + UpdateHashWithContainer(unpacked_tensor, hash); } } return hash;