Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 14 additions & 35 deletions onnxruntime/core/optimizer/common_subexpression_elimination.cc
Original file line number Diff line number Diff line change
Expand Up @@ -172,29 +172,18 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
}

// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto.
bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) {
if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() &&
(lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT ||
lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 ||
lhs_t.data_type() == onnx::TensorProto_DataType_INT64) &&
lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 &&
utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) {
lhs_t.data_type() != onnx::TensorProto_DataType_STRING &&
lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1)) {
return false;
}
const void* lhs_value = lhs_t.raw_data().data();
const void* rhs_value = rhs_t.raw_data().data();
switch (lhs_t.data_type()) {
case onnx::TensorProto_DataType_FLOAT:
return *reinterpret_cast<const float*>(lhs_value) == *reinterpret_cast<const float*>(rhs_value);
case onnx::TensorProto_DataType_FLOAT16:
return *reinterpret_cast<const MLFloat16*>(lhs_value) == *reinterpret_cast<const MLFloat16*>(rhs_value);
case onnx::TensorProto_DataType_INT64:
return *reinterpret_cast<const int64_t*>(lhs_value) == *reinterpret_cast<const int64_t*>(rhs_value);
default:
break;
std::vector<uint8_t> unpacked_lhs_tensor, unpacked_rhs_tensor;
if (!utils::UnpackInitializerData(lhs_t, unpacked_lhs_tensor).IsOK() ||
!utils::UnpackInitializerData(rhs_t, unpacked_rhs_tensor).IsOK()) {
return false;
}
return false;
return unpacked_lhs_tensor == unpacked_rhs_tensor;
}

bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) {
Expand Down Expand Up @@ -235,26 +224,16 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
return false;
}

// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto.
// Support scalar tensor attribute only for now.
std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) {
std::size_t hash = 0;
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) {
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1) {
int data_type = attr_t.data_type();
switch (data_type) {
case onnx::TensorProto_DataType_FLOAT:
UpdateHash(data_type, hash);
UpdateHash(*reinterpret_cast<const float*>(attr_t.raw_data().data()), hash);
break;
case onnx::TensorProto_DataType_FLOAT16:
UpdateHash(data_type, hash);
UpdateHash(static_cast<float>(*reinterpret_cast<const MLFloat16*>(attr_t.raw_data().data())), hash);
break;
case onnx::TensorProto_DataType_INT64:
UpdateHash(data_type, hash);
UpdateHash(*reinterpret_cast<const int64_t*>(attr_t.raw_data().data()), hash);
break;
default:
break;
ORT_ENFORCE(data_type != onnx::TensorProto_DataType_STRING, "Unexpected tensor string type");
std::vector<uint8_t> unpacked_tensor;
if (utils::UnpackInitializerData(attr_t, unpacked_tensor).IsOK()) {
UpdateHash(data_type, hash);
UpdateHashWithContainer(unpacked_tensor, hash);
}
}
return hash;
Expand Down
Loading