Skip to content

Commit 5e90c46

Browse files
committed
Fix misaligned addresses while reading tensor attributes from raw data buffers
1 parent 4295524 commit 5e90c46

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

onnxruntime/core/optimizer/common_subexpression_elimination.cc

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,26 @@ bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, con
186186
const void* rhs_value = rhs_t.raw_data().data();
187187
switch (lhs_t.data_type()) {
188188
case onnx::TensorProto_DataType_FLOAT:
189-
return *reinterpret_cast<const float*>(lhs_value) == *reinterpret_cast<const float*>(rhs_value);
189+
{
190+
float lhs_float_value, rhs_float_value;
191+
std::memcpy(&lhs_float_value, lhs_value, sizeof(lhs_float_value));
192+
std::memcpy(&rhs_float_value, rhs_value, sizeof(rhs_float_value));
193+
return lhs_float_value == rhs_float_value;
194+
}
190195
case onnx::TensorProto_DataType_FLOAT16:
191-
return *reinterpret_cast<const MLFloat16*>(lhs_value) == *reinterpret_cast<const MLFloat16*>(rhs_value);
196+
{
197+
MLFloat16 lhs_float16_value, rhs_float16_value;
198+
std::memcpy(&lhs_float16_value, lhs_value, sizeof(lhs_float16_value));
199+
std::memcpy(&rhs_float16_value, rhs_value, sizeof(rhs_float16_value));
200+
return lhs_float16_value == rhs_float16_value;
201+
}
192202
case onnx::TensorProto_DataType_INT64:
193-
return *reinterpret_cast<const int64_t*>(lhs_value) == *reinterpret_cast<const int64_t*>(rhs_value);
203+
{
204+
int64_t lhs_int64_value, rhs_int64_value;
205+
std::memcpy(&lhs_int64_value, lhs_value, sizeof(lhs_int64_value));
206+
std::memcpy(&rhs_int64_value, rhs_value, sizeof(rhs_int64_value));
207+
return lhs_int64_value == rhs_int64_value;
208+
}
194209
default:
195210
break;
196211
}
@@ -240,19 +255,32 @@ std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) {
240255
std::size_t hash = 0;
241256
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) {
242257
int data_type = attr_t.data_type();
258+
const char* value = attr_t.raw_data().data();
243259
switch (data_type) {
244260
case onnx::TensorProto_DataType_FLOAT:
261+
{
262+
float float_value;
263+
std::memcpy(&float_value, value, sizeof(float_value));
245264
UpdateHash(data_type, hash);
246-
UpdateHash(*reinterpret_cast<const float*>(attr_t.raw_data().data()), hash);
265+
UpdateHash(float_value, hash);
247266
break;
267+
}
248268
case onnx::TensorProto_DataType_FLOAT16:
269+
{
270+
MLFloat16 float16_value;
271+
std::memcpy(&float16_value, value, sizeof(float16_value));
249272
UpdateHash(data_type, hash);
250-
UpdateHash(static_cast<float>(*reinterpret_cast<const MLFloat16*>(attr_t.raw_data().data())), hash);
273+
UpdateHash(static_cast<float>(float16_value), hash);
251274
break;
275+
}
252276
case onnx::TensorProto_DataType_INT64:
277+
{
278+
int64_t int64_value;
279+
std::memcpy(&int64_value, value, sizeof(int64_value));
253280
UpdateHash(data_type, hash);
254-
UpdateHash(*reinterpret_cast<const int64_t*>(attr_t.raw_data().data()), hash);
281+
UpdateHash(int64_value, hash);
255282
break;
283+
}
256284
default:
257285
break;
258286
}

0 commit comments

Comments
 (0)