3
3
from dataclasses import asdict , dataclass
4
4
from enum import Enum
5
5
from functools import partial
6
- from typing import Any , Dict , List , Optional , Union
6
+ from typing import Any , Callable , Dict , List , Optional , Union
7
7
8
8
import matplotlib
9
9
import matplotlib .pyplot as plt
14
14
from plottable import ColumnDefinition , Table
15
15
from plottable .cmap import normed_cmap
16
16
from plottable .plots import bar
17
- from pynndescent import NNDescent
18
17
from sklearn .preprocessing import MinMaxScaler
19
18
from tqdm import tqdm
20
19
21
20
import scib_metrics
21
+ from scib_metrics .nearest_neighbors import NeighborsOutput , pynndescent
22
22
23
23
Kwargs = Dict [str , Any ]
24
24
MetricType = Union [bool , Kwargs ]
@@ -156,8 +156,17 @@ def __init__(
156
156
"Batch correction" : self ._batch_correction_metrics ,
157
157
}
158
158
159
- def prepare (self ) -> None :
160
- """Prepare the data for benchmarking."""
159
+ def prepare (self , neighbor_computer : Optional [Callable [[np .ndarray , int ], NeighborsOutput ]] = None ) -> None :
160
+ """Prepare the data for benchmarking.
161
+
162
+ Parameters
163
+ ----------
164
+ neighbor_computer
165
+ Function that computes the neighbors of the data. If `None`, the neighbors will be computed
166
+ with :func:`~scib_metrics.utils.nearest_neighbors.pynndescent`. The function should take as input
167
+ the data and the number of neighbors to compute and return a :class:`~scib_metrics.utils.nearest_neighbors.NeighborsOutput`
168
+ object.
169
+ """
161
170
# Compute PCA
162
171
if self ._pre_integrated_embedding_obsm_key is None :
163
172
# This is how scib does it
@@ -173,24 +182,13 @@ def prepare(self) -> None:
173
182
174
183
# Compute neighbors
175
184
for ad in tqdm (self ._emb_adatas .values (), desc = "Computing neighbors" ):
176
- # Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
177
- # which is used by scanpy under the hood
178
- n_trees = min (64 , 5 + int (round ((ad .X .shape [0 ]) ** 0.5 / 20.0 )))
179
- n_iters = max (5 , int (round (np .log2 (ad .X .shape [0 ]))))
180
- max_candidates = 60
181
-
182
- knn_search_index = NNDescent (
183
- ad .X ,
184
- n_neighbors = max (self ._neighbor_values ),
185
- random_state = 0 ,
186
- low_memory = True ,
187
- n_jobs = self ._n_jobs ,
188
- compressed = False ,
189
- n_trees = n_trees ,
190
- n_iters = n_iters ,
191
- max_candidates = max_candidates ,
192
- )
193
- indices , distances = knn_search_index .neighbor_graph
185
+ if neighbor_computer is not None :
186
+ neigh_output = neighbor_computer (ad .X , max (self ._neighbor_values ))
187
+ else :
188
+ neigh_output = pynndescent (
189
+ ad .X , n_neighbors = max (self ._neighbor_values ), random_state = 0 , n_jobs = self ._n_jobs
190
+ )
191
+ indices , distances = neigh_output .indices , neigh_output .distances
194
192
for n in self ._neighbor_values :
195
193
sp_distances , sp_conns = sc .neighbors ._compute_connectivities_umap (
196
194
indices [:, :n ], distances [:, :n ], ad .n_obs , n_neighbors = n
0 commit comments