-
Notifications
You must be signed in to change notification settings - Fork 46
Expand file tree
/
Copy path__init__.py
More file actions
603 lines (514 loc) · 23.4 KB
/
__init__.py
File metadata and controls
603 lines (514 loc) · 23.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
import cupy as cp
import numpy as np
import pandas as pd
from anndata import AnnData
from cupyx.scipy import sparse
from scanpy import logging as logg
from scanpy._compat import old_positionals
from scanpy.get import _get_obs_rep
from rapids_singlecell import preprocessing as pp
from rapids_singlecell._compat import DaskArray
from rapids_singlecell.preprocessing._utils import _check_gpu_X
from . import pipeline
from .core import Scrublet
if TYPE_CHECKING:
from rapids_singlecell._utils import AnyRandom
from rapids_singlecell.preprocessing._neighbors import _Metrics
@old_positionals(
"batch_key",
"sim_doublet_ratio",
"expected_doublet_rate",
"stdev_doublet_rate",
"synthetic_doublet_umi_subsampling",
"knn_dist_metric",
"normalize_variance",
"log_transform",
"mean_center",
"n_prin_comps",
"use_approx_neighbors",
"get_doublet_neighbor_parents",
"n_neighbors",
"threshold",
"verbose",
"copy",
"random_state",
)
def scrublet(
adata: AnnData,
adata_sim: AnnData | None = None,
*,
batch_key: str | None = None,
sim_doublet_ratio: float = 2.0,
expected_doublet_rate: float = 0.05,
stdev_doublet_rate: float = 0.02,
synthetic_doublet_umi_subsampling: float = 1.0,
knn_dist_metric: _Metrics = "euclidean",
normalize_variance: bool = True,
log_transform: bool = False,
mean_center: bool = True,
n_prin_comps: int = 30,
use_approx_neighbors: bool = True,
get_doublet_neighbor_parents: bool = False,
n_neighbors: int | None = None,
threshold: float | None = None,
verbose: bool = True,
copy: bool = False,
random_state: AnyRandom = 0,
) -> AnnData | None:
"""\
Predict doublets using Scrublet.
Predict cell doublets using a nearest-neighbor classifier of observed
transcriptomes and simulated doublets. Works best if the input is a raw
(unnormalized) counts matrix from a single sample or a collection of
similar samples from the same experiment.
This function is a wrapper around functions that pre-process using rapids-singlecell
and directly call functions of Scrublet(). You may also undertake your own
preprocessing, simulate doublets with
:func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, and run the core scrublet
function :func:`~rapids_singlecell.pp.scrublet` with ``adata_sim`` set.
Scrublet can also be run with a `dask array` if a batch key is provided. Please make sure that each batch can fit into memory. In addition to that scrublet will not return the full scrublet results, but only the `doublet score` and `predicted doublet`, not `.uns['scrublet']`. `adata_sim` is not supported for`dask arrays`.
Parameters
----------
adata
The annotated data matrix of shape ``n_obs`` × ``n_vars``. Rows
correspond to cells and columns to genes. Expected to be un-normalised
where adata_sim is not supplied, in which case doublets will be
simulated and pre-processing applied to both objects. If adata_sim is
supplied, this should be the observed transcriptomes processed
consistently (filtering, transform, normalisaton, hvg) with adata_sim.
adata_sim
(Advanced use case) Optional annData object generated by
:func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, with same number of vars
as adata. This should have been built from adata_obs after
filtering genes and cells and selecting highly-variable genes.
Not supported for dask arrays.
batch_key
Optional :attr:`~anndata.AnnData.obs` column name discriminating between batches.
sim_doublet_ratio
Number of doublets to simulate relative to the number of observed
transcriptomes.
expected_doublet_rate
Where adata_sim not supplied, the estimated doublet rate for the
experiment.
stdev_doublet_rate
Where adata_sim not supplied, uncertainty in the expected doublet rate.
synthetic_doublet_umi_subsampling
Where adata_sim not supplied, rate for sampling UMIs when creating
synthetic doublets. If 1.0, each doublet is created by simply adding
the UMI counts from two randomly sampled observed transcriptomes. For
values less than 1, the UMI counts are added and then randomly sampled
at the specified rate.
knn_dist_metric
Distance metric used when finding nearest neighbors. For list of
valid values, see the documentation :class:`cuml.neighbors.NearestNeighbors`.
normalize_variance
If True, normalize the data such that each gene has a variance of 1.
:class:`cuml.decomposition.tsvd.TruncatedSVD` will be used for dimensionality
reduction, unless `mean_center` is True.
log_transform
Whether to use :func:`~rapids_singlecell.pp.log1p` to log-transform the data
prior to PCA.
mean_center
If True, center the data such that each gene has a mean of 0.
:class:`cuml.decomposition.pca.PCA` will be used for dimensionality
reduction.
n_prin_comps
Number of principal components used to embed the transcriptomes prior
to k-nearest-neighbor graph construction.
use_approx_neighbors
Does not affect the results, just here to stay consistent with :func:`scanpy.pp.scrublet`.
get_doublet_neighbor_parents
If True, return (in .uns) the parent transcriptomes that generated the
doublet neighbors of each observed transcriptome. This information can
be used to infer the cell states that generated a given doublet state.
n_neighbors
Number of neighbors used to construct the KNN graph of observed
transcriptomes and simulated doublets. If ``None``, this is
automatically set to ``np.round(0.5 * np.sqrt(n_obs))``.
threshold
Doublet score threshold for calling a transcriptome a doublet. If
`None`, this is set automatically by looking for the minimum between
the two modes of the `doublet_scores_sim_` histogram. It is best
practice to check the threshold visually using the
`doublet_scores_sim_` histogram and/or based on co-localization of
predicted doublets in a 2-D embedding.
verbose
If :data:`True`, log progress updates.
copy
If :data:`True`, return a copy of the input ``adata`` with Scrublet results
added. Otherwise, Scrublet results are added in place.
random_state
Initial state for doublet simulation and nearest neighbors.
Returns
-------
if ``copy=True`` it returns or else adds fields to ``adata``. Those fields:
``.obs['doublet_score']``
Doublet scores for each observed transcriptome
``.obs['predicted_doublet']``
Boolean indicating predicted doublet status
``.uns['scrublet']['doublet_scores_sim']``
Doublet scores for each simulated doublet transcriptome
``.uns['scrublet']['doublet_parents']``
Pairs of ``.obs_names`` used to generate each simulated doublet
transcriptome
``.uns['scrublet']['parameters']``
Dictionary of Scrublet parameters
See also
--------
:func:`~rapids_singlecell.pp.scrublet_simulate_doublets`: Run Scrublet's doublet
simulation separately for advanced usage.
:func:`~scanpy.pl.scrublet_score_distribution`: Plot histogram of doublet
scores for observed transcriptomes and simulated doublets.
"""
if copy:
adata = adata.copy()
start = logg.info("Running Scrublet")
def _run_scrublet(ad_obs: AnnData, ad_sim: AnnData | None = None):
# With no adata_sim we assume the regular use case, starting with raw
# counts and simulating doublets
if isinstance(ad_obs.X, DaskArray):
ad_obs.X = ad_obs.X.compute()
if ad_sim is None:
pp.filter_genes(ad_obs, min_cells=3, verbose=False)
pp.filter_cells(ad_obs, min_genes=3, verbose=False)
# Doublet simulation will be based on the un-normalised counts, but on the
# selection of genes following normalisation and variability filtering. So
# we need to save the raw and subset at the same time.
ad_obs.layers["raw"] = ad_obs.X.copy()
pp.normalize_total(ad_obs)
# HVG process needs log'd data.
ad_obs.layers["logged"] = pp.log1p(ad_obs, inplace=False)
pp.highly_variable_genes(ad_obs, layer="logged")
del ad_obs.layers["logged"]
ad_obs = ad_obs[:, ad_obs.var["highly_variable"]].copy()
# Simulate the doublets based on the raw expressions from the normalised
# and filtered object.
ad_sim = scrublet_simulate_doublets(
ad_obs,
layer="raw",
sim_doublet_ratio=sim_doublet_ratio,
synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling,
random_seed=random_state,
)
if log_transform:
pp.log1p(ad_obs)
pp.log1p(ad_sim)
del ad_obs.layers["raw"]
# Now normalise simulated and observed in the same way
pp.normalize_total(ad_obs, target_sum=1e6)
pp.normalize_total(ad_sim, target_sum=1e6)
ad_obs = _scrublet_call_doublets(
adata_obs=ad_obs,
adata_sim=ad_sim,
n_neighbors=n_neighbors,
expected_doublet_rate=expected_doublet_rate,
stdev_doublet_rate=stdev_doublet_rate,
mean_center=mean_center,
normalize_variance=normalize_variance,
n_prin_comps=n_prin_comps,
use_approx_neighbors=use_approx_neighbors,
knn_dist_metric=knn_dist_metric,
get_doublet_neighbor_parents=get_doublet_neighbor_parents,
threshold=threshold,
random_state=random_state,
verbose=verbose,
)
return {"obs": ad_obs.obs, "uns": ad_obs.uns["scrublet"]}
_check_gpu_X(adata.X, allow_dask=True)
if batch_key is not None:
if batch_key not in adata.obs.keys():
raise ValueError(
"`batch_key` must be a column of .obs in the input annData object."
)
# Run Scrublet independently on batches and return just the
# scrublet-relevant parts of the objects to add to the input object
if isinstance(adata.X, DaskArray):
# Define function to process each batch chunk
def _process_batch_chunk(X_chunk):
"""Process a single batch chunk through Scrublet."""
batch_adata = AnnData(X_chunk)
batch_results = _run_scrublet(batch_adata, None)
return np.array(
batch_results["obs"][["doublet_score", "predicted_doublet"]]
).astype(np.float64)
# Get batch information and sort data by batch
batch_codes = adata.obs[batch_key].astype("category").cat.codes
sort_indices = np.argsort(batch_codes)
X_sorted = adata.X[sort_indices]
# Calculate chunk sizes based on batch sizes
batch_sizes = np.bincount(batch_codes.iloc[sort_indices])
X_rechunked = X_sorted.rechunk((tuple(batch_sizes), adata.X.shape[1]))
# Process all batches in parallel using map_blocks
batch_results = X_rechunked.map_blocks(
_process_batch_chunk,
meta=np.array([], dtype=np.float64),
dtype=np.float64,
chunks=(X_rechunked.chunks[0], 2),
)
# Convert results to DataFrame and restore original order
results_df = pd.DataFrame(
batch_results.compute(), columns=["doublet_score", "predicted_doublet"]
)
final_results = results_df.iloc[np.argsort(sort_indices)]
# Update the original AnnData object with results
adata.obs["doublet_score"] = final_results["doublet_score"].values
adata.obs["predicted_doublet"] = final_results[
"predicted_doublet"
].values.astype(bool)
adata.uns["scrublet"] = {"batched_by": batch_key}
else:
batches = np.unique(adata.obs[batch_key])
scrubbed = [
_run_scrublet(
adata[adata.obs[batch_key] == batch].copy(),
adata_sim,
)
for batch in batches
]
scrubbed_obs = pd.concat([scrub["obs"] for scrub in scrubbed])
# Now reset the obs to get the scrublet scores
adata.obs = scrubbed_obs.loc[adata.obs_names.values]
# Save the .uns from each batch separately
adata.uns["scrublet"] = {}
adata.uns["scrublet"]["batches"] = dict(
zip(batches, [scrub["uns"] for scrub in scrubbed])
)
adata.uns["scrublet"]["batched_by"] = batch_key
else:
adata_obs = adata.copy()
if isinstance(adata_obs.X, DaskArray):
raise ValueError(
"Dask arrays are not supported for Scrublet without a batch key. Please provide a batch key."
)
scrubbed = _run_scrublet(adata_obs, adata_sim)
# Copy outcomes to input object from our processed version
adata.obs["doublet_score"] = scrubbed["obs"]["doublet_score"]
adata.obs["predicted_doublet"] = scrubbed["obs"]["predicted_doublet"]
adata.uns["scrublet"] = scrubbed["uns"]
logg.info(" Scrublet finished", time=start)
return adata if copy else None
def _scrublet_call_doublets(
adata_obs: AnnData,
adata_sim: AnnData,
*,
n_neighbors: int | None = None,
expected_doublet_rate: float = 0.05,
stdev_doublet_rate: float = 0.02,
mean_center: bool = True,
normalize_variance: bool = True,
n_prin_comps: int = 30,
use_approx_neighbors: bool = True,
knn_dist_metric: _Metrics = "euclidean",
get_doublet_neighbor_parents: bool = False,
threshold: float | None = None,
random_state: AnyRandom = 0,
verbose: bool = True,
) -> AnnData:
"""\
Core function for predicting doublets using Scrublet.
Predict cell doublets using a nearest-neighbor classifier of observed
transcriptomes and simulated doublets.
Parameters
----------
adata_obs
The annotated data matrix of shape ``n_obs`` × ``n_vars``. Rows
correspond to cells and columns to genes. Should be normalised with
:func:`~rapids_singlecell.pp.normalize_total` and filtered to include only highly
variable genes.
adata_sim
Anndata object generated by
:func:`~rapids_singlecell.pp.scrublet_simulate_doublets`, with same number of vars
as adata_obs. This should have been built from adata_obs after
filtering genes and cells and selecting highly-variable genes.
n_neighbors
Number of neighbors used to construct the KNN graph of observed
transcriptomes and simulated doublets. If ``None``, this is
automatically set to ``np.round(0.5 * np.sqrt(n_obs))``.
expected_doublet_rate
The estimated doublet rate for the experiment.
stdev_doublet_rate
Uncertainty in the expected doublet rate.
mean_center
If True, center the data such that each gene has a mean of 0.
:class:`cuml.decomposition.PCA` will be used for dimensionality
reduction.
normalize_variance
If True, normalize the data such that each gene has a variance of 1.
:class:`cuml.decomposition.tsvd.TruncatedSVD` will be used for dimensionality
reduction, unless `mean_center` is True.
n_prin_comps
Number of principal components used to embed the transcriptomes prior
to k-nearest-neighbor graph construction.
use_approx_neighbors
Does not affect the results, just here to stay consistent with :func:`scanpy.pp.scrublet`.
knn_dist_metric
Distance metric used when finding nearest neighbors. For list of
valid values, see the documentation cuml.neighbors.NearestNeighbors.
get_doublet_neighbor_parents
If True, return the parent transcriptomes that generated the
doublet neighbors of each observed transcriptome. This information can
be used to infer the cell states that generated a given
doublet state.
threshold
Doublet score threshold for calling a transcriptome a doublet. If
`None`, this is set automatically by looking for the minimum between
the two modes of the `doublet_scores_sim_` histogram. It is best
practice to check the threshold visually using the
`doublet_scores_sim_` histogram and/or based on co-localization of
predicted doublets in a 2-D embedding.
random_state
Initial state for doublet simulation and nearest neighbors.
verbose
If :data:`True`, log progress updates.
Returns
-------
if ``copy=True`` it returns or else adds fields to ``adata``:
``.obs['doublet_score']``
Doublet scores for each observed transcriptome
``.obs['predicted_doublets']``
Boolean indicating predicted doublet status
``.uns['scrublet']['doublet_scores_sim']``
Doublet scores for each simulated doublet transcriptome
``.uns['scrublet']['doublet_parents']``
Pairs of ``.obs_names`` used to generate each simulated doublet transcriptome
``.uns['scrublet']['parameters']``
Dictionary of Scrublet parameters
"""
# Estimate n_neighbors if not provided, and create scrublet object.
if n_neighbors is None:
n_neighbors = int(round(0.5 * np.sqrt(adata_obs.shape[0])))
# Note: Scrublet() will sparse adata_obs.X if it's not already, but this
# matrix won't get used if we pre-set the normalised slots.
scrub = Scrublet(
adata_obs.X,
n_neighbors=n_neighbors,
expected_doublet_rate=expected_doublet_rate,
stdev_doublet_rate=stdev_doublet_rate,
random_state=random_state,
)
del scrub._counts_obs
# Ensure normalised matrix sparseness as Scrublet does
# https://github.com/swolock/scrublet/blob/67f8ecbad14e8e1aa9c89b43dac6638cebe38640/src/scrublet/scrublet.py#L100
scrub._counts_obs_norm = sparse.csc_matrix(adata_obs.X)
scrub._counts_sim_norm = sparse.csc_matrix(adata_sim.X)
del adata_obs.X, adata_sim.X
scrub.doublet_parents_ = adata_sim.obsm["doublet_parents"]
# Call scrublet-specific preprocessing where specified
if mean_center and normalize_variance:
pipeline.zscore(scrub)
elif mean_center:
pipeline.mean_center(scrub)
elif normalize_variance:
pipeline.normalize_variance(scrub)
# Do PCA. Scrublet fits to the observed matrix and decomposes both observed
# and simulated based on that fit, so we'll just let it do its thing rather
# than trying to use Scanpy's PCA wrapper of the same functions.
if mean_center:
logg.info("Embedding transcriptomes using PCA...")
pipeline.pca(scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state)
else:
logg.info("Embedding transcriptomes using Truncated SVD...")
pipeline.truncated_svd(
scrub, n_prin_comps=n_prin_comps, random_state=scrub._random_state
)
# Score the doublets
scrub.calculate_doublet_scores(
use_approx_neighbors=use_approx_neighbors,
distance_metric=knn_dist_metric,
get_doublet_neighbor_parents=get_doublet_neighbor_parents,
)
# Actually call doublets
scrub.call_doublets(threshold=threshold, verbose=verbose)
# Store results in AnnData for return
adata_obs.obs["doublet_score"] = scrub.doublet_scores_obs_
# Store doublet Scrublet metadata
adata_obs.uns["scrublet"] = {
"doublet_scores_sim": scrub.doublet_scores_sim_,
"doublet_parents": adata_sim.obsm["doublet_parents"],
"parameters": {
"expected_doublet_rate": expected_doublet_rate,
"sim_doublet_ratio": (
adata_sim.uns.get("scrublet", {})
.get("parameters", {})
.get("sim_doublet_ratio", None)
),
"n_neighbors": n_neighbors,
"random_state": random_state,
},
}
# If threshold hasn't been located successfully then we couldn't make any
# predictions. The user will get a warning from Scrublet, but we need to
# set the boolean so that any downstream filtering on
# predicted_doublet=False doesn't incorrectly filter cells. The user can
# still use this object to generate the plot and derive a threshold
# manually.
if hasattr(scrub, "threshold_"):
adata_obs.uns["scrublet"]["threshold"] = scrub.threshold_
adata_obs.obs["predicted_doublet"] = scrub.predicted_doublets_
else:
adata_obs.obs["predicted_doublet"] = False
if get_doublet_neighbor_parents:
adata_obs.uns["scrublet"]["doublet_neighbor_parents"] = (
scrub.doublet_neighbor_parents_
)
return adata_obs
@old_positionals(
"layer", "sim_doublet_ratio", "synthetic_doublet_umi_subsampling", "random_seed"
)
def scrublet_simulate_doublets(
adata: AnnData,
*,
layer: str | None = None,
sim_doublet_ratio: float = 2.0,
synthetic_doublet_umi_subsampling: float = 1.0,
random_seed: AnyRandom = 0,
) -> AnnData:
"""
Simulate doublets by adding the counts of random observed transcriptome pairs.
Parameters
----------
adata
The annotated data matrix of shape ``n_obs`` × ``n_vars``. Rows
correspond to cells and columns to genes. Genes should have been
filtered for expression and variability, and the object should contain
raw expression of the same dimensions.
layer
Layer of adata where raw values are stored, or 'X' if values are in .X.
sim_doublet_ratio
Number of doublets to simulate relative to the number of observed
transcriptomes. If `None`, self.sim_doublet_ratio is used.
synthetic_doublet_umi_subsampling
Rate for sampling UMIs when creating synthetic doublets. If 1.0,
each doublet is created by simply adding the UMIs from two randomly
sampled observed transcriptomes. For values less than 1, the
UMI counts are added and then randomly sampled at the specified
rate.
Returns
-------
adata : anndata.AnnData with simulated doublets in .X
Adds fields to ``adata``:
``.obsm['scrublet']['doublet_parents']``
Pairs of ``.obs_names`` used to generate each simulated doublet transcriptome
``.uns['scrublet']['parameters']``
Dictionary of Scrublet parameters
See also
--------
:func:`~rapids_singlecell.pp.scrublet`: Main way of running Scrublet, runs
preprocessing, doublet simulation (this function) and calling.
:func:`~scanpy.pl.scrublet_score_distribution`: Plot histogram of doublet
scores for observed transcriptomes and simulated doublets.
"""
X = _get_obs_rep(adata, layer=layer)
scrub = Scrublet(X, random_state=random_seed)
scrub.simulate_doublets(
sim_doublet_ratio=sim_doublet_ratio,
synthetic_doublet_umi_subsampling=synthetic_doublet_umi_subsampling,
)
adata_sim = AnnData(scrub._counts_sim)
adata_sim.obs["n_counts"] = scrub._total_counts_sim.get()
adata_sim.obsm["doublet_parents"] = scrub.doublet_parents_ # .get()
adata_sim.uns["scrublet"] = {"parameters": {"sim_doublet_ratio": sim_doublet_ratio}}
return adata_sim