Skip to content

Commit db0900c

Browse files
authored
feat: support for rapids singlecell in scvi-tools (#3811)
1 parent 513012f commit db0900c

5 files changed

Lines changed: 36 additions & 14 deletions

File tree

CHANGELOG.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,20 @@
33
Starting from version 0.20.1, this format is based on [Keep a Changelog], and this project adheres
44
to [Semantic Versioning]. The full commit history is available in the [commit logs](https://github.com/scverse/scvi-tools/commits/).
55

6+
## Version 1.5
7+
8+
### 1.5.0 (2026-XX-XX)
9+
10+
#### Added
11+
12+
- Add support for rapids-singlecell, {pr}`3811`.
13+
14+
#### Fixed
15+
16+
#### Changed
17+
18+
#### Removed
19+
620
## Version 1.4
721

822
### 1.4.3 (2026-05-12)

Dockerfile

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
22
FROM python:3.12 AS base
33

4-
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/nvidia/lib${LD_LIBRARY_PATH:+:$LD_LIBRARY_PATH}
5-
64
RUN pip install --no-cache-dir uv
75

86
RUN uv pip install --system --no-cache torch torchvision torchaudio

pyproject.toml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"]
5959
editing = ["jupyter", "pre-commit"]
6060
dev = ["scvi-tools[editing,tests]"]
6161
test = ["scvi-tools[tests]"]
62-
cuda = ["torchvision", "torchaudio", "jax[cuda12]","mlx[cuda]"]
62+
cuda = ["torchvision", "torchaudio", "jax[cuda]", "mlx[cuda]"]
63+
cuda13 = ["torchvision", "torchaudio", "jax[cuda13]", "mlx[cuda13]"]
6364
tpu = ["torch_xla[tpu]"]
6465
metal = ["torchvision", "torchaudio", "jax-metal","mlx-metal"]
6566

@@ -79,8 +80,8 @@ docs = [
7980
]
8081
docsbuild = ["scvi-tools[docs,autotune,hub,jax,diagvi]","mlx"]
8182

82-
# scvi.autotune #TODO remove ray[tune] constraint once solved
83-
autotune = ["hyperopt>=0.2", "ray[tune]; python_version < '3.14'", "scib-metrics", "muon"]
83+
# scvi.autotune
84+
autotune = ["hyperopt>=0.2", "ray[tune]", "scib-metrics", "muon"]
8485
# scvi.hub dependencies
8586
hub = ["huggingface_hub", "dvc[s3]", "boto3"]
8687
# scvi.data.add_dna_sequence
@@ -99,6 +100,9 @@ dataloaders = ["lamindb>=1.12.1", "cellxgene-census", "tiledbsoma", "tiledbsoma_
99100
diagvi = ["torch_geometric", "geomloss"]
100101
# for mlflow
101102
mlflow = ["mlflow","psutil","GPUtil","nvidia-ml-py"]
103+
# for rapids
104+
rapids = [ "cugraph>=24", "cuml>=24", "cupy-cuda12x", "rapids-singlecell[rapids]" ]
105+
rapids-cuda13 = [ "cugraph>=24", "cuml>=24", "cupy-cuda13x", "rapids-singlecell[rapids]" ]
102106

103107
optional = [
104108
"scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability,diagvi]",

src/scvi/criticism/_ppc.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -331,9 +331,17 @@ def differential_expression(
331331
sc.pp.normalize_total(adata_de, target_sum=cell_scale_factor)
332332
sc.pp.log1p(adata_de)
333333
if de_groupby is None:
334-
sc.tl.pca(adata_de)
335-
sc.pp.neighbors(adata_de)
336-
sc.tl.leiden(adata_de, key_added="leiden_scvi_criticism")
334+
try:
335+
import rapids_singlecell as rsc
336+
337+
print("RAPIDS SingleCell is installed and can be imported")
338+
rsc.pp.pca(adata_de)
339+
rsc.pp.neighbors(adata_de)
340+
rsc.tl.leiden(adata_de, key_added="leiden_scvi_criticism")
341+
except ImportError:
342+
sc.tl.pca(adata_de)
343+
sc.pp.neighbors(adata_de)
344+
sc.tl.leiden(adata_de, key_added="leiden_scvi_criticism")
337345
de_groupby = "leiden_scvi_criticism"
338346
with warnings.catch_warnings():
339347
warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

src/scvi/external/resolvi/_model.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ def _prepare_data(
360360
if slice_key is not None:
361361
batch_key = slice_key
362362
try:
363-
import scanpy
363+
import scanpy as sc
364364
from sklearn.neighbors._base import _kneighbors_from_graph
365365
except ImportError as err:
366366
raise ImportError(
@@ -381,14 +381,12 @@ def _prepare_data(
381381
for index in indices:
382382
sub_data = adata[index].copy()
383383
try:
384-
import rapids_singlecell
384+
import rapids_singlecell as rsc
385385

386386
print("RAPIDS SingleCell is installed and can be imported")
387-
rapids_singlecell.pp.neighbors(
388-
sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep
389-
)
387+
rsc.pp.neighbors(sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep)
390388
except ImportError:
391-
scanpy.pp.neighbors(sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep)
389+
sc.pp.neighbors(sub_data, n_neighbors=n_neighbors + 5, use_rep=spatial_rep)
392390
distances = sub_data.obsp["distances"] ** 2
393391

394392
distance_neighbor[index, :], index_neighbor_batch = _kneighbors_from_graph(

0 commit comments

Comments
 (0)