Skip to content

Commit 7a95e42

Browse files
committed
add random_state for reproducibility
1 parent 178a629 commit 7a95e42

File tree

1 file changed

+62
-6
lines changed

1 file changed

+62
-6
lines changed

multispaeti/_multispati_pca.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
TransformerMixin,
1313
)
1414
from sklearn.preprocessing import normalize
15+
from sklearn.utils import check_random_state
1516
from sklearn.utils.validation import check_array, check_is_fitted
1617

1718
T = TypeVar("T", bound=np.number)
@@ -24,6 +25,40 @@
2425
_Connectivity: TypeAlias = np.ndarray | _Csr
2526

2627

28+
# adapted from scikit-learn
29+
# https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/utils/_arpack.py
30+
def _init_arpack_v0(
31+
size: int, random_state: int | np.random.RandomState | None
32+
) -> NDArray[np.float64]:
33+
"""
34+
Initialize the starting vector for iteration in ARPACK functions.
35+
36+
Initialize a ndarray with values sampled from the uniform distribution on
37+
[-1, 1]. This initialization model has been chosen to be consistent with
38+
the ARPACK one as another initialization can lead to convergence issues.
39+
40+
Parameters
41+
----------
42+
size : int
43+
The size of the eigenvalue vector to be initialized.
44+
45+
random_state : int, RandomState instance or None
46+
The seed of the pseudo random number generator used to generate a
47+
uniform distribution. If int, random_state is the seed used by the
48+
random number generator; If RandomState instance, random_state is the
49+
random number generator; If None, the random number generator is the
50+
RandomState instance used by `np.random`.
51+
52+
Returns
53+
-------
54+
v0 : ndarray of shape (size,)
55+
The initialized vector.
56+
"""
57+
random_state = check_random_state(random_state)
58+
v0 = random_state.uniform(-1, 1, size)
59+
return v0
60+
61+
2762
class MultispatiPCA(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstimator):
2863
"""
2964
MULTISPATI-PCA
@@ -55,6 +90,10 @@ class MultispatiPCA(ClassNamePrefixFeaturesOutMixin, TransformerMixin, BaseEstim
5590
Whether to center `X` if it is a sparse array. By default sparse `X` will not be
5691
centered as this requires transforming it to a dense array, potentially raising
5792
out-of-memory errors.
93+
random_state : int | numpy.random.RandomState | None
94+
Used when the `X` is sparse and `center_sparse` is False and for Moran's I
95+
bound estimation.
96+
Pass an int for reproducible results across multiple function calls.
5897
5998
Attributes
6099
----------
@@ -99,10 +138,12 @@ def __init__(
99138
*,
100139
connectivity: _Connectivity | None = None,
101140
center_sparse: bool = False,
141+
random_state: int | np.random.RandomState | None = None,
102142
):
103143
self.n_components = n_components
104144
self.connectivity = connectivity
105145
self.center_sparse = center_sparse
146+
self.random_state = random_state
106147

107148
@staticmethod
108149
def _validate_connectivity(W: _Connectivity, n: int):
@@ -231,18 +272,25 @@ def remove_zero_eigenvalues(
231272
H = (X.T @ (W + W.T) @ X) / (2 * n)
232273
# TODO handle sparse based on density?
233274
if issparse(H):
275+
v0 = _init_arpack_v0(H.shape[0], self.random_state)
234276
match self._n_components:
235277
case None:
236278
eig_val, eig_vec = sparse_linalg.eigsh(
237-
H, k=min(n, d) - 1, which="LM"
279+
H, k=min(n, d) - 1, which="LM", v0=v0
238280
)
239281
case (n_pos, 0) | int(n_pos):
240-
eig_val, eig_vec = sparse_linalg.eigsh(H, k=n_pos, which="LA")
282+
eig_val, eig_vec = sparse_linalg.eigsh(
283+
H, k=n_pos, which="LA", v0=v0
284+
)
241285
case (0, n_neg):
242-
eig_val, eig_vec = sparse_linalg.eigsh(H, k=n_neg, which="SA")
286+
eig_val, eig_vec = sparse_linalg.eigsh(
287+
H, k=n_neg, which="SA", v0=v0
288+
)
243289
case (n_pos, n_neg):
244290
n_comp = min(2 * max(n_neg, n_pos), min(n, d))
245-
eig_val, eig_vec = sparse_linalg.eigsh(H, k=n_comp, which="BE")
291+
eig_val, eig_vec = sparse_linalg.eigsh(
292+
H, k=n_comp, which="BE", v0=v0
293+
)
246294
component_indices = self._get_component_indices(
247295
n_comp, n_pos, n_neg
248296
)
@@ -399,8 +447,9 @@ def double_center(W: np.ndarray | csr_array) -> np.ndarray:
399447
if not issparse(W) or not sparse_approx:
400448
W = double_center(W)
401449

450+
v0 = _init_arpack_v0(W.shape[0], self.random_state)
402451
eigen_values = s * sparse_linalg.eigsh(
403-
W, k=2, which="BE", return_eigenvectors=False
452+
W, k=2, which="BE", return_eigenvectors=False, v0=v0
404453
)
405454

406455
I_0 = -1 / (n_sample - 1)
@@ -416,6 +465,7 @@ def multispati_pca(
416465
*,
417466
connectivity: _Connectivity | None = None,
418467
center_sparse: bool = False,
468+
random_state: int | np.random.RandomState | None = None,
419469
) -> tuple[np.ndarray, np.ndarray]:
420470
"""
421471
Calculate MULTISPATI-PCA and return the transformed data matrix and components.
@@ -445,14 +495,20 @@ def multispati_pca(
445495
Whether to center `X` if it is a sparse array. By default sparse `X` will not be
446496
centered as this requires transforming it to a dense array, potentially raising
447497
out-of-memory errors.
498+
random_state : int | numpy.random.RandomState | None
499+
Used when the `X` is sparse and `center_sparse` is False.
500+
Pass an int for reproducible results across multiple function calls.
448501
449502
Returns
450503
-------
451504
X_transformed : numpy.ndarray
452505
components : numpy.ndarray
453506
"""
454507
ms_pca = MultispatiPCA(
455-
n_components, connectivity=connectivity, center_sparse=center_sparse
508+
n_components,
509+
connectivity=connectivity,
510+
center_sparse=center_sparse,
511+
random_state=random_state,
456512
)
457513

458514
X_tr = ms_pca._fit(X, return_transform=True, stats=False)

0 commit comments

Comments
 (0)