Bug: unexpected error from COO.from_numpy when using idx_dtype kwarg #810
Open
Description
sparse version checks
-
I checked that this issue has not been reported before list of issues.
-
I have confirmed this bug exists on the latest version of sparse.
-
I have confirmed this bug exists on the main branch of sparse.
Describe the bug
If you want to obtain create a muldimensional sparse array with idx_dtype=np.uint8 from a numpy array whose size is larger than 256, but whose max(shape) is smaller, you get an unexpected error.
Steps or code to reproduce the bug
This code
import numpy as np
import sparse
x = np.empty((25, 25)) # idem for x = np.zeros((25, 25))
idx_dtype = np.uint8
assert max(x.shape) < 256
sparse.COO.from_numpy(x, idx_dtype=idx_dtype)
Expected results
I would have expected no error, and an output with the correct shape and the correct idx_dtype.
Actual results
{
"name": "ValueError",
"message": "cannot cast array with shape (625,) to dtype <class 'numpy.uint8'>.",
"stack": "---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[17], line 6
4 x = np.empty((25, 25))
5 idx_dtype = np.uint8
----> 6 sparse.COO.from_numpy(x, idx_dtype=idx_dtype)
7 x.shape, x.size
File /usr/local/lib/python3.11/site-packages/sparse/_coo/core.py:400, in COO.from_numpy(cls, x, fill_value, idx_dtype)
398 coords = np.atleast_2d(np.flatnonzero(~equivalent(x, fill_value)))
399 data = x.ravel()[tuple(coords)]
--> 400 return cls(
401 coords,
402 data,
403 shape=x.size,
404 has_duplicates=False,
405 sorted=True,
406 fill_value=fill_value,
407 idx_dtype=idx_dtype,
408 ).reshape(x.shape)
File /usr/local/lib/python3.11/site-packages/sparse/_coo/core.py:272, in COO.__init__(self, coords, data, shape, has_duplicates, sorted, prune, cache, fill_value, idx_dtype)
270 if idx_dtype:
271 if not can_store(idx_dtype, max(shape)):
--> 272 raise ValueError(
273 \"cannot cast array with shape {} to dtype {}.\".format(
274 shape, idx_dtype
275 )
276 )
277 self.coords = self.coords.astype(idx_dtype)
279 if self.shape:
ValueError: cannot cast array with shape (625,) to dtype <class 'numpy.uint8'>."
}
Please describe your system.
- OS and version: Codespace from Github (Linux)
- sparse version 0.14.0
- NumPy version 1.23.5
- Numba version: not used
Relevant log output
No response