Skip to content

Commit 4f81989

Browse files
committed
Expand test coverage and remove unnecessary lines.
1 parent 3cc7bd3 commit 4f81989

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

sparse/numba_backend/_common.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -1610,11 +1610,9 @@ def eye(N, M=None, k=0, dtype=float, format="coo", *, device=None, **kwargs):
16101610
return zeros((N, M), dtype=dtype, format=format, device=device)
16111611

16121612
if k > 0:
1613-
data_length = builtins.max(builtins.min(data_length, M - k), 0)
16141613
n_coords = np.arange(data_length, dtype=np.intp)
16151614
m_coords = n_coords + k
16161615
elif k < 0:
1617-
data_length = builtins.max(builtins.min(data_length, N + k), 0)
16181616
m_coords = np.arange(data_length, dtype=np.intp)
16191617
n_coords = m_coords - k
16201618
else:
@@ -1888,10 +1886,7 @@ def can_cast(from_: SparseArray, to: np.dtype, /, *, casting: str = "safe") -> b
18881886
--------
18891887
- [`numpy.can_cast`][] : NumPy equivalent function
18901888
"""
1891-
try:
1892-
from_ = np.dtype(from_)
1893-
except TypeError:
1894-
from_ = from_.dtype
1889+
from_ = np.dtype(from_)
18951890

18961891
return np.can_cast(from_, to, casting=casting)
18971892

sparse/numba_backend/tests/test_coo.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1203,11 +1203,21 @@ def test_initialization(ndim, rng):
12031203
@pytest.mark.parametrize("N, M", [(4, None), (4, 10), (10, 4), (0, 10)])
12041204
def test_eye(N, M):
12051205
m = M or N
1206-
for k in [0, N - 2, N + 2, m - 2, m + 2]:
1206+
for k in [0, N - 2, N + 2, m - 2, m + 2, np.iinfo(np.intp).min]:
12071207
assert_eq(sparse.eye(N, M=M, k=k), np.eye(N, M=M, k=k))
12081208
assert_eq(sparse.eye(N, M=M, k=k, dtype="i4"), np.eye(N, M=M, k=k, dtype="i4"))
12091209

12101210

1211+
@pytest.mark.parametrize("from_", [np.int8, np.int64, np.float32, np.float64, np.complex64, np.complex128])
1212+
@pytest.mark.parametrize("to", [np.int8, np.int64, np.float32, np.float64, np.complex64, np.complex128])
1213+
@pytest.mark.parametrize("casting", ["no", "safe", "same_kind"])
1214+
def test_can_cast(from_, to, casting):
1215+
assert sparse.can_cast(sparse.zeros((2, 2), dtype=from_), to, casting=casting) == np.can_cast(
1216+
np.zeros((2, 2), dtype=from_), to, casting=casting
1217+
)
1218+
assert sparse.can_cast(from_, to, casting=casting) == np.can_cast(from_, to, casting=casting)
1219+
1220+
12111221
@pytest.mark.parametrize("funcname", ["ones", "zeros"])
12121222
def test_ones_zeros(funcname):
12131223
sp_func = getattr(sparse, funcname)

0 commit comments

Comments
 (0)