Skip to content

Commit 8ef072a

Browse files
authored
Merge pull request #19 from jcapriot/python3.13
Python3.13
2 parents 24fa212 + 09f0463 commit 8ef072a

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

.github/workflows/python-package-conda.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ jobs:
2323
fail-fast: False
2424
matrix:
2525
os: [ubuntu-latest, macos-12, windows-latest]
26-
python-version: ["3.10", "3.11", "3.12"]
26+
python-version: ["3.10", "3.11", "3.12", "3.13"]
2727
mkl-version: ['2023', '2024']
2828
include:
2929
- os: ubuntu-latest
@@ -39,7 +39,7 @@ jobs:
3939
uses: conda-incubator/setup-miniconda@v3
4040
with:
4141
python-version: ${{ matrix.python-version }}
42-
channels: defaults
42+
channels: conda-forge, defaults
4343
channel-priority: true
4444
activate-environment: dev
4545

pydiso/mkl_solver.py

+19-12
Original file line numberDiff line numberDiff line change
@@ -143,14 +143,16 @@ def __init__(self, A, matrix_type=None, factor=True, verbose=False):
143143

144144
self.matrix_type = matrix_type
145145

146-
indptr = np.asarray(A.indptr) # double check it's a numpy array
146+
A = self._validate_csr_matrix(A)
147+
148+
max_a_ind_itemsize = max(A.indptr.itemsize, A.indices.itemsize)
147149
mkl_int_size = get_mkl_int_size()
148150
mkl_int64_size = get_mkl_int64_size()
149151

150-
target_int_size = mkl_int_size if indptr.itemsize <= mkl_int_size else mkl_int64_size
152+
target_int_size = mkl_int_size if max_a_ind_itemsize <= mkl_int_size else mkl_int64_size
151153
self._ind_dtype = np.dtype(f"i{target_int_size}")
152154

153-
data, indptr, indices = self._validate_matrix(A)
155+
data, indptr, indices = self._validate_matrix_dtypes(A)
154156
self._data = data
155157
self._indptr = indptr
156158
self._indices = indices
@@ -185,7 +187,9 @@ def refactor(self, A):
185187
raise TypeError("A is not a sparse matrix.")
186188
if A.shape != self.shape:
187189
raise ValueError("A is not the same size as the previous matrix.")
188-
data, indptr, indices = self._validate_matrix(A)
190+
191+
A = self._validate_csr_matrix(A)
192+
data, indptr, indices = self._validate_matrix_dtypes(A)
189193
if len(data) != len(self._data):
190194
raise ValueError("new A matrix does not have the same number of non zeros.")
191195

@@ -284,21 +288,24 @@ def iparm(self):
284288
"""
285289
return np.array(self._handle.iparm)
286290

287-
def _validate_matrix(self, mat):
288-
291+
def _validate_csr_matrix(self, mat):
289292
if self.matrix_type in [-2, 2, -4, 4, 6]:
290-
# Symmetric matrices must have only the upper triangle
291-
if sp.isspmatrix_csc(mat):
292-
mat = mat.T # Transpose to get a CSR matrix since it's symmetric
293+
# only grab the upper triangle.
293294
mat = sp.triu(mat, format='csr')
294295

295-
if not (sp.isspmatrix_csr(mat)):
296-
warnings.warn("Converting %s matrix to CSR format."
297-
% mat.__class__.__name__, PardisoTypeConversionWarning)
296+
if mat.format != 'csr':
297+
warnings.warn(
298+
"Converting %s matrix to CSR format."% A.__class__.__name__,
299+
PardisoTypeConversionWarning,
300+
stacklevel=3
301+
)
298302
mat = mat.tocsr()
303+
299304
mat.sort_indices()
300305
mat.sum_duplicates()
306+
return mat
301307

308+
def _validate_matrix_dtypes(self, mat):
302309
data = np.require(mat.data, self._data_dtype, requirements="C")
303310
indptr = np.require(mat.indptr, self._ind_dtype, requirements="C")
304311
indices = np.require(mat.indices, self._ind_dtype, requirements="C")

0 commit comments

Comments
 (0)