diff --git a/docs/api/developer.md b/docs/api/developer.md index 511265ab33..7b5fdd4e8d 100644 --- a/docs/api/developer.md +++ b/docs/api/developer.md @@ -122,8 +122,6 @@ These classes should be used to construct user-facing model classes. model.base.UnsupervisedTrainingMixin model.base.PyroSviTrainMixin model.base.PyroSampleMixin - model.base.PyroJitGuideWarmup - model.base.PyroModelGuideWarmup model.base.DifferentialComputation model.base.EmbeddingMixin ``` diff --git a/pyproject.toml b/pyproject.toml index 362e7c63da..c66a8d4734 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,7 +87,7 @@ loompy = ["loompy>=3.0.6"] scanpy = ["scanpy>=1.6"] optional = [ - "scvi-tools[autotune,aws,criticism,hub,loompy,pymde,regseq,scanpy]" + "scvi-tools[autotune,aws,hub,loompy,pymde,regseq,scanpy]" ] tutorials = [ "cell2location", diff --git a/scvi/model/base/__init__.py b/scvi/model/base/__init__.py index e8573f8d53..466b542a46 100644 --- a/scvi/model/base/__init__.py +++ b/scvi/model/base/__init__.py @@ -4,10 +4,9 @@ from ._embedding_mixin import EmbeddingMixin from ._jaxmixin import JaxTrainingMixin from ._pyromixin import ( - PyroJitGuideWarmup, - PyroModelGuideWarmup, PyroSampleMixin, PyroSviTrainMixin, + setup_pyro_model, ) from ._rnamixin import RNASeqMixin from ._training_mixin import UnsupervisedTrainingMixin @@ -21,8 +20,7 @@ "UnsupervisedTrainingMixin", "PyroSviTrainMixin", "PyroSampleMixin", - "PyroJitGuideWarmup", - "PyroModelGuideWarmup", + "setup_pyro_model", "DifferentialComputation", "JaxTrainingMixin", "BaseMinifiedModeModelClass", diff --git a/scvi/model/base/_pyromixin.py b/scvi/model/base/_pyromixin.py index a4bc1308e0..446b9a2e4e 100755 --- a/scvi/model/base/_pyromixin.py +++ b/scvi/model/base/_pyromixin.py @@ -5,7 +5,6 @@ import numpy as np import torch -from lightning.pytorch.callbacks import Callback from pyro import poutine from scvi import settings @@ -18,59 +17,17 @@ logger = logging.getLogger(__name__) -class PyroJitGuideWarmup(Callback): - """A callback to warmup a Pyro guide. +def setup_pyro_model(dataloader, pl_module): + """Way to warmup Pyro Model and Guide in an automated way. - This helps initialize all the relevant parameters by running - one minibatch through the Pyro model. + Setup occurs before any device movement, so params are iniitalized on CPU. """ - - def __init__(self, dataloader: AnnDataLoader = None) -> None: - super().__init__() - self.dataloader = dataloader - - def on_train_start(self, trainer, pl_module): - """Way to warmup Pyro Guide in an automated way. - - Also device agnostic. - """ - # warmup guide for JIT - pyro_guide = pl_module.module.guide - if self.dataloader is None: - dl = trainer.datamodule.train_dataloader() - else: - dl = self.dataloader - for tensors in dl: - tens = {k: t.to(pl_module.device) for k, t in tensors.items()} - args, kwargs = pl_module.module._get_fn_args_from_batch(tens) - pyro_guide(*args, **kwargs) - break - - -class PyroModelGuideWarmup(Callback): - """A callback to warmup a Pyro guide and model. - - This helps initialize all the relevant parameters by running - one minibatch through the Pyro model. This warmup occurs on the CPU. - """ - - def __init__(self, dataloader: AnnDataLoader) -> None: - super().__init__() - self.dataloader = dataloader - - def setup(self, trainer, pl_module, stage=None): - """Way to warmup Pyro Model and Guide in an automated way. - - Setup occurs before any device movement, so params are iniitalized on CPU. - """ - if stage == "fit": - pyro_guide = pl_module.module.guide - dl = self.dataloader - for tensors in dl: - tens = {k: t.to(pl_module.device) for k, t in tensors.items()} - args, kwargs = pl_module.module._get_fn_args_from_batch(tens) - pyro_guide(*args, **kwargs) - break + for tensors in dataloader: + tens = {k: t.to(pl_module.device) for k, t in tensors.items()} + args, kwargs = pl_module.module._get_fn_args_from_batch(tens) + pl_module.module.guide(*args, **kwargs) + pl_module.module.model(*args, **kwargs) + break class PyroSviTrainMixin: @@ -177,7 +134,14 @@ def train( if "callbacks" not in trainer_kwargs.keys(): trainer_kwargs["callbacks"] = [] - trainer_kwargs["callbacks"].append(PyroJitGuideWarmup()) + + # Initialise pyro model with data + from copy import copy + + dl = copy(data_splitter) + dl.setup() + dl = dl.train_dataloader() + setup_pyro_model(dl, training_plan) runner = self._train_runner_cls( self, @@ -202,8 +166,9 @@ def _get_one_posterior_sample( self, args, kwargs, - return_sites: list | None = None, + return_sites: list = None, return_observed: bool = False, + exclude_vars: list = None, ): """Get one sample from posterior distribution. @@ -225,6 +190,11 @@ def _get_one_posterior_sample( if isinstance(self.module.guide, poutine.messenger.Messenger): # This already includes trace-replay behavior. sample = self.module.guide(*args, **kwargs) + # include and exclude requested sites + if return_sites is not None: + sample = {k: v for k, v in sample.items() if k in return_sites} + if exclude_vars is not None: + sample = {k: v for k, v in sample.items() if k not in exclude_vars} else: guide_trace = poutine.trace(self.module.guide).get_trace(*args, **kwargs) model_trace = poutine.trace(poutine.replay(self.module.model, guide_trace)).get_trace( @@ -235,6 +205,9 @@ def _get_one_posterior_sample( for name, site in model_trace.nodes.items() if ( (site["type"] == "sample") # sample statement + and not ( + name in exclude_vars if exclude_vars is not None else False + ) # exclude variables and ( (return_sites is None) or (name in return_sites) ) # selected in return_sites list @@ -261,6 +234,7 @@ def _get_posterior_samples( num_samples: int = 1000, return_sites: list | None = None, return_observed: bool = False, + exclude_vars: list | None = None, show_progress: bool = True, ): """Get many (num_samples=N) samples from posterior distribution. @@ -284,7 +258,11 @@ def _get_posterior_samples( dictionary {variable_name: [array with samples in 0 dimension]} """ samples = self._get_one_posterior_sample( - args, kwargs, return_sites=return_sites, return_observed=return_observed + args, + kwargs, + return_sites=return_sites, + return_observed=return_observed, + exclude_vars=exclude_vars, ) samples = {k: [v] for k, v in samples.items()} @@ -296,7 +274,11 @@ def _get_posterior_samples( ): # generate new sample samples_ = self._get_one_posterior_sample( - args, kwargs, return_sites=return_sites, return_observed=return_observed + args, + kwargs, + return_sites=return_sites, + return_observed=return_observed, + exclude_vars=exclude_vars, ) # add new sample @@ -365,6 +347,47 @@ def _get_obs_plate_sites( return obs_plate + def _get_valid_sites( + self, + args: list, + kwargs: dict, + return_observed: bool = False, + ): + """Automatically guess which model sites should be sampled. + + Parameters + ---------- + args + Arguments to the model. + kwargs + Keyword arguments to the model. + return_observed + Record samples of observed variables. + + Returns + ------- + List with keys corresponding to site names. + """ + # find plate dimension + trace = poutine.trace(self.module.model).get_trace(*args, **kwargs) + valid_sites = [ + name + for name, site in trace.nodes.items() + if ( + (site["type"] == "sample") # sample statement + and ( + ( + (not site.get("is_observed", True)) or return_observed + ) # don't save observed unless requested + or (site.get("infer", False).get("_deterministic", False)) + ) # unless it is deterministic + and not isinstance( + site.get("fn", None), poutine.subsample_messenger._Subsample + ) # don't save plates + ) + ] + return valid_sites + @devices_dsp.dedent def _posterior_samples_minibatch( self, @@ -415,13 +438,17 @@ def _posterior_samples_minibatch( self.to_device(device) if i == 0: - return_observed = getattr(sample_kwargs, "return_observed", False) + # get observation plate sites + return_observed = sample_kwargs.get("return_observed", False) obs_plate_sites = self._get_obs_plate_sites( args, kwargs, return_observed=return_observed ) if len(obs_plate_sites) == 0: # if no local variables - don't sample break + # get valid sites & filter local sites + valid_sites = self._get_valid_sites(args, kwargs, return_observed=return_observed) + obs_plate_sites = {k: v for k, v in obs_plate_sites.items() if k in valid_sites} obs_plate_dim = list(obs_plate_sites.values())[0] sample_kwargs_obs_plate = sample_kwargs.copy() @@ -449,10 +476,10 @@ def _posterior_samples_minibatch( i += 1 # sample global parameters + valid_sites = self._get_valid_sites(args, kwargs, return_observed=return_observed) + valid_sites = [v for v in valid_sites if v not in obs_plate_sites.keys()] + sample_kwargs["return_sites"] = valid_sites global_samples = self._get_posterior_samples(args, kwargs, **sample_kwargs) - global_samples = { - k: v for k, v in global_samples.items() if k not in list(obs_plate_sites.keys()) - } for k in global_samples.keys(): samples[k] = global_samples[k] @@ -471,6 +498,7 @@ def sample_posterior( batch_size: int | None = None, return_observed: bool = False, return_samples: bool = False, + exclude_vars: list | None = None, summary_fun: dict[str, Callable] | None = None, ): """Summarise posterior distribution. @@ -531,6 +559,7 @@ def sample_posterior( num_samples=num_samples, return_sites=return_sites, return_observed=return_observed, + exclude_vars=exclude_vars, ) param_names = list(samples.keys()) diff --git a/tests/model/test_pyro.py b/tests/model/test_pyro.py index 56f10f8d92..4c48fe8b13 100644 --- a/tests/model/test_pyro.py +++ b/tests/model/test_pyro.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +from copy import copy import numpy as np import pyro @@ -18,10 +19,9 @@ from scvi.dataloaders import AnnDataLoader from scvi.model.base import ( BaseModelClass, - PyroJitGuideWarmup, - PyroModelGuideWarmup, PyroSampleMixin, PyroSviTrainMixin, + setup_pyro_model, ) from scvi.module.base import PyroBaseModuleClass from scvi.nn import DecoderSCVI, Encoder @@ -194,11 +194,11 @@ def test_pyro_bayesian_regression_low_level( model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = LowLevelPyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) + setup_pyro_model(copy(train_dl), plan) trainer = Trainer( accelerator=accelerator, devices=devices, max_epochs=2, - callbacks=[PyroModelGuideWarmup(train_dl)], ) trainer.fit(plan, train_dl) # 100 features @@ -220,6 +220,7 @@ def test_pyro_bayesian_regression(accelerator: str, devices: list | str | int, s model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model) plan.n_obs_training = len(train_dl.indices) + setup_pyro_model(copy(train_dl), plan) trainer = Trainer( accelerator=accelerator, devices=devices, @@ -285,11 +286,11 @@ def test_pyro_bayesian_regression_jit( model = BayesianRegressionModule(in_features=adata.shape[1], out_features=1) plan = PyroTrainingPlan(model, loss_fn=pyro.infer.JitTrace_ELBO()) plan.n_obs_training = len(train_dl.indices) + setup_pyro_model(copy(train_dl), plan) trainer = Trainer( accelerator=accelerator, devices=devices, max_epochs=2, - callbacks=[PyroJitGuideWarmup(train_dl)], ) trainer.fit(plan, train_dl) @@ -415,6 +416,8 @@ def test_pyro_bayesian_train_sample_mixin_with_local(): adata.n_obs, 1, ) + # test that observed variables are excluded + assert "obs" not in samples["posterior_samples"].keys() def test_pyro_bayesian_train_sample_mixin_with_local_full_data():