diff --git a/.codecov.yaml b/.codecov.yaml index 956ffea779..4c92dda86a 100644 --- a/.codecov.yaml +++ b/.codecov.yaml @@ -21,10 +21,6 @@ flags: paths: - src/scvi/ carryforward: false - nonjax: - paths: - - src/scvi/ - carryforward: true cuda: paths: - src/scvi/ diff --git a/.github/workflows/test_linux_autotune.yml b/.github/workflows/test_linux_autotune.yml index a24d3e1f2a..32d138b527 100644 --- a/.github/workflows/test_linux_autotune.yml +++ b/.github/workflows/test_linux_autotune.yml @@ -55,7 +55,7 @@ jobs: python -m pip install --upgrade pip wheel uv python -m pip install nvidia-nccl-cu13 python -m pip install setuptools==70.0.0 - python -m uv pip install --system "scvi-tools[tests, cuda] @ ." + python -m uv pip install --system "scvi-tools[tests, autotune, cuda] @ ." - name: Run pytest env: diff --git a/.github/workflows/test_linux_jax.yml b/.github/workflows/test_linux_jax.yml deleted file mode 100644 index 9f24a429e7..0000000000 --- a/.github/workflows/test_linux_jax.yml +++ /dev/null @@ -1,75 +0,0 @@ -name: test (jax) - -on: - pull_request: - branches: [main, "[0-9]+.[0-9]+.x"] - types: [labeled, synchronize, opened] - schedule: - - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - test: - # if PR has label "jax" or "all tests" or if scheduled or manually triggered or on push - if: >- - ( - contains(github.event.pull_request.labels.*.name, 'jax') || - contains(github.event.pull_request.labels.*.name, 'all tests') || - (contains(github.event_name, 'schedule') && github.repository == 'scverse/scvi-tools') || - contains(github.event_name, 'workflow_dispatch') || - contains(github.event_name, 'push') - ) - - runs-on: ${{ matrix.os }} - - defaults: - run: - shell: bash -e {0} # -e to fail on error - - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest] - python: ["3.13"] - - permissions: - id-token: write - - name: unit - - env: - OS: ${{ matrix.os }} - PYTHON: ${{ matrix.python }} - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - cache: "pip" - cache-dependency-path: "**/pyproject.toml" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip wheel uv - python -m uv pip install --system "scvi-tools[tests] @ ." - - - name: Run pytest - env: - MPLBACKEND: agg - PLATFORM: ${{ matrix.os }} - DISPLAY: :42 - COLUMNS: 120 - run: | - coverage run -m pytest -v --color=yes --jax - coverage report - - - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - flags: jax diff --git a/.github/workflows/test_linux_mlflow.yml b/.github/workflows/test_linux_mlflow.yml index 255e8bb05e..037d0e4ae3 100644 --- a/.github/workflows/test_linux_mlflow.yml +++ b/.github/workflows/test_linux_mlflow.yml @@ -54,7 +54,7 @@ jobs: run: | python -m pip install --upgrade pip wheel uv python -m pip install nvidia-nccl-cu12 - python -m uv pip install --system "scvi-tools[cuda,tests] @ ." + python -m uv pip install --system "scvi-tools[cuda,tests,mlflow] @ ." - name: Run pytest env: diff --git a/.github/workflows/test_linux_nonjax.yml b/.github/workflows/test_linux_nonjax.yml deleted file mode 100644 index e7464ef7ce..0000000000 --- a/.github/workflows/test_linux_nonjax.yml +++ /dev/null @@ -1,76 +0,0 @@ -name: test (nonJax) - -on: - push: - branches: [main, "[0-9]+.[0-9]+.x"] - pull_request: - schedule: - - cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - -jobs: - test: - # if PR has label "nonjax" or if scheduled or manually triggered - if: >- - ( - contains(github.event.pull_request.labels.*.name, 'nonJAX') || - contains(github.event.pull_request.labels.*.name, 'all tests') || - (contains(github.event_name, 'schedule') && github.repository == 'scverse/scvi-tools') || - contains(github.event_name, 'workflow_dispatch') - ) - - runs-on: ${{ matrix.os }} - - defaults: - run: - shell: bash -e {0} # -e to fail on error - - strategy: - fail-fast: false - matrix: - os: [ubuntu-latest] - python: ["3.13"] - - name: integration - - env: - OS: ${{ matrix.os }} - PYTHON: ${{ matrix.python }} - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python }} - cache: "pip" - cache-dependency-path: "**/pyproject.toml" - - - name: Install dependencies - run: | - python -m pip install --upgrade pip wheel uv - python -m uv pip install --system "pytest" - python -m uv pip install --system "pytest-pretty" - python -m uv pip install --system "coverage" - python -m uv pip install --system "igraph" - python -m uv pip install --system "leidenalg" - python -m uv pip install --system "scvi-tools[autotune,hub,mlflow,file_sharing,regseq,parallel,interpretability,diagvi] @ ." - - - name: Run pytest - env: - MPLBACKEND: agg - PLATFORM: ${{ matrix.os }} - DISPLAY: :42 - COLUMNS: 120 - run: | - coverage run -m pytest -v --color=yes --ignore=tests/model/test_jaxscvi.py --ignore=tests/external/mrvi_jax/test_jaxmrvi_components.py --ignore=tests/external/tangram/test_tangram.py - coverage report - - - uses: codecov/codecov-action@v4 - with: - token: ${{ secrets.CODECOV_TOKEN }} - flags: nonjax diff --git a/.github/workflows/test_macos_mps.yml b/.github/workflows/test_macos_mps.yml index ee7380707f..ae90227ada 100644 --- a/.github/workflows/test_macos_mps.yml +++ b/.github/workflows/test_macos_mps.yml @@ -57,7 +57,6 @@ jobs: python -m pip install --upgrade pip wheel uv python -m pip install "scvi-tools[tests]" python -m pip install mlx - python -m pip install jax-metal python -m pip install mlx-metal python -m pip install coverage python -m pip install pytest diff --git a/CHANGELOG.md b/CHANGELOG.md index 286464ab8a..d7891b95b3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,7 @@ to [Semantic Versioning]. The full commit history is available in the [commit lo #### Changed - Update SCVI-Tools Hub models, {pr}`3733`. +- Removed Jax support from SCVI-Tools, {pr}`37xx`. #### Removed diff --git a/docs/api/developer.md b/docs/api/developer.md index a0fa13866f..7f2e166822 100644 --- a/docs/api/developer.md +++ b/docs/api/developer.md @@ -100,7 +100,6 @@ Parameterizable probability distributions. distributions.NegativeBinomial distributions.NegativeBinomialMixture distributions.ZeroInflatedNegativeBinomial - distributions.JaxNegativeBinomialMeanDisp distributions.BetaBinomial distributions.Normal distributions.Log1pNormal @@ -162,8 +161,6 @@ Existing module classes with respective generative and inference procedures. module.VAE module.VAEC module.AmortizedLDAPyroModule - module.JaxVAE - ``` ## External module @@ -186,12 +183,10 @@ Module classes in the external API with respective generative and inference proc external.contrastivevi.ContrastiveDataSplitter external.stereoscope.RNADeconv external.stereoscope.SpatialDeconv - external.tangram.TangramMapper external.scbasset.ScBassetModule external.contrastivevi.ContrastiveVAE external.velovi.VELOVAE external.mrvi.MRVAE - external.mrvi_jax.JaxMRVAE external.mrvi_torch.TorchMRVAE external.methylvi.METHYLVAE external.methylvi.METHYLANVAE @@ -223,11 +218,9 @@ These classes should be used to construct module classes that define generative module.base.BaseMinifiedModeModuleClass module.base.SupervisedModuleClass module.base.PyroBaseModuleClass - module.base.JaxBaseModuleClass module.base.EmbeddingModuleMixin module.base.LossOutput module.base.auto_move_data - ``` ## Neural networks @@ -277,7 +270,6 @@ TrainingPlans define train/test/val optimization steps for modules. train.SemiSupervisedAdversarialTrainingPlan train.LowLevelPyroTrainingPlan train.PyroTrainingPlan - train.JaxTrainingPlan train.Trainer train.TrainingPlan train.TrainRunner diff --git a/docs/api/user.md b/docs/api/user.md index f7a80125b5..1c1c91976d 100644 --- a/docs/api/user.md +++ b/docs/api/user.md @@ -33,7 +33,6 @@ import scvi model.TOTALVI model.MULTIVI model.AmortizedLDA - model.JaxSCVI model.mlxSCVI ``` @@ -55,14 +54,12 @@ import scvi external.SpatialStereoscope external.SOLO external.SCAR - external.Tangram external.SCBASSET external.ContrastiveVI external.POISSONVI external.VELOVI external.MRVI external.TorchMRVI - external.JaxMRVI external.METHYLVI external.METHYLANVI external.Decipher @@ -146,7 +143,6 @@ Here we maintain a few package specific utilities for feature selection, etc. train.PyroTrainingPlanConfig train.LowLevelPyroTrainingPlanConfig train.ClassifierTrainingPlanConfig - train.JaxTrainingPlanConfig train.TrainerConfig ``` diff --git a/docs/installation.md b/docs/installation.md index dd4a14fb25..c11c0dbe51 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -77,6 +77,8 @@ In this case we recommend installing PyTorch and JAX _before_ installing scvi-to Please follow the respective installation instructions for [PyTorch](https://pytorch.org/get-started/locally/) and [JAX](https://jax.readthedocs.io/en/latest/installation.html) compatible with your system and device type. +Note: starting v1.5, Jax is no longer supported in scvi-tools. + ## Optional dependencies scvi-tools is installed in its lightest form by default. @@ -89,7 +91,6 @@ It has many optional dependencies which expand its capabilities: - _parallel_ - for parallelization engine - _interpretability_ - for supervised models interpretability - _dataloaders_ - for custom dataloaders use -- _jax_ - for Jax support - _mlflow_ - for MLflow support - _tests_ - in order to be able to perform tests - _editing_ - for code editing @@ -101,9 +102,6 @@ It has many optional dependencies which expand its capabilities: The easiest way to install this is with `pip`. To install capability X run: _pip install scvi-tools[X]_ -You can install several capabilities together, e.g: -To install scvi-tools with JAX support for GPU on Ubuntu: _pip install scvi-tools[cuda,jax]_ - To install all tutorial dependencies: ```bash diff --git a/docs/user_guide/models/scvi.md b/docs/user_guide/models/scvi.md index 7e2a87d7b1..4a47eda63c 100644 --- a/docs/user_guide/models/scvi.md +++ b/docs/user_guide/models/scvi.md @@ -185,7 +185,7 @@ distributions of the latent variables, retrieving the likelihood parameters (of The standard {class}`~scvi.model.SCVI` class uses PyTorch as its computational backend. For users who prefer a different framework or are running on hardware where another backend offers better performance, two experimental alternatives are available: -- **JAX** – {class}`~scvi.model.JaxSCVI` is a JAX-based implementation of scVI. It can be substantially faster than the PyTorch implementation on CPUs (e.g., comparable to PyTorch on a GPU on a multi-core machine) and works on any platform supported by JAX. +- **JAX** – {class}`~scvi.model.JaxSCVI` is a JAX-based implementation of scVI. It can be substantially faster than the PyTorch implementation on CPUs (e.g., comparable to PyTorch on a GPU on a multi-core machine) and works on any platform supported by JAX. This version is deprecated starting v1.5. - **MLX (Apple Silicon)** – {class}`~scvi.model.mlxSCVI` is an MLX-based implementation optimized for Apple Silicon (M-series) chips via the [MLX](https://ml-explore.github.io/mlx/) framework. It is only available on macOS with Apple Silicon. Both alternatives expose the same high-level API (e.g., `setup_anndata`, `train`, `get_latent_representation`, `save`, `load`) as {class}`~scvi.model.SCVI`, though they may have reduced feature sets compared to the full PyTorch implementation. diff --git a/docs/user_guide/models/tangram.md b/docs/user_guide/models/tangram.md index 1509da845e..fff5d74b2d 100644 --- a/docs/user_guide/models/tangram.md +++ b/docs/user_guide/models/tangram.md @@ -1,5 +1,9 @@ # Tangram +:::{note} +This model is deprecated starting v1.5. +::: + **Tangram** {cite:p}`Biancalani21` (Python class {class}`~scvi.external.Tangram`) maps single-cell RNA-seq data to spatial data, permitting deconvolution of cell types in spatial data like Visium. This is a reimplementation of Tangram, which can originally be found [here](https://github.com/broadinstitute/Tangram). diff --git a/docs/user_guide/use_case/training_configuration.md b/docs/user_guide/use_case/training_configuration.md index 30b53ce8fe..745aedd7b3 100644 --- a/docs/user_guide/use_case/training_configuration.md +++ b/docs/user_guide/use_case/training_configuration.md @@ -53,6 +53,5 @@ Use the plan config that matches the training plan behind your model: models with adversarial mixing. - `PyroTrainingPlanConfig` / `LowLevelPyroTrainingPlanConfig` → Pyro‑based models. - `ClassifierTrainingPlanConfig` → classifier training plans. -- `JaxTrainingPlanConfig` → Jax training plans. If you’re unsure, you can still use `plan_kwargs` exactly as before. diff --git a/pyproject.toml b/pyproject.toml index 943c181f55..29e481a6b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,10 +59,10 @@ tests = ["pytest", "pytest-pretty", "coverage", "scvi-tools[optional]"] editing = ["jupyter", "pre-commit"] dev = ["scvi-tools[editing,tests]"] test = ["scvi-tools[tests]"] -cuda = ["torchvision", "torchaudio", "jax[cuda]"] -cuda13 = ["torchvision", "torchaudio", "jax[cuda13]"] +cuda = ["torchvision", "torchaudio"] +cuda13 = ["torchvision", "torchaudio"] tpu = ["torch_xla[tpu]"] -metal = ["torchvision", "torchaudio", "jax-metal","mlx-metal"] +metal = ["torchvision", "torchaudio", "mlx-metal"] docs = [ "docutils>=0.8,!=0.18.*,!=0.19.*", # see https://github.com/scverse/cookiecutter-scverse/pull/205 @@ -78,7 +78,7 @@ docs = [ "myst-nb", "sphinx-autodoc-typehints", ] -docsbuild = ["scvi-tools[docs,autotune,hub,jax,diagvi]","mlx"] +docsbuild = ["scvi-tools[docs,autotune,hub,diagvi]","mlx"] # scvi.autotune autotune = ["hyperopt>=0.2", "ray[tune]", "scib-metrics", "muon"] @@ -105,7 +105,7 @@ rapids = [ "cugraph>=24", "cuml>=24", "cupy-cuda12x", "rapids-singlecell[rapids] rapids-cuda13 = [ "cugraph>=24", "cuml>=24", "cupy-cuda13x", "rapids-singlecell[rapids]" ] optional = [ - "scvi-tools[autotune,mlflow,hub,jax,file_sharing,regseq,parallel,interpretability,diagvi]", + "scvi-tools[autotune,mlflow,hub,file_sharing,regseq,parallel,interpretability,diagvi]", "igraph","leidenalg","pynndescent", ] tutorials = [ @@ -144,7 +144,6 @@ markers = [ "autotune: mark tests that are used to check ray autotune capabilities", "custom dataloaders: mark tests that are used to check different custom data loaders", "dataloader: mark tests that are used to check data loaders", - "jax: mark test as jax related", "mlflow: mark test for mlflow", ] diff --git a/src/scvi/__init__.py b/src/scvi/__init__.py index b3b42e80a9..0a58d65225 100644 --- a/src/scvi/__init__.py +++ b/src/scvi/__init__.py @@ -16,7 +16,6 @@ settings.verbosity = logging.INFO -# Jax sets the root logger, this prevents double output. scvi_logger = logging.getLogger("scvi") scvi_logger.propagate = False diff --git a/src/scvi/_settings.py b/src/scvi/_settings.py index 80b778f7ab..64ca515106 100644 --- a/src/scvi/_settings.py +++ b/src/scvi/_settings.py @@ -1,7 +1,6 @@ from __future__ import annotations import logging -import os from pathlib import Path from typing import TYPE_CHECKING @@ -42,9 +41,6 @@ class ScviConfig: >>> scvi.settings.num_threads = 2 - To prevent Jax from preallocating GPU memory on start (default) - - >>> scvi.settings.jax_preallocate_gpu_memory = False """ def __init__( @@ -56,7 +52,6 @@ def __init__( logging_dir: str = "./scvi_log/", dl_num_workers: int = 0, dl_persistent_workers: bool = False, - jax_preallocate_gpu_memory: bool = False, warnings_stacklevel: int = 2, mlflow_set_tracking_uri: str = "", mlflow_set_experiment: str = "mlflow_experiment", @@ -71,7 +66,6 @@ def __init__( self.dl_num_workers = dl_num_workers self.dl_persistent_workers = dl_persistent_workers self._num_threads = None - self.jax_preallocate_gpu_memory = jax_preallocate_gpu_memory self.verbosity = verbosity self.mlflow_set_tracking_uri = mlflow_set_tracking_uri self.mlflow_set_experiment = mlflow_set_experiment @@ -157,12 +151,6 @@ def seed(self, seed: int | None = None): else: torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - # Ensure deterministic CUDA operations for Jax - # (see https://github.com/google/jax/issues/13672) - if "XLA_FLAGS" not in os.environ: - os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops=true" - else: - os.environ["XLA_FLAGS"] += " --xla_gpu_deterministic_ops=true" seed_everything(seed) self._seed = seed @@ -219,30 +207,6 @@ def reset_logging_handler(self): ch.setFormatter(formatter) scvi_logger.addHandler(ch) - @property - def jax_preallocate_gpu_memory(self): - """Jax GPU memory allocation settings. - - If False, Jax will only preallocate GPU memory it needs. - If float in (0, 1), Jax will preallocate GPU memory to that - fraction of the GPU memory. - """ - return self._jax_gpu - - @jax_preallocate_gpu_memory.setter - def jax_preallocate_gpu_memory(self, value: float | bool): - # see https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html#gpu-memory-allocation - if value is False: - os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" - elif isinstance(value, float): - if value >= 1 or value <= 0: - raise ValueError("Need to use a value between 0 and 1") - # format is ".XX" - os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(value)[1:4] - else: - raise ValueError("value not understood, need bool or float in (0, 1)") - self._jax_gpu = value - @property def mlflow_set_tracking_uri(self) -> str: """Setting the MLFlow tracking URI. Setting it will cause to also use it""" diff --git a/src/scvi/_types.py b/src/scvi/_types.py index ab8c823fa5..d6a828dabc 100644 --- a/src/scvi/_types.py +++ b/src/scvi/_types.py @@ -6,15 +6,8 @@ import mudata import torch -from scvi.utils import is_package_installed - Number = int | float AnnOrMuData = anndata.AnnData | mudata.MuData -if is_package_installed("jax"): - import jax.numpy as jnp - - Tensor = torch.Tensor | jnp.ndarray -else: - Tensor = torch.Tensor +Tensor = torch.Tensor LossRecord = dict[str, Tensor] | Tensor MinifiedDataType = Literal["latent_posterior_parameters"] diff --git a/src/scvi/distributions/__init__.py b/src/scvi/distributions/__init__.py index 3a18821368..973bde1aad 100644 --- a/src/scvi/distributions/__init__.py +++ b/src/scvi/distributions/__init__.py @@ -22,17 +22,3 @@ "ZeroInflatedLogNormal", "ZeroInflatedGamma", ] - - -def __getattr__(name: str): - """ - Lazily provide object. If optional deps are missing, raise a helpful ImportError - - only when object is actually requested. - """ - if name == "JaxNegativeBinomialMeanDisp": - error_on_missing_dependencies("jax", "numpyro") - from ._negative_binomial import JaxNegativeBinomialMeanDisp as _JaxNegativeBinomialMeanDisp - - return _JaxNegativeBinomialMeanDisp - raise AttributeError(f"module {__name__!r} has no attribute {name}") diff --git a/src/scvi/distributions/_negative_binomial.py b/src/scvi/distributions/_negative_binomial.py index cbb93fcb4a..e1efe27089 100644 --- a/src/scvi/distributions/_negative_binomial.py +++ b/src/scvi/distributions/_negative_binomial.py @@ -14,15 +14,9 @@ ) from scvi import settings -from scvi.utils import is_package_installed from ._constraints import optional_constraint -try: - import jax.numpy as jnp -except ImportError: - jnp = None - def torch_lgamma_mps(x: torch.Tensor) -> torch.Tensor: """Used in Mac Mx devices while broadcasting a tensor @@ -103,13 +97,13 @@ def log_zinb_positive( def log_nb_positive( - x: torch.Tensor | (jnp.ndarray if jnp else torch.Tensor), - mu: torch.Tensor | (jnp.ndarray if jnp else torch.Tensor), - theta: torch.Tensor | (jnp.ndarray if jnp else torch.Tensor), + x: torch.Tensor, + mu: torch.Tensor, + theta: torch.Tensor, eps: float = 1e-8, log_fn: callable = torch.log, lgamma_fn: callable = torch.lgamma, -) -> torch.Tensor | (jnp.ndarray if jnp else torch.Tensor): +) -> torch.Tensor: """Log likelihood (scalar) of a minibatch according to a nb model. Parameters @@ -721,60 +715,3 @@ def __repr__(self) -> str: ] ) return self.__class__.__name__ + "(" + args_string + ")" - - -if is_package_installed("numpyro") and is_package_installed("jax"): - import numpyro.distributions as dist - - class JaxNegativeBinomialMeanDisp(dist.NegativeBinomial2): - """Negative binomial parameterized by mean and inverse dispersion.""" - - import jax.numpy as jnp - from numpyro.distributions import constraints as numpyro_constraints - from numpyro.distributions.util import validate_sample - - arg_constraints = { - "mean": numpyro_constraints.positive, - "inverse_dispersion": numpyro_constraints.positive, - } - support = numpyro_constraints.nonnegative_integer - - def __init__( - self, - mean: jnp.ndarray, - inverse_dispersion: jnp.ndarray, - validate_args: bool | None = None, - eps: float = 1e-8, - ): - from numpyro.distributions.util import promote_shapes - - self._inverse_dispersion, self._mean = promote_shapes(inverse_dispersion, mean) - self._eps = eps - super().__init__(mean, inverse_dispersion, validate_args=validate_args) - - @property - def mean(self) -> jnp.ndarray: - return self._mean - - @property - def inverse_dispersion(self) -> jnp.ndarray: - return self._inverse_dispersion - - @validate_sample - def log_prob(self, value) -> jnp.ndarray: - """Log probability.""" - import jax - import jax.numpy as jnp - - # theta is inverse_dispersion - theta = self._inverse_dispersion - mu = self._mean - eps = self._eps - return log_nb_positive( - value, - mu, - theta, - eps=eps, - log_fn=jnp.log, - lgamma_fn=jax.scipy.special.gammaln, - ) diff --git a/src/scvi/external/__init__.py b/src/scvi/external/__init__.py index 4621b23dad..837c6fb0e8 100644 --- a/src/scvi/external/__init__.py +++ b/src/scvi/external/__init__.py @@ -46,35 +46,3 @@ "CYTOVI", "DIAGVI", ] - - -def __getattr__(name: str): - """ - Lazily provide object. If optional deps are missing, raise a helpful ImportError - - only when object is actually requested. - """ - if name == "JaxMRVI": - warnings.warn( - "In order to use the Jax version of MRVI make sure to install scvi-tools[jax]", - DeprecationWarning, - stacklevel=settings.warnings_stacklevel, - ) - - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from .mrvi_jax import JaxMRVI as _JaxMRVI - - return _JaxMRVI - - if name == "Tangram": - warnings.warn( - "In order to use the TANGRAM make sure to install scvi-tools[jax]", - DeprecationWarning, - stacklevel=settings.warnings_stacklevel, - ) - - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from .tangram import Tangram as _Tangram - - return _Tangram - raise AttributeError(f"module {__name__!r} has no attribute {name}") diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index f312e7852d..34bc9c3f34 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -21,14 +21,14 @@ from scvi._types import AnnOrMuData -Backend = Literal["torch", "jax", None] +Backend = Literal["torch", None] class MRVI(BaseMinifiedModeModelClass): """ Multi-resolution Variational Inference (MrVI). - This is a convenience wrapper that instantiates the Torch or JAX + This is a convenience wrapper that instantiates the Torch implementation based on `backend` and returns that instance. Parameters @@ -36,7 +36,7 @@ class MRVI(BaseMinifiedModeModelClass): adata AnnData object that has been registered via the appropriate `setup_anndata`. backend - Which backend to use: "torch" or "jax". + Which backend to use: "torch" registry (Torch-only) Registry dict for loading from saved state. **model_kwargs @@ -45,7 +45,7 @@ class MRVI(BaseMinifiedModeModelClass): Notes ----- - When setup anndata with `backend="torch"`, this returns an instance of `TorchMRVI`. - - When setup anndata with `backend="jax"`, this returns an instance of `JaxMRVI`. + - The JAX version is deprecated starting v1.5. """ def __new__( @@ -58,7 +58,7 @@ def __new__( ): if backend is not None: warnings.warn( - "backend parameter is ignored from version 1.4.1", + "backend parameter is ignored from version 1.4.3", UserWarning, stacklevel=settings.warnings_stacklevel, ) @@ -66,17 +66,9 @@ def __new__( raise ValueError("MRVI requires adata or registry to infer backend.") try: model = TorchMRVI(adata=adata, registry=registry, **model_kwargs) - model_name = "TorchMRVI" - except (ValueError, KeyError): - model_name = "JaxMRVI" - if model_name == "TorchMRVI": - return model - if model_name == "JaxMRVI": - from scvi.external.mrvi_jax import JaxMRVI - - return JaxMRVI(adata=adata, **model_kwargs) - else: - raise ValueError("Unknown backend. Use 'torch' or 'jax' MRVI.") + except Exception: + raise + return model @classmethod @setup_anndata_dsp.dedent @@ -100,7 +92,7 @@ def setup_anndata( %(param_batch_key)s %(param_labels_key)s backend - Which backend to use: "torch" or "jax". + Which backend to use: "torch". **kwargs Additional keyword arguments passed into :meth:`~scvi.data.AnnDataManager.register_fields`. @@ -120,24 +112,8 @@ def setup_anndata( labels_key=labels_key, **kwargs, ) - elif backend == "jax": - from scvi.external.mrvi_jax import JaxMRVI - - warnings.warn( - "MRVI model is being setup with JAX backend", - UserWarning, - stacklevel=settings.warnings_stacklevel, - ) - JaxMRVI.setup_anndata( - adata=adata, - layer=layer, - sample_key=sample_key, - batch_key=batch_key, - labels_key=labels_key, - **kwargs, - ) else: - raise ValueError(f"Unknown backend '{backend}'. Use 'torch' or 'jax'.") + raise ValueError(f"Unknown backend '{backend}'. Use 'torch'") @classmethod @devices_dsp.dedent @@ -202,43 +178,21 @@ def load( backup_url=backup_url, datamodule=datamodule, ) - elif model_name == "JaxMRVI" or model_name == "MRVI": - from scvi.external.mrvi_jax import JaxMRVI - - warnings.warn( - "MRVI model is being loaded with JAX backend", - UserWarning, - stacklevel=settings.warnings_stacklevel, - ) - - return JaxMRVI.load( - dir_path, - adata=adata, - accelerator=accelerator, - device=device, - prefix=prefix, - backup_url=backup_url, - datamodule=datamodule, - allowed_classes_names_list=[ - "MRVI" - ], # allowing old JAX MRVI models to be loaded TODO: need to change in v1.5 - ) else: - raise ValueError("Unknown backend . Use 'torch' or 'jax' MRVI.") + raise ValueError("Unknown backend . Use 'torch' MRVI.") def differential_expression(self, *args, **kwargs): """Perform differential expression analysis. - Delegates to the underlying :class:`~scvi.external.TorchMRVI` or - :class:`~scvi.external.JaxMRVI` instance returned by the constructor. + Delegates to the underlying :class:`~scvi.external.TorchMRVI` + instance returned by the constructor. See Also -------- :meth:`~scvi.external.TorchMRVI.differential_expression` """ raise NotImplementedError( - "Call differential_expression on the TorchMRVI or JaxMRVI instance " - "returned by MRVI(...)." + "Call differential_expression on the TorchMRVI instance returned by MRVI(...)." ) diff --git a/src/scvi/external/mrvi_jax/__init__.py b/src/scvi/external/mrvi_jax/__init__.py deleted file mode 100644 index de3ebe1401..0000000000 --- a/src/scvi/external/mrvi_jax/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from scvi.utils import error_on_missing_dependencies - -error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - -from ._model import JaxMRVI # noqa: E402 -from ._module import JaxMRVAE # noqa: E402 - -__all__ = ["JaxMRVI", "JaxMRVAE"] diff --git a/src/scvi/external/mrvi_jax/_components.py b/src/scvi/external/mrvi_jax/_components.py deleted file mode 100644 index 906dd59e7f..0000000000 --- a/src/scvi/external/mrvi_jax/_components.py +++ /dev/null @@ -1,312 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpyro.distributions as dist - -if TYPE_CHECKING: - from collections.abc import Callable - from typing import Any, Literal - -PYTORCH_DEFAULT_SCALE = 1 / 3 - - -class Dense(nn.DenseGeneral): - """Dense layer. - - Uses a custom initializer for the kernel to replicate the default PyTorch behavior. - """ - - def __init__(self, *args, **kwargs): - from flax.linen.initializers import variance_scaling - - _kwargs = {"kernel_init": variance_scaling(PYTORCH_DEFAULT_SCALE, "fan_in", "uniform")} - _kwargs.update(kwargs) - - super().__init__(*args, **_kwargs) - - -class ResnetBlock(nn.Module): - """Resnet block. - - Consists of the following operations: - - 1. :class:`~flax.linen.Dense` - 2. :class:`~flax.linen.LayerNorm` - 3. Activation function specified by ``internal_activation`` - 4. Skip connection if ``n_in`` is equal to ``n_hidden``, otherwise a :class:`~flax.linen.Dense` - layer is applied to the input before the skip connection to match features. - 5. :class:`~flax.linen.Dense` - 6. :class:`~flax.linen.LayerNorm` - 7. Activation function specified by ``output_activation`` - - Parameters - ---------- - n_out - Number of output units. - n_hidden - Number of hidden units. - internal_activation - Activation function to use after the first :class:`~flax.linen.Dense` layer. - output_activation - Activation function to use after the last :class:`~flax.linen.Dense` layer. - training - Whether the model is in training mode. - """ - - n_out: int - n_hidden: int = 128 - internal_activation: Callable[[jax.typing.ArrayLike], jax.Array] = nn.relu - output_activation: Callable[[jax.typing.ArrayLike], jax.Array] = nn.relu - training: bool | None = None - - @nn.compact - def __call__(self, inputs: jax.typing.ArrayLike, training: bool | None = None) -> jax.Array: - training = nn.merge_param("training", self.training, training) - h = Dense(self.n_hidden)(inputs) - h = nn.LayerNorm()(h) - h = self.internal_activation(h) - if inputs.shape[-1] != self.n_hidden: - h = h + Dense(self.n_hidden)(inputs) - else: - h = h + inputs - h = Dense(self.n_out)(h) - h = nn.LayerNorm()(h) - return self.output_activation(h) - - -class MLP(nn.Module): - """Multi-layer perceptron with resnet blocks. - - Applies ``n_layers`` :class:`~ResnetBlock` blocks to the input, followed by a - :class:`~flax.linen.Dense` layer to project to the output dimension. - - Parameters - ---------- - n_out - Number of output units. - n_hidden - Number of hidden units. - n_layers - Number of resnet blocks. - activation - Activation function to use. - training - Whether the model is in training mode. - """ - - n_out: int - n_hidden: int = 128 - n_layers: int = 1 - activation: Callable[[jax.typing.ArrayLike], jax.Array] = nn.relu - training: bool | None = None - - @nn.compact - def __call__(self, inputs: jax.typing.ArrayLike, training: bool | None = None) -> jax.Array: - training = nn.merge_param("training", self.training, training) - h = inputs - for _ in range(self.n_layers): - h = ResnetBlock( - n_out=self.n_hidden, - internal_activation=self.activation, - output_activation=self.activation, - )(h, training=training) - return Dense(self.n_out)(h) - - -class NormalDistOutputNN(nn.Module): - """Fully-connected neural net parameterizing a normal distribution. - - Applies ``n_layers`` :class:`~ResnetBlock` blocks to the input, followed by a - :class:`~flax.linen.Dense` layer for the mean and a :class:`~flax.linen.Dense` and - :func:`~flax.linen.softplus` layer for the scale. - - Parameters - ---------- - n_out - Number of output units. - n_hidden - Number of hidden units. - n_layers - Number of resnet blocks. - scale_eps - Numerical stability constant added to the scale of the normal distribution. - """ - - n_out: int - n_hidden: int = 128 - n_layers: int = 1 - scale_eps: float = 1e-5 - training: bool | None = None - - @nn.compact - def __call__(self, inputs: jax.typing.ArrayLike, training: bool | None = None) -> dist.Normal: - training = nn.merge_param("training", self.training, training) - h = inputs - for _ in range(self.n_layers): - h = ResnetBlock(n_out=self.n_hidden)(h, training=training) - mean = Dense(self.n_out)(h) - scale = nn.Sequential([Dense(self.n_out), nn.softplus])(h) - return dist.Normal(mean, scale + self.scale_eps) - - -class ConditionalNormalization(nn.Module): - """Condition-specific normalization. - - Applies either batch normalization or layer normalization to the input, followed by - condition-specific scaling (``gamma``) and shifting (``beta``). - - Parameters - ---------- - n_features - Number of features. - n_conditions - Number of conditions. - training - Whether the model is in training mode. - normalization_type - Type of normalization to apply. Must be one of ``"batch", "layer"``. - """ - - n_features: int - n_conditions: int - training: bool | None = None - normalization_type: Literal["batch", "layer"] = "layer" - - @staticmethod - def _gamma_initializer() -> jax.nn.initializers.Initializer: - def init(key: jax.random.KeyArray, shape: tuple, dtype: Any = jnp.float_) -> jax.Array: - weights = jax.random.normal(key, shape, dtype) * 0.02 + 1 - return weights - - return init - - @staticmethod - def _beta_initializer() -> jax.nn.initializers.Initializer: - def init(key: jax.random.KeyArray, shape: tuple, dtype: Any = jnp.float_) -> jax.Array: - del key - weights = jnp.zeros(shape, dtype=dtype) - return weights - - return init - - @nn.compact - def __call__( - self, - x: jax.typing.ArrayLike, - condition: jax.typing.ArrayLike, - training: bool | None = None, - ) -> jax.Array: - training = nn.merge_param("training", self.training, training) - - if self.normalization_type == "batch": - x = nn.BatchNorm(use_bias=False, use_scale=False)(x, use_running_average=not training) - elif self.normalization_type == "layer": - x = nn.LayerNorm(use_bias=False, use_scale=False)(x) - else: - raise ValueError("`normalization_type` must be one of ['batch', 'layer'].") - - cond_int = condition.squeeze(-1).astype(int) - gamma = nn.Embed( - self.n_conditions, - self.n_features, - embedding_init=self._gamma_initializer(), - name="gamma_conditional", - )(cond_int) - beta = nn.Embed( - self.n_conditions, - self.n_features, - embedding_init=self._beta_initializer(), - name="beta_conditional", - )(cond_int) - - return gamma * x + beta - - -class AttentionBlock(nn.Module): - """Attention block consisting of multi-head self-attention and MLP. - - Parameters - ---------- - query_dim - Dimension of the query input. - out_dim - Dimension of the output. - outerprod_dim - Dimension of the outer product. - n_channels - Number of channels. - n_heads - Number of heads. - dropout_rate - Dropout rate. - n_hidden_mlp - Number of hidden units in the MLP. - n_layers_mlp - Number of layers in the MLP. - training - Whether the model is in training mode. - stop_gradients_mlp - Whether to stop gradients through the MLP. - activation - Activation function to use. - """ - - query_dim: int - out_dim: int - outerprod_dim: int = 16 - n_channels: int = 4 - n_heads: int = 2 - dropout_rate: float = 0.0 - n_hidden_mlp: int = 32 - n_layers_mlp: int = 1 - training: bool | None = None - stop_gradients_mlp: bool = False - activation: Callable[[jax.Array], jax.Array] = nn.gelu - - @nn.compact - def __call__( - self, - query_embed: jax.typing.ArrayLike, - kv_embed: jax.typing.ArrayLike, - training: bool | None = None, - ) -> jax.Array: - training = nn.merge_param("training", self.training, training) - has_mc_samples = query_embed.ndim == 3 - - query_embed_stop = ( - query_embed if not self.stop_gradients_mlp else jax.lax.stop_gradient(query_embed) - ) - query_for_att = nn.DenseGeneral((self.outerprod_dim, 1), use_bias=False)(query_embed_stop) - kv_for_att = nn.DenseGeneral((self.outerprod_dim, 1), use_bias=False)(kv_embed) - eps = nn.MultiHeadDotProductAttention( - num_heads=self.n_heads, - qkv_features=self.n_channels * self.n_heads, - out_features=self.n_channels, - dropout_rate=self.dropout_rate, - use_bias=True, - )(inputs_q=query_for_att, inputs_kv=kv_for_att, deterministic=not training) - - if not has_mc_samples: - eps = jnp.reshape(eps, (eps.shape[0], eps.shape[1] * eps.shape[2])) - else: - eps = jnp.reshape(eps, (eps.shape[0], eps.shape[1], eps.shape[2] * eps.shape[3])) - - eps_ = MLP( - n_out=self.outerprod_dim, - n_hidden=self.n_hidden_mlp, - training=training, - activation=self.activation, - )(inputs=eps) - inputs = jnp.concatenate([query_embed, eps_], axis=-1) - residual = MLP( - n_out=self.out_dim, - n_hidden=self.n_hidden_mlp, - n_layers=self.n_layers_mlp, - training=training, - activation=self.activation, - )(inputs=inputs) - return residual diff --git a/src/scvi/external/mrvi_jax/_model.py b/src/scvi/external/mrvi_jax/_model.py deleted file mode 100644 index 65cac5f6fc..0000000000 --- a/src/scvi/external/mrvi_jax/_model.py +++ /dev/null @@ -1,1661 +0,0 @@ -from __future__ import annotations - -import logging -import warnings -from typing import TYPE_CHECKING - -import jax -import jax.numpy as jnp -import numpy as np -import xarray as xr -from tqdm import tqdm - -from scvi import REGISTRY_KEYS, settings -from scvi.data import AnnDataManager, fields -from scvi.external.mrvi._types import MRVIReduction -from scvi.external.mrvi_jax._module import JaxMRVAE -from scvi.external.mrvi_jax._utils import rowwise_max_excluding_diagonal -from scvi.model.base import BaseModelClass, JaxTrainingMixin -from scvi.train._config import merge_kwargs -from scvi.utils import setup_anndata_dsp -from scvi.utils._docstrings import devices_dsp - -if TYPE_CHECKING: - from typing import Literal - - import numpy.typing as npt - from anndata import AnnData - from numpyro.distributions import Distribution - -logger = logging.getLogger(__name__) - -DEFAULT_TRAIN_KWARGS = { - "max_epochs": 100, - "early_stopping": True, - "early_stopping_patience": 15, - "check_val_every_n_epoch": 1, - "batch_size": 256, - "train_size": 0.9, - "plan_kwargs": { - "lr": 2e-3, - "n_epochs_kl_warmup": 20, - "max_norm": 40, - "eps": 1e-8, - "weight_decay": 1e-8, - }, -} - - -class JaxMRVI(JaxTrainingMixin, BaseModelClass): - """Multi-resolution Variational Inference (MrVI) :cite:p:`Boyeau24`. - - Parameters - ---------- - adata - AnnData object that has been registered via :meth:`~scvi.external.JaxMRVI.setup_anndata`. - n_latent - Dimensionality of the latent space for ``z``. - n_latent_u - Dimensionality of the latent space for ``u``. - encoder_n_hidden - Number of nodes per hidden layer in the encoder. - encoder_n_layers - Number of hidden layers in the encoder. - z_u_prior - Whether to use a prior for ``z_u``. - z_u_prior_scale - Scale of the prior for the difference between ``z`` and ``u``. - u_prior_scale - Scale of the prior for ``u``. - u_prior_mixture - Whether to use a mixture model for the ``u`` prior. - u_prior_mixture_k - Number of components in the mixture model for the ``u`` prior. - learn_z_u_prior_scale - Whether to learn the scale of the ``z`` and ``u`` difference prior during training. - laplace_scale - Scale parameter for the Laplace distribution in the decoder. - scale_observations - Whether to scale loss by the number of observations per sample. - px_kwargs - Keyword args for :class:`~scvi.external.mrvi_jax._module.DecoderZXAttention`. - qz_kwargs - Keyword args for :class:`~scvi.external.mrvi_jax._module.EncoderUZ`. - qu_kwargs - Keyword args for :class:`~scvi.external.mrvi_jax._module.EncoderXU`. - - Notes - ----- - This implementation of MRVI in JAX is deprecated and will be unsupported in v1.5. - Please use the torch implementation of MRVI - - See further usage examples in the following tutorial: - - 1. :doc:`/tutorials/notebooks/scrna/MrVI_tutorial` - - See the user guide for this model: - - 1. :doc:`/user_guide/models/mrvi_jax` - - See Also - -------- - :class:`~scvi.external.mrvi_jax.JaxMRVAE` - """ - - def __init__(self, adata: AnnData, **model_kwargs): - super().__init__(adata) - - warnings.warn( - "You are using the Jax Version of MrVI, which is the default. Starting v1.5, " - "This class will still be usable but won't be as actively maintained as the PyTorch " - "implementation of MRVI, which will become the default one. We recommend to train " - "your MRVI with torch backend by stating MRVI.setup_anndata(adata...,backend='torch'", - DeprecationWarning, - stacklevel=settings.warnings_stacklevel, - ) - - n_sample = self.summary_stats.n_sample - n_batch = self.summary_stats.n_batch - n_labels = self.summary_stats.n_labels - - self.update_sample_info(adata) - self.sample_key = self.adata_manager.get_state_registry( - REGISTRY_KEYS.SAMPLE_KEY - ).original_key - self.sample_order = self.adata_manager.get_state_registry( - REGISTRY_KEYS.SAMPLE_KEY - ).categorical_mapping - - self.n_obs_per_sample = jnp.array( - adata.obs._scvi_sample.value_counts().sort_index().values - ) - self.backend = "jax" - - self.module = JaxMRVAE( - n_input=self.summary_stats.n_vars, - n_sample=n_sample, - n_batch=n_batch, - n_labels=n_labels, - n_obs_per_sample=self.n_obs_per_sample, - **model_kwargs, - ) - self.init_params_ = self._get_init_params(locals()) - - def to_device(self, device): - pass - - def _generate_stacked_rngs(self, n_sets: int | tuple) -> dict[str, jax.random.KeyArray]: - return_1d = isinstance(n_sets, int) - if return_1d: - n_sets_1d = n_sets - else: - n_sets_1d = np.prod(n_sets) - rngs_list = [self.module.rngs for _ in range(n_sets_1d)] - # Combine the list of RNG dicts into a single list. This is necessary for vmap/map. - rngs = { - required_rng: jnp.concatenate( - [rngs_dict[required_rng][None] for rngs_dict in rngs_list], axis=0 - ) - for required_rng in self.module.required_rngs - } - if not return_1d: - # Reshaping the random keys to the desired shape in - # the case of multiple sets. - rngs = { - key: random_key.reshape(n_sets + random_key.shape[1:]) - for (key, random_key) in rngs.items() - } - return rngs - - @classmethod - @setup_anndata_dsp.dedent - def setup_anndata( - cls, - adata: AnnData, - layer: str | None = None, - sample_key: str | None = None, - batch_key: str | None = None, - labels_key: str | None = None, - **kwargs, - ): - """%(summary)s. - - Parameters - ---------- - %(param_adata)s - %(param_layer)s - %(param_sample_key)s - %(param_batch_key)s - %(param_labels_key)s - **kwargs - Additional keyword arguments passed into - :meth:`~scvi.data.AnnDataManager.register_fields`. - """ - setup_method_args = cls._get_setup_method_args(**locals()) - # Add the index for batched computation of local statistics. - adata.obs["_indices"] = np.arange(adata.n_obs).astype(int) - anndata_fields = [ - fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), - fields.CategoricalObsField(REGISTRY_KEYS.SAMPLE_KEY, sample_key), - fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - fields.CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - fields.NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), - ] - - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) - - @devices_dsp.dedent - def train( - self, - max_epochs: int | None = None, - accelerator: str | None = "auto", - devices: int | list[int] | str = "auto", - train_size: float | None = None, - validation_size: float | None = None, - batch_size: int = 128, - early_stopping: bool = False, - plan_kwargs: dict | None = None, - **trainer_kwargs, - ): - """Train the model. - - Parameters - ---------- - max_epochs - Maximum number of epochs to train the model. The actual number of epochs may be less if - early stopping is enabled. If ``None``, defaults to a heuristic based on - :func:`~scvi.model.get_max_epochs_heuristic`. - %(param_accelerator)s - %(param_devices)s - train_size - Size of the training set in the range ``[0.0, 1.0]``. - validation_size - Size of the validation set. If ``None``, defaults to ``1 - train_size``. If - ``train_size + validation_size < 1``, the remaining cells belong to a test set. - batch_size - Minibatch size to use during training. - early_stopping - Perform early stopping. Additional arguments can be passed in through ``**kwargs``. - See :class:`~scvi.train.Trainer` for further options. - plan_kwargs - Additional keyword arguments passed into :class:`~scvi.train.JaxTrainingPlan`. - **trainer_kwargs - Additional keyword arguments passed into :class:`~scvi.train.Trainer`. - """ - from copy import deepcopy - - train_kwargs = { - "max_epochs": max_epochs, - "accelerator": accelerator, - "devices": devices, - "train_size": train_size, - "validation_size": validation_size, - "batch_size": batch_size, - "early_stopping": early_stopping, - **trainer_kwargs, - } - train_kwargs = dict(deepcopy(DEFAULT_TRAIN_KWARGS), **train_kwargs) - plan_kwargs = merge_kwargs(None, plan_kwargs, name="plan") - train_kwargs["plan_kwargs"] = dict( - deepcopy(DEFAULT_TRAIN_KWARGS["plan_kwargs"]), **plan_kwargs - ) - - super().train(**train_kwargs) - - def get_latent_representation( - self, - adata: AnnData | None = None, - indices: npt.ArrayLike | None = None, - batch_size: int | None = None, - use_mean: bool = True, - give_z: bool = False, - ) -> npt.NDArray: - """Compute the latent representation of the data. - - Parameters - ---------- - adata - AnnData object to use. Defaults to the AnnData object used to initialize the model. - indices - Indices of cells to use. - batch_size - Batch size to use for computing the latent representation. - use_mean - Whether to use the mean of the distribution as the latent representation. - give_z - Whether to return the z latent representation or the u latent representation. - - Returns - ------- - The latent representation of the data. - """ - self._check_if_trained(warn=False) - adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True - ) - - us = [] - zs = [] - jit_inference_fn = self.module.get_jit_inference_fn( - inference_kwargs={"use_mean": use_mean} - ) - for array_dict in tqdm(scdl): - outputs = jit_inference_fn(self.module.rngs, array_dict) - - if give_z: - zs.append(jax.device_get(outputs["z"])) - else: - us.append(jax.device_get(outputs["u"])) - - if give_z: - return np.array(jnp.concatenate(zs, axis=0)) - else: - return np.array(jnp.concatenate(us, axis=0)) - - def compute_local_statistics( - self, - reductions: list[MRVIReduction], - adata: AnnData | None = None, - indices: npt.ArrayLike | None = None, - batch_size: int | None = None, - use_vmap: Literal["auto", True, False] = "auto", - norm: str = "l2", - mc_samples: int = 10, - ) -> xr.Dataset: - """Compute local statistics from counterfactual sample representations. - - Local statistics are reductions over either the local counterfactual latent representations - or the resulting local sample distance matrices. For a large number of cells and/or - samples, this method can avoid scalability issues by grouping over cell-level covariates. - - Parameters - ---------- - reductions - List of reductions to compute over local counterfactual sample representations. - adata - AnnData object to use. - indices - Indices of cells to use. - batch_size - Batch size to use for computing the local statistics. - use_vmap - Whether to use vmap to compute the local statistics. If "auto", vmap will be used if - the number of samples is less than 500. - norm - Norm to use for computing the distances. - mc_samples - Number of Monte Carlo samples to use for computing the local statistics. Only applies - if using sampled representations. - """ - from functools import partial - - from scvi.external.mrvi._types import _parse_local_statistics_requirements - - use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 - - if not reductions or len(reductions) == 0: - raise ValueError("At least one reduction must be provided.") - - adata = self.adata if adata is None else adata - self._check_if_trained(warn=False) - # Hack to ensure new AnnDatas have indices. - adata.obs["_indices"] = np.arange(adata.n_obs).astype(int) - - adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True - ) - n_sample = self.summary_stats.n_sample - - reqs = _parse_local_statistics_requirements(reductions) - - vars_in = {"params": self.module.params, **self.module.state} - - @partial(jax.jit, static_argnames=["use_mean", "mc_samples"]) - def mapped_inference_fn( - stacked_rngs: dict[str, jax.random.KeyArray], - x: jax.typing.ArrayLike, - sample_index: jax.typing.ArrayLike, - cf_sample: jax.typing.ArrayLike, - use_mean: bool, - mc_samples: int | None = None, - ): - # TODO: use `self.module.get_jit_inference_fn` when it supports traced values. - def inference_fn( - rngs, - cf_sample, - ): - return self.module.apply( - vars_in, - rngs=rngs, - method=self.module.inference, - x=x, - sample_index=sample_index, - cf_sample=cf_sample, - use_mean=use_mean, - mc_samples=mc_samples, - )["z"] - - if use_vmap: - return jax.vmap(inference_fn, in_axes=(0, 0), out_axes=-2)( - stacked_rngs, - cf_sample, - ) - else: - - def per_sample_inference_fn(pair): - rngs, cf_sample = pair - return inference_fn(rngs, cf_sample) - - return jax.lax.transpose( - jax.lax.map(per_sample_inference_fn, (stacked_rngs, cf_sample)), - (1, 0, 2), - ) - - ungrouped_data_arrs = {} - grouped_data_arrs = {} - for ur in reqs.ungrouped_reductions: - ungrouped_data_arrs[ur.name] = [] - for gr in reqs.grouped_reductions: - grouped_data_arrs[gr.name] = {} # Will map group category to running the group sum. - - for array_dict in tqdm(scdl): - indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() - n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to(np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1)) - inf_inputs = self.module._get_inference_input( - array_dict, - ) - stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) - cell_names = adata.obs_names[indices].values - - # OK to use stacked rngs here since there is no stochasticity for mean rep. - if reqs.needs_mean_representations: - try: - mean_zs_ = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - use_mean=True, - ) - except jax.errors.JaxRuntimeError as e: - if use_vmap: - raise RuntimeError( - "JAX ran out of memory. Try setting use_vmap=False." - ) from e - else: - raise e - - mean_zs = xr.DataArray( - np.array(mean_zs_), - dims=["cell_name", "sample", "latent_dim"], - coords={ - "cell_name": cell_names, - "sample": self.sample_order, - }, - name="sample_representations", - ) - if reqs.needs_sampled_representations: - sampled_zs_ = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - use_mean=False, - mc_samples=mc_samples, - ) # (n_mc_samples, n_cells, n_samples, n_latent) - sampled_zs_ = sampled_zs_.transpose((1, 0, 2, 3)) - sampled_zs = xr.DataArray( - np.array(sampled_zs_), - dims=["cell_name", "mc_sample", "sample", "latent_dim"], - coords={ - "cell_name": cell_names, - "sample": self.sample_order, - }, - name="sample_representations", - ) - - if reqs.needs_mean_distances: - mean_dists = self._compute_distances_from_representations( - mean_zs_, cell_names, norm=norm, return_numpy=True - ) - - if reqs.needs_sampled_distances or reqs.needs_normalized_distances: - sampled_dists = self._compute_distances_from_representations( - sampled_zs_, cell_names, norm=norm, return_numpy=True - ) - - if reqs.needs_normalized_distances: - if norm != "l2": - raise ValueError( - f"Norm must be 'l2' when using normalized distances. Got {norm}." - ) - ( - normalization_means, - normalization_vars, - ) = self._compute_local_baseline_dists( - array_dict, mc_samples=mc_samples - ) # both are shape (n_cells,) - normalization_means = normalization_means.reshape(-1, 1, 1, 1) - normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) - normalized_dists = ( - (sampled_dists - normalization_means) / (normalization_vars**0.5) - ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) - - # Compute each reduction - for r in reductions: - if r.input == "mean_representations": - inputs = mean_zs - elif r.input == "sampled_representations": - inputs = sampled_zs - elif r.input == "mean_distances": - inputs = mean_dists - elif r.input == "sampled_distances": - inputs = sampled_dists - elif r.input == "normalized_distances": - inputs = normalized_dists - else: - raise ValueError(f"Unknown reduction input: {r.input}") - - outputs = r.fn(inputs) - - if r.group_by is not None: - group_by = adata.obs[r.group_by].iloc[indices] - group_by_cats = group_by.unique() - for cat in group_by_cats: - cat_summed_outputs = outputs.sel( - cell_name=cell_names[group_by == cat] - ).sum(dim="cell_name") - cat_summed_outputs = cat_summed_outputs.assign_coords( - {f"{r.group_by}_name": cat} - ) - if cat not in grouped_data_arrs[r.name]: - grouped_data_arrs[r.name][cat] = cat_summed_outputs - else: - grouped_data_arrs[r.name][cat] += cat_summed_outputs - else: - ungrouped_data_arrs[r.name].append(outputs) - - # Combine all outputs. - final_data_arrs = {} - for ur_name, ur_outputs in ungrouped_data_arrs.items(): - final_data_arrs[ur_name] = xr.concat(ur_outputs, dim="cell_name") - - for gr in reqs.grouped_reductions: - group_by = adata.obs[gr.group_by] - group_by_counts = group_by.value_counts() - averaged_grouped_data_arrs = [] - for cat, count in group_by_counts.items(): - averaged_grouped_data_arrs.append(grouped_data_arrs[gr.name][cat] / count) - final_data_arr = xr.concat(averaged_grouped_data_arrs, dim=f"{gr.group_by}_name") - final_data_arrs[gr.name] = final_data_arr - - return xr.Dataset(data_vars=final_data_arrs) - - def _compute_local_baseline_dists( - self, batch: dict, mc_samples: int = 250 - ) -> tuple[npt.NDArray, npt.NDArray]: - """ - Approximate the distributions used as baselines for normalizing the local sample distances. - - Approximates the means and variances of the Euclidean distance between two samples of - the z latent representation for the original sample for each cell in ``adata``. - - Reference: https://www.overleaf.com/read/mhdxcrknzxpm. - - Parameters - ---------- - batch - Batch of data to compute the local sample representation for. - mc_samples - Number of Monte Carlo samples to use for computing the local baseline distributions. - """ - mc_samples_per_cell = ( - mc_samples * 2 - ) # need double for pairs of samples to compute distance between - jit_inference_fn = self.module.get_jit_inference_fn( - inference_kwargs={"use_mean": False, "mc_samples": mc_samples_per_cell} - ) - - outputs = jit_inference_fn(self.module.rngs, batch) - - # figure out how to compute dists here - z = outputs["z"] - first_half_z, second_half_z = z[:mc_samples], z[mc_samples:] - l2_dists = jnp.sqrt(jnp.sum((first_half_z - second_half_z) ** 2, axis=2)).T - - return np.array(jnp.mean(l2_dists, axis=1)), np.array(jnp.var(l2_dists, axis=1)) - - def _compute_distances_from_representations( - self, - reps: jax.typing.ArrayLike, - cell_names: jax.typing.ArrayLike, - norm: Literal["l2", "l1", "linf"] = "l2", - return_numpy: bool = True, - ) -> xr.DataArray: - if norm not in ("l2", "l1", "linf"): - raise ValueError(f"`norm` {norm} not supported") - - @jax.jit - def _compute_distance(rep: jax.typing.ArrayLike): - delta_mat = jnp.expand_dims(rep, 0) - jnp.expand_dims(rep, 1) - if norm == "l2": - res = delta_mat**2 - res = jnp.sqrt(res.sum(-1)) - elif norm == "l1": - res = jnp.abs(delta_mat).sum(-1) - elif norm == "linf": - res = jnp.abs(delta_mat).max(-1) - return res - - if reps.ndim == 3: - dists = jax.vmap(_compute_distance)(reps) - if return_numpy: - dists = np.array(dists) - return xr.DataArray( - dists, - dims=["cell_name", "sample_x", "sample_y"], - coords={ - "cell_name": cell_names, - "sample_x": self.sample_order, - "sample_y": self.sample_order, - }, - name="sample_distances", - ) - else: - # Case with sampled representations - dists = jax.vmap(jax.vmap(_compute_distance))(reps) - if return_numpy: - dists = np.array(dists) - return xr.DataArray( - dists, - dims=["cell_name", "mc_sample", "sample_x", "sample_y"], - coords={ - "cell_name": cell_names, - "mc_sample": np.arange(reps.shape[1]), - "sample_x": self.sample_order, - "sample_y": self.sample_order, - }, - name="sample_distances", - ) - - def get_local_sample_representation( - self, - adata: AnnData | None = None, - indices: npt.ArrayLike | None = None, - batch_size: int = 256, - use_mean: bool = True, - use_vmap: Literal["auto", True, False] = "auto", - ) -> xr.DataArray: - """Compute the local sample representation of the cells in the ``adata`` object. - - For each cell, it returns a matrix of size ``(n_sample, n_features)``. - - Parameters - ---------- - adata - AnnData object to use for computing the local sample representation. - batch_size - Batch size to use for computing the local sample representation. - use_mean - Whether to use the mean of the latent representation as the local sample - representation. - use_vmap - Whether to use vmap for computing the local sample representation. - Disabling vmap can be useful if running out of memory on a GPU. - """ - reductions = [ - MRVIReduction( - name="sample_representations", - input="mean_representations" if use_mean else "sampled_representations", - fn=lambda x: x, - group_by=None, - ) - ] - return self.compute_local_statistics( - reductions, - adata=adata, - indices=indices, - batch_size=batch_size, - use_vmap=use_vmap, - ).sample_representations - - def get_local_sample_distances( - self, - adata: AnnData | None = None, - batch_size: int = 256, - use_mean: bool = True, - normalize_distances: bool = False, - use_vmap: Literal["auto", True, False] = "auto", - groupby: list[str] | str | None = None, - keep_cell: bool = True, - norm: str = "l2", - mc_samples: int = 10, - ) -> xr.Dataset: - """Compute local sample distances. - - Computes cell-specific distances between samples, of size ``(n_sample, n_sample)``, - stored as a Dataset, with variable name ``"cell"``, of size - ``(n_cell, n_sample, n_sample)``. If in addition, ``groupby`` is provided, distances are - also aggregated by group. In this case, the group-specific distances - via the group name key. - - Parameters - ---------- - adata - AnnData object to use for computing the local sample representation. - batch_size - Batch size to use for computing the local sample representation. - use_mean - Whether to use the mean of the latent representation as the local sample - representation. - normalize_distances - Whether to normalize the local sample distances. Normalizes by the standard deviation - of the original intra-sample distances. Only works with ``use_mean=False``. - use_vmap - Whether to use vmap for computing the local sample representation. Disabling vmap can - be useful if running out of memory on a GPU. - groupby - List of categorical keys or single key of the anndata that is used to group the cells. - keep_cell - Whether to keep the original cell sample-sample distance matrices. - norm - Norm to use for computing the local sample distances. - mc_samples - Number of Monte Carlo samples to use for computing the local sample distances. Only - relevant if ``use_mean=False``. - """ - use_vmap = "auto" if use_vmap == "auto" else use_vmap - - input = "mean_distances" if use_mean else "sampled_distances" - if normalize_distances: - if use_mean: - warnings.warn( - "Normalizing distances uses sampled distances. Ignoring ``use_mean``.", - UserWarning, - stacklevel=2, - ) - input = "normalized_distances" - if groupby and not isinstance(groupby, list): - groupby = [groupby] - - reductions = [] - if not keep_cell and not groupby: - raise ValueError("Undefined computation because not keep_cell and no groupby.") - if keep_cell: - reductions.append( - MRVIReduction( - name="cell", - input=input, - fn=lambda x: x, - ) - ) - if groupby: - for groupby_key in groupby: - reductions.append( - MRVIReduction( - name=groupby_key, - input=input, - group_by=groupby_key, - ) - ) - return self.compute_local_statistics( - reductions, - adata=adata, - batch_size=batch_size, - use_vmap=use_vmap, - norm=norm, - mc_samples=mc_samples, - ) - - def get_aggregated_posterior( - self, - adata: AnnData | None = None, - sample: str | int | None = None, - indices: npt.ArrayLike | None = None, - batch_size: int = 256, - ) -> Distribution: - """Computes the aggregated posterior over the ``u`` latent representations. - - For the specified samples, it computes the aggregated posterior over the ``u`` latent - representations. Returns a NumPyro MixtureSameFamily distribution. - - Parameters - ---------- - adata - AnnData object to use. Defaults to the AnnData object used to initialize the model. - sample - Name or index of the sample to filter on. If ``None``, uses all cells. - indices - Indices of cells to use. - batch_size - Batch size to use for computing the latent representation. - - Returns - ------- - A mixture distribution of the aggregated posterior. - """ - from numpyro.distributions import ( - Categorical, - MixtureSameFamily, - MultivariateNormal, - ) - - self._check_if_trained(warn=False) - adata = self._validate_anndata(adata) - if indices is None: - indices = np.arange(adata.n_obs) - if sample is not None: - indices = np.intersect1d( - np.array(indices), np.where(adata.obs[self.sample_key] == sample)[0] - ) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True - ) - - qu_locs = [] - qu_scales = [] - jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"use_mean": True}) - for array_dict in scdl: - outputs = jit_inference_fn(self.module.rngs, array_dict) - - qu_locs.append(outputs["qu"].loc) - qu_scales.append(outputs["qu"].scale) - - qu_loc = jnp.concatenate(qu_locs, axis=0) # n_cells x n_latent_u - qu_scale = jnp.concatenate(qu_scales, axis=0) # n_cells x n_latent_u - # Use MultivariateNormal with diagonal covariance instead of Normal.to_event(1) - # because numpyro MixtureSameFamily requires ParameterFreeConstraint support - scale_tril = jax.vmap(jnp.diag)(qu_scale) - return MixtureSameFamily( - Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]), - MultivariateNormal(qu_loc, scale_tril=scale_tril), - ) - - def differential_abundance( - self, - adata: AnnData | None = None, - sample_cov_keys: list[str] | None = None, - sample_subset: list[str] | None = None, - compute_log_enrichment: bool = False, - omit_original_sample: bool = True, - batch_size: int = 128, - ) -> xr.Dataset: - """Compute the differential abundance between samples. - - Computes the logarithm of the ratio of the probabilities of each sample conditioned on the - estimated aggregate posterior distribution of each cell. - - Parameters - ---------- - adata - The data object to compute the differential abundance for. - If not given, the data object stored in the model is used. - sample_cov_keys - Keys for covariates (batch, etc.) that should also be taken into account - when computing the differential abundance. At the moment, only discrete covariates are - supported. - sample_subset - Only computes differential abundance for these sample labels. - compute_log_enrichment - Whether to compute the log enrichment scores for each covariate value. - omit_original_sample - If true, each cell's sample-of-origin is discarded to compute aggregate posteriors. - Only relevant if sample_cov_keys is not None. - batch_size - Minibatch size for computing the differential abundance. - - Returns - ------- - A dataset with data variables: - - * ``"log_probs"``: Array of shape ``(n_cells, n_samples)`` containing the log probabilities - for each cell across samples. - * ``"{cov_key}_log_probs"``: For each key in ``sample_cov_keys``, an array of shape - ``(n_cells, _cov_values)`` containing the log probabilities for each cell across - covariate values. - """ - from pandas import DataFrame - from scipy.special import logsumexp - - adata = self._validate_anndata(adata) - - if sample_cov_keys is not None: - for key in sample_cov_keys: - n_cov_values = len(adata.obs[key].unique()) - n_samples = len(adata.obs[self.sample_key].unique()) - if n_cov_values > n_samples / 2: - warnings.warn( - f"The covariate '{key}' does not seem to refer to a discrete key. " - f"It has {n_cov_values} unique values, which exceeds one half of the " - f"total samples ({n_samples}).", - UserWarning, - stacklevel=2, - ) - - us = self.get_latent_representation( - adata, use_mean=True, give_z=False, batch_size=batch_size - ) - - log_probs = [] - unique_samples = adata.obs[self.sample_key].unique() - for sample_name in tqdm(unique_samples): - ap = self.get_aggregated_posterior( - adata=adata, sample=sample_name, batch_size=batch_size - ) - n_splits = max(adata.n_obs // batch_size, 1) - log_probs_ = [] - for u_rep in np.array_split(us, n_splits): - log_probs_.append(jax.device_get(ap.log_prob(u_rep))[..., np.newaxis]) - - log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1) - - log_probs = np.concatenate(log_probs, 1) - - coords = { - "cell_name": adata.obs_names.to_numpy(), - "sample": unique_samples, - } - data_vars = { - "log_probs": (["cell_name", "sample"], log_probs), - } - log_probs_arr = xr.Dataset(data_vars, coords=coords) - - if sample_cov_keys is None or len(sample_cov_keys) == 0: - return log_probs_arr - - def aggregate_log_probs(log_probs, samples, omit_original_sample=False): - sample_log_probs = log_probs.loc[ - {"sample": samples} - ].values # (n_cells, n_samples_in_group) - if omit_original_sample: - sample_one_hot = np.zeros((adata.n_obs, len(samples))) - for i, sample in enumerate(samples): - sample_one_hot[adata.obs[self.sample_key] == sample, i] = 1 - log_probs_no_original = np.where( - sample_one_hot, -np.inf, sample_log_probs - ) # virtually discards samples-of-origin from aggregate posteriors - return logsumexp(log_probs_no_original, axis=1) - np.log( - (1 - sample_one_hot).sum(axis=1) - ) - else: - return logsumexp(sample_log_probs, axis=1) - np.log(sample_log_probs.shape[1]) - - sample_cov_log_probs_map = {} - sample_cov_log_enrichs_map = {} - for sample_cov_key in sample_cov_keys: - sample_cov_unique_values = self.sample_info[sample_cov_key].unique() - per_val_log_probs = {} - per_val_log_enrichs = {} - for sample_cov_value in sample_cov_unique_values: - cov_samples = ( - self.sample_info[self.sample_info[sample_cov_key] == sample_cov_value] - )[self.sample_key].to_numpy() - if sample_subset is not None: - cov_samples = np.intersect1d(cov_samples, np.array(sample_subset)) - if len(cov_samples) == 0: - continue - - val_log_probs = aggregate_log_probs( - log_probs_arr.log_probs, - cov_samples, - omit_original_sample=omit_original_sample, - ) - per_val_log_probs[sample_cov_value] = val_log_probs - - if compute_log_enrichment: - rest_samples = np.setdiff1d(unique_samples, cov_samples) - if len(rest_samples) == 0: - warnings.warn( - f"All samples have {sample_cov_key}={sample_cov_value}. Skipping log " - "enrichment computation.", - UserWarning, - stacklevel=2, - ) - continue - rest_val_log_probs = aggregate_log_probs( - log_probs_arr.log_probs, - rest_samples, - omit_original_sample=omit_original_sample, - ) - enrichment_scores = val_log_probs - rest_val_log_probs - per_val_log_enrichs[sample_cov_value] = enrichment_scores - sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict(per_val_log_probs) - if compute_log_enrichment and len(per_val_log_enrichs) > 0: - sample_cov_log_enrichs_map[sample_cov_key] = DataFrame.from_dict( - per_val_log_enrichs - ) - - coords = { - "cell_name": adata.obs_names.to_numpy(), - "sample": unique_samples, - **{ - sample_cov_key: sample_cov_log_probs.columns - for sample_cov_key, sample_cov_log_probs in sample_cov_log_probs_map.items() - }, - } - data_vars = { - "log_probs": (["cell_name", "sample"], log_probs), - **{ - f"{sample_cov_key}_log_probs": ( - ["cell_name", sample_cov_key], - sample_cov_log_probs.values, - ) - for sample_cov_key, sample_cov_log_probs in sample_cov_log_probs_map.items() - }, - } - if compute_log_enrichment: - data_vars.update( - { - f"{sample_key}_log_enrichs": ( - ["cell_name", sample_key], - sample_log_enrichs.values, - ) - for sample_key, sample_log_enrichs in sample_cov_log_enrichs_map.items() - } - ) - return xr.Dataset(data_vars, coords=coords) - - def get_outlier_cell_sample_pairs( - self, - adata: AnnData | None = None, - subsample_size: int = 5_000, - quantile_threshold: float = 0.05, - admissibility_threshold: float = 0.0, - batch_size: int = 256, - ) -> xr.Dataset: - """Compute admissibility scores for cell-sample pairs. - - This function computes the posterior distribution for u for each cell. Then, for every - cell, it computes the log-probability of the cell under the posterior of each cell - each sample and takes the maximum value for a given sample as a measure of admissibility - for that sample. Additionally, it computes a threshold that determines if - a cell-sample pair is admissible based on the within-sample admissibility scores. - - Parameters - ---------- - adata - AnnData object containing the cells for which to compute the outlier cell-sample pairs. - subsample_size - Number of cells to use from each sample to approximate the posterior. If None, uses all - of the available cells. - quantile_threshold - Quantile of the within-sample log probabilities to use as a baseline for admissibility. - admissibility_threshold - Threshold for admissibility. Cell-sample pairs with admissibility below this threshold - are considered outliers. - batch_size - Size of the batch to use for computing outlier cell-sample pairs. - """ - adata = self._validate_anndata(adata) - us = self.get_latent_representation(adata, use_mean=True, give_z=False) - adata.obsm["U"] = us - - log_probs = [] - threshs = [] - unique_samples = adata.obs[self.sample_key].unique() - for sample_name in tqdm(unique_samples): - sample_idxs = np.where(adata.obs[self.sample_key] == sample_name)[0] - if subsample_size is not None and sample_idxs.shape[0] > subsample_size: - sample_idxs = np.random.choice(sample_idxs, size=subsample_size, replace=False) - adata_s = adata[sample_idxs] - - ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) - in_max_comp_log_probs = ap.component_distribution.log_prob( - np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim) - ) # (n_cells_ap, n_cells_ap) - log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) - - log_probs_ = [] - n_splits = adata.n_obs // batch_size - for u_rep in np.array_split(adata.obsm["U"], n_splits): - log_probs_.append( - jax.device_get( - ap.component_distribution.log_prob( - np.expand_dims( - u_rep, ap.mixture_dim - ) # (n_cells_batch, 1, n_latent_dim) - ).max( # (n_cells_batch, n_cells_ap) - axis=1, keepdims=True - ) # (n_cells_batch, 1) - ) - ) - - log_probs_ = np.concatenate(log_probs_, axis=0) # (n_cells, 1) - - threshs.append(np.array(log_probs_s)) - log_probs.append(np.array(log_probs_)) - - threshs_all = np.concatenate(threshs) - global_thresh = np.quantile(threshs_all, q=quantile_threshold) - threshs = np.array(len(log_probs) * [global_thresh]) - - log_probs = np.concatenate(log_probs, 1) - log_ratios = log_probs - threshs - - coords = { - "cell_name": adata.obs_names.to_numpy(), - "sample": unique_samples, - } - data_vars = { - "log_probs": (["cell_name", "sample"], log_probs), - "log_ratios": ( - ["cell_name", "sample"], - log_ratios, - ), - "is_admissible": ( - ["cell_name", "sample"], - log_ratios > admissibility_threshold, - ), - } - return xr.Dataset(data_vars, coords=coords) - - def differential_expression( - self, - adata: AnnData | None = None, - sample_cov_keys: list[str] | None = None, - sample_subset: list[str] | None = None, - batch_size: int = 128, - use_vmap: Literal["auto", True, False] = "auto", - normalize_design_matrix: bool = True, - add_batch_specific_offsets: bool = False, - mc_samples: int = 100, - store_lfc: bool = False, - store_lfc_metadata_subset: list[str] | None = None, - store_baseline: bool = False, - eps_lfc: float = 1e-4, - filter_inadmissible_samples: bool = False, - lambd: float = 0.0, - delta: float | None = 0.3, - **filter_samples_kwargs, - ) -> xr.Dataset: - """Compute cell-specific multivariate differential expression. - - For every cell, we first compute all counterfactual cell-state shifts, defined as - ``e_d = z_d - u``, where ``z_d`` is the latent representation of the cell for sample ``d`` - and ``u`` is the sample-unaware latent representation. Then, we fit a linear model in each - cell of the form: ``e_d = X_d * beta_d + iid gaussian noise``. - - Parameters - ---------- - sample_cov_keys - List of sample covariates to consider for the multivariate analysis. - These keys should be present in ``adata.obs``. - adata - AnnData object to use for computing the local sample representation. - If ``None``, the analysis is performed on all cells in the dataset. - sample_subset - Optional list of samples to consider for the multivariate analysis. - If ``None``, all samples are considered. - batch_size - Batch size to use for computing the local sample representation. - use_vmap - Whether to use vmap for computing the local sample representation. - normalize_design_matrix - Whether to normalize the design matrix. - add_batch_specific_offsets - Whether to offset the design matrix by adding batch-specific offsets to the design - matrix. Setting this option to True is recommended when considering multi-site - datasets. - mc_samples - How many MC samples should be taken for computing betas. - store_lfc - Whether to store the log-fold changes in the module. - Storing log-fold changes is memory-intensive and may require specifying - a smaller set of cells to analyze, e.g., by specifying ``adata``. - store_lfc_metadata_subset - Specifies a subset of metadata for which log-fold changes are computed. - These keys must be a subset of ``sample_cov_keys``. - Only applies when ``store_lfc=True``. - store_baseline - Whether to store the expression in the module if logfoldchanges are computed. - eps_lfc - Epsilon to add to the log-fold changes to avoid detecting genes with low expression. - filter_inadmissible_samples - Whether to filter out-of-distribution samples prior to performing the analysis. - lambd - Regularization parameter for the linear model. - delta - LFC threshold used to compute posterior DE probabilities. - If None does not compute them to save memory consumption. - filter_samples_kwargs - Keyword arguments to pass to :meth:`~scvi.external.MRVI.get_outlier_cell_sample_pairs`. - - Returns - ------- - A dataset containing the results of the differential expression analysis: - - * ``"beta"``: Coefficients for each covariate across cells and latent dimensions. - * ``"effect_size"``: Effect sizes for each covariate across cells. - * ``"pvalue"``: P-values for each covariate across cells. - * ``"padj"``: Adjusted P-values for each covariate across cells using the - Benjamini-Hochberg procedure. - * ``"lfc"``: Log fold changes for each covariate across cells and genes, if ``store_lfc`` - is ``True``. - * ``"lfc_std"``: Standard deviation of log fold changes, if ``store_lfc`` is ``True`` and - ``delta`` is not ``None``. - * ``"pde"``: Posterior DE probabilities, if ``store_lfc`` is ``True`` and ``delta`` is not - ``None``. - * ``"baseline_expression"``: Baseline expression levels for each covariate across cells and - genes, if ``store_baseline`` is ``True``. - * ``"n_samples"``: Number of admissible samples for each cell, if - ``filter_inadmissible_samples`` is ``True``. - """ - from functools import partial - - from scipy.stats import false_discovery_control - - use_vmap = use_vmap if use_vmap != "auto" else self.summary_stats.n_sample < 500 - - if sample_cov_keys is None: - # Hack: kept as kwarg to maintain the order of arguments. - raise ValueError("Must assign `sample_cov_keys`") - adata = self.adata if adata is None else adata - self._check_if_trained(warn=False) - # Hack to ensure new AnnDatas have indices and indices have correct dimensions. - if adata is not None: - adata.obs["_indices"] = np.arange(adata.n_obs).astype(int) - - adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=None, batch_size=batch_size, iter_ndarray=True - ) - n_sample = self.summary_stats.n_sample - vars_in = {"params": self.module.params, **self.module.state} - - sample_mask = ( - np.isin(self.sample_order, sample_subset) - if sample_subset is not None - else np.ones(n_sample, dtype=bool) - ) - sample_mask = np.array(sample_mask) - sample_order = self.sample_order[sample_mask] - n_samples_kept = sample_mask.sum() - - if filter_inadmissible_samples: - admissible_samples = self.get_outlier_cell_sample_pairs( - adata=adata, **filter_samples_kwargs - )["is_admissible"].loc[{"sample": sample_order}] - assert (admissible_samples.sample == sample_order).all() - admissible_samples = admissible_samples.values - else: - admissible_samples = np.ones((adata.n_obs, n_samples_kept), dtype=bool) - n_admissible_samples = admissible_samples.sum(1) - - ( - Xmat, - Xmat_names, - covariates_require_lfc, - offset_indices, - ) = self._construct_design_matrix( - sample_cov_keys=sample_cov_keys, - sample_mask=sample_mask, - normalize_design_matrix=normalize_design_matrix, - add_batch_specific_offsets=add_batch_specific_offsets, - store_lfc=store_lfc, - store_lfc_metadata_subset=store_lfc_metadata_subset, - ) - add_batch_specific_offsets = offset_indices is not None - n_covariates = Xmat.shape[1] - - @partial(jax.jit, backend="cpu") - def process_design_matrix( - admissible_samples_dmat: jax.typing.ArrayLike, - Xmat: jax.typing.ArrayLike, - ) -> tuple[jax.Array, jax.Array]: - xtmx = jnp.einsum("ak,nkl,lm->nam", Xmat.T, admissible_samples_dmat, Xmat) - xtmx += lambd * jnp.eye(n_covariates) - - prefactor = jnp.real(jax.vmap(jax.scipy.linalg.sqrtm)(xtmx)) - inv_ = jax.vmap(jnp.linalg.pinv)(xtmx) - Amat = jnp.einsum("nab,bc,ncd->nad", inv_, Xmat.T, admissible_samples_dmat) - return Amat, prefactor - - @partial(jax.jit, static_argnames=["use_mean", "mc_samples"]) - def mapped_inference_fn( - stacked_rngs: dict[str, jax.random.KeyArray], - x: jax.typing.ArrayLike, - sample_index: jax.typing.ArrayLike, - cf_sample: jax.typing.ArrayLike, - Amat: jax.typing.ArrayLike, - prefactor: jax.typing.ArrayLike, - n_samples_per_cell: int, - admissible_samples_mat: jax.typing.ArrayLike, - use_mean: bool, - mc_samples: int, - rngs_de=None, - ): - def inference_fn( - rngs, - cf_sample, - ): - return self.module.apply( - vars_in, - rngs=rngs, - method=self.module.inference, - x=x, - sample_index=sample_index, - cf_sample=cf_sample, - use_mean=use_mean, - mc_samples=mc_samples, - )["eps"] - - if use_vmap: - eps_ = jax.vmap(inference_fn, in_axes=(0, 0), out_axes=-2)( - stacked_rngs, - cf_sample, - ) - else: - - def per_sample_inference_fn(pair): - rngs, cf_sample = pair - return inference_fn(rngs, cf_sample) - - # eps_ has shape (mc_samples, n_cells, n_samples, n_latent) - eps_ = jax.lax.transpose( - jax.lax.map(per_sample_inference_fn, (stacked_rngs, cf_sample)), - (1, 2, 0, 3), - ) - eps_std = eps_.std(axis=2, keepdims=True) - eps_mean = eps_.mean(axis=2, keepdims=True) - - eps = (eps_ - eps_mean) / (1e-6 + eps_std) # over samples - # MLE for betas - betas = jnp.einsum("nks,ansd->ankd", Amat, eps) - - # Statistical tests - betas_norm = jnp.einsum("ankd,nkl->anld", betas, prefactor) - ts = (betas_norm**2).mean(axis=0).sum(axis=-1) - pvals = 1 - jnp.nan_to_num( - jax.scipy.stats.chi2.cdf(ts, df=n_samples_per_cell[:, None]), nan=0.0 - ) - - betas = betas * eps_std - - lfc_mean = None - lfc_std = None - pde = None - if store_lfc: - betas_ = betas.transpose((0, 2, 1, 3)) - eps_mean_ = eps_mean.transpose((0, 2, 1, 3)) - betas_covariates = betas_[:, covariates_require_lfc, :, :] - - def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): - extra_eps += batch_offset_eps - - return self.module.apply( - vars_in, - rngs=rngs_de, - method=self.module.compute_h_from_x_eps, - x=x, - extra_eps=extra_eps, - sample_index=sample_index, - batch_index=batch_index_cf, - cf_sample=None, - mc_samples=None, # mc_samples also taken for eps. vmap over mc_samples - ) - - batch_index_ = jnp.arange(self.summary_stats.n_batch)[:, None] - batch_index_ = jnp.repeat(batch_index_, repeats=n_cells, axis=1)[ - ..., None - ] # (n_batch, n_cells, 1) - betas_null = jnp.zeros_like(betas_covariates) - - if add_batch_specific_offsets: - batch_weights = jnp.einsum( - "nd,db->nb", admissible_samples_mat, Xmat[:, offset_indices] - ).mean(0) - betas_offset_ = betas_[:, offset_indices, :, :] + eps_mean_ - else: - batch_weights = (1.0 / self.summary_stats.n_batch) * jnp.ones( - self.summary_stats.n_batch - ) - mc_samples, _, n_cells_, n_latent = betas_covariates.shape - betas_offset_ = ( - jnp.zeros((mc_samples, self.summary_stats.n_batch, n_cells_, n_latent)) - + eps_mean_ - ) - # batch_offset shape (mc_samples, n_batch, n_cells, n_latent) - - f_ = jax.vmap( - h_inference_fn, in_axes=(0, None, 0), out_axes=0 - ) # fn over MC samples - f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates - f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches - h_fn = jax.jit(f_) - - x_1 = h_fn(betas_covariates, batch_index_, betas_offset_) - x_0 = h_fn(betas_null, batch_index_, betas_offset_) - - lfcs = jnp.log2(x_1 + eps_lfc) - jnp.log2(x_0 + eps_lfc) - lfc_mean = jnp.average(lfcs.mean(1), weights=batch_weights, axis=0) - if delta is not None: - lfc_std = jnp.sqrt(jnp.average(lfcs.var(1), weights=batch_weights, axis=0)) - pde = (jnp.abs(lfcs) >= delta).mean(1).mean(0) - - if store_baseline: - baseline_expression = x_1.mean(1) - else: - baseline_expression = None - return { - "beta": betas.mean(0), - "effect_size": ts, - "pvalue": pvals, - "lfc_mean": lfc_mean, - "lfc_std": lfc_std, - "pde": pde, - "baseline_expression": baseline_expression, - } - - beta = [] - effect_size = [] - pvalue = [] - lfc = [] - lfc_std = [] - pde = [] - baseline_expression = [] - for array_dict in tqdm(scdl): - indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() - n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to( - (np.where(sample_mask)[0])[:, None, None], (n_samples_kept, n_cells, 1) - ) - inf_inputs = self.module._get_inference_input( - array_dict, - ) - stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) - - rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array(admissible_samples[indices]) - n_samples_per_cell = admissible_samples_mat.sum(axis=1) - admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( - float - ) # (n_cells, n_samples, n_samples) - # element nij is 1 if sample i is admissible and i=j for cell n - Amat, prefactor = process_design_matrix(admissible_samples_dmat, Xmat) - Amat = jax.device_put(Amat, self.device) - prefactor = jax.device_put(prefactor, self.device) - - try: - res = mapped_inference_fn( - stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), - Amat=Amat, - prefactor=prefactor, - n_samples_per_cell=n_samples_per_cell, - admissible_samples_mat=admissible_samples_mat, - use_mean=False, - rngs_de=rngs_de, - mc_samples=mc_samples, - ) - except jax.errors.JaxRuntimeError as e: - if use_vmap: - raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e - else: - raise e - - beta.append(np.array(res["beta"])) - effect_size.append(np.array(res["effect_size"])) - pvalue.append(np.array(res["pvalue"])) - if store_lfc: - lfc.append(np.array(res["lfc_mean"])) - if delta is not None: - lfc_std.append(np.array(res["lfc_std"])) - pde.append(np.array(res["pde"])) - if store_baseline: - baseline_expression.append(np.array(res["baseline_expression"])) - beta = np.concatenate(beta, axis=0) - effect_size = np.concatenate(effect_size, axis=0) - pvalue = np.concatenate(pvalue, axis=0) - pvalue_shape = pvalue.shape - padj = false_discovery_control(pvalue.flatten(), method="bh").reshape(pvalue_shape) - - coords = { - "cell_name": (("cell_name"), adata.obs_names), - "covariate": (("covariate"), Xmat_names), - "latent_dim": (("latent_dim"), np.arange(beta.shape[2])), - "gene": (("gene"), adata.var_names), - } - data_vars = { - "beta": ( - ["cell_name", "covariate", "latent_dim"], - beta, - ), - "effect_size": ( - ["cell_name", "covariate"], - effect_size, - ), - "pvalue": ( - ["cell_name", "covariate"], - pvalue, - ), - "padj": ( - ["cell_name", "covariate"], - padj, - ), - } - if filter_inadmissible_samples: - data_vars["n_samples"] = ( - ["cell_name"], - n_admissible_samples, - ) - if store_lfc: - if store_lfc_metadata_subset is None and not add_batch_specific_offsets: - coords_lfc = ["covariate", "cell_name", "gene"] - else: - coords_lfc = ["covariate_sub", "cell_name", "gene"] - coords["covariate_sub"] = ( - ("covariate_sub"), - Xmat_names[covariates_require_lfc], - ) - lfc = np.concatenate(lfc, axis=1) - data_vars["lfc"] = (coords_lfc, lfc) - if delta is not None: - lfc_std = np.concatenate(lfc_std, axis=1) - pde = np.concatenate(pde, axis=1) - data_vars["lfc_std"] = (coords_lfc, lfc_std) - data_vars["pde"] = (coords_lfc, pde) - - if store_baseline: - baseline_expression = np.concatenate(baseline_expression, axis=1) - data_vars["baseline_expression"] = ( - ["covariate", "cell_name", "gene"], - baseline_expression, - ) - return xr.Dataset(data_vars, coords=coords) - - def _construct_design_matrix( - self, - sample_cov_keys: list[str], - sample_mask: npt.ArrayLike, - normalize_design_matrix: bool, - add_batch_specific_offsets: bool, - store_lfc: bool, - store_lfc_metadata_subset: list[str] | None = None, - ) -> tuple[jax.Array, npt.NDArray, jax.Array, jax.Array | None]: - """Construct a design matrix of samples and covariates. - - Starting from a list of sample covariate keys, construct a design matrix of samples and - covariates. Categorical covariates are one-hot encoded. - - Parameters - ---------- - sample_cov_keys - List of sample metadata to use as covariates. - sample_mask - Mask of admissible samples. Must have the same length as the number of samples in the - dataset. - normalize_design_matrix - Whether the design matrix should be 0-1 normalized. This is useful to ensure that the - beta coefficients are comparable across covariates. - add_batch_specific_offsets - Whether the design matrix should be offset. If True, the matrix includes batch-specific - offsets. This ensures that we can learn perturbation effects that do not depend on - batch effects. - - Returns - ------- - A tuple consisting of: - - 1. The design matrix - 2. Names for each column in the design matrix - 3. A mask précising which coefficients from the design matrix requires computing LFCs. - 4. A mask précising which coefficients from the design matrix correspond to offsets. - """ - from pandas import Series, get_dummies - - Xmat = [] - Xmat_names = [] - Xmat_dim_to_key = [] - sample_info = self.sample_info.iloc[sample_mask] - for sample_cov_key in tqdm(sample_cov_keys): - cov = sample_info[sample_cov_key] - if (cov.dtype == str) or (cov.dtype == "category"): - cov = cov.cat.remove_unused_categories() - cov = get_dummies(cov, drop_first=True) - cov_names = np.array([f"{sample_cov_key}_{col}" for col in cov.columns]) - cov = cov.values - else: - cov_names = np.array([sample_cov_key]) - cov = cov.values[:, None] - n_covs = cov.shape[1] - Xmat.append(cov) - Xmat_names.append(cov_names) - Xmat_dim_to_key.append([sample_cov_key] * n_covs) - Xmat_names = np.concatenate(Xmat_names) - Xmat = np.concatenate(Xmat, axis=1).astype(np.float32) - Xmat_dim_to_key = np.concatenate(Xmat_dim_to_key) - - if normalize_design_matrix: - Xmat = (Xmat - Xmat.min(axis=0)) / (1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0)) - if add_batch_specific_offsets: - cov = sample_info["_scvi_batch"] - if cov.nunique() == self.summary_stats.n_batch: - cov = np.eye(self.summary_stats.n_batch)[sample_info["_scvi_batch"].values] - cov_names = ["offset_batch_" + str(i) for i in range(self.summary_stats.n_batch)] - Xmat = np.concatenate([cov, Xmat], axis=1) - Xmat_names = np.concatenate([np.array(cov_names), Xmat_names]) - Xmat_dim_to_key = np.concatenate([np.array(cov_names), Xmat_dim_to_key]) - - # Retrieve indices of offset covariates in the right order - offset_indices = ( - Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values - ) - offset_indices = jnp.array(offset_indices) - else: - warnings.warn( - """ - Number of batches in sample_info does not match number of batches in - summary_stats. `add_batch_specific_offsets=True` assumes that samples are - not shared across batches. Setting `add_batch_specific_offsets=False`... - """, - stacklevel=2, - ) - offset_indices = None - else: - offset_indices = None - - Xmat = jnp.array(Xmat) - if store_lfc: - covariates_require_lfc = ( - np.isin(Xmat_dim_to_key, store_lfc_metadata_subset) - if store_lfc_metadata_subset is not None - else np.isin(Xmat_dim_to_key, sample_cov_keys) - ) - else: - covariates_require_lfc = np.zeros(len(Xmat_names), dtype=bool) - covariates_require_lfc = jnp.array(covariates_require_lfc) - - return Xmat, Xmat_names, covariates_require_lfc, offset_indices - - def update_sample_info(self, adata): - """Initialize/update metadata in the case where additional covariates are added. - - Parameters - ---------- - adata - AnnData object to update the sample info with. Typically, this corresponds to the - working dataset, where additional sample-specific covariates have been added. - - Examples - -------- - >>> import scanpy as sc - >>> from scvi.external import MRVI - >>> MRVI.setup_anndata(adata, sample_key="sample_id") - >>> model = MRVI(adata) - >>> model.train() - >>> # Update sample info with new covariates - >>> sample_mapper = {"sample_1": "healthy", "sample_2": "disease"} - >>> adata.obs["disease_status"] = adata.obs["sample_id"].map(sample_mapper) - >>> model.update_sample_info(adata) - """ - adata = self._validate_anndata(adata) - obs_df = adata.obs.copy() - obs_df = obs_df.loc[~obs_df._scvi_sample.duplicated("first")] - self.sample_info = obs_df.set_index("_scvi_sample").sort_index() diff --git a/src/scvi/external/mrvi_jax/_module.py b/src/scvi/external/mrvi_jax/_module.py deleted file mode 100644 index 88612069a2..0000000000 --- a/src/scvi/external/mrvi_jax/_module.py +++ /dev/null @@ -1,598 +0,0 @@ -from __future__ import annotations - -import warnings -from typing import TYPE_CHECKING - -import flax.linen as nn -import jax -import jax.numpy as jnp -import numpyro.distributions as dist - -from scvi import REGISTRY_KEYS, settings -from scvi.distributions import JaxNegativeBinomialMeanDisp as NegativeBinomial -from scvi.external.mrvi_jax._components import AttentionBlock, Dense -from scvi.module._jaxvae import LossOutput -from scvi.module.base import JaxBaseModuleClass, flax_configure - -if TYPE_CHECKING: - from collections.abc import Callable - from typing import Any - - -DEFAULT_PX_KWARGS = { - "n_hidden": 32, - "stop_gradients": False, - "stop_gradients_mlp": True, - "dropout_rate": 0.03, -} -DEFAULT_QZ_ATTENTION_KWARGS = { - "use_map": True, - "stop_gradients": False, - "stop_gradients_mlp": True, - "dropout_rate": 0.03, -} -DEFAULT_QU_KWARGS = {} - -# Lower stddev leads to better initial loss values -_normal_initializer = jax.nn.initializers.normal(stddev=0.1) - - -class DecoderZXAttention(nn.Module): - """Attention-based decoder. - - Parameters - ---------- - n_in - Number of input features. - n_out - Number of output features. - n_batch - Number of batches. - n_latent_sample - Number of latent samples. - h_activation - Activation function for the output layer. - n_channels - Number of channels in the attention block. - n_heads - Number of heads in the attention block. - dropout_rate - Dropout rate. - stop_gradients - Whether to stop gradients to ``z``. - stop_gradients_mlp - Whether to stop gradients to the MLP in the attention block. - training - Whether the model is in training mode. - n_hidden - Number of hidden units in the MLP. - n_layers - Number of layers in the MLP. - low_dim_batch - Whether to use low-dimensional batch embeddings. - activation - Activation function for the MLP. - """ - - n_in: int - n_out: int - n_batch: int - n_latent_sample: int = 16 - h_activation: Callable[[jax.typing.ArrayLike], jax.Array] = nn.softmax - n_channels: int = 4 - n_heads: int = 2 - dropout_rate: float = 0.1 - stop_gradients: bool = False - stop_gradients_mlp: bool = False - training: bool | None = None - n_hidden: int = 32 - n_layers: int = 1 - training: bool | None = None - low_dim_batch: bool = True - activation: Callable[[jax.typing.ArrayLike], jax.Array] = nn.gelu - - @nn.compact - def __call__( - self, - z: jax.typing.ArrayLike, - batch_covariate: jax.typing.ArrayLike, - size_factor: jax.typing.ArrayLike, - training: bool | None = None, - ) -> NegativeBinomial: - has_mc_samples = z.ndim == 3 - z_stop = z if not self.stop_gradients else jax.lax.stop_gradient(z) - z_ = nn.LayerNorm(name="u_ln")(z_stop) - - batch_covariate = batch_covariate.astype(int).flatten() - - if self.n_batch >= 2: - batch_embed = nn.Embed( - self.n_batch, self.n_latent_sample, embedding_init=_normal_initializer - )(batch_covariate) # (batch, n_latent_sample) - batch_embed = nn.LayerNorm(name="batch_embed_ln")(batch_embed) - if has_mc_samples: - batch_embed = jnp.tile(batch_embed, (z_.shape[0], 1, 1)) - - res_dim = self.n_in if self.low_dim_batch else self.n_out - - query_embed = z_ - kv_embed = batch_embed - residual = AttentionBlock( - query_dim=self.n_in, - out_dim=res_dim, - outerprod_dim=self.n_latent_sample, - n_channels=self.n_channels, - n_heads=self.n_heads, - dropout_rate=self.dropout_rate, - n_hidden_mlp=self.n_hidden, - n_layers_mlp=self.n_layers, - stop_gradients_mlp=self.stop_gradients_mlp, - training=training, - activation=self.activation, - )(query_embed=query_embed, kv_embed=kv_embed) - - if self.low_dim_batch: - mu = nn.Dense(self.n_out)(z + residual) - else: - mu = nn.Dense(self.n_out)(z) + residual - else: - mu = nn.Dense(self.n_out)(z_) - mu = self.h_activation(mu) - return NegativeBinomial( - mean=mu * size_factor, - inverse_dispersion=jnp.exp(self.param("px_r", jax.random.normal, (self.n_out,))), - ) - - -class EncoderUZ(nn.Module): - """Attention-based encoder from ``u`` to ``z``. - - Parameters - ---------- - n_latent - Number of latent variables. - n_sample - Number of samples. - n_latent_u - Number of latent variables for ``u``. - n_latent_sample - Number of latent samples. - n_channels - Number of channels in the attention block. - n_heads - Number of heads in the attention block. - dropout_rate - Dropout rate. - stop_gradients - Whether to stop gradients to ``u``. - stop_gradients_mlp - Whether to stop gradients to the MLP in the attention block. - use_map - Whether to use the MAP estimate to approximate the posterior of ``z`` given ``u`` - n_hidden - Number of hidden units in the MLP. - n_layers - Number of layers in the MLP. - training - Whether the model is in training mode. - activation - Activation function for the MLP. - """ - - n_latent: int - n_sample: int - n_latent_u: int | None = None - n_latent_sample: int = 16 - n_channels: int = 4 - n_heads: int = 2 - dropout_rate: float = 0.0 - stop_gradients: bool = False - stop_gradients_mlp: bool = False - use_map: bool = True - n_hidden: int = 32 - n_layers: int = 1 - training: bool | None = None - activation: Callable[[jax.typing.ArrayLike], jax.Array] = nn.gelu - - @nn.compact - def __call__( - self, - u: jax.typing.ArrayLike, - sample_covariate: jax.typing.ArrayLike, - training: bool | None = None, - ) -> tuple[jax.Array, jax.Array]: - training = nn.merge_param("training", self.training, training) - sample_covariate = sample_covariate.astype(int).flatten() - self.n_latent_u if self.n_latent_u is not None else self.n_latent # noqa: B018 - has_mc_samples = u.ndim == 3 - u_stop = u if not self.stop_gradients else jax.lax.stop_gradient(u) - u_ = nn.LayerNorm(name="u_ln")(u_stop) - - sample_embed = nn.Embed( - self.n_sample, self.n_latent_sample, embedding_init=_normal_initializer - )(sample_covariate) # (batch, n_latent_sample) - sample_embed = nn.LayerNorm(name="sample_embed_ln")(sample_embed) - if has_mc_samples: - sample_embed = jnp.tile(sample_embed, (u_.shape[0], 1, 1)) - - n_outs = 1 if self.use_map else 2 - residual = AttentionBlock( - query_dim=self.n_latent, - out_dim=n_outs * self.n_latent, - outerprod_dim=self.n_latent_sample, - n_channels=self.n_channels, - n_heads=self.n_heads, - dropout_rate=self.dropout_rate, - stop_gradients_mlp=self.stop_gradients_mlp, - n_hidden_mlp=self.n_hidden, - n_layers_mlp=self.n_layers, - training=training, - activation=self.activation, - )(query_embed=u_, kv_embed=sample_embed) - - if self.n_latent_u is not None: - z_base = nn.Dense(self.n_latent)(u_stop) - return z_base, residual - else: - return u, residual - - -class EncoderXU(nn.Module): - """Encoder from ``x`` to ``u``. - - Parameters - ---------- - n_latent - Number of latent variables. - n_sample - Number of samples. - n_hidden - Number of hidden units in the MLP. - n_layers - Number of layers in the MLP. - activation - Activation function for the MLP. - training - Whether the model is in training mode. - """ - - n_latent: int - n_sample: int - n_hidden: int - n_layers: int = 1 - activation: Callable[[jax.typing.ArrayLike], jax.Array] = nn.gelu - training: bool | None = None - - @nn.compact - def __call__( - self, - x: jax.typing.ArrayLike, - sample_covariate: jax.typing.ArrayLike, - training: bool | None = None, - ) -> dist.Normal: - from scvi.external.mrvi_jax._components import ( - ConditionalNormalization, - NormalDistOutputNN, - ) - - training = nn.merge_param("training", self.training, training) - x_feat = jnp.log1p(x) - for _ in range(2): - x_feat = Dense(self.n_hidden)(x_feat) - x_feat = ConditionalNormalization(self.n_hidden, self.n_sample)( - x_feat, sample_covariate, training=training - ) - x_feat = self.activation(x_feat) - sample_effect = nn.Embed(self.n_sample, self.n_hidden, embedding_init=_normal_initializer)( - sample_covariate.squeeze(-1).astype(int) - ) - inputs = x_feat + sample_effect - return NormalDistOutputNN(self.n_latent, self.n_hidden, self.n_layers)( - inputs, training=training - ) - - -@flax_configure -class JaxMRVAE(JaxBaseModuleClass): - """Multi-resolution Variational Inference (MrVI) module. - - Parameters - ---------- - n_input - Number of input features. - n_sample - Number of samples. - n_batch - Number of batches. - n_labels - Number of labels. - n_latent - Number of latent variables for ``z``. - n_latent_u - Number of latent variables for ``u``. - encoder_n_hidden - Number of hidden units in the encoder. - encoder_n_layers - Number of layers in the encoder. - z_u_prior - Whether to place a Gaussian prior on ``z`` given ``u``. - z_u_prior_scale - Natural log of the scale parameter of the Gaussian prior placed on ``z`` given ``u``. Only - applies of ``learn_z_u_prior_scale`` is ``False``. - u_prior_scale - Natural log of the scale parameter of the Gaussian prior placed on ``u``. If - ``u_prior_mixture`` is ``True``, this scale applies to each mixture component distribution. - u_prior_mixture - Whether to use a mixture of Gaussians prior for ``u``. - u_prior_mixture_k - Number of mixture components to use for the mixture of Gaussians prior on ``u``. - learn_z_u_prior_scale - Whether to learn the scale parameter of the prior distribution of ``z`` given ``u``. - scale_observations - Whether to scale the loss associated with each observation by the total number of - observations linked to the associated sample. - px_kwargs - Keyword arguments for the generative model. - qz_kwargs - Keyword arguments for the inference model from ``u`` to ``z``. - qu_kwargs - Keyword arguments for the inference model from ``x`` to ``u``. - training - Whether the model is in training mode. - n_obs_per_sample - Number of observations per sample. - """ - - n_input: int - n_sample: int - n_batch: int - n_labels: int - n_latent: int = 30 - n_latent_u: int = 10 - encoder_n_hidden: int = 128 - encoder_n_layers: int = 2 - z_u_prior: bool = True - z_u_prior_scale: float = 0.0 - u_prior_scale: float = 0.0 - u_prior_mixture: bool = True - u_prior_mixture_k: int = 20 - learn_z_u_prior_scale: bool = False - scale_observations: bool = False - px_kwargs: dict | None = None - qz_kwargs: dict | None = None - qu_kwargs: dict | None = None - training: bool = True - n_obs_per_sample: jax.typing.ArrayLike | None = None - - def setup(self): - px_kwargs = DEFAULT_PX_KWARGS.copy() - if self.px_kwargs is not None: - px_kwargs.update(self.px_kwargs) - - qz_kwargs = DEFAULT_QZ_ATTENTION_KWARGS.copy() - if self.qz_kwargs is not None: - qz_kwargs.update(self.qz_kwargs) - - qu_kwargs = DEFAULT_QU_KWARGS.copy() - if self.qu_kwargs is not None: - qu_kwargs.update(self.qu_kwargs) - - is_isomorphic_uz = self.n_latent == self.n_latent_u - n_latent_u = None if is_isomorphic_uz else self.n_latent_u - - if self.n_latent < self.n_latent_u: - warnings.warn( - "The number of latent variables for `z` is set to less than the number of latent " - "variables for `u`.", - UserWarning, - stacklevel=settings.warnings_stacklevel, - ) - - # Generative model - px_cls = DecoderZXAttention - self.px = px_cls( - self.n_latent, - self.n_input, - self.n_batch, - **px_kwargs, - ) - - qz_cls = EncoderUZ - self.qz = qz_cls( - self.n_latent, - self.n_sample, - n_latent_u=n_latent_u, - **qz_kwargs, - ) - - # Inference model - self.qu = EncoderXU( - n_latent=self.n_latent if is_isomorphic_uz else n_latent_u, - n_sample=self.n_sample, - n_hidden=self.encoder_n_hidden, - n_layers=self.encoder_n_layers, - **qu_kwargs, - ) - self.backend = "jax" - - if self.learn_z_u_prior_scale: - self.pz_scale = self.param("pz_scale", nn.initializers.zeros, (self.n_latent,)) - else: - self.pz_scale = self.z_u_prior_scale - - if self.u_prior_mixture: - if self.n_labels > 1: - u_prior_mixture_k = self.n_labels - else: - u_prior_mixture_k = self.u_prior_mixture_k - u_dim = self.n_latent_u if self.n_latent_u is not None else self.n_latent - self.u_prior_logits = self.param( - "u_prior_logits", nn.initializers.zeros, (u_prior_mixture_k,) - ) - self.u_prior_means = self.param( - "u_prior_means", jax.random.normal, (u_prior_mixture_k, u_dim) - ) - self.u_prior_scales = self.param( - "u_prior_scales", nn.initializers.zeros, (u_prior_mixture_k, u_dim) - ) - - @property - def required_rngs(self): - return ("params", "u", "dropout", "eps") - - def _get_inference_input(self, tensors: dict[str, jax.typing.ArrayLike]) -> dict[str, Any]: - x = tensors[REGISTRY_KEYS.X_KEY] - sample_index = tensors[REGISTRY_KEYS.SAMPLE_KEY] - return {"x": x, "sample_index": sample_index} - - def inference( - self, - x: jax.typing.ArrayLike, - sample_index: jax.typing.ArrayLike, - mc_samples: int | None = None, - cf_sample: jax.typing.ArrayLike | None = None, - use_mean: bool = False, - ) -> dict[str, jax.Array | dist.Distribution]: - """Latent variable inference.""" - qu = self.qu(x, sample_index, training=self.training) - if use_mean: - u = qu.mean - else: - u_rng = self.make_rng("u") - sample_shape = (mc_samples,) if mc_samples is not None else () - u = qu.rsample(u_rng, sample_shape=sample_shape) - - sample_index_cf = sample_index if cf_sample is None else cf_sample - - z_base, eps = self.qz(u, sample_index_cf, training=self.training) - qeps_ = eps - - qeps = None - if qeps_.shape[-1] == 2 * self.n_latent: - loc_, scale_ = qeps_[..., : self.n_latent], qeps_[..., self.n_latent :] - qeps = dist.Normal(loc_, nn.softplus(scale_) + 1e-3) - eps = qeps.mean if use_mean else qeps.rsample(self.make_rng("eps")) - z = z_base + eps - library = jnp.log(x.sum(1, keepdims=True)) - - return { - "qu": qu, - "qeps": qeps, - "eps": eps, - "u": u, - "z": z, - "z_base": z_base, - "library": library, - } - - def _get_generative_input( - self, - tensors: dict[str, jax.typing.ArrayLike], - inference_outputs: dict[str, jax.Array | dist.Distribution], - ) -> dict[str, jax.Array]: - z = inference_outputs["z"] - library = inference_outputs["library"] - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - label_index = tensors[REGISTRY_KEYS.LABELS_KEY] - return { - "z": z, - "library": library, - "batch_index": batch_index, - "label_index": label_index, - } - - def generative( - self, - z: jax.typing.ArrayLike, - library: jax.typing.ArrayLike, - batch_index: jax.typing.ArrayLike, - label_index: jax.typing.ArrayLike, - ) -> dict[str, jax.Array | dist.Distribution]: - """Generative model.""" - library_exp = jnp.exp(library) - px = self.px( - z, - batch_index, - size_factor=library_exp, - training=self.training, - ) - h = px.mean / library_exp - - if self.u_prior_mixture: - offset = ( - 10.0 * jax.nn.one_hot(label_index, self.n_labels) if self.n_labels >= 2 else 0.0 - ) - cats = dist.Categorical(logits=self.u_prior_logits + offset) - # Use MultivariateNormal with diagonal covariance instead of Normal.to_event(1) - # because numpyro MixtureSameFamily requires ParameterFreeConstraint support - scales = jnp.exp(self.u_prior_scales) - scale_tril = jax.vmap(jnp.diag)(scales) - normal_dists = dist.MultivariateNormal(self.u_prior_means, scale_tril=scale_tril) - pu = dist.MixtureSameFamily(cats, normal_dists) - else: - pu = dist.Normal(0, jnp.exp(self.u_prior_scale)) - return {"px": px, "pu": pu, "h": h} - - def loss( - self, - tensors: dict[str, jax.typing.ArrayLike], - inference_outputs: dict[str, jax.Array | dist.Distribution], - generative_outputs: dict[str, jax.Array | dist.Distribution], - kl_weight: float = 1.0, - ) -> LossOutput: - """Compute the loss function value.""" - reconstruction_loss = ( - -generative_outputs["px"].log_prob(tensors[REGISTRY_KEYS.X_KEY]).sum(-1) - ) - - if self.u_prior_mixture: - kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]).sum( - -1 - ) - generative_outputs["pu"].log_prob(inference_outputs["u"]) - else: - kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]).sum(-1) - - kl_z = 0.0 - eps = inference_outputs["z"] - inference_outputs["z_base"] - if self.z_u_prior: - peps = dist.Normal(0, jnp.exp(self.pz_scale)) - kl_z = -peps.log_prob(eps).sum(-1) - - weighted_kl_local = kl_weight * (kl_u + kl_z) - loss = reconstruction_loss + weighted_kl_local - - if self.scale_observations: - sample_index = tensors[REGISTRY_KEYS.SAMPLE_KEY].flatten().astype(int) - prefactors = self.n_obs_per_sample[sample_index] - loss = loss / prefactors - - loss = jnp.mean(loss) - - return LossOutput( - loss=loss, - reconstruction_loss=reconstruction_loss, - kl_local=(kl_u + kl_z), - ) - - def compute_h_from_x_eps( - self, - x: jax.typing.ArrayLike, - sample_index: jax.typing.ArrayLike, - batch_index: jax.typing.ArrayLike, - extra_eps: float, - cf_sample: jax.typing.ArrayLike | None = None, - mc_samples: int = 10, - ): - """Compute normalized gene expression from observations using predefined eps""" - library = 7.0 * jnp.ones_like(sample_index) # placeholder has no effect on the value of h. - inference_outputs = self.inference( - x, sample_index, mc_samples=mc_samples, cf_sample=cf_sample, use_mean=False - ) - generative_inputs = { - "z": inference_outputs["z_base"] + extra_eps, - "library": library, - "batch_index": batch_index, - "label_index": jnp.zeros([x.shape[0], 1]), - } - generative_outputs = self.generative(**generative_inputs) - return generative_outputs["h"] diff --git a/src/scvi/external/mrvi_jax/_utils.py b/src/scvi/external/mrvi_jax/_utils.py deleted file mode 100644 index bad170400f..0000000000 --- a/src/scvi/external/mrvi_jax/_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - -from jax import jit - -if TYPE_CHECKING: - from jax import Array - from jax.typing import ArrayLike - - -@jit -def rowwise_max_excluding_diagonal(matrix: ArrayLike) -> Array: - """Get the rowwise maximum of a matrix excluding the diagonal.""" - import jax.numpy as jnp - - assert matrix.ndim == 2 - num_cols = matrix.shape[1] - mask = (1 - jnp.eye(num_cols)).astype(bool) - return (jnp.where(mask, matrix, -jnp.inf)).max(axis=1) - - -def simple_reciprocal(w: ArrayLike, eps: float = 1e-6) -> Array: - """Convert distances to similarities via a reciprocal.""" - return 1.0 / (w + eps) diff --git a/src/scvi/external/mrvi_torch/_model.py b/src/scvi/external/mrvi_torch/_model.py index 2e72999d30..69b7ed9845 100644 --- a/src/scvi/external/mrvi_torch/_model.py +++ b/src/scvi/external/mrvi_torch/_model.py @@ -110,7 +110,8 @@ class TorchMRVI( Notes ----- This implementation of MRVI is in PyTorch. - This will become the default version in v1.5 for MRVI. + This will become the default version in v1.4.3 for MRVI. + The JAX version is deprecated starting v1.5. See further usage examples in the following tutorial: @@ -118,7 +119,7 @@ class TorchMRVI( See the user guide for this model: - 1. :doc:`/user_guide/models/mrvi_jax` + 1. :doc:`/user_guide/models/mrvi` See Also -------- @@ -131,8 +132,9 @@ def __init__(self, adata: AnnData | None = None, registry: dict | None = None, * super().__init__(adata, registry) warnings.warn( - "You are using the Torch Version of MrVI, starting v1.5, This will become the " - "default version of MrVI instead of the Jax backend version.", + "You are using the Torch Version of MrVI, starting v1.4.3, " + "This will become the default version of MrVI instead of the Jax backend version, " + "which will be removed in v1.5", DeprecationWarning, stacklevel=settings.warnings_stacklevel, ) @@ -1700,7 +1702,7 @@ def update_sample_info(self, adata): -------- >>> import scanpy as sc >>> from scvi.external import MRVI - >>> MRVI.setup_anndata(adata, sample_key="sample_id", backend="torch") + >>> MRVI.setup_anndata(adata, sample_key="sample_id") >>> model = MRVI(adata) >>> model.train() >>> # Update sample info with new covariates diff --git a/src/scvi/external/tangram/__init__.py b/src/scvi/external/tangram/__init__.py deleted file mode 100644 index 476250d7c1..0000000000 --- a/src/scvi/external/tangram/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from scvi.utils import error_on_missing_dependencies - -error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - -from ._model import Tangram # noqa: E402 -from ._module import TangramMapper # noqa: E402 - -__all__ = ["Tangram", "TangramMapper"] diff --git a/src/scvi/external/tangram/_model.py b/src/scvi/external/tangram/_model.py deleted file mode 100644 index 32173cb208..0000000000 --- a/src/scvi/external/tangram/_model.py +++ /dev/null @@ -1,367 +0,0 @@ -from __future__ import annotations - -import logging -import warnings -from typing import TYPE_CHECKING - -import flax -import jax -import numpy as np -import pandas as pd -import scipy -from anndata import AnnData - -from scvi import settings -from scvi.data import AnnDataManager, AnnDataManagerValidationCheck, fields -from scvi.external.tangram._module import TANGRAM_REGISTRY_KEYS, TangramMapper -from scvi.model._utils import parse_device_args -from scvi.model.base import BaseModelClass -from scvi.train import JaxTrainingPlan -from scvi.train._config import merge_kwargs -from scvi.utils import setup_anndata_dsp, track -from scvi.utils._docstrings import devices_dsp - -if TYPE_CHECKING: - from typing import Literal - - import jax.numpy as jnp - from jaxlib.xla_extension import Device - from mudata import MuData - -logger = logging.getLogger(__name__) - - -def _asarray(x: np.ndarray, device: Device) -> jnp.ndarray: - return jax.device_put(x, device=device) - - -class Tangram(BaseModelClass): - """Reimplementation of Tangram :cite:p:`Biancalani21`. - - Maps single-cell RNA-seq data to spatial data. Original implementation: - https://github.com/broadinstitute/Tangram. - - Currently the "cells" and "constrained" modes are implemented. - - Parameters - ---------- - mdata - MuData object that has been registered via :meth:`~scvi.external.Tangram.setup_mudata`. - constrained - Whether to use the constrained version of Tangram instead of cells mode. - target_count - The number of cells to be filtered. Necessary when `constrained` is True. - **model_kwargs - Keyword args for :class:`~scvi.external.tangram.TangramMapper` - - Examples - -------- - >>> from scvi.external import Tangram - >>> ad_sc = anndata.read_h5ad(path_to_sc_anndata) - >>> ad_sp = anndata.read_h5ad(path_to_sp_anndata) - >>> markers = pd.read_csv(path_to_markers, index_col=0) # genes to use for mapping - >>> mdata = mudata.MuData( - { - "sp_full": ad_sp, - "sc_full": ad_sc, - "sp": ad_sp[:, markers].copy(), - "sc": ad_sc[:, markers].copy() - } - ) - >>> modalities = {"density_prior_key": "sp", "sc_layer": "sc", "sp_layer": "sp"} - >>> Tangram.setup_mudata( - mdata, density_prior_key="rna_count_based_density", modalities=modalities - ) - >>> tangram = Tangram(sc_adata) - >>> tangram.train() - >>> ad_sc.obsm["tangram_mapper"] = tangram.get_mapper_matrix() - >>> ad_sp.obsm["tangram_cts"] = tangram.project_cell_annotations( - ad_sc, ad_sp, ad_sc.obsm["tangram_mapper"], ad_sc.obs["labels"] - ) - >>> projected_ad_sp = tangram.project_genes(ad_sc, ad_sp, ad_sc.obsm["tangram_mapper"]) - - Notes - ----- - See further usage examples in the following tutorials: - - 1. :doc:`/tutorials/notebooks/spatial/tangram_scvi_tools` - """ - - def __init__( - self, - sc_adata: AnnData, - constrained: bool = False, - target_count: int | None = None, - **model_kwargs, - ): - warnings.warn( - "Tangram is a spatial transcriptomics model that will be moved to the " - "scvi-tools spatial companion package `scviva-tools` starting in scvi-tools v1.5 and " - "will change its backend to pyTorch over Jax. " - "It will be deprecated from scvi-tools in v1.6.", - FutureWarning, - stacklevel=settings.warnings_stacklevel, - ) - super().__init__(sc_adata) - self.n_obs_sc = self.adata_manager.get_from_registry(TANGRAM_REGISTRY_KEYS.SC_KEY).shape[0] - self.n_obs_sp = self.adata_manager.get_from_registry(TANGRAM_REGISTRY_KEYS.SP_KEY).shape[0] - - if constrained and target_count is None: - raise ValueError("Please specify `target_count` when using constrained Tangram.") - has_density_prior = not self.adata_manager.fields[-1].is_empty - if has_density_prior: - prior = self.adata_manager.get_from_registry(TANGRAM_REGISTRY_KEYS.DENSITY_KEY) - if np.abs(prior.ravel().sum() - 1) > 1e-3: - raise ValueError("Density prior must sum to 1. Please normalize the prior.") - - self.module = TangramMapper( - n_obs_sc=self.n_obs_sc, - n_obs_sp=self.n_obs_sp, - lambda_d=1.0 if has_density_prior else 0.0, - constrained=constrained, - target_count=target_count, - **model_kwargs, - ) - self._model_summary_string = ( - f"TangramMapper Model with params: \nn_obs_sc: {self.n_obs_sc}, " - "n_obs_sp: {self.n_obs_sp}" - ) - self.init_params_ = self._get_init_params(locals()) - - def get_mapper_matrix(self) -> np.ndarray: - """Return the mapping matrix. - - Returns - ------- - Mapping matrix of shape (n_obs_sp, n_obs_sc) - """ - return jax.device_get(jax.nn.softmax(self.module.params["mapper_unconstrained"], axis=1)) - - @devices_dsp.dedent - def train( - self, - max_epochs: int = 1000, - accelerator: str = "auto", - devices: int | list[int] | str = "auto", - lr: float = 0.1, - plan_kwargs: dict | None = None, - ): - """Train the model. - - Parameters - ---------- - max_epochs - Number of passes through the dataset. - %(param_accelerator)s - %(param_devices)s - lr - Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). - Specifying optimiser via plan_kwargs overrides this choice of lr. - plan_kwargs - Keyword args for :class:`~scvi.train.JaxTrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - """ - update_dict = { - "optim_kwargs": { - "learning_rate": lr, - "eps": 1e-8, - "weight_decay": 0, - } - } - plan_kwargs = merge_kwargs(None, plan_kwargs, name="plan") - plan_kwargs.update(update_dict) - - _, _, device = parse_device_args( - accelerator, - devices, - return_device="jax", - validate_single_device=True, - ) - try: - self.module.to(device) - logger.info( - f"Jax module moved to {device}." - "Note: Pytorch lightning will show GPU is not being used for the Trainer." - ) - except RuntimeError: - logger.debug("No GPU available to Jax.") - - tensor_dict = self._get_tensor_dict(device=device) - training_plan = JaxTrainingPlan(self.module, **plan_kwargs) - module_init = self.module.init(self.module.rngs, tensor_dict) - state, params = flax.core.pop(module_init, "params") - training_plan.set_train_state(params, state) - train_step_fn = JaxTrainingPlan.jit_training_step - pbar = track(range(max_epochs), style="tqdm", description="Training") - history = pd.DataFrame(index=np.arange(max_epochs), columns=["loss"]) - for i in pbar: - self.module.train_state, loss, _ = train_step_fn( - self.module.train_state, tensor_dict, self.module.rngs - ) - loss = jax.device_get(loss) - history.iloc[i] = loss - pbar.set_description(f"Training... Loss: {loss}") - self.history_ = {} - self.history_["loss"] = history - self.module.eval() - - @classmethod - @setup_anndata_dsp.dedent - def setup_mudata( - cls, - mdata: MuData, - density_prior_key: str | Literal["rna_count_based", "uniform"] | None = "rna_count_based", - sc_layer: str | None = None, - sp_layer: str | None = None, - modalities: dict[str, str] | None = None, - **kwargs, - ): - """%(summary)s. - - Parameters - ---------- - mdata - MuData with scRNA and spatial modalities. - sc_layer - Layer key in scRNA modality to use for training. - sp_layer - Layer key in spatial modality to use for training. - density_prior_key - Key in spatial modality obs for density prior. - modalities - Mapping from `setup_mudata` param name to modality in mdata. - """ - setup_method_args = cls._get_setup_method_args(**locals()) - - if modalities is None: - raise ValueError("Modalities cannot be None.") - modalities = cls._create_modalities_attr_dict(modalities, setup_method_args) - - mudata_fields = [ - fields.MuDataLayerField( - TANGRAM_REGISTRY_KEYS.SC_KEY, - sc_layer, - mod_key=modalities.sc_layer, - is_count_data=False, - mod_required=True, - ), - fields.MuDataLayerField( - TANGRAM_REGISTRY_KEYS.SP_KEY, - sp_layer, - mod_key=modalities.sp_layer, - is_count_data=False, - mod_required=True, - ), - fields.MuDataNumericalObsField( - TANGRAM_REGISTRY_KEYS.DENSITY_KEY, - density_prior_key, - mod_key=modalities.density_prior_key, - required=False, - mod_required=True, - ), - ] - adata_manager = AnnDataManager( - fields=mudata_fields, - setup_method_args=setup_method_args, - validation_checks=AnnDataManagerValidationCheck(check_fully_paired_mudata=False), - ) - adata_manager.register_fields(mdata, **kwargs) - sc_state = adata_manager.get_state_registry(TANGRAM_REGISTRY_KEYS.SC_KEY) - sp_state = adata_manager.get_state_registry(TANGRAM_REGISTRY_KEYS.SP_KEY) - # Need to access the underlying AnnData field to get these attributes - if not ( - pd.Index(sc_state[fields.LayerField.COLUMN_NAMES_KEY]).equals( - sp_state[fields.LayerField.COLUMN_NAMES_KEY] - ), - ): - raise ValueError( - "The column names of the spatial and single-cell layers must be the same." - ) - cls.register_manager(adata_manager) - - @classmethod - def setup_anndata(cls): - """Not implemented, use `setup_mudata`.""" - raise NotImplementedError("Use `setup_mudata` to setup a MuData object for training.") - - def _get_tensor_dict( - self, - device: Device, - ) -> dict[str, jnp.ndarray]: - """Get training data for Tangram model. - - Tangram does not minibatch, so we just make a dictionary of - jnp arrays here. - """ - tensor_dict = {} - for key in TANGRAM_REGISTRY_KEYS: - try: - tensor_dict[key] = self.adata_manager.get_from_registry(key) - # When density is missing - except KeyError: - continue - if scipy.sparse.issparse(tensor_dict[key]): - tensor_dict[key] = tensor_dict[key].toarray() - elif isinstance(tensor_dict[key], pd.DataFrame): - tensor_dict[key] = tensor_dict[key].values - else: - tensor_dict[key] = tensor_dict[key] - tensor_dict[key] = _asarray(tensor_dict[key], device=device) - - return tensor_dict - - @staticmethod - def project_cell_annotations( - adata_sc: AnnData, adata_sp: AnnData, mapper: np.ndarray, labels: pd.Series - ) -> pd.DataFrame: - """Project cell annotations to spatial data. - - Parameters - ---------- - adata_sc - AnnData object with single-cell data. - adata_sp - AnnData object with spatial data. - mapper - Mapping from single-cell to spatial data. - labels - Cell annotations to project. - - Returns - ------- - Projected annotations as a :class:`pd.DataFrame` with shape (n_sp, n_labels). - """ - if len(labels) != adata_sc.shape[0]: - raise ValueError( - "The number of labels must match the number of cells in the sc AnnData object." - ) - cell_type_df = pd.get_dummies(labels) - projection = mapper.T @ cell_type_df.values - return pd.DataFrame( - index=adata_sp.obs_names, columns=cell_type_df.columns, data=projection - ) - - @staticmethod - def project_genes(adata_sc: AnnData, adata_sp: AnnData, mapper: np.ndarray) -> AnnData: - """Project gene expression to spatial data. - - Parameters - ---------- - adata_sc - AnnData object with single-cell data. - adata_sp - AnnData object with spatial data. - mapper - Mapping from single-cell to spatial data. - - Returns - ------- - :class:`anndata.AnnData` object with projected gene expression. - """ - adata_ge = AnnData( - X=mapper.T @ adata_sc.X, - obs=adata_sp.obs, - var=adata_sc.var, - uns=adata_sc.uns, - ) - return adata_ge diff --git a/src/scvi/external/tangram/_module.py b/src/scvi/external/tangram/_module.py deleted file mode 100644 index a2c16da3ff..0000000000 --- a/src/scvi/external/tangram/_module.py +++ /dev/null @@ -1,160 +0,0 @@ -from typing import NamedTuple - -import jax -import jax.numpy as jnp - -from scvi.module._jaxvae import LossOutput -from scvi.module.base import JaxBaseModuleClass, flax_configure - - -class _TANGRAM_REGISTRY_KEYS_NT(NamedTuple): - SC_KEY: str = "X" - SP_KEY: str = "Y" - DENSITY_KEY: str = "DENSITY" - - -TANGRAM_REGISTRY_KEYS = _TANGRAM_REGISTRY_KEYS_NT() - -EPS = 1e-8 - - -def _cosine_similarity_vectors(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - return jnp.dot(x, y) / (jnp.maximum(jnp.linalg.norm(x) * jnp.linalg.norm(y), EPS)) - - -def _density_criterion(log_y_pred: jnp.ndarray, y_true: jnp.ndarray) -> jnp.ndarray: - # Kl divergence between the predicted and true distributions - log_y_true = jnp.log(y_true + EPS) - return (y_true * (log_y_true - log_y_pred)).sum() - - -@flax_configure -class TangramMapper(JaxBaseModuleClass): - """Tangram Mapper Model.""" - - n_obs_sc: int - n_obs_sp: int - lambda_g1: float = 1.0 - lambda_d: float = 0.0 - lambda_g2: float = 0.0 - lambda_r: float = 0.0 - lambda_count: float = 1.0 - lambda_f_reg: float = 1.0 - constrained: bool = False - target_count: int | None = None - training: bool = True - - def setup(self): - """Setup model.""" - self.mapper_unconstrained = self.param( - "mapper_unconstrained", - lambda rng, shape: jax.random.normal(rng, shape), - (self.n_obs_sc, self.n_obs_sp), - ) - - if self.constrained: - self.filter_unconstrained = self.param( - "filter_unconstrained", - lambda rng, shape: jax.random.normal(rng, shape), - (self.n_obs_sc, 1), - ) - - @property - def required_rngs(self): - return ("params",) - - def _get_inference_input(self, tensors: dict[str, jnp.ndarray]): - """Get input for inference.""" - return {} - - def inference(self) -> dict: - """Run inference model.""" - return {} - - def _get_generative_input( - self, - tensors: dict[str, jnp.ndarray], - inference_outputs: dict[str, jnp.ndarray], - ): - return {} - - def generative(self) -> dict: - """No generative model here.""" - return {} - - def loss( - self, - tensors, - inference_outputs, - generative_outputs, - ): - """Compute loss.""" - sp = tensors[TANGRAM_REGISTRY_KEYS.SP_KEY] - sc = tensors[TANGRAM_REGISTRY_KEYS.SC_KEY] - mapper = jax.nn.softmax(self.mapper_unconstrained, axis=1) - - if self.constrained: - filter = jax.nn.sigmoid(self.filter_unconstrained) - mapper_filtered = mapper * filter - - if self.lambda_d > 0: - density = tensors[TANGRAM_REGISTRY_KEYS.DENSITY_KEY].ravel() - if self.constrained: - d_pred = jnp.log(mapper_filtered.sum(axis=0) / (filter.sum())) - else: - d_pred = jnp.log(mapper.sum(axis=0) / mapper.shape[0]) - density_term = self.lambda_d * _density_criterion(d_pred, density) - else: - density_term = 0 - - if self.constrained: - sc = sc * filter - - g_pred = mapper.transpose() @ sc - - # Expression term - if self.lambda_g1 > 0: - cosine_similarity_0 = jax.vmap(_cosine_similarity_vectors, in_axes=1) - gv_term = self.lambda_g1 * cosine_similarity_0(sp, g_pred).mean() - else: - gv_term = 0 - if self.lambda_g2 > 0: - cosine_similarity_1 = jax.vmap(_cosine_similarity_vectors, in_axes=0) - vg_term = self.lambda_g1 * cosine_similarity_1(sp, g_pred).mean() - vg_term = self.lambda_g2 * vg_term - else: - vg_term = 0 - - expression_term = gv_term + vg_term - - # Regularization terms - if self.lambda_r > 0: - regularizer_term = self.lambda_r * (jnp.log(mapper) * mapper).sum() - else: - regularizer_term = 0 - - if self.lambda_count > 0 and self.constrained: - if self.target_count is None: - raise ValueError("target_count must be set if in constrained mode.") - count_term = self.lambda_count * jnp.abs(filter.sum() - self.target_count) - else: - count_term = 0 - - if self.lambda_f_reg > 0 and self.constrained: - f_reg_t = filter - jnp.square(filter) - f_reg = self.lambda_f_reg * f_reg_t.sum() - else: - f_reg = 0 - - # Total loss - total_loss = -expression_term - regularizer_term + count_term + f_reg - total_loss = total_loss + density_term - - return LossOutput( - loss=total_loss, - n_obs_minibatch=sp.shape[0], - extra_metrics={ - "expression_term": expression_term, - "regularizer_term": regularizer_term, - }, - ) diff --git a/src/scvi/model/__init__.py b/src/scvi/model/__init__.py index 7957627c64..6ded5bddcc 100644 --- a/src/scvi/model/__init__.py +++ b/src/scvi/model/__init__.py @@ -37,17 +37,6 @@ def __getattr__(name: str): only when object is actually requested. """ - if name == "JaxSCVI": - warnings.warn( - "In order to use the Jax version of SCVI make sure to install scvi-tools[jax]", - DeprecationWarning, - stacklevel=settings.warnings_stacklevel, - ) - - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._jaxscvi import JaxSCVI as _JaxSCVI - - return _JaxSCVI if name == "mlxSCVI": warnings.warn( "In order to use the MLX version of SCVI make sure to install mlx", diff --git a/src/scvi/model/_jaxscvi.py b/src/scvi/model/_jaxscvi.py deleted file mode 100644 index 70c6e267a0..0000000000 --- a/src/scvi/model/_jaxscvi.py +++ /dev/null @@ -1,194 +0,0 @@ -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -import jax.numpy as jnp - -from scvi import REGISTRY_KEYS -from scvi.data import AnnDataManager -from scvi.data.fields import ( - CategoricalJointObsField, - CategoricalObsField, - LayerField, - NumericalJointObsField, -) -from scvi.module import JaxVAE -from scvi.utils import setup_anndata_dsp - -from .base import BaseModelClass, JaxTrainingMixin - -if TYPE_CHECKING: - from collections.abc import Sequence - from typing import Literal - - import numpy as np - from anndata import AnnData - -logger = logging.getLogger(__name__) - - -class JaxSCVI(JaxTrainingMixin, BaseModelClass): - """single-cell Variational Inference :cite:p:`Lopez18`, but with JAX. - - This implementation is in a very experimental state. API is completely subject to change. - - Parameters - ---------- - adata - AnnData object that has been registered via :meth:`~scvi.model.JaxSCVI.setup_anndata`. - n_hidden - Number of nodes per hidden layer. - n_latent - Dimensionality of the latent space. - dropout_rate - Dropout rate for neural networks. - gene_likelihood - One of: - - * ``'nb'`` - Negative binomial distribution - * ``'poisson'`` - Poisson distribution - **model_kwargs - Keyword args for :class:`~scvi.module.JaxVAE` - - Examples - -------- - >>> adata = anndata.read_h5ad(path_to_anndata) - >>> scvi.model.JaxSCVI.setup_anndata(adata, batch_key="batch") - >>> vae = scvi.model.JaxSCVI(adata) - >>> vae.train() - >>> adata.obsm["X_scVI"] = vae.get_latent_representation() - """ - - _module_cls = JaxVAE - - def __init__( - self, - adata: AnnData, - n_hidden: int = 128, - n_latent: int = 10, - dropout_rate: float = 0.1, - gene_likelihood: Literal["nb", "poisson"] = "nb", - **model_kwargs, - ): - super().__init__(adata) - - n_batch = self.summary_stats.n_batch - n_cats_per_cov = ( - self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) - - self.module = self._module_cls( - n_input=self.summary_stats.n_vars, - n_batch=n_batch, - n_hidden=n_hidden, - n_latent=n_latent, - dropout_rate=dropout_rate, - gene_likelihood=gene_likelihood, - n_continuous_cov=self.summary_stats.get("n_extra_continuous_covs", 0), - n_cats_per_cov=tuple(n_cats_per_cov) if n_cats_per_cov is not None else (), - **model_kwargs, - ) - - self._model_summary_string = "" - self.init_params_ = self._get_init_params(locals()) - - @classmethod - @setup_anndata_dsp.dedent - def setup_anndata( - cls, - adata: AnnData, - layer: str | None = None, - batch_key: str | None = None, - labels_key: str | None = None, - categorical_covariate_keys: list[str] | None = None, - continuous_covariate_keys: list[str] | None = None, - **kwargs, - ): - """%(summary)s. - - Parameters - ---------- - %(param_adata)s - %(param_layer)s - %(param_batch_key)s - %(param_labels_key)s - %(param_cat_cov_keys)s - %(param_cont_cov_keys)s - """ - setup_method_args = cls._get_setup_method_args(**locals()) - anndata_fields = [ - LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), - CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys), - NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, continuous_covariate_keys), - ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) - adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) - - def get_latent_representation( - self, - adata: AnnData | None = None, - indices: Sequence[int] | None = None, - give_mean: bool = True, - n_samples: int = 1, - batch_size: int | None = None, - ) -> np.ndarray: - r"""Return the latent representation for each cell. - - This is denoted as :math:`z_n` in our manuscripts. - - Parameters - ---------- - adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. - indices - Indices of cells in adata to use. If `None`, all cells are used. - give_mean - Whether to return the mean of the posterior distribution or a sample. - n_samples - Number of samples to use for computing the latent representation. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - - Returns - ------- - latent_representation : np.ndarray - Low-dimensional representation for each cell - """ - self._check_if_trained(warn=False) - - adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size, iter_ndarray=True - ) - - jit_inference_fn = self.module.get_jit_inference_fn( - inference_kwargs={"n_samples": n_samples} - ) - latent = [] - for array_dict in scdl: - out = jit_inference_fn(self.module.rngs, array_dict) - if give_mean: - z = out["qz"].mean - else: - z = out["z"] - latent.append(z) - concat_axis = 0 if ((n_samples == 1) or give_mean) else 1 - latent = jnp.concatenate(latent, axis=concat_axis) - - return self.module.as_numpy_array(latent) - - def to_device(self, device): - """Move model to device. No-op for JAX models (device placement is handled by JAX).""" - pass - - @property - def device(self): - """The current device that the module's params are on.""" - return self.module.device diff --git a/src/scvi/model/_utils.py b/src/scvi/model/_utils.py index 385bb91b9a..83c4e43e06 100644 --- a/src/scvi/model/_utils.py +++ b/src/scvi/model/_utils.py @@ -76,7 +76,7 @@ def get_max_epochs_heuristic( def parse_device_args( accelerator: str = "auto", devices: int | list[int] | str = "auto", - return_device: Literal["torch", "jax"] | None = None, + return_device: Literal["torch"] | None = None, validate_single_device: bool = False, ): """Parses device-related arguments. @@ -88,9 +88,9 @@ def parse_device_args( %(param_return_device)s %(param_validate_single_device)s """ - valid = [None, "torch", "jax"] + valid = [None, "torch"] if return_device not in valid: - return ValueError(f"`return_device` must be one of {valid}") + raise ValueError(f"`return_device` must be one of {valid}") _validate_single_device = validate_single_device and devices != "auto" cond1 = isinstance(devices, list) and len(devices) > 1 @@ -124,7 +124,7 @@ def parse_device_args( ) elif _accelerator == "mps" and accelerator != "auto": warnings.warn( - "`accelerator` has been set to `mps`. Please note that not all PyTorch/Jax " + "`accelerator` has been set to `mps`. Please note that not all PyTorch " "operations are supported with this backend. as a result, some models might be slower " "and less accurate than usual. Please verify your analysis!" "Refer to https://github.com/pytorch/pytorch/issues/77764 for more details.", @@ -168,18 +168,8 @@ def parse_device_args( else: device = torch.device(f"{_accelerator}:{device_idx}") return _accelerator, _devices, device - elif return_device == "jax" and is_package_installed("jax"): - import jax - - device = jax.devices("cpu")[0] - if _accelerator != "cpu": - if _accelerator == "mps": - device = jax.devices("METAL")[device_idx] # MPS-JAX - else: - device = jax.devices(_accelerator)[device_idx] - return _accelerator, _devices, device else: - raise ImportError("Please install jax to use this functionality.") + raise ImportError("Please install gpu support to use this functionality.") def scrna_raw_counts_properties( diff --git a/src/scvi/model/base/__init__.py b/src/scvi/model/base/__init__.py index be74ad191c..f35209ff13 100644 --- a/src/scvi/model/base/__init__.py +++ b/src/scvi/model/base/__init__.py @@ -44,17 +44,6 @@ def __getattr__(name: str): only when object is actually requested. """ - if name == "JaxTrainingMixin": - warnings.warn( - "In order to use the JaxTrainingMixin make sure to install scvi-tools[jax]", - DeprecationWarning, - stacklevel=settings.warnings_stacklevel, - ) - - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._jaxmixin import JaxTrainingMixin as _JaxTrainingMixin - - return _JaxTrainingMixin if name == "MlxTrainingMixin": warnings.warn( "In order to use the MlxTrainingMixin make sure to install mlx", diff --git a/src/scvi/model/base/_jaxmixin.py b/src/scvi/model/base/_jaxmixin.py deleted file mode 100644 index 4dc64b6ee8..0000000000 --- a/src/scvi/model/base/_jaxmixin.py +++ /dev/null @@ -1,175 +0,0 @@ -from __future__ import annotations - -import logging -import warnings -from typing import TYPE_CHECKING - -from scvi import settings -from scvi.dataloaders import DataSplitter -from scvi.model._utils import get_max_epochs_heuristic, parse_device_args -from scvi.train import JaxModuleInit, JaxTrainingPlan, TrainRunner -from scvi.train._config import merge_kwargs -from scvi.utils._docstrings import devices_dsp - -if TYPE_CHECKING: - from scvi.train._config import KwargsLike - -logger = logging.getLogger(__name__) - - -class JaxTrainingMixin: - """General purpose train method for Jax-backed modules.""" - - _data_splitter_cls = DataSplitter - _training_plan_cls = JaxTrainingPlan - _train_runner_cls = TrainRunner - - @staticmethod - def _resolve_jax_n_devices(accelerator: str, devices) -> int: - """Resolve the number of JAX devices to use based on user arguments.""" - import jax - - if accelerator == "cpu": - return 1 - available = jax.local_device_count() - if devices == "auto" or devices == -1: - return available - if isinstance(devices, int): - return min(devices, available) - if isinstance(devices, (list, tuple)): - return min(len(devices), available) - if isinstance(devices, str) and devices.isdigit(): - return min(int(devices), available) - return 1 - - @devices_dsp.dedent - def train( - self, - max_epochs: int | None = None, - accelerator: str = "auto", - devices: int | list[int] | str = "auto", - train_size: float | None = None, - validation_size: float | None = None, - shuffle_set_split: bool = True, - batch_size: int = 128, - datasplitter_kwargs: dict | None = None, - plan_config: KwargsLike | None = None, - plan_kwargs: KwargsLike | None = None, - trainer_config: KwargsLike | None = None, - **trainer_kwargs, - ): - """Train the model. - - Parameters - ---------- - max_epochs - Number of passes through the dataset. If `None`, defaults to - `np.min([round((20000 / n_cells) * 400), 400])` - %(param_accelerator)s - %(param_devices)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set - are split in the sequential order of the data according to `validation_size` and - `train_size` percentages. - batch_size - Minibatch size to use during training. - lr - Learning rate to use during training. - datasplitter_kwargs - Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. - plan_kwargs - Keyword args for :class:`~scvi.train.JaxTrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - plan_config - Configuration object or mapping used to build :class:`~scvi.train.JaxTrainingPlan`. - Values in ``plan_kwargs`` and explicit arguments take precedence. - trainer_config - Configuration object or mapping used to build :class:`~scvi.train.Trainer`. Values in - ``trainer_kwargs`` and explicit arguments take precedence. - **trainer_kwargs - Other keyword args for :class:`~scvi.train.Trainer`. - """ - if max_epochs is None: - max_epochs = get_max_epochs_heuristic(self.adata.n_obs) - - n_devices = self._resolve_jax_n_devices(accelerator, devices) - - _, _, device = parse_device_args( - accelerator, - devices, - return_device="jax", - validate_single_device=False, - ) - try: - self.module.to(device) - logger.info( - f"Jax module moved to {device}." - "Note: Pytorch lightning will show GPU is not being used for the Trainer." - ) - except RuntimeError: - logger.debug("No GPU available to Jax.") - - if n_devices > 1: - logger.info(f"JAX multi-GPU training with {n_devices} devices.") - - datasplitter_kwargs = datasplitter_kwargs or {} - - # For multi-GPU, ensure batch size is divisible by n_devices - effective_batch_size = batch_size or settings.batch_size - if n_devices > 1 and effective_batch_size % n_devices != 0: - effective_batch_size = (effective_batch_size // n_devices + 1) * n_devices - logger.info( - f"Adjusted batch size to {effective_batch_size} for even sharding " - f"across {n_devices} devices." - ) - - # For multi-GPU, drop incomplete training batches to avoid sharding issues - if n_devices > 1: - datasplitter_kwargs.setdefault("drop_last", True) - - data_splitter = self._data_splitter_cls( - self.adata_manager, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - batch_size=effective_batch_size, - iter_ndarray=True, - **datasplitter_kwargs, - ) - plan_kwargs = merge_kwargs(plan_config, plan_kwargs, name="plan") - - self.training_plan = self._training_plan_cls(self.module, **plan_kwargs) - self.training_plan.n_devices = n_devices - if "callbacks" not in trainer_kwargs.keys(): - trainer_kwargs["callbacks"] = [] - trainer_kwargs["callbacks"].append(JaxModuleInit()) - - # Ignore Pytorch Lightning warnings for Jax workarounds. - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, module=r"pytorch_lightning.*") - runner = self._train_runner_cls( - self, - training_plan=self.training_plan, - data_splitter=data_splitter, - max_epochs=max_epochs, - accelerator="cpu", - devices="auto", - trainer_config=trainer_config, - **trainer_kwargs, - ) - runner() - - # After training, unreplicate state if multi-GPU was used - if n_devices > 1 and self.module.train_state is not None: - from flax.jax_utils import unreplicate - - self.module.train_state = unreplicate(self.module.train_state) - self.training_plan.n_devices = 1 - - self.is_trained_ = True - self.module.eval() diff --git a/src/scvi/module/__init__.py b/src/scvi/module/__init__.py index c60d18769a..fd4b7b7a9e 100644 --- a/src/scvi/module/__init__.py +++ b/src/scvi/module/__init__.py @@ -34,17 +34,6 @@ def __getattr__(name: str): only when object is actually requested. """ - if name == "JaxVAE": - warnings.warn( - "In order to use the JaxVAE make sure to install scvi-tools[jax]", - DeprecationWarning, - stacklevel=settings.warnings_stacklevel, - ) - - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._jaxvae import JaxVAE as _JaxVAE - - return _JaxVAE if name == "MlxVAE": warnings.warn( "In order to use the MlxVAE make sure to install mlx", diff --git a/src/scvi/module/_jaxvae.py b/src/scvi/module/_jaxvae.py deleted file mode 100644 index d9b7fd380e..0000000000 --- a/src/scvi/module/_jaxvae.py +++ /dev/null @@ -1,405 +0,0 @@ -from collections.abc import Iterable -from dataclasses import field - -import flax -import jax -import jax.numpy as jnp -import numpyro.distributions as dist -from flax import linen as nn -from flax.linen.initializers import variance_scaling - -from scvi import REGISTRY_KEYS -from scvi._types import LossRecord, Tensor -from scvi.distributions import JaxNegativeBinomialMeanDisp as NegativeBinomial -from scvi.module.base import JaxBaseModuleClass, flax_configure - - -@flax.struct.dataclass -class LossOutput: - """Loss signature for Jax models. - - This class provides an organized way to record the model loss, as well as - the components of the ELBO. This may also be used in MLE, MAP, EM methods. - The loss is used for backpropagation during inference. The other parameters - are used for logging/early stopping during inference. - - Parameters - ---------- - loss - Tensor with loss for minibatch. Should be one dimensional with one value. - Note that loss should be in an array/tensor and not a float. - reconstruction_loss - Reconstruction loss for each observation in the minibatch. If a tensor, converted to - a dictionary with key "reconstruction_loss" and value as tensor. - kl_local - KL divergence associated with each observation in the minibatch. If a tensor, converted to - a dictionary with key "kl_local" and value as tensor. - kl_global - Global KL divergence term. Should be one dimensional with one value. If a tensor, converted - to a dictionary with key "kl_global" and value as tensor. - classification_loss - Classification loss. - logits - Logits for classification. - true_labels - True labels for classification. - extra_metrics - Additional metrics can be passed as arrays/tensors or dictionaries of - arrays/tensors. - n_obs_minibatch - Number of observations in the minibatch. If None, will be inferred from - the shape of the reconstruction_loss tensor. - - - Examples - -------- - >>> loss_output = LossOutput( - ... loss=loss, - ... reconstruction_loss=reconstruction_loss, - ... kl_local=kl_local, - ... extra_metrics={"x": scalar_tensor_x, "y": scalar_tensor_y}, - ... ) - """ - - loss: LossRecord - reconstruction_loss: LossRecord | None = None - kl_local: LossRecord | None = None - kl_global: LossRecord | None = None - classification_loss: LossRecord | None = None - logits: Tensor | None = None - true_labels: Tensor | None = None - extra_metrics: dict[str, Tensor] | None = field(default_factory=dict) - n_obs_minibatch: int | None = None - reconstruction_loss_sum: Tensor = field(default=None) - kl_local_sum: Tensor = field(default=None) - kl_global_sum: Tensor = field(default=None) - - def __post_init__(self): - object.__setattr__(self, "loss", self.dict_sum(self.loss)) - - if self.n_obs_minibatch is None and self.reconstruction_loss is None: - raise ValueError("Must provide either n_obs_minibatch or reconstruction_loss") - - default = 0 * self.loss - if self.reconstruction_loss is None: - object.__setattr__(self, "reconstruction_loss", default) - if self.kl_local is None: - object.__setattr__(self, "kl_local", default) - if self.kl_global is None: - object.__setattr__(self, "kl_global", default) - - object.__setattr__(self, "reconstruction_loss", self._as_dict("reconstruction_loss")) - object.__setattr__(self, "kl_local", self._as_dict("kl_local")) - object.__setattr__(self, "kl_global", self._as_dict("kl_global")) - object.__setattr__( - self, - "reconstruction_loss_sum", - self.dict_sum(self.reconstruction_loss).sum(), - ) - object.__setattr__(self, "kl_local_sum", self.dict_sum(self.kl_local).sum()) - object.__setattr__(self, "kl_global_sum", self.dict_sum(self.kl_global)) - - if self.reconstruction_loss is not None and self.n_obs_minibatch is None: - rec_loss = self.reconstruction_loss - object.__setattr__(self, "n_obs_minibatch", list(rec_loss.values())[0].shape[0]) - - if self.classification_loss is not None and ( - self.logits is None or self.true_labels is None - ): - raise ValueError( - "Must provide `logits` and `true_labels` if `classification_loss` is provided." - ) - - @staticmethod - def dict_sum(dictionary: dict[str, Tensor] | Tensor): - """Sum over elements of a dictionary.""" - if isinstance(dictionary, dict): - return sum(dictionary.values()) - else: - return dictionary - - @property - def extra_metrics_keys(self) -> Iterable[str]: - """Keys for extra metrics.""" - return self.extra_metrics.keys() - - def _as_dict(self, attr_name: str): - attr = getattr(self, attr_name) - if isinstance(attr, dict): - return attr - else: - return {attr_name: attr} - - -class Dense(nn.Dense): - """Jax dense layer.""" - - def __init__(self, *args, **kwargs): - # scale set to reimplement pytorch init - scale = 1 / 3 - kernel_init = variance_scaling(scale, "fan_in", "uniform") - # bias init can't see input shape so don't include here - kwargs.update({"kernel_init": kernel_init}) - super().__init__(*args, **kwargs) - - -class FlaxEncoder(nn.Module): - """Encoder for Jax VAE.""" - - n_input: int - n_latent: int - n_hidden: int - dropout_rate: int - training: bool | None = None - - def setup(self): - """Setup encoder.""" - self.dense1 = Dense(self.n_hidden) - self.dense2 = Dense(self.n_hidden) - self.dense3 = Dense(self.n_latent) - self.dense4 = Dense(self.n_latent) - - self.batchnorm1 = nn.BatchNorm(momentum=0.9) - self.batchnorm2 = nn.BatchNorm(momentum=0.9) - self.dropout1 = nn.Dropout(self.dropout_rate) - self.dropout2 = nn.Dropout(self.dropout_rate) - - def __call__(self, x: jnp.ndarray, training: bool | None = None): - """Forward pass.""" - training = nn.merge_param("training", self.training, training) - is_eval = not training - - h = self.dense1(x) - h = self.batchnorm1(h, use_running_average=is_eval) - h = nn.relu(h) - h = self.dropout1(h, deterministic=is_eval) - h = self.dense2(h) - h = self.batchnorm2(h, use_running_average=is_eval) - h = nn.relu(h) - h = self.dropout2(h, deterministic=is_eval) - - mean = self.dense3(h) - log_var = self.dense4(h) - - return mean, jnp.exp(log_var) - - -class FlaxDecoder(nn.Module): - """Decoder for Jax VAE.""" - - n_input: int - dropout_rate: float - n_hidden: int - n_latent_input: int = 0 # Total input dim (z + covariates); 0 means use dense1 as-is - training: bool | None = None - - def setup(self): - """Setup decoder.""" - self.dense1 = Dense(self.n_hidden) - self.dense2 = Dense(self.n_hidden) - self.dense3 = Dense(self.n_hidden) - self.dense4 = Dense(self.n_hidden) - self.dense5 = Dense(self.n_input) - - self.batchnorm1 = nn.BatchNorm(momentum=0.9) - self.batchnorm2 = nn.BatchNorm(momentum=0.9) - self.dropout1 = nn.Dropout(self.dropout_rate) - self.dropout2 = nn.Dropout(self.dropout_rate) - - self.disp = self.param( - "disp", lambda rng, shape: jax.random.normal(rng, shape), (self.n_input, 1) - ) - - def __call__(self, z: jnp.ndarray, batch: jnp.ndarray, training: bool | None = None): - """Forward pass.""" - # TODO(adamgayoso): Test this - training = nn.merge_param("training", self.training, training) - is_eval = not training - - h = self.dense1(z) - h += self.dense2(batch) - - h = self.batchnorm1(h, use_running_average=is_eval) - h = nn.relu(h) - h = self.dropout1(h, deterministic=is_eval) - h = self.dense3(h) - # skip connection - h += self.dense4(batch) - h = self.batchnorm2(h, use_running_average=is_eval) - h = nn.relu(h) - h = self.dropout2(h, deterministic=is_eval) - h = self.dense5(h) - return h, self.disp.ravel() - - -@flax_configure -class JaxVAE(JaxBaseModuleClass): - """Variational autoencoder model.""" - - n_input: int - n_batch: int - n_hidden: int = 128 - n_latent: int = 30 - dropout_rate: float = 0.0 - n_layers: int = 1 - gene_likelihood: str = "nb" - eps: float = 1e-8 - training: bool = True - n_continuous_cov: int = 0 - n_cats_per_cov: tuple[int, ...] = () - - def setup(self): - """Setup model.""" - n_cat_cov_total = sum(self.n_cats_per_cov) - self._n_cat_cov_total = n_cat_cov_total - - n_input_encoder = self.n_input + self.n_continuous_cov + n_cat_cov_total - self.encoder = FlaxEncoder( - n_input=n_input_encoder, - n_latent=self.n_latent, - n_hidden=self.n_hidden, - dropout_rate=self.dropout_rate, - ) - - n_latent_input = self.n_latent + self.n_continuous_cov + n_cat_cov_total - self.decoder = FlaxDecoder( - n_input=self.n_input, - dropout_rate=0.0, - n_hidden=self.n_hidden, - n_latent_input=n_latent_input, - ) - - @property - def required_rngs(self): - return ("params", "dropout", "z") - - def _encode_covariates(self, cat_covs: jnp.ndarray | None) -> jnp.ndarray | None: - """One-hot encode categorical covariates and concatenate them.""" - if cat_covs is None or len(self.n_cats_per_cov) == 0: - return None - one_hots = [] - for i, n_cats in enumerate(self.n_cats_per_cov): - cat_col = cat_covs[:, i].astype(jnp.int32) - oh = jax.nn.one_hot(cat_col, n_cats) - one_hots.append(oh) - return jnp.concatenate(one_hots, axis=-1) - - def _get_inference_input(self, tensors: dict[str, jnp.ndarray]): - """Get input for inference.""" - x = tensors[REGISTRY_KEYS.X_KEY] - cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None) - cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None) - - input_dict = {"x": x, "cont_covs": cont_covs, "cat_covs": cat_covs} - return input_dict - - def inference( - self, - x: jnp.ndarray, - cont_covs: jnp.ndarray | None = None, - cat_covs: jnp.ndarray | None = None, - n_samples: int = 1, - ) -> dict: - """Run inference model.""" - encoder_input = jnp.log1p(x) - if cont_covs is not None: - encoder_input = jnp.concatenate([encoder_input, cont_covs], axis=-1) - cat_oh = self._encode_covariates(cat_covs) - if cat_oh is not None: - encoder_input = jnp.concatenate([encoder_input, cat_oh], axis=-1) - - mean, var = self.encoder(encoder_input, training=self.training) - stddev = jnp.sqrt(var) + self.eps - - qz = dist.Normal(mean, stddev) - z_rng = self.make_rng("z") - sample_shape = () if n_samples == 1 else (n_samples,) - z = qz.rsample(z_rng, sample_shape=sample_shape) - - return {"qz": qz, "z": z} - - def _get_generative_input( - self, - tensors: dict[str, jnp.ndarray], - inference_outputs: dict[str, jnp.ndarray], - ): - """Get input for generative model.""" - x = tensors[REGISTRY_KEYS.X_KEY] - z = inference_outputs["z"] - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - cont_covs = tensors.get(REGISTRY_KEYS.CONT_COVS_KEY, None) - cat_covs = tensors.get(REGISTRY_KEYS.CAT_COVS_KEY, None) - - input_dict = { - "x": x, - "z": z, - "batch_index": batch_index, - "cont_covs": cont_covs, - "cat_covs": cat_covs, - } - return input_dict - - def generative( - self, - x, - z, - batch_index, - cont_covs: jnp.ndarray | None = None, - cat_covs: jnp.ndarray | None = None, - ) -> dict: - """Run generative model.""" - # one hot adds an extra dimension - batch = jax.nn.one_hot(batch_index, self.n_batch).squeeze(-2) - - # When n_samples > 1, z has shape (n_samples, batch, latent). - # cont_covs, cat_oh, and batch are 2D (batch, dim) and must be - # broadcast to (n_samples, batch, dim) before concatenation. - n_samples = z.shape[0] if z.ndim == 3 else None - if n_samples is not None: - batch = jnp.broadcast_to(batch[jnp.newaxis], (n_samples, *batch.shape)) - - decoder_input = z - if cont_covs is not None: - if n_samples is not None: - cont_covs = jnp.broadcast_to(cont_covs[jnp.newaxis], (n_samples, *cont_covs.shape)) - decoder_input = jnp.concatenate([decoder_input, cont_covs], axis=-1) - cat_oh = self._encode_covariates(cat_covs) - if cat_oh is not None: - if n_samples is not None: - cat_oh = jnp.broadcast_to(cat_oh[jnp.newaxis], (n_samples, *cat_oh.shape)) - decoder_input = jnp.concatenate([decoder_input, cat_oh], axis=-1) - rho_unnorm, disp = self.decoder(decoder_input, batch, training=self.training) - disp_ = jnp.exp(disp) - rho = jax.nn.softmax(rho_unnorm, axis=-1) - total_count = x.sum(-1)[:, jnp.newaxis] - mu = total_count * rho - - if self.gene_likelihood == "nb": - disp_ = jnp.exp(disp) - px = NegativeBinomial(mean=mu, inverse_dispersion=disp_) - else: - px = dist.Poisson(mu) - - return {"px": px, "rho": rho} - - def loss( - self, - tensors, - inference_outputs, - generative_outputs, - kl_weight: float = 1.0, - ): - """Compute loss.""" - x = tensors[REGISTRY_KEYS.X_KEY] - px = generative_outputs["px"] - qz = inference_outputs["qz"] - reconst_loss = -px.log_prob(x).sum(-1) - kl_divergence_z = dist.kl_divergence(qz, dist.Normal(0, 1)).sum(-1) - - kl_local_for_warmup = kl_divergence_z - weighted_kl_local = kl_weight * kl_local_for_warmup - - loss = jnp.mean(reconst_loss + weighted_kl_local) - - kl_local = kl_divergence_z - return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local) diff --git a/src/scvi/module/base/__init__.py b/src/scvi/module/base/__init__.py index 2533938519..f91449c19e 100644 --- a/src/scvi/module/base/__init__.py +++ b/src/scvi/module/base/__init__.py @@ -23,26 +23,3 @@ "MogPrior", "VampPrior", ] - - -def __getattr__(name: str): - """Lazily provide object. If optional deps are missing, raise a helpful ImportError - - only when object is actually requested. - """ - if name == "flax_configure": - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._decorators import flax_configure as _flax_configure - - return _flax_configure - if name == "JaxBaseModuleClass": - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._base_module import JaxBaseModuleClass as _JaxBaseModuleClass - - return _JaxBaseModuleClass - if name == "TrainStateWithState": - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._base_module import TrainStateWithState as _TrainStateWithState - - return _TrainStateWithState - raise AttributeError(f"module {__name__!r} has no attribute {name}") diff --git a/src/scvi/module/base/_base_module.py b/src/scvi/module/base/_base_module.py index c3906ccec8..69b2862e31 100644 --- a/src/scvi/module/base/_base_module.py +++ b/src/scvi/module/base/_base_module.py @@ -4,22 +4,19 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING -import numpy as np import pyro import torch from torch import nn from torch.nn import functional as F -from scvi import REGISTRY_KEYS, settings +from scvi import REGISTRY_KEYS from scvi.data import _constants -from scvi.utils import is_package_installed from ._decorators import auto_move_data from ._pyro import AutoMoveDataPredictive if TYPE_CHECKING: from collections.abc import Callable, Iterable - from typing import Any from pyro.infer.predictive import Predictive @@ -444,302 +441,6 @@ def forward(self, *args, **kwargs): return self.model(*args, **kwargs) -if is_package_installed("jax") and is_package_installed("flax"): - import flax - from flax.training import train_state - - class TrainStateWithState(train_state.TrainState): - """TrainState with state attribute.""" - - state: dict[str, Any] - - class JaxBaseModuleClass(flax.linen.Module): - """Abstract class for Jax-based scvi-tools modules. - - The :class:`~scvi.module.base.JaxBaseModuleClass` provides an interface for Jax-backed - modules consistent with the :class:`~scvi.module.base.BaseModuleClass`. - - Any subclass must have a `training` parameter in its constructor, as well as - use the `@flax_configure` decorator. - - Children of :class:`~scvi.module.base.JaxBaseModuleClass` should - use the instance attribute ``self.training`` to appropriately modify - the behavior of the model whether it is in training or evaluation mode. - """ - - if TYPE_CHECKING: - import jax.numpy as jnp - from jaxlib.xla_extension import Device - from numpyro.distributions import Distribution - - def configure(self) -> None: - """Add necessary attrs.""" - from scvi.utils._jax import device_selecting_PRNGKey - - self.training = None - self.train_state = None - self.seed = settings.seed if settings.seed is not None else 0 - self.seed_rng = device_selecting_PRNGKey()(self.seed) - self._set_rngs() - - @abstractmethod - def setup(self): - """Flax setup method. - - With scvi-tools we prefer to use the setup parameterization of - flax.linen Modules. This tends the interface to be more like - PyTorch. More about this can be found here: - - https://flax.readthedocs.io/en/latest/design_notes/setup_or_nncompact.html - """ - - @property - @abstractmethod - def required_rngs(self): - """Returns a tuple of rng sequence names required for this Flax module.""" - return ("params",) - - def __call__( - self, - tensors: dict[str, jnp.ndarray], - get_inference_input_kwargs: dict | None = None, - get_generative_input_kwargs: dict | None = None, - inference_kwargs: dict | None = None, - generative_kwargs: dict | None = None, - loss_kwargs: dict | None = None, - compute_loss=True, - ) -> tuple[jnp.ndarray, jnp.ndarray] | tuple[jnp.ndarray, jnp.ndarray, LossOutput]: - """Forward pass through the network. - - Parameters - ---------- - tensors - tensors to pass through - get_inference_input_kwargs - Keyword args for ``_get_inference_input()`` - get_generative_input_kwargs - Keyword args for ``_get_generative_input()`` - inference_kwargs - Keyword args for ``inference()`` - generative_kwargs - Keyword args for ``generative()`` - loss_kwargs - Keyword args for ``loss()`` - compute_loss - Whether to compute loss on forward pass. This adds - another return value. - """ - return _generic_forward( - self, - tensors, - inference_kwargs, - generative_kwargs, - loss_kwargs, - get_inference_input_kwargs, - get_generative_input_kwargs, - compute_loss, - ) - - @abstractmethod - def _get_inference_input(self, tensors: dict[str, jnp.ndarray], **kwargs): - """Parse tensors dictionary for inference-related values.""" - - @abstractmethod - def _get_generative_input( - self, - tensors: dict[str, jnp.ndarray], - inference_outputs: dict[str, jnp.ndarray], - **kwargs, - ): - """Parse tensors dictionary for generative related values.""" - - @abstractmethod - def inference( - self, - *args, - **kwargs, - ) -> dict[str, jnp.ndarray | Distribution]: - """Run the recognition model. - - In the case of variational inference, this function will perform steps related to - computing variational distribution parameters. In a VAE, this will involve running - data through encoder networks. - - This function should return a dictionary with str keys and :class:`~jnp.ndarray` values - """ - - @abstractmethod - def generative(self, *args, **kwargs) -> dict[str, jnp.ndarray | Distribution]: - """Run the generative model. - - This function should return the parameters associated with the likelihood of the data. - This is typically written as :math:`p(x|z)`. - - This function should return a dictionary with str keys and :class:`~jnp.ndarray` values - """ - - @abstractmethod - def loss(self, *args, **kwargs) -> LossOutput: - """Compute the loss for a minibatch of data. - - This function uses the outputs of the inference and generative functions to compute - a loss. This many optionally include other penalty terms, which should be computed here - - This function should return an object of type :class:`~scvi.module.base.LossOutput`. - """ - - @property - def device(self): - devices = self.seed_rng.devices() - if len(devices) > 1: - raise RuntimeError("Module rng on multiple devices.") - return next(iter(devices)) - - def train(self): - """Switch to train mode. Emulates Pytorch's interface.""" - self.training = True - - def eval(self): - """Switch to evaluation mode. Emulates Pytorch's interface.""" - self.training = False - - @property - def rngs(self) -> dict[str, jnp.ndarray]: - """Dictionary of RNGs mapping required RNG name to RNG values. - - Calls ``self._split_rngs()`` resulting in newly generated RNGs on - every reference to ``self.rngs``. - """ - return self._split_rngs() - - def _set_rngs(self): - """Creates RNGs split off of the seed RNG for each RNG required by the module.""" - from jax import random - - required_rngs = self.required_rngs - rng_keys = random.split(self.seed_rng, num=len(required_rngs) + 1) - self.seed_rng, module_rngs = rng_keys[0], rng_keys[1:] - self._rngs = {k: module_rngs[i] for i, k in enumerate(required_rngs)} - - def _split_rngs(self): - """Regenerates the current set of RNGs and returns newly split RNGs. - - Importantly, this method does not reuse RNGs in future references to ``self.rngs``. - """ - from jax import random - - new_rngs = {} - ret_rngs = {} - for k, v in self._rngs.items(): - new_rngs[k], ret_rngs[k] = random.split(v) - self._rngs = new_rngs - return ret_rngs - - @property - def params(self) -> dict[str, Any]: - self._check_train_state_is_not_none() - return self.train_state.params - - @property - def state(self) -> dict[str, Any]: - self._check_train_state_is_not_none() - return self.train_state.state - - def state_dict(self) -> dict[str, Any]: - """Returns a serialized version of the train state as a dictionary.""" - self._check_train_state_is_not_none() - return flax.serialization.to_state_dict(self.train_state) - - def load_state_dict(self, state_dict: dict[str, Any]): - """Load a state dictionary into a train state.""" - if self.train_state is None: - raise RuntimeError( - "Train state is not set. Train for one iteration prior to loading state dict." - ) - self.train_state = flax.serialization.from_state_dict(self.train_state, state_dict) - - def to(self, device: Device): - """Move the module to a device.""" - import jax - - if device is not self.device: - if self.train_state is not None: - self.train_state = jax.tree_util.tree_map( - lambda x: jax.device_put(x, device), self.train_state - ) - - self.seed_rng = jax.device_put(self.seed_rng, device) - self._rngs = jax.device_put(self._rngs, device) - - def _check_train_state_is_not_none(self): - if self.train_state is None: - raise RuntimeError("Train state is not set. Module has not been trained.") - - def as_bound(self) -> JaxBaseModuleClass: - """Module bound with parameters learned from training.""" - return self.bind( - {"params": self.params, **self.state}, - rngs=self.rngs, - ) - - def get_jit_inference_fn( - self, - get_inference_input_kwargs: dict[str, Any] | None = None, - inference_kwargs: dict[str, Any] | None = None, - ) -> Callable[[dict[str, jnp.ndarray], dict[str, jnp.ndarray]], dict[str, jnp.ndarray]]: - """Create a method to run inference using the bound module. - - Parameters - ---------- - get_inference_input_kwargs - Keyword arguments to pass to subclass `_get_inference_input` - inference_kwargs - Keyword arguments for subclass `inference` method - - Returns - ------- - A callable taking rngs and array_dict as input and returning the output - of the `inference` method. This callable runs `_get_inference_input`. - """ - import jax - - vars_in = {"params": self.params, **self.state} - get_inference_input_kwargs = _get_dict_if_none(get_inference_input_kwargs) - inference_kwargs = _get_dict_if_none(inference_kwargs) - - @jax.jit - def _run_inference(rngs, array_dict): - module = self.clone() - inference_input = module._get_inference_input(array_dict) - out = module.apply( - vars_in, - rngs=rngs, - method=module.inference, - **inference_input, - **inference_kwargs, - ) - return out - - return _run_inference - - @staticmethod - def on_load(model, **kwargs): - """Callback function run in :meth:`~scvi.model.base.BaseModelClass.load`. - - Run one training step prior to loading state dict in order to initialize params. - """ - old_history = model.history_.copy() - model.train(max_steps=1) - model.history_ = old_history - - @staticmethod - def as_numpy_array(x: jnp.ndarray): - """Converts a jax device array to a numpy array.""" - import jax - - return np.array(jax.device_get(x)) - - def _generic_forward( module, tensors, @@ -750,7 +451,7 @@ def _generic_forward( get_generative_input_kwargs, compute_loss, ): - """Core of the forward call shared by PyTorch- and Jax-based modules.""" + """Core of the forward call shared by PyTorch modules.""" inference_kwargs = _get_dict_if_none(inference_kwargs) generative_kwargs = _get_dict_if_none(generative_kwargs) loss_kwargs = _get_dict_if_none(loss_kwargs) diff --git a/src/scvi/module/base/_decorators.py b/src/scvi/module/base/_decorators.py index 2510fd1e54..e3e959e476 100644 --- a/src/scvi/module/base/_decorators.py +++ b/src/scvi/module/base/_decorators.py @@ -5,8 +5,6 @@ import torch from torch.nn import Module -from scvi.utils import is_package_installed - def auto_move_data(fn: Callable) -> Callable: """Decorator for :class:`~torch.nn.Module` methods to move data to correct device. @@ -116,21 +114,3 @@ def _apply_to_collection( # data is neither of dtype, nor a collection return data - - -if is_package_installed("flax"): - import flax.linen as nn - - def flax_configure(cls: nn.Module) -> Callable: - """Decorator to raise an error if a boolean `training` param is missing in the call.""" - original_init = cls.__init__ - - @wraps(original_init) - def init(self, *args, **kwargs): - self.configure() - original_init(self, *args, **kwargs) - if not isinstance(self.training, bool): - raise ValueError("Custom sublclasses must have a training parameter.") - - cls.__init__ = init - return cls diff --git a/src/scvi/train/__init__.py b/src/scvi/train/__init__.py index 124eaaafec..687ca91e1a 100644 --- a/src/scvi/train/__init__.py +++ b/src/scvi/train/__init__.py @@ -8,7 +8,6 @@ from ._config import ( AdversarialTrainingPlanConfig, ClassifierTrainingPlanConfig, - JaxTrainingPlanConfig, KwargsConfig, LowLevelPyroTrainingPlanConfig, PyroTrainingPlanConfig, @@ -54,7 +53,6 @@ "SaveCheckpoint", "ScibCallback", "METRIC_KEYS", - "JaxTrainingPlanConfig", "KwargsConfig", ] @@ -64,16 +62,6 @@ def __getattr__(name: str): only when object is actually requested. """ - if name == "JaxModuleInit": - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._callbacks import JaxModuleInit as _JaxModuleInit - - return _JaxModuleInit - if name == "JaxTrainingPlan": - error_on_missing_dependencies("flax", "jax", "jaxlib", "optax", "numpyro") - from ._trainingplans import JaxTrainingPlan as _JaxTrainingPlan - - return _JaxTrainingPlan if name == "MlxTrainingPlan": error_on_missing_dependencies("mlx") from ._trainingplans import MlxTrainingPlan as _MlxTrainingPlan diff --git a/src/scvi/train/_callbacks.py b/src/scvi/train/_callbacks.py index dcb5958be4..178ec1fc42 100644 --- a/src/scvi/train/_callbacks.py +++ b/src/scvi/train/_callbacks.py @@ -17,12 +17,10 @@ from scvi import settings from scvi.model.base import BaseModelClass from scvi.model.base._save_load import _load_saved_files -from scvi.utils import dependencies if TYPE_CHECKING: import lightning.pytorch as pl - from scvi.dataloaders import AnnDataLoader MetricCallable = Callable[[BaseModelClass], float] @@ -267,35 +265,6 @@ def teardown( print(self.early_stopping_reason) -class JaxModuleInit(Callback): - """A callback to initialize the Jax-based module.""" - - def __init__(self, dataloader: AnnDataLoader = None) -> None: - super().__init__() - self.dataloader = dataloader - - @dependencies("flax") - def on_train_start(self, trainer, pl_module): - import flax - - module = pl_module.module - if self.dataloader is None: - dl = trainer.datamodule.train_dataloader() - else: - dl = self.dataloader - module_init = module.init(module.rngs, next(iter(dl))) - state, params = flax.core.pop(module_init, "params") - pl_module.set_train_state(params, state) - - # Multi-GPU: replicate train state across devices - n_devices = getattr(pl_module, "n_devices", 1) - if n_devices > 1: - from flax.jax_utils import replicate - - log.info(f"JAX multi-GPU: replicating state across {n_devices} devices.") - pl_module.module.train_state = replicate(pl_module.module.train_state) - - class ScibCallback(Callback): """A callback to initialize the Scib-Metrics autotune module.""" diff --git a/src/scvi/train/_config.py b/src/scvi/train/_config.py index c5f336a2ab..9121f5646e 100644 --- a/src/scvi/train/_config.py +++ b/src/scvi/train/_config.py @@ -323,35 +323,6 @@ def to_kwargs(self) -> dict[str, Any]: } -@dataclass -class JaxTrainingPlanConfig: - """Config for :class:`~scvi.train.JaxTrainingPlan`.""" - - optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam" - optimizer_creator: Callable[[], Any] | None = None - lr: float = 1e-3 - weight_decay: float = 1e-6 - eps: float = 0.01 - max_norm: float | None = None - n_steps_kl_warmup: int | None = None - n_epochs_kl_warmup: int | None = 400 - loss_kwargs: dict[str, Any] = field(default_factory=dict) - - def to_kwargs(self) -> dict[str, Any]: - kwargs = { - "optimizer": self.optimizer, - "optimizer_creator": self.optimizer_creator, - "lr": self.lr, - "weight_decay": self.weight_decay, - "eps": self.eps, - "max_norm": self.max_norm, - "n_steps_kl_warmup": self.n_steps_kl_warmup, - "n_epochs_kl_warmup": self.n_epochs_kl_warmup, - } - kwargs.update(self.loss_kwargs) - return kwargs - - @dataclass class TrainerConfig: """Config for :class:`~scvi.train.Trainer`.""" diff --git a/src/scvi/train/_trainingplans.py b/src/scvi/train/_trainingplans.py index 8373d9d41f..e7f668558c 100644 --- a/src/scvi/train/_trainingplans.py +++ b/src/scvi/train/_trainingplans.py @@ -1,7 +1,6 @@ import warnings from collections import OrderedDict from collections.abc import Callable, Iterable -from functools import partial from inspect import signature from typing import Any, Literal @@ -518,7 +517,7 @@ def configure_optimizers(self): @property def kl_weight(self): - """Scaling factor on KL divergence during training. Consider Jax""" + """Scaling factor on KL divergence during training""" klw = _compute_kl_weight( self.current_epoch, self.global_step, @@ -527,9 +526,7 @@ def kl_weight(self): self.max_kl_weight, self.min_kl_weight, ) - return ( - klw if type(self).__name__ == "JaxTrainingPlan" else torch.tensor(klw).to(self.device) - ) + return torch.tensor(klw).to(self.device) class AdversarialTrainingPlan(TrainingPlan): @@ -1614,359 +1611,6 @@ def configure_optimizers(self): return optimizer -if is_package_installed("jax") and is_package_installed("optax") and is_package_installed("flax"): - import jax - import jax.numpy as jnp - import optax - - from scvi.module.base._base_module import JaxBaseModuleClass, TrainStateWithState - - JaxOptimizerCreator = Callable[[], optax.GradientTransformation] - - class JaxTrainingPlan(TrainingPlan): - """Lightning module task to train Jax scvi-tools modules. - - Parameters - ---------- - module - An instance of :class:`~scvi.module.base.JaxBaseModuleClass`. - optimizer - One of "Adam", "AdamW", or "Custom", which requires a custom - optimizer creator callable to be passed via `optimizer_creator`. - optimizer_creator - A callable returning a :class:`~optax.GradientTransformation`. - This allows using any optax optimizer with custom hyperparameters. - lr - Learning rate used for optimization when `optimizer_creator` is None. - weight_decay - Weight decay used in optimization, when `optimizer_creator` is None. - eps - eps used for optimization, when `optimizer_creator` is None. - max_norm - Max global norm of gradients for gradient clipping. - n_steps_kl_warmup - Number of training steps (minibatches) to scale weight on KL divergences from - `min_kl_weight` to `max_kl_weight`. Only activated when `n_epochs_kl_warmup` is - set to None. - n_epochs_kl_warmup - Number of epochs to scale weight on KL divergences from `min_kl_weight` to - `max_kl_weight`. Overrides `n_steps_kl_warmup` when both are not `None`. - """ - - def __init__( - self, - module: JaxBaseModuleClass, - *, - optimizer: Literal["Adam", "AdamW", "Custom"] = "Adam", - optimizer_creator: JaxOptimizerCreator | None = None, - lr: float = 1e-3, - weight_decay: float = 1e-6, - eps: float = 0.01, - max_norm: float | None = None, - n_steps_kl_warmup: int | None = None, - n_epochs_kl_warmup: int | None = 400, - **loss_kwargs, - ): - super().__init__( - module=module, - lr=lr, - weight_decay=weight_decay, - eps=eps, - optimizer=optimizer, - optimizer_creator=optimizer_creator, - n_steps_kl_warmup=n_steps_kl_warmup, - n_epochs_kl_warmup=n_epochs_kl_warmup, - **loss_kwargs, - ) - self.max_norm = max_norm - self.automatic_optimization = False - self._dummy_param = torch.nn.Parameter(torch.Tensor([0.0])) - self.n_devices = 1 - self._pmap_train_fn = None - self._pmap_val_fn = None - - def get_optimizer_creator(self) -> JaxOptimizerCreator: - """Get the optimizer creator for the model.""" - clip_by = ( - optax.clip_by_global_norm(self.max_norm) if self.max_norm else optax.identity() - ) - if self.optimizer_name == "Adam": - # Replicates PyTorch Adam defaults - optim = optax.chain( - clip_by, - optax.add_decayed_weights(weight_decay=self.weight_decay), - optax.adam(self.lr, eps=self.eps), - ) - elif self.optimizer_name == "AdamW": - optim = optax.chain( - clip_by, - optax.clip_by_global_norm(self.max_norm), - optax.adamw(self.lr, eps=self.eps, weight_decay=self.weight_decay), - ) - elif self.optimizer_name == "Custom": - optim = self._optimizer_creator - else: - raise ValueError("Optimizer not understood.") - - return lambda: optim - - def set_train_state(self, params, state=None): - """Set the state of the module.""" - if self.module.train_state is not None: - return - optimizer = self.get_optimizer_creator()() - train_state = TrainStateWithState.create( - apply_fn=self.module.apply, - params=params, - tx=optimizer, - state=state, - ) - self.module.train_state = train_state - - @staticmethod - def _create_pmap_training_step(): - """Create a pmap'd training step function for multi-GPU.""" - - def pmap_train(state, batch, rngs, loss_kwargs): - def loss_fn(params): - vars_in = {"params": params, **state.state} - outputs, new_model_state = state.apply_fn( - vars_in, - batch, - rngs=rngs, - mutable=list(state.state.keys()), - loss_kwargs=loss_kwargs, - ) - loss_output = outputs[2] - return loss_output.loss, (loss_output, new_model_state) - - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( - loss_fn, has_aux=True - )(state.params) - grads = jax.lax.pmean(grads, axis_name="devices") - loss = jax.lax.pmean(loss, axis_name="devices") - new_model_state = jax.lax.pmean(new_model_state, axis_name="devices") - new_state = state.apply_gradients(grads=grads, state=new_model_state) - return new_state, loss, loss_output - - return jax.pmap(pmap_train, axis_name="devices", donate_argnums=(0,)) - - @staticmethod - def _create_pmap_validation_step(): - """Create a pmap'd validation step function for multi-GPU.""" - - def pmap_val(state, batch, rngs, loss_kwargs): - vars_in = {"params": state.params, **state.state} - outputs = state.apply_fn(vars_in, batch, rngs=rngs, loss_kwargs=loss_kwargs) - loss_output = outputs[2] - loss = jax.lax.pmean(loss_output.loss, axis_name="devices") - return loss, loss_output - - return jax.pmap(pmap_val, axis_name="devices") - - @staticmethod - def _shard_batch(batch, n_devices): - """Shard a batch across devices for pmap. - - If the batch size is not divisible by ``n_devices``, the last sample - is repeated to pad the batch to a divisible size. - """ - devices = jax.local_devices()[:n_devices] - - def _shard(x): - x = np.asarray(x) if not isinstance(x, np.ndarray) else x - remainder = x.shape[0] % n_devices - if remainder != 0: - pad_size = n_devices - remainder - pad = np.repeat(x[-1:], pad_size, axis=0) - x = np.concatenate([x, pad], axis=0) - chunks = np.reshape(x, (n_devices, x.shape[0] // n_devices, *x.shape[1:])) - return jax.device_put_sharded(list(chunks), devices) - - return jax.tree_util.tree_map(_shard, batch) - - @staticmethod - def _replicate_rngs(rngs, n_devices): - """Create per-device RNGs by splitting each RNG key.""" - import jax.random as random - - replicated = {} - for name, rng in rngs.items(): - keys = random.split(rng, n_devices) - # Place each key on the corresponding device for pmap - replicated[name] = jax.device_put_sharded( - list(keys), jax.local_devices()[:n_devices] - ) - return replicated - - @staticmethod - def _replicate_kwargs(kwargs, n_devices): - """Replicate kwargs values as JAX arrays with leading device dimension.""" - devices = jax.local_devices()[:n_devices] - replicated = {} - for k, v in kwargs.items(): - arr = jnp.asarray(v) - replicated[k] = jax.device_put_sharded([arr] * n_devices, devices) - return replicated - - @staticmethod - def _unshard_value(value): - """Take the value from the first device (for pmean'd scalars).""" - return jax.tree_util.tree_map( - lambda x: x[0] if hasattr(x, "ndim") and x.ndim > 0 else x, - value, - ) - - @staticmethod - @jax.jit - def jit_training_step( - state: TrainStateWithState, - batch: dict[str, np.ndarray], - rngs: dict[str, jnp.ndarray], - **kwargs, - ): - """Jit training step.""" - - def loss_fn(params): - # state can't be passed here - vars_in = {"params": params, **state.state} - outputs, new_model_state = state.apply_fn( - vars_in, batch, rngs=rngs, mutable=list(state.state.keys()), **kwargs - ) - loss_output = outputs[2] - loss = loss_output.loss - return loss, (loss_output, new_model_state) - - (loss, (loss_output, new_model_state)), grads = jax.value_and_grad( - loss_fn, has_aux=True - )(state.params) - new_state = state.apply_gradients(grads=grads, state=new_model_state) - return new_state, loss, loss_output - - def training_step(self, batch, batch_idx): - """Training step for Jax.""" - if "kl_weight" in self.loss_kwargs: - self.loss_kwargs.update({"kl_weight": self.kl_weight}) - self.module.train() - - if self.n_devices > 1: - if self._pmap_train_fn is None: - self._pmap_train_fn = self._create_pmap_training_step() - sharded_batch = self._shard_batch(batch, self.n_devices) - rngs = self._replicate_rngs(self.module.rngs, self.n_devices) - rep_kwargs = self._replicate_kwargs(self.loss_kwargs, self.n_devices) - self.module.train_state, _, loss_output = self._pmap_train_fn( - self.module.train_state, - sharded_batch, - rngs, - rep_kwargs, - ) - loss_output = self._unshard_value(loss_output) - else: - self.module.train_state, _, loss_output = self.jit_training_step( - self.module.train_state, - batch, - self.module.rngs, - loss_kwargs=self.loss_kwargs, - ) - - loss_output = jax.tree_util.tree_map( - lambda x: torch.tensor(jax.device_get(x)), - loss_output, - ) - self.log( - "train_loss", - loss_output.loss, - on_step=self.on_step, - on_epoch=self.on_epoch, - batch_size=loss_output.n_obs_minibatch, - prog_bar=True, - ) - self.compute_and_log_metrics(loss_output, self.train_metrics, "train") - # Update the dummy optimizer to update the global step - _opt = self.optimizers() - _opt.step() - - @partial(jax.jit, static_argnums=(0,)) - def jit_validation_step( - self, - state: TrainStateWithState, - batch: dict[str, np.ndarray], - rngs: dict[str, jnp.ndarray], - **kwargs, - ): - """Jit validation step.""" - vars_in = {"params": state.params, **state.state} - outputs = self.module.apply(vars_in, batch, rngs=rngs, **kwargs) - loss_output = outputs[2] - - return loss_output - - def validation_step(self, batch, batch_idx): - """Validation step for Jax.""" - self.module.eval() - - if self.n_devices > 1: - if self._pmap_val_fn is None: - self._pmap_val_fn = self._create_pmap_validation_step() - sharded_batch = self._shard_batch(batch, self.n_devices) - rngs = self._replicate_rngs(self.module.rngs, self.n_devices) - rep_kwargs = self._replicate_kwargs(self.loss_kwargs, self.n_devices) - _, loss_output = self._pmap_val_fn( - self.module.train_state, - sharded_batch, - rngs, - rep_kwargs, - ) - loss_output = self._unshard_value(loss_output) - else: - loss_output = self.jit_validation_step( - self.module.train_state, - batch, - self.module.rngs, - loss_kwargs=self.loss_kwargs, - ) - - loss_output = jax.tree_util.tree_map( - lambda x: torch.tensor(jax.device_get(x)), - loss_output, - ) - self.log( - "validation_loss", - loss_output.loss, - on_step=self.on_step, - on_epoch=self.on_epoch, - batch_size=loss_output.n_obs_minibatch, - ) - self.compute_and_log_metrics(loss_output, self.val_metrics, "validation") - - @staticmethod - def transfer_batch_to_device(batch, device, dataloader_idx): - """Bypass Pytorch Lightning device management.""" - return batch - - def configure_optimizers(self): - """Shim optimizer for PyTorch Lightning. - - PyTorch Lightning wants to take steps on an optimizer - returned by this function in order to increment the global - step count. See PyTorch Lightning optimizer manual loop. - - Here we provide a shim optimizer that we can take steps on - at minimal computational cost in order to keep Lightning happy :). - """ - return torch.optim.Adam([self._dummy_param]) - - def optimizer_step(self, *args, **kwargs): - pass - - def backward(self, *args, **kwargs): - pass - - def forward(self, *args, **kwargs): - pass - - if is_package_installed("mlx"): import mlx diff --git a/src/scvi/utils/__init__.py b/src/scvi/utils/__init__.py index 1ba2f47100..81425b98b6 100644 --- a/src/scvi/utils/__init__.py +++ b/src/scvi/utils/__init__.py @@ -2,7 +2,6 @@ from ._decorators import unsupported_if_adata_minified from ._dependencies import dependencies, error_on_missing_dependencies, is_package_installed from ._docstrings import de_dsp, setup_anndata_dsp -from ._jax import device_selecting_PRNGKey from ._mlflow import mlflow_log_artifact, mlflow_log_table, mlflow_log_text, mlflow_logger from ._track import track @@ -11,7 +10,6 @@ "setup_anndata_dsp", "de_dsp", "attrdict", - "device_selecting_PRNGKey", "unsupported_if_adata_minified", "mlflow_logger", "mlflow_log_artifact", diff --git a/src/scvi/utils/_docstrings.py b/src/scvi/utils/_docstrings.py index 1963cab7cb..91cd517760 100644 --- a/src/scvi/utils/_docstrings.py +++ b/src/scvi/utils/_docstrings.py @@ -232,8 +232,8 @@ param_return_device = """\ return_device Returns the first or only device as determined by `accelerator` and `devices`. - Depending on the value, will either return a PyTorch device (`"torch"`), a Jax - device (`"jax"`), or neither (`None`).""" + Depending on the value, will either return a PyTorch device (`"torch"`), + or neither (`None`).""" param_validate_single_device = """\ validate_single_device diff --git a/src/scvi/utils/_jax.py b/src/scvi/utils/_jax.py deleted file mode 100644 index a4b58c0e3f..0000000000 --- a/src/scvi/utils/_jax.py +++ /dev/null @@ -1,22 +0,0 @@ -from collections.abc import Callable - -from scvi.utils import dependencies - - -@dependencies("jax") -def device_selecting_PRNGKey(use_cpu: bool = True) -> Callable: - """Returns a PRNGKey that is either on CPU or GPU.""" - # if key is generated on CPU, model params will be on CPU - import jax - from jax import random - - if use_cpu is True: - - def key(i: int): - return jax.device_put(random.PRNGKey(i), jax.devices("cpu")[0]) - else: - # dummy function - def key(i: int): - return random.PRNGKey(i) - - return key diff --git a/tests/conftest.py b/tests/conftest.py index f530fff417..8c84a4134b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,12 +44,6 @@ def pytest_addoption(parser): default=False, help="Run tests that are optional.", ) - parser.addoption( - "--jax", - action="store_true", - default=False, - help="Run tests that are Jax adopted.", - ) parser.addoption( "--accelerator", action="store", @@ -124,18 +118,6 @@ def pytest_collection_modifyitems(config, items): elif run_optional and ("optional" not in item.keywords): item.add_marker(skip_non_optional) - run_jax = config.getoption("--jax") - skip_jax = pytest.mark.skip(reason="need --jax option to run") - skip_non_jax = pytest.mark.skip(reason="test not having a pytest.mark.jax decorator") - for item in items: - # All tests marked with `pytest.mark.jax` get skipped unless - # `--jax` passed - if not run_jax and ("jax" in item.keywords): - item.add_marker(skip_jax) - # Skip all tests not marked with `pytest.mark.jax` if `--jax` passed - elif run_jax and ("jax" not in item.keywords): - item.add_marker(skip_non_jax) - run_private = config.getoption("--private") skip_private = pytest.mark.skip(reason="need --private option to run") skip_non_private = pytest.mark.skip(reason="test not having a pytest.mark.private decorator") diff --git a/tests/external/mrvi_jax/mrvi_model/model.pt b/tests/external/mrvi_jax/mrvi_model/model.pt deleted file mode 100644 index 3c7433ae1d..0000000000 Binary files a/tests/external/mrvi_jax/mrvi_model/model.pt and /dev/null differ diff --git a/tests/external/mrvi_jax/mrvi_model_old_jax/model.pt b/tests/external/mrvi_jax/mrvi_model_old_jax/model.pt deleted file mode 100644 index 234a49ce66..0000000000 Binary files a/tests/external/mrvi_jax/mrvi_model_old_jax/model.pt and /dev/null differ diff --git a/tests/external/mrvi_jax/test_jaxmrvi_components.py b/tests/external/mrvi_jax/test_jaxmrvi_components.py deleted file mode 100644 index d835214e05..0000000000 --- a/tests/external/mrvi_jax/test_jaxmrvi_components.py +++ /dev/null @@ -1,81 +0,0 @@ -import flax.linen as nn -import jax -import jax.numpy as jnp -import pytest - -from scvi.external.mrvi_jax._components import ( - MLP, - AttentionBlock, - ConditionalNormalization, - Dense, - NormalDistOutputNN, - ResnetBlock, -) - - -@pytest.mark.jax -def test_jaxmrvi_dense(): - key = jax.random.PRNGKey(0) - x = jnp.ones((20, 10)) - dense = Dense(10) - params = dense.init(key, x) - output = dense.apply(params, x) - assert output.shape == (20, 10) - - -@pytest.mark.jax -def test_jaxmrvi_resnetblock(): - key = jax.random.PRNGKey(0) - x = jnp.ones((20, 10)) - block = ResnetBlock(n_out=30, n_hidden=128, training=True) - params = block.init(key, x) - output = block.apply(params, x, mutable=["batch_stats"]) - assert output[0].shape == (20, 30) - - -@pytest.mark.jax -def test_jaxmrvi_normalnn(): - key = jax.random.PRNGKey(0) - key, subkey = jax.random.split(key) - x = jnp.ones((20, 10)) - nn = NormalDistOutputNN(n_out=30, n_hidden=128, n_layers=3, training=True) - params = nn.init(key, x) - output = nn.apply(params, x, mutable=["batch_stats"]) - assert output[0].batch_shape == (20, 30) - - -@pytest.mark.jax -def test_jaxmrvi_mlp(): - key = jax.random.PRNGKey(0) - x = jnp.ones((20, 10)) - mlp = MLP(n_out=30, n_hidden=128, n_layers=3, activation=nn.relu, training=True) - params = mlp.init(key, x) - output = mlp.apply(params, x, mutable=["batch_stats"]) - assert output[0].shape == (20, 30) - - -@pytest.mark.jax -def test_jaxmrvi_conditionalbatchnorm1d(): - key = jax.random.PRNGKey(0) - x = jnp.ones((20, 10)) - y = jnp.ones((20, 1)) - conditionalbatchnorm1d = ConditionalNormalization( - n_features=10, n_conditions=3, normalization_type="batch", training=True - ) - params = conditionalbatchnorm1d.init(key, x, y) - output = conditionalbatchnorm1d.apply(params, x, y, mutable=["batch_stats"]) - assert output[0].shape == (20, 10) - - -@pytest.mark.jax -def test_jaxmrvi_attention(): - key = jax.random.PRNGKey(0) - q_vals = jnp.ones((30, 10)) - kv_vals = jnp.ones((30, 10)) - mod = AttentionBlock(query_dim=20, out_dim=40, training=True) - params = mod.init(key, q_vals, kv_vals) - mod.apply(params, q_vals, kv_vals, mutable=["batch_stats"]) - q_vals_3d = jnp.ones((3, 30, 10)) - kv_vals_3d = jnp.ones((3, 30, 10)) - r = mod.apply(params, q_vals_3d, kv_vals_3d, mutable=["batch_stats"]) - assert r[0].shape == (3, 30, 40) diff --git a/tests/external/mrvi_jax/test_jaxmrvi_model.py b/tests/external/mrvi_jax/test_jaxmrvi_model.py deleted file mode 100644 index ef5b99f42b..0000000000 --- a/tests/external/mrvi_jax/test_jaxmrvi_model.py +++ /dev/null @@ -1,258 +0,0 @@ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING - -import numpy as np -import pytest - -from scvi.data import synthetic_iid -from scvi.external import MRVI - -if TYPE_CHECKING: - from typing import Any - - from anndata import AnnData - - -@pytest.fixture(scope="session") -def adata(): - adata = synthetic_iid() - adata.obs.index.name = "cell_id" - adata.obs["sample"] = np.random.choice(15, size=adata.shape[0]) - adata.obs["sample_str"] = [chr(i + ord("a")) for i in adata.obs["sample"]] - meta1 = np.random.randint(0, 2, size=15) - adata.obs["meta1"] = meta1[adata.obs["sample"].values] - meta2 = np.random.randn(15) - adata.obs["meta2"] = meta2[adata.obs["sample"].values] - adata.obs["cont_cov"] = np.random.normal(0, 1, size=adata.shape[0]) - adata.obs["meta1_cat"] = "CAT_" + adata.obs["meta1"].astype(str) - adata.obs["meta1_cat"] = adata.obs["meta1_cat"].astype("category") - adata.obs.loc[:, "disjoint_batch"] = (adata.obs.loc[:, "sample"] <= 6).replace( - {True: "batch_0", False: "batch_1"} - ) - adata.obs["dummy_batch"] = 1 - return adata - - -@pytest.fixture(scope="session") -def model(adata: AnnData): - MRVI.setup_anndata(adata, sample_key="sample_str", batch_key="batch", backend="jax") - model = MRVI(adata) - model.train(1, train_size=0.5) - - return model - - -@pytest.mark.jax -def test_jaxmrvi(model: MRVI, adata: AnnData, save_path: str): - model.get_local_sample_distances(batch_size=16) - model.get_local_sample_distances(normalize_distances=True, batch_size=16) - model.get_latent_representation(give_z=False) - model.get_latent_representation(give_z=True) - - model_path = os.path.join(save_path, "mrvi_model") - model.save(model_path, save_anndata=False, overwrite=True) - model = MRVI.load(model_path, adata=adata) - model.train(1) - # a jax model from prev version - should work! - model = MRVI.load("tests/external/mrvi_jax/mrvi_model_old_jax", adata=adata) - model.train(1) - - -@pytest.mark.jax -@pytest.mark.parametrize( - ("setup_kwargs", "de_kwargs"), - [ - ( - {"sample_key": "sample_str", "batch_key": "batch"}, - [ - { - "sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"], - "store_lfc": True, - "add_batch_specific_offsets": True, - }, - { - "sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"], - "store_lfc": True, - "add_batch_specific_offsets": True, - "filter_inadmissible_samples": True, - }, - { - "sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"], - "store_lfc": True, - "add_batch_specific_offsets": False, - }, - ], - ), - ( - {"sample_key": "sample_str", "batch_key": "dummy_batch"}, - [ - { - "sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"], - "store_lfc": True, - }, - { - "sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"], - "store_lfc": True, - "lambd": 1e-1, - }, - { - "sample_cov_keys": ["meta1_cat", "meta2", "cont_cov"], - "store_lfc": True, - "filter_inadmissible_samples": True, - }, - ], - ), - ], -) -def test_jaxmrvi_de(model: MRVI, setup_kwargs: dict[str, Any], de_kwargs: dict[str, Any]): - for de_kwarg in de_kwargs: - model.differential_expression(**de_kwarg) - - -@pytest.mark.jax -@pytest.mark.parametrize( - "sample_key", - ["sample", "sample_str"], -) -@pytest.mark.parametrize( - "da_kwargs", - [ - {"sample_cov_keys": ["meta1_cat"]}, - {"sample_cov_keys": ["meta1_cat", "batch"]}, - {"sample_cov_keys": ["meta1_cat"], "omit_original_sample": False}, - {"sample_cov_keys": ["meta1_cat"], "compute_log_enrichment": True}, - {"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True}, - ], -) -def test_jaxmrvi_da(model, sample_key, da_kwargs): - model.differential_abundance(**da_kwargs) - - -@pytest.mark.jax -@pytest.mark.parametrize( - "model_kwargs", - [ - {"qz_kwargs": {"use_map": False}}, - { - "qz_kwargs": {"use_map": False}, - "px_kwargs": {"low_dim_batch": False}, - "u_prior_mixture": True, - }, - { - "qz_kwargs": { - "use_map": False, - "stop_gradients": False, - "stop_gradients_mlp": True, - }, - "px_kwargs": { - "low_dim_batch": False, - "stop_gradients": False, - "stop_gradients_mlp": True, - }, - "z_u_prior": False, - }, - { - "qz_kwargs": {"use_map": False}, - "px_kwargs": {"low_dim_batch": True}, - "learn_z_u_prior_scale": True, - }, - ], -) -def test_jaxmrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): - MRVI.setup_anndata( - adata, - sample_key="sample_str", - batch_key="batch", - backend="jax", - ) - model = MRVI(adata, n_latent=10, scale_observations=True, **model_kwargs) - model.train(2, train_size=0.5) - - model_path = os.path.join(save_path, "mrvi_model") - model.save(model_path, save_anndata=False, overwrite=True) - model = MRVI.load(model_path, adata=adata) - - -@pytest.mark.jax -def test_jaxmrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): - sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] - sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) - - model_path = os.path.join(save_path, "mrvi_model") - model.save(model_path, save_anndata=False, overwrite=True) - model = MRVI.load(model_path, adata=adata) - - -@pytest.mark.jax -def test_jaxmrvi_shrink_u(adata: AnnData, save_path: str): - MRVI.setup_anndata( - adata, - sample_key="sample_str", - batch_key="batch", - backend="jax", - ) - model = MRVI(adata, n_latent=10, n_latent_u=5) - model.train(2, train_size=0.5) - model.get_local_sample_distances(batch_size=16) - - assert model.get_latent_representation().shape == ( - adata.shape[0], - 5, - ) - - model_path = os.path.join(save_path, "mrvi_model") - model.save(model_path, save_anndata=False, overwrite=True) - model = MRVI.load(model_path, adata=adata) - - -@pytest.fixture -def adata_stratifications(): - adata = synthetic_iid() - adata.obs["sample"] = np.random.choice(15, size=adata.shape[0]) - adata.obs["sample_str"] = [chr(i + ord("a")) for i in adata.obs["sample"]] - meta1 = np.random.randint(0, 2, size=15) - adata.obs["meta1"] = meta1[adata.obs["sample"].values] - meta2 = np.random.randn(15) - adata.obs["meta2"] = meta2[adata.obs["sample"].values] - adata.obs["cont_cov"] = np.random.normal(0, 1, size=adata.shape[0]) - adata.obs.loc[:, "label_2"] = np.random.choice(2, size=adata.shape[0]) - return adata - - -@pytest.mark.jax -def test_jaxmrvi_stratifications(adata_stratifications: AnnData, save_path: str): - MRVI.setup_anndata( - adata_stratifications, - sample_key="sample_str", - batch_key="batch", - backend="jax", - ) - model = MRVI(adata_stratifications, n_latent=10) - model.train(2, train_size=0.5) - - dists = model.get_local_sample_distances(groupby=["labels", "label_2"], batch_size=16) - cell_dists = dists["cell"] - assert cell_dists.shape == (adata_stratifications.shape[0], 15, 15) - ct_dists = dists["labels"] - assert ct_dists.shape == (3, 15, 15) - assert np.allclose(ct_dists[0].values, ct_dists[0].values.T, atol=1e-6) - ct_dists = dists["label_2"] - assert ct_dists.shape == (2, 15, 15) - assert np.allclose(ct_dists[0].values, ct_dists[0].values.T, atol=1e-6) - - adata_stratifications_sub = adata_stratifications[ - adata_stratifications.obs["labels"] == "label_0" - ] - sub_dists = model.get_local_sample_distances( - adata=adata_stratifications_sub, groupby=["labels"], batch_size=16 - ) - assert sub_dists["cell"].shape == (adata_stratifications_sub.shape[0], 15, 15) - assert sub_dists["labels"].shape == (1, 15, 15) - assert np.allclose(sub_dists["labels"][0].values, sub_dists["labels"][0].values.T, atol=1e-6) - - model_path = os.path.join(save_path, "mrvi_model") - model.save(model_path, save_anndata=False, overwrite=True) - model = MRVI.load(model_path, adata=adata_stratifications) diff --git a/tests/external/mrvi_torch/test_torchmrvi_jax_equivalence.py b/tests/external/mrvi_torch/test_torchmrvi_jax_equivalence.py deleted file mode 100644 index f4f3a10e80..0000000000 --- a/tests/external/mrvi_torch/test_torchmrvi_jax_equivalence.py +++ /dev/null @@ -1,508 +0,0 @@ -"""Numerical equivalence tests between JAX and PyTorch MRVI implementations. - -Transfers weights from JAX to PyTorch and compares forward pass outputs and -gradients. Requires JAX — skipped automatically if not installed. -""" - -from __future__ import annotations - -import os - -import numpy as np -import pytest -import torch - -os.environ.setdefault("JAX_DEFAULT_MATMUL_PRECISION", "float32") -os.environ.setdefault("JAX_PLATFORMS", "cpu") -torch.backends.cuda.matmul.allow_tf32 = False -torch.backends.cudnn.allow_tf32 = False -torch.set_float32_matmul_precision("highest") - -from scvi import REGISTRY_KEYS # noqa: E402 -from scvi.external.mrvi_torch._module import TorchMRVAE # noqa: E402 - -ATOL_SIMPLE = 1e-5 -ATOL_ATTN = 5e-4 - - -# ── weight transfer helpers ────────────────────────────────────────────────── - - -def _t(arr): - return torch.tensor(np.array(arr)).float() - - -def _dense(jax_p, torch_linear): - torch_linear.weight.data = _t(jax_p["kernel"].T) - if "bias" in jax_p: - torch_linear.bias.data = _t(jax_p["bias"]) - - -def _ln(jax_p, torch_ln): - if "scale" in jax_p: - torch_ln.weight.data = _t(jax_p["scale"]) - if "bias" in jax_p: - torch_ln.bias.data = _t(jax_p["bias"]) - - -def _embed(jax_p, torch_emb): - torch_emb.weight.data = _t(jax_p["embedding"]) - - -def _resnet_block(jp, tb): - _dense(jp["Dense_0"], tb.fc1) - _ln(jp["LayerNorm_0"], tb.layer_norm1) - if tb.fc_match is not None: - _dense(jp["Dense_1"], tb.fc_match) - _dense(jp["Dense_2"], tb.fc2) - else: - _dense(jp["Dense_1"], tb.fc2) - _ln(jp["LayerNorm_1"], tb.layer_norm2) - - -def _mlp(jp, tm): - for i in range(len(tm.resnet_blocks)): - _resnet_block(jp[f"ResnetBlock_{i}"], tm.resnet_blocks[i]) - _dense(jp["Dense_0"], tm.fc) - - -def _normal_dist_nn(jp, tn): - for i, rb in enumerate(tn.resnet_blocks): - _resnet_block(jp[f"ResnetBlock_{i}"], rb) - _dense(jp["Dense_0"], tn.fc_mean) - _dense(jp["Dense_1"], tn.fc_scale[0]) - - -def _attention_block(jp, tab): - n_heads, depth, n_channels = tab.n_heads, tab.depth_per_head, tab.n_channels - tab.query_proj.weight.data = _t(jp["DenseGeneral_0"]["kernel"][:, :, 0].T) - tab.kv_proj.weight.data = _t(jp["DenseGeneral_1"]["kernel"][:, :, 0].T) - mha = jp["MultiHeadDotProductAttention_0"] - for name, proj in [("query", tab.q_proj), ("key", tab.k_proj), ("value", tab.v_proj)]: - k = np.array(mha[name]["kernel"]) - b = np.array(mha[name]["bias"]) - proj.weight.data = _t(k.reshape(1, n_heads * depth).T) - proj.bias.data = _t(b.reshape(-1)) - ok = np.array(mha["out"]["kernel"]) - ob = np.array(mha["out"]["bias"]) - tab.out_proj.weight.data = _t(ok.reshape(n_heads * depth, n_channels).T) - tab.out_proj.bias.data = _t(ob) - _mlp(jp["MLP_0"], tab.mlp_eps) - _mlp(jp["MLP_1"], tab.mlp_residual) - - -def transfer_all_weights(jax_params, torch_module): - p = jax_params - tqu = torch_module.qu - _dense(p["qu"]["Dense_0"], tqu.fc1) - _dense(p["qu"]["Dense_1"], tqu.fc2) - _embed( - p["qu"]["ConditionalNormalization_0"]["gamma_conditional"], - tqu.conditional_norm1.gamma_embedding, - ) - _embed( - p["qu"]["ConditionalNormalization_0"]["beta_conditional"], - tqu.conditional_norm1.beta_embedding, - ) - _embed( - p["qu"]["ConditionalNormalization_1"]["gamma_conditional"], - tqu.conditional_norm2.gamma_embedding, - ) - _embed( - p["qu"]["ConditionalNormalization_1"]["beta_conditional"], - tqu.conditional_norm2.beta_embedding, - ) - _embed(p["qu"]["Embed_0"], tqu.sample_embed) - _normal_dist_nn(p["qu"]["NormalDistOutputNN_0"], tqu.output_nn) - - tqz = torch_module.qz - _ln(p["qz"]["u_ln"], tqz.layer_norm) - _embed(p["qz"]["Embed_0"], tqz.embedding) - _ln(p["qz"]["sample_embed_ln"], tqz.layer_norm_embed) - _attention_block(p["qz"]["AttentionBlock_0"], tqz.attention_block) - if tqz.fc is not None: - _dense(p["qz"]["Dense_0"], tqz.fc) - - tpx = torch_module.px - _ln(p["px"]["u_ln"], tpx.layer_norm) - _embed(p["px"]["Embed_0"], tpx.batch_embedding) - _ln(p["px"]["batch_embed_ln"], tpx.layer_norm_batch_embed) - _attention_block(p["px"]["AttentionBlock_0"], tpx.attention_block) - _dense(p["px"]["Dense_0"], tpx.fc) - tpx.px_r.data = _t(p["px"]["px_r"]) - - if torch_module.u_prior_mixture: - torch_module.u_prior_logits.data = _t(p["u_prior_logits"]) - torch_module.u_prior_means.data = _t(p["u_prior_means"]) - torch_module.u_prior_scales.data = _t(p["u_prior_scales"]) - - if hasattr(torch_module, "pz_scale"): - if isinstance(torch_module.pz_scale, torch.nn.Parameter): - torch_module.pz_scale.data = _t(p["pz_scale"]) - else: - pz_val = ( - np.array(p["pz_scale"]) if "pz_scale" in p else np.zeros(torch_module.n_latent) - ) - torch_module.pz_scale.copy_(_t(pz_val)) - - -# ── comparison helpers ─────────────────────────────────────────────────────── - - -def _compare(name, jax_val, torch_val, atol): - j = np.array(jax_val) - t = ( - torch_val.detach().cpu().numpy() - if isinstance(torch_val, torch.Tensor) - else np.array(torch_val) - ) - diff = np.max(np.abs(j - t)) - assert diff < atol, f"{name}: max_diff={diff:.2e} exceeds atol={atol:.0e}" - - -def _compare_grad(name, jax_grad_val, torch_param, atol): - j = np.array(jax_grad_val) - assert torch_param.grad is not None, f"{name}: PyTorch grad is None" - t = torch_param.grad.detach().cpu().numpy() - diff = np.max(np.abs(j - t)) - assert diff < atol, f"grad {name}: max_diff={diff:.2e} exceeds atol={atol:.0e}" - - -def _compare_dense_grad(name, jg, tl, atol): - _compare_grad(f"{name}.w", np.array(jg["kernel"]).T, tl.weight, atol) - if "bias" in jg and tl.bias is not None: - _compare_grad(f"{name}.b", jg["bias"], tl.bias, atol) - - -def _compare_ln_grad(name, jg, tl, atol): - if "scale" in jg: - _compare_grad(f"{name}.w", jg["scale"], tl.weight, atol) - if "bias" in jg: - _compare_grad(f"{name}.b", jg["bias"], tl.bias, atol) - - -def _compare_embed_grad(name, jg, te, atol): - _compare_grad(f"{name}.w", jg["embedding"], te.weight, atol) - - -def _compare_resnet_block_grad(name, jg, tb, atol): - _compare_dense_grad(f"{name}.fc1", jg["Dense_0"], tb.fc1, atol) - _compare_ln_grad(f"{name}.ln1", jg["LayerNorm_0"], tb.layer_norm1, atol) - if tb.fc_match is not None: - _compare_dense_grad(f"{name}.fc_match", jg["Dense_1"], tb.fc_match, atol) - _compare_dense_grad(f"{name}.fc2", jg["Dense_2"], tb.fc2, atol) - else: - _compare_dense_grad(f"{name}.fc2", jg["Dense_1"], tb.fc2, atol) - _compare_ln_grad(f"{name}.ln2", jg["LayerNorm_1"], tb.layer_norm2, atol) - - -def _compare_mlp_grad(name, jg, tm, atol): - for i in range(len(tm.resnet_blocks)): - _compare_resnet_block_grad( - f"{name}.rb{i}", jg[f"ResnetBlock_{i}"], tm.resnet_blocks[i], atol - ) - _compare_dense_grad(f"{name}.fc", jg["Dense_0"], tm.fc, atol) - - -def _compare_attention_block_grad(name, jg, tab, atol): - n_heads, depth, n_channels = tab.n_heads, tab.depth_per_head, tab.n_channels - _compare_grad( - f"{name}.qproj.w", - np.array(jg["DenseGeneral_0"]["kernel"])[:, :, 0].T, - tab.query_proj.weight, - atol, - ) - _compare_grad( - f"{name}.kvproj.w", - np.array(jg["DenseGeneral_1"]["kernel"])[:, :, 0].T, - tab.kv_proj.weight, - atol, - ) - mha_g = jg["MultiHeadDotProductAttention_0"] - for qkv, proj in [("query", tab.q_proj), ("key", tab.k_proj), ("value", tab.v_proj)]: - _compare_grad( - f"{name}.{qkv}.w", - np.array(mha_g[qkv]["kernel"]).reshape(1, n_heads * depth).T, - proj.weight, - atol, - ) - _compare_grad(f"{name}.{qkv}.b", np.array(mha_g[qkv]["bias"]).reshape(-1), proj.bias, atol) - _compare_grad( - f"{name}.out.w", - np.array(mha_g["out"]["kernel"]).reshape(n_heads * depth, n_channels).T, - tab.out_proj.weight, - atol, - ) - _compare_grad(f"{name}.out.b", np.array(mha_g["out"]["bias"]), tab.out_proj.bias, atol) - _compare_mlp_grad(f"{name}.mlp_eps", jg["MLP_0"], tab.mlp_eps, atol) - _compare_mlp_grad(f"{name}.mlp_res", jg["MLP_1"], tab.mlp_residual, atol) - - -def compare_all_gradients(jax_grads, torch_mod, atol_s, atol_a): - g = jax_grads - tqu = torch_mod.qu - _compare_dense_grad("qu.fc1", g["qu"]["Dense_0"], tqu.fc1, atol_s) - _compare_dense_grad("qu.fc2", g["qu"]["Dense_1"], tqu.fc2, atol_s) - for i, cn in enumerate([tqu.conditional_norm1, tqu.conditional_norm2]): - _compare_embed_grad( - f"qu.cn{i}.gamma", - g["qu"][f"ConditionalNormalization_{i}"]["gamma_conditional"], - cn.gamma_embedding, - atol_s, - ) - _compare_embed_grad( - f"qu.cn{i}.beta", - g["qu"][f"ConditionalNormalization_{i}"]["beta_conditional"], - cn.beta_embedding, - atol_s, - ) - _compare_embed_grad("qu.sample_embed", g["qu"]["Embed_0"], tqu.sample_embed, atol_s) - nn_g = g["qu"]["NormalDistOutputNN_0"] - for i, rb in enumerate(tqu.output_nn.resnet_blocks): - _compare_resnet_block_grad(f"qu.nn.rb{i}", nn_g[f"ResnetBlock_{i}"], rb, atol_s) - _compare_dense_grad("qu.nn.fc_mean", nn_g["Dense_0"], tqu.output_nn.fc_mean, atol_s) - _compare_dense_grad("qu.nn.fc_scale", nn_g["Dense_1"], tqu.output_nn.fc_scale[0], atol_s) - - tqz = torch_mod.qz - _compare_ln_grad("qz.u_ln", g["qz"]["u_ln"], tqz.layer_norm, atol_a) - _compare_embed_grad("qz.embed", g["qz"]["Embed_0"], tqz.embedding, atol_a) - _compare_ln_grad("qz.embed_ln", g["qz"]["sample_embed_ln"], tqz.layer_norm_embed, atol_a) - _compare_attention_block_grad( - "qz.attn", g["qz"]["AttentionBlock_0"], tqz.attention_block, atol_a - ) - if tqz.fc is not None: - _compare_dense_grad("qz.fc", g["qz"]["Dense_0"], tqz.fc, atol_a) - - tpx = torch_mod.px - _compare_ln_grad("px.u_ln", g["px"]["u_ln"], tpx.layer_norm, atol_a) - _compare_embed_grad("px.batch_embed", g["px"]["Embed_0"], tpx.batch_embedding, atol_a) - _compare_ln_grad("px.batch_ln", g["px"]["batch_embed_ln"], tpx.layer_norm_batch_embed, atol_a) - _compare_attention_block_grad( - "px.attn", g["px"]["AttentionBlock_0"], tpx.attention_block, atol_a - ) - _compare_dense_grad("px.fc", g["px"]["Dense_0"], tpx.fc, atol_a) - _compare_grad("px.px_r", g["px"]["px_r"], tpx.px_r, atol_a) - - if torch_mod.u_prior_mixture: - _compare_grad("u_prior_logits", g["u_prior_logits"], torch_mod.u_prior_logits, atol_s) - _compare_grad("u_prior_means", g["u_prior_means"], torch_mod.u_prior_means, atol_s) - _compare_grad("u_prior_scales", g["u_prior_scales"], torch_mod.u_prior_scales, atol_s) - - -# ── shared test data ───────────────────────────────────────────────────────── - -_N_INPUT, _N_SAMPLE, _N_BATCH, _N_LABELS = 100, 5, 2, 1 -_N_LATENT, _N_LATENT_U = 30, 10 -_BS = 16 - - -def _make_test_data(): - np.random.seed(0) - return { - "x": (np.abs(np.random.randn(_BS, _N_INPUT)) + 0.1).astype(np.float32), - "sample": np.random.randint(0, _N_SAMPLE, (_BS, 1)).astype(np.float32), - "batch": np.random.randint(0, _N_BATCH, (_BS, 1)).astype(np.float32), - "label": np.zeros((_BS, 1), dtype=np.float32), - } - - -def _init_jax_module(data): - jax = pytest.importorskip("jax") - jnp = pytest.importorskip("jax.numpy") - JaxMRVAE = pytest.importorskip("scvi.external.mrvi_jax._module").JaxMRVAE - - jax.config.update("jax_default_matmul_precision", "float32") - - jax_mod = JaxMRVAE( - n_input=_N_INPUT, - n_sample=_N_SAMPLE, - n_batch=_N_BATCH, - n_labels=_N_LABELS, - n_latent=_N_LATENT, - n_latent_u=_N_LATENT_U, - training=False, - ) - tensors = { - REGISTRY_KEYS.X_KEY: jnp.array(data["x"]), - REGISTRY_KEYS.SAMPLE_KEY: jnp.array(data["sample"]), - REGISTRY_KEYS.BATCH_KEY: jnp.array(data["batch"]), - REGISTRY_KEYS.LABELS_KEY: jnp.array(data["label"]), - } - key = jax.random.PRNGKey(42) - keys = jax.random.split(key, 4) - rngs = {"params": keys[0], "u": keys[1], "dropout": keys[2], "eps": keys[3]} - variables = jax_mod.init(rngs, tensors) - return jax_mod, variables, tensors, rngs - - -def _init_torch_module(jax_params): - torch_mod = TorchMRVAE( - n_input=_N_INPUT, - n_sample=_N_SAMPLE, - n_batch=_N_BATCH, - n_labels=_N_LABELS, - n_latent=_N_LATENT, - n_latent_u=_N_LATENT_U, - ) - torch_mod.eval() - transfer_all_weights(jax_params, torch_mod) - return torch_mod - - -# ── tests ──────────────────────────────────────────────────────────────────── - - -@pytest.mark.jax -def test_forward_pass_equivalence(): - """Init JAX, transfer weights to PyTorch, compare all inference/generative/loss outputs.""" - pytest.importorskip("jax") - jnp = pytest.importorskip("jax.numpy") - - data = _make_test_data() - jax_mod, variables, tensors_jax, rngs = _init_jax_module(data) - torch_mod = _init_torch_module(variables["params"]) - - # Inference - jax_inf = jax_mod.apply( - variables, - x=jnp.array(data["x"]), - sample_index=jnp.array(data["sample"]), - use_mean=True, - method=jax_mod.inference, - rngs=rngs, - ) - with torch.no_grad(): - torch_inf = torch_mod.inference( - x=torch.tensor(data["x"]), sample_index=torch.tensor(data["sample"]), use_mean=True - ) - - for key in ("u", "z", "z_base", "library"): - atol = ATOL_SIMPLE if key in ("u", "library") else ATOL_ATTN - _compare(key, jax_inf[key], torch_inf[key], atol) - _compare("qu.mean", jax_inf["qu"].mean, torch_inf["qu"].mean, ATOL_SIMPLE) - _compare("qu.scale", jax_inf["qu"].scale, torch_inf["qu"].scale, ATOL_SIMPLE) - - # Generative - jax_gen = jax_mod.apply( - variables, - z=jax_inf["z"], - library=jax_inf["library"], - batch_index=jnp.array(data["batch"]), - label_index=jnp.array(data["label"]), - method=jax_mod.generative, - rngs=rngs, - ) - with torch.no_grad(): - torch_gen = torch_mod.generative( - z=torch_inf["z"], - library=torch_inf["library"], - batch_index=torch.tensor(data["batch"]), - label_index=torch.tensor(data["label"]), - ) - _compare("h", jax_gen["h"], torch_gen["h"], ATOL_ATTN) - _compare("px.mean", jax_gen["px"].mean, torch_gen["px"].mean, ATOL_ATTN) - - # Loss - tensors_torch = { - REGISTRY_KEYS.X_KEY: torch.tensor(data["x"]), - REGISTRY_KEYS.SAMPLE_KEY: torch.tensor(data["sample"]), - REGISTRY_KEYS.BATCH_KEY: torch.tensor(data["batch"]), - REGISTRY_KEYS.LABELS_KEY: torch.tensor(data["label"]), - } - - jax_loss = jax_mod.apply( - variables, - tensors=tensors_jax, - inference_outputs=jax_inf, - generative_outputs=jax_gen, - kl_weight=1.0, - method=jax_mod.loss, - rngs=rngs, - ) - with torch.no_grad(): - torch_loss = torch_mod.loss( - tensors=tensors_torch, - inference_outputs=torch_inf, - generative_outputs=torch_gen, - kl_weight=1.0, - ) - - def _extract(val): - return sum(val.values()) if isinstance(val, dict) else val - - _compare( - "recon_loss", - _extract(jax_loss.reconstruction_loss), - _extract(torch_loss.reconstruction_loss), - ATOL_ATTN, - ) - _compare("kl_local", _extract(jax_loss.kl_local), _extract(torch_loss.kl_local), ATOL_ATTN) - _compare("total_loss", jax_loss.loss, torch_loss.loss, ATOL_ATTN) - - -@pytest.mark.jax -def test_gradient_equivalence(): - """Compare jax.grad vs loss.backward() for all parameters.""" - jax = pytest.importorskip("jax") - jnp = pytest.importorskip("jax.numpy") - - data = _make_test_data() - jax_mod, variables, tensors_jax, rngs = _init_jax_module(data) - - def jax_loss_fn(params): - v = {"params": params} - inf = jax_mod.apply( - v, - x=jnp.array(data["x"]), - sample_index=jnp.array(data["sample"]), - use_mean=True, - method=jax_mod.inference, - rngs=rngs, - ) - gen = jax_mod.apply( - v, - z=inf["z"], - library=inf["library"], - batch_index=jnp.array(data["batch"]), - label_index=jnp.array(data["label"]), - method=jax_mod.generative, - rngs=rngs, - ) - loss = jax_mod.apply( - v, - tensors=tensors_jax, - inference_outputs=inf, - generative_outputs=gen, - kl_weight=1.0, - method=jax_mod.loss, - rngs=rngs, - ) - return loss.loss - - jax_grads = jax.grad(jax_loss_fn)(variables["params"]) - - torch_mod = _init_torch_module(variables["params"]) - torch_mod.zero_grad() - torch_inf = torch_mod.inference( - x=torch.tensor(data["x"]), sample_index=torch.tensor(data["sample"]), use_mean=True - ) - torch_gen = torch_mod.generative( - z=torch_inf["z"], - library=torch_inf["library"], - batch_index=torch.tensor(data["batch"]), - label_index=torch.tensor(data["label"]), - ) - tensors_torch = { - REGISTRY_KEYS.X_KEY: torch.tensor(data["x"]), - REGISTRY_KEYS.SAMPLE_KEY: torch.tensor(data["sample"]), - REGISTRY_KEYS.BATCH_KEY: torch.tensor(data["batch"]), - REGISTRY_KEYS.LABELS_KEY: torch.tensor(data["label"]), - } - torch_mod.loss( - tensors=tensors_torch, - inference_outputs=torch_inf, - generative_outputs=torch_gen, - kl_weight=1.0, - ).loss.backward() - - compare_all_gradients(jax_grads, torch_mod, ATOL_SIMPLE, ATOL_ATTN) diff --git a/tests/external/mrvi_torch/test_torchmrvi_model.py b/tests/external/mrvi_torch/test_torchmrvi_model.py index 6abb2f8c79..a2fb8fb8c2 100644 --- a/tests/external/mrvi_torch/test_torchmrvi_model.py +++ b/tests/external/mrvi_torch/test_torchmrvi_model.py @@ -37,7 +37,7 @@ def adata(): @pytest.fixture(scope="session") def model(adata: AnnData): - MRVI.setup_anndata(adata, sample_key="sample_str", batch_key="batch", backend="torch") + MRVI.setup_anndata(adata, sample_key="sample_str", batch_key="batch") model = MRVI(adata) model.train(max_epochs=1, train_size=0.5) @@ -46,9 +46,7 @@ def model(adata: AnnData): @pytest.fixture(scope="session") def model2(adata: AnnData): - MRVI.setup_anndata( - adata, sample_key="sample_str", batch_key="batch", backend="torch", labels_key="labels" - ) + MRVI.setup_anndata(adata, sample_key="sample_str", batch_key="batch", labels_key="labels") model = MRVI(adata) model.train(max_epochs=1, train_size=0.5) @@ -288,7 +286,6 @@ def test_torchMRVI_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], sa adata, sample_key="sample_str", batch_key="batch", - backend="torch", ) model = MRVI(adata, n_latent=10, scale_observations=True, **model_kwargs) model.train(max_epochs=1, train_size=0.5) @@ -315,7 +312,6 @@ def test_torchMRVI_shrink_u(adata: AnnData, save_path: str): adata, sample_key="sample_str", batch_key="batch", - backend="torch", ) model = MRVI(adata, n_latent=10, n_latent_u=5) model.train(max_epochs=1, train_size=0.5) @@ -350,7 +346,6 @@ def test_torchMRVI_stratifications(adata_stratifications: AnnData, save_path: st adata_stratifications, sample_key="sample_str", batch_key="batch", - backend="torch", ) model = MRVI(adata_stratifications, n_latent=10) model.train(max_epochs=1, train_size=0.5) diff --git a/tests/external/tangram/test_tangram.py b/tests/external/tangram/test_tangram.py deleted file mode 100644 index 10f12c3b88..0000000000 --- a/tests/external/tangram/test_tangram.py +++ /dev/null @@ -1,72 +0,0 @@ -import mudata -import numpy as np -import pytest - -from scvi.data import synthetic_iid -from scvi.external import Tangram - -modalities = {"density_prior_key": "sp", "sc_layer": "sc", "sp_layer": "sp"} - - -def _get_mdata(sparse_format: str | None = None): - dataset1 = synthetic_iid(batch_size=100, sparse_format=sparse_format) - dataset2 = dataset1[-25:].copy() - dataset1 = dataset1[:-25].copy() - mdata = mudata.MuData({"sc": dataset1, "sp": dataset2}) - ad_sp = mdata.mod["sp"] - rna_count_per_spot = np.asarray(ad_sp.X.sum(axis=1)).squeeze() - ad_sp.obs["rna_count_based_density"] = rna_count_per_spot / np.sum(rna_count_per_spot) - ad_sp.obs["bad_prior"] = np.random.uniform(size=ad_sp.n_obs) - return mdata - - -@pytest.mark.jax -@pytest.mark.parametrize( - ("density_prior_key", "constrained"), - [ - (None, False), - ("rna_count_based_density", False), - ("rna_count_based_density", True), - ], -) -def test_tangram(density_prior_key, constrained): - mdata = _get_mdata() - Tangram.setup_mudata( - mdata, - density_prior_key=density_prior_key, - modalities=modalities, - ) - if constrained: - target_count = 2 - else: - target_count = None - model = Tangram(mdata, constrained=constrained, target_count=target_count) - model.train(max_epochs=1) - mdata.mod["sc"].obsm["mapper"] = model.get_mapper_matrix() - model.project_cell_annotations( - mdata.mod["sc"], - mdata.mod["sp"], - mdata.mod["sc"].obsm["mapper"], - mdata.mod["sc"].obs.labels, - ) - model.project_genes(mdata.mod["sc"], mdata.mod["sp"], mdata.mod["sc"].obsm["mapper"]) - - -@pytest.mark.jax -def test_tangram_errors(): - mdata = _get_mdata() - Tangram.setup_mudata( - mdata, - density_prior_key="rna_count_based_density", - modalities=modalities, - ) - with pytest.raises(ValueError): - Tangram(mdata, constrained=True, target_count=None) - - Tangram.setup_mudata( - mdata, - density_prior_key="bad_prior", - modalities=modalities, - ) - with pytest.raises(ValueError): - Tangram(mdata) diff --git a/tests/model/test_jaxscvi.py b/tests/model/test_jaxscvi.py deleted file mode 100644 index 3f852ae8d8..0000000000 --- a/tests/model/test_jaxscvi.py +++ /dev/null @@ -1,220 +0,0 @@ -from unittest import mock - -import numpy as np -import pytest - -from scvi.data import synthetic_iid -from scvi.model import JaxSCVI -from scvi.train import JaxTrainingPlan -from scvi.utils import attrdict - - -@pytest.mark.jax -@pytest.mark.parametrize("n_latent", [5]) -def test_jax_scvi(n_latent: int): - adata = synthetic_iid() - JaxSCVI.setup_anndata( - adata, - batch_key="batch", - ) - model = JaxSCVI(adata, n_latent=n_latent) - model.train(2, train_size=0.5, check_val_every_n_epoch=1) - model.get_latent_representation() - - model = JaxSCVI(adata, n_latent=n_latent, gene_likelihood="poisson") - model.train(1, train_size=0.5) - z1 = model.get_latent_representation(give_mean=True, n_samples=1) - assert z1.ndim == 2 - z2 = model.get_latent_representation(give_mean=False, n_samples=15) - assert z2.ndim == 3 - assert z2.shape[0] == 15 - - -@pytest.mark.jax -@pytest.mark.parametrize("n_latent", [5]) -@pytest.mark.parametrize("dropout_rate", [0.1]) -def test_jax_scvi_training(n_latent: int, dropout_rate: float): - from flax import linen as nn - - adata = synthetic_iid() - JaxSCVI.setup_anndata( - adata, - batch_key="batch", - ) - - model = JaxSCVI(adata, n_latent=n_latent, dropout_rate=dropout_rate) - assert model.module.training - - with mock.patch("scvi.module._jaxvae.nn.Dropout", wraps=nn.Dropout) as mock_dropout_cls: - mock_dropout = mock.Mock() - mock_dropout.side_effect = lambda h, **_kwargs: h - mock_dropout_cls.return_value = mock_dropout - model.train(1, train_size=0.5, check_val_every_n_epoch=1) - - assert not model.module.training - mock_dropout_cls.assert_called() - mock_dropout.assert_has_calls( - 12 * [mock.call(mock.ANY, deterministic=False)] - + 8 * [mock.call(mock.ANY, deterministic=True)] - ) - - -@pytest.mark.jax -@pytest.mark.parametrize("n_latent", [5]) -def test_jax_scvi_save_load(save_path: str, n_latent: int): - adata = synthetic_iid() - JaxSCVI.setup_anndata( - adata, - batch_key="batch", - ) - model = JaxSCVI(adata, n_latent=n_latent) - model.train(2, train_size=0.5, check_val_every_n_epoch=1) - z1 = model.get_latent_representation(adata) - model.save(save_path, overwrite=True, save_anndata=True) - model.view_setup_args(save_path) - model = JaxSCVI.load(save_path) - model.get_latent_representation() - - # Load with mismatched genes. - tmp_adata = synthetic_iid( - n_genes=200, - ) - with pytest.raises(ValueError): - JaxSCVI.load(save_path, adata=tmp_adata) - - # Load with different batches. - tmp_adata = synthetic_iid() - tmp_adata.obs["batch"] = tmp_adata.obs["batch"].cat.rename_categories(["batch_2", "batch_3"]) - with pytest.raises(ValueError): - JaxSCVI.load(save_path, adata=tmp_adata) - - model = JaxSCVI.load(save_path, adata=adata) - assert "batch" in model.adata_manager.data_registry - assert model.adata_manager.data_registry.batch == attrdict( - {"attr_name": "obs", "attr_key": "_scvi_batch"} - ) - assert model.is_trained is True - - z2 = model.get_latent_representation() - np.testing.assert_array_equal(z1, z2) - - -@pytest.mark.jax -def test_jax_scvi_history(): - """Test that JaxSCVI logs unsuffixed history keys.""" - adata = synthetic_iid() - JaxSCVI.setup_anndata(adata, batch_key="batch") - model = JaxSCVI(adata, n_latent=5) - model.train(2, train_size=0.5, check_val_every_n_epoch=1) - - assert "train_loss" in model.history, ( - f"Expected 'train_loss' in history, got keys: {list(model.history.keys())}" - ) - assert "train_loss_epoch" not in model.history - assert "validation_loss" in model.history - assert "elbo_train" in model.history - - -@pytest.mark.multigpu -@pytest.mark.jax -def test_jax_scvi_multigpu(): - """Test JaxSCVI with multiple GPUs using pmap.""" - import jax - - n_devices = jax.local_device_count() - assert n_devices > 1, f"Need >1 device for multi-GPU test, got {n_devices}" - - adata = synthetic_iid() - JaxSCVI.setup_anndata(adata, batch_key="batch") - - model = JaxSCVI(adata, n_latent=5) - model.train( - 2, - train_size=0.5, - check_val_every_n_epoch=1, - batch_size=128, - accelerator="gpu", - devices="auto", - ) - - assert model.is_trained - assert "train_loss" in model.history - assert "validation_loss" in model.history - - z = model.get_latent_representation() - assert z.shape == (adata.n_obs, 5) - - -@pytest.mark.multigpu -@pytest.mark.jax -def test_jax_scvi_single_gpu_explicit(): - """Test JaxSCVI with devices=1 on a multi-GPU machine uses single-GPU path.""" - import jax - - n_devices = jax.local_device_count() - assert n_devices > 1, f"Need >1 device for this test, got {n_devices}" - - adata = synthetic_iid() - JaxSCVI.setup_anndata(adata, batch_key="batch") - - model = JaxSCVI(adata, n_latent=5) - model.train( - 2, - train_size=0.5, - check_val_every_n_epoch=1, - batch_size=128, - accelerator="gpu", - devices=1, - ) - - # Should have used single-device path (no pmap) - assert model.training_plan.n_devices == 1 - assert model.is_trained - z = model.get_latent_representation() - assert z.shape == (adata.n_obs, 5) - - -@pytest.mark.jax -def test_loss_args_jax(): - """Test that self._loss_args is set correctly.""" - adata = synthetic_iid() - JaxSCVI.setup_anndata(adata) - jax_vae = JaxSCVI(adata) - jax_tp = JaxTrainingPlan(jax_vae.module) - - loss_args = [ - "tensors", - "inference_outputs", - "generative_outputs", - "kl_weight", - ] - assert len(jax_tp._loss_args) == len(loss_args) - for arg in loss_args: - assert arg in jax_tp._loss_args - - -@pytest.mark.jax -def test_multiple_covariates_jaxscvi(): - """Test that JaxSCVI can handle multiple categorical and continuous covariates.""" - adata = synthetic_iid() - adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],)) - adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],)) - adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0],)) - adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0],)) - - JaxSCVI.setup_anndata( - adata, - batch_key="batch", - labels_key="labels", - continuous_covariate_keys=["cont1", "cont2"], - categorical_covariate_keys=["cat1", "cat2"], - ) - m = JaxSCVI(adata) - m.train(1) - z1 = m.get_latent_representation(give_mean=True, n_samples=1) - assert z1.ndim == 2 - # n_samples > 1 triggers the 3-D z path in generative; covariates must be - # broadcast to match the sample dimension before concatenation. - z2 = m.get_latent_representation(give_mean=False, n_samples=5) - assert z2.ndim == 3 - assert z2.shape[0] == 5