@@ -143,14 +143,16 @@ def __init__(self, A, matrix_type=None, factor=True, verbose=False):
143
143
144
144
self .matrix_type = matrix_type
145
145
146
- indptr = np .asarray (A .indptr ) # double check it's a numpy array
146
+ A = self ._validate_csr_matrix (A )
147
+
148
+ max_a_ind_itemsize = max (A .indptr .itemsize , A .indices .itemsize )
147
149
mkl_int_size = get_mkl_int_size ()
148
150
mkl_int64_size = get_mkl_int64_size ()
149
151
150
- target_int_size = mkl_int_size if indptr . itemsize <= mkl_int_size else mkl_int64_size
152
+ target_int_size = mkl_int_size if max_a_ind_itemsize <= mkl_int_size else mkl_int64_size
151
153
self ._ind_dtype = np .dtype (f"i{ target_int_size } " )
152
154
153
- data , indptr , indices = self ._validate_matrix (A )
155
+ data , indptr , indices = self ._validate_matrix_dtypes (A )
154
156
self ._data = data
155
157
self ._indptr = indptr
156
158
self ._indices = indices
@@ -185,7 +187,9 @@ def refactor(self, A):
185
187
raise TypeError ("A is not a sparse matrix." )
186
188
if A .shape != self .shape :
187
189
raise ValueError ("A is not the same size as the previous matrix." )
188
- data , indptr , indices = self ._validate_matrix (A )
190
+
191
+ A = self ._validate_csr_matrix (A )
192
+ data , indptr , indices = self ._validate_matrix_dtypes (A )
189
193
if len (data ) != len (self ._data ):
190
194
raise ValueError ("new A matrix does not have the same number of non zeros." )
191
195
@@ -284,21 +288,24 @@ def iparm(self):
284
288
"""
285
289
return np .array (self ._handle .iparm )
286
290
287
- def _validate_matrix (self , mat ):
288
-
291
+ def _validate_csr_matrix (self , mat ):
289
292
if self .matrix_type in [- 2 , 2 , - 4 , 4 , 6 ]:
290
- # Symmetric matrices must have only the upper triangle
291
- if sp .isspmatrix_csc (mat ):
292
- mat = mat .T # Transpose to get a CSR matrix since it's symmetric
293
+ # only grab the upper triangle.
293
294
mat = sp .triu (mat , format = 'csr' )
294
295
295
- if not (sp .isspmatrix_csr (mat )):
296
- warnings .warn ("Converting %s matrix to CSR format."
297
- % mat .__class__ .__name__ , PardisoTypeConversionWarning )
296
+ if mat .format != 'csr' :
297
+ warnings .warn (
298
+ "Converting %s matrix to CSR format." % A .__class__ .__name__ ,
299
+ PardisoTypeConversionWarning ,
300
+ stacklevel = 3
301
+ )
298
302
mat = mat .tocsr ()
303
+
299
304
mat .sort_indices ()
300
305
mat .sum_duplicates ()
306
+ return mat
301
307
308
+ def _validate_matrix_dtypes (self , mat ):
302
309
data = np .require (mat .data , self ._data_dtype , requirements = "C" )
303
310
indptr = np .require (mat .indptr , self ._ind_dtype , requirements = "C" )
304
311
indices = np .require (mat .indices , self ._ind_dtype , requirements = "C" )
0 commit comments