1010
1111from ... import logging as logg
1212from ... import preprocessing as pp
13- from ..._utils .random import _accepts_legacy_random_state , _legacy_random_state
13+ from ..._utils .random import _accepts_legacy_random_state , _FakeRandomGen
1414from ...get import _get_obs_rep
1515from . import pipeline
1616from .core import Scrublet
@@ -177,10 +177,12 @@ def scrublet( # noqa: PLR0913
177177
178178 adata_obs = adata .copy ()
179179
180- def _run_scrublet (ad_obs : AnnData , ad_sim : AnnData | None = None ):
180+ def _run_scrublet (
181+ ad_obs : AnnData , ad_sim : AnnData | None , * , rng : np .random .Generator
182+ ):
183+ rng_sim , rng_call = rng .spawn (2 )
181184 # With no adata_sim we assume the regular use case, starting with raw
182185 # counts and simulating doublets
183-
184186 if ad_sim is None :
185187 pp .filter_genes (ad_obs , min_cells = 3 )
186188 pp .filter_cells (ad_obs , min_genes = 3 )
@@ -207,7 +209,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
207209 layer = "raw" ,
208210 sim_doublet_ratio = sim_doublet_ratio ,
209211 synthetic_doublet_umi_subsampling = synthetic_doublet_umi_subsampling ,
210- rng = rng ,
212+ rng = rng_sim ,
211213 )
212214 del ad_obs .layers ["raw" ]
213215 if log_transform :
@@ -232,7 +234,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
232234 knn_dist_metric = knn_dist_metric ,
233235 get_doublet_neighbor_parents = get_doublet_neighbor_parents ,
234236 threshold = threshold ,
235- rng = rng ,
237+ rng = rng_call ,
236238 verbose = verbose ,
237239 )
238240
@@ -249,12 +251,14 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
249251 # Run Scrublet independently on batches and return just the
250252 # scrublet-relevant parts of the objects to add to the input object
251253 batches = np .unique (adata .obs [batch_key ])
254+ sub_rngs = rng .spawn (len (batches ))
252255 scrubbed = [
253256 _run_scrublet (
254257 adata_obs [adata_obs .obs [batch_key ] == batch ].copy (),
255258 adata_sim ,
259+ rng = sub_rng ,
256260 )
257- for batch in batches
261+ for batch , sub_rng in zip ( batches , sub_rngs , strict = True )
258262 ]
259263 scrubbed_obs = pd .concat ([scrub ["obs" ] for scrub in scrubbed ]).astype (
260264 adata .obs .dtypes
@@ -274,7 +278,7 @@ def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
274278 adata .uns ["scrublet" ]["batched_by" ] = batch_key
275279
276280 else :
277- scrubbed = _run_scrublet (adata_obs , adata_sim )
281+ scrubbed = _run_scrublet (adata_obs , adata_sim , rng = rng )
278282
279283 # Copy outcomes to input object from our processed version
280284 adata .obs ["doublet_score" ] = scrubbed ["obs" ]["doublet_score" ]
@@ -385,6 +389,12 @@ def _scrublet_call_doublets( # noqa: PLR0913
385389 Dictionary of Scrublet parameters
386390
387391 """
392+ meta_random_state = (
393+ dict (random_state = rng ._arg ) if isinstance (rng , _FakeRandomGen ) else {}
394+ )
395+ rng_scrub , rng_pca = rng .spawn (2 )
396+ del rng
397+
388398 # Estimate n_neighbors if not provided, and create scrublet object.
389399
390400 if n_neighbors is None :
@@ -398,7 +408,7 @@ def _scrublet_call_doublets( # noqa: PLR0913
398408 n_neighbors = n_neighbors ,
399409 expected_doublet_rate = expected_doublet_rate ,
400410 stdev_doublet_rate = stdev_doublet_rate ,
401- rng = rng ,
411+ rng = rng_scrub ,
402412 )
403413
404414 # Ensure normalised matrix sparseness as Scrublet does
@@ -424,13 +434,11 @@ def _scrublet_call_doublets( # noqa: PLR0913
424434
425435 if mean_center :
426436 logg .info ("Embedding transcriptomes using PCA..." )
427- pipeline .pca (
428- scrub , n_prin_comps = n_prin_comps , svd_solver = "arpack" , rng = scrub ._rng
429- )
437+ pipeline .pca (scrub , n_prin_comps = n_prin_comps , svd_solver = "arpack" , rng = rng_pca )
430438 else :
431439 logg .info ("Embedding transcriptomes using Truncated SVD..." )
432440 pipeline .truncated_svd (
433- scrub , n_prin_comps = n_prin_comps , algorithm = "arpack" , rng = scrub . _rng
441+ scrub , n_prin_comps = n_prin_comps , algorithm = "arpack" , rng = rng_pca
434442 )
435443
436444 # Score the doublets
@@ -463,7 +471,7 @@ def _scrublet_call_doublets( # noqa: PLR0913
463471 .get ("sim_doublet_ratio" , None )
464472 ),
465473 "n_neighbors" : n_neighbors ,
466- "random_state" : _legacy_random_state ( rng ) ,
474+ ** meta_random_state ,
467475 },
468476 }
469477
0 commit comments