Skip to content

Commit 883d509

Browse files
committed
Fix issues and add tests
1 parent a70ac2f commit 883d509

File tree

3 files changed

+379
-71
lines changed

3 files changed

+379
-71
lines changed

onnxruntime/core/framework/tensorprotoutils.cc

Lines changed: 97 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -1744,42 +1744,41 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor) {
17441744
static Status CopySparseData(size_t n_sparse_elements,
17451745
const ONNX_NAMESPACE::TensorProto& indices,
17461746
const std::filesystem::path& model_path,
1747-
gsl::span<const int64_t>
1748-
dims,
1749-
std::function<void(size_t from_idx, size_t to_idx)>
1750-
copier) {
1747+
gsl::span<const int64_t> dense_dims,
1748+
int64_t dense_elements,
1749+
std::function<void(size_t from_idx, size_t to_idx)> copier) {
17511750
Status status = Status::OK();
17521751
TensorShape indices_shape(indices.dims().data(), indices.dims().size());
1753-
const auto elements = narrow<size_t>(indices_shape.Size());
1752+
const auto nnz_elements = narrow<size_t>(indices_shape.Size());
17541753

17551754
std::vector<int64_t> indices_values; // used for conversion of smaller size indices
17561755
std::vector<uint8_t> unpack_buffer;
17571756
gsl::span<const int64_t> indices_data;
1758-
const bool has_raw_data = indices.has_raw_data();
1757+
const bool has_raw_data = utils::HasRawData(indices);
17591758
switch (indices.data_type()) {
17601759
case ONNX_NAMESPACE::TensorProto_DataType_INT64:
17611760
if (has_raw_data) {
1762-
ORT_RETURN_IF_NOT(indices.raw_data().size() == (elements * sizeof(int64_t)),
1761+
ORT_RETURN_IF_NOT(indices.raw_data().size() == (nnz_elements * sizeof(int64_t)),
17631762
"Sparse Indices raw data size does not match expected.");
17641763
ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer));
17651764
indices_data = ReinterpretAsSpan<const int64_t>(gsl::make_span(unpack_buffer));
17661765
} else {
1767-
ORT_RETURN_IF_NOT(indices.int64_data_size() == static_cast<int64_t>(elements),
1766+
ORT_RETURN_IF_NOT(indices.int64_data_size() == static_cast<int64_t>(nnz_elements),
17681767
"Sparse indices int64 data size does not match expected");
1769-
indices_data = gsl::make_span(indices.int64_data().data(), elements);
1768+
indices_data = gsl::make_span(indices.int64_data().data(), nnz_elements);
17701769
}
17711770
break;
17721771
case ONNX_NAMESPACE::TensorProto_DataType_INT32: {
17731772
if (has_raw_data) {
1774-
ORT_RETURN_IF_NOT(indices.raw_data().size() == (elements * sizeof(int32_t)),
1773+
ORT_RETURN_IF_NOT(indices.raw_data().size() == (nnz_elements * sizeof(int32_t)),
17751774
"Sparse Indices raw data size does not match expected.");
17761775
ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer));
17771776
auto int32_span = ReinterpretAsSpan<const int32_t>(gsl::make_span(unpack_buffer));
17781777
indices_values.insert(indices_values.cend(), int32_span.begin(), int32_span.end());
17791778
unpack_buffer.clear();
17801779
unpack_buffer.shrink_to_fit();
17811780
} else {
1782-
ORT_RETURN_IF_NOT(indices.int32_data_size() == static_cast<int64_t>(elements),
1781+
ORT_RETURN_IF_NOT(indices.int32_data_size() == static_cast<int64_t>(nnz_elements),
17831782
"Sparse indices int32 data size does not match expected");
17841783
indices_values.insert(indices_values.cend(), indices.int32_data().cbegin(), indices.int32_data().cend());
17851784
}
@@ -1788,7 +1787,7 @@ static Status CopySparseData(size_t n_sparse_elements,
17881787
}
17891788
case ONNX_NAMESPACE::TensorProto_DataType_INT16: {
17901789
if (has_raw_data) {
1791-
ORT_RETURN_IF_NOT(indices.raw_data().size() == (elements * sizeof(int16_t)),
1790+
ORT_RETURN_IF_NOT(indices.raw_data().size() == (nnz_elements * sizeof(int16_t)),
17921791
"Sparse Indices raw data size does not match expected.");
17931792
ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer));
17941793
auto int16_span = ReinterpretAsSpan<const int16_t>(gsl::make_span(unpack_buffer));
@@ -1804,7 +1803,7 @@ static Status CopySparseData(size_t n_sparse_elements,
18041803
}
18051804
case ONNX_NAMESPACE::TensorProto_DataType_INT8: {
18061805
if (has_raw_data) {
1807-
ORT_RETURN_IF_NOT(indices.raw_data().size() == elements,
1806+
ORT_RETURN_IF_NOT(indices.raw_data().size() == nnz_elements,
18081807
"Sparse Indices raw data size does not match expected.");
18091808
ORT_RETURN_IF_ERROR(UnpackInitializerData(indices, model_path, unpack_buffer));
18101809
auto int8_span = ReinterpretAsSpan<const int8_t>(gsl::make_span(unpack_buffer));
@@ -1824,24 +1823,29 @@ static Status CopySparseData(size_t n_sparse_elements,
18241823
"Invalid SparseTensor indices. Should one of the following types: int8, int16, int32 or int64");
18251824
}
18261825

1827-
if (indices_shape.NumDimensions() == 1) {
1826+
const auto indices_rank = indices_shape.NumDimensions();
1827+
if (indices_rank == 1) {
18281828
// flattened indexes
18291829
for (size_t i = 0; i < n_sparse_elements; ++i) {
1830-
copier(i, narrow<size_t>(indices_data[i]));
1830+
const auto idx = indices_data[i];
1831+
ORT_RETURN_IF_NOT(idx >= 0 && idx < dense_elements,
1832+
"Sparse index is out of bounds. Got:", idx, " expected to be in [0, ", dense_elements, ")");
1833+
1834+
copier(i, narrow<size_t>(idx));
18311835
}
1832-
} else if (indices_shape.NumDimensions() == 2) {
1836+
} else if (indices_rank == 2) {
18331837
// entries in format {NNZ, rank}
1834-
ORT_ENFORCE(indices_shape[1] > 0 && static_cast<size_t>(indices_shape[1]) == dims.size());
1838+
ORT_ENFORCE(indices_shape[1] > 0 && static_cast<size_t>(indices_shape[1]) == dense_dims.size());
18351839
auto rank = static_cast<size_t>(indices_shape[1]);
18361840
auto cur_index = indices_data.begin();
1837-
std::vector<size_t> multipliers;
1841+
InlinedVector<size_t> multipliers;
18381842
multipliers.resize(rank);
18391843

18401844
// calculate sum of inner dimension elements for each dimension.
18411845
// e.g. if shape {2,3,4}, the result should be {3*4, 4, 1}
18421846
multipliers[rank - 1] = 1;
18431847
for (auto r = rank - 1; r > 0; --r) {
1844-
multipliers[r - 1] = SafeInt<size_t>(dims[r]) * multipliers[r];
1848+
multipliers[r - 1] = SafeInt<size_t>(dense_dims[r]) * multipliers[r];
18451849
}
18461850

18471851
// calculate the offset for the entry
@@ -1852,6 +1856,9 @@ static Status CopySparseData(size_t n_sparse_elements,
18521856
for (size_t j = 0; j < rank; ++j) {
18531857
idx += SafeInt<int64_t>(cur_index[j]) * multipliers[j];
18541858
}
1859+
ORT_RETURN_IF_NOT(idx >= 0 && idx < dense_elements,
1860+
"Sparse index is out of bounds. Got:", static_cast<int64_t>(idx),
1861+
" expected to be in [0, ", dense_elements, ")");
18551862

18561863
copier(i, static_cast<size_t>(idx));
18571864
cur_index += rank;
@@ -1860,7 +1867,7 @@ static Status CopySparseData(size_t n_sparse_elements,
18601867
ORT_ENFORCE(cur_index == indices_data.end());
18611868
} else {
18621869
status = ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH,
1863-
"Invalid SparseTensor indices. Should be rank 0 or 1. Got:", indices_shape);
1870+
"Invalid SparseTensor indices shape. Expected be rank 1 or 2. Got:", indices_shape);
18641871
}
18651872

18661873
return status;
@@ -1871,26 +1878,46 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT
18711878
ONNX_NAMESPACE::TensorProto& dense) {
18721879
Status status = Status::OK();
18731880

1881+
const auto& indices = sparse.indices();
1882+
const auto indices_rank = indices.dims_size();
1883+
if (indices_rank != 1 && indices_rank != 2) {
1884+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH,
1885+
"Indices should be rank 1 or 2 for supported COO format. Got:", indices_rank);
1886+
}
1887+
18741888
const auto& sparse_values = sparse.values();
1889+
const auto values_rank = sparse_values.dims_size();
1890+
if (values_rank != 1) {
1891+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH,
1892+
"Values should be rank 1 for COO format. Got:", values_rank);
1893+
}
1894+
18751895
auto type = sparse_values.data_type();
18761896
dense.set_data_type(type);
18771897
*dense.mutable_name() = sparse_values.name();
18781898

18791899
SafeInt<size_t> n_sparse_elements = 1;
18801900
for (auto dim : sparse_values.dims()) {
1901+
if (dim <= 0) {
1902+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH,
1903+
"Sparse values tensor dims expected to be positive. Got:", dim);
1904+
}
18811905
n_sparse_elements *= dim;
18821906
}
18831907

1884-
SafeInt<size_t> n_dense_elements = 1;
1908+
SafeInt<int64_t> dense_elements = 1;
18851909
for (auto dim : sparse.dims()) {
1886-
n_dense_elements *= dim;
1910+
if (dim <= 0) {
1911+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH,
1912+
"Sparse tensor dense dims expected to be positive. Got:", dim);
1913+
}
1914+
dense_elements *= dim;
18871915
dense.add_dims(dim);
18881916
}
18891917

1890-
const auto& indices = sparse.indices();
1891-
auto dims = gsl::make_span<const int64_t>(dense.dims().data(), dense.dims().size());
1918+
const auto dense_dims = gsl::make_span<const int64_t>(dense.dims().data(), dense.dims().size());
18921919

1893-
if (type != TensorProto_DataType_STRING) {
1920+
if (type != ONNX_NAMESPACE::TensorProto_DataType_STRING) {
18941921
auto ml_data = DataTypeImpl::TensorTypeFromONNXEnum(type)->GetElementType();
18951922
size_t element_size = ml_data->Size();
18961923

@@ -1901,56 +1928,55 @@ common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseT
19011928

19021929
// by putting the data into a std::string we can avoid a copy as set_raw_data can do a std::move
19031930
// into the TensorProto.
1904-
std::string dense_data_storage(n_dense_elements * element_size, 0);
1905-
if (n_sparse_elements > 0) {
1906-
void* dense_data = dense_data_storage.data();
1907-
1908-
switch (element_size) {
1909-
case 1: {
1910-
status = CopySparseData(
1911-
n_sparse_elements, indices, model_path, dims, [sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1912-
static_cast<uint8_t*>(dense_data)[to_idx] = static_cast<const uint8_t*>(sparse_data)[from_idx];
1913-
});
1914-
1915-
break;
1916-
}
1917-
case 2: {
1918-
status = CopySparseData(n_sparse_elements, indices, model_path, dims,
1919-
[sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1920-
const auto* src = static_cast<const uint16_t*>(sparse_data) + from_idx;
1921-
auto* dst = static_cast<uint16_t*>(dense_data) + to_idx;
1922-
memcpy(dst, src, sizeof(uint16_t));
1923-
});
1924-
1925-
break;
1926-
}
1927-
case 4: {
1928-
status = CopySparseData(n_sparse_elements, indices, model_path, dims,
1929-
[sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1930-
const auto* src = static_cast<const uint32_t*>(sparse_data) + from_idx;
1931-
auto* dst = static_cast<uint32_t*>(dense_data) + to_idx;
1932-
memcpy(dst, src, sizeof(uint32_t));
1933-
});
1934-
1935-
break;
1936-
}
1937-
case 8: {
1938-
status = CopySparseData(n_sparse_elements, indices, model_path, dims,
1939-
[sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1940-
const auto* src = static_cast<const uint64_t*>(sparse_data) + from_idx;
1941-
auto* dst = static_cast<uint64_t*>(dense_data) + to_idx;
1942-
memcpy(dst, src, sizeof(uint64_t));
1943-
});
1944-
break;
1945-
}
1931+
std::string dense_data_storage(narrow<size_t>(dense_elements) * element_size, 0);
1932+
void* dense_data = dense_data_storage.data();
19461933

1947-
default:
1948-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Element_size of: ", element_size, " is not supported.",
1949-
" type: ", type);
1934+
switch (element_size) {
1935+
case 1: {
1936+
status = CopySparseData(
1937+
n_sparse_elements, indices, model_path, dense_dims, dense_elements,
1938+
[sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1939+
static_cast<uint8_t*>(dense_data)[to_idx] = static_cast<const uint8_t*>(sparse_data)[from_idx];
1940+
});
1941+
1942+
break;
1943+
}
1944+
case 2: {
1945+
status = CopySparseData(n_sparse_elements, indices, model_path, dense_dims, dense_elements,
1946+
[sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1947+
const auto* src = static_cast<const uint16_t*>(sparse_data) + from_idx;
1948+
auto* dst = static_cast<uint16_t*>(dense_data) + to_idx;
1949+
memcpy(dst, src, sizeof(uint16_t));
1950+
});
1951+
1952+
break;
19501953
}
1954+
case 4: {
1955+
status = CopySparseData(n_sparse_elements, indices, model_path, dense_dims, dense_elements,
1956+
[sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1957+
const auto* src = static_cast<const uint32_t*>(sparse_data) + from_idx;
1958+
auto* dst = static_cast<uint32_t*>(dense_data) + to_idx;
1959+
memcpy(dst, src, sizeof(uint32_t));
1960+
});
19511961

1952-
ORT_RETURN_IF_ERROR(status);
1962+
break;
1963+
}
1964+
case 8: {
1965+
status = CopySparseData(n_sparse_elements, indices, model_path, dense_dims, dense_elements,
1966+
[sparse_data, dense_data](size_t from_idx, size_t to_idx) {
1967+
const auto* src = static_cast<const uint64_t*>(sparse_data) + from_idx;
1968+
auto* dst = static_cast<uint64_t*>(dense_data) + to_idx;
1969+
memcpy(dst, src, sizeof(uint64_t));
1970+
});
1971+
break;
1972+
}
1973+
1974+
default:
1975+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Element_size of: ", element_size, " is not supported.",
1976+
" type: ", type);
19531977
}
1978+
1979+
ORT_RETURN_IF_ERROR(status);
19541980
utils::SetRawDataInTensorProto(dense, std::move(dense_data_storage));
19551981
} else {
19561982
// No request for std::string

onnxruntime/core/framework/tensorprotoutils.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,14 @@ void MakeCpuTensorCopy(const Tensor& src_tensor, Tensor& dst_tensor);
253253
// If the SparseTensorProto contains external data then it loads the data and converts to dense tensor proto
254254
// The resulting TensorProto will contain the data as raw data.
255255
// model_path is used for constructing full path for external_data
256+
257+
// The function supports only COO format with 1D or 2D indices. Values shape is expected to be 1D.
258+
// The function does not support sparse tensors with 2D indices or other formats like CSR/CSC.
259+
/// </summary>
260+
/// <param name="sparse"></param>
261+
/// <param name="model_path">model path is only used if there are references to external data.</param>
262+
/// <param name="dense">The resulting dense tensor proto.</param>
263+
/// <returns>Status</returns>
256264
common::Status SparseTensorProtoToDenseTensorProto(const ONNX_NAMESPACE::SparseTensorProto& sparse,
257265
const std::filesystem::path& model_path,
258266
ONNX_NAMESPACE::TensorProto& dense);

0 commit comments

Comments
 (0)