Skip to content

Commit 0df5dbc

Browse files
authored
Fix misaligned addresses while reading tensor attributes from raw data buffers (#27312)
### Description Explicitly copy tensor attribute values from raw data buffers instead of directly using pointers without checking proper memory alignment. ### Motivation and Context Fixes #27311 With some (rare) models ONNXRuntime built for Android armeabi-v7a 32 bit architecture crashes with bus error. Raw data buffers for tensor attributes of type `char*` has no proper alignment guarantees while reading values from them. Actually, it is in general (not only Android armeabi-v7a) UB to access objects at a misaligned address in C++, though many modern platforms allow it.
1 parent 735e69a commit 0df5dbc

File tree

1 file changed

+14
-35
lines changed

1 file changed

+14
-35
lines changed

onnxruntime/core/optimizer/common_subexpression_elimination.cc

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -172,29 +172,18 @@ bool AreRangesEqual(const Range& lhs, const Range& rhs) {
172172
}
173173

174174
// Check if two tensor attributes are equal scalar tensors, mainly to support ConstantOfShape Op.
175-
// Currently support float, float16 and int64 data types, and requires the data are raw data in TensorProto.
176175
bool AreScalarTensorAttributeEqual(const ONNX_NAMESPACE::TensorProto& lhs_t, const ONNX_NAMESPACE::TensorProto& rhs_t) {
177176
if (!(utils::HasDataType(lhs_t) && utils::HasDataType(rhs_t) && lhs_t.data_type() == rhs_t.data_type() &&
178-
(lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT ||
179-
lhs_t.data_type() == onnx::TensorProto_DataType_FLOAT16 ||
180-
lhs_t.data_type() == onnx::TensorProto_DataType_INT64) &&
181-
lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1 &&
182-
utils::HasRawData(lhs_t) && utils::HasRawData(rhs_t))) {
177+
lhs_t.data_type() != onnx::TensorProto_DataType_STRING &&
178+
lhs_t.dims_size() == 1 && rhs_t.dims_size() == 1 && lhs_t.dims()[0] == 1 && rhs_t.dims()[0] == 1)) {
183179
return false;
184180
}
185-
const void* lhs_value = lhs_t.raw_data().data();
186-
const void* rhs_value = rhs_t.raw_data().data();
187-
switch (lhs_t.data_type()) {
188-
case onnx::TensorProto_DataType_FLOAT:
189-
return *reinterpret_cast<const float*>(lhs_value) == *reinterpret_cast<const float*>(rhs_value);
190-
case onnx::TensorProto_DataType_FLOAT16:
191-
return *reinterpret_cast<const MLFloat16*>(lhs_value) == *reinterpret_cast<const MLFloat16*>(rhs_value);
192-
case onnx::TensorProto_DataType_INT64:
193-
return *reinterpret_cast<const int64_t*>(lhs_value) == *reinterpret_cast<const int64_t*>(rhs_value);
194-
default:
195-
break;
181+
std::vector<uint8_t> unpacked_lhs_tensor, unpacked_rhs_tensor;
182+
if (!utils::UnpackInitializerData(lhs_t, unpacked_lhs_tensor).IsOK() ||
183+
!utils::UnpackInitializerData(rhs_t, unpacked_rhs_tensor).IsOK()) {
184+
return false;
196185
}
197-
return false;
186+
return unpacked_lhs_tensor == unpacked_rhs_tensor;
198187
}
199188

200189
bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::AttributeProto& rhs) {
@@ -235,26 +224,16 @@ bool AreEqual(const ONNX_NAMESPACE::AttributeProto& lhs, const ONNX_NAMESPACE::A
235224
return false;
236225
}
237226

238-
// Support scalar float/int64/fp16 tensor attribute only for now, and requires data is raw data in TensorProto.
227+
// Support scalar tensor attribute only for now.
239228
std::size_t GetTensorAttributeHash(const ONNX_NAMESPACE::TensorProto& attr_t) {
240229
std::size_t hash = 0;
241-
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1 && utils::HasRawData(attr_t)) {
230+
if (utils::HasDataType(attr_t) && attr_t.dims_size() == 1 && attr_t.dims()[0] == 1) {
242231
int data_type = attr_t.data_type();
243-
switch (data_type) {
244-
case onnx::TensorProto_DataType_FLOAT:
245-
UpdateHash(data_type, hash);
246-
UpdateHash(*reinterpret_cast<const float*>(attr_t.raw_data().data()), hash);
247-
break;
248-
case onnx::TensorProto_DataType_FLOAT16:
249-
UpdateHash(data_type, hash);
250-
UpdateHash(static_cast<float>(*reinterpret_cast<const MLFloat16*>(attr_t.raw_data().data())), hash);
251-
break;
252-
case onnx::TensorProto_DataType_INT64:
253-
UpdateHash(data_type, hash);
254-
UpdateHash(*reinterpret_cast<const int64_t*>(attr_t.raw_data().data()), hash);
255-
break;
256-
default:
257-
break;
232+
ORT_ENFORCE(data_type != onnx::TensorProto_DataType_STRING, "Unexpected tensor string type");
233+
std::vector<uint8_t> unpacked_tensor;
234+
if (utils::UnpackInitializerData(attr_t, unpacked_tensor).IsOK()) {
235+
UpdateHash(data_type, hash);
236+
UpdateHashWithContainer(unpacked_tensor, hash);
258237
}
259238
}
260239
return hash;

0 commit comments

Comments
 (0)