88import cupy as cp
99import numpy as np
1010from cupyx .scipy import sparse as cp_sparse
11+ from scipy import sparse as sc_sparse
1112
1213from rapids_singlecell .preprocessing ._neighbors ._algorithms ._all_neighbors import (
1314 _all_neighbors_knn ,
@@ -263,9 +264,16 @@ def neighbors(
263264
264265 n_nonzero = n_obs * n_neighbors
265266 rowptr = cp .arange (0 , n_nonzero + 1 , n_neighbors )
266- distances = cp_sparse .csr_matrix (
267- (cp .ravel (knn_dist ), cp .ravel (knn_indices ), rowptr ), shape = (n_obs , n_obs )
268- )
267+ if n_nonzero >= np .iinfo (np .int32 ).max :
268+ distances = sc_sparse .csr_matrix (
269+ (cp .ravel (knn_dist ).get (), cp .ravel (knn_indices ).get (), rowptr .get ()),
270+ shape = (n_obs , n_obs ),
271+ )
272+ else :
273+ distances = cp_sparse .csr_matrix (
274+ (cp .ravel (knn_dist ), cp .ravel (knn_indices ), rowptr ), shape = (n_obs , n_obs )
275+ )
276+ distances = distances .get ()
269277
270278 connectivities = _get_connectivities (
271279 n_neighbors = n_neighbors ,
@@ -275,8 +283,10 @@ def neighbors(
275283 knn_indices = knn_indices ,
276284 knn_dist = knn_dist ,
277285 )
278- connectivities = connectivities .tocsr ().get ()
279- distances = distances .get ()
286+ if connectivities .nnz >= np .iinfo (np .int32 ).max :
287+ connectivities = connectivities .get ().tocsr ()
288+ else :
289+ connectivities = connectivities .tocsr ().get ()
280290 if key_added is None :
281291 key_added = "neighbors"
282292 conns_key = "connectivities"
@@ -471,9 +481,17 @@ def bbknn(
471481
472482 n_nonzero = n_obs * total_neighbors
473483 rowptr = cp .arange (0 , n_nonzero + 1 , total_neighbors )
474- distances = cp_sparse .csr_matrix (
475- (cp .ravel (knn_dist ), cp .ravel (knn_indices ), rowptr ), shape = (n_obs , n_obs )
476- )
484+ if rowptr .max () >= np .iinfo (np .int32 ).max :
485+ distances = sc_sparse .csr_matrix (
486+ (cp .ravel (knn_dist ).get (), cp .ravel (knn_indices ).get (), rowptr .get ()),
487+ shape = (n_obs , n_obs ),
488+ )
489+ else :
490+ distances = cp_sparse .csr_matrix (
491+ (cp .ravel (knn_dist ), cp .ravel (knn_indices ), rowptr ), shape = (n_obs , n_obs )
492+ )
493+ distances = distances .get ()
494+
477495 connectivities = _get_connectivities (
478496 total_neighbors ,
479497 n_obs = n_obs ,
@@ -490,7 +508,6 @@ def bbknn(
490508 connectivities = _trimming (connectivities , trim )
491509
492510 connectivities = connectivities .get ()
493- distances = distances .get ()
494511 if key_added is None :
495512 key_added = "neighbors"
496513 conns_key = "connectivities"
0 commit comments