Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions .codecov.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@ flags:
paths:
- src/scvi/
carryforward: false
nonjax:
paths:
- src/scvi/
carryforward: true
cuda:
paths:
- src/scvi/
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/test_linux_jax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ jobs:
run: |
python -m pip install --upgrade pip wheel uv
python -m uv pip install --system "scvi-tools[tests] @ ."
python -m uv pip install --system "jax"
python -m uv pip install --system "flax"
python -m uv pip install --system "numpyro"
python -m uv pip install --system "optax"

- name: Run pytest
env:
Expand Down
76 changes: 0 additions & 76 deletions .github/workflows/test_linux_nonjax.yml

This file was deleted.

1 change: 0 additions & 1 deletion .github/workflows/test_macos_mps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ jobs:
python -m pip install --upgrade pip wheel uv
python -m pip install "scvi-tools[tests]"
python -m pip install mlx
python -m pip install jax-metal
python -m pip install mlx-metal
python -m pip install coverage
python -m pip install pytest
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo
#### Changed

- Update SCVI-Tools Hub models, {pr}`3733`.
- Removed Jax support from SCVI-Tools, {pr}`37xx`.

#### Removed

Expand Down
8 changes: 0 additions & 8 deletions docs/api/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ Parameterizable probability distributions.
distributions.NegativeBinomial
distributions.NegativeBinomialMixture
distributions.ZeroInflatedNegativeBinomial
distributions.JaxNegativeBinomialMeanDisp
distributions.BetaBinomial

```
Expand Down Expand Up @@ -158,8 +157,6 @@ Existing module classes with respective generative and inference procedures.
module.VAE
module.VAEC
module.AmortizedLDAPyroModule
module.JaxVAE

```

## External module
Expand All @@ -182,12 +179,10 @@ Module classes in the external API with respective generative and inference proc
external.contrastivevi.ContrastiveDataSplitter
external.stereoscope.RNADeconv
external.stereoscope.SpatialDeconv
external.tangram.TangramMapper
external.scbasset.ScBassetModule
external.contrastivevi.ContrastiveVAE
external.velovi.VELOVAE
external.mrvi.MRVAE
external.mrvi_jax.JaxMRVAE
external.mrvi_torch.TorchMRVAE
external.methylvi.METHYLVAE
external.methylvi.METHYLANVAE
Expand Down Expand Up @@ -218,11 +213,9 @@ These classes should be used to construct module classes that define generative
module.base.BaseMinifiedModeModuleClass
module.base.SupervisedModuleClass
module.base.PyroBaseModuleClass
module.base.JaxBaseModuleClass
module.base.EmbeddingModuleMixin
module.base.LossOutput
module.base.auto_move_data

```

## Neural networks
Expand Down Expand Up @@ -272,7 +265,6 @@ TrainingPlans define train/test/val optimization steps for modules.
train.SemiSupervisedAdversarialTrainingPlan
train.LowLevelPyroTrainingPlan
train.PyroTrainingPlan
train.JaxTrainingPlan
train.Trainer
train.TrainingPlan
train.TrainRunner
Expand Down
4 changes: 0 additions & 4 deletions docs/api/user.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ import scvi
model.TOTALVI
model.MULTIVI
model.AmortizedLDA
model.JaxSCVI
model.mlxSCVI
```

Expand All @@ -55,14 +54,12 @@ import scvi
external.SpatialStereoscope
external.SOLO
external.SCAR
external.Tangram
external.SCBASSET
external.ContrastiveVI
external.POISSONVI
external.VELOVI
external.MRVI
external.TorchMRVI
external.JaxMRVI
external.METHYLVI
external.METHYLANVI
external.Decipher
Expand Down Expand Up @@ -145,7 +142,6 @@ Here we maintain a few package specific utilities for feature selection, etc.
train.PyroTrainingPlanConfig
train.LowLevelPyroTrainingPlanConfig
train.ClassifierTrainingPlanConfig
train.JaxTrainingPlanConfig
train.TrainerConfig
```

Expand Down
6 changes: 2 additions & 4 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ In this case we recommend installing PyTorch and JAX _before_ installing scvi-to
Please follow the respective installation instructions for [PyTorch](https://pytorch.org/get-started/locally/) and
[JAX](https://jax.readthedocs.io/en/latest/installation.html) compatible with your system and device type.

Note: starting v1.5, Jax is no longer supported in scvi-tools.

## Optional dependencies

scvi-tools is installed in its lightest form by default.
Expand All @@ -89,7 +91,6 @@ It has many optional dependencies which expand its capabilities:
- _parallel_ - for parallelization engine
- _interpretability_ - for supervised models interpretability
- _dataloaders_ - for custom dataloaders use
- _jax_ - for Jax support
- _mlflow_ - for MLflow support
- _tests_ - in order to be able to perform tests
- _editing_ - for code editing
Expand All @@ -101,9 +102,6 @@ It has many optional dependencies which expand its capabilities:
The easiest way to install this is with `pip`.
To install capability X run: _pip install scvi-tools[X]_

You can install several capabilities together, e.g:
To install scvi-tools with JAX support for GPU on Ubuntu: _pip install scvi-tools[cuda,jax]_

To install all tutorial dependencies:

```bash
Expand Down
2 changes: 1 addition & 1 deletion docs/user_guide/models/scvi.md
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ distributions of the latent variables, retrieving the likelihood parameters (of
The standard {class}`~scvi.model.SCVI` class uses PyTorch as its computational backend.
For users who prefer a different framework or are running on hardware where another backend offers better performance, two experimental alternatives are available:

- **JAX** – {class}`~scvi.model.JaxSCVI` is a JAX-based implementation of scVI. It can be substantially faster than the PyTorch implementation on CPUs (e.g., comparable to PyTorch on a GPU on a multi-core machine) and works on any platform supported by JAX.
- **JAX** – {class}`~scvi.model.JaxSCVI` is a JAX-based implementation of scVI. It can be substantially faster than the PyTorch implementation on CPUs (e.g., comparable to PyTorch on a GPU on a multi-core machine) and works on any platform supported by JAX. This version is deprecated starting v1.5.
- **MLX (Apple Silicon)** – {class}`~scvi.model.mlxSCVI` is an MLX-based implementation optimized for Apple Silicon (M-series) chips via the [MLX](https://ml-explore.github.io/mlx/) framework. It is only available on macOS with Apple Silicon.

Both alternatives expose the same high-level API (e.g., `setup_anndata`, `train`, `get_latent_representation`, `save`, `load`) as {class}`~scvi.model.SCVI`, though they may have reduced feature sets compared to the full PyTorch implementation.
Expand Down
4 changes: 4 additions & 0 deletions docs/user_guide/models/tangram.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Tangram

:::{note}
This model is deprecated starting v1.5.
:::

**Tangram** {cite:p}`Biancalani21` (Python class {class}`~scvi.external.Tangram`) maps single-cell RNA-seq data to spatial data, permitting deconvolution of cell types in spatial data like Visium.

This is a reimplementation of Tangram, which can originally be found [here](https://github.com/broadinstitute/Tangram).
Expand Down
1 change: 0 additions & 1 deletion docs/user_guide/use_case/training_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,5 @@ Use the plan config that matches the training plan behind your model:
models with adversarial mixing.
- `PyroTrainingPlanConfig` / `LowLevelPyroTrainingPlanConfig` → Pyro‑based models.
- `ClassifierTrainingPlanConfig` → classifier training plans.
- `JaxTrainingPlanConfig` → Jax training plans.

If you’re unsure, you can still use `plan_kwargs` exactly as before.
15 changes: 6 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"]
editing = ["jupyter", "pre-commit"]
dev = ["scvi-tools[editing,tests]"]
test = ["scvi-tools[tests]"]
cuda = ["torchvision", "torchaudio", "jax[cuda12]","mlx[cuda]"]
cuda = ["torchvision", "torchaudio","mlx[cuda]"]
tpu = ["torch_xla[tpu]"]
metal = ["torchvision", "torchaudio", "jax-metal","mlx-metal"]
metal = ["torchvision", "torchaudio","mlx-metal"]

docs = [
"docutils>=0.8,!=0.18.*,!=0.19.*", # see https://github.com/scverse/cookiecutter-scverse/pull/205
Expand All @@ -77,10 +77,10 @@ docs = [
"myst-nb",
"sphinx-autodoc-typehints",
]
docsbuild = ["scvi-tools[docs,autotune,hub,jax]","mlx"]
docsbuild = ["scvi-tools[docs,autotune,hub]","mlx"]

# scvi.autotune #TODO remove ray[tune] constraint once solved
autotune = ["hyperopt>=0.2", "ray[tune]; python_version < '3.14'", "scib-metrics", "muon"]
# scvi.autotune
autotune = ["hyperopt>=0.2", "ray[tune]", "scib-metrics", "muon"]
# scvi.hub dependencies
hub = ["huggingface_hub", "dvc[s3]", "boto3"]
# scvi.data.add_dna_sequence
Expand All @@ -91,15 +91,13 @@ file_sharing = ["pooch","gdown","readfcs","fcswrite"]
parallel = ["dask[array]", "zarr"]
# for models interpretability
interpretability = ["captum", "shap", "decoupler"]
# for jax support
jax = ["jax", "jaxlib", "optax", "numpyro", "flax"]
# for custom dataloders
dataloaders = ["lamindb>=1.12.1", "cellxgene-census", "tiledbsoma", "tiledbsoma_ml", "torchdata"]
# for mlflow
mlflow = ["mlflow","psutil","GPUtil","nvidia-ml-py"]

optional = [
"scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability]",
"scvi-tools[autotune,mlflow,hub,file_sharing,regseq,parallel,interpretability]",
"igraph","leidenalg","pynndescent",
]
tutorials = [
Expand Down Expand Up @@ -138,7 +136,6 @@ markers = [
"autotune: mark tests that are used to check ray autotune capabilities",
"custom dataloaders: mark tests that are used to check different custom data loaders",
"dataloader: mark tests that are used to check data loaders",
"jax: mark test as jax related",
"mlflow: mark test for mlflow",
]

Expand Down
1 change: 0 additions & 1 deletion src/scvi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

settings.verbosity = logging.INFO

# Jax sets the root logger, this prevents double output.
scvi_logger = logging.getLogger("scvi")
scvi_logger.propagate = False

Expand Down
36 changes: 0 additions & 36 deletions src/scvi/_settings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -42,9 +41,6 @@ class ScviConfig:

>>> scvi.settings.num_threads = 2

To prevent Jax from preallocating GPU memory on start (default)

>>> scvi.settings.jax_preallocate_gpu_memory = False
"""

def __init__(
Expand All @@ -56,7 +52,6 @@ def __init__(
logging_dir: str = "./scvi_log/",
dl_num_workers: int = 0,
dl_persistent_workers: bool = False,
jax_preallocate_gpu_memory: bool = False,
warnings_stacklevel: int = 2,
mlflow_set_tracking_uri: str = "",
mlflow_set_experiment: str = "mlflow_experiment",
Expand All @@ -71,7 +66,6 @@ def __init__(
self.dl_num_workers = dl_num_workers
self.dl_persistent_workers = dl_persistent_workers
self._num_threads = None
self.jax_preallocate_gpu_memory = jax_preallocate_gpu_memory
self.verbosity = verbosity
self.mlflow_set_tracking_uri = mlflow_set_tracking_uri
self.mlflow_set_experiment = mlflow_set_experiment
Expand Down Expand Up @@ -157,12 +151,6 @@ def seed(self, seed: int | None = None):
else:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Ensure deterministic CUDA operations for Jax
# (see https://github.com/google/jax/issues/13672)
if "XLA_FLAGS" not in os.environ:
os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true"
else:
os.environ["XLA_FLAGS"] += " --xla_gpu_deterministic_ops=true"
seed_everything(seed)
self._seed = seed

Expand Down Expand Up @@ -219,30 +207,6 @@ def reset_logging_handler(self):
ch.setFormatter(formatter)
scvi_logger.addHandler(ch)

@property
def jax_preallocate_gpu_memory(self):
"""Jax GPU memory allocation settings.

If False, Jax will only preallocate GPU memory it needs.
If float in (0, 1), Jax will preallocate GPU memory to that
fraction of the GPU memory.
"""
return self._jax_gpu

@jax_preallocate_gpu_memory.setter
def jax_preallocate_gpu_memory(self, value: float | bool):
# see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation
if value is False:
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
elif isinstance(value, float):
if value >= 1 or value <= 0:
raise ValueError("Need to use a value between 0 and 1")
# format is ".XX"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(value)[1:4]
else:
raise ValueError("value not understood, need bool or float in (0, 1)")
self._jax_gpu = value

@property
def mlflow_set_tracking_uri(self) -> str:
"""Setting the MLFlow tracking URI. Setting it will cause to also use it"""
Expand Down
Loading
Loading