Skip to content

Commit 2ef0643

Browse files
Merge pull request #29 from projekter/empty-product
Fix #28: Empty product
2 parents 5ea1658 + 4f168eb commit 2ef0643

File tree

8 files changed

+133
-79
lines changed

8 files changed

+133
-79
lines changed

sparse_dot_mkl/_dense_dense.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _dense_matmul(matrix_a, matrix_b, scalar=1., out=None, out_scalar=None):
4848
# The complex versions of these functions take void pointers instead of passed structs
4949
# So create a C struct if necessary to be passed by reference
5050
scalar = _mkl_scalar(scalar, complex_type, double_precision)
51-
out_scalar = _mkl_scalar(out_scalar, complex_type, double_precision)
51+
out_scalar = _mkl_scalar(0 if out is None else out_scalar, complex_type, double_precision)
5252

5353
func(layout_a,
5454
111,
@@ -75,8 +75,13 @@ def _dense_dot_dense(matrix_a, matrix_b, cast=False, scalar=1., out=None, out_sc
7575
# Check for edge condition inputs which result in empty outputs
7676
if _empty_output_check(matrix_a, matrix_b):
7777
debug_print("Skipping multiplication because A (dot) B must yield an empty matrix")
78-
final_dtype = np.float64 if matrix_a.dtype != matrix_b.dtype or matrix_a.dtype != np.float32 else np.float32
79-
return _out_matrix((matrix_a.shape[0], matrix_b.shape[1]), final_dtype, out_arr=out)
78+
output_arr = _out_matrix((matrix_a.shape[0], matrix_b.shape[1]),
79+
_type_check(matrix_a, matrix_b, cast=cast, convert=False), out_arr=out)
80+
if out is None or (out_scalar is not None and not out_scalar):
81+
output_arr.fill(0)
82+
elif out_scalar is not None:
83+
output_arr *= out_scalar
84+
return output_arr
8085

8186
matrix_a, matrix_b = _type_check(matrix_a, matrix_b, cast=cast)
8287

sparse_dot_mkl/_gram_matrix.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _gram_matrix_sparse(
9393

9494

9595
# Dict keyed by ('double_precision_bool', 'complex_bool')
96-
_mkl_skryd_funcs = {
96+
_mkl_syrkd_funcs = {
9797
(False, False): MKL._mkl_sparse_s_syrkd,
9898
(True, False): MKL._mkl_sparse_d_syrkd,
9999
(False, True): MKL._mkl_sparse_c_syrkd,
@@ -129,19 +129,29 @@ def _gram_matrix_sparse_to_dense(
129129
_order_mkl_handle(sp_ref_a)
130130

131131
out_dtype = _output_dtypes[(double_prec, complex_type)]
132-
func = _mkl_skryd_funcs[(double_prec, complex_type)]
132+
func = _mkl_syrkd_funcs[(double_prec, complex_type)]
133133

134134
out_dim = matrix_a.shape[0 if aat else 1]
135135

136136
output_arr = _out_matrix(
137137
(out_dim, out_dim),
138138
out_dtype,
139-
order="C", out_arr=out)
139+
order="C",
140+
out_arr=out,
141+
initialize_zeros=True
142+
)
140143
_, output_ld = _get_numpy_layout(output_arr)
141144

142145
if _empty_output_check(matrix_a, matrix_a):
143146
_destroy_mkl_handle(sp_ref_a)
147+
if out is None or (out_scalar is not None and not out_scalar):
148+
output_arr.fill(0)
149+
elif out_scalar is not None:
150+
output_arr *= out_scalar
144151
return output_arr
152+
153+
if out is None:
154+
out_scalar = 0.
145155

146156
scalar = _mkl_scalar(scalar, complex_type, double_prec)
147157
out_scalar = _mkl_scalar(out_scalar, complex_type, double_prec)
@@ -165,7 +175,7 @@ def _gram_matrix_sparse_to_dense(
165175
# matrix. This stupid thing only happens with specific flags
166176
# I could probably leave it but it's pretty annoying
167177

168-
if not aat and out is None and not complex_type:
178+
if not aat and out is None:
169179
output_arr[np.tril_indices(output_arr.shape[0], k=-1)] = 0.0
170180

171181
return output_arr
@@ -208,6 +218,13 @@ def _gram_matrix_dense_to_dense(
208218
:rtype: numpy.ndarray
209219
"""
210220

221+
if aat and np.iscomplexobj(matrix_a):
222+
raise ValueError(
223+
"transpose=True with dense complex data currently "
224+
"fails with an Intel oneMKL ERROR: "
225+
"Parameter 3 was incorrect on entry to cblas_csyrk"
226+
)
227+
211228
# Get dimensions
212229
n, k = matrix_a.shape if aat else matrix_a.shape[::-1]
213230

@@ -223,15 +240,16 @@ def _gram_matrix_dense_to_dense(
223240
(n, n),
224241
out_dtype,
225242
order="C" if layout_a == LAYOUT_CODE_C else "F",
226-
out_arr=out
243+
out_arr=out,
244+
initialize_zeros=True
227245
)
228246

229247
# The complex versions of these functions take void pointers instead of
230248
# passed structs, so create a C struct if necessary to be passed by
231249
# reference
232250
scalar = _mkl_scalar(scalar, complex_type, double_precision)
233251
out_scalar = _mkl_scalar(out_scalar, complex_type, double_precision)
234-
252+
235253
func(
236254
layout_a,
237255
MKL_UPPER,
@@ -243,7 +261,7 @@ def _gram_matrix_dense_to_dense(
243261
ld_a,
244262
out_scalar if not complex_type else _ctypes.byref(scalar),
245263
output_arr,
246-
n,
264+
n
247265
)
248266

249267
return output_arr
@@ -279,6 +297,22 @@ def _gram_matrix(
279297
:rtype: scipy.sparse.csr_matrix, np.ndarray
280298
"""
281299

300+
if _sps.issparse(matrix) and not (is_csr(matrix) or is_csc(matrix)):
301+
raise ValueError(
302+
"gram_matrix requires sparse matrix to be CSR or CSC format"
303+
)
304+
elif is_csc(matrix) and not cast:
305+
raise ValueError(
306+
"gram_matrix cannot use a CSC matrix unless cast=True"
307+
)
308+
elif out is not None and not dense:
309+
raise ValueError(
310+
"out argument cannot be used with sparse (dot) sparse "
311+
"matrix multiplication"
312+
)
313+
elif out is not None and not isinstance(out, np.ndarray):
314+
raise ValueError("out argument must be dense")
315+
282316
# Check for edge condition inputs which result in empty outputs
283317
if _empty_output_check(matrix, matrix):
284318
debug_print(
@@ -290,25 +324,18 @@ def _gram_matrix(
290324
if transpose
291325
else (matrix.shape[0], matrix.shape[0])
292326
)
293-
output_func = _sps.csr_matrix if _sps.isspmatrix(matrix) else np.zeros
294-
return output_func(output_shape, dtype=matrix.dtype)
295-
296-
if np.iscomplexobj(matrix):
297-
raise ValueError(
298-
"gram_matrix_mkl does not support complex datatypes"
299-
)
327+
if out is None:
328+
output_func = np.zeros if dense else _sps.csr_matrix
329+
return output_func(output_shape, dtype=matrix.dtype)
330+
elif out_scalar is not None and not out_scalar:
331+
out.fill(0)
332+
elif out_scalar is not None:
333+
out *= out_scalar
334+
return out
300335

301336
matrix = _type_check(matrix, cast=cast)
302337

303-
if _sps.issparse(matrix) and not (is_csr(matrix) or is_csc(matrix)):
304-
raise ValueError(
305-
"gram_matrix requires sparse matrix to be CSR or CSC format"
306-
)
307-
elif is_csc(matrix) and not cast:
308-
raise ValueError(
309-
"gram_matrix cannot use a CSC matrix unless cast=True"
310-
)
311-
elif not _sps.issparse(matrix):
338+
if not _sps.issparse(matrix):
312339
return _gram_matrix_dense_to_dense(
313340
matrix,
314341
aat=transpose,
@@ -322,11 +349,6 @@ def _gram_matrix(
322349
out=out,
323350
out_scalar=out_scalar
324351
)
325-
elif out is not None:
326-
raise ValueError(
327-
"out argument cannot be used with sparse (dot) sparse "
328-
"matrix multiplication"
329-
)
330352
else:
331353
return _gram_matrix_sparse(
332354
matrix,

sparse_dot_mkl/_mkl_interface/_common.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -770,12 +770,14 @@ def _is_valid_dtype(matrix, complex_dtype=False, all_dtype=False):
770770
return matrix.dtype in NUMPY_FLOAT_DTYPES
771771

772772

773-
def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
773+
def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True, convert=True):
774774
"""
775775
Make sure that both matrices are single precision floats or both are
776776
double precision floats.
777777
If not, convert to double precision floats if cast is True,
778778
or raise an error if cast is False
779+
If convert is set to False, the resulting data type is returned without
780+
any conversion happening.
779781
"""
780782

781783
_n_complex = _np.iscomplexobj(matrix_a) + _np.iscomplexobj(matrix_b)
@@ -785,17 +787,17 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
785787

786788
# If there's no matrix B and matrix A is valid dtype, return it
787789
if matrix_b is None and _is_valid_dtype(matrix_a, all_dtype=True):
788-
return matrix_a
790+
return matrix_a if convert else matrix_a.dtype
789791

790792
# If matrix A is complex but not csingle or cdouble, and cast is True,
791793
# convert it to a cdouble
792794
elif matrix_b is None and cast and _n_complex == 1:
793-
return _cast_to(matrix_a, _np.cdouble)
795+
return _cast_to(matrix_a, _np.cdouble) if convert else _np.cdouble
794796

795797
# If matrix A is real but not float32 or float64, and cast is True,
796798
# convert it to a float64
797799
elif matrix_b is None and cast:
798-
return _cast_to(matrix_a, _np.float64)
800+
return _cast_to(matrix_a, _np.float64) if convert else _np.float64
799801

800802
# Raise an error - the dtype is invalid and cast is False
801803
elif matrix_b is None:
@@ -809,7 +811,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
809811
_is_valid_dtype(matrix_a, all_dtype=True) and
810812
matrix_a.dtype == matrix_b.dtype
811813
):
812-
return matrix_a, matrix_b
814+
return (matrix_a, matrix_b) if convert else matrix_a.dtype
813815

814816
# If neither matrix is complex and cast is True, convert to float64s
815817
# and return them
@@ -818,7 +820,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
818820
f"Recasting matrix data types {matrix_a.dtype} and "
819821
f"{matrix_b.dtype} to np.float64"
820822
)
821-
return _cast_to(matrix_a, _np.float64), _cast_to(matrix_b, _np.float64)
823+
return (_cast_to(matrix_a, _np.float64), _cast_to(matrix_b, _np.float64)) if convert else _np.float64
822824

823825
# If both matrices are complex and cast is True, convert to cdoubles
824826
# and return them
@@ -827,7 +829,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
827829
f"Recasting matrix data types {matrix_a.dtype} and "
828830
f"{matrix_b.dtype} to _np.cdouble"
829831
)
830-
return _cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)
832+
return (_cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)) if convert else _np.cdouble
831833

832834
# Cast reals and complex matrices together
833835
elif (
@@ -838,7 +840,7 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
838840
debug_print(
839841
f"Recasting matrix data type {matrix_b.dtype} to {matrix_a.dtype}"
840842
)
841-
return matrix_a, _cast_to(matrix_b, matrix_a.dtype)
843+
return (matrix_a, _cast_to(matrix_b, matrix_a.dtype)) if convert else matrix_a.dtype
842844

843845
elif (
844846
cast and
@@ -848,14 +850,14 @@ def _type_check(matrix_a, matrix_b=None, cast=False, allow_complex=True):
848850
debug_print(
849851
f"Recasting matrix data type {matrix_a.dtype} to {matrix_b.dtype}"
850852
)
851-
return _cast_to(matrix_a, matrix_b.dtype), matrix_b
853+
return (_cast_to(matrix_a, matrix_b.dtype), matrix_b) if convert else matrix_b.dtype
852854

853855
elif cast and _n_complex == 1:
854856
debug_print(
855857
f"Recasting matrix data type {matrix_a.dtype} and {matrix_b.dtype}"
856858
f" to np.cdouble"
857859
)
858-
return _cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)
860+
return (_cast_to(matrix_a, _np.cdouble), _cast_to(matrix_b, _np.cdouble)) if convert else _np.cdouble
859861

860862
# If cast is False, can't cast anything together
861863
elif not cast:
@@ -882,9 +884,16 @@ def _mkl_scalar(scalar, complex_type, double_precision):
882884
return float(scalar)
883885

884886

885-
def _out_matrix(shape, dtype, order="C", out_arr=None, out_t=False):
887+
def _out_matrix(
888+
shape,
889+
dtype,
890+
order="C",
891+
out_arr=None,
892+
out_t=False,
893+
initialize_zeros=False
894+
):
886895
"""
887-
Create an all-zero matrix or check to make sure that
896+
Create an undefined matrix or check to make sure that
888897
the provided output array matches
889898
890899
:param shape: Required output shape
@@ -904,8 +913,10 @@ def _out_matrix(shape, dtype, order="C", out_arr=None, out_t=False):
904913
out_t = False if out_t is None else out_t
905914

906915
# If there's no output array allocate a new array and return it
907-
if out_arr is None:
916+
if out_arr is None and initialize_zeros:
908917
return _np.zeros(shape, dtype=dtype, order=order)
918+
elif out_arr is None:
919+
return _np.ndarray(shape, dtype=dtype, order=order)
909920

910921
# Check and make sure the order is correct
911922
# Note 1d arrays have both flags set

sparse_dot_mkl/_sparse_dense.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _sparse_dense_matmul(
106106

107107
# Create a C struct if necessary to be passed
108108
scalar = _mkl_scalar(scalar, cplx, dbl)
109-
out_scalar = _mkl_scalar(out_scalar, cplx, dbl)
109+
out_scalar = _mkl_scalar(0 if out is None else out_scalar, cplx, dbl)
110110

111111
ret_val = func(
112112
11 if transpose else 10,
@@ -165,14 +165,15 @@ def _sparse_dot_dense(
165165
debug_print(
166166
"Skipping multiplication because A (dot) B must yield empty matrix"
167167
)
168-
final_dtype = (
169-
np.float64
170-
if matrix_a.dtype != matrix_b.dtype or matrix_a.dtype != np.float32
171-
else np.float32
172-
)
173-
return _out_matrix(
174-
(matrix_a.shape[0], matrix_b.shape[1]), final_dtype, out_arr=out
168+
output_arr = _out_matrix(
169+
(matrix_a.shape[0], matrix_b.shape[1]),
170+
_type_check(matrix_a, matrix_b, cast=cast, convert=False), out_arr=out
175171
)
172+
if out is None or (out_scalar is not None and not out_scalar):
173+
output_arr.fill(0)
174+
elif out_scalar is not None:
175+
output_arr *= out_scalar
176+
return output_arr
176177

177178
matrix_a, matrix_b = _type_check(matrix_a, matrix_b, cast=cast)
178179

sparse_dot_mkl/_sparse_sparse.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,16 +169,19 @@ def _sparse_dot_sparse(
169169
# Check for edge condition inputs which result in empty outputs
170170
if _empty_output_check(matrix_a, matrix_b):
171171

172+
final_dtype = _type_check(matrix_a, matrix_b, cast=cast, convert=False)
172173
if dense:
173-
return _out_matrix(
174+
output_arr = _out_matrix(
174175
(matrix_a.shape[0], matrix_b.shape[1]),
175-
matrix_a.dtype,
176+
final_dtype,
176177
out_arr=out
177178
)
179+
output_arr.fill(0)
180+
return output_arr
178181
else:
179182
return default_output(
180183
(matrix_a.shape[0], matrix_b.shape[1]),
181-
dtype=matrix_a.dtype
184+
dtype=final_dtype
182185
)
183186

184187
# Check dtypes

sparse_dot_mkl/_sparse_sypr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def _sypr_sparse_A_dense_B(
6969
matrix_b,
7070
layout_b,
7171
ld_b,
72-
float(out_scalar) if a_scalar is not None else 1.0,
73-
float(out_scalar) if out_scalar is not None else 1.0,
72+
float(a_scalar) if a_scalar is not None else 1.0,
73+
0.0 if out is None else (float(out_scalar) if out_scalar is not None else 1.0),
7474
output_arr,
7575
output_layout,
7676
output_ld,

0 commit comments

Comments
 (0)