@@ -52,7 +52,7 @@ COOTensor::COOTensor(const ITensor *tensor, size_t sparse_dim) : SparseTensor(te
5252 " argument must be in [1,%zu] range. %zu is given" ,
5353 dim (), sparse_dim);
5454
55- const uint8_t *data = tensor->buffer ();
55+ const uint8_t *data = tensor->buffer () + info-> offset_first_element_in_bytes () ;
5656 const size_t dense_dims = dense_dim ();
5757 const auto is_nonzero = make_is_nonzero_predicate (info->data_type ());
5858
@@ -75,18 +75,18 @@ COOTensor::COOTensor(const ITensor *tensor, size_t sparse_dim) : SparseTensor(te
7575 const size_t slice_size = step * element_size;
7676
7777 size_t value_byte_size = 0 ;
78- size_t indices_bytes = 0 ;
7978 for (size_t i = 0 ; i < max_iter; i++)
8079 {
8180 const size_t offset = i * slice_size;
8281 if (has_non_zero_elements (const_cast <uint8_t *>(data + offset), slice_size, element_size, is_nonzero))
8382 {
8483 value_byte_size += slice_size;
85- indices_bytes += dim () * sizeof (int32_t );
8684 }
8785 }
8886
89- _allocator.init (coo_tensor_info (info), value_byte_size, indices_bytes);
87+ // Indices are stored in _indices (host vector); no index data is written into
88+ // the allocator buffer yet, so pass 0 for indices_bytes.
89+ _allocator.init (coo_tensor_info (info), value_byte_size, 0 );
9090 _allocator.allocate ();
9191
9292 for (size_t i = 0 ; i < max_iter; i++)
@@ -181,7 +181,7 @@ std::unique_ptr<ITensor> COOTensor::to_dense()
181181 for (size_t j = 0 ; j < dense_vol; ++j)
182182 {
183183 const void *value_ptr = block_ptr + j * element_size;
184- uint8_t *base_ptr = tensor->buffer () + final_offset + j * element_size;
184+ uint8_t *base_ptr = tensor->buffer () + first_elem_offset + final_offset + j * element_size;
185185
186186 std::memcpy (base_ptr, value_ptr, element_size);
187187 }
@@ -211,13 +211,14 @@ const uint8_t *COOTensor::get_value(Coordinates coords) const
211211 for (size_t i = 0 ; i < _indices.size (); ++i)
212212 {
213213 const Coordinates &c = _indices[i];
214- bool match = false ;
214+ bool match = true ;
215215
216216 for (size_t d = 0 ; d < coords.num_dimensions (); ++d)
217217 {
218- if (c[d] = = coords[d])
218+ if (c[d] ! = coords[d])
219219 {
220- match = true ;
220+ match = false ;
221+ break ;
221222 }
222223 }
223224 if (match)
0 commit comments