Skip to content

Commit a44ac07

Browse files
committed
Add more tests
1 parent 361f22f commit a44ac07

File tree

2 files changed

+107
-4
lines changed

2 files changed

+107
-4
lines changed

onnxruntime/core/framework/tensorprotoutils.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,7 +1768,7 @@ static Status CopySparseData(const std::string& name,
17681768
ORT_RETURN_IF_NOT(indices.int64_data_size() == indices_elements,
17691769
"Sparse tensor: ", name, " indices int64 data size does not match expected: ",
17701770
indices_elements);
1771-
indices_data = gsl::make_span(indices.int64_data().data(), indices_elements);
1771+
indices_data = gsl::make_span(indices.int64_data().data(), narrow<size_t>(indices_elements));
17721772
}
17731773
break;
17741774
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
@@ -1976,9 +1976,12 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT
19761976
std::string dense_data_storage(narrow<size_t>(dense_elements) * element_size, 0);
19771977
if (nnz_elements > 0) {
19781978
// need to read in sparse data first as it could be in a type specific field, in raw data, or in external data
1979-
std::vector<uint8_t> sparse_data_storage;
1980-
ORT_RETURN_IF_ERROR(UnpackInitializerData(sparse_values, model_path, sparse_data_storage));
1981-
void* sparse_data = sparse_data_storage.data();
1979+
std::vector<uint8_t> values_data;
1980+
ORT_RETURN_IF_ERROR(UnpackInitializerData(sparse_values, model_path, values_data));
1981+
ORT_RETURN_IF_NOT(values_data.size() == static_cast<size_t>(nnz_elements) * element_size,
1982+
"Sparse tensor: ", name, " values data size does not match expected: ",
1983+
static_cast<size_t>(nnz_elements) * element_size);
1984+
void* sparse_data = values_data.data();
19821985
void* dense_data = dense_data_storage.data();
19831986

19841987
switch (element_size) {

onnxruntime/test/framework/sparse_kernels_test.cc

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,6 +2085,32 @@ TEST(SparseTensorConversionTests, SparseTensorProtoToDense_OutOfBounds_Rank2) {
20852085
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor index is out of bounds"));
20862086
}
20872087

2088+
TEST(SparseTensorConversionTests, SparseTensorProtoToDense_OutOfBounds_Rank2_Dim1) {
2089+
// Dense Shape [2, 2]
2090+
// Index [0, 2] -> 2 is out of bounds for the 2nd dimension (size 2)
2091+
ONNX_NAMESPACE::SparseTensorProto sparse;
2092+
sparse.mutable_values()->set_name("test_tensor_dim1_oob");
2093+
sparse.add_dims(2);
2094+
sparse.add_dims(2);
2095+
2096+
auto* val = sparse.mutable_values();
2097+
val->add_dims(1);
2098+
val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
2099+
val->add_float_data(1.0f);
2100+
2101+
auto* ind = sparse.mutable_indices();
2102+
ind->add_dims(1); // NNZ=1
2103+
ind->add_dims(2); // Rank=2
2104+
ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
2105+
ind->add_int64_data(0);
2106+
ind->add_int64_data(2); // Out of bounds for dim 1
2107+
2108+
ONNX_NAMESPACE::TensorProto dense;
2109+
auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense);
2110+
EXPECT_FALSE(status.IsOK());
2111+
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor_dim1_oob index is out of bounds"));
2112+
}
2113+
20882114
TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InvalidValuesRank) {
20892115
ONNX_NAMESPACE::SparseTensorProto sparse;
20902116
sparse.mutable_values()->set_name("test_tensor");
@@ -2150,6 +2176,80 @@ TEST(SparseTensorConversionTests, SparseTensorProtoToDense_NegativeDenseShape) {
21502176
EXPECT_FALSE(status.IsOK());
21512177
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor dense dims expected to be non-negative"));
21522178
}
2179+
2180+
TEST(SparseTensorConversionTests, SparseTensorProtoToDense_InvalidValuesRank_Zero) {
2181+
ONNX_NAMESPACE::SparseTensorProto sparse;
2182+
sparse.mutable_values()->set_name("test_tensor_val_rank_0");
2183+
sparse.add_dims(10);
2184+
2185+
auto* val = sparse.mutable_values();
2186+
// No dims added -> Rank 0 (Scalar)
2187+
val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
2188+
val->add_float_data(1.0f);
2189+
2190+
auto* ind = sparse.mutable_indices();
2191+
ind->add_dims(1);
2192+
ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
2193+
ind->add_int64_data(0);
2194+
2195+
ONNX_NAMESPACE::TensorProto dense;
2196+
auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense);
2197+
EXPECT_FALSE(status.IsOK());
2198+
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("Sparse tensor: test_tensor_val_rank_0 values should be rank 1"));
2199+
}
2200+
2201+
TEST(SparseTensorConversionTests, SparseTensorProtoToDense_ValuesSizeMismatch) {
2202+
// Case where the actual data in 'values' doesn't match the dimension specified in 'values'
2203+
ONNX_NAMESPACE::SparseTensorProto sparse;
2204+
sparse.mutable_values()->set_name("test_tensor_val_size_mismatch");
2205+
sparse.add_dims(10);
2206+
2207+
auto* val = sparse.mutable_values();
2208+
val->add_dims(2); // Claiming 2 elements
2209+
val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
2210+
val->add_float_data(1.0f);
2211+
// Only added 1 element, this should fail during UnpackInitializerData or subsequent checks depending on where it's caught
2212+
// Note: UnpackTensor checks if size matches.
2213+
2214+
auto* ind = sparse.mutable_indices();
2215+
ind->add_dims(2);
2216+
ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
2217+
ind->add_int64_data(0);
2218+
ind->add_int64_data(1);
2219+
2220+
ONNX_NAMESPACE::TensorProto dense;
2221+
auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense);
2222+
EXPECT_FALSE(status.IsOK());
2223+
// The error comes from UnpackTensor usually
2224+
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("data size"));
2225+
}
2226+
2227+
TEST(SparseTensorConversionTests, SparseTensorProtoToDense_ValuesSizeMismatch_RawData) {
2228+
// Case where raw data size doesn't match the shape size * element size
2229+
ONNX_NAMESPACE::SparseTensorProto sparse;
2230+
sparse.mutable_values()->set_name("test_tensor_val_size_mismatch_raw");
2231+
sparse.add_dims(10);
2232+
2233+
auto* val = sparse.mutable_values();
2234+
val->add_dims(2); // Claiming 2 elements
2235+
val->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
2236+
2237+
// 1 float is 4 bytes. We provide 4 bytes, but claim 2 elements (8 bytes needed).
2238+
float raw_val = 1.0f;
2239+
val->set_raw_data(&raw_val, sizeof(float));
2240+
2241+
auto* ind = sparse.mutable_indices();
2242+
ind->add_dims(2);
2243+
ind->set_data_type(ONNX_NAMESPACE::TensorProto_DataType_INT64);
2244+
ind->add_int64_data(0);
2245+
ind->add_int64_data(1);
2246+
2247+
ONNX_NAMESPACE::TensorProto dense;
2248+
auto status = utils::SparseTensorProtoToDenseTensorProto(sparse, {}, dense);
2249+
EXPECT_FALSE(status.IsOK());
2250+
EXPECT_THAT(status.ErrorMessage(), testing::HasSubstr("values data size does not match expected"));
2251+
}
2252+
21532253
#endif // !defined(DISABLE_SPARSE_TENSORS)
21542254
} // namespace test
21552255
} // namespace onnxruntime

0 commit comments

Comments
 (0)