Skip to content

Commit 965c8d8

Browse files
authored
Make vecdot compliant to the Array API (#850)
1 parent 3174527 commit 965c8d8

File tree

3 files changed

+2
-3
lines changed

3 files changed

+2
-3
lines changed

ci/Numba-array-api-xfails.txt

-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack__]
4141
array_api_tests/test_has_names.py::test_has_names[array_method-__dlpack_device__]
4242
array_api_tests/test_has_names.py::test_has_names[array_method-__setitem__]
4343
array_api_tests/test_indexing_functions.py::test_take
44-
array_api_tests/test_linalg.py::test_vecdot
4544
array_api_tests/test_set_functions.py::test_unique_all
4645
array_api_tests/test_set_functions.py::test_unique_inverse
4746
array_api_tests/test_signatures.py::test_func_signature[unique_all]

sparse/numba_backend/_common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3102,4 +3102,4 @@ def vecdot(x1, x2, /, *, axis=-1):
31023102
if np.issubdtype(x1.dtype, np.complexfloating):
31033103
x1 = np.conjugate(x1)
31043104

3105-
return np.sum(x1 * x2, axis=axis)
3105+
return np.sum(x1 * x2, axis=axis, dtype=np.result_type(x1, x2))

sparse/numba_backend/tests/test_coo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1851,8 +1851,8 @@ def np_vecdot(x1, x2, /, *, axis=-1):
18511851
return np.sum(x1 * x2, axis=axis)
18521852

18531853
actual = sparse.vecdot(s1, s2, axis=axis)
1854+
assert s1.dtype == s2.dtype == actual.dtype
18541855
expected = np_vecdot(x1, x2, axis=axis)
1855-
18561856
np.testing.assert_allclose(actual.todense(), expected)
18571857

18581858

0 commit comments

Comments
 (0)