Skip to content

Commit 1696069

Browse files
committed
Update README
1 parent 6f88249 commit 1696069

File tree

2 files changed

+32
-18
lines changed

2 files changed

+32
-18
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,13 @@ It will return a dense array X.
9696
Refer to the pardiso documentation for detailed description of options.
9797
Consider this wrapper to be experimental.
9898

99+
#### SciPy Classes
100+
101+
`csr_array`, `csr_matrix`, `csc_array`, `csc_matrix`, `bsr_array`, `bsr_matrix`
102+
103+
Scipy sparse classes where `__matmul__` and `__rmatmul__` have been replaced to use MKL
104+
for matrix math
105+
99106
#### Service Functions
100107

101108
Several service functions are available and can be imported from the base `sparse_dot_mkl` package.

sparse_dot_mkl/tests/test_scipy_classes.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import numpy.testing as npt
3+
import scipy as sp
34
import scipy.sparse as sps
45
from types import MethodType
56

@@ -77,28 +78,34 @@ def test_matmul_fail(self):
7778
with self.assertRaises(ValueError):
7879
b @ a
7980

80-
m1 = MATRIX_1.copy()
81-
m2 = MATRIX_2.copy()
82-
83-
install_wire(m1)
84-
install_wire(m2)
85-
install_wire(a)
86-
install_wire(b)
81+
# Following tests dont work with old scipy
82+
if (
83+
(int(sp.__version__.split('.')[1]) > 1) or
84+
(int(sp.__version__.split('.')[1]) > 13)
85+
):
86+
87+
m1 = MATRIX_1.copy()
88+
m2 = MATRIX_2.copy()
89+
90+
install_wire(m1)
91+
install_wire(m2)
92+
install_wire(a)
93+
install_wire(b)
94+
# SCIPY
95+
with self.assertRaises(TripError):
96+
m1 @ m2
8797

88-
# SCIPY
89-
with self.assertRaises(TripError):
90-
m1 @ m2
98+
# SCIPY CSR_MATRIX USES RMATMUL DUNNO WHY
99+
if self.arr != csr_matrix:
100+
with self.assertRaises(TripError):
101+
m1 @ b
91102

92-
# SCIPY CSR_MATRIX USES RMATMUL DUNNO WHY
93-
if self.arr != csr_matrix:
94-
with self.assertRaises(TripError):
95-
m1 @ b
103+
# MKL
104+
a @ m2
96105

97-
# MKL
98-
a @ m2
106+
# MKL
107+
a @ b
99108

100-
# MKL
101-
a @ b
102109

103110
class TestCSRMat(TestCSR):
104111
arr = csr_matrix

0 commit comments

Comments
 (0)