Skip to content

Commit 5be0175

Browse files
committed
Data type independent tensor attribute hashing and comparison
1 parent 12a2a75 commit 5be0175

File tree

1 file changed

+7
-18
lines changed

1 file changed

+7
-18
lines changed

onnxruntime/core/optimizer/common_subexpression_elimination.cc

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,9 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
172172
}
173173

174174
// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
175-
// Currently support float, float16 and int64 data types.
176175
bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) {
177176
if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() &&
178-
(lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT ||
179-
lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 ||
180-
lhs_t.data_type() == onnx::TensorProto_DataType_INT64) &&
177+
lhs_t.data_type() != onnx::TensorProto_DataType_STRING &&
181178
lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1)) {
182179
return false;
183180
}
@@ -227,24 +224,16 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
227224
return false;
228225
}
229226

230-
// Support scalar float/int64/fp16 tensor attribute only for now.
227+
// Support scalar tensor attribute only for now.
231228
std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) {
232229
std::size_t hash = 0;
233230
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1) {
234231
int data_type = attr_t.data_type();
235-
switch (data_type) {
236-
case onnx::TensorProto_DataType_FLOAT:
237-
case onnx::TensorProto_DataType_FLOAT16:
238-
case onnx::TensorProto_DataType_INT64: {
239-
std::vector<uint8_t> unpacked_tensor;
240-
if (utils::UnpackInitializerData(attr_t, unpacked_tensor).IsOK()) {
241-
UpdateHash(data_type, hash);
242-
UpdateHashWithContainer(unpacked_tensor, hash);
243-
}
244-
break;
245-
}
246-
default:
247-
break;
232+
ORT_ENFORCE(data_type != onnx::TensorProto_DataType_STRING, "Unexpected tensor string type");
233+
std::vector<uint8_t> unpacked_tensor;
234+
if (utils::UnpackInitializerData(attr_t, unpacked_tensor).IsOK()) {
235+
UpdateHash(data_type, hash);
236+
UpdateHashWithContainer(unpacked_tensor, hash);
248237
}
249238
}
250239
return hash;

0 commit comments

Comments
 (0)