@@ -172,29 +172,18 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
172172}
173173
174174// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
175- // Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto.
176175bool AreScalarTensorAttributeEqual (const ONNX_NAMESPACE::TensorProto& lhs_t , const ONNX_NAMESPACE::TensorProto& rhs_t ) {
177176 if (!(utils::HasDataType (lhs_t ) && utils::HasDataType (rhs_t ) && lhs_t .data_type () == rhs_t .data_type () &&
178- (lhs_t .data_type () == onnx::TensorProto_DataType_FLOAT ||
179- lhs_t .data_type () == onnx::TensorProto_DataType_FLOAT16 ||
180- lhs_t .data_type () == onnx::TensorProto_DataType_INT64) &&
181- lhs_t .dims_size () == 1 && rhs_t .dims_size () == 1 && lhs_t .dims ()[0 ] == 1 && rhs_t .dims ()[0 ] == 1 &&
182- utils::HasRawData (lhs_t ) && utils::HasRawData (rhs_t ))) {
177+ lhs_t .data_type () != onnx::TensorProto_DataType_STRING &&
178+ lhs_t .dims_size () == 1 && rhs_t .dims_size () == 1 && lhs_t .dims ()[0 ] == 1 && rhs_t .dims ()[0 ] == 1 )) {
183179 return false ;
184180 }
185- const void * lhs_value = lhs_t .raw_data ().data ();
186- const void * rhs_value = rhs_t .raw_data ().data ();
187- switch (lhs_t .data_type ()) {
188- case onnx::TensorProto_DataType_FLOAT:
189- return *reinterpret_cast <const float *>(lhs_value) == *reinterpret_cast <const float *>(rhs_value);
190- case onnx::TensorProto_DataType_FLOAT16:
191- return *reinterpret_cast <const MLFloat16*>(lhs_value) == *reinterpret_cast <const MLFloat16*>(rhs_value);
192- case onnx::TensorProto_DataType_INT64:
193- return *reinterpret_cast <const int64_t *>(lhs_value) == *reinterpret_cast <const int64_t *>(rhs_value);
194- default :
195- break ;
181+ std::vector<uint8_t > unpacked_lhs_tensor, unpacked_rhs_tensor;
182+ if (!utils::UnpackInitializerData (lhs_t , unpacked_lhs_tensor).IsOK () ||
183+ !utils::UnpackInitializerData (rhs_t , unpacked_rhs_tensor).IsOK ()) {
184+ return false ;
196185 }
197- return false ;
186+ return unpacked_lhs_tensor == unpacked_rhs_tensor ;
198187}
199188
200189bool AreEqual (const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) {
@@ -235,26 +224,16 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
235224 return false ;
236225}
237226
238- // Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto .
227+ // Support scalar tensor attribute only for now.
239228std::size_t GetTensorAttributeHash (const ONNX_NAMESPACE::TensorProto& attr_t ) {
240229 std::size_t hash = 0 ;
241- if (utils::HasDataType (attr_t ) && attr_t .dims_size () == 1 && attr_t .dims ()[0 ] == 1 && utils::HasRawData ( attr_t ) ) {
230+ if (utils::HasDataType (attr_t ) && attr_t .dims_size () == 1 && attr_t .dims ()[0 ] == 1 ) {
242231 int data_type = attr_t .data_type ();
243- switch (data_type) {
244- case onnx::TensorProto_DataType_FLOAT:
245- UpdateHash (data_type, hash);
246- UpdateHash (*reinterpret_cast <const float *>(attr_t .raw_data ().data ()), hash);
247- break ;
248- case onnx::TensorProto_DataType_FLOAT16:
249- UpdateHash (data_type, hash);
250- UpdateHash (static_cast <float >(*reinterpret_cast <const MLFloat16*>(attr_t .raw_data ().data ())), hash);
251- break ;
252- case onnx::TensorProto_DataType_INT64:
253- UpdateHash (data_type, hash);
254- UpdateHash (*reinterpret_cast <const int64_t *>(attr_t .raw_data ().data ()), hash);
255- break ;
256- default :
257- break ;
232+ ORT_ENFORCE (data_type != onnx::TensorProto_DataType_STRING, " Unexpected tensor string type" );
233+ std::vector<uint8_t > unpacked_tensor;
234+ if (utils::UnpackInitializerData (attr_t , unpacked_tensor).IsOK ()) {
235+ UpdateHash (data_type, hash);
236+ UpdateHashWithContainer (unpacked_tensor, hash);
258237 }
259238 }
260239 return hash;
0 commit comments