Skip to content

Commit 1bff1c7

Browse files
committed
fix umap and big data neighbors
1 parent dbc0400 commit 1bff1c7

2 files changed

Lines changed: 29 additions & 11 deletions

File tree

src/rapids_singlecell/preprocessing/_neighbors/__init__.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import cupy as cp
99
import numpy as np
1010
from cupyx.scipy import sparse as cp_sparse
11+
from scipy import sparse as sc_sparse
1112

1213
from rapids_singlecell.preprocessing._neighbors._algorithms._all_neighbors import (
1314
_all_neighbors_knn,
@@ -263,9 +264,16 @@ def neighbors(
263264

264265
n_nonzero = n_obs * n_neighbors
265266
rowptr = cp.arange(0, n_nonzero + 1, n_neighbors)
266-
distances = cp_sparse.csr_matrix(
267-
(cp.ravel(knn_dist), cp.ravel(knn_indices), rowptr), shape=(n_obs, n_obs)
268-
)
267+
if n_nonzero >= np.iinfo(np.int32).max:
268+
distances = sc_sparse.csr_matrix(
269+
(cp.ravel(knn_dist).get(), cp.ravel(knn_indices).get(), rowptr.get()),
270+
shape=(n_obs, n_obs),
271+
)
272+
else:
273+
distances = cp_sparse.csr_matrix(
274+
(cp.ravel(knn_dist), cp.ravel(knn_indices), rowptr), shape=(n_obs, n_obs)
275+
)
276+
distances = distances.get()
269277

270278
connectivities = _get_connectivities(
271279
n_neighbors=n_neighbors,
@@ -275,8 +283,10 @@ def neighbors(
275283
knn_indices=knn_indices,
276284
knn_dist=knn_dist,
277285
)
278-
connectivities = connectivities.tocsr().get()
279-
distances = distances.get()
286+
if connectivities.nnz >= np.iinfo(np.int32).max:
287+
connectivities = connectivities.get().tocsr()
288+
else:
289+
connectivities = connectivities.tocsr().get()
280290
if key_added is None:
281291
key_added = "neighbors"
282292
conns_key = "connectivities"
@@ -471,9 +481,17 @@ def bbknn(
471481

472482
n_nonzero = n_obs * total_neighbors
473483
rowptr = cp.arange(0, n_nonzero + 1, total_neighbors)
474-
distances = cp_sparse.csr_matrix(
475-
(cp.ravel(knn_dist), cp.ravel(knn_indices), rowptr), shape=(n_obs, n_obs)
476-
)
484+
if rowptr.max() >= np.iinfo(np.int32).max:
485+
distances = sc_sparse.csr_matrix(
486+
(cp.ravel(knn_dist).get(), cp.ravel(knn_indices).get(), rowptr.get()),
487+
shape=(n_obs, n_obs),
488+
)
489+
else:
490+
distances = cp_sparse.csr_matrix(
491+
(cp.ravel(knn_dist), cp.ravel(knn_indices), rowptr), shape=(n_obs, n_obs)
492+
)
493+
distances = distances.get()
494+
477495
connectivities = _get_connectivities(
478496
total_neighbors,
479497
n_obs=n_obs,
@@ -490,7 +508,6 @@ def bbknn(
490508
connectivities = _trimming(connectivities, trim)
491509

492510
connectivities = connectivities.get()
493-
distances = distances.get()
494511
if key_added is None:
495512
key_added = "neighbors"
496513
conns_key = "connectivities"

src/rapids_singlecell/tools/_umap.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,10 @@ def umap(
159159
n_epochs = (
160160
500 if maxiter is None else maxiter
161161
) # 0 is not a valid value for rapids, unlike original umap
162-
162+
if neighbors["connectivities"].nnz > np.iinfo(np.int32).max:
163+
use_umap = True
163164
n_obs = adata.shape[0]
164-
if parse_version(cuml.__version__) < parse_version("24.10"):
165+
if parse_version(cuml.__version__) < parse_version("24.10") or use_umap:
165166
# `simplicial_set_embedding` is bugged in cuml<24.10. This is why we use `UMAP` instead.
166167
n_neighbors = neigh_params["n_neighbors"]
167168
if neigh_params.get("method") == "rapids":

0 commit comments

Comments
 (0)