diff --git a/CHANGELOG.md b/CHANGELOG.md index e4f8dd1fa2..e48fdd7261 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ to [Semantic Versioning]. Full commit history is available in the #### Changed +- Update model {class}`scvi.model.DestVI` with fine cell-type classifier {pr}`3380`. + #### Removed ### 1.3.3 (2025-07-23) diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 72c6bf8ca4..2b86fc4067 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -413,7 +413,7 @@ def __init__( labels_state_registry.original_key, mod_key=getattr(self.adata_manager.data_registry.labels, "mod_key", None), ).ravel() - self.unlabeled_category = labels_state_registry.unlabeled_category + self.unlabeled_category = getattr(labels_state_registry, "unlabeled_category", None) self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category).ravel() self._labeled_indices = np.argwhere(labels != self.unlabeled_category).ravel() diff --git a/src/scvi/dataloaders/_semi_dataloader.py b/src/scvi/dataloaders/_semi_dataloader.py index 384112d245..622d8a78ea 100644 --- a/src/scvi/dataloaders/_semi_dataloader.py +++ b/src/scvi/dataloaders/_semi_dataloader.py @@ -66,7 +66,7 @@ def __init__( # save a nested list of the indices per labeled category self.labeled_locs = [] for label in np.unique(labels): - if label != labels_state_registry.unlabeled_category: + if label != getattr(labels_state_registry, "unlabeled_category", None): label_loc_idx = np.where(labels[indices] == label)[0] label_loc = self.indices[label_loc_idx] self.labeled_locs.append(label_loc) diff --git a/src/scvi/external/resolvi/_model.py b/src/scvi/external/resolvi/_model.py index 3bb24abd5c..25730a13ac 100644 --- a/src/scvi/external/resolvi/_model.py +++ b/src/scvi/external/resolvi/_model.py @@ -657,7 +657,7 @@ def _set_indices_and_labels(self): """Set indices for labeled and unlabeled cells.""" labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key - self.unlabeled_category_ = labels_state_registry.unlabeled_category + self.unlabeled_category_ = getattr(labels_state_registry, "unlabeled_category", None) labels = get_anndata_attribute( self.adata, diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index 45af751ff2..9cce7da4ab 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -8,16 +8,17 @@ import torch from scvi import REGISTRY_KEYS, settings -from scvi.data import AnnDataManager, fields +from scvi.data import AnnDataManager +from scvi.data._utils import _get_adata_minify_type +from scvi.data.fields import CategoricalObsField, LabelsWithUnlabeledObsField, LayerField from scvi.model.base import ( BaseModelClass, RNASeqMixin, - UnsupervisedTrainingMixin, + SemisupervisedTrainingMixin, VAEMixin, ) from scvi.module import VAEC from scvi.utils import setup_anndata_dsp -from scvi.utils._docstrings import devices_dsp if TYPE_CHECKING: from anndata import AnnData @@ -25,7 +26,7 @@ logger = logging.getLogger(__name__) -class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +class CondSCVI(RNASeqMixin, VAEMixin, SemisupervisedTrainingMixin, BaseModelClass): """Conditional version of single-cell Variational Inference. Used for multi-resolution deconvolution of spatial transcriptomics data :cite:p:`Lopez22`. @@ -77,6 +78,9 @@ def __init__( ): super().__init__(adata) + self.n_labels = self.summary_stats.n_labels + self.n_vars = self.summary_stats.n_vars + self._set_indices_and_labels() if weight_obs: ct_counts = np.unique( self.get_from_registry(adata, REGISTRY_KEYS.LABELS_KEY), @@ -89,9 +93,10 @@ def __init__( module_kwargs.update({"ct_weight": ct_weight}) self.module = self._module_cls( - n_input=self.summary_stats.n_vars, + n_input=self.n_vars, n_batch=getattr(self.summary_stats, "n_batch", 0), - n_labels=self.summary_stats.n_labels, + n_labels=self.n_labels, + n_fine_labels=getattr(self.summary_stats, "n_fine_labels", None), n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, @@ -126,162 +131,302 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra (n_labels, p, D) array var_vprior (n_labels, p, D) array + weights_vprior + (n_labels, p) array """ from sklearn.cluster import KMeans if self.is_trained_ is False: warnings.warn( - "Trying to query inferred values from an untrained model. Please train " - "the model first.", + "Trying to query inferred values from an untrained model. " + "Please train the model first.", UserWarning, stacklevel=settings.warnings_stacklevel, ) adata = self._validate_anndata(adata) - # Extracting latent representation of adata including variances. - mean_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent)) - var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent)) - mp_vprior = np.zeros((self.summary_stats.n_labels, p)) - - labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) - key = labels_state_registry.original_key - mapping = labels_state_registry.categorical_mapping - - scdl = self._make_data_loader(adata=adata, batch_size=p) - - mean = [] - var = [] - for tensors in scdl: - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY] - batch_index = tensors.get(REGISTRY_KEYS.BATCH_KEY, None) - out = self.module.inference(x, y, batch_index=batch_index) - mean_, var_ = out["qz"].loc, (out["qz"].scale ** 2) - mean += [mean_.cpu()] - var += [var_.cpu()] - - mean_cat, var_cat = torch.cat(mean).numpy(), torch.cat(var).numpy() - - for ct in range(self.summary_stats["n_labels"]): - local_indices = np.where(adata.obs[key] == mapping[ct])[0] - n_local_indices = len(local_indices) - if "overclustering_vamp" not in adata.obs.columns: - if p < n_local_indices and p > 0: - overclustering_vamp = KMeans(n_clusters=p, n_init=30).fit_predict( - mean_cat[local_indices] - ) + if self.module.prior == "mog": + results = { + "mean_vprior": self.module.prior_means, + "var_vprior": torch.exp(self.module.prior_log_std) ** 2, + "weights_vprior": torch.nn.functional.softmax(self.module.prior_logits, dim=-1), + } + else: + # Extracting latent representation of adata including variances. + mean_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent)) + var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent)) + mp_vprior = np.zeros((self.summary_stats.n_labels, p)) + + labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + key = labels_state_registry.original_key + mapping = labels_state_registry.categorical_mapping + + mean_cat, var_cat = self.get_latent_representation(adata, return_dist=True) + + for ct in range(self.summary_stats["n_labels"]): + local_indices = np.where(adata.obs[key] == mapping[ct])[0] + n_local_indices = len(local_indices) + if "overclustering_vamp" not in adata.obs.columns: + if p < n_local_indices and p > 0: + overclustering_vamp = KMeans(n_clusters=p, n_init=30).fit_predict( + mean_cat[local_indices] + ) + else: + # Every cell is its own cluster + overclustering_vamp = np.arange(n_local_indices) else: - # Every cell is its own cluster - overclustering_vamp = np.arange(n_local_indices) - else: - overclustering_vamp = adata[local_indices, :].obs["overclustering_vamp"] - - keys, counts = np.unique(overclustering_vamp, return_counts=True) - - n_labels_overclustering = len(keys) - if n_labels_overclustering > p: - error_mess = """ - Given cell type specific clustering contains more clusters than vamp_prior_p. - Increase value of vamp_prior_p to largest number of cell type specific - clusters.""" - - raise ValueError(error_mess) - - var_cluster = np.ones( - [ - n_labels_overclustering, - self.module.n_latent, - ] - ) - mean_cluster = np.zeros_like(var_cluster) + overclustering_vamp = adata[local_indices, :].obs["overclustering_vamp"] - for index, cluster in enumerate(keys): - indices_curr = local_indices[np.where(overclustering_vamp == cluster)[0]] - var_cluster[index, :] = np.mean(var_cat[indices_curr], axis=0) + np.var( - mean_cat[indices_curr], axis=0 - ) - mean_cluster[index, :] = np.mean(mean_cat[indices_curr], axis=0) + keys, counts = np.unique(overclustering_vamp, return_counts=True) - slicing = slice(n_labels_overclustering) - mean_vprior[ct, slicing, :] = mean_cluster - var_vprior[ct, slicing, :] = var_cluster - mp_vprior[ct, slicing] = counts / sum(counts) + n_labels_overclustering = len(keys) + if n_labels_overclustering > p: + error_mess = """ + Given cell type specific clustering contains more clusters than + vamp_prior_p. Increase value of vamp_prior_p to largest number of cell type + specific clusters.""" - return mean_vprior, var_vprior, mp_vprior + raise ValueError(error_mess) - @devices_dsp.dedent - def train( - self, - max_epochs: int = 300, - lr: float = 0.001, - accelerator: str = "auto", - devices: int | list[int] | str = "auto", - train_size: float = 1, - validation_size: float | None = None, - shuffle_set_split: bool = True, - batch_size: int = 128, - datasplitter_kwargs: dict | None = None, - plan_kwargs: dict | None = None, - **kwargs, - ): - """Trains the model using MAP inference. + var_cluster = np.ones( + [ + n_labels_overclustering, + self.module.n_latent, + ] + ) + mean_cluster = np.zeros_like(var_cluster) - Parameters - ---------- - max_epochs - Number of epochs to train for - lr - Learning rate for optimization. - %(param_accelerator)s - %(param_devices)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set - are split in the sequential order of the data according to `validation_size` and - `train_size` percentages. - batch_size - Minibatch size to use during training. - datasplitter_kwargs - Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. - plan_kwargs - Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **kwargs - Other keyword args for :class:`~scvi.train.Trainer`. - """ - update_dict = { - "lr": lr, - } - if plan_kwargs is not None: - plan_kwargs.update(update_dict) - else: - plan_kwargs = update_dict - super().train( - max_epochs=max_epochs, - accelerator=accelerator, - devices=devices, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - batch_size=batch_size, - datasplitter_kwargs=datasplitter_kwargs, - plan_kwargs=plan_kwargs, - **kwargs, - ) + for index, cluster in enumerate(keys): + indices_curr = local_indices[np.where(overclustering_vamp == cluster)[0]] + var_cluster[index, :] = np.mean(var_cat[indices_curr], axis=0) + np.var( + mean_cat[indices_curr], axis=0 + ) + mean_cluster[index, :] = np.mean(mean_cat[indices_curr], axis=0) + + slicing = slice(n_labels_overclustering) + mean_vprior[ct, slicing, :] = mean_cluster + var_vprior[ct, slicing, :] = var_cluster + mp_vprior[ct, slicing] = counts / sum(counts) + results = { + "mean_vprior": mean_vprior, + "var_vprior": var_vprior, + "weights_vprior": mp_vprior, + } + + return results + + # @torch.inference_mode() + # def predict( + # self, + # adata: AnnData | None = None, + # indices: list[int] | None = None, + # soft: bool = False, + # batch_size: int | None = None, + # use_posterior_mean: bool = True, + # ) -> np.ndarray | pd.DataFrame: + # """Return cell label predictions. + # + # Parameters + # ---------- + # adata + # AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + # indices + # Return probabilities for each class label. + # soft + # If True, returns per class probabilities + # batch_size + # Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + # use_posterior_mean + # If ``True``, uses the mean of the posterior distribution to predict celltype + # labels. Otherwise, uses a sample from the posterior distribution - this + # means that the predictions will be stochastic. + # """ + # + # adata = self._validate_anndata(adata) + # + # if indices is None: + # indices = np.arange(adata.n_obs) + # + # scdl = self._make_data_loader( + # adata=adata, + # indices=indices, + # batch_size=batch_size, + # ) + # y_pred = [] + # for _, tensors in enumerate(scdl): + # inference_input = self.module._get_inference_input(tensors) + # qz = self.module.inference(**inference_input)["qz"] + # if use_posterior_mean: + # z = qz.loc + # else: + # z = qz.sample() + # pred = self.module.classify( + # z, + # label_index=inference_input["y"], + # ) + # if self.module.classifier.logits: + # pred = torch.nn.functional.softmax(pred, dim=-1) + # if not soft: + # pred = pred.argmax(dim=1) + # y_pred.append(pred.detach().cpu()) + # + # y_pred = torch.cat(y_pred).numpy() + # if not soft: + # predictions = [] + # for p in y_pred: + # predictions.append(self._code_to_fine_label[p]) + # + # return np.array(predictions) + # else: + # n_labels = len(pred[0]) + # pred = pd.DataFrame( + # y_pred, + # columns=self._fine_label_mapping[:n_labels], + # index=adata.obs_names[indices], + # ) + # return pred + + # @torch.inference_mode() + # def confusion_coarse_celltypes( + # self, + # adata: AnnData | None = None, + # indices: Sequence[int] | None = None, + # batch_size: int | None = None, + # ) -> np.ndarray | pd.DataFrame: + # """Return likelihood ratios to assess if switch coarse cell-type resolution + # + # is too granular + # + # Parameters + # ---------- + # adata + # AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + # indices + # Return probabilities for each class label. + # soft + # If True, returns per class probabilities + # batch_size + # Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + # use_posterior_mean + # If ``True``, uses the mean of the posterior distribution to predict celltype + # labels. Otherwise, uses a sample from the posterior distribution - this + # means that the predictions will be stochastic. + # + # """ + # adata = self._validate_anndata(adata) + # + # if indices is None: + # indices = np.arange(adata.n_obs) + # + # scdl = self._make_data_loader( + # adata=adata, + # indices=indices, + # batch_size=batch_size, + # ) + # # Iterate once over the data and computes the reconstruction error + # keys = list(self._label_mapping) + ["original"] + # log_lkl = {key: [] for key in keys} + # for tensors in scdl: + # loss_kwargs = {"kl_weight": 1} + # _, _, losses = self.module(tensors, loss_kwargs=loss_kwargs) + # log_lkl["original"] += [losses.reconstruction_loss] + # for i in range(self.module.n_labels): + # tensors_ = tensors + # tensors_["y"] = torch.full_like(tensors["y"], i) + # _, _, losses = self.module(tensors_, loss_kwargs=loss_kwargs) + # log_lkl[keys[i]] += [losses.reconstruction_loss] + # for key in keys: + # log_lkl[key] = torch.stack(log_lkl[key]).detach().numpy() + # + # return log_lkl + + # @devices_dsp.dedent + # def train( + # self, + # max_epochs: int = 300, + # lr: float = 0.001, + # accelerator: str = "auto", + # devices: int | list[int] | str = "auto", + # train_size: float = 1, + # validation_size: float | None = None, + # shuffle_set_split: bool = True, + # batch_size: int = 128, + # datasplitter_kwargs: dict | None = None, + # plan_kwargs: dict | None = None, + # **kwargs, + # ): + # """Trains the model using MAP inference. + # + # Parameters + # ---------- + # max_epochs + # Number of epochs to train for + # lr + # Learning rate for optimization. + # %(param_accelerator)s + # %(param_devices)s + # train_size + # Size of training set in the range [0.0, 1.0]. + # validation_size + # Size of the test set. If `None`, defaults to 1 - `train_size`. If + # `train_size + validation_size < 1`, the remaining cells belong to a test set. + # shuffle_set_split + # Whether to shuffle indices before splitting. If `False`, the val, train, and test set + # are split in the sequential order of the data according to `validation_size` and + # `train_size` percentages. + # batch_size + # Minibatch size to use during training. + # datasplitter_kwargs + # Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. + # plan_kwargs + # Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to + # `train()` will overwrite values present in `plan_kwargs`, when appropriate. + # **kwargs + # Other keyword args for :class:`~scvi.train.Trainer`. + # """ + # update_dict = { + # "lr": lr, + # } + # if plan_kwargs is not None: + # plan_kwargs.update(update_dict) + # else: + # plan_kwargs = update_dict + # super().train( + # max_epochs=max_epochs, + # accelerator=accelerator, + # devices=devices, + # train_size=train_size, + # validation_size=validation_size, + # shuffle_set_split=shuffle_set_split, + # batch_size=batch_size, + # datasplitter_kwargs=datasplitter_kwargs, + # plan_kwargs=plan_kwargs, + # **kwargs, + # ) + + # def _set_indices_and_labels(self, adata: AnnData): + # """Set indices for labeled and unlabeled cells.""" + # labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + # self.original_label_key = labels_state_registry.original_key + # self._label_mapping = labels_state_registry.categorical_mapping + # self._code_to_label = dict(enumerate(self._label_mapping)) + # if self.n_fine_labels is not None: + # fine_labels_state_registry = self.adata_manager.get_state_registry("fine_labels") + # self.original_fine_label_key = fine_labels_state_registry.original_key + # self._fine_label_mapping = fine_labels_state_registry.categorical_mapping + # self._code_to_fine_label = dict(enumerate(self._fine_label_mapping)) @classmethod @setup_anndata_dsp.dedent def setup_anndata( cls, adata: AnnData, + batch_key: str | None = None, labels_key: str | None = None, + fine_labels_key: str | None = None, layer: str | None = None, - batch_key: str | None = None, + unlabeled_category: str = "unlabeled", **kwargs, ): """%(summary)s. @@ -289,16 +434,28 @@ def setup_anndata( Parameters ---------- %(param_adata)s + %(param_batch_key)s %(param_labels_key)s + fine_labels_key + Key in `adata.obs` where fine-grained labels are stored. %(param_layer)s - %(param_batch_key)s + %(param_unlabeled_category)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ - fields.LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), - fields.CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - fields.CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), ] + if fine_labels_key is not None: + anndata_fields.append( + LabelsWithUnlabeledObsField("fine_labels", fine_labels_key, unlabeled_category), + ) + + # register new fields if the adata is minified + adata_minify_type = _get_adata_minify_type(adata) + if adata_minify_type is not None: + anndata_fields += cls._get_fields_for_adata_minification(cls, adata_minify_type) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/_destvi.py b/src/scvi/model/_destvi.py index fb5f5a9add..6b2af34138 100644 --- a/src/scvi/model/_destvi.py +++ b/src/scvi/model/_destvi.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from collections import OrderedDict from typing import TYPE_CHECKING import numpy as np @@ -9,14 +10,18 @@ from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data.fields import LayerField, NumericalObsField +from scvi.data._constants import _SETUP_ARGS_KEY +from scvi.data.fields import ( + CategoricalObsField, + LayerField, + NumericalObsField, +) from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin +from scvi.model.base._archesmixin import _get_loaded_data from scvi.module import MRDeconv from scvi.utils import setup_anndata_dsp -from scvi.utils._docstrings import devices_dsp if TYPE_CHECKING: - from collections import OrderedDict from collections.abc import Sequence from anndata import AnnData @@ -81,30 +86,35 @@ def __init__( cell_type_mapping: np.ndarray, decoder_state_dict: OrderedDict, px_decoder_state_dict: OrderedDict, - px_r: np.ndarray, + px_r: torch.tensor, + per_ct_bias: torch.tensor, n_hidden: int, n_latent: int, n_layers: int, dropout_decoder: float, - l1_reg: float, **module_kwargs, ): super().__init__(st_adata) + self.module = self._module_cls( n_spots=st_adata.n_obs, n_labels=cell_type_mapping.shape[0], + n_batch=self.summary_stats.n_batch, decoder_state_dict=decoder_state_dict, px_decoder_state_dict=px_decoder_state_dict, px_r=px_r, + per_ct_bias=per_ct_bias, n_genes=st_adata.n_vars, n_latent=n_latent, n_layers=n_layers, n_hidden=n_hidden, dropout_decoder=dropout_decoder, - l1_reg=l1_reg, **module_kwargs, ) self.cell_type_mapping = cell_type_mapping + self.cell_type_mapping_extended = list(self.cell_type_mapping) + [ + f"additional_{i}" for i in range(self.module.add_celltypes) + ] self._model_summary_string = "DestVI Model" self.init_params_ = self._get_init_params(locals()) @@ -114,7 +124,7 @@ def from_rna_model( st_adata: AnnData, sc_model: CondSCVI, vamp_prior_p: int = 15, - l1_reg: float = 0.0, + anndata_setup_kwargs: dict | None = None, **module_kwargs, ): """Alternate constructor for exploiting a pre-trained model on a RNA-seq dataset. @@ -122,31 +132,62 @@ def from_rna_model( Parameters ---------- st_adata - registered anndata object + anndata object will be registered sc_model - trained CondSCVI model + trained CondSCVI model or path to a trained model vamp_prior_p number of mixture parameter for VampPrior calculations l1_reg Scalar parameter indicating the strength of L1 regularization on cell type proportions. A value of 50 leads to sparser results. - **model_kwargs - Keyword args for :class:`~scvi.model.DestVI` + anndata_setup_kwargs + Keyword args for :meth:`~scvi.model.DestVI.setup_anndata` + **module_kwargs + Keyword args for :class:`~scvi.model.MRDeconv` """ - decoder_state_dict = sc_model.module.decoder.state_dict() - px_decoder_state_dict = sc_model.module.px_decoder.state_dict() - px_r = sc_model.module.px_r.detach().cpu().numpy() - mapping = sc_model.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ).categorical_mapping - dropout_decoder = sc_model.module.dropout_rate + attr_dict, var_names, load_state_dict, _ = _get_loaded_data(sc_model) + registry = attr_dict.pop("registry_") + + decoder_state_dict = OrderedDict( + (i[8:], load_state_dict[i]) + for i in load_state_dict.keys() + if i.split(".")[0] == "decoder" + ) + px_decoder_state_dict = OrderedDict( + (i[11:], load_state_dict[i]) + for i in load_state_dict.keys() + if i.split(".")[0] == "px_decoder" + ) + px_r = load_state_dict["px_r"] + per_ct_bias = load_state_dict["per_ct_bias"] + mapping = registry["field_registries"]["labels"]["state_registry"]["categorical_mapping"] + + dropout_decoder = attr_dict["init_params_"]["non_kwargs"]["dropout_rate"] if vamp_prior_p is None: mean_vprior = None var_vprior = None + elif attr_dict["init_params_"]["kwargs"]["module_kwargs"]["prior"] == "mog": + mean_vprior = load_state_dict["prior_means"].clone().detach() + var_vprior = torch.exp(load_state_dict["prior_log_std"]) ** 2 + mp_vprior = torch.nn.Softmax(dim=-1)(load_state_dict["prior_logits"]) else: + assert sc_model is not str, ( + "VampPrior requires loading CondSCVI model and providing it" + ) mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior( sc_model.adata, p=vamp_prior_p - ) + ).values() + + if anndata_setup_kwargs is None: + anndata_setup_kwargs = {} + + cls.setup_anndata( + st_adata, + source_registry=registry, + extend_categories=True, + **anndata_setup_kwargs, + **registry[_SETUP_ARGS_KEY], + ) return cls( st_adata, @@ -154,6 +195,7 @@ def from_rna_model( decoder_state_dict, px_decoder_state_dict, px_r, + per_ct_bias, sc_model.module.n_hidden, sc_model.module.n_latent, sc_model.module.n_layers, @@ -161,25 +203,28 @@ def from_rna_model( var_vprior=var_vprior, mp_vprior=mp_vprior, dropout_decoder=dropout_decoder, - l1_reg=l1_reg, **module_kwargs, ) + @torch.inference_mode() def get_proportions( self, - keep_noise: bool = False, + keep_additional: bool = False, + normalize: bool = True, indices: Sequence[int] | None = None, batch_size: int | None = None, ) -> pd.DataFrame: """Returns the estimated cell type proportion for the spatial data. - Shape is n_cells x n_labels OR n_cells x (n_labels + 1) if keep_noise. + Shape is n_cells x n_labels OR n_cells x (n_labels + add_celltypes) if keep_additional. Parameters ---------- - keep_noise - whether to account for the noise term as a standalone cell type in the proportion - estimate. + keep_additional + whether to account for the additional cell-types as standalone cell types + in the proportion estimate. + normalize + whether to normalize the proportions to sum to 1. indices Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used. @@ -191,28 +236,31 @@ def get_proportions( column_names = self.cell_type_mapping index_names = self.adata.obs.index - if keep_noise: - column_names = np.append(column_names, "noise_term") + if keep_additional: + column_names = list(self.cell_type_mapping_extended) + else: + column_names = list(self.cell_type_mapping) if self.module.amortization in ["both", "proportion"]: stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) prop_ = [] for tensors in stdl: - generative_inputs = self.module._get_generative_input(tensors, None) - prop_local = self.module.get_proportions( - x=generative_inputs["x"], keep_noise=keep_noise - ) + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + prop_local = self.module.generative(**generative_inputs)["v"][0, ...] prop_ += [prop_local.cpu()] - data = torch.cat(prop_).numpy() + data = torch.cat(prop_).detach().numpy() if indices: index_names = index_names[indices] else: - if indices is not None: - logger.info( - "No amortization for proportions, ignoring indices and returning results for " - "the full data" - ) - data = self.module.get_proportions(keep_noise=keep_noise) + data = ( + torch.nn.functional.softplus(self.module.V).transpose(0, 1).detach().cpu().numpy() + ) + if not keep_additional: + data = data[:, : -self.module.add_celltypes] + if normalize: + data = data / data.sum(axis=1, keepdims=True) return pd.DataFrame( data=data, @@ -220,6 +268,82 @@ def get_proportions( index=index_names, ) + # @torch.inference_mode() + # def get_fine_celltypes( + # self, + # sc_model: CondSCVI, + # indices=None, + # batch_size: int | None = None, + # ) -> np.ndarray | dict[str, pd.DataFrame]: + # """Returns the estimated cell-type specific latent space for the spatial data. + # + # Parameters + # ---------- + # sc_model + # trained CondSCVI model + # indices + # Indices of cells in adata to use. Only used if amortization. + # If `None`, all cells are used. + # batch_size + # Minibatch size for data loading into model. Only used if amortization. + # Defaults to `scvi.settings.batch_size`. + # """ + # self._check_if_trained() + # index_names = self.adata.obs.index + # stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) + # if sc_model.n_fine_labels is None: + # raise RuntimeError( + # "Single cell model does not contain fine labels. " + # "Please train the single-cell model with fine labels." + # ) + # predicted_fine_celltype_ = [] + # for tensors in stdl: + # inference_inputs = self.module._get_inference_input(tensors) + # outputs = self.module.inference(**inference_inputs) + # generative_inputs = self.module._get_generative_input(tensors, outputs) + # generative_outputs = self.module.generative(**generative_inputs) + # + # gamma_local = generative_outputs["gamma"][0, ...].transpose(-2, -4) # c, n, p, m + # proportions_modes_local = generative_outputs["proportion_modes"][0, ...] # pmc + # n_modes, batch_size, n_celltypes = proportions_modes_local.shape + # gamma_local_ = gamma_local.permute((3, 2, 0, 1)).reshape( + # -1, self.module.n_latent + # ) # m*p*c, n + # proportions_modes_local_ = proportions_modes_local.permute( + # (1, 0, 2) + # ).flatten() # m*p*c + # v_local = ( + # generative_outputs["v"][0, ..., : -self.module.add_celltypes] + # .flatten() + # .repeat_interleave(n_modes) + # ) # m*p*c + # label = ( + # torch.arange(self.module.n_labels, device=gamma_local.device) + # .repeat(batch_size) + # .repeat_interleave(n_modes) + # .unsqueeze(-1) + # ) # m*p*c, 1 + # predicted_fine_celltype_local = ( + # v_local.unsqueeze(-1) + # * proportions_modes_local_.unsqueeze(-1) + # * torch.nn.functional.softmax( + # sc_model.module.classify(gamma_local_, label), dim=-1 + # ) + # ) + # predicted_fine_celltype_sum = predicted_fine_celltype_local.reshape( + # batch_size, n_celltypes * n_modes, sc_model.n_fine_labels + # ).sum(1) + # predicted_fine_celltype_.append(predicted_fine_celltype_sum.detach().cpu()) + # predicted_fine_celltype = torch.cat(predicted_fine_celltype_, dim=0).numpy() + # + # pred = pd.DataFrame( + # predicted_fine_celltype, + # columns=sc_model._fine_label_mapping, + # index=index_names, + # ) + # return pred + + @torch.inference_mode() def get_gamma( self, indices: Sequence[int] | None = None, @@ -241,27 +365,31 @@ def get_gamma( """ self._check_if_trained() - column_names = np.arange(self.module.n_latent) + column_names = [str(i) for i in np.arange(self.module.n_latent)] index_names = self.adata.obs.index if self.module.amortization in ["both", "latent"]: stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) gamma_ = [] for tensors in stdl: - generative_inputs = self.module._get_generative_input(tensors, None) - gamma_local = self.module.get_gamma(x=generative_inputs["x"]) + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + generative_outputs = self.module.generative(**generative_inputs) + gamma_local = generative_outputs["gamma"][0, ...] + if self.module.prior_mode == "mog": + proportions_model_local = generative_outputs["proportion_modes"][0, ...] + gamma_local = torch.einsum( + "pncm,pmc->ncm", gamma_local, proportions_model_local + ) + else: + gamma_local = gamma_local[0, ...].squeeze(0) gamma_ += [gamma_local.cpu()] data = torch.cat(gamma_, dim=-1).numpy() if indices is not None: index_names = index_names[indices] else: - if indices is not None: - logger.info( - "No amortization for latent values, ignoring adata and returning results for " - "the full data" - ) - data = self.module.get_gamma() - + data = self.module.gamma.detach().cpu().numpy() data = np.transpose(data, (2, 0, 1)) if return_numpy: return data @@ -271,6 +399,72 @@ def get_gamma( res[ct] = pd.DataFrame(data=data[:, :, i], columns=column_names, index=index_names) return res + @torch.inference_mode() + def get_latent_representation( + self, + adata: AnnData | None = None, + indices: Sequence[int] | None = None, + give_mean: bool = True, + mc_samples: int = 5000, + batch_size: int | None = None, + return_dist: bool = False, + ) -> np.ndarray(np.ndarray, np.ndarray): + """Return the latent representation for each cell. + + This is typically denoted as :math:`z_n`. + + Parameters + ---------- + adata + AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the + AnnData object used to initialize the model. + indices + Indices of cells in adata to use. If `None`, all cells are used. + give_mean + Give mean of distribution or sample from it. + mc_samples + For distributions with no closed-form mean (e.g., `logistic normal`), + how many Monte Carlo samples to take for computing mean. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_dist + Return (mean, variance) of distributions instead of just the mean. + If `True`, ignores `give_mean` and `mc_samples`. In the case of the latter, + `mc_samples` is used to compute the mean of a transformed distribution. + If `return_dist` is true the untransformed mean and variance are returned. + + Returns + ------- + Low-dimensional representation for each cell or a tuple containing its mean and variance. + """ + assert self.module.n_latent_amortization is not None, ( + "Model has no latent representation for amortized values." + ) + self._check_if_trained(warn=False) + + adata = self._validate_anndata(adata) + scdl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size) + latent = [] + latent_qzm = [] + latent_qzv = [] + for tensors in scdl: + inference_inputs = self.module._get_inference_input(tensors) + inference_outputs = self.module.inference(**inference_inputs, n_samples=mc_samples) + z = inference_outputs["z"][0, ...] + qz = inference_outputs["qz"] + if give_mean: + latent += [qz.loc[0, ...].cpu()] + else: + latent += [z.cpu()] + latent_qzm += [qz.loc[0, ...].cpu()] + latent_qzv += [qz.scale[0, ...].square().cpu()] + return ( + (torch.cat(latent_qzm).numpy(), torch.cat(latent_qzv).numpy()) + if return_dist + else torch.cat(latent).numpy() + ) + + @torch.inference_mode() def get_scale_for_ct( self, label: str, @@ -293,20 +487,24 @@ def get_scale_for_ct( Pandas dataframe of gene_expression """ self._check_if_trained() + self._validate_anndata() + + cell_type_mapping_extended = list(self.cell_type_mapping) + [ + f"additional_{i}" for i in range(self.module.add_celltypes) + ] - if label not in self.cell_type_mapping: + if label not in cell_type_mapping_extended: raise ValueError("Unknown cell type") - y = np.where(label == self.cell_type_mapping)[0][0] + y = cell_type_mapping_extended.index(label) stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size) scale = [] for tensors in stdl: - generative_inputs = self.module._get_generative_input(tensors, None) - x, ind_x = ( - generative_inputs["x"], - generative_inputs["ind_x"], - ) - px_scale = self.module.get_ct_specific_expression(x, ind_x, y) + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + px_scale = self.module.generative(**generative_inputs)["px_mu"][0, :, y, :] + scale += [px_scale.cpu()] data = torch.cat(scale).numpy() @@ -316,73 +514,131 @@ def get_scale_for_ct( index_names = index_names[indices] return pd.DataFrame(data=data, columns=column_names, index=index_names) - @devices_dsp.dedent - def train( - self, - max_epochs: int = 2000, - lr: float = 0.003, - accelerator: str = "auto", - devices: int | list[int] | str = "auto", - train_size: float = 1.0, - validation_size: float | None = None, - shuffle_set_split: bool = True, - batch_size: int = 128, - n_epochs_kl_warmup: int = 200, - datasplitter_kwargs: dict | None = None, - plan_kwargs: dict | None = None, - **kwargs, - ): - """Trains the model using MAP inference. - - Parameters - ---------- - max_epochs - Number of epochs to train for - lr - Learning rate for optimization. - %(param_accelerator)s - %(param_devices)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set - are split in the sequential order of the data according to `validation_size` and - `train_size` percentages. - batch_size - Minibatch size to use during training. - n_epochs_kl_warmup - number of epochs needed to reach unit kl weight in the elbo - datasplitter_kwargs - Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. - plan_kwargs - Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **kwargs - Other keyword args for :class:`~scvi.train.Trainer`. - """ - update_dict = { - "lr": lr, - "n_epochs_kl_warmup": n_epochs_kl_warmup, - } - if plan_kwargs is not None: - plan_kwargs.update(update_dict) - else: - plan_kwargs = update_dict - super().train( - max_epochs=max_epochs, - accelerator=accelerator, - devices=devices, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - batch_size=batch_size, - datasplitter_kwargs=datasplitter_kwargs, - plan_kwargs=plan_kwargs, - **kwargs, - ) + # @torch.inference_mode() + # def get_expression_for_ct( + # self, + # label: str, + # indices: Sequence[int] | None = None, + # batch_size: int | None = None, + # return_sparse_array: bool = False, + # ) -> pd.DataFrame: + # r"""Return the scaled parameter of the NB for every spot in queried cell types. + # + # Parameters + # ---------- + # label + # cell type of interest + # indices + # Indices of cells in self.adata to use. If `None`, all cells are used. + # batch_size + # Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + # return_sparse_array + # If `True`, returns a sparse array instead of a dataframe. + # + # Returns + # ------- + # Pandas dataframe of gene_expression + # """ + # self._check_if_trained() + # + # if label not in self.cell_type_mapping_extended: + # raise ValueError("Unknown cell type") + # y = self.cell_type_mapping_extended.index(label) + # + # stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size) + # expression_ct = [] + # for tensors in stdl: + # inference_inputs = self.module._get_inference_input(tensors) + # outputs = self.module.inference(**inference_inputs) + # generative_inputs = self.module._get_generative_input(tensors, outputs) + # generative_outputs = self.module.generative(**generative_inputs) + # px_scale, proportions = ( + # generative_outputs["px_mu"][0, ...], + # generative_outputs["v"][0, ...], + # ) + # px_scale_expected = torch.einsum("mkl,mk->mkl", px_scale, proportions) + # px_scale_proportions = px_scale_expected[:, y, :] / px_scale_expected.sum(dim=1) + # x_ct = tensors["X"].to(px_scale_proportions.device) * px_scale_proportions + # expression_ct += [x_ct.cpu()] + # + # data = torch.cat(expression_ct).numpy() + # if return_sparse_array: + # data = csr_matrix(data.T) + # return data + # else: + # column_names = self.adata.var.index + # index_names = self.adata.obs.index + # if indices is not None: + # index_names = index_names[indices] + # return pd.DataFrame(data=data, columns=column_names, index=index_names) + + # @devices_dsp.dedent + # def train( + # self, + # max_epochs: int = 2000, + # lr: float = 0.003, + # accelerator: str = "auto", + # devices: int | list[int] | str = "auto", + # train_size: float = 1.0, + # validation_size: float | None = None, + # shuffle_set_split: bool = True, + # batch_size: int = 128, + # n_epochs_kl_warmup: int = 200, + # datasplitter_kwargs: dict | None = None, + # plan_kwargs: dict | None = None, + # **kwargs, + # ): + # """Trains the model using MAP inference. + # + # Parameters + # ---------- + # max_epochs + # Number of epochs to train for + # lr + # Learning rate for optimization. + # %(param_accelerator)s + # %(param_devices)s + # train_size + # Size of training set in the range [0.0, 1.0]. + # validation_size + # Size of the test set. If `None`, defaults to 1 - `train_size`. If + # `train_size + validation_size < 1`, the remaining cells belong to a test set. + # shuffle_set_split + # Whether to shuffle indices before splitting. If `False`, the val, train, and test set + # are split in the sequential order of the data according to `validation_size` and + # `train_size` percentages. + # batch_size + # Minibatch size to use during training. + # n_epochs_kl_warmup + # number of epochs needed to reach unit kl weight in the elbo + # datasplitter_kwargs + # Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. + # plan_kwargs + # Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to + # `train()` will overwrite values present in `plan_kwargs`, when appropriate. + # **kwargs + # Other keyword args for :class:`~scvi.train.Trainer`. + # """ + # update_dict = { + # "lr": lr, + # "n_epochs_kl_warmup": n_epochs_kl_warmup, + # } + # if plan_kwargs is not None: + # plan_kwargs.update(update_dict) + # else: + # plan_kwargs = update_dict + # super().train( + # max_epochs=max_epochs, + # accelerator=accelerator, + # devices=devices, + # train_size=train_size, + # validation_size=validation_size, + # shuffle_set_split=shuffle_set_split, + # batch_size=batch_size, + # datasplitter_kwargs=datasplitter_kwargs, + # plan_kwargs=plan_kwargs, + # **kwargs, + # ) @classmethod @setup_anndata_dsp.dedent @@ -390,6 +646,8 @@ def setup_anndata( cls, adata: AnnData, layer: str | None = None, + smoothed_layer: str | None = None, + batch_key: str | None = None, **kwargs, ): """%(summary)s. @@ -398,6 +656,9 @@ def setup_anndata( ---------- %(param_adata)s %(param_layer)s + smoothed_layer + param that... + %(param_batch_key)s """ setup_method_args = cls._get_setup_method_args(**locals()) # add index for each cell (provided to pyro plate for correct minibatching) @@ -405,7 +666,10 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), ] + if smoothed_layer is not None: + anndata_fields.append(LayerField("x_smoothed", smoothed_layer, is_count_data=True)) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/base/_archesmixin.py b/src/scvi/model/base/_archesmixin.py index 2f612aca9f..9f2b596525 100644 --- a/src/scvi/model/base/_archesmixin.py +++ b/src/scvi/model/base/_archesmixin.py @@ -83,9 +83,9 @@ def load_query_data( freeze_dropout Whether to freeze dropout during training freeze_expression - Freeze neurons corersponding to expression in first layer + Freeze neurons corresponding to expression in first layer freeze_decoder_first_layer - Freeze neurons corersponding to first layer in decoder + Freeze neurons corresponding to first layer in decoder freeze_batchnorm_encoder Whether to freeze batchnorm weight and bias during training for encoder freeze_batchnorm_decoder diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index fd7e61af58..707636b4ca 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -172,7 +172,7 @@ def _set_indices_and_labels(self, datamodule=None): """Set indices for labeled and unlabeled cells.""" labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key - self.unlabeled_category_ = labels_state_registry.unlabeled_category + self.unlabeled_category_ = getattr(labels_state_registry, "unlabeled_category", None) if datamodule is None: self.labels_ = get_anndata_attribute( diff --git a/src/scvi/module/_mrdeconv.py b/src/scvi/module/_mrdeconv.py index e0d6c40e3f..66662011fe 100644 --- a/src/scvi/module/_mrdeconv.py +++ b/src/scvi/module/_mrdeconv.py @@ -1,20 +1,26 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING +from collections import OrderedDict +from typing import Literal +import numpy as np import torch -from torch.distributions import Normal +from torch.distributions import ( + Categorical, + Exponential, + Independent, + MixtureSameFamily, + Normal, +) +from torch.distributions import kl_divergence as kl from scvi import REGISTRY_KEYS from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data -from scvi.nn import FCLayers - -if TYPE_CHECKING: - from collections import OrderedDict - from typing import Literal - - import numpy as np +from scvi.module.base import ( + BaseMinifiedModeModuleClass, + EmbeddingModuleMixin, + LossOutput, + auto_move_data, +) +from scvi.nn import Encoder, FCLayers def identity(x): @@ -22,7 +28,7 @@ def identity(x): return x -class MRDeconv(BaseModuleClass): +class MRDeconv(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): """Model for multi-resolution deconvolution of spatial transriptomics. Parameters @@ -39,36 +45,41 @@ class MRDeconv(BaseModuleClass): Number of dimensions used in the latent variables n_genes Number of genes used in the decoder - dropout_decoder - Dropout rate for the decoder neural network (same dropout as in CondSCVI decoder) - dropout_amortization - Dropout rate for the amortization neural network + px_r + parameters for the px_r tensor in the CondSCVI model + per_ct_bias + estimates of per cell-type expression bias in the CondSCVI model decoder_state_dict state_dict from the decoder of the CondSCVI model px_decoder_state_dict state_dict from the px_decoder of the CondSCVI model - px_r - parameters for the px_r tensor in the CondSCVI model + dropout_decoder + Dropout rate for the decoder neural network (same dropout as in CondSCVI decoder) + dropout_amortization + Dropout rate for the amortization neural network + n_samples_augmentation + Number of samples used in the augmentation + n_states_per_label + Number of states per cell-type in each spot + eps_v + Epsilon value for each cell-type proportion used during training. + n_states_per_augmented_label + Number of states per cell-type in each spot during augmentation mean_vprior Mean parameter for each component in the empirical prior over the latent space var_vprior Diagonal variance parameter for each component in the empirical prior over the latent space mp_vprior Mixture proportion in cell type sub-clustering of each component in the empirical prior - amortization - which of the latent variables to amortize inference over (gamma, proportions, both or none) - l1_reg - Scalar parameter indicating the strength of L1 regularization on cell type proportions. - A value of 50 leads to sparser results. - beta_reg - Scalar parameter indicating the strength of the variance penalty for - the multiplicative offset in gene expression values (beta parameter). Default is 5 - (setting to 0.5 might help if single cell reference and spatial assay are different - e.g. UMI vs non-UMI.) - eta_reg - Scalar parameter indicating the strength of the prior for - the noise term (eta parameter). Default is 1e-4. - (changing value is discouraged.) + amortization + prior_mode + Mode of the prior distribution for the latent space. + Either "mog" for mixture of gaussians or "normal" for normal distribution. + add_celltypes + Number of additional cell types compared to single cell data to add to the model + n_latent_amortization + Number of dimensions used in the latent variables + for the amortization encoder neural network extra_encoder_kwargs Extra keyword arguments passed into :class:`~scvi.nn.FCLayers`. extra_decoder_kwargs @@ -79,43 +90,64 @@ def __init__( self, n_spots: int, n_labels: int, + n_batch: int, n_hidden: int, n_layers: int, n_latent: int, n_genes: int, decoder_state_dict: OrderedDict, px_decoder_state_dict: OrderedDict, - px_r: np.ndarray, + px_r: torch.tensor, + per_ct_bias: torch.tensor, dropout_decoder: float, - dropout_amortization: float = 0.05, + dropout_amortization: float = 0.03, + augmentation: bool = True, + n_samples_augmentation: int = 2, + n_states_per_label: int = 3, + eps_v: float = 2e-3, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, mp_vprior: np.ndarray = None, amortization: Literal["none", "latent", "proportion", "both"] = "both", - l1_reg: float = 0.0, - beta_reg: float = 5.0, - eta_reg: float = 1e-4, + prior_mode: Literal["mog", "normal"] = "mog", + add_celltypes: int = 1, + n_latent_amortization: int | None = None, extra_encoder_kwargs: dict | None = None, extra_decoder_kwargs: dict | None = None, ): super().__init__() + if prior_mode == "mog": + assert amortization in ["both", "latent"], ( + "Amortization must be active for latent variables to use mixture " + "of gaussians generation" + ) self.n_spots = n_spots + self.n_batch = n_batch self.n_labels = n_labels self.n_hidden = n_hidden self.n_latent = n_latent + self.augmentation = augmentation + self.n_samples_augmentation = n_samples_augmentation self.dropout_decoder = dropout_decoder + self.n_states_per_label = n_states_per_label self.dropout_amortization = dropout_amortization self.n_genes = n_genes self.amortization = amortization - self.l1_reg = l1_reg - self.beta_reg = beta_reg - self.eta_reg = eta_reg + self.prior_mode = prior_mode + self.add_celltypes = add_celltypes + self.eps_v = eps_v + self.n_latent_amortization = n_latent_amortization # unpack and copy parameters _extra_decoder_kwargs = extra_decoder_kwargs or {} + cat_list = [n_labels] + self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch) + batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim + n_input_decoder = n_latent + batch_dim + self.decoder = FCLayers( - n_in=n_latent, + n_in=n_input_decoder, n_out=n_hidden, - n_cat_list=[n_labels], + n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_decoder, @@ -123,42 +155,74 @@ def __init__( use_batch_norm=False, **_extra_decoder_kwargs, ) - self.px_decoder = torch.nn.Sequential( - torch.nn.Linear(n_hidden, n_genes), torch.nn.Softplus() - ) - # don't compute gradient for those parameters self.decoder.load_state_dict(decoder_state_dict) for param in self.decoder.parameters(): param.requires_grad = False + self.px_decoder = torch.nn.Linear(n_hidden, n_genes) self.px_decoder.load_state_dict(px_decoder_state_dict) for param in self.px_decoder.parameters(): param.requires_grad = False - self.register_buffer("px_o", torch.tensor(px_r, dtype=torch.float32)) + self.px_r = torch.nn.Parameter(px_r) + self.register_buffer("per_ct_bias", per_ct_bias) # cell_type specific factor loadings - self.V = torch.nn.Parameter(torch.randn(self.n_labels + 1, self.n_spots)) + self.V = torch.nn.Parameter(torch.randn(self.n_labels + self.add_celltypes, self.n_spots)) # within cell_type factor loadings self.gamma = torch.nn.Parameter(torch.randn(n_latent, self.n_labels, self.n_spots)) if mean_vprior is not None: - self.p = mean_vprior.shape[1] - self.register_buffer("mean_vprior", torch.tensor(mean_vprior, dtype=torch.float32)) - self.register_buffer("var_vprior", torch.tensor(var_vprior, dtype=torch.float32)) - self.register_buffer("mp_vprior", torch.tensor(mp_vprior, dtype=torch.float32)) + self.register_buffer("mean_vprior", mean_vprior) + self.register_buffer("var_vprior", var_vprior) + self.register_buffer("mp_vprior", mp_vprior) + cats = Categorical(probs=self.mp_vprior) + normal_dists = Independent( + Normal(self.mean_vprior, torch.sqrt(self.var_vprior) + 1e-4), + reinterpreted_batch_ndims=1, + ) + self.qz_prior = MixtureSameFamily(cats, normal_dists) else: self.mean_vprior = None self.var_vprior = None # noise from data - self.eta = torch.nn.Parameter(torch.randn(self.n_genes)) + self.eta = torch.nn.Parameter(torch.zeros(self.add_celltypes, self.n_genes)) # additive gene bias - self.beta = torch.nn.Parameter(0.01 * torch.randn(self.n_genes)) + self.beta = torch.nn.Parameter(torch.zeros(self.n_genes)) # create additional neural nets for amortization # within cell_type factor loadings _extra_encoder_kwargs = extra_encoder_kwargs or {} + if self.prior_mode == "mog": + return_dist = ( + self.n_states_per_label * n_labels * n_latent + self.n_states_per_label * n_labels + ) + else: + return_dist = n_labels * n_latent + if self.n_latent_amortization is not None: + # Uses a combined latent space for proportions and gammas. + self.z_encoder = Encoder( + self.n_genes, + n_latent_amortization, + n_cat_list=[n_batch], + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=dropout_amortization, + inject_covariates=True, + use_batch_norm=False, + use_layer_norm=True, + var_activation=torch.nn.functional.softplus, + return_dist=True, + **_extra_encoder_kwargs, + ) + else: + + def identity(x, batch_index=None): + return x, Normal(x, scale=1e-6 * torch.ones_like(x)) + + self.z_encoder = identity + n_latent_amortization = self.n_genes self.gamma_encoder = torch.nn.Sequential( FCLayers( - n_in=self.n_genes, + n_in=n_latent_amortization, n_out=n_hidden, n_cat_list=None, n_layers=2, @@ -166,127 +230,355 @@ def __init__( dropout_rate=dropout_amortization, use_layer_norm=True, use_batch_norm=False, - **_extra_encoder_kwargs, ), - torch.nn.Linear(n_hidden, n_latent * n_labels), + torch.nn.Linear(n_hidden, return_dist), ) # cell type loadings self.V_encoder = torch.nn.Sequential( FCLayers( - n_in=self.n_genes, + n_in=n_latent_amortization, n_out=n_hidden, - n_layers=2, + n_cat_list=None, + n_layers=n_layers, n_hidden=n_hidden, - dropout_rate=dropout_amortization, + dropout_rate=0, use_layer_norm=True, use_batch_norm=False, - **_extra_encoder_kwargs, ), - torch.nn.Linear(n_hidden, n_labels + 1), + torch.nn.Linear(n_hidden, n_labels + self.add_celltypes), ) def _get_inference_input(self, tensors): - # we perform MAP here, so we just need to subsample the variables - return {} + x = tensors[REGISTRY_KEYS.X_KEY] + x_smoothed = tensors.get("x_smoothed", None) + m = x.shape[0] + if x_smoothed is not None: + n_samples = self.n_samples_augmentation + 2 + n_samples_observed = 2 + else: + n_samples = self.n_samples_augmentation + 1 + n_samples_observed = 1 + px_r = torch.exp(self.px_r) + if self.augmentation and self.training: + with torch.no_grad(): + prior_sampled = self.qz_prior.sample([n_samples, m]).reshape( + n_samples, -1, self.n_latent + ) + enum_label = ( + torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) + ) # m * n_labels, 1 + batch_rep = self.compute_embedding( + REGISTRY_KEYS.BATCH_KEY, tensors[REGISTRY_KEYS.BATCH_KEY] + ) + batch_rep_input = batch_rep.repeat_interleave(self.n_labels, dim=0).repeat( + n_samples, 1, 1 + ) + decoder_input = torch.cat([prior_sampled, batch_rep_input], dim=-1) + px_scale_augment_ = torch.nn.Softmax(dim=-1)( + self.px_decoder(self.decoder(decoder_input, enum_label.to(x.device))) + + self.per_ct_bias[enum_label.ravel()].unsqueeze(-3) + + self.beta.view(1, 1, -1) + ) + px_scale_augment = px_scale_augment_.reshape( + (n_samples, x.shape[0], self.n_labels, -1) + ) # (samples, mi, n_labels, n_genes) + library = x.sum(-1).view(1, m, 1, 1).repeat(n_samples, 1, 1, 1) + library[1, ...] = library[1, ...] + 50 + px_scale_augment = px_scale_augment.reshape( + n_samples, m, self.n_labels, -1 + ) # (samples, m, n_labels, n_genes) + px_rate = library * px_scale_augment # (samples, m, n_labels, n_genes) + ratios_ct_augmentation = ( + torch.distributions.Dirichlet(torch.zeros(self.n_labels) + 0.03) + .sample([n_samples, m]) + .to(x.device) + ) + ratios_ct_augmentation = ratios_ct_augmentation.reshape( + n_samples, m, self.n_labels + ) + # sum over celltypes + augmentation_rate = torch.einsum( + "imc, imcg -> img", ratios_ct_augmentation, px_rate + ) # (samples, m, n_genes) + ratio_augmentation_ = ( + torch.distributions.Beta(0.4, 0.5) + .sample([self.n_samples_augmentation - 1, m]) + .unsqueeze(-1) + .to(x.device) + ) + ratio_augmentation = torch.cat( + [ + torch.zeros((n_samples_observed, m, 1), device=x.device), + torch.ones((1, m, 1), device=x.device), + ratio_augmentation_, + ], + dim=0, + ) + augmented_counts = NegativeBinomial( + mu=augmentation_rate, theta=px_r + ).sample() # (samples, m, n_genes) + x_augmented = (1 - ratio_augmentation) * x.unsqueeze( + 0 + ) + ratio_augmentation * augmented_counts + if x_smoothed is not None: + x_augmented[1, ...] = x_smoothed + else: + x_augmented = x.unsqueeze(0) + if x_smoothed is not None: + x_augmented = torch.cat([x_augmented, x_smoothed.unsqueeze(0)], dim=0) + prior_sampled = None + ratios_ct_augmentation = None + ratio_augmentation = None + + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] + + input_dict = { + "batch_index": batch_index, + "x_augmented": x_augmented, + "prior_sampled": prior_sampled, + "ratios_ct_augmentation": ratios_ct_augmentation, + "ratio_augmentation": ratio_augmentation, + } + return input_dict def _get_generative_input(self, tensors, inference_outputs): - x = tensors[REGISTRY_KEYS.X_KEY] + z = inference_outputs["z"] + library = inference_outputs["library"] ind_x = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel() + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - batch_index = None # tensors[REGISTRY_KEYS.BATCH_KEY] - - input_dict = {"x": x, "ind_x": ind_x, "batch_index": batch_index} + input_dict = {"z": z, "ind_x": ind_x, "library": library, "batch_index": batch_index} return input_dict @auto_move_data - def inference(self): - """Run the inference model.""" - return {} + def inference( + self, + x_augmented, + batch_index, + n_samples=1, + prior_sampled=None, + ratios_ct_augmentation=None, + ratio_augmentation=None, + ): + """Runs the inference (encoder) model.""" + x_ = x_augmented + library = x_augmented.sum(-1).unsqueeze(-1) + x_ = torch.log(1 + x_) + if self.n_latent_amortization is not None: + qz, z = self.z_encoder(x_, batch_index) + else: + z = x_ + qz = Normal(x_, scale=1e-6 * torch.ones_like(x_)) + + outputs = { + "z": z, + "qz": qz, + "library": library, + "x_augmented": x_augmented, + "prior_sampled": prior_sampled, + "ratio_augmentation": ratio_augmentation, + "ratios_ct_augmentation": ratios_ct_augmentation, + } + return outputs @auto_move_data - def generative(self, x, ind_x, batch_index=None, transform_batch: torch.Tensor | None = None): + def generative(self, z, ind_x, library, batch_index): """Build the deconvolution model for every cell in the minibatch.""" - m = x.shape[0] - library = torch.sum(x, dim=1, keepdim=True) + m = len(ind_x) # setup all non-linearities - beta = torch.exp(self.beta) # n_genes - eps = torch.nn.functional.softplus(self.eta) # n_genes - x_ = torch.log(1 + x) - # subsample parameters - - # if transform_batch is not None: - # batch_index = torch.ones_like(batch_index) * transform_batch + eps = torch.nn.functional.softmax(self.eta, dim=-1) # add_celltypes, n_genes + px_r = torch.exp(self.px_r) + n_samples = z.size(0) if self.amortization in ["both", "latent"]: - gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape( - (self.n_latent, self.n_labels, -1) - ) + if self.prior_mode == "mog": + gamma_ = self.gamma_encoder(z) + proportion_modes_logits = ( + torch.transpose( + gamma_[:, :, -self.n_states_per_label * self.n_labels :], -1, -2 + ) + .reshape((n_samples, self.n_states_per_label, self.n_labels, m)) + .transpose(-1, -2) + ) # n_samples, n_states_per_label, m, n_labels + proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=-3) + gamma_ind = torch.transpose( + gamma_[:, :, : self.n_states_per_label * self.n_labels * self.n_latent], -1, -2 + ).reshape( + (n_samples, self.n_states_per_label, self.n_latent, self.n_labels, m) + ) # n_samples, n_states_per_label, n_latent, n_labels, m + else: + gamma_ind = torch.transpose(self.gamma_encoder(z), 0, 1).reshape( + (n_samples, 1, self.n_latent, self.n_labels, m) + ) + proportion_modes_logits = proportion_modes = torch.ones( + (n_samples, 1, m, self.n_labels), device=z.device + ) else: - gamma_ind = self.gamma[:, :, ind_x] # n_latent, n_labels, minibatch_size + gamma_ind = ( + self.gamma[:, :, ind_x].unsqueeze(0).unsqueeze(0).repeat(n_samples, 1, 1, 1, 1) + ) # n_samples, n_latent, n_labels, m + proportion_modes_logits = proportion_modes = torch.ones( + (n_samples, 1, m, self.n_labels), device=z.device + ) if self.amortization in ["both", "proportion"]: - v_ind = self.V_encoder(x_) + v_ind = self.V_encoder(z) else: - v_ind = self.V[:, ind_x].T # minibatch_size, labels + 1 - v_ind = torch.nn.functional.softplus(v_ind) - - # reshape and get gene expression value for all minibatch - gamma_ind = torch.transpose(gamma_ind, 2, 0) # minibatch_size, n_labels, n_latent - gamma_reshape = gamma_ind.reshape( - (-1, self.n_latent) - ) # minibatch_size * n_labels, n_latent - enum_label = ( - torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) - ) # minibatch_size * n_labels, 1 - h = self.decoder(gamma_reshape, enum_label.to(x.device)) - px_rate = self.px_decoder(h).reshape( - (m, self.n_labels, -1) - ) # (minibatch, n_labels, n_genes) - - # add the dummy cell type - eps = eps.repeat((m, 1)).view(m, 1, -1) # (M, 1, n_genes) <- this is the dummy cell type - - # account for gene specific bias and add noise - r_hat = torch.cat( - [beta.unsqueeze(0).unsqueeze(1) * px_rate, eps], dim=1 - ) # M, n_labels + 1, n_genes - # now combine them for convolution - px_scale = torch.sum(v_ind.unsqueeze(2) * r_hat, dim=1) # batch_size, n_genes + v_ind = ( + self.V[:, ind_x].T.unsqueeze(0).repeat(n_samples, 1, 1) + ) # n_samples, m, labels + 1 + v_ind = torch.nn.functional.softmax(v_ind, dim=-1) + + px_est = torch.zeros((n_samples, m, self.n_labels, self.n_genes), device=z.device) + enum_label = torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) # m * n_labels, 1 + + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + batch_rep_input = ( + batch_rep.repeat_interleave(self.n_labels, dim=0).unsqueeze(0).repeat(n_samples, 1, 1) + ) + + for mode in range(gamma_ind.shape[-4]): + # reshape and get gene expression value for all minibatch + gamma_ind_ = torch.transpose( + gamma_ind[:, mode, ...], -1, -3 + ) # n_samples, m, n_labels, n_latent + gamma_reshape_ = gamma_ind_.reshape( + (n_samples, -1, self.n_latent) + ) # n_samples, m * n_labels, n_latent + decoder_input_ = torch.cat([gamma_reshape_, batch_rep_input], dim=-1) + h = self.decoder(decoder_input_, enum_label.to(z.device)) + px_est += proportion_modes[:, mode, ...].unsqueeze(-1) * torch.nn.Softmax(dim=-1)( + self.px_decoder(h).reshape((n_samples, m, self.n_labels, -1)) + + self.per_ct_bias[enum_label.ravel()].reshape(1, m, self.n_labels, -1) + + self.beta.view(1, 1, 1, -1) + ) # (n_samples, m, n_labels, n_genes) + + # add the additional cell types + eps = eps.unsqueeze(0).repeat( + n_samples, m, 1, 1 + ) # (n_samples, m, add_celltypes, n_genes) <- this is the dummy cell type + r_hat = torch.cat([px_est, eps], dim=-2) # n_samples, m, n_labels + add_celltypes, n_genes + + # now combine them for convolution. Add epsilon during training. + eps_v = self.eps_v if self.training else 0.0 + px_scale = torch.sum( + (v_ind.unsqueeze(-1) + eps_v) * r_hat, dim=-2 + ) # n_samples, m, n_genes px_rate = library * px_scale return { - "px_o": self.px_o, + "px_r": px_r, "px_rate": px_rate, "px_scale": px_scale, + "px_mu": r_hat, "gamma": gamma_ind, "v": v_ind, - "batch_index": batch_index, + "proportion_modes": proportion_modes, + "proportion_modes_logits": proportion_modes_logits, } + def _compute_cross_entropy(self, prob_true, prob_pred): + log_prob_pred = torch.log(prob_pred / prob_pred.sum(axis=-1, keepdim=True)) + prob_true = prob_true + 1e-20 + prob_true = prob_true / prob_true.sum(axis=-1, keepdim=True) + kl_div = torch.nn.functional.kl_div( + log_prob_pred, prob_true, reduction="batchmean", log_target=False + ) + + return kl_div + def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, - n_obs: int = 1.0, + ct_sparsity_weight: float = 2.0, + weighting_augmentation: float = 100.0, + weighting_smoothing: float = 100.0, + eta_reg: float = 1.0, + beta_reg: float = 1.0, + weighting_kl_latent: float = 1e-3, + reconst_weight: float = 3.0, ): """Compute the loss.""" - x = tensors[REGISTRY_KEYS.X_KEY] + x_augmented = inference_outputs["x_augmented"] px_rate = generative_outputs["px_rate"] - px_o = generative_outputs["px_o"] + px_r = generative_outputs["px_r"] gamma = generative_outputs["gamma"] v = generative_outputs["v"] + ratio_augmentation = inference_outputs["ratio_augmentation"] + ratios_ct_augmentation = inference_outputs["ratios_ct_augmentation"] - reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1) - - # eta prior likelihood - mean = torch.zeros_like(self.eta) - scale = torch.ones_like(self.eta) - glo_neg_log_likelihood_prior = -self.eta_reg * Normal(mean, scale).log_prob(self.eta).sum() - glo_neg_log_likelihood_prior += self.beta_reg * torch.var(self.beta) + if "x_smoothed" in tensors: + sample_fully_augmented = 2 + n_samples = self.n_samples_augmentation + 2 + else: + sample_fully_augmented = 1 + n_samples = self.n_samples_augmentation + 1 + m = x_augmented.shape[1] - v_sparsity_loss = self.l1_reg * torch.abs(v).mean(1) + if self.augmentation: + prior_sampled = inference_outputs["prior_sampled"].reshape( + n_samples, 1, x_augmented.shape[1], self.n_labels, self.n_latent + ) + mean_vprior = torch.cat( + [ + self.mean_vprior.unsqueeze(0).unsqueeze(-4).repeat(n_samples, m, 1, 1, 1), + prior_sampled.permute(0, 2, 3, 1, 4), + ], + dim=-2, + ) # n_samples, m, n_labels, p + 1, n_latent + var_vprior = torch.cat( + [ + self.var_vprior.unsqueeze(0).unsqueeze(-4).repeat(n_samples, m, 1, 1, 1), + torch.min(self.var_vprior, dim=-2) + .values.view(1, 1, self.n_labels, 1, self.n_latent) + .repeat(n_samples, m, 1, 1, 1), + ], + dim=-2, + ) # n_samples, m, n_labels, p + 1, n_latent + + mp_vprior = torch.cat( + [ + (1 - ratio_augmentation.unsqueeze(-1)) + * self.mp_vprior.view(1, 1, self.n_labels, -1).repeat(n_samples, m, 1, 1), + ratio_augmentation.unsqueeze(-1) * ratios_ct_augmentation.unsqueeze(-1), + ], + dim=-1, + ) # n_samples, m, n_labels, p + 1 + else: + mean_vprior = self.mean_vprior.unsqueeze(0).unsqueeze(0) + var_vprior = self.var_vprior.unsqueeze(0).unsqueeze(0) + mp_vprior = self.mp_vprior.unsqueeze(0).unsqueeze(0) + + proportion_modes = generative_outputs["proportion_modes"] + # rounding errors for softmax lead to px_rate 0 which induces NaNs in the log_prob + reconst_loss = ( + -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x_augmented).sum(-1).mean(0) + ) + # beta prior likelihood + mean = torch.zeros_like(self.beta) + scale = torch.ones_like(self.beta) + # beta loss + glo_neg_log_likelihood_prior = -beta_reg * Normal(mean, scale).log_prob(self.beta).sum() + loss_augmentation = torch.tensor(0.0, device=x_augmented.device) + if self.augmentation: + expected_proportions = torch.cat( + [ + ratios_ct_augmentation, + torch.zeros([n_samples, m, self.add_celltypes]).to(v.device), + ], + dim=-1, + ) + loss_augmentation += weighting_augmentation * self._compute_cross_entropy( + expected_proportions[sample_fully_augmented, ...].squeeze(0), + v[sample_fully_augmented, ...].squeeze(0), + ) + if "x_smoothed" in tensors: + loss_augmentation += weighting_smoothing * self._compute_cross_entropy( + v[1, ...].squeeze(0), v[0, ...].squeeze(0) + ) # gamma prior likelihood if self.mean_vprior is None: @@ -294,30 +586,67 @@ def loss( mean = torch.zeros_like(gamma) scale = torch.ones_like(gamma) neg_log_likelihood_prior = -Normal(mean, scale).log_prob(gamma).sum(2).sum(1) + elif self.prior_mode == "mog": + # gamma is of shape n_samples, minibatch_size, 1, n_latent, n_labels + gamma = gamma.permute( + 1, 0, 4, 3, 2 + ) # p, n_samples, minibatch_size, n_labels, n_latent + cats = Categorical(probs=mp_vprior) + normal_dists = Independent( + Normal(mean_vprior, var_vprior), reinterpreted_batch_ndims=1 + ) + pre_lse = MixtureSameFamily(cats, normal_dists).log_prob( + gamma + ) # p, n_samples, minibatch_size, n_labels + pre_lse = pre_lse.permute(1, 0, 2, 3) + log_likelihood_prior = torch.mul(pre_lse, proportion_modes + 1e-6).sum( + -3 + ) # n_samples, minibatch, n_labels + neg_log_likelihood_prior = -torch.mul( + log_likelihood_prior, v[:, :, : -self.add_celltypes] + 1e-3 + ).sum(-1) # n_samples, minibatch, n_labels + # neg_log_likelihood_prior = - log_likelihood_prior.sum(-1) else: - # vampprior - # gamma is of shape n_latent, n_labels, minibatch_size - gamma = gamma.unsqueeze(1) # minibatch_size, 1, n_labels, n_latent - mean_vprior = torch.transpose(self.mean_vprior, 0, 1).unsqueeze( - 0 - ) # 1, p, n_labels, n_latent - var_vprior = torch.transpose(self.var_vprior, 0, 1).unsqueeze( - 0 - ) # 1, p, n_labels, n_latent - mp_vprior = torch.transpose(self.mp_vprior, 0, 1) # p, n_labels + gamma = gamma.permute( + 0, 4, 1, 3, 2 + ) # n_samples, minibatch_size, 1, n_labels, n_latent + mean_vprior = torch.transpose( + mean_vprior, -3, -2 + ) # n_samples, m, p, n_labels, n_latent + var_vprior = torch.transpose(var_vprior, -3, -2) # n_samples, m, p, n_labels, n_latent + mp_vprior = torch.transpose(mp_vprior, -2, -1) # n_samples, m, p, n_labels pre_lse = ( - Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) - ) + torch.log(mp_vprior) # minibatch, p, n_labels - # Pseudocount for numerical stability - log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels - neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch - # mean_vprior is of shape n_labels, p, n_latent - - # High v_sparsity_loss is detrimental early in training, scaling by kl_weight to increase - # over training epochs. - loss = n_obs * ( - torch.mean(reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss)) + Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4) + .log_prob(gamma) + .sum(-1) + .squeeze(-3) + ) + torch.log(1e-3 + mp_vprior) # n_samples, minibatch, p, n_labels + log_likelihood_prior = torch.logsumexp(pre_lse, -2) # n_samples, minibatch, n_labels + neg_log_likelihood_prior = -log_likelihood_prior.sum(-1) # n_samples, minibatch + if self.n_latent_amortization is not None: + neg_log_likelihood_prior += weighting_kl_latent * kl( + inference_outputs["qz"], + Normal( + torch.zeros([self.n_latent_amortization], device=x_augmented.device), + torch.ones([self.n_latent_amortization], device=x_augmented.device), + ), + ).sum(dim=-1) + + v_sparsity_loss = ( + ct_sparsity_weight * torch.distributions.Categorical(probs=v[0, ...]).entropy() + ) + v_sparsity_loss -= ( + eta_reg + * Exponential(torch.ones_like(v[0, :, -self.add_celltypes :])) + .log_prob(v[0, :, -self.add_celltypes :]) + .sum() + ) + + loss = torch.mean( + reconst_weight * reconst_loss + + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss) + glo_neg_log_likelihood_prior + + loss_augmentation ) return LossOutput( @@ -325,6 +654,10 @@ def loss( reconstruction_loss=reconst_loss, kl_local=neg_log_likelihood_prior, kl_global=glo_neg_log_likelihood_prior, + extra_metrics={ + "v_sparsity": v_sparsity_loss.mean(), + "augmentation": loss_augmentation.mean(), + }, ) @torch.inference_mode() @@ -336,77 +669,3 @@ def sample( ): """Sample from the posterior.""" raise NotImplementedError("No sampling method for DestVI") - - @torch.inference_mode() - @auto_move_data - def get_proportions(self, x=None, keep_noise=False) -> np.ndarray: - """Returns the loadings.""" - if self.amortization in ["both", "proportion"]: - # get estimated unadjusted proportions - x_ = torch.log(1 + x) - res = torch.nn.functional.softplus(self.V_encoder(x_)) - else: - res = torch.nn.functional.softplus(self.V).cpu().numpy().T # n_spots, n_labels + 1 - # remove dummy cell type proportion values - if not keep_noise: - res = res[:, :-1] - # normalize to obtain adjusted proportions - res = res / res.sum(axis=1).reshape(-1, 1) - return res - - @torch.inference_mode() - @auto_move_data - def get_gamma(self, x: torch.Tensor = None) -> torch.Tensor: - """Returns the loadings. - - Returns - ------- - type - tensor - """ - # get estimated unadjusted proportions - if self.amortization in ["latent", "both"]: - x_ = torch.log(1 + x) - gamma = self.gamma_encoder(x_) - return torch.transpose(gamma, 0, 1).reshape( - (self.n_latent, self.n_labels, -1) - ) # n_latent, n_labels, minibatch - else: - return self.gamma.cpu().numpy() # (n_latent, n_labels, n_spots) - - @torch.inference_mode() - @auto_move_data - def get_ct_specific_expression( - self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None - ): - """Returns cell type specific gene expression at the queried spots. - - Parameters - ---------- - x - tensor of data - ind_x - tensor of indices - y - integer for cell types - """ - # cell-type specific gene expression, shape (minibatch, celltype, gene). - beta = torch.exp(self.beta) # n_genes - y_torch = (y * torch.ones_like(ind_x)).ravel() - # obtain the relevant gammas - if self.amortization in ["both", "latent"]: - x_ = torch.log(1 + x) - gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape( - (self.n_latent, self.n_labels, -1) - ) - else: - gamma_ind = self.gamma[:, :, ind_x] # n_latent, n_labels, minibatch_size - - # calculate cell type specific expression - gamma_select = gamma_ind[ - :, y_torch, torch.arange(ind_x.shape[0]) - ].T # minibatch_size, n_latent - h = self.decoder(gamma_select, y_torch.unsqueeze(1)) - px_scale = self.px_decoder(h) # (minibatch, n_genes) - px_ct = torch.exp(self.px_o).unsqueeze(0) * beta.unsqueeze(0) * px_scale - return px_ct # shape (minibatch, genes) diff --git a/src/scvi/module/_vaec.py b/src/scvi/module/_vaec.py index e20b58d3e6..b28d2836ea 100644 --- a/src/scvi/module/_vaec.py +++ b/src/scvi/module/_vaec.py @@ -1,19 +1,26 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING - +import numpy as np import torch +from torch.distributions import Categorical, Distribution, Independent, MixtureSameFamily, Normal +from torch.distributions import kl_divergence as kl +from torch.nn import functional as F from scvi import REGISTRY_KEYS -from scvi.module._constants import MODULE_KEYS -from scvi.module.base import BaseModuleClass, auto_move_data +from scvi.data._constants import ADATA_MINIFY_TYPE +from scvi.distributions import NegativeBinomial +from scvi.module.base import ( + BaseMinifiedModeModuleClass, + EmbeddingModuleMixin, + LossOutput, + auto_move_data, +) +from scvi.nn import Encoder, FCLayers -if TYPE_CHECKING: - import numpy as np - from torch.distributions import Distribution +from ._classifier import Classifier +torch.backends.cudnn.benchmark = True -class VAEC(BaseModuleClass): + +class VAEC(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): """Conditional Variational auto-encoder model. This is an implementation of the CondSCVI model @@ -52,6 +59,7 @@ def __init__( n_input: int, n_batch: int = 0, n_labels: int = 0, + n_fine_labels: int | None = None, n_hidden: int = 128, n_latent: int = 5, n_layers: int = 2, @@ -61,30 +69,45 @@ def __init__( encode_covariates: bool = False, extra_encoder_kwargs: dict | None = None, extra_decoder_kwargs: dict | None = None, + linear_classifier: bool = True, + prior: str = "normal", + num_classes_mog: int | None = 10, ): - from scvi.nn import Encoder, FCLayers - super().__init__() self.dispersion = "gene" self.n_latent = n_latent self.n_layers = n_layers self.n_hidden = n_hidden self.dropout_rate = dropout_rate + self.encode_covariates = encode_covariates self.log_variational = log_variational self.gene_likelihood = "nb" self.latent_distribution = "normal" - self.encode_covariates = encode_covariates + # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels + self.n_fine_labels = n_fine_labels + self.prior = prior + self.num_classes_mog = num_classes_mog + self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **{}) if self.encode_covariates and self.n_batch < 1: raise ValueError("`n_batch` must be greater than 0 if `encode_covariates` is `True`.") + batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim + + cat_list = [n_labels] + n_input_encoder = n_input + batch_dim * encode_covariates + + # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(n_input)) + + # z encoder goes from the n_input-dimensional data to an n_latent-d + _extra_encoder_kwargs = {} self.z_encoder = Encoder( - n_input, + n_input_encoder, n_latent, - n_cat_list=[n_labels] + ([n_batch] if n_batch > 0 and encode_covariates else []), + n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, @@ -94,11 +117,39 @@ def __init__( return_dist=True, **(extra_encoder_kwargs or {}), ) + if n_fine_labels is not None: + cls_parameters = { + "n_layers": 0, + "n_hidden": 0, + "dropout_rate": dropout_rate, + "logits": True, + } + # linear mapping from latent space to a coarse-celltype aware space + self.linear_mapping = FCLayers( + n_in=n_latent, + n_out=n_hidden, + n_cat_list=[n_labels], + use_layer_norm=True, + dropout_rate=0.0, + ) + + self.classifier = Classifier( + n_hidden, + n_labels=n_fine_labels, + use_batch_norm=False, + use_layer_norm=True, + **cls_parameters, + ) + else: + self.classifier = None + # decoder goes from n_latent-dimensional space to n_input-d data + _extra_decoder_kwargs = {} + n_input_decoder = n_latent + batch_dim self.decoder = FCLayers( - n_in=n_latent, + n_in=n_input_decoder, n_out=n_hidden, - n_cat_list=[n_labels] + ([n_batch] if n_batch > 0 else []), + n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, @@ -107,48 +158,66 @@ def __init__( use_layer_norm=True, **(extra_decoder_kwargs or {}), ) - self.px_decoder = torch.nn.Sequential( - torch.nn.Linear(n_hidden, n_input), torch.nn.Softplus() - ) + self.px_decoder = torch.nn.Linear(n_hidden, n_input) + self.per_ct_bias = torch.nn.Parameter(torch.zeros(n_labels, n_input)) - self.register_buffer( - "ct_weight", - ( - torch.ones((self.n_labels,), dtype=torch.float32) - if ct_weight is None - else torch.tensor(ct_weight, dtype=torch.float32) - ), - ) - - def _get_inference_input( - self, tensors: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor | None]: - return { - MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY], - MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.LABELS_KEY], - MODULE_KEYS.BATCH_INDEX_KEY: tensors.get(REGISTRY_KEYS.BATCH_KEY, None), - } + if ct_weight is not None: + ct_weight = torch.tensor(ct_weight, dtype=torch.float32) + else: + ct_weight = torch.ones((self.n_labels,), dtype=torch.float32) + self.register_buffer("ct_weight", ct_weight) + if self.prior == "mog": + self.prior_means = torch.nn.Parameter( + torch.randn([n_labels, self.num_classes_mog, n_latent]) + ) + self.prior_log_std = torch.nn.Parameter( + torch.zeros([n_labels, self.num_classes_mog, n_latent]) - 2.0 + ) + self.prior_logits = torch.nn.Parameter(torch.zeros([n_labels, self.num_classes_mog])) + + def _get_inference_input(self, tensors): + y = tensors[REGISTRY_KEYS.LABELS_KEY] + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] + if self.minified_data_type is None: + x = tensors[REGISTRY_KEYS.X_KEY] + input_dict = { + "x": x, + "y": y, + "batch_index": batch_index, + } + else: + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): + qzm = tensors[REGISTRY_KEYS.LATENT_QZM_KEY] + qzv = tensors[REGISTRY_KEYS.LATENT_QZV_KEY] + observed_lib_size = tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE] + input_dict = { + "qzm": qzm, + "qzv": qzv, + "observed_lib_size": observed_lib_size, + "y": y, + "batch_index": batch_index, + } + else: + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") + + return input_dict + + def _get_generative_input(self, tensors, inference_outputs): + z = inference_outputs["z"] + library = inference_outputs["library"] + y = tensors[REGISTRY_KEYS.LABELS_KEY] + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - def _get_generative_input( - self, - tensors: dict[str, torch.Tensor], - inference_outputs: dict[str, torch.Tensor | Distribution], - ) -> dict[str, torch.Tensor]: - return { - MODULE_KEYS.Z_KEY: inference_outputs[MODULE_KEYS.Z_KEY], - MODULE_KEYS.LIBRARY_KEY: inference_outputs[MODULE_KEYS.LIBRARY_KEY], - MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.LABELS_KEY], - MODULE_KEYS.BATCH_INDEX_KEY: tensors.get(REGISTRY_KEYS.BATCH_KEY, None), + input_dict = { + "z": z, + "library": library, + "y": y, + "batch_index": batch_index, } + return input_dict @auto_move_data - def inference( - self, - x: torch.Tensor, - y: torch.Tensor, - batch_index: torch.Tensor | None = None, - n_samples: int = 1, - ) -> dict[str, torch.Tensor | Distribution]: + def _regular_inference(self, x, y, batch_index, n_samples=1): """High level inference method. Runs the inference (encoder) model. @@ -157,23 +226,62 @@ def inference( library = x.sum(1).unsqueeze(1) if self.log_variational: x_ = torch.log1p(x_) - - encoder_input = [x_, y] - if batch_index is not None and self.encode_covariates: - encoder_input.append(batch_index) - - qz, z = self.z_encoder(*encoder_input) + if self.encode_covariates: + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + encoder_input = torch.cat([x_, batch_rep], dim=-1) + else: + encoder_input = x_ + qz, z = self.z_encoder(encoder_input, y) if n_samples > 1: untran_z = qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) - return { - MODULE_KEYS.Z_KEY: z, - MODULE_KEYS.QZ_KEY: qz, - MODULE_KEYS.LIBRARY_KEY: library, - } + outputs = {"z": z, "qz": qz, "library": library} + return outputs + + @auto_move_data + def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): + qz = Normal(qzm, qzv.sqrt()) + # use dist.sample() rather than rsample because we aren't optimizing the z here + untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) + z = self.z_encoder.z_transformation(untran_z) + library = observed_lib_size + if n_samples > 1: + library = library.unsqueeze(0).expand( + (n_samples, library.size(0), library.size(1)) + ) + else: + raise NotImplementedError(f"Unknown minified-data type: {self.minified_data_type}") + outputs = {"z": z, "qz": qz, "library": library} + return outputs + + @auto_move_data + def classify( + self, + z: torch.Tensor, + label_index: torch.Tensor | None = None, + ) -> torch.Tensor: + """Forward pass through the encoder and classifier. + + Parameters + ---------- + z + Tensor of shape ``(n_obs, n_latent)``. + label_index + Tensor of shape ``(n_obs,)`` denoting label indices. + + Returns + ------- + Tensor of shape ``(n_obs, n_labels)`` denoting logit scores per label. + """ + if len(label_index.shape) == 1: + label_index = label_index.unsqueeze(1) + classifier_latent = self.linear_mapping(z, label_index) + w_y = self.classifier(classifier_latent) + return w_y @auto_move_data def generative( @@ -185,19 +293,14 @@ def generative( transform_batch: torch.Tensor | None = None, ) -> dict[str, Distribution]: """Runs the generative model.""" - from scvi.distributions import NegativeBinomial - - decoder_input = [z, y] - if transform_batch is not None: - batch_index = torch.ones_like(batch_index) * transform_batch - - if batch_index is not None: - decoder_input.append(batch_index) - - h = self.decoder(*decoder_input) - px_scale = self.px_decoder(h) + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + decoder_input = torch.cat([z, batch_rep], dim=-1) + h = self.decoder(decoder_input, y, batch_index) + px_scale = torch.nn.Softmax(dim=-1)(self.px_decoder(h) + self.per_ct_bias[y.ravel()]) px_rate = library * px_scale - return {MODULE_KEYS.PX_KEY: NegativeBinomial(px_rate, logits=self.px_r)} + px_r = torch.exp(self.px_r) + px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) + return {"px": px} def loss( self, @@ -205,35 +308,74 @@ def loss( inference_outputs: dict[str, torch.Tensor | Distribution], generative_outputs: dict[str, Distribution], kl_weight: float = 1.0, + labelled_tensors: dict[str, torch.Tensor] | None = None, + classification_ratio=5.0, ): """Loss computation.""" - from torch.distributions import Normal - from torch.distributions import kl_divergence as kl - - from scvi.module.base import LossOutput - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY] - qz = inference_outputs[MODULE_KEYS.QZ_KEY] - px = generative_outputs[MODULE_KEYS.PX_KEY] - - mean = torch.zeros_like(qz.loc) - scale = torch.ones_like(qz.scale) + y = tensors[REGISTRY_KEYS.LABELS_KEY].ravel().long() + qz = inference_outputs["qz"] + px = generative_outputs["px"] + fine_labels = ( + tensors["fine_labels"].ravel().long() if "fine_labels" in tensors.keys() else None + ) - kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) + if self.prior == "mog": + indexed_means = self.prior_means[y] + indexed_log_std = self.prior_log_std[y] + indexed_logits = self.prior_logits[y] + cats = Categorical(logits=indexed_logits) + normal_dists = Independent( + Normal(indexed_means, torch.exp(indexed_log_std) + 1e-4), + reinterpreted_batch_ndims=1, + ) + prior = MixtureSameFamily(cats, normal_dists) + u = qz.rsample(sample_shape=(30,)) + # (sample, n_obs, n_latent) -> (sample, n_obs,) + kl_divergence_z = (qz.log_prob(u).sum(-1) - prior.log_prob(u)).mean(0) + else: + mean = torch.zeros_like(qz.loc) + scale = torch.ones_like(qz.scale) + kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) reconst_loss = -px.log_prob(x).sum(-1) - scaling_factor = self.ct_weight[y.long()[:, 0]] - loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) - - return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z) + scaling_factor = self.ct_weight[y] + + if self.classifier is not None: + fine_labels = fine_labels.view(-1) + logits = self.classify( + qz.loc, label_index=tensors[REGISTRY_KEYS.LABELS_KEY] + ) # (n_obs, n_labels) + classification_loss_ = F.cross_entropy(logits, fine_labels, reduction="none") + mask = fine_labels != self.n_fine_labels + classification_loss = classification_ratio * torch.masked_select( + classification_loss_, mask + ).mean(0) + + loss = torch.mean( + scaling_factor * (reconst_loss + classification_loss + kl_weight * kl_divergence_z) + ) + + return LossOutput( + loss=loss, + reconstruction_loss=reconst_loss, + kl_local=kl_divergence_z, + classification_loss=classification_loss, + logits=logits, + true_labels=fine_labels, + ) + else: + loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z + ) @torch.inference_mode() def sample( self, - tensors: dict[str, torch.Tensor], - n_samples: int = 1, - ) -> torch.Tensor: + tensors, + n_samples=1, + ) -> np.ndarray: r"""Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. @@ -257,7 +399,10 @@ def sample( compute_loss=False, )[1] - dist = generative_outputs[MODULE_KEYS.PX_KEY] + px_r = generative_outputs["px_r"] + px_rate = generative_outputs["px_rate"] + + dist = NegativeBinomial(px_rate, logits=px_r) if n_samples > 1: exprs = dist.sample().permute([1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) else: diff --git a/tests/model/test_condscvi.py b/tests/model/test_condscvi.py index bf78a24a8a..2c15b9d926 100644 --- a/tests/model/test_condscvi.py +++ b/tests/model/test_condscvi.py @@ -1,5 +1,6 @@ import os +import numpy as np import pytest from scvi.data import synthetic_iid @@ -32,6 +33,22 @@ def test_condscvi_batch_key( model = CondSCVI.load(model_path, adata=adata) +def test_condscvi_fine_celltype( + save_path: str, +): + adata = synthetic_iid(n_batches=5, n_labels=5) + adata.obs["fine_labels"] = [i + str(np.random.randint(2)) for i in adata.obs["labels"]] + CondSCVI.setup_anndata( + adata, batch_key="batch", labels_key="labels", fine_labels_key="fine_labels" + ) + model = CondSCVI(adata, encode_covariates=True) + + model.train(max_epochs=1) + model.predict() + model.predict(adata=adata) + model.predict(adata, soft=True, use_posterior_mean=False) + + def test_condscvi_batch_key_compat_load(save_path: str): adata = synthetic_iid(n_batches=1, n_labels=5) model = CondSCVI.load("tests/test_data/condscvi_pre_batch", adata=adata) diff --git a/tests/model/test_destvi.py b/tests/model/test_destvi.py index 0bcff17ab8..c98becf89e 100644 --- a/tests/model/test_destvi.py +++ b/tests/model/test_destvi.py @@ -13,7 +13,7 @@ def test_destvi(): dataset = synthetic_iid(n_labels=n_labels) dataset.obs["overclustering_vamp"] = list(range(dataset.n_obs)) CondSCVI.setup_anndata(dataset, labels_key="labels") - sc_model = CondSCVI(dataset, n_latent=n_latent, n_layers=n_layers) + sc_model = CondSCVI(dataset, n_latent=n_latent, n_layers=n_layers, prior="mog") sc_model.train(1, train_size=1) sc_model.get_normalized_expression(dataset) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 29413dbd7d..f012d0ee04 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -9,7 +9,7 @@ from scvi.data import synthetic_iid from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified -from scvi.model import SCANVI, SCVI +from scvi.model import SCANVI, SCVI, CondSCVI from scvi.model.base import BaseMinifiedModeModelClass if TYPE_CHECKING: @@ -449,3 +449,28 @@ def test_scvi_with_minified_adata_get_feature_correlation_matrix(): ) assert_approx_equal(fcm_new, fcm_orig) + + +def test_condscvi_with_minified_adata_one_sample(): + run_test_for_model_with_minified_adata(CondSCVI) + + +def test_condscvi_with_minified_adata_one_sample_with_spec_layer(): + run_test_for_model_with_minified_adata(CondSCVI, layer="data_layer") + + +def test_condscvi_with_minified_adata_n_samples(): + run_test_for_model_with_minified_adata(CondSCVI, n_samples=10, give_mean=True) + + +def test_condscvi_downstream(): + model, adata, _, adata_before_setup = prep_model(CondSCVI) + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + adata.obsm["X_latent_qzm"] = qzm + adata.obsm["X_latent_qzv"] = qzv + model.minify_adata() + model.get_vamp_prior() + scvi.model.DestVI.setup_anndata(adata_before_setup) + scvi.model.DestVI.from_rna_model( + adata_before_setup, model, amortization="both", vamp_prior_p=10 + )