@@ -172,12 +172,9 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
172172}
173173
174174// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
175- // Currently support float, float16 and int64 data types.
176175bool AreScalarTensorAttributeEqual (const ONNX_NAMESPACE::TensorProto& lhs_t , const ONNX_NAMESPACE::TensorProto& rhs_t ) {
177176 if (!(utils::HasDataType (lhs_t ) && utils::HasDataType (rhs_t ) && lhs_t .data_type () == rhs_t .data_type () &&
178- (lhs_t .data_type () == onnx::TensorProto_DataType_FLOAT ||
179- lhs_t .data_type () == onnx::TensorProto_DataType_FLOAT16 ||
180- lhs_t .data_type () == onnx::TensorProto_DataType_INT64) &&
177+ lhs_t .data_type () != onnx::TensorProto_DataType_STRING &&
181178 lhs_t .dims_size () == 1 && rhs_t .dims_size () == 1 && lhs_t .dims ()[0 ] == 1 && rhs_t .dims ()[0 ] == 1 )) {
182179 return false ;
183180 }
@@ -227,24 +224,16 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
227224 return false ;
228225}
229226
230- // Support scalar float/int64/fp16 tensor attribute only for now.
227+ // Support scalar tensor attribute only for now.
231228std::size_t GetTensorAttributeHash (const ONNX_NAMESPACE::TensorProto& attr_t ) {
232229 std::size_t hash = 0 ;
233230 if (utils::HasDataType (attr_t ) && attr_t .dims_size () == 1 && attr_t .dims ()[0 ] == 1 ) {
234231 int data_type = attr_t .data_type ();
235- switch (data_type) {
236- case onnx::TensorProto_DataType_FLOAT:
237- case onnx::TensorProto_DataType_FLOAT16:
238- case onnx::TensorProto_DataType_INT64: {
239- std::vector<uint8_t > unpacked_tensor;
240- if (utils::UnpackInitializerData (attr_t , unpacked_tensor).IsOK ()) {
241- UpdateHash (data_type, hash);
242- UpdateHashWithContainer (unpacked_tensor, hash);
243- }
244- break ;
245- }
246- default :
247- break ;
232+ ORT_ENFORCE (data_type != onnx::TensorProto_DataType_STRING, " Unexpected tensor string type" );
233+ std::vector<uint8_t > unpacked_tensor;
234+ if (utils::UnpackInitializerData (attr_t , unpacked_tensor).IsOK ()) {
235+ UpdateHash (data_type, hash);
236+ UpdateHashWithContainer (unpacked_tensor, hash);
248237 }
249238 }
250239 return hash;
0 commit comments