Skip to content

Commit 64a0f26

Browse files
committed
use sub-generators
1 parent 1e43b2a commit 64a0f26

File tree

11 files changed

+176
-122
lines changed

11 files changed

+176
-122
lines changed

hatch.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ overrides.matrix.deps.python = [
3636
{ if = [ "low-vers" ], value = "3.12" },
3737
]
3838
overrides.matrix.deps.extra-dependencies = [
39+
{ if = [ "stable" ], value = "scipy>=1.17" },
3940
{ if = [ "pre" ], value = "anndata @ git+https://github.com/scverse/anndata.git" },
4041
{ if = [ "pre" ], value = "pandas>=3rc0" },
4142
]

src/scanpy/_utils/random.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,19 @@ def wrap_global(
108108
np.random.seed(arg)
109109
return _FakeRandomGen(arg, np.random.RandomState(np.random.get_bit_generator()))
110110

111+
def spawn(self, n_children: int) -> list[Self]:
112+
"""Return `self` `n_children` times.
113+
114+
In a real generator, the spawned children are independent,
115+
but for backwards compatibility we return the same instance.
116+
"""
117+
return [self] * n_children
118+
111119
@classmethod
112120
def _delegate(cls) -> None:
113121
names = dict(integers="randint")
114122
for name, meth in np.random.Generator.__dict__.items():
115-
if name.startswith("_") or not callable(meth):
123+
if name.startswith("_") or not callable(meth) or name in cls.__dict__:
116124
continue
117125

118126
def mk_wrapper(name: str, meth):
@@ -129,11 +137,11 @@ def wrapper(self: _FakeRandomGen, *args, **kwargs):
129137
_FakeRandomGen._delegate()
130138

131139

132-
def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator:
133-
"""Re-apply legacy `random_state` semantics when `rng` is a `_FakeRandomGen`.
140+
def _if_legacy_apply_global(rng: np.random.Generator, /) -> np.random.Generator:
141+
"""Wrap the global legacy RNG if `rng` is a `_FakeRandomGen`.
134142
135-
This resets the global legacy RNG from the original `_arg` and returns a
136-
generator which continues drawing from the same internal state.
143+
This is used where our code used to call `np.random.seed()`.
144+
It’s a no-op if `rng` is not a `_FakeRandomGen`.
137145
"""
138146
if not isinstance(rng, _FakeRandomGen):
139147
return rng
@@ -142,20 +150,20 @@ def _if_legacy_apply_global(rng: np.random.Generator) -> np.random.Generator:
142150

143151

144152
def _legacy_random_state(
145-
rng: SeedLike | RNGLike | None, *, always_state: bool = False
153+
rng: SeedLike | RNGLike | None, /, *, always_state: bool = False
146154
) -> _LegacyRandom:
147155
"""Convert a np.random.Generator into a legacy `random_state` argument.
148156
149157
If `rng` is already a `_FakeRandomGen`, return its original `_arg` attribute.
150158
"""
151159
if isinstance(rng, _FakeRandomGen):
152160
return rng._state if always_state else rng._arg
153-
rng = np.random.default_rng(rng)
154-
return np.random.RandomState(rng.bit_generator.spawn(1)[0])
161+
[bitgen] = np.random.default_rng(rng).bit_generator.spawn(1)
162+
return np.random.RandomState(bitgen)
155163

156164

157165
def _accepts_legacy_random_state[**P, R](
158-
random_state_default: _LegacyRandom,
166+
random_state_default: _LegacyRandom, /
159167
) -> Callable[[Callable[P, R]], Callable[P, R]]:
160168
"""Make a function accept `random_state: _LegacyRandom` and pass it as `rng`.
161169

src/scanpy/neighbors/__init__.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010

1111
import numpy as np
1212
import scipy
13+
from packaging.version import Version
1314
from scipy import sparse
1415

1516
from .. import _utils
1617
from .. import logging as logg
17-
from .._compat import CSBase, CSRBase, SpBase, warn
18+
from .._compat import CSBase, CSRBase, SpBase, pkg_version, warn
1819
from .._settings import settings
1920
from .._utils import NeighborsView, _doc_params, get_literal_vals
2021
from .._utils.random import (
@@ -46,9 +47,8 @@
4647
# TODO: make `type` when https://github.com/sphinx-doc/sphinx/pull/13508 is released
4748
RPForestDict: TypeAlias = Mapping[str, Mapping[str, np.ndarray]] # noqa: UP040
4849

49-
N_DCS: int = 15 # default number of diffusion components
50-
# Backwards compat, constants should be defined in only one place.
51-
N_PCS: int = settings.N_PCS
50+
51+
SCIPY_1_17 = pkg_version("scipy") >= Version("1.17")
5252

5353

5454
class KwdsForTransformer(TypedDict):
@@ -208,6 +208,10 @@ def neighbors( # noqa: PLR0913
208208
:doc:`/how-to/knn-transformers`
209209
210210
"""
211+
meta_random_state = (
212+
dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {}
213+
)
214+
211215
if distances is None:
212216
if metric is None:
213217
metric = "euclidean"
@@ -235,9 +239,8 @@ def neighbors( # noqa: PLR0913
235239
if p.name in {"use_rep", "knn", "n_pcs", "metric_kwds"}
236240
if params[p.name] != p.default
237241
}
238-
if not isinstance(rng, _FakeRandomGen) or rng._arg != 0:
242+
if meta_random_state.get("random_state") != 0: # rng or random_state was passed
239243
ignored.add("rng/random_state")
240-
rng = _FakeRandomGen(0)
241244
if ignored:
242245
warn(
243246
f"Parameter(s) ignored if `distances` is given: {ignored}",
@@ -270,8 +273,8 @@ def neighbors( # noqa: PLR0913
270273
key_added,
271274
n_neighbors=neighbors_.n_neighbors,
272275
method=method,
273-
random_state=_legacy_random_state(rng),
274276
metric=metric,
277+
**meta_random_state,
275278
**({} if not metric_kwds else dict(metric_kwds=metric_kwds)),
276279
**({} if use_rep is None else dict(use_rep=use_rep)),
277280
**({} if n_pcs is None else dict(n_pcs=n_pcs)),
@@ -849,15 +852,13 @@ def compute_transitions(self, *, density_normalize: bool = True) -> None:
849852
self._transitions_sym = self.Z @ conn_norm @ self.Z
850853
logg.info(" finished", time=start)
851854

852-
@_accepts_legacy_random_state(0)
853855
def compute_eigen(
854856
self,
855857
*,
856858
n_comps: int = 15,
857-
sym: bool | None = None,
858859
sort: Literal["decrease", "increase"] = "decrease",
859860
rng: np.random.Generator,
860-
):
861+
) -> None:
861862
"""Compute eigen decomposition of transition matrix.
862863
863864
Parameters
@@ -886,6 +887,9 @@ def compute_eigen(
886887
plotting.
887888
888889
"""
890+
[rng_init, rng_eigsh] = np.random.default_rng(rng).spawn(2)
891+
del rng
892+
889893
np.set_printoptions(precision=10)
890894
if self._transitions_sym is None:
891895
msg = "Run `.compute_transitions` first."
@@ -903,9 +907,14 @@ def compute_eigen(
903907
matrix = matrix.astype(np.float64)
904908

905909
# Setting the random initial vector
906-
v0 = rng.standard_normal(matrix.shape[0])
910+
v0 = rng_init.standard_normal(matrix.shape[0])
907911
evals, evecs = sparse.linalg.eigsh(
908-
matrix, k=n_comps, which=which, ncv=ncv, v0=v0
912+
matrix,
913+
k=n_comps,
914+
which=which,
915+
ncv=ncv,
916+
v0=v0,
917+
**(dict(rng=rng_eigsh) if SCIPY_1_17 else {}),
909918
)
910919
evals, evecs = evals.astype(np.float32), evecs.astype(np.float32)
911920
if sort == "decrease":

src/scanpy/preprocessing/_pca/_compat.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
from ..._utils.random import RNGLike, SeedLike
2323

2424

25+
SCIPY_1_15 = pkg_version("scikit-learn") >= Version("1.5.0rc1")
26+
27+
2528
@_accepts_legacy_random_state(None)
2629
def _pca_compat_sparse(
2730
x: CSBase,
@@ -33,7 +36,11 @@ def _pca_compat_sparse(
3336
) -> tuple[NDArray[np.floating], PCA]:
3437
"""Sparse PCA for scikit-learn <1.4."""
3538
rng = np.random.default_rng(rng)
36-
random_init = rng.uniform(size=np.min(x.shape))
39+
# this exists only to be stored in our PCA container object
40+
random_state_meta = _legacy_random_state(rng)
41+
[rng_init, rng_svds] = rng.spawn(2)
42+
del rng
43+
3744
x = check_array(x, accept_sparse=["csr", "csc"])
3845

3946
if mu is None:
@@ -55,11 +62,15 @@ def rmat_op(v: NDArray[np.floating]):
5562
rmatmat=rmat_op,
5663
)
5764

58-
u, s, v = svds(linop, solver=solver, k=n_pcs, v0=random_init)
59-
# u_based_decision was changed in https://github.com/scikit-learn/scikit-learn/pull/27491
60-
u, v = svd_flip(
61-
u, v, u_based_decision=pkg_version("scikit-learn") < Version("1.5.0rc1")
65+
random_init = rng_init.uniform(size=np.min(x.shape))
66+
kw = (
67+
dict(rng=rng_svds)
68+
if SCIPY_1_15
69+
else dict(random_state=_legacy_random_state(rng_svds))
6270
)
71+
u, s, v = svds(linop, solver=solver, k=n_pcs, v0=random_init, **kw)
72+
# u_based_decision was changed in https://github.com/scikit-learn/scikit-learn/pull/27491
73+
u, v = svd_flip(u, v, u_based_decision=not SCIPY_1_15)
6374
idx = np.argsort(-s)
6475
v = v[idx, :]
6576

@@ -71,9 +82,7 @@ def rmat_op(v: NDArray[np.floating]):
7182

7283
from sklearn.decomposition import PCA
7384

74-
pca = PCA(
75-
n_components=n_pcs, svd_solver=solver, random_state=_legacy_random_state(rng)
76-
)
85+
pca = PCA(n_components=n_pcs, svd_solver=solver, random_state=random_state_meta)
7786
pca.explained_variance_ = ev
7887
pca.explained_variance_ratio_ = ev_ratio
7988
pca.components_ = v

src/scanpy/preprocessing/_scrublet/__init__.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from ... import logging as logg
1212
from ... import preprocessing as pp
13-
from ..._utils.random import _accepts_legacy_random_state, _legacy_random_state
13+
from ..._utils.random import _accepts_legacy_random_state, _FakeRandomGen
1414
from ...get import _get_obs_rep
1515
from . import pipeline
1616
from .core import Scrublet
@@ -177,10 +177,12 @@ def scrublet( # noqa: PLR0913
177177

178178
adata_obs = adata.copy()
179179

180-
def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
180+
def _run_scrublet(
181+
ad_obs: AnnData, ad_sim: AnnData | None, *, rng: np.random.Generator
182+
):
183+
rng_sim, rng_call = rng.spawn(2)
181184
# With no adata_sim we assume the regular use case, starting with raw
182185
# counts and simulating doublets
183-
184186
if ad_sim is None:
185187
pp.filter_genes(ad_obs, min_cells=3)
186188
pp.filter_cells(ad_obs, min_genes=3)
@@ -207,7 +209,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
207209
layer="raw",
208210
sim_doublet_ratio=sim_doublet_ratio,
209211
synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling,
210-
rng=rng,
212+
rng=rng_sim,
211213
)
212214
del ad_obs.layers["raw"]
213215
if log_transform:
@@ -232,7 +234,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
232234
knn_dist_metric=knn_dist_metric,
233235
get_doublet_neighbor_parents=get_doublet_neighbor_parents,
234236
threshold=threshold,
235-
rng=rng,
237+
rng=rng_call,
236238
verbose=verbose,
237239
)
238240

@@ -249,12 +251,14 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
249251
# Run Scrublet independently on batches and return just the
250252
# scrublet-relevant parts of the objects to add to the input object
251253
batches = np.unique(adata.obs[batch_key])
254+
sub_rngs = rng.spawn(len(batches))
252255
scrubbed = [
253256
_run_scrublet(
254257
adata_obs[adata_obs.obs[batch_key] == batch].copy(),
255258
adata_sim,
259+
rng=sub_rng,
256260
)
257-
for batch in batches
261+
for batch, sub_rng in zip(batches, sub_rngs, strict=True)
258262
]
259263
scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed]).astype(
260264
adata.obs.dtypes
@@ -274,7 +278,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
274278
adata.uns["scrublet"]["batched_by"] = batch_key
275279

276280
else:
277-
scrubbed = _run_scrublet(adata_obs, adata_sim)
281+
scrubbed = _run_scrublet(adata_obs, adata_sim, rng=rng)
278282

279283
# Copy outcomes to input object from our processed version
280284
adata.obs["doublet_score"] = scrubbed["obs"]["doublet_score"]
@@ -385,6 +389,12 @@ def _scrublet_call_doublets( # noqa: PLR0913
385389
Dictionary of Scrublet parameters
386390
387391
"""
392+
meta_random_state = (
393+
dict(random_state=rng._arg) if isinstance(rng, _FakeRandomGen) else {}
394+
)
395+
rng_scrub, rng_pca = rng.spawn(2)
396+
del rng
397+
388398
# Estimate n_neighbors if not provided, and create scrublet object.
389399

390400
if n_neighbors is None:
@@ -398,7 +408,7 @@ def _scrublet_call_doublets( # noqa: PLR0913
398408
n_neighbors=n_neighbors,
399409
expected_doublet_rate=expected_doublet_rate,
400410
stdev_doublet_rate=stdev_doublet_rate,
401-
rng=rng,
411+
rng=rng_scrub,
402412
)
403413

404414
# Ensure normalised matrix sparseness as Scrublet does
@@ -424,13 +434,11 @@ def _scrublet_call_doublets( # noqa: PLR0913
424434

425435
if mean_center:
426436
logg.info("Embedding transcriptomes using PCA...")
427-
pipeline.pca(
428-
scrub, n_prin_comps=n_prin_comps, svd_solver="arpack", rng=scrub._rng
429-
)
437+
pipeline.pca(scrub, n_prin_comps=n_prin_comps, svd_solver="arpack", rng=rng_pca)
430438
else:
431439
logg.info("Embedding transcriptomes using Truncated SVD...")
432440
pipeline.truncated_svd(
433-
scrub, n_prin_comps=n_prin_comps, algorithm="arpack", rng=scrub._rng
441+
scrub, n_prin_comps=n_prin_comps, algorithm="arpack", rng=rng_pca
434442
)
435443

436444
# Score the doublets
@@ -463,7 +471,7 @@ def _scrublet_call_doublets( # noqa: PLR0913
463471
.get("sim_doublet_ratio", None)
464472
),
465473
"n_neighbors": n_neighbors,
466-
"random_state": _legacy_random_state(rng),
474+
**meta_random_state,
467475
},
468476
}
469477

0 commit comments

Comments
 (0)