@@ -59,7 +59,8 @@ tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"]
5959editing = [" jupyter" , " pre-commit" ]
6060dev = [" scvi-tools[editing,tests]" ]
6161test = [" scvi-tools[tests]" ]
62- cuda = [" torchvision" , " torchaudio" , " jax[cuda12]" ," mlx[cuda]" ]
62+ cuda = [" torchvision" , " torchaudio" , " jax[cuda]" , " mlx[cuda]" ]
63+ cuda13 = [" torchvision" , " torchaudio" , " jax[cuda13]" , " mlx[cuda13]" ]
6364tpu = [" torch_xla[tpu]" ]
6465metal = [" torchvision" , " torchaudio" , " jax-metal" ," mlx-metal" ]
6566
@@ -79,8 +80,8 @@ docs = [
7980]
8081docsbuild = [" scvi-tools[docs,autotune,hub,jax,diagvi]" ," mlx" ]
8182
82- # scvi.autotune #TODO remove ray[tune] constraint once solved
83- autotune = [" hyperopt>=0.2" , " ray[tune]; python_version < '3.14' " , " scib-metrics" , " muon" ]
83+ # scvi.autotune
84+ autotune = [" hyperopt>=0.2" , " ray[tune]" , " scib-metrics" , " muon" ]
8485# scvi.hub dependencies
8586hub = [" huggingface_hub" , " dvc[s3]" , " boto3" ]
8687# scvi.data.add_dna_sequence
@@ -99,6 +100,9 @@ dataloaders = ["lamindb>=1.12.1", "cellxgene-census", "tiledbsoma", "tiledbsoma_
99100diagvi = [" torch_geometric" , " geomloss" ]
100101# for mlflow
101102mlflow = [" mlflow" ," psutil" ," GPUtil" ," nvidia-ml-py" ]
103+ # for rapids
104+ rapids = [ " cugraph>=24" , " cuml>=24" , " cupy-cuda12x" , " rapids-singlecell[rapids]" ]
105+ rapids-cuda13 = [ " cugraph>=24" , " cuml>=24" , " cupy-cuda13x" , " rapids-singlecell[rapids]" ]
102106
103107optional = [
104108 " scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability,diagvi]" ,
0 commit comments