From 058fdcacbc080661e413e7379ab7beda4c0f3f3a Mon Sep 17 00:00:00 2001 From: cane11 Date: Sun, 31 Dec 2023 01:03:46 -0800 Subject: [PATCH 01/12] Change setup for minified data add condscvi --- scvi/data/_constants.py | 1 + scvi/data/_utils.py | 3 +- scvi/model/_condscvi.py | 21 +-- scvi/model/_scanvi.py | 101 +------------- scvi/model/_scvi.py | 97 +------------- scvi/model/base/_base_model.py | 126 +++++++++++++++--- scvi/model/utils/_minification.py | 13 +- scvi/module/_vae.py | 4 +- scvi/module/_vaec.py | 53 ++++++-- scvi/module/base/_base_module.py | 2 +- scvi/utils/_decorators.py | 4 +- tests/model/test_models_with_minified_data.py | 30 ++++- 12 files changed, 213 insertions(+), 242 deletions(-) diff --git a/scvi/data/_constants.py b/scvi/data/_constants.py index 9efa664537..50c0f39a5f 100644 --- a/scvi/data/_constants.py +++ b/scvi/data/_constants.py @@ -37,6 +37,7 @@ class _ADATA_MINIFY_TYPE_NT(NamedTuple): LATENT_POSTERIOR: str = "latent_posterior_parameters" + ADD_POSTERIOR_PARAMETERS: str = "add_posterior_parameters" ADATA_MINIFY_TYPE = _ADATA_MINIFY_TYPE_NT() diff --git a/scvi/data/_utils.py b/scvi/data/_utils.py index 71bf342a4a..f8b2724744 100644 --- a/scvi/data/_utils.py +++ b/scvi/data/_utils.py @@ -324,8 +324,9 @@ def _get_adata_minify_type(adata: AnnData) -> Union[MinifiedDataType, None]: def _is_minified(adata: Union[AnnData, str]) -> bool: uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY + layer_key = REGISTRY_KEYS.X_KEY if isinstance(adata, AnnData): - return adata.uns.get(uns_key, None) is not None + return adata.layers.get(layer_key, adata.X).sum()==0 elif isinstance(adata, str): with h5py.File(adata) as fp: return uns_key in read_elem(fp["uns"]).keys() diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index 4e8c7a42b0..8ad23f92ca 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -12,7 +12,7 @@ from scvi.data import AnnDataManager from scvi.data.fields import CategoricalObsField, LayerField from scvi.model.base import ( - BaseModelClass, + BaseMinifiedModeModelClass, RNASeqMixin, UnsupervisedTrainingMixin, VAEMixin, @@ -24,7 +24,7 @@ logger = logging.getLogger(__name__) -class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseMinifiedModeModelClass): """Conditional version of single-cell Variational Inference, used for multi-resolution deconvolution of spatial transcriptomics data :cite:p:`Lopez22`. Parameters @@ -60,6 +60,9 @@ class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass) """ _module_cls = VAEC + _LATENT_QZM = "_condscvi_latent_qzm" + _LATENT_QZV = "_condscvi_latent_qzv" + _OBSERVED_LIB_SIZE = "_condscvi_observed_lib_size" def __init__( self, @@ -140,19 +143,7 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra key = labels_state_registry.original_key mapping = labels_state_registry.categorical_mapping - scdl = self._make_data_loader(adata=adata, batch_size=p) - - mean = [] - var = [] - for tensors in scdl: - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY] - out = self.module.inference(x, y) - mean_, var_ = out["qz"].loc, (out["qz"].scale ** 2) - mean += [mean_.cpu()] - var += [var_.cpu()] - - mean_cat, var_cat = torch.cat(mean).numpy(), torch.cat(var).numpy() + mean_cat, var_cat = self.get_latent_representation(adata, return_dist=True) for ct in range(self.summary_stats["n_labels"]): local_indices = np.where(adata.obs[key] == mapping[ct])[0] diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index 35838b4509..9058cc1798 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -12,28 +12,21 @@ from anndata import AnnData from scvi import REGISTRY_KEYS, settings -from scvi._types import MinifiedDataType from scvi.data import AnnDataManager from scvi.data._constants import ( - _ADATA_MINIFY_TYPE_UNS_KEY, _SETUP_ARGS_KEY, - ADATA_MINIFY_TYPE, ) from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute from scvi.data.fields import ( - BaseAnnDataField, CategoricalJointObsField, CategoricalObsField, LabelsWithUnlabeledObsField, LayerField, NumericalJointObsField, NumericalObsField, - ObsmField, - StringUnsField, ) from scvi.dataloaders import SemiSupervisedDataSplitter from scvi.model._utils import _init_library_size, get_max_epochs_heuristic -from scvi.model.utils import get_minified_adata_scrna from scvi.module import SCANVAE from scvi.train import SemiSupervisedTrainingPlan, TrainRunner from scvi.train._callbacks import SubSampleLabels @@ -43,10 +36,6 @@ from ._scvi import SCVI from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin -_SCANVI_LATENT_QZM = "_scanvi_latent_qzm" -_SCANVI_LATENT_QZV = "_scanvi_latent_qzv" -_SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" - logger = logging.getLogger(__name__) @@ -106,6 +95,9 @@ class SCANVI(RNASeqMixin, VAEMixin, ArchesMixin, BaseMinifiedModeModelClass): _module_cls = SCANVAE _training_plan_cls = SemiSupervisedTrainingPlan + _LATENT_QZM = "_scanvi_latent_qzm" + _LATENT_QZV = "_scanvi_latent_qzv" + _OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" def __init__( self, @@ -478,6 +470,8 @@ def setup_anndata( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ + + print("XXXXX", cls._latent_qzm) setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), @@ -498,90 +492,9 @@ def setup_anndata( # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: - anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + anndata_fields += cls._get_fields_for_adata_minification(cls, adata_minify_type) adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) - - @staticmethod - def _get_fields_for_adata_minification( - minified_data_type: MinifiedDataType, - ) -> list[BaseAnnDataField]: - """Return the anndata fields required for adata minification of the given minified_data_type.""" - if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - fields = [ - ObsmField( - REGISTRY_KEYS.LATENT_QZM_KEY, - _SCANVI_LATENT_QZM, - ), - ObsmField( - REGISTRY_KEYS.LATENT_QZV_KEY, - _SCANVI_LATENT_QZV, - ), - NumericalObsField( - REGISTRY_KEYS.OBSERVED_LIB_SIZE, - _SCANVI_OBSERVED_LIB_SIZE, - ), - ] - else: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - fields.append( - StringUnsField( - REGISTRY_KEYS.MINIFY_TYPE_KEY, - _ADATA_MINIFY_TYPE_UNS_KEY, - ), - ) - return fields - - def minify_adata( - self, - minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, - use_latent_qzm_key: str = "X_latent_qzm", - use_latent_qzv_key: str = "X_latent_qzv", - ): - """Minifies the model's adata. - - Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns - containing minified-adata type, and library size. - This also sets the appropriate property on the module to indicate that the adata is minified. - - Parameters - ---------- - minified_data_type - How to minify the data. Currently only supports `latent_posterior_parameters`. - If minified_data_type == `latent_posterior_parameters`: - - * the original count data is removed (`adata.X`, adata.raw, and any layers) - * the parameters of the latent representation of the original data is stored - * everything else is left untouched - use_latent_qzm_key - Key to use in `adata.obsm` where the latent qzm params are stored - use_latent_qzv_key - Key to use in `adata.obsm` where the latent qzv params are stored - - Notes - ----- - The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. - """ - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - - if self.module.use_observed_lib_size is False: - raise ValueError( - "Cannot minify the data if `use_observed_lib_size` is False" - ) - - minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) - minified_adata.obsm[_SCANVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] - minified_adata.obsm[_SCANVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] - counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCANVI_OBSERVED_LIB_SIZE] = np.squeeze( - np.asarray(counts.sum(axis=1)) - ) - self._update_adata_and_manager_post_minification( - minified_adata, minified_data_type - ) - self.module.minified_data_type = minified_data_type + cls.register_manager(adata_manager) \ No newline at end of file diff --git a/scvi/model/_scvi.py b/scvi/model/_scvi.py index c2459b1329..074eec3c1a 100644 --- a/scvi/model/_scvi.py +++ b/scvi/model/_scvi.py @@ -1,36 +1,25 @@ import logging from typing import Literal, Optional -import numpy as np from anndata import AnnData from scvi import REGISTRY_KEYS -from scvi._types import MinifiedDataType from scvi.data import AnnDataManager -from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( - BaseAnnDataField, CategoricalJointObsField, CategoricalObsField, LayerField, NumericalJointObsField, NumericalObsField, - ObsmField, - StringUnsField, ) from scvi.model._utils import _init_library_size from scvi.model.base import UnsupervisedTrainingMixin -from scvi.model.utils import get_minified_adata_scrna from scvi.module import VAE from scvi.utils import setup_anndata_dsp from .base import ArchesMixin, BaseMinifiedModeModelClass, RNASeqMixin, VAEMixin -_SCVI_LATENT_QZM = "_scvi_latent_qzm" -_SCVI_LATENT_QZV = "_scvi_latent_qzv" -_SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" - logger = logging.getLogger(__name__) @@ -96,6 +85,9 @@ class SCVI( """ _module_cls = VAE + _LATENT_QZM = "_scvi_latent_qzm" + _LATENT_QZV = "_scvi_latent_qzv" + _OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" def __init__( self, @@ -204,92 +196,11 @@ def setup_anndata( # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: - anndata_fields += cls._get_fields_for_adata_minification(adata_minify_type) + anndata_fields += cls._get_fields_for_adata_minification(cls, adata_minify_type) adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - @staticmethod - def _get_fields_for_adata_minification( - minified_data_type: MinifiedDataType, - ) -> list[BaseAnnDataField]: - """Return the anndata fields required for adata minification of the given minified_data_type.""" - if minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - fields = [ - ObsmField( - REGISTRY_KEYS.LATENT_QZM_KEY, - _SCVI_LATENT_QZM, - ), - ObsmField( - REGISTRY_KEYS.LATENT_QZV_KEY, - _SCVI_LATENT_QZV, - ), - NumericalObsField( - REGISTRY_KEYS.OBSERVED_LIB_SIZE, - _SCVI_OBSERVED_LIB_SIZE, - ), - ] - else: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - fields.append( - StringUnsField( - REGISTRY_KEYS.MINIFY_TYPE_KEY, - _ADATA_MINIFY_TYPE_UNS_KEY, - ), - ) - return fields - - def minify_adata( - self, - minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, - use_latent_qzm_key: str = "X_latent_qzm", - use_latent_qzv_key: str = "X_latent_qzv", - ) -> None: - """Minifies the model's adata. - - Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns - containing minified-adata type, and library size. - This also sets the appropriate property on the module to indicate that the adata is minified. - - Parameters - ---------- - minified_data_type - How to minify the data. Currently only supports `latent_posterior_parameters`. - If minified_data_type == `latent_posterior_parameters`: - - * the original count data is removed (`adata.X`, adata.raw, and any layers) - * the parameters of the latent representation of the original data is stored - * everything else is left untouched - use_latent_qzm_key - Key to use in `adata.obsm` where the latent qzm params are stored - use_latent_qzv_key - Key to use in `adata.obsm` where the latent qzv params are stored - - Notes - ----- - The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. - """ - # TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior - # without removing the original counts. - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: - raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - - if self.module.use_observed_lib_size is False: - raise ValueError( - "Cannot minify the data if `use_observed_lib_size` is False" - ) - minified_adata = get_minified_adata_scrna(self.adata, minified_data_type) - minified_adata.obsm[_SCVI_LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] - minified_adata.obsm[_SCVI_LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] - counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) - minified_adata.obs[_SCVI_OBSERVED_LIB_SIZE] = np.squeeze( - np.asarray(counts.sum(axis=1)) - ) - self._update_adata_and_manager_post_minification( - minified_adata, minified_data_type - ) - self.module.minified_data_type = minified_data_type diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index c29fda27d7..1691f261d7 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -19,15 +19,24 @@ from scvi.data import AnnDataManager from scvi.data._compat import registry_from_setup_dict from scvi.data._constants import ( + _ADATA_MINIFY_TYPE_UNS_KEY, _MODEL_NAME_KEY, _SCVI_UUID_KEY, _SETUP_ARGS_KEY, _SETUP_METHOD_NAME, + ADATA_MINIFY_TYPE, ) from scvi.data._utils import _assign_adata_uuid, _check_if_view, _get_adata_minify_type +from scvi.data.fields import ( + BaseAnnDataField, + NumericalObsField, + ObsmField, + StringUnsField, +) from scvi.dataloaders import AnnDataLoader from scvi.model._utils import parse_device_args from scvi.model.base._utils import _load_legacy_saved_files +from scvi.model.utils import get_minified_adata_scrna from scvi.utils import attrdict, setup_anndata_dsp from scvi.utils._docstrings import devices_dsp @@ -481,6 +490,101 @@ def _check_if_trained( else: raise RuntimeError(message) + def minify_adata( + self, + minified_data_type: MinifiedDataType = ADATA_MINIFY_TYPE.LATENT_POSTERIOR, + use_latent_qzm_key: str = "X_latent_qzm", + use_latent_qzv_key: str = "X_latent_qzv", + ): + """Minifies the model's adata. + + Minifies the adata, and registers new anndata fields: latent qzm, latent qzv, adata uns + containing minified-adata type, and library size. + This also sets the appropriate property on the module to indicate that the adata is minified. + + Parameters + ---------- + minified_data_type + How to minify the data. Currently only supports `latent_posterior_parameters` and `add_posterior_parameters`,. + If minified_data_type == `latent_posterior_parameters`: + + * the original count data is removed (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + If minified_data_type == `add_posterior_parameters`: + + * the original count data is kept (`adata.X`, adata.raw, and any layers) + * the parameters of the latent representation of the original data is stored + * everything else is left untouched + use_latent_qzm_key + Key to use in `adata.obsm` where the latent qzm params are stored + use_latent_qzv_key + Key to use in `adata.obsm` where the latent qzv params are stored + + Notes + ----- + The modification is not done inplace -- instead the model is assigned a new (minified) + version of the adata. + """ + # TODO(adamgayoso): Add support for a scenario where we want to cache the latent posterior + if not ADATA_MINIFY_TYPE.__contains__(minified_data_type): + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + + if minified_data_type == ADATA_MINIFY_TYPE.ADD_POSTERIOR_PARAMETERS: + keep_count_data = True + elif minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + keep_count_data = False + + if not getattr(self.module, 'use_observed_lib_size', True): + raise ValueError( + "Cannot minify the data if `use_observed_lib_size` is False" + ) + + minified_adata = get_minified_adata_scrna(self.adata, minified_data_type, keep_count_data=keep_count_data) + minified_adata.obsm[self._LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] + minified_adata.obsm[self._LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] + counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) + minified_adata.obs[self._OBSERVED_LIB_SIZE] = np.squeeze( + np.asarray(counts.sum(axis=1)) + ) + self._update_adata_and_manager_post_minification( + minified_adata, minified_data_type + ) + self.module.minified_data_type = minified_data_type + + def _get_fields_for_adata_minification( + self, + minified_data_type: MinifiedDataType, + ) -> list[BaseAnnDataField]: + """Return the anndata fields required for adata minification of the given minified_data_type.""" + assert self._LATENT_QZM, NotImplementedError("Minified mode is not defined for model.") + if ADATA_MINIFY_TYPE.__contains__(minified_data_type): + fields = [ + ObsmField( + REGISTRY_KEYS.LATENT_QZM_KEY, + self._LATENT_QZM, + ), + ObsmField( + REGISTRY_KEYS.LATENT_QZV_KEY, + self._LATENT_QZV, + ), + NumericalObsField( + REGISTRY_KEYS.OBSERVED_LIB_SIZE, + self._OBSERVED_LIB_SIZE, + ), + ] + else: + raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") + fields.append( + StringUnsField( + REGISTRY_KEYS.MINIFY_TYPE_KEY, + _ADATA_MINIFY_TYPE_UNS_KEY, + ), + ) + return fields + + + @property def is_trained(self) -> bool: """Whether the model has been trained.""" @@ -893,28 +997,6 @@ def minified_data_type(self) -> MinifiedDataType | None: else None ) - @abstractmethod - def minify_adata( - self, - *args, - **kwargs, - ): - """Minifies the model's adata. - - Minifies the adata, and registers new anndata fields as required (can be model-specific). - This also sets the appropriate property on the module to indicate that the adata is minified. - - Notes - ----- - The modification is not done inplace -- instead the model is assigned a new (minified) - version of the adata. - """ - - @staticmethod - @abstractmethod - def _get_fields_for_adata_minification(minified_data_type: MinifiedDataType): - """Return the anndata fields required for adata minification of the given type.""" - def _update_adata_and_manager_post_minification( self, minified_adata: AnnOrMuData, minified_data_type: MinifiedDataType ): diff --git a/scvi/model/utils/_minification.py b/scvi/model/utils/_minification.py index cf84687bc5..d5ed1b479a 100644 --- a/scvi/model/utils/_minification.py +++ b/scvi/model/utils/_minification.py @@ -12,6 +12,7 @@ def get_minified_adata_scrna( adata: AnnData, minified_data_type: MinifiedDataType, + keep_count_data: bool = False, ) -> AnnData: """Returns a minified adata that works for most scrna models (such as SCVI, SCANVI). @@ -21,12 +22,18 @@ def get_minified_adata_scrna( Original adata, of which we to create a minified version. minified_data_type How to minify the data. + keep_count_data + Whether to keep the count data and only store additionally the latent posterior. """ - if minified_data_type != ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if not ADATA_MINIFY_TYPE.__contains__(minified_data_type): raise NotImplementedError(f"Unknown MinifiedDataType: {minified_data_type}") - all_zeros = csr_matrix(adata.X.shape) - layers = {layer: all_zeros.copy() for layer in adata.layers} + if not keep_count_data: + all_zeros = csr_matrix(adata.X.shape) + layers = {layer: all_zeros.copy() for layer in adata.layers} + else: + all_zeros = adata.X + layers = adata.layers bdata = AnnData( X=all_zeros, layers=layers, diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index c5c093f68e..b25fce88c5 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -245,7 +245,7 @@ def _get_inference_input( "cat_covs": cat_covs, } else: - if self.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): qzm = tensors[REGISTRY_KEYS.LATENT_QZM_KEY] qzv = tensors[REGISTRY_KEYS.LATENT_QZV_KEY] observed_lib_size = tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE] @@ -356,7 +356,7 @@ def _regular_inference( @auto_move_data def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): - if self.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): dist = Normal(qzm, qzv.sqrt()) # use dist.sample() rather than rsample because we aren't optimizing the z here untran_z = dist.sample() if n_samples == 1 else dist.sample((n_samples,)) diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index 94a14a2c6e..d6b0bf920f 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -7,15 +7,16 @@ from scvi import REGISTRY_KEYS from scvi._types import Tunable +from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data from scvi.nn import Encoder, FCLayers torch.backends.cudnn.benchmark = True # Conditional VAE model -class VAEC(BaseModuleClass): +class VAEC(BaseMinifiedModeModuleClass): """Conditional Variational auto-encoder model. This is an implementation of the CondSCVI model @@ -114,13 +115,28 @@ def __init__( self.register_buffer("ct_weight", ct_weight) def _get_inference_input(self, tensors): - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY] + if self.minified_data_type is None: + x = tensors[REGISTRY_KEYS.X_KEY] + y = tensors[REGISTRY_KEYS.LABELS_KEY] + input_dict = { + "x": x, + "y": y, + } + else: + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): + qzm = tensors[REGISTRY_KEYS.LATENT_QZM_KEY] + qzv = tensors[REGISTRY_KEYS.LATENT_QZV_KEY] + observed_lib_size = tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE] + input_dict = { + "qzm": qzm, + "qzv": qzv, + "observed_lib_size": observed_lib_size, + } + else: + raise NotImplementedError( + f"Unknown minified-data type: {self.minified_data_type}" + ) - input_dict = { - "x": x, - "y": y, - } return input_dict def _get_generative_input(self, tensors, inference_outputs): @@ -136,7 +152,7 @@ def _get_generative_input(self, tensors, inference_outputs): return input_dict @auto_move_data - def inference(self, x, y, n_samples=1): + def _regular_inference(self, x, y, n_samples=1): """High level inference method. Runs the inference (encoder) model. @@ -158,6 +174,25 @@ def inference(self, x, y, n_samples=1): outputs = {"z": z, "qz": qz, "library": library} return outputs + @auto_move_data + def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): + qz = Normal(qzm, qzv.sqrt()) + # use dist.sample() rather than rsample because we aren't optimizing the z here + untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) + z = self.z_encoder.z_transformation(untran_z) + library = observed_lib_size + if n_samples > 1: + library = library.unsqueeze(0).expand( + (n_samples, library.size(0), library.size(1)) + ) + else: + raise NotImplementedError( + f"Unknown minified-data type: {self.minified_data_type}" + ) + outputs = {"z": z, "qz": qz, "library": library} + return outputs + @auto_move_data def generative(self, z, library, y): """Runs the generative model.""" diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index 2c483a0825..a8352b8c1a 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -301,7 +301,7 @@ def inference(self, *args, **kwargs): """ if ( self.minified_data_type is not None - and self.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + and ADATA_MINIFY_TYPE.__contains__(self.minified_data_type) ): return self._cached_inference(*args, **kwargs) else: diff --git a/scvi/utils/_decorators.py b/scvi/utils/_decorators.py index aa205a2da7..8618e866a0 100644 --- a/scvi/utils/_decorators.py +++ b/scvi/utils/_decorators.py @@ -1,13 +1,15 @@ from functools import wraps from typing import Callable +from scvi.data._constants import ADATA_MINIFY_TYPE + def unsupported_if_adata_minified(fn: Callable) -> Callable: """Decorator to raise an error if the model's `adata` is minified.""" @wraps(fn) def wrapper(self, *args, **kwargs): - if getattr(self, "minified_data_type", None) is not None: + if getattr(self, "minified_data_type", None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR: raise ValueError( f"The {fn.__qualname__} function currently does not support minified data." ) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 37b61ff34b..02c82ecd1f 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -5,7 +5,7 @@ from scvi.data import synthetic_iid from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified -from scvi.model import SCANVI, SCVI +from scvi.model import SCANVI, SCVI, CondSCVI, DestVI _SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" _SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" @@ -156,6 +156,28 @@ def test_scanvi_with_minified_adata_n_samples(): run_test_for_model_with_minified_adata( SCANVI, n_samples=10, give_mean=True, use_size_factor=True ) + +def test_condscvi_with_minified_adata_one_sample(): + run_test_for_model_with_minified_adata(CondSCVI) + + +def test_condscvi_with_minified_adata_one_sample(): + run_test_for_model_with_minified_adata(CondSCVI, layer="data_layer") + + +def test_condscvi_with_minified_adata_one_sample(): + run_test_for_model_with_minified_adata(CondSCVI, n_samples=10, give_mean=True) + + +def test_condscvi_downstream(): + model, adata, _, adata_before_setup = prep_model(CondSCVI) + zm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + adata.obsm["X_latent_qzm"] = qzm + adata.obsm["X_latent_qzv"] = qzv + model.minify_adata() + model.get_vamp_prior() + scvi.model.DestVI.setup_anndata(adata_before_setup) + scvi.model.DestVI.from_rna_model(adata_before_setup, model, amortization="both", vamp_prior_p=10) def test_scanvi_from_scvi(save_path): @@ -203,8 +225,14 @@ def test_scvi_with_minified_adata_get_normalized_expression(): scvi.settings.seed = 1 exprs_orig = model.get_normalized_expression() + model.minify_adata('add_posterior_parameters') + assert model.minified_data_type == ADATA_MINIFY_TYPE.ADD_POSTERIOR_PARAMETERS + assert np.isfinite(model.adata.get_elbo()) + assert np.isfinite(model.adata.get_reconstruction_error()) + model.minify_adata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR scvi.settings.seed = 1 exprs_new = model.get_normalized_expression() From c876fd978e01725bc8dc9b6e132b89bff4f4518e Mon Sep 17 00:00:00 2001 From: cane11 Date: Fri, 5 Jan 2024 11:35:35 -0800 Subject: [PATCH 02/12] First draft destVI v2. --- scvi/data/fields/_dataframe_field.py | 15 ++++- scvi/model/_condscvi.py | 69 ++++++++++++++++++++- scvi/model/_destvi.py | 25 +++++++- scvi/module/_mrdeconv.py | 12 +++- scvi/module/_vaec.py | 92 +++++++++++++++++++++++----- 5 files changed, 192 insertions(+), 21 deletions(-) diff --git a/scvi/data/fields/_dataframe_field.py b/scvi/data/fields/_dataframe_field.py index c35770c4ea..3c36625ba6 100644 --- a/scvi/data/fields/_dataframe_field.py +++ b/scvi/data/fields/_dataframe_field.py @@ -36,9 +36,10 @@ def __init__( attr_key: Optional[str], field_type: Literal["obs", "var"] = None, required: bool = True, + is_empty: bool = False, ) -> None: super().__init__() - if required and attr_key is None: + if required and (attr_key is None or is_empty): raise ValueError( "`attr_key` cannot be `None` if `required=True`. Please provide an `attr_key`." ) @@ -51,7 +52,7 @@ def __init__( self._registry_key = registry_key self._attr_key = attr_key - self._is_empty = attr_key is None + self._is_empty = is_empty or attr_key is None @property def registry_key(self) -> str: @@ -136,6 +137,8 @@ class CategoricalDataFrameField(BaseDataFrameField): Key to access the field in the AnnData obs or var mapping. If None, defaults to `registry_key`. field_type Type of field. Can be either "obs" or "var". + required + If False, allows for `attr_key is None` and marks the field as `is_empty`. """ CATEGORICAL_MAPPING_KEY = "categorical_mapping" @@ -146,14 +149,18 @@ def __init__( registry_key: str, attr_key: Optional[str], field_type: Literal["obs", "var"] = None, + required: bool = True, ) -> None: self.is_default = attr_key is None self._original_attr_key = attr_key or registry_key + is_empty = attr_key is None super().__init__( registry_key, f"_scvi_{registry_key}", field_type=field_type, + required=required, + is_empty=is_empty, ) self.count_stat_key = f"n_{self.registry_key}" @@ -237,12 +244,16 @@ def transfer_field( def get_summary_stats(self, state_registry: dict) -> dict: """Get summary stats.""" + if self.is_empty: + return {} categorical_mapping = state_registry[self.CATEGORICAL_MAPPING_KEY] n_categories = len(np.unique(categorical_mapping)) return {self.count_stat_key: n_categories} def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: """View state registry.""" + if self.is_empty: + return None source_key = state_registry[self.ORIGINAL_ATTR_KEY] mapping = state_registry[self.CATEGORICAL_MAPPING_KEY] t = rich.table.Table(title=f"{self.registry_key} State Registry") diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index 8ad23f92ca..ca717fcde6 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -4,13 +4,15 @@ import warnings import numpy as np +import pandas as pd import torch from anndata import AnnData from sklearn.cluster import KMeans from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager -from scvi.data.fields import CategoricalObsField, LayerField +from scvi.data._utils import _get_adata_minify_type, get_anndata_attribute +from scvi.data.fields import CategoricalJointObsField, CategoricalObsField, LayerField from scvi.model.base import ( BaseMinifiedModeModelClass, RNASeqMixin, @@ -78,6 +80,13 @@ def __init__( n_labels = self.summary_stats.n_labels n_vars = self.summary_stats.n_vars + n_cats_per_cov = ( + self.adata_manager.get_state_registry( + REGISTRY_KEYS.CAT_COVS_KEY + ).n_cats_per_key + if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry + else None + ) if weight_obs: ct_counts = np.unique( self.get_from_registry(adata, REGISTRY_KEYS.LABELS_KEY), @@ -88,6 +97,40 @@ def __init__( ct_prop = ct_prop / np.sum(ct_prop) ct_weight = 1.0 / ct_prop module_kwargs.update({"ct_weight": ct_weight}) + if 'fine_labels' in self.adata_manager.data_registry: + fine_labels = get_anndata_attribute( + adata, + self.adata_manager.data_registry.labels.attr_name, + '_scvi_fine_labels', + ) + coarse_labels = get_anndata_attribute( + adata, + self.adata_manager.data_registry.labels.attr_name, + '_scvi_labels' + ) + + df_ct = pd.DataFrame({ + 'fine_labels_key': fine_labels.ravel(), + 'coarse_labels_key': coarse_labels.ravel()}).drop_duplicates() + fine_labels_mapping = self.adata_manager.get_state_registry( + 'fine_labels' + ).categorical_mapping + coarse_labels_mapping = self.adata_manager.get_state_registry( + REGISTRY_KEYS.LABELS_KEY + ).categorical_mapping + + df_ct['fine_labels'] = fine_labels_mapping[df_ct['fine_labels_key']] + df_ct['coarse_labels'] = coarse_labels_mapping[df_ct['coarse_labels_key']] + + self.df_ct_name_dict = {} + self.df_ct_id_dict = {} + for i, row in df_ct.iterrows(): + count = len(df_ct.loc[:i][df_ct['coarse_labels'] == row['coarse_labels']]) - 1 + self.df_ct_name_dict[row['fine_labels']] = (row['coarse_labels'], row['coarse_labels_key'], count) + self.df_ct_id_dict[row['fine_labels_key']] = (row['coarse_labels'], row['coarse_labels_key'], count) + else: + self.df_ct_name_dict = None + self.df_ct_id_dict = None self.module = self._module_cls( n_input=n_vars, @@ -96,6 +139,8 @@ def __init__( n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, + n_cats_per_cov=n_cats_per_cov, + df_ct_id_dict=self.df_ct_id_dict, **module_kwargs, ) self._model_summary_string = ( @@ -132,6 +177,14 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra adata = self._validate_anndata(adata) + if self.module.prior == "mog": + print('Using MoG') + return ( + self.module.prior_means, + torch.exp(self.module.prior_log_scales)**2 + 1e-4, + torch.nn.functional.softmax(self.module.prior_logits, dim=-1) + ) + # Extracting latent representation of adata including variances. mean_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent)) var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent)) @@ -260,6 +313,8 @@ def setup_anndata( cls, adata: AnnData, labels_key: str | None = None, + fine_labels_key: str | None = None, + categorical_covariate_keys: list[str] | None = None, layer: str | None = None, **kwargs, ): @@ -269,13 +324,25 @@ def setup_anndata( ---------- %(param_adata)s %(param_labels_key)s + fine_labels_key + Key in `adata.obs` where fine-grained labels are stored. + %(param_cat_cov_keys)s %(param_layer)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), + CategoricalObsField('fine_labels', fine_labels_key, required=False), + CategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys + ), ] + + # register new fields if the adata is minified + adata_minify_type = _get_adata_minify_type(adata) + if adata_minify_type is not None: + anndata_fields += cls._get_fields_for_adata_minification(cls, adata_minify_type) adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) diff --git a/scvi/model/_destvi.py b/scvi/model/_destvi.py index 44c0b105ab..b2cb0808e8 100644 --- a/scvi/model/_destvi.py +++ b/scvi/model/_destvi.py @@ -11,7 +11,7 @@ from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data.fields import LayerField, NumericalObsField +from scvi.data.fields import CategoricalJointObsField, LayerField, NumericalObsField from scvi.model import CondSCVI from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin from scvi.module import MRDeconv @@ -77,14 +77,28 @@ def __init__( n_hidden: int, n_latent: int, n_layers: int, + n_cats_per_cov: Sequence[int], dropout_decoder: float, l1_reg: float, + sc_covariate_registry: dict[str, list[str]], **module_kwargs, ): super().__init__(st_adata) + st_covariate_registry = dict(self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings) + assert not set(sc_covariate_registry.keys()) == set(st_covariate_registry.keys()), ( + f'Spatial model has other covariates than single cell model, {set(sc_covariate_registry.keys()).symmetric_difference(st_covariate_registry.keys())}' + ) + for key, value in st_covariate_registry.items(): + assert not set(value).issubset(set(sc_covariate_registry[key])), ( + f'Spatial model has other covariates than single cell model, {set(sc_covariate_registry.keys()).symmetric_difference(st_covariate_registry.keys())}' + ) + self.adata.obsm['_scvi_extra_categorical_covs'][key] = self.adata.obsm['_scvi_extra_categorical_covs'][key].apply( + lambda x: sc_covariate_registry[key].index(value[x]) + ) self.module = self._module_cls( n_spots=st_adata.n_obs, n_labels=cell_type_mapping.shape[0], + n_cats_per_cov=n_cats_per_cov, decoder_state_dict=decoder_state_dict, px_decoder_state_dict=px_decoder_state_dict, px_r=px_r, @@ -131,6 +145,7 @@ def from_rna_model( mapping = sc_model.adata_manager.get_state_registry( REGISTRY_KEYS.LABELS_KEY ).categorical_mapping + dropout_decoder = sc_model.module.dropout_rate if vamp_prior_p is None: mean_vprior = None @@ -149,11 +164,13 @@ def from_rna_model( sc_model.module.n_hidden, sc_model.module.n_latent, sc_model.module.n_layers, + sc_model.module.n_cats_per_cov, mean_vprior=mean_vprior, var_vprior=var_vprior, mp_vprior=mp_vprior, dropout_decoder=dropout_decoder, l1_reg=l1_reg, + sc_covariate_registry=dict(sc_model.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings), **module_kwargs, ) @@ -380,6 +397,7 @@ def setup_anndata( cls, adata: AnnData, layer: str | None = None, + categorical_covariate_keys: Sequence[str] | None = None, **kwargs, ): """%(summary)s. @@ -388,6 +406,8 @@ def setup_anndata( ---------- %(param_adata)s %(param_layer)s + %(param_categorical_covariate_keys)s + Categorical covariate keys need to line up with single cell model. """ setup_method_args = cls._get_setup_method_args(**locals()) # add index for each cell (provided to pyro plate for correct minibatching) @@ -395,6 +415,9 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), + CategoricalJointObsField( + REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys + ), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args diff --git a/scvi/module/_mrdeconv.py b/scvi/module/_mrdeconv.py index 3213df7424..3f889df429 100644 --- a/scvi/module/_mrdeconv.py +++ b/scvi/module/_mrdeconv.py @@ -82,6 +82,7 @@ def __init__( px_decoder_state_dict: OrderedDict, px_r: np.ndarray, dropout_decoder: float, + n_cats_per_cov: Optional[list] = None, dropout_amortization: float = 0.05, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, @@ -90,6 +91,7 @@ def __init__( l1_reg: Tunable[float] = 0.0, beta_reg: Tunable[float] = 5.0, eta_reg: Tunable[float] = 1e-4, + mode: Literal["mog", "normal"] = "normal", extra_encoder_kwargs: Optional[dict] = None, extra_decoder_kwargs: Optional[dict] = None, ): @@ -107,10 +109,12 @@ def __init__( self.eta_reg = eta_reg # unpack and copy parameters _extra_decoder_kwargs = extra_decoder_kwargs or {} + cat_list = [n_labels] + list([] if self.n_cats_per_cov is None else self.n_cats_per_cov) + self.decoder = FCLayers( n_in=n_latent, n_out=n_hidden, - n_cat_list=[n_labels], + n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_decoder, @@ -153,6 +157,10 @@ def __init__( # create additional neural nets for amortization # within cell_type factor loadings _extra_encoder_kwargs = extra_encoder_kwargs or {} + if self.mode == "mog": + return_dist = 3 + else: + return_dist = 1 self.gamma_encoder = torch.nn.Sequential( FCLayers( n_in=self.n_genes, @@ -165,7 +173,7 @@ def __init__( use_batch_norm=False, **_extra_encoder_kwargs, ), - torch.nn.Linear(n_hidden, n_latent * n_labels), + torch.nn.Linear(n_hidden, return_dist * n_latent * n_labels), ) # cell type loadings self.V_encoder = torch.nn.Sequential( diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index d6b0bf920f..2654255fad 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -1,8 +1,9 @@ +from collections.abc import Iterable from typing import Optional import numpy as np import torch -from torch.distributions import Normal +from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal from torch.distributions import kl_divergence as kl from scvi import REGISTRY_KEYS @@ -49,14 +50,19 @@ def __init__( self, n_input: int, n_labels: int = 0, + n_cats_per_cov: Optional[Iterable[int]] = None, n_hidden: Tunable[int] = 128, n_latent: Tunable[int] = 5, n_layers: Tunable[int] = 2, log_variational: bool = True, ct_weight: np.ndarray = None, dropout_rate: Tunable[float] = 0.05, + encode_covariates: bool = False, extra_encoder_kwargs: Optional[dict] = None, extra_decoder_kwargs: Optional[dict] = None, + prior: str = 'normal', + df_ct_id_dict: dict = None, + num_classes_mog: Optional[int] = 10, ): super().__init__() self.dispersion = "gene" @@ -64,12 +70,23 @@ def __init__( self.n_layers = n_layers self.n_hidden = n_hidden self.dropout_rate = dropout_rate + self.encode_covariates = encode_covariates self.log_variational = log_variational self.gene_likelihood = "nb" self.latent_distribution = "normal" # Automatically deactivate if useless self.n_batch = 0 self.n_labels = n_labels + self.prior = prior + self.n_cats_per_cov = n_cats_per_cov + if df_ct_id_dict is not None: + self.num_classes_mog = max([v[2] for v in df_ct_id_dict.values()]) + 1 + mapping_mog = torch.tensor([v[2] for _, v in sorted(df_ct_id_dict.items())]) + self.register_buffer("mapping_mog", mapping_mog) + else: + self.num_classes_mog = num_classes_mog + cat_list = [n_labels] + list([] if self.n_cats_per_cov is None else self.n_cats_per_cov) + encoder_cat_list = cat_list if self.encode_covariates else [n_labels] # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(n_input)) @@ -79,7 +96,7 @@ def __init__( self.z_encoder = Encoder( n_input, n_latent, - n_cat_list=[n_labels], + n_cat_list=encoder_cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, @@ -95,7 +112,7 @@ def __init__( self.decoder = FCLayers( n_in=n_latent, n_out=n_hidden, - n_cat_list=[n_labels], + n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, @@ -113,14 +130,24 @@ def __init__( else: ct_weight = torch.ones((self.n_labels,), dtype=torch.float32) self.register_buffer("ct_weight", ct_weight) + if self.prior=='mog': + self.prior_means = torch.nn.Parameter( + 0.01 * torch.randn([n_labels, self.num_classes_mog, n_latent])) + self.prior_log_scales = torch.nn.Parameter( + torch.zeros([n_labels, self.num_classes_mog, n_latent])) + self.prior_logits = torch.nn.Parameter( + torch.ones([n_labels, self.num_classes_mog])) def _get_inference_input(self, tensors): + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None if self.minified_data_type is None: x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] input_dict = { "x": x, "y": y, + "cat_covs": cat_covs, } else: if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): @@ -131,6 +158,7 @@ def _get_inference_input(self, tensors): "qzm": qzm, "qzv": qzv, "observed_lib_size": observed_lib_size, + "cat_covs": cat_covs, } else: raise NotImplementedError( @@ -143,16 +171,19 @@ def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] library = inference_outputs["library"] y = tensors[REGISTRY_KEYS.LABELS_KEY] + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None input_dict = { "z": z, "library": library, "y": y, + "cat_covs": cat_covs, } return input_dict @auto_move_data - def _regular_inference(self, x, y, n_samples=1): + def _regular_inference(self, x, y, cat_covs=None, n_samples=1): """High level inference method. Runs the inference (encoder) model. @@ -161,8 +192,11 @@ def _regular_inference(self, x, y, n_samples=1): library = x.sum(1).unsqueeze(1) if self.log_variational: x_ = torch.log(1 + x_) - - qz, z = self.z_encoder(x_, y) + if cat_covs is not None and self.encode_covariates: + categorical_input = torch.split(cat_covs, 1, dim=1) + else: + categorical_input = () + qz, z = self.z_encoder(x_, y, *categorical_input) if n_samples > 1: untran_z = qz.sample((n_samples,)) @@ -194,9 +228,13 @@ def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): return outputs @auto_move_data - def generative(self, z, library, y): + def generative(self, z, library, y, cat_covs=None): """Runs the generative model.""" - h = self.decoder(z, y) + if cat_covs is not None: + categorical_input = torch.split(cat_covs, 1, dim=1) + else: + categorical_input = () + h = self.decoder(z, y, *categorical_input) px_scale = self.px_decoder(h) px_rate = library * px_scale px = NegativeBinomial(px_rate, logits=self.px_r) @@ -211,17 +249,41 @@ def loss( ): """Loss computation.""" x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY] + y = tensors[REGISTRY_KEYS.LABELS_KEY].ravel().long() qz = inference_outputs["qz"] px = generative_outputs["px"] - - mean = torch.zeros_like(qz.loc) - scale = torch.ones_like(qz.scale) - - kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) + fine_celltypes = tensors['fine_labels'].ravel().long() if 'fine_labels' in tensors.keys() else None + + if self.prior == "mog": + indexed_means = self.prior_means[y] + indexed_log_scales = self.prior_log_scales[y] + indexed_logits = self.prior_logits[y] + + # Assigns zero meaning equal weight to all unlabeled cells. Otherwise biases to sample from respective MoG. + if fine_celltypes is not None: + logits_input = torch.nn.functional.one_hot( + self.mapping_mog[fine_celltypes], self.num_classes_mog) + cats = Categorical(logits=10*logits_input + indexed_logits) + else: + cats = Categorical(logits=indexed_logits) + normal_dists = torch.distributions.Independent( + Normal( + indexed_means, + torch.exp(indexed_log_scales) + 1e-4 + ), + reinterpreted_batch_ndims=1 + ) + prior = MixtureSameFamily(cats, normal_dists) + u = qz.rsample(sample_shape=(30,)) + # (sample, n_obs, n_latent) -> (sample, n_obs,) + kl_divergence_z = - (prior.log_prob(u) - qz.log_prob(u).sum(-1)).mean(0) + else: + mean = torch.zeros_like(qz.loc) + scale = torch.ones_like(qz.scale) + kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) reconst_loss = -px.log_prob(x).sum(-1) - scaling_factor = self.ct_weight[y.long()[:, 0]] + scaling_factor = self.ct_weight[y] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) return LossOutput( From a8ee394cdcb9256cda309cfec071b3b756f4422d Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Fri, 5 Jan 2024 12:23:37 -0800 Subject: [PATCH 03/12] scvi-hub fixes. --- scvi/data/_utils.py | 6 ++-- scvi/model/_scanvi.py | 5 ++- scvi/model/base/_rnamixin.py | 7 +++- scvi/model/base/_vaemixin.py | 3 ++ scvi/module/_vae.py | 3 +- tests/model/test_models_with_minified_data.py | 33 ++++++++++++++++--- 6 files changed, 45 insertions(+), 12 deletions(-) diff --git a/scvi/data/_utils.py b/scvi/data/_utils.py index f8b2724744..39eb30ffab 100644 --- a/scvi/data/_utils.py +++ b/scvi/data/_utils.py @@ -32,6 +32,7 @@ from scvi import REGISTRY_KEYS, settings from scvi._types import AnnOrMuData, MinifiedDataType +from scvi.data._constants import ADATA_MINIFY_TYPE from . import _constants @@ -324,12 +325,11 @@ def _get_adata_minify_type(adata: AnnData) -> Union[MinifiedDataType, None]: def _is_minified(adata: Union[AnnData, str]) -> bool: uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY - layer_key = REGISTRY_KEYS.X_KEY if isinstance(adata, AnnData): - return adata.layers.get(layer_key, adata.X).sum()==0 + return adata.uns.get(uns_key, None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR elif isinstance(adata, str): with h5py.File(adata) as fp: - return uns_key in read_elem(fp["uns"]).keys() + return read_elem(fp["uns"]).get(uns_key, None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR else: raise TypeError(f"Unsupported type: {type(adata)}") diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index 9058cc1798..e3d5c539c0 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -457,7 +457,8 @@ def setup_anndata( continuous_covariate_keys: list[str] | None = None, **kwargs, ): - """%(summary)s. + """ + %(summary)s. Parameters ---------- @@ -470,8 +471,6 @@ def setup_anndata( %(param_cat_cov_keys)s %(param_cont_cov_keys)s """ - - print("XXXXX", cls._latent_qzm) setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 6f47abde1b..9b1e0a6880 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -16,6 +16,7 @@ from scvi import REGISTRY_KEYS, settings from scvi._types import Number +from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.distributions._utils import DistributionConcatenator, subset_distribution from scvi.model._utils import _get_batch_code_from_category, scrna_raw_counts_properties from scvi.module.base._decorators import _move_data_to_device @@ -37,6 +38,7 @@ def _get_transform_batch_gen_kwargs(self, batch): "Transforming batches is not implemented for this model." ) + @unsupported_if_adata_minified def _get_importance_weights( self, adata: AnnData | None, @@ -325,7 +327,7 @@ def differential_expression( mode: Literal["vanilla", "change"] = "change", delta: float = 0.25, batch_size: int | None = None, - all_stats: bool = True, + all_stats: bool | None = None, batch_correction: bool = False, batchid1: list[str] | None = None, batchid2: list[str] | None = None, @@ -385,6 +387,9 @@ def differential_expression( self.get_latent_representation if filter_outlier_cells else None ) + if all_stats is None: + all_stats = getattr(self, "minified_data_type", None)!=ADATA_MINIFY_TYPE.LATENT_POSTERIOR + result = _de_core( self.get_anndata_manager(adata, required=True), model_fn, diff --git a/scvi/model/base/_vaemixin.py b/scvi/model/base/_vaemixin.py index 29506655af..ded10a8084 100644 --- a/scvi/model/base/_vaemixin.py +++ b/scvi/model/base/_vaemixin.py @@ -129,6 +129,9 @@ def get_reconstruction_error( Indices of cells in adata to use. If `None`, all cells are used. batch_size Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_mean + If False, return the marginal log likelihood for each observation. + Otherwise, return the mmean arginal log likelihood. """ adata = self._validate_anndata(adata) scdl = self._make_data_loader( diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index da133d1c63..c71a53adcf 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -361,6 +361,7 @@ def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): # use dist.sample() rather than rsample because we aren't optimizing the z here untran_z = dist.sample() if n_samples == 1 else dist.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) + qz = Normal(qzm, qzv.sqrt()) library = torch.log(observed_lib_size) if n_samples > 1: library = library.unsqueeze(0).expand( @@ -370,7 +371,7 @@ def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): raise NotImplementedError( f"Unknown minified-data type: {self.minified_data_type}" ) - outputs = {"z": z, "qz_m": qzm, "qz_v": qzv, "ql": None, "library": library} + outputs = {"z": z, "qz": qz, "ql": None, "library": library} return outputs @auto_move_data diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index 02c82ecd1f..a0cddd1722 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -33,6 +33,8 @@ def prep_model(cls=SCVI, layer=None, use_size_factor=False): } if cls == SCANVI: setup_kwargs["unlabeled_category"] = "unknown" + if cls == CondSCVI: + setup_kwargs.pop("batch_key") if use_size_factor: setup_kwargs["size_factor_key"] = "size_factor" cls.setup_anndata( @@ -89,7 +91,7 @@ def run_test_for_model_with_minified_adata( assert adata_orig.layers.keys() == model.adata.layers.keys() orig_obs_df = adata_orig.obs - obs_keys = _SCANVI_OBSERVED_LIB_SIZE if cls == SCANVI else _SCVI_OBSERVED_LIB_SIZE + obs_keys = model._OBSERVED_LIB_SIZE orig_obs_df[obs_keys] = adata_lib_size assert model.adata.obs.equals(orig_obs_df) assert model.adata.var_names.equals(adata_orig.var_names) @@ -101,6 +103,8 @@ def run_test_for_model_with_minified_adata( scvi.settings.seed = 1 keys = ["mean", "dispersions", "dropout"] + if cls == CondSCVI: + keys.remove("dropout") if n_samples == 1: params_latent = model.get_likelihood_parameters( n_samples=n_samples, give_mean=give_mean @@ -171,7 +175,7 @@ def test_condscvi_with_minified_adata_one_sample(): def test_condscvi_downstream(): model, adata, _, adata_before_setup = prep_model(CondSCVI) - zm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) adata.obsm["X_latent_qzm"] = qzm adata.obsm["X_latent_qzv"] = qzv model.minify_adata() @@ -227,8 +231,9 @@ def test_scvi_with_minified_adata_get_normalized_expression(): model.minify_adata('add_posterior_parameters') assert model.minified_data_type == ADATA_MINIFY_TYPE.ADD_POSTERIOR_PARAMETERS - assert np.isfinite(model.adata.get_elbo()) - assert np.isfinite(model.adata.get_reconstruction_error()) + assert np.isfinite(model.get_elbo()) + print('XXXX', model.get_reconstruction_error()) + assert np.isfinite(model.get_reconstruction_error()['reconstruction_loss']) model.minify_adata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR @@ -403,6 +408,26 @@ def test_scvi_with_minified_adata_get_latent_representation(): latent_repr_new = model.get_latent_representation() np.testing.assert_array_equal(latent_repr_new, latent_repr_orig) + + +def test_scvi_with_minified_adata_differential_expression(): + model, _, _, _ = prep_model() + + scvi.settings.seed = 1 + qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) + model.adata.obsm["X_latent_qzm"] = qzm + model.adata.obsm["X_latent_qzv"] = qzv + + scvi.settings.seed = 1 + latent_repr_orig = model.get_latent_representation() + + model.minify_adata() + assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + + scvi.settings.seed = 1 + latent_repr_new = model.get_latent_representation() + + np.testing.assert_array_equal(latent_repr_new, latent_repr_orig) def test_scvi_with_minified_adata_posterior_predictive_sample(): From 422c313573c0b5bc1775995393d42c483c1ad60d Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Fri, 5 Jan 2024 12:26:07 -0800 Subject: [PATCH 04/12] Pre-commit --- scvi/data/_utils.py | 7 ++++-- scvi/model/_condscvi.py | 4 +++- scvi/model/_scanvi.py | 6 +++-- scvi/model/_scvi.py | 6 ++--- scvi/model/base/_base_model.py | 12 ++++++---- scvi/model/base/_rnamixin.py | 5 +++- scvi/module/base/_base_module.py | 5 ++-- scvi/utils/_decorators.py | 5 +++- tests/model/test_models_with_minified_data.py | 23 +++++++++++-------- 9 files changed, 45 insertions(+), 28 deletions(-) diff --git a/scvi/data/_utils.py b/scvi/data/_utils.py index 39eb30ffab..232733cf79 100644 --- a/scvi/data/_utils.py +++ b/scvi/data/_utils.py @@ -326,10 +326,13 @@ def _get_adata_minify_type(adata: AnnData) -> Union[MinifiedDataType, None]: def _is_minified(adata: Union[AnnData, str]) -> bool: uns_key = _constants._ADATA_MINIFY_TYPE_UNS_KEY if isinstance(adata, AnnData): - return adata.uns.get(uns_key, None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR + return adata.uns.get(uns_key, None) == ADATA_MINIFY_TYPE.LATENT_POSTERIOR elif isinstance(adata, str): with h5py.File(adata) as fp: - return read_elem(fp["uns"]).get(uns_key, None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR + return ( + read_elem(fp["uns"]).get(uns_key, None) + == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + ) else: raise TypeError(f"Unsupported type: {type(adata)}") diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index 8ad23f92ca..9e1bcc2f75 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -24,7 +24,9 @@ logger = logging.getLogger(__name__) -class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseMinifiedModeModelClass): +class CondSCVI( + RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseMinifiedModeModelClass +): """Conditional version of single-cell Variational Inference, used for multi-resolution deconvolution of spatial transcriptomics data :cite:p:`Lopez22`. Parameters diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index e3d5c539c0..93ba30dcba 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -491,9 +491,11 @@ def setup_anndata( # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: - anndata_fields += cls._get_fields_for_adata_minification(cls, adata_minify_type) + anndata_fields += cls._get_fields_for_adata_minification( + cls, adata_minify_type + ) adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) - cls.register_manager(adata_manager) \ No newline at end of file + cls.register_manager(adata_manager) diff --git a/scvi/model/_scvi.py b/scvi/model/_scvi.py index 074eec3c1a..f84716d44d 100644 --- a/scvi/model/_scvi.py +++ b/scvi/model/_scvi.py @@ -196,11 +196,11 @@ def setup_anndata( # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: - anndata_fields += cls._get_fields_for_adata_minification(cls, adata_minify_type) + anndata_fields += cls._get_fields_for_adata_minification( + cls, adata_minify_type + ) adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - - diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index 1691f261d7..41b2d2a19b 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -535,12 +535,14 @@ def minify_adata( elif minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR: keep_count_data = False - if not getattr(self.module, 'use_observed_lib_size', True): + if not getattr(self.module, "use_observed_lib_size", True): raise ValueError( "Cannot minify the data if `use_observed_lib_size` is False" ) - minified_adata = get_minified_adata_scrna(self.adata, minified_data_type, keep_count_data=keep_count_data) + minified_adata = get_minified_adata_scrna( + self.adata, minified_data_type, keep_count_data=keep_count_data + ) minified_adata.obsm[self._LATENT_QZM] = self.adata.obsm[use_latent_qzm_key] minified_adata.obsm[self._LATENT_QZV] = self.adata.obsm[use_latent_qzv_key] counts = self.adata_manager.get_from_registry(REGISTRY_KEYS.X_KEY) @@ -557,7 +559,9 @@ def _get_fields_for_adata_minification( minified_data_type: MinifiedDataType, ) -> list[BaseAnnDataField]: """Return the anndata fields required for adata minification of the given minified_data_type.""" - assert self._LATENT_QZM, NotImplementedError("Minified mode is not defined for model.") + assert self._LATENT_QZM, NotImplementedError( + "Minified mode is not defined for model." + ) if ADATA_MINIFY_TYPE.__contains__(minified_data_type): fields = [ ObsmField( @@ -583,8 +587,6 @@ def _get_fields_for_adata_minification( ) return fields - - @property def is_trained(self) -> bool: """Whether the model has been trained.""" diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 9b1e0a6880..f9dbcb5686 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -388,7 +388,10 @@ def differential_expression( ) if all_stats is None: - all_stats = getattr(self, "minified_data_type", None)!=ADATA_MINIFY_TYPE.LATENT_POSTERIOR + all_stats = ( + getattr(self, "minified_data_type", None) + != ADATA_MINIFY_TYPE.LATENT_POSTERIOR + ) result = _de_core( self.get_anndata_manager(adata, required=True), diff --git a/scvi/module/base/_base_module.py b/scvi/module/base/_base_module.py index a8352b8c1a..8b21478cf3 100644 --- a/scvi/module/base/_base_module.py +++ b/scvi/module/base/_base_module.py @@ -299,9 +299,8 @@ def inference(self, *args, **kwargs): Branches off to regular or cached inference depending on whether we have a minified adata that contains the latent posterior parameters. """ - if ( - self.minified_data_type is not None - and ADATA_MINIFY_TYPE.__contains__(self.minified_data_type) + if self.minified_data_type is not None and ADATA_MINIFY_TYPE.__contains__( + self.minified_data_type ): return self._cached_inference(*args, **kwargs) else: diff --git a/scvi/utils/_decorators.py b/scvi/utils/_decorators.py index 8618e866a0..0c4390ae5d 100644 --- a/scvi/utils/_decorators.py +++ b/scvi/utils/_decorators.py @@ -9,7 +9,10 @@ def unsupported_if_adata_minified(fn: Callable) -> Callable: @wraps(fn) def wrapper(self, *args, **kwargs): - if getattr(self, "minified_data_type", None)==ADATA_MINIFY_TYPE.LATENT_POSTERIOR: + if ( + getattr(self, "minified_data_type", None) + == ADATA_MINIFY_TYPE.LATENT_POSTERIOR + ): raise ValueError( f"The {fn.__qualname__} function currently does not support minified data." ) diff --git a/tests/model/test_models_with_minified_data.py b/tests/model/test_models_with_minified_data.py index a0cddd1722..4b3e8492e3 100644 --- a/tests/model/test_models_with_minified_data.py +++ b/tests/model/test_models_with_minified_data.py @@ -5,7 +5,7 @@ from scvi.data import synthetic_iid from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _is_minified -from scvi.model import SCANVI, SCVI, CondSCVI, DestVI +from scvi.model import SCANVI, SCVI, CondSCVI _SCVI_OBSERVED_LIB_SIZE = "_scvi_observed_lib_size" _SCANVI_OBSERVED_LIB_SIZE = "_scanvi_observed_lib_size" @@ -160,7 +160,8 @@ def test_scanvi_with_minified_adata_n_samples(): run_test_for_model_with_minified_adata( SCANVI, n_samples=10, give_mean=True, use_size_factor=True ) - + + def test_condscvi_with_minified_adata_one_sample(): run_test_for_model_with_minified_adata(CondSCVI) @@ -169,10 +170,10 @@ def test_condscvi_with_minified_adata_one_sample(): run_test_for_model_with_minified_adata(CondSCVI, layer="data_layer") -def test_condscvi_with_minified_adata_one_sample(): +def test_condscvi_with_minified_adata_n_samples(): run_test_for_model_with_minified_adata(CondSCVI, n_samples=10, give_mean=True) - + def test_condscvi_downstream(): model, adata, _, adata_before_setup = prep_model(CondSCVI) qzm, qzv = model.get_latent_representation(give_mean=False, return_dist=True) @@ -181,7 +182,9 @@ def test_condscvi_downstream(): model.minify_adata() model.get_vamp_prior() scvi.model.DestVI.setup_anndata(adata_before_setup) - scvi.model.DestVI.from_rna_model(adata_before_setup, model, amortization="both", vamp_prior_p=10) + scvi.model.DestVI.from_rna_model( + adata_before_setup, model, amortization="both", vamp_prior_p=10 + ) def test_scanvi_from_scvi(save_path): @@ -229,11 +232,11 @@ def test_scvi_with_minified_adata_get_normalized_expression(): scvi.settings.seed = 1 exprs_orig = model.get_normalized_expression() - model.minify_adata('add_posterior_parameters') + model.minify_adata("add_posterior_parameters") assert model.minified_data_type == ADATA_MINIFY_TYPE.ADD_POSTERIOR_PARAMETERS assert np.isfinite(model.get_elbo()) - print('XXXX', model.get_reconstruction_error()) - assert np.isfinite(model.get_reconstruction_error()['reconstruction_loss']) + print("XXXX", model.get_reconstruction_error()) + assert np.isfinite(model.get_reconstruction_error()["reconstruction_loss"]) model.minify_adata() assert model.minified_data_type == ADATA_MINIFY_TYPE.LATENT_POSTERIOR @@ -408,8 +411,8 @@ def test_scvi_with_minified_adata_get_latent_representation(): latent_repr_new = model.get_latent_representation() np.testing.assert_array_equal(latent_repr_new, latent_repr_orig) - - + + def test_scvi_with_minified_adata_differential_expression(): model, _, _, _ = prep_model() From 28ee4689e04571c5ead4c5655819bc33b49593af Mon Sep 17 00:00:00 2001 From: cane11 Date: Sat, 13 Jan 2024 13:31:48 -0800 Subject: [PATCH 05/12] DestVI experiments. --- scvi/model/_condscvi.py | 147 +++++++++++++--------- scvi/model/_destvi.py | 80 ++++++++++-- scvi/module/_mrdeconv.py | 254 ++++++++++++++++++++++++++++++--------- scvi/module/_vaec.py | 2 +- 4 files changed, 352 insertions(+), 131 deletions(-) diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index ca717fcde6..e6ad14c2e4 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -149,7 +149,7 @@ def __init__( self.init_params_ = self._get_init_params(locals()) @torch.inference_mode() - def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarray: + def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10, scales_prior_n_samples: int | None = None, default_cat: list | None = None) -> np.ndarray: r"""Return an empirical prior over the cell-type specific latent space (vamp prior) that may be used for deconvolution. Parameters @@ -159,6 +159,10 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra AnnData object used to initialize the model. p number of clusters in kmeans clustering for cell-type sub-clustering for empirical prior + scales_prior_n_samples + return scales of negative binomial distribution for calculates prior means and variances using n_samples. + default_cat + default value for categorical covariates Returns ------- @@ -166,6 +170,10 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra (n_labels, p, D) array var_vprior (n_labels, p, D) array + weights_vprior + (n_labels, p) array + scales_vprior + (n_labels, p, G) array """ if self.is_trained_ is False: warnings.warn( @@ -178,71 +186,90 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra adata = self._validate_anndata(adata) if self.module.prior == "mog": - print('Using MoG') - return ( - self.module.prior_means, - torch.exp(self.module.prior_log_scales)**2 + 1e-4, - torch.nn.functional.softmax(self.module.prior_logits, dim=-1) - ) + results = { + "mean_vprior": self.module.prior_means, + "var_vprior": torch.exp(self.module.prior_log_scales)**2 + 1e-4, + "weights_vprior": torch.nn.functional.softmax(self.module.prior_logits, dim=-1) + } + else: - # Extracting latent representation of adata including variances. - mean_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent)) - var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent)) - mp_vprior = np.zeros((self.summary_stats.n_labels, p)) + # Extracting latent representation of adata including variances. + mean_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent)) + var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent)) + mp_vprior = np.zeros((self.summary_stats.n_labels, p)) - labels_state_registry = self.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) - key = labels_state_registry.original_key - mapping = labels_state_registry.categorical_mapping - - mean_cat, var_cat = self.get_latent_representation(adata, return_dist=True) - - for ct in range(self.summary_stats["n_labels"]): - local_indices = np.where(adata.obs[key] == mapping[ct])[0] - n_local_indices = len(local_indices) - if "overclustering_vamp" not in adata.obs.columns: - if p < n_local_indices and p > 0: - overclustering_vamp = KMeans(n_clusters=p, n_init=30).fit_predict( - mean_cat[local_indices] - ) - else: - # Every cell is its own cluster - overclustering_vamp = np.arange(n_local_indices) - else: - overclustering_vamp = adata[local_indices, :].obs["overclustering_vamp"] - - keys, counts = np.unique(overclustering_vamp, return_counts=True) - - n_labels_overclustering = len(keys) - if n_labels_overclustering > p: - error_mess = """ - Given cell type specific clustering contains more clusters than vamp_prior_p. - Increase value of vamp_prior_p to largest number of cell type specific clusters.""" - - raise ValueError(error_mess) - - var_cluster = np.ones( - [ - n_labels_overclustering, - self.module.n_latent, - ] + labels_state_registry = self.adata_manager.get_state_registry( + REGISTRY_KEYS.LABELS_KEY ) - mean_cluster = np.zeros_like(var_cluster) + key = labels_state_registry.original_key + mapping = labels_state_registry.categorical_mapping + + mean_cat, var_cat = self.get_latent_representation(adata, return_dist=True) + + for ct in range(self.summary_stats["n_labels"]): + local_indices = np.where(adata.obs[key] == mapping[ct])[0] + n_local_indices = len(local_indices) + if "overclustering_vamp" not in adata.obs.columns: + if p < n_local_indices and p > 0: + overclustering_vamp = KMeans(n_clusters=p, n_init=30).fit_predict( + mean_cat[local_indices] + ) + else: + # Every cell is its own cluster + overclustering_vamp = np.arange(n_local_indices) + else: + overclustering_vamp = adata[local_indices, :].obs["overclustering_vamp"] - for index, cluster in enumerate(keys): - indices_curr = local_indices[np.where(overclustering_vamp == cluster)[0]] - var_cluster[index, :] = np.mean(var_cat[indices_curr], axis=0) + np.var( - mean_cat[indices_curr], axis=0 - ) - mean_cluster[index, :] = np.mean(mean_cat[indices_curr], axis=0) + keys, counts = np.unique(overclustering_vamp, return_counts=True) + + n_labels_overclustering = len(keys) + if n_labels_overclustering > p: + error_mess = """ + Given cell type specific clustering contains more clusters than vamp_prior_p. + Increase value of vamp_prior_p to largest number of cell type specific clusters.""" - slicing = slice(n_labels_overclustering) - mean_vprior[ct, slicing, :] = mean_cluster - var_vprior[ct, slicing, :] = var_cluster - mp_vprior[ct, slicing] = counts / sum(counts) + raise ValueError(error_mess) - return mean_vprior, var_vprior, mp_vprior + var_cluster = np.ones( + [ + n_labels_overclustering, + self.module.n_latent, + ] + ) + mean_cluster = np.zeros_like(var_cluster) + + for index, cluster in enumerate(keys): + indices_curr = local_indices[np.where(overclustering_vamp == cluster)[0]] + var_cluster[index, :] = np.mean(var_cat[indices_curr], axis=0) + np.var( + mean_cat[indices_curr], axis=0 + ) + mean_cluster[index, :] = np.mean(mean_cat[indices_curr], axis=0) + + slicing = slice(n_labels_overclustering) + mean_vprior[ct, slicing, :] = mean_cluster + var_vprior[ct, slicing, :] = var_cluster + mp_vprior[ct, slicing] = counts / sum(counts) + results = { + "mean_vprior": mean_vprior, + "var_vprior": var_vprior, + "weights_vprior": mp_vprior + } + + if scales_prior_n_samples is not None: + scales_vprior = np.zeros((self.summary_stats.n_labels, p, self.summary_stats.n_vars)) + cat_covs = [ + torch.full([scales_prior_n_samples, 1], float(np.where(value==default_cat[ind])[0]) if default_cat else 0, device=self.module.device) + for ind, value in enumerate(self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings.values())] + for ct in range(self.summary_stats["n_labels"]): + for cluster in range(p): + sampled_z = torch.distributions.Normal( + results['mean_vprior'][ct, cluster, :], torch.sqrt(results['var_vprior'][ct, cluster, :]) + ).sample([scales_prior_n_samples,]).to(self.module.device) + h = self.module.decoder(sampled_z, torch.full([scales_prior_n_samples, 1], ct, device=self.module.device), *cat_covs) + scales_vprior[ct, cluster, :] = self.module.px_decoder(h).mean(0).cpu() + results["scales_vprior"] = scales_vprior + + return results @devices_dsp.dedent def train( diff --git a/scvi/model/_destvi.py b/scvi/model/_destvi.py index b2cb0808e8..b859ca8413 100644 --- a/scvi/model/_destvi.py +++ b/scvi/model/_destvi.py @@ -84,17 +84,21 @@ def __init__( **module_kwargs, ): super().__init__(st_adata) - st_covariate_registry = dict(self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings) - assert not set(sc_covariate_registry.keys()) == set(st_covariate_registry.keys()), ( - f'Spatial model has other covariates than single cell model, {set(sc_covariate_registry.keys()).symmetric_difference(st_covariate_registry.keys())}' - ) - for key, value in st_covariate_registry.items(): - assert not set(value).issubset(set(sc_covariate_registry[key])), ( - f'Spatial model has other covariates than single cell model, {set(sc_covariate_registry.keys()).symmetric_difference(st_covariate_registry.keys())}' + if sc_covariate_registry is not None: + st_covariate_registry = dict(self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings) + sc_covariate_mappings = dict(sc_covariate_registry.mappings) + assert set(sc_covariate_mappings.keys()) == set(st_covariate_registry.keys()), ( + f'Spatial model has other covariates than single cell model, {set(sc_covariate_mappings.keys()).symmetric_difference(st_covariate_registry.keys())}' ) - self.adata.obsm['_scvi_extra_categorical_covs'][key] = self.adata.obsm['_scvi_extra_categorical_covs'][key].apply( - lambda x: sc_covariate_registry[key].index(value[x]) + for key, value in st_covariate_registry.items(): + assert set(value).issubset(set(sc_covariate_mappings[key])), ( + f'Spatial model has other covariates than single cell model, {set(value) - set(sc_covariate_mappings[key])}') + n_cats_per_cov = ( + sc_covariate_registry.n_cats_per_key + if sc_covariate_registry + else None ) + self.module = self._module_cls( n_spots=st_adata.n_obs, n_labels=cell_type_mapping.shape[0], @@ -153,7 +157,15 @@ def from_rna_model( else: mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior( sc_model.adata, p=vamp_prior_p + ).values() + + sc_covariate_registry = ( + sc_model.adata_manager.get_state_registry( + REGISTRY_KEYS.CAT_COVS_KEY ) + if REGISTRY_KEYS.CAT_COVS_KEY in sc_model.adata_manager.data_registry + else None + ) return cls( st_adata, @@ -170,7 +182,7 @@ def from_rna_model( mp_vprior=mp_vprior, dropout_decoder=dropout_decoder, l1_reg=l1_reg, - sc_covariate_registry=dict(sc_model.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings), + sc_covariate_registry=sc_covariate_registry, **module_kwargs, ) @@ -279,6 +291,49 @@ def get_gamma( ) return res + def get_latent_amortization( + self, + indices: Sequence[int] | None = None, + batch_size: int | None = None, + return_numpy: bool = False, + ) -> np.ndarray | dict[str, pd.DataFrame]: + """Returns the amortized latent space for the spatial data. + + Parameters + ---------- + indices + Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used. + batch_size + Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`. + return_numpy + if activated, will return a numpy array of shape is n_spots x n_latent x n_labels. + """ + self._check_if_trained() + + if self.module.n_latent_amortization is None or self.module.amortization in ["none"]: + ValueError('Get latent amortization is not defined if n_latent_amortization is None or no amortization is used') + + column_names = np.arange(self.module.n_latent_amortization) + index_names = self.adata.obs.index + + stdl = self._make_data_loader( + self.adata, indices=indices, batch_size=batch_size + ) + + amortization = [] + for tensors in stdl: + generative_inputs = self.module._get_generative_input(tensors, None) + x = generative_inputs["x"] + z_amortization = self.module.get_latent_amortization(x) + amortization += [z_amortization.cpu()] + + data = torch.cat(amortization).numpy() + column_names = np.arange(self.module.n_latent_amortization) + index_names = self.adata.obs.index + if indices is not None: + index_names = index_names[indices] + return pd.DataFrame(data=data, columns=column_names, index=index_names) + def get_scale_for_ct( self, label: str, @@ -310,11 +365,12 @@ def get_scale_for_ct( scale = [] for tensors in stdl: generative_inputs = self.module._get_generative_input(tensors, None) - x, ind_x = ( + x, ind_x, cat_covs = ( generative_inputs["x"], generative_inputs["ind_x"], + generative_inputs["cat_covs"], ) - px_scale = self.module.get_ct_specific_expression(x, ind_x, y) + px_scale = self.module.get_ct_specific_expression(x, ind_x, y, cat_covs) scale += [px_scale.cpu()] data = torch.cat(scale).numpy() diff --git a/scvi/module/_mrdeconv.py b/scvi/module/_mrdeconv.py index 3f889df429..0f3a4188e5 100644 --- a/scvi/module/_mrdeconv.py +++ b/scvi/module/_mrdeconv.py @@ -4,6 +4,7 @@ import numpy as np import torch from torch.distributions import Normal +from torch.distributions import kl_divergence as kl from scvi import REGISTRY_KEYS from scvi._types import Tunable @@ -90,8 +91,9 @@ def __init__( amortization: Literal["none", "latent", "proportion", "both"] = "both", l1_reg: Tunable[float] = 0.0, beta_reg: Tunable[float] = 5.0, - eta_reg: Tunable[float] = 1e-4, - mode: Literal["mog", "normal"] = "normal", + eta_reg: Tunable[float] = 1e-7, + prior_mode: Literal["mog", "normal"] = "normal", + n_latent_amortization: Optional[int] = None, extra_encoder_kwargs: Optional[dict] = None, extra_decoder_kwargs: Optional[dict] = None, ): @@ -107,9 +109,11 @@ def __init__( self.l1_reg = l1_reg self.beta_reg = beta_reg self.eta_reg = eta_reg + self.prior_mode = prior_mode + self.n_latent_amortization = n_latent_amortization # unpack and copy parameters _extra_decoder_kwargs = extra_decoder_kwargs or {} - cat_list = [n_labels] + list([] if self.n_cats_per_cov is None else self.n_cats_per_cov) + cat_list = [n_labels] + list([] if n_cats_per_cov is None else n_cats_per_cov) self.decoder = FCLayers( n_in=n_latent, @@ -157,35 +161,55 @@ def __init__( # create additional neural nets for amortization # within cell_type factor loadings _extra_encoder_kwargs = extra_encoder_kwargs or {} - if self.mode == "mog": - return_dist = 3 + if self.prior_mode == "mog": + return_dist = self.p * n_labels * n_latent + self.p * n_labels else: - return_dist = 1 + return_dist = n_labels * n_latent + if self.n_latent_amortization is not None: + # Uses a combined latent space for proportions and gammas. + self.amortization_network = torch.nn.Sequential( + FCLayers( + n_in=self.n_genes, + n_out=n_hidden, + n_cat_list=None, + n_layers=1, + n_hidden=n_hidden, + dropout_rate=dropout_amortization, + use_layer_norm=True, + use_batch_norm=False, + ), + torch.nn.Linear(n_hidden, 2 * self.n_latent_amortization), + ) + n_layers = 2 + + else: + self.amortization_network = torch.nn.Identity() + n_latent_amortization = self.n_genes + n_layers = 2 self.gamma_encoder = torch.nn.Sequential( FCLayers( - n_in=self.n_genes, + n_in=n_latent_amortization, n_out=n_hidden, n_cat_list=None, - n_layers=2, + n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_amortization, use_layer_norm=True, use_batch_norm=False, - **_extra_encoder_kwargs, ), - torch.nn.Linear(n_hidden, return_dist * n_latent * n_labels), + torch.nn.Linear(n_hidden, return_dist), ) # cell type loadings self.V_encoder = torch.nn.Sequential( FCLayers( - n_in=self.n_genes, + n_in=n_latent_amortization, n_out=n_hidden, - n_layers=2, + n_cat_list=None, + n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_amortization, use_layer_norm=True, use_batch_norm=False, - **_extra_encoder_kwargs, ), torch.nn.Linear(n_hidden, n_labels + 1), ) @@ -197,8 +221,10 @@ def _get_inference_input(self, tensors): def _get_generative_input(self, tensors, inference_outputs): x = tensors[REGISTRY_KEYS.X_KEY] ind_x = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel() + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = tensors[cat_key] + 2. if cat_key in tensors.keys() else None - input_dict = {"x": x, "ind_x": ind_x} + input_dict = {"x": x, "ind_x": ind_x, "cat_covs": cat_covs} return input_dict @auto_move_data @@ -207,7 +233,7 @@ def inference(self): return {} @auto_move_data - def generative(self, x, ind_x): + def generative(self, x, ind_x, cat_covs=None): """Build the deconvolution model for every cell in the minibatch.""" m = x.shape[0] library = torch.sum(x, dim=1, keepdim=True) @@ -215,35 +241,62 @@ def generative(self, x, ind_x): beta = torch.exp(self.beta) # n_genes eps = torch.nn.functional.softplus(self.eta) # n_genes x_ = torch.log(1 + x) - # subsample parameters + z_amortization_params = self.amortization_network(x_) + qz_amortization = Normal( + z_amortization_params[:, :self.n_latent_amortization], + torch.exp(z_amortization_params[:, self.n_latent_amortization:])+1e-4 + ) + z_amortization = qz_amortization.rsample() if self.amortization in ["both", "latent"]: - gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape( - (self.n_latent, self.n_labels, -1) - ) + if self.prior_mode == "mog": + gamma_ = self.gamma_encoder(z_amortization) + proportion_modes_logits = torch.transpose( + gamma_[:, -self.p*self.n_labels:], 0, 1).reshape( + (self.p, self.n_labels, m) + ).transpose(1, 2) + proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=0) + gamma_ind = torch.transpose( + gamma_[:, :self.p*self.n_labels*self.n_latent], 0, 1).reshape( + (self.p, self.n_latent, self.n_labels, -1) + ) + else: + gamma_ind = torch.transpose( + self.gamma_encoder(z_amortization), 0, 1).reshape( + (1, self.n_latent, self.n_labels, -1) + ) + proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=x.device) else: - gamma_ind = self.gamma[:, :, ind_x] # n_latent, n_labels, minibatch_size + gamma_ind = self.gamma[:, :, ind_x].unsqueeze(0) # 1, n_latent, n_labels, minibatch_size + proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=x.device) if self.amortization in ["both", "proportion"]: - v_ind = self.V_encoder(x_) + v_ind = self.V_encoder(z_amortization) else: v_ind = self.V[:, ind_x].T # minibatch_size, labels + 1 v_ind = torch.nn.functional.softplus(v_ind) - # reshape and get gene expression value for all minibatch - gamma_ind = torch.transpose( - gamma_ind, 2, 0 - ) # minibatch_size, n_labels, n_latent - gamma_reshape = gamma_ind.reshape( - (-1, self.n_latent) - ) # minibatch_size * n_labels, n_latent + px_est = torch.zeros((x.shape[0], self.n_labels, self.n_genes), device=x.device) enum_label = ( torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) ) # minibatch_size * n_labels, 1 - h = self.decoder(gamma_reshape, enum_label.to(x.device)) - px_rate = self.px_decoder(h).reshape( - (m, self.n_labels, -1) - ) # (minibatch, n_labels, n_genes) + if cat_covs is not None: + categorical_input = [i.repeat_interleave(self.n_labels, dim=0) for i in torch.split(cat_covs, 1, dim=1)] + else: + categorical_input = () + + for mode in range(gamma_ind.shape[0]): + # reshape and get gene expression value for all minibatch + gamma_ind_ = torch.transpose( + gamma_ind[mode, ...], 2, 0 + ) # minibatch_size, n_labels, n_latent + gamma_reshape_ = gamma_ind_.reshape( + (-1, self.n_latent) + ) # minibatch_size * n_labels, n_latent + h = self.decoder(gamma_reshape_, enum_label.to(x.device), *categorical_input) + px_est += proportion_modes[mode, ...].unsqueeze(-1) * self.px_decoder(h).reshape( + (m, self.n_labels, -1) + ) # (minibatch, n_labels, n_genes) # add the dummy cell type eps = eps.repeat((m, 1)).view( @@ -252,7 +305,7 @@ def generative(self, x, ind_x): # account for gene specific bias and add noise r_hat = torch.cat( - [beta.unsqueeze(0).unsqueeze(1) * px_rate, eps], dim=1 + [beta.unsqueeze(0).unsqueeze(1) * px_est, eps], dim=1 ) # M, n_labels + 1, n_genes # now combine them for convolution px_scale = torch.sum(v_ind.unsqueeze(2) * r_hat, dim=1) # batch_size, n_genes @@ -264,6 +317,9 @@ def generative(self, x, ind_x): "px_scale": px_scale, "gamma": gamma_ind, "v": v_ind, + "proportion_modes": proportion_modes, + "proportion_modes_logits": proportion_modes_logits, + "qz_amortization": qz_amortization, } def loss( @@ -273,6 +329,7 @@ def loss( generative_outputs, kl_weight: float = 1.0, n_obs: int = 1.0, + weighting_cross_entropy: float = 1.0 ): """Compute the loss.""" x = tensors[REGISTRY_KEYS.X_KEY] @@ -301,8 +358,8 @@ def loss( neg_log_likelihood_prior = -Normal(mean, scale).log_prob(gamma).sum(2).sum(1) else: # vampprior - # gamma is of shape n_latent, n_labels, minibatch_size - gamma = gamma.unsqueeze(1) # minibatch_size, 1, n_labels, n_latent + # gamma is of shape minibatch_size, 1, n_latent, n_labels + gamma = gamma.permute(3, 0, 2, 1) # minibatch_size, 1, n_labels, n_latent mean_vprior = torch.transpose(self.mean_vprior, 0, 1).unsqueeze( 0 ) # 1, p, n_labels, n_latent @@ -310,13 +367,28 @@ def loss( 0 ) # 1, p, n_labels, n_latent mp_vprior = torch.transpose(self.mp_vprior, 0, 1) # p, n_labels - pre_lse = ( - Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) - ) + torch.log(mp_vprior) # minibatch, p, n_labels - # Pseudocount for numerical stability - log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels - neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch - # mean_vprior is of shape n_labels, p, n_latent + if self.prior_mode == "mog": + proportion_modes_logits = generative_outputs["proportion_modes_logits"] + proportion_modes = generative_outputs["proportion_modes"] + pre_lse = ( + Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) + ) + torch.log(proportion_modes).permute(1, 0, 2) # minibatch, p, n_labels + log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels + neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch + + neg_log_likelihood_prior += weighting_cross_entropy * torch.nn.functional.cross_entropy( + proportion_modes_logits.permute(1, 0, 2), mp_vprior.repeat(x.shape[0], 1, 1), reduction='none').sum(1) + else: + pre_lse = ( + Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) + ) + torch.log(mp_vprior) # minibatch, p, n_labels + log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels + neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch + + neg_log_likelihood_prior += kl( + generative_outputs["qz_amortization"], + Normal(torch.zeros([self.n_latent_amortization], device=x.device), torch.ones([self.n_latent_amortization], device=x.device)) + ).sum(dim=-1) # High v_sparsity_loss is detrimental early in training, scaling by kl_weight to increase over training epochs. loss = n_obs * ( @@ -350,7 +422,8 @@ def get_proportions(self, x=None, keep_noise=False) -> np.ndarray: if self.amortization in ["both", "proportion"]: # get estimated unadjusted proportions x_ = torch.log(1 + x) - res = torch.nn.functional.softplus(self.V_encoder(x_)) + z_amortization = self.amortization_network(x_)[:, :self.n_latent_amortization] + res = torch.nn.functional.softplus(self.V_encoder(z_amortization)) else: res = ( torch.nn.functional.softplus(self.V).cpu().numpy().T @@ -375,7 +448,19 @@ def get_gamma(self, x: torch.Tensor = None) -> torch.Tensor: # get estimated unadjusted proportions if self.amortization in ["latent", "both"]: x_ = torch.log(1 + x) - gamma = self.gamma_encoder(x_) + z_amortization = self.amortization_network(x_)[:, :self.n_latent_amortization] + gamma = self.gamma_encoder(z_amortization) + if self.prior_mode == "mog": + proportion_modes_logits = torch.transpose( + gamma[:, -self.p*self.n_labels:], 0, 1).reshape( + (self.p, 1, self.n_labels, x.shape[0]) + ) + proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=0) + gamma_ind = torch.transpose( + gamma[:, :self.p*self.n_labels*self.n_latent], 0, 1).reshape( + (self.p, self.n_latent, self.n_labels, x.shape[0]) + ) + gamma = torch.sum(proportion_modes * gamma_ind, dim=0) return torch.transpose(gamma, 0, 1).reshape( (self.n_latent, self.n_labels, -1) ) # n_latent, n_labels, minibatch @@ -385,7 +470,7 @@ def get_gamma(self, x: torch.Tensor = None) -> torch.Tensor: @torch.inference_mode() @auto_move_data def get_ct_specific_expression( - self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None + self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None, cat_covs: torch.Tensor = None ): """Returns cell type specific gene expression at the queried spots. @@ -397,24 +482,77 @@ def get_ct_specific_expression( tensor of indices y integer for cell types + cat_covs + tensor of categorical covariates """ # cell-type specific gene expression, shape (minibatch, celltype, gene). beta = torch.exp(self.beta) # n_genes y_torch = (y * torch.ones_like(ind_x)).ravel() + if cat_covs is not None: + categorical_input = torch.split(cat_covs, 1, dim=1) + else: + categorical_input = () # obtain the relevant gammas if self.amortization in ["both", "latent"]: x_ = torch.log(1 + x) - gamma_ind = torch.transpose(self.gamma_encoder(x_), 0, 1).reshape( - (self.n_latent, self.n_labels, -1) - ) + z_amortization = self.amortization_network(x_)[:, :self.n_latent_amortization] + if self.prior_mode == "mog": + gamma_ = self.gamma_encoder(z_amortization) + proportion_modes_logits = torch.transpose( + gamma_[:, -self.p*self.n_labels:], 0, 1).reshape( + (self.p, 1, -1) + ).transpose(1, 2) + proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=0) + # shape (p, n_labels, minibatch_size) + gamma_ind = torch.transpose( + gamma_[:, :self.p*self.n_labels*self.n_latent], 0, 1).reshape( + (self.p, self.n_latent, self.n_labels, -1) + ) + else: + gamma_ind = torch.transpose( + self.gamma_encoder(z_amortization), 0, 1).reshape( + (1, self.n_latent, self.n_labels, -1) + ) + proportion_modes = torch.ones((1, self.n_labels), device=x.device) else: - gamma_ind = self.gamma[:, :, ind_x] # n_latent, n_labels, minibatch_size - - # calculate cell type specific expression - gamma_select = gamma_ind[ - :, y_torch, torch.arange(ind_x.shape[0]) - ].T # minibatch_size, n_latent - h = self.decoder(gamma_select, y_torch.unsqueeze(1)) - px_scale = self.px_decoder(h) # (minibatch, n_genes) - px_ct = torch.exp(self.px_o).unsqueeze(0) * beta.unsqueeze(0) * px_scale - return px_ct # shape (minibatch, genes) + gamma_ind = self.gamma[:, :, ind_x].unsqueeze(0) # 1, n_latent, n_labels, minibatch_size + proportion_modes = torch.ones((1, self.n_labels), device=x.device) + gamma_ind = gamma_ind[:, :, y, :] + proportion_modes = proportion_modes[:, y] + + px_est = torch.zeros((x.shape[0], self.n_genes), device=x.device) + for mode in range(gamma_ind.shape[0]): + # reshape and get gene expression value for all minibatch + gamma_ind_ = torch.transpose( + gamma_ind[mode, ...], 1, 0 + ) # minibatch_size, n_latent + h = self.decoder(gamma_ind_, y_torch.unsqueeze(1), *categorical_input) + px_est += proportion_modes[mode, ...].unsqueeze(-1) * self.px_decoder(h).reshape( + (x.shape[0], -1) + ) # (minibatch, n_genes) + + px_scale_ct = torch.exp(self.px_o).unsqueeze(0) * beta.unsqueeze(0) * px_est + return px_scale_ct # shape (minibatch, genes) + + @torch.inference_mode() + @auto_move_data + def get_latent_amortization( + self, x: torch.Tensor = None + ): + """ + Returns cell type specific latent representation at the queried spots. + + Parameters + ---------- + x + tensor of data + ind_x + tensor of indices + y + integer for cell types + """ + # cell-type specific gene expression, shape (minibatch, celltype, gene). + x_ = torch.log(1 + x) + z_amortized = self.amortization_network(x_) + + return z_amortized # shape (minibatch, genes) diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index 2654255fad..73d90002ec 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -238,7 +238,7 @@ def generative(self, z, library, y, cat_covs=None): px_scale = self.px_decoder(h) px_rate = library * px_scale px = NegativeBinomial(px_rate, logits=self.px_r) - return {"px": px} + return {"px": px, "px_scale": px_scale} def loss( self, From 1697cda3d8a8c63ab7118bdf30a15f1b5296f3b8 Mon Sep 17 00:00:00 2001 From: cane11 Date: Tue, 19 Mar 2024 23:36:17 -0700 Subject: [PATCH 06/12] Fixed model. --- scvi/data/fields/_dataframe_field.py | 15 +- scvi/model/_condscvi.py | 25 ++- scvi/model/_destvi.py | 249 ++++++++++++++++++--------- scvi/model/_scanvi.py | 1 - scvi/module/_mrdeconv.py | 216 +++++++++-------------- scvi/module/_vaec.py | 44 +++-- 6 files changed, 284 insertions(+), 266 deletions(-) diff --git a/scvi/data/fields/_dataframe_field.py b/scvi/data/fields/_dataframe_field.py index 3c36625ba6..c35770c4ea 100644 --- a/scvi/data/fields/_dataframe_field.py +++ b/scvi/data/fields/_dataframe_field.py @@ -36,10 +36,9 @@ def __init__( attr_key: Optional[str], field_type: Literal["obs", "var"] = None, required: bool = True, - is_empty: bool = False, ) -> None: super().__init__() - if required and (attr_key is None or is_empty): + if required and attr_key is None: raise ValueError( "`attr_key` cannot be `None` if `required=True`. Please provide an `attr_key`." ) @@ -52,7 +51,7 @@ def __init__( self._registry_key = registry_key self._attr_key = attr_key - self._is_empty = is_empty or attr_key is None + self._is_empty = attr_key is None @property def registry_key(self) -> str: @@ -137,8 +136,6 @@ class CategoricalDataFrameField(BaseDataFrameField): Key to access the field in the AnnData obs or var mapping. If None, defaults to `registry_key`. field_type Type of field. Can be either "obs" or "var". - required - If False, allows for `attr_key is None` and marks the field as `is_empty`. """ CATEGORICAL_MAPPING_KEY = "categorical_mapping" @@ -149,18 +146,14 @@ def __init__( registry_key: str, attr_key: Optional[str], field_type: Literal["obs", "var"] = None, - required: bool = True, ) -> None: self.is_default = attr_key is None self._original_attr_key = attr_key or registry_key - is_empty = attr_key is None super().__init__( registry_key, f"_scvi_{registry_key}", field_type=field_type, - required=required, - is_empty=is_empty, ) self.count_stat_key = f"n_{self.registry_key}" @@ -244,16 +237,12 @@ def transfer_field( def get_summary_stats(self, state_registry: dict) -> dict: """Get summary stats.""" - if self.is_empty: - return {} categorical_mapping = state_registry[self.CATEGORICAL_MAPPING_KEY] n_categories = len(np.unique(categorical_mapping)) return {self.count_stat_key: n_categories} def view_state_registry(self, state_registry: dict) -> Optional[rich.table.Table]: """View state registry.""" - if self.is_empty: - return None source_key = state_registry[self.ORIGINAL_ATTR_KEY] mapping = state_registry[self.CATEGORICAL_MAPPING_KEY] t = rich.table.Table(title=f"{self.registry_key} State Registry") diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index e6ad14c2e4..ba76d129f1 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -78,15 +78,9 @@ def __init__( ): super().__init__(adata) + n_batch = self.summary_stats.n_batch n_labels = self.summary_stats.n_labels n_vars = self.summary_stats.n_vars - n_cats_per_cov = ( - self.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY - ).n_cats_per_key - if REGISTRY_KEYS.CAT_COVS_KEY in self.adata_manager.data_registry - else None - ) if weight_obs: ct_counts = np.unique( self.get_from_registry(adata, REGISTRY_KEYS.LABELS_KEY), @@ -112,6 +106,7 @@ def __init__( df_ct = pd.DataFrame({ 'fine_labels_key': fine_labels.ravel(), 'coarse_labels_key': coarse_labels.ravel()}).drop_duplicates() + print('YYYY', df_ct) fine_labels_mapping = self.adata_manager.get_state_registry( 'fine_labels' ).categorical_mapping @@ -134,12 +129,12 @@ def __init__( self.module = self._module_cls( n_input=n_vars, + n_batch=n_batch, n_labels=n_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, - n_cats_per_cov=n_cats_per_cov, df_ct_id_dict=self.df_ct_id_dict, **module_kwargs, ) @@ -339,9 +334,9 @@ def train( def setup_anndata( cls, adata: AnnData, + batch_key: str | None = None, labels_key: str | None = None, fine_labels_key: str | None = None, - categorical_covariate_keys: list[str] | None = None, layer: str | None = None, **kwargs, ): @@ -350,21 +345,23 @@ def setup_anndata( Parameters ---------- %(param_adata)s + %(param_batch_key)s %(param_labels_key)s fine_labels_key Key in `adata.obs` where fine-grained labels are stored. - %(param_cat_cov_keys)s %(param_layer)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), CategoricalObsField(REGISTRY_KEYS.LABELS_KEY, labels_key), - CategoricalObsField('fine_labels', fine_labels_key, required=False), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), ] + if fine_labels_key is not None: + anndata_fields.append( + CategoricalObsField('fine_labels', fine_labels_key + ) + ) # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) diff --git a/scvi/model/_destvi.py b/scvi/model/_destvi.py index b859ca8413..5691365199 100644 --- a/scvi/model/_destvi.py +++ b/scvi/model/_destvi.py @@ -6,12 +6,18 @@ import numpy as np import pandas as pd +from scipy.sparse import csr_matrix import torch from anndata import AnnData from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager -from scvi.data.fields import CategoricalJointObsField, LayerField, NumericalObsField +from scvi.data.fields import ( + CategoricalJointObsField, + LayerField, + NumericalObsField, + CategoricalObsField, +) from scvi.model import CondSCVI from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin from scvi.module import MRDeconv @@ -77,32 +83,24 @@ def __init__( n_hidden: int, n_latent: int, n_layers: int, - n_cats_per_cov: Sequence[int], + n_batch_sc: int, dropout_decoder: float, l1_reg: float, - sc_covariate_registry: dict[str, list[str]], + sc_batch_mapping: list[str], **module_kwargs, ): super().__init__(st_adata) - if sc_covariate_registry is not None: - st_covariate_registry = dict(self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings) - sc_covariate_mappings = dict(sc_covariate_registry.mappings) - assert set(sc_covariate_mappings.keys()) == set(st_covariate_registry.keys()), ( - f'Spatial model has other covariates than single cell model, {set(sc_covariate_mappings.keys()).symmetric_difference(st_covariate_registry.keys())}' - ) - for key, value in st_covariate_registry.items(): - assert set(value).issubset(set(sc_covariate_mappings[key])), ( - f'Spatial model has other covariates than single cell model, {set(value) - set(sc_covariate_mappings[key])}') - n_cats_per_cov = ( - sc_covariate_registry.n_cats_per_key - if sc_covariate_registry - else None + if sc_batch_mapping is not None: + st_sc_batch_mapping = self.adata_manager.get_state_registry('batch_index_sc')['categorical_mapping'] + assert set(st_sc_batch_mapping).issubset(set(sc_batch_mapping)), ( + f'Spatial model has other covariates than single cell model, {set(st_sc_batch_mapping) - set(sc_batch_mapping)}' ) self.module = self._module_cls( n_spots=st_adata.n_obs, n_labels=cell_type_mapping.shape[0], - n_cats_per_cov=n_cats_per_cov, + n_batch=self.summary_stats.n_batch, + n_batch_sc=n_batch_sc, decoder_state_dict=decoder_state_dict, px_decoder_state_dict=px_decoder_state_dict, px_r=px_r, @@ -159,13 +157,11 @@ def from_rna_model( sc_model.adata, p=vamp_prior_p ).values() - sc_covariate_registry = ( + sc_batch_mapping = ( sc_model.adata_manager.get_state_registry( - REGISTRY_KEYS.CAT_COVS_KEY + REGISTRY_KEYS.BATCH_KEY ) - if REGISTRY_KEYS.CAT_COVS_KEY in sc_model.adata_manager.data_registry - else None - ) + )['categorical_mapping'] return cls( st_adata, @@ -176,19 +172,21 @@ def from_rna_model( sc_model.module.n_hidden, sc_model.module.n_latent, sc_model.module.n_layers, - sc_model.module.n_cats_per_cov, + sc_model.module.n_batch, mean_vprior=mean_vprior, var_vprior=var_vprior, mp_vprior=mp_vprior, dropout_decoder=dropout_decoder, l1_reg=l1_reg, - sc_covariate_registry=sc_covariate_registry, + sc_batch_mapping=sc_batch_mapping, **module_kwargs, ) + @torch.inference_mode() def get_proportions( self, keep_noise: bool = False, + normalize: bool = True, indices: Sequence[int] | None = None, batch_size: int | None = None, ) -> pd.DataFrame: @@ -200,6 +198,8 @@ def get_proportions( ---------- keep_noise whether to account for the noise term as a standalone cell type in the proportion estimate. + normalize + whether to normalize the proportions to sum to 1. indices Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used. batch_size @@ -218,10 +218,10 @@ def get_proportions( ) prop_ = [] for tensors in stdl: - generative_inputs = self.module._get_generative_input(tensors, None) - prop_local = self.module.get_proportions( - x=generative_inputs["x"], keep_noise=keep_noise - ) + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + prop_local = self.module.generative(**generative_inputs)["v"] prop_ += [prop_local.cpu()] data = torch.cat(prop_).numpy() if indices: @@ -231,7 +231,11 @@ def get_proportions( logger.info( "No amortization for proportions, ignoring indices and returning results for the full data" ) - data = self.module.get_proportions(keep_noise=keep_noise) + data = torch.nn.functional.softplus(self.module.V).transpose(1, 0).detach().cpu().numpy() + if normalize: + data = data / data.sum(axis=1, keepdims=True) + if not keep_noise: + data = data[:, :-1] return pd.DataFrame( data=data, @@ -239,6 +243,7 @@ def get_proportions( index=index_names, ) + @torch.inference_mode() def get_gamma( self, indices: Sequence[int] | None = None, @@ -258,7 +263,7 @@ def get_gamma( """ self._check_if_trained() - column_names = np.arange(self.module.n_latent) + column_names = [str(i) for i in np.arange(self.module.n_latent)] index_names = self.adata.obs.index if self.module.amortization in ["both", "latent"]: @@ -267,8 +272,16 @@ def get_gamma( ) gamma_ = [] for tensors in stdl: - generative_inputs = self.module._get_generative_input(tensors, None) - gamma_local = self.module.get_gamma(x=generative_inputs["x"]) + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + generative_outputs = self.module.generative(**generative_inputs) + gamma_local = generative_outputs["gamma"] + if self.module.prior_mode == 'mog': + proportions_model_local = generative_outputs['proportion_modes'] + gamma_local = torch.einsum('pncm,pmc->ncm', gamma_local, proportions_model_local) + else: + gamma_local = gamma_local.squeeze(0) gamma_ += [gamma_local.cpu()] data = torch.cat(gamma_, dim=-1).numpy() if indices is not None: @@ -278,7 +291,7 @@ def get_gamma( logger.info( "No amortization for latent values, ignoring adata and returning results for the full data" ) - data = self.module.get_gamma() + data = self.module.gamma.detach().cpu().numpy() data = np.transpose(data, (2, 0, 1)) if return_numpy: @@ -291,49 +304,72 @@ def get_gamma( ) return res - def get_latent_amortization( + @torch.inference_mode() + def get_latent_representation( self, + adata: AnnData | None = None, indices: Sequence[int] | None = None, + give_mean: bool = True, + mc_samples: int = 5000, batch_size: int | None = None, - return_numpy: bool = False, - ) -> np.ndarray | dict[str, pd.DataFrame]: - """Returns the amortized latent space for the spatial data. + return_dist: bool = False, + ) -> np.ndarray (np.ndarray, np.ndarray): + """Return the latent representation for each cell. + + This is typically denoted as :math:`z_n`. Parameters ---------- + adata + AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the + AnnData object used to initialize the model. indices - Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used. + Indices of cells in adata to use. If `None`, all cells are used. + give_mean + Give mean of distribution or sample from it. + mc_samples + For distributions with no closed-form mean (e.g., `logistic normal`), how many Monte Carlo + samples to take for computing mean. batch_size - Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`. - return_numpy - if activated, will return a numpy array of shape is n_spots x n_latent x n_labels. - """ - self._check_if_trained() - - if self.module.n_latent_amortization is None or self.module.amortization in ["none"]: - ValueError('Get latent amortization is not defined if n_latent_amortization is None or no amortization is used') - - column_names = np.arange(self.module.n_latent_amortization) - index_names = self.adata.obs.index + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_dist + Return (mean, variance) of distributions instead of just the mean. + If `True`, ignores `give_mean` and `mc_samples`. In the case of the latter, + `mc_samples` is used to compute the mean of a transformed distribution. + If `return_dist` is true the untransformed mean and variance are returned. - stdl = self._make_data_loader( - self.adata, indices=indices, batch_size=batch_size + Returns + ------- + Low-dimensional representation for each cell or a tuple containing its mean and variance. + """ + assert self.module.n_latent_amortization is not None, ( + "Model has no latent representation for amortized values." ) + self._check_if_trained(warn=False) - amortization = [] - for tensors in stdl: - generative_inputs = self.module._get_generative_input(tensors, None) - x = generative_inputs["x"] - z_amortization = self.module.get_latent_amortization(x) - amortization += [z_amortization.cpu()] - - data = torch.cat(amortization).numpy() - column_names = np.arange(self.module.n_latent_amortization) - index_names = self.adata.obs.index - if indices is not None: - index_names = index_names[indices] - return pd.DataFrame(data=data, columns=column_names, index=index_names) + adata = self._validate_anndata(adata) + scdl = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) + latent = [] + latent_qzm = [] + latent_qzv = [] + for tensors in scdl: + inference_inputs = self.module._get_inference_input(tensors) + z, qz, _ = self.module.inference(**inference_inputs, n_samples=mc_samples).values() + if give_mean: + latent += [qz.loc.cpu()] + else: + latent += [z.cpu()] + latent_qzm += [qz.loc.cpu()] + latent_qzv += [qz.scale.square().cpu()] + return ( + (torch.cat(latent_qzm).numpy(), torch.cat(latent_qzv).numpy()) + if return_dist + else torch.cat(latent).numpy() + ) + @torch.inference_mode() def get_scale_for_ct( self, label: str, @@ -356,21 +392,22 @@ def get_scale_for_ct( Pandas dataframe of gene_expression """ self._check_if_trained() + self._validate_anndata() + + cell_type_mapping_extended = list(self.cell_type_mapping) + ['noise'] - if label not in self.cell_type_mapping: + if label not in cell_type_mapping_extended: raise ValueError("Unknown cell type") - y = np.where(label == self.cell_type_mapping)[0][0] + y = cell_type_mapping_extended.index(label) stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size) scale = [] for tensors in stdl: - generative_inputs = self.module._get_generative_input(tensors, None) - x, ind_x, cat_covs = ( - generative_inputs["x"], - generative_inputs["ind_x"], - generative_inputs["cat_covs"], - ) - px_scale = self.module.get_ct_specific_expression(x, ind_x, y, cat_covs) + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + px_scale = self.module.generative(**generative_inputs)["px_mu"][:, y, :] + scale += [px_scale.cpu()] data = torch.cat(scale).numpy() @@ -379,6 +416,62 @@ def get_scale_for_ct( if indices is not None: index_names = index_names[indices] return pd.DataFrame(data=data, columns=column_names, index=index_names) + + @torch.inference_mode() + def get_expression_for_ct( + self, + label: str, + indices: Sequence[int] | None = None, + batch_size: int | None = None, + return_sparse_array: bool = False, + ) -> pd.DataFrame: + r"""Return the scaled parameter of the NB for every spot in queried cell types. + + Parameters + ---------- + label + cell type of interest + indices + Indices of cells in self.adata to use. If `None`, all cells are used. + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + return_sparse_array + If `True`, returns a sparse array instead of a dataframe. + + Returns + ------- + Pandas dataframe of gene_expression + """ + self._check_if_trained() + cell_type_mapping_extended = list(self.cell_type_mapping) + ['noise'] + + if label not in cell_type_mapping_extended: + raise ValueError("Unknown cell type") + y = cell_type_mapping_extended.index(label) + + stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size) + expression_ct = [] + for tensors in stdl: + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + generative_outputs = self.module.generative(**generative_inputs) + px_scale, proportions = generative_outputs['px_mu'], generative_outputs['v'] + px_scale = torch.einsum('mkl,mk->mkl', px_scale, proportions) + px_scale_proportions = px_scale[:, y, :]/px_scale.sum(dim=1) + x_ct = inference_inputs['x'].to(px_scale_proportions.device) * px_scale_proportions + expression_ct += [x_ct.cpu()] + + data = torch.cat(expression_ct).numpy() + if return_sparse_array: + data = csr_matrix(data.T) + return data + else: + column_names = self.adata.var.index + index_names = self.adata.obs.index + if indices is not None: + index_names = index_names[indices] + return pd.DataFrame(data=data, columns=column_names, index=index_names) @devices_dsp.dedent def train( @@ -453,6 +546,8 @@ def setup_anndata( cls, adata: AnnData, layer: str | None = None, + batch_key: str | None = None, + sc_batch_key: str | None = None, categorical_covariate_keys: Sequence[str] | None = None, **kwargs, ): @@ -462,7 +557,8 @@ def setup_anndata( ---------- %(param_adata)s %(param_layer)s - %(param_categorical_covariate_keys)s + %(param_batch_key)s + sc_batch_key: Categorical covariate keys need to line up with single cell model. """ setup_method_args = cls._get_setup_method_args(**locals()) @@ -471,9 +567,8 @@ def setup_anndata( anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), - CategoricalJointObsField( - REGISTRY_KEYS.CAT_COVS_KEY, categorical_covariate_keys - ), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + CategoricalObsField("batch_index_sc", sc_batch_key), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index 9058cc1798..2f8eb362b9 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -471,7 +471,6 @@ def setup_anndata( %(param_cont_cov_keys)s """ - print("XXXXX", cls._latent_qzm) setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), diff --git a/scvi/module/_mrdeconv.py b/scvi/module/_mrdeconv.py index 0f3a4188e5..abade09c22 100644 --- a/scvi/module/_mrdeconv.py +++ b/scvi/module/_mrdeconv.py @@ -10,7 +10,7 @@ from scvi._types import Tunable from scvi.distributions import NegativeBinomial from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data -from scvi.nn import FCLayers +from scvi.nn import Encoder, FCLayers def identity(x): @@ -75,6 +75,7 @@ def __init__( self, n_spots: int, n_labels: int, + n_batch: int, n_hidden: Tunable[int], n_layers: Tunable[int], n_latent: Tunable[int], @@ -83,7 +84,7 @@ def __init__( px_decoder_state_dict: OrderedDict, px_r: np.ndarray, dropout_decoder: float, - n_cats_per_cov: Optional[list] = None, + n_batch_sc: Optional[list] = None, dropout_amortization: float = 0.05, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, @@ -98,7 +99,10 @@ def __init__( extra_decoder_kwargs: Optional[dict] = None, ): super().__init__() + if prior_mode == 'mog': + assert amortization in ["both", "latent"], "Amortization must be active for latent variables to use mixture of gaussians generation" self.n_spots = n_spots + self.n_batch = n_batch self.n_labels = n_labels self.n_hidden = n_hidden self.n_latent = n_latent @@ -113,7 +117,7 @@ def __init__( self.n_latent_amortization = n_latent_amortization # unpack and copy parameters _extra_decoder_kwargs = extra_decoder_kwargs or {} - cat_list = [n_labels] + list([] if n_cats_per_cov is None else n_cats_per_cov) + cat_list = [n_labels, n_batch_sc] self.decoder = FCLayers( n_in=n_latent, @@ -162,28 +166,33 @@ def __init__( # within cell_type factor loadings _extra_encoder_kwargs = extra_encoder_kwargs or {} if self.prior_mode == "mog": + print('Using mixture of gaussians for prior') return_dist = self.p * n_labels * n_latent + self.p * n_labels else: + print('Using normal prior') return_dist = n_labels * n_latent + print(f"return_dist: {return_dist}, {self.p}, {n_labels}, {n_latent}, {mean_vprior.shape}") if self.n_latent_amortization is not None: # Uses a combined latent space for proportions and gammas. - self.amortization_network = torch.nn.Sequential( - FCLayers( - n_in=self.n_genes, - n_out=n_hidden, - n_cat_list=None, - n_layers=1, - n_hidden=n_hidden, - dropout_rate=dropout_amortization, - use_layer_norm=True, - use_batch_norm=False, - ), - torch.nn.Linear(n_hidden, 2 * self.n_latent_amortization), + self.z_encoder = Encoder( + self.n_genes, + n_latent_amortization, + n_cat_list=[n_batch], + n_layers=n_layers, + n_hidden=n_hidden, + dropout_rate=dropout_amortization, + inject_covariates=True, + use_batch_norm=False, + use_layer_norm=True, + var_activation=torch.nn.functional.softplus, + return_dist=True, + **_extra_encoder_kwargs, ) - n_layers = 2 else: - self.amortization_network = torch.nn.Identity() + def identity(x, batch_index=None): + return x, Normal(x, scale=1e-6*torch.ones_like(x)) + self.z_encoder = identity n_latent_amortization = self.n_genes n_layers = 2 self.gamma_encoder = torch.nn.Sequential( @@ -215,42 +224,52 @@ def __init__( ) def _get_inference_input(self, tensors): - # we perform MAP here, so we just need to subsample the variables - return {} + x = tensors[REGISTRY_KEYS.X_KEY] + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] + + input_dict = {"x": x, "batch_index": batch_index} + return input_dict def _get_generative_input(self, tensors, inference_outputs): - x = tensors[REGISTRY_KEYS.X_KEY] + z = inference_outputs["z"] + library = inference_outputs["library"] ind_x = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel() - cat_key = REGISTRY_KEYS.CAT_COVS_KEY - cat_covs = tensors[cat_key] + 2. if cat_key in tensors.keys() else None + batch_index_sc = tensors['batch_index_sc'] + 2. - input_dict = {"x": x, "ind_x": ind_x, "cat_covs": cat_covs} + input_dict = {"z": z, "ind_x": ind_x, "library": library, "batch_index_sc": batch_index_sc} return input_dict @auto_move_data - def inference(self): - """Run the inference model.""" - return {} + def inference( + self, + x, + batch_index, + n_samples=1, + ): + """Runs the inference (encoder) model.""" + x_ = x + library = torch.log(x.sum(1)).unsqueeze(1) + x_ = torch.log(1 + x_) + if self.n_latent_amortization is not None: + qz, z = self.z_encoder(x_, batch_index) + else: + z = x_ + qz = Normal(x_, scale=1e-6*torch.ones_like(x_)) # dummy distribution + + outputs = {"z": z, "qz": qz, "library": library} + return outputs @auto_move_data - def generative(self, x, ind_x, cat_covs=None): + def generative(self, z, ind_x, library, batch_index_sc): """Build the deconvolution model for every cell in the minibatch.""" - m = x.shape[0] - library = torch.sum(x, dim=1, keepdim=True) + m = len(ind_x) # setup all non-linearities beta = torch.exp(self.beta) # n_genes - eps = torch.nn.functional.softplus(self.eta) # n_genes - x_ = torch.log(1 + x) - z_amortization_params = self.amortization_network(x_) - qz_amortization = Normal( - z_amortization_params[:, :self.n_latent_amortization], - torch.exp(z_amortization_params[:, self.n_latent_amortization:])+1e-4 - ) - z_amortization = qz_amortization.rsample() + eps = torch.nn.functional.softplus(self.eta) # n_genes if self.amortization in ["both", "latent"]: if self.prior_mode == "mog": - gamma_ = self.gamma_encoder(z_amortization) + gamma_ = self.gamma_encoder(z) proportion_modes_logits = torch.transpose( gamma_[:, -self.p*self.n_labels:], 0, 1).reshape( (self.p, self.n_labels, m) @@ -262,28 +281,25 @@ def generative(self, x, ind_x, cat_covs=None): ) else: gamma_ind = torch.transpose( - self.gamma_encoder(z_amortization), 0, 1).reshape( + self.gamma_encoder(z), 0, 1).reshape( (1, self.n_latent, self.n_labels, -1) ) - proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=x.device) + proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=z.device) else: gamma_ind = self.gamma[:, :, ind_x].unsqueeze(0) # 1, n_latent, n_labels, minibatch_size - proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=x.device) + proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=z.device) if self.amortization in ["both", "proportion"]: - v_ind = self.V_encoder(z_amortization) + v_ind = self.V_encoder(z) else: v_ind = self.V[:, ind_x].T # minibatch_size, labels + 1 v_ind = torch.nn.functional.softplus(v_ind) - px_est = torch.zeros((x.shape[0], self.n_labels, self.n_genes), device=x.device) + px_est = torch.zeros((m, self.n_labels, self.n_genes), device=z.device) enum_label = ( torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) ) # minibatch_size * n_labels, 1 - if cat_covs is not None: - categorical_input = [i.repeat_interleave(self.n_labels, dim=0) for i in torch.split(cat_covs, 1, dim=1)] - else: - categorical_input = () + batch_index_sc_input = batch_index_sc.repeat_interleave(self.n_labels, dim=0) for mode in range(gamma_ind.shape[0]): # reshape and get gene expression value for all minibatch @@ -293,7 +309,7 @@ def generative(self, x, ind_x, cat_covs=None): gamma_reshape_ = gamma_ind_.reshape( (-1, self.n_latent) ) # minibatch_size * n_labels, n_latent - h = self.decoder(gamma_reshape_, enum_label.to(x.device), *categorical_input) + h = self.decoder(gamma_reshape_, enum_label.to(z.device), batch_index_sc_input) px_est += proportion_modes[mode, ...].unsqueeze(-1) * self.px_decoder(h).reshape( (m, self.n_labels, -1) ) # (minibatch, n_labels, n_genes) @@ -310,16 +326,17 @@ def generative(self, x, ind_x, cat_covs=None): # now combine them for convolution px_scale = torch.sum(v_ind.unsqueeze(2) * r_hat, dim=1) # batch_size, n_genes px_rate = library * px_scale + px_mu = torch.exp(self.px_o) * r_hat return { "px_o": self.px_o, "px_rate": px_rate, + "px_mu": px_mu, "px_scale": px_scale, "gamma": gamma_ind, "v": v_ind, "proportion_modes": proportion_modes, "proportion_modes_logits": proportion_modes_logits, - "qz_amortization": qz_amortization, } def loss( @@ -329,7 +346,7 @@ def loss( generative_outputs, kl_weight: float = 1.0, n_obs: int = 1.0, - weighting_cross_entropy: float = 1.0 + weighting_cross_entropy: float = 1e-6, ): """Compute the loss.""" x = tensors[REGISTRY_KEYS.X_KEY] @@ -385,10 +402,11 @@ def loss( log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch - neg_log_likelihood_prior += kl( - generative_outputs["qz_amortization"], - Normal(torch.zeros([self.n_latent_amortization], device=x.device), torch.ones([self.n_latent_amortization], device=x.device)) - ).sum(dim=-1) + if self.n_latent_amortization is not None: + neg_log_likelihood_prior += kl( + inference_outputs["qz"], + Normal(torch.zeros([self.n_latent_amortization], device=x.device), torch.ones([self.n_latent_amortization], device=x.device)) + ).sum(dim=-1) # High v_sparsity_loss is detrimental early in training, scaling by kl_weight to increase over training epochs. loss = n_obs * ( @@ -415,62 +433,12 @@ def sample( """Sample from the posterior.""" raise NotImplementedError("No sampling method for DestVI") - @torch.inference_mode() - @auto_move_data - def get_proportions(self, x=None, keep_noise=False) -> np.ndarray: - """Returns the loadings.""" - if self.amortization in ["both", "proportion"]: - # get estimated unadjusted proportions - x_ = torch.log(1 + x) - z_amortization = self.amortization_network(x_)[:, :self.n_latent_amortization] - res = torch.nn.functional.softplus(self.V_encoder(z_amortization)) - else: - res = ( - torch.nn.functional.softplus(self.V).cpu().numpy().T - ) # n_spots, n_labels + 1 - # remove dummy cell type proportion values - if not keep_noise: - res = res[:, :-1] - # normalize to obtain adjusted proportions - res = res / res.sum(axis=1).reshape(-1, 1) - return res - - @torch.inference_mode() - @auto_move_data - def get_gamma(self, x: torch.Tensor = None) -> torch.Tensor: - """Returns the loadings. - - Returns - ------- - type - tensor - """ - # get estimated unadjusted proportions - if self.amortization in ["latent", "both"]: - x_ = torch.log(1 + x) - z_amortization = self.amortization_network(x_)[:, :self.n_latent_amortization] - gamma = self.gamma_encoder(z_amortization) - if self.prior_mode == "mog": - proportion_modes_logits = torch.transpose( - gamma[:, -self.p*self.n_labels:], 0, 1).reshape( - (self.p, 1, self.n_labels, x.shape[0]) - ) - proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=0) - gamma_ind = torch.transpose( - gamma[:, :self.p*self.n_labels*self.n_latent], 0, 1).reshape( - (self.p, self.n_latent, self.n_labels, x.shape[0]) - ) - gamma = torch.sum(proportion_modes * gamma_ind, dim=0) - return torch.transpose(gamma, 0, 1).reshape( - (self.n_latent, self.n_labels, -1) - ) # n_latent, n_labels, minibatch - else: - return self.gamma.cpu().numpy() # (n_latent, n_labels, n_spots) - + @torch.inference_mode() @auto_move_data def get_ct_specific_expression( - self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None, cat_covs: torch.Tensor = None + self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None, + batch_index: torch.Tensor = None, batch_index_sc: torch.Tensor = None ): """Returns cell type specific gene expression at the queried spots. @@ -482,20 +450,16 @@ def get_ct_specific_expression( tensor of indices y integer for cell types - cat_covs - tensor of categorical covariates + batch_index_sc + tensor of corresponding batch in single cell data for decoder """ # cell-type specific gene expression, shape (minibatch, celltype, gene). beta = torch.exp(self.beta) # n_genes y_torch = (y * torch.ones_like(ind_x)).ravel() - if cat_covs is not None: - categorical_input = torch.split(cat_covs, 1, dim=1) - else: - categorical_input = () # obtain the relevant gammas if self.amortization in ["both", "latent"]: x_ = torch.log(1 + x) - z_amortization = self.amortization_network(x_)[:, :self.n_latent_amortization] + z_amortization = self.amortization_network(x_, batch_index)[:, :self.n_latent_amortization] if self.prior_mode == "mog": gamma_ = self.gamma_encoder(z_amortization) proportion_modes_logits = torch.transpose( @@ -526,7 +490,7 @@ def get_ct_specific_expression( gamma_ind_ = torch.transpose( gamma_ind[mode, ...], 1, 0 ) # minibatch_size, n_latent - h = self.decoder(gamma_ind_, y_torch.unsqueeze(1), *categorical_input) + h = self.decoder(gamma_ind_, y_torch.unsqueeze(1), batch_index_sc.unsqueeze(1)) px_est += proportion_modes[mode, ...].unsqueeze(-1) * self.px_decoder(h).reshape( (x.shape[0], -1) ) # (minibatch, n_genes) @@ -534,25 +498,3 @@ def get_ct_specific_expression( px_scale_ct = torch.exp(self.px_o).unsqueeze(0) * beta.unsqueeze(0) * px_est return px_scale_ct # shape (minibatch, genes) - @torch.inference_mode() - @auto_move_data - def get_latent_amortization( - self, x: torch.Tensor = None - ): - """ - Returns cell type specific latent representation at the queried spots. - - Parameters - ---------- - x - tensor of data - ind_x - tensor of indices - y - integer for cell types - """ - # cell-type specific gene expression, shape (minibatch, celltype, gene). - x_ = torch.log(1 + x) - z_amortized = self.amortization_network(x_) - - return z_amortized # shape (minibatch, genes) diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index 73d90002ec..370b1b59fc 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -26,6 +26,8 @@ class VAEC(BaseMinifiedModeModuleClass): ---------- n_input Number of input genes + n_batch + Number of batches n_labels Number of labels n_hidden @@ -49,8 +51,8 @@ class VAEC(BaseMinifiedModeModuleClass): def __init__( self, n_input: int, + n_batch: int = 0, n_labels: int = 0, - n_cats_per_cov: Optional[Iterable[int]] = None, n_hidden: Tunable[int] = 128, n_latent: Tunable[int] = 5, n_layers: Tunable[int] = 2, @@ -75,17 +77,16 @@ def __init__( self.gene_likelihood = "nb" self.latent_distribution = "normal" # Automatically deactivate if useless - self.n_batch = 0 + self.n_batch = n_batch self.n_labels = n_labels self.prior = prior - self.n_cats_per_cov = n_cats_per_cov if df_ct_id_dict is not None: self.num_classes_mog = max([v[2] for v in df_ct_id_dict.values()]) + 1 mapping_mog = torch.tensor([v[2] for _, v in sorted(df_ct_id_dict.items())]) self.register_buffer("mapping_mog", mapping_mog) else: self.num_classes_mog = num_classes_mog - cat_list = [n_labels] + list([] if self.n_cats_per_cov is None else self.n_cats_per_cov) + cat_list = [n_labels, n_batch] encoder_cat_list = cat_list if self.encode_covariates else [n_labels] # gene dispersion @@ -136,18 +137,17 @@ def __init__( self.prior_log_scales = torch.nn.Parameter( torch.zeros([n_labels, self.num_classes_mog, n_latent])) self.prior_logits = torch.nn.Parameter( - torch.ones([n_labels, self.num_classes_mog])) + torch.zeros([n_labels, self.num_classes_mog])) def _get_inference_input(self, tensors): - cat_key = REGISTRY_KEYS.CAT_COVS_KEY - cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None + y = tensors[REGISTRY_KEYS.LABELS_KEY] + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] if self.minified_data_type is None: x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY] input_dict = { "x": x, "y": y, - "cat_covs": cat_covs, + "batch_index": batch_index, } else: if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): @@ -158,7 +158,8 @@ def _get_inference_input(self, tensors): "qzm": qzm, "qzv": qzv, "observed_lib_size": observed_lib_size, - "cat_covs": cat_covs, + "y": y, + "batch_index": batch_index, } else: raise NotImplementedError( @@ -171,19 +172,18 @@ def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] library = inference_outputs["library"] y = tensors[REGISTRY_KEYS.LABELS_KEY] - cat_key = REGISTRY_KEYS.CAT_COVS_KEY - cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] input_dict = { "z": z, "library": library, "y": y, - "cat_covs": cat_covs, + "batch_index": batch_index, } return input_dict @auto_move_data - def _regular_inference(self, x, y, cat_covs=None, n_samples=1): + def _regular_inference(self, x, y, batch_index, n_samples=1): """High level inference method. Runs the inference (encoder) model. @@ -192,11 +192,11 @@ def _regular_inference(self, x, y, cat_covs=None, n_samples=1): library = x.sum(1).unsqueeze(1) if self.log_variational: x_ = torch.log(1 + x_) - if cat_covs is not None and self.encode_covariates: - categorical_input = torch.split(cat_covs, 1, dim=1) + if self.encode_covariates: + categorical_input = [y, batch_index] else: - categorical_input = () - qz, z = self.z_encoder(x_, y, *categorical_input) + categorical_input = [y] + qz, z = self.z_encoder(x_, *categorical_input) if n_samples > 1: untran_z = qz.sample((n_samples,)) @@ -228,13 +228,9 @@ def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): return outputs @auto_move_data - def generative(self, z, library, y, cat_covs=None): + def generative(self, z, library, y, batch_index): """Runs the generative model.""" - if cat_covs is not None: - categorical_input = torch.split(cat_covs, 1, dim=1) - else: - categorical_input = () - h = self.decoder(z, y, *categorical_input) + h = self.decoder(z, y, batch_index) px_scale = self.px_decoder(h) px_rate = library * px_scale px = NegativeBinomial(px_rate, logits=self.px_r) From 17c2ec49e43e302dfd14cfda03294b9f8abcce94 Mon Sep 17 00:00:00 2001 From: cane11 Date: Sun, 21 Apr 2024 16:58:25 -0700 Subject: [PATCH 07/12] Changed loss to get rid of sparse output --- scvi/model/_destvi.py | 59 ++++++++++++++++++++++++++++++++++++++++ scvi/module/_mrdeconv.py | 9 +++--- 2 files changed, 63 insertions(+), 5 deletions(-) diff --git a/scvi/model/_destvi.py b/scvi/model/_destvi.py index 5691365199..1316a9dbf2 100644 --- a/scvi/model/_destvi.py +++ b/scvi/model/_destvi.py @@ -242,6 +242,65 @@ def get_proportions( columns=column_names, index=index_names, ) + + @torch.inference_mode() + def get_fine_celltypes( + self, + sc_model: CondSCVI, + indices: Sequence[int] | None = None, + batch_size: int | None = None, + return_numpy: bool = False, + ) -> np.ndarray | dict[str, pd.DataFrame]: + """Returns the estimated cell-type specific latent space for the spatial data. + + Parameters + ---------- + sc_model + trained CondSCVI model + indices + Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used. + batch_size + Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`. + return_numpy + if activated, will return a numpy array of shape is n_spots x n_latent x n_labels. + """ + self._check_if_trained() + + column_names = [str(i) for i in np.arange(self.module.n_latent)] + index_names = self.adata.obs.index + + if self.module.amortization in ["both", "latent"]: + stdl = self._make_data_loader( + adata=self.adata, indices=indices, batch_size=batch_size + ) + gamma_ = [] + proportions_modes_ = [] + for tensors in stdl: + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + generative_outputs = self.module.generative(**generative_inputs) + gamma_local = generative_outputs["gamma"] + if self.module.prior_mode == 'mog': + proportions_modes_local = generative_outputs['proportion_modes'] # pmc + gamma_local = gamma_local # pncm + else: + proportions_modes_local = torch.ones(gamma_local.shape[0], 1, 1) + gamma_local = gamma_local.squeeze(0) # pncm + gamma_ += [gamma_local.cpu()] + proportions_modes_ += [proportions_modes_local.cpu()] + + proportions_modes = torch.cat(proportions_modes_, dim=-1).numpy() + gamma = torch.cat(gamma_, dim=-1).numpy() + else: + if indices is not None: + logger.info( + "No amortization for latent values, ignoring adata and returning results for the full data" + ) + gamma = self.module.gamma.detach().cpu().numpy() + + sc_latent_distribution = sc_model.get_latent_representation(return_dist=True) + @torch.inference_mode() def get_gamma( diff --git a/scvi/module/_mrdeconv.py b/scvi/module/_mrdeconv.py index abade09c22..62852efd88 100644 --- a/scvi/module/_mrdeconv.py +++ b/scvi/module/_mrdeconv.py @@ -389,18 +389,18 @@ def loss( proportion_modes = generative_outputs["proportion_modes"] pre_lse = ( Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) - ) + torch.log(proportion_modes).permute(1, 0, 2) # minibatch, p, n_labels + ) + torch.log(1e-3 + proportion_modes).permute(1, 0, 2) # minibatch, p, n_labels log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels - neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch + neg_log_likelihood_prior = - log_likelihood_prior.sum(1) # minibatch neg_log_likelihood_prior += weighting_cross_entropy * torch.nn.functional.cross_entropy( proportion_modes_logits.permute(1, 0, 2), mp_vprior.repeat(x.shape[0], 1, 1), reduction='none').sum(1) else: pre_lse = ( Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) - ) + torch.log(mp_vprior) # minibatch, p, n_labels + ) + torch.log(1e-3 + mp_vprior) # minibatch, p, n_labels log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels - neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch + neg_log_likelihood_prior = - log_likelihood_prior.sum(1) # minibatch if self.n_latent_amortization is not None: neg_log_likelihood_prior += kl( @@ -434,7 +434,6 @@ def sample( raise NotImplementedError("No sampling method for DestVI") - @torch.inference_mode() @auto_move_data def get_ct_specific_expression( self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None, From cbc2511275740c1cb126d8d38558b5de8fb9a458 Mon Sep 17 00:00:00 2001 From: cane11 Date: Sat, 29 Jun 2024 04:09:56 -0700 Subject: [PATCH 08/12] Batch key added. --- scvi/data/_manager.py | 2 +- scvi/model/_condscvi.py | 206 +++++++++++----- scvi/model/_destvi.py | 179 ++++++-------- scvi/model/base/_archesmixin.py | 4 +- scvi/module/_mrdeconv.py | 418 ++++++++++++++++++-------------- scvi/module/_vaec.py | 124 +++++++--- 6 files changed, 553 insertions(+), 380 deletions(-) diff --git a/scvi/data/_manager.py b/scvi/data/_manager.py index 7e3e9b309a..f31c6d1ed1 100644 --- a/scvi/data/_manager.py +++ b/scvi/data/_manager.py @@ -212,7 +212,7 @@ def _add_field( # If empty, we skip registering the field. if not field.is_empty: # Transfer case: Source registry is used for validation and/or setup. - if source_registry is not None: + if source_registry is not None and field.registry_key in source_registry[_constants._FIELD_REGISTRIES_KEY]: field_registry[_constants._STATE_REGISTRY_KEY] = field.transfer_field( source_registry[_constants._FIELD_REGISTRIES_KEY][ field.registry_key diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index ba76d129f1..4458d15e8e 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -81,6 +81,11 @@ def __init__( n_batch = self.summary_stats.n_batch n_labels = self.summary_stats.n_labels n_vars = self.summary_stats.n_vars + if 'n_fine_labels' in self.summary_stats: + self.n_fine_labels = self.summary_stats.n_fine_labels + else: + self.n_fine_labels = None + self._set_indices_and_labels(adata) if weight_obs: ct_counts = np.unique( self.get_from_registry(adata, REGISTRY_KEYS.LABELS_KEY), @@ -91,51 +96,16 @@ def __init__( ct_prop = ct_prop / np.sum(ct_prop) ct_weight = 1.0 / ct_prop module_kwargs.update({"ct_weight": ct_weight}) - if 'fine_labels' in self.adata_manager.data_registry: - fine_labels = get_anndata_attribute( - adata, - self.adata_manager.data_registry.labels.attr_name, - '_scvi_fine_labels', - ) - coarse_labels = get_anndata_attribute( - adata, - self.adata_manager.data_registry.labels.attr_name, - '_scvi_labels' - ) - - df_ct = pd.DataFrame({ - 'fine_labels_key': fine_labels.ravel(), - 'coarse_labels_key': coarse_labels.ravel()}).drop_duplicates() - print('YYYY', df_ct) - fine_labels_mapping = self.adata_manager.get_state_registry( - 'fine_labels' - ).categorical_mapping - coarse_labels_mapping = self.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ).categorical_mapping - - df_ct['fine_labels'] = fine_labels_mapping[df_ct['fine_labels_key']] - df_ct['coarse_labels'] = coarse_labels_mapping[df_ct['coarse_labels_key']] - - self.df_ct_name_dict = {} - self.df_ct_id_dict = {} - for i, row in df_ct.iterrows(): - count = len(df_ct.loc[:i][df_ct['coarse_labels'] == row['coarse_labels']]) - 1 - self.df_ct_name_dict[row['fine_labels']] = (row['coarse_labels'], row['coarse_labels_key'], count) - self.df_ct_id_dict[row['fine_labels_key']] = (row['coarse_labels'], row['coarse_labels_key'], count) - else: - self.df_ct_name_dict = None - self.df_ct_id_dict = None self.module = self._module_cls( n_input=n_vars, n_batch=n_batch, n_labels=n_labels, + n_fine_labels=self.n_fine_labels, n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, dropout_rate=dropout_rate, - df_ct_id_dict=self.df_ct_id_dict, **module_kwargs, ) self._model_summary_string = ( @@ -144,7 +114,7 @@ def __init__( self.init_params_ = self._get_init_params(locals()) @torch.inference_mode() - def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10, scales_prior_n_samples: int | None = None, default_cat: list | None = None) -> np.ndarray: + def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarray: r"""Return an empirical prior over the cell-type specific latent space (vamp prior) that may be used for deconvolution. Parameters @@ -154,10 +124,6 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10, scales_prior AnnData object used to initialize the model. p number of clusters in kmeans clustering for cell-type sub-clustering for empirical prior - scales_prior_n_samples - return scales of negative binomial distribution for calculates prior means and variances using n_samples. - default_cat - default value for categorical covariates Returns ------- @@ -167,13 +133,10 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10, scales_prior (n_labels, p, D) array weights_vprior (n_labels, p) array - scales_vprior - (n_labels, p, G) array """ if self.is_trained_ is False: warnings.warn( - "Trying to query inferred values from an untrained model. Please train " - "the model first.", + "Trying to query inferred values from an untrained model. Please train the model first.", UserWarning, stacklevel=settings.warnings_stacklevel, ) @@ -183,7 +146,7 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10, scales_prior if self.module.prior == "mog": results = { "mean_vprior": self.module.prior_means, - "var_vprior": torch.exp(self.module.prior_log_scales)**2 + 1e-4, + "var_vprior": torch.exp(self.module.prior_log_std)**2, "weights_vprior": torch.nn.functional.softmax(self.module.prior_logits, dim=-1) } else: @@ -250,21 +213,128 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10, scales_prior "weights_vprior": mp_vprior } - if scales_prior_n_samples is not None: - scales_vprior = np.zeros((self.summary_stats.n_labels, p, self.summary_stats.n_vars)) - cat_covs = [ - torch.full([scales_prior_n_samples, 1], float(np.where(value==default_cat[ind])[0]) if default_cat else 0, device=self.module.device) - for ind, value in enumerate(self.adata_manager.get_state_registry(REGISTRY_KEYS.CAT_COVS_KEY).mappings.values())] - for ct in range(self.summary_stats["n_labels"]): - for cluster in range(p): - sampled_z = torch.distributions.Normal( - results['mean_vprior'][ct, cluster, :], torch.sqrt(results['var_vprior'][ct, cluster, :]) - ).sample([scales_prior_n_samples,]).to(self.module.device) - h = self.module.decoder(sampled_z, torch.full([scales_prior_n_samples, 1], ct, device=self.module.device), *cat_covs) - scales_vprior[ct, cluster, :] = self.module.px_decoder(h).mean(0).cpu() - results["scales_vprior"] = scales_vprior - return results + + @torch.inference_mode() + def predict( + self, + adata: AnnData | None = None, + indices: Sequence[int] | None = None, + soft: bool = False, + batch_size: int | None = None, + use_posterior_mean: bool = True, + ) -> np.ndarray | pd.DataFrame: + """Return cell label predictions. + + Parameters + ---------- + adata + AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + indices + Return probabilities for each class label. + soft + If True, returns per class probabilities + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + use_posterior_mean + If ``True``, uses the mean of the posterior distribution to predict celltype + labels. Otherwise, uses a sample from the posterior distribution - this + means that the predictions will be stochastic. + """ + adata = self._validate_anndata(adata) + + if indices is None: + indices = np.arange(adata.n_obs) + + scdl = self._make_data_loader( + adata=adata, + indices=indices, + batch_size=batch_size, + ) + y_pred = [] + for _, tensors in enumerate(scdl): + inference_input = self.module._get_inference_input(tensors) + qz = self.module.inference(**inference_input)['qz'] + if use_posterior_mean: + z = qz.loc + else: + z = qz.sample() + pred = self.module.classify( + z, + label_index=inference_input["y"], + ) + if self.module.classifier.logits: + pred = torch.nn.functional.softmax(pred, dim=-1) + if not soft: + pred = pred.argmax(dim=1) + y_pred.append(pred.detach().cpu()) + + y_pred = torch.cat(y_pred).numpy() + if not soft: + predictions = [] + for p in y_pred: + predictions.append(self._code_to_fine_label[p]) + + return np.array(predictions) + else: + n_labels = len(pred[0]) + pred = pd.DataFrame( + y_pred, + columns=self._fine_label_mapping[:n_labels], + index=adata.obs_names[indices], + ) + return pred + + @torch.inference_mode() + def confusion_coarse_celltypes( + self, + adata: AnnData | None = None, + indices: Sequence[int] | None = None, + batch_size: int | None = None, + ) -> np.ndarray | pd.DataFrame: + """Return likelihood ratios of switching coarse cell-types to inform whether resolution is to granular. + + Parameters + ---------- + adata + AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + indices + Return probabilities for each class label. + soft + If True, returns per class probabilities + batch_size + Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + use_posterior_mean + If ``True``, uses the mean of the posterior distribution to predict celltype + labels. Otherwise, uses a sample from the posterior distribution - this + means that the predictions will be stochastic. + """ + adata = self._validate_anndata(adata) + + if indices is None: + indices = np.arange(adata.n_obs) + + scdl = self._make_data_loader( + adata=adata, + indices=indices, + batch_size=batch_size, + ) + # Iterate once over the data and computes the reconstruction error + keys = list(self._label_mapping) + ['original'] + log_lkl = {key: [] for key in keys} + for tensors in scdl: + loss_kwargs = {"kl_weight": 1} + _, _, losses = self.module(tensors, loss_kwargs=loss_kwargs) + log_lkl['original'] += [losses.reconstruction_loss] + for i in range(self.module.n_labels): + tensors_ = tensors + tensors_['y'] = torch.full_like(tensors['y'], i) + _, _, losses = self.module(tensors_, loss_kwargs=loss_kwargs) + log_lkl[keys[i]] += [losses.reconstruction_loss] + for key in keys: + log_lkl[key] = torch.stack(log_lkl[key]).detach().numpy() + + return log_lkl @devices_dsp.dedent def train( @@ -328,6 +398,22 @@ def train( plan_kwargs=plan_kwargs, **kwargs, ) + + def _set_indices_and_labels(self, adata: AnnData): + """Set indices for labeled and unlabeled cells.""" + labels_state_registry = self.adata_manager.get_state_registry( + REGISTRY_KEYS.LABELS_KEY + ) + self.original_label_key = labels_state_registry.original_key + self._label_mapping = labels_state_registry.categorical_mapping + self._code_to_label = dict(enumerate(self._label_mapping)) + if self.n_fine_labels is not None: + fine_labels_state_registry = self.adata_manager.get_state_registry( + 'fine_labels' + ) + self.original_fine_label_key = fine_labels_state_registry.original_key + self._fine_label_mapping = fine_labels_state_registry.categorical_mapping + self._code_to_fine_label = dict(enumerate(self._fine_label_mapping)) @classmethod @setup_anndata_dsp.dedent diff --git a/scvi/model/_destvi.py b/scvi/model/_destvi.py index 1316a9dbf2..a26123ce5a 100644 --- a/scvi/model/_destvi.py +++ b/scvi/model/_destvi.py @@ -13,16 +13,17 @@ from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager from scvi.data.fields import ( - CategoricalJointObsField, LayerField, NumericalObsField, CategoricalObsField, ) +from scvi.data._constants import _SETUP_ARGS_KEY from scvi.model import CondSCVI from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin from scvi.module import MRDeconv from scvi.utils import setup_anndata_dsp from scvi.utils._docstrings import devices_dsp +from scvi.model.base._archesmixin import _get_loaded_data logger = logging.getLogger(__name__) @@ -83,24 +84,15 @@ def __init__( n_hidden: int, n_latent: int, n_layers: int, - n_batch_sc: int, dropout_decoder: float, - l1_reg: float, - sc_batch_mapping: list[str], **module_kwargs, ): super().__init__(st_adata) - if sc_batch_mapping is not None: - st_sc_batch_mapping = self.adata_manager.get_state_registry('batch_index_sc')['categorical_mapping'] - assert set(st_sc_batch_mapping).issubset(set(sc_batch_mapping)), ( - f'Spatial model has other covariates than single cell model, {set(st_sc_batch_mapping) - set(sc_batch_mapping)}' - ) self.module = self._module_cls( n_spots=st_adata.n_obs, n_labels=cell_type_mapping.shape[0], n_batch=self.summary_stats.n_batch, - n_batch_sc=n_batch_sc, decoder_state_dict=decoder_state_dict, px_decoder_state_dict=px_decoder_state_dict, px_r=px_r, @@ -109,7 +101,6 @@ def __init__( n_layers=n_layers, n_hidden=n_hidden, dropout_decoder=dropout_decoder, - l1_reg=l1_reg, **module_kwargs, ) self.cell_type_mapping = cell_type_mapping @@ -122,7 +113,7 @@ def from_rna_model( st_adata: AnnData, sc_model: CondSCVI, vamp_prior_p: int = 15, - l1_reg: float = 0.0, + anndata_setup_kwargs: dict | None = None, **module_kwargs, ): """Alternate constructor for exploiting a pre-trained model on a RNA-seq dataset. @@ -130,38 +121,47 @@ def from_rna_model( Parameters ---------- st_adata - registered anndata object + anndata object will be registered sc_model - trained CondSCVI model + trained CondSCVI model or path to a trained model vamp_prior_p number of mixture parameter for VampPrior calculations l1_reg Scalar parameter indicating the strength of L1 regularization on cell type proportions. A value of 50 leads to sparser results. + anndata_setup_kwargs + Keyword args for :meth:`~scvi.model.DestVI.setup_anndata` **model_kwargs Keyword args for :class:`~scvi.model.DestVI` """ - decoder_state_dict = sc_model.module.decoder.state_dict() - px_decoder_state_dict = sc_model.module.px_decoder.state_dict() - px_r = sc_model.module.px_r.detach().cpu().numpy() - mapping = sc_model.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ).categorical_mapping - - dropout_decoder = sc_model.module.dropout_rate + attr_dict, var_names, load_state_dict = _get_loaded_data(sc_model) + registry = attr_dict.pop("registry_") + + decoder_state_dict = OrderedDict((i[8:], load_state_dict[i]) for i in load_state_dict.keys() if i.split('.')[0]=='decoder') + px_decoder_state_dict = OrderedDict((i[11:], load_state_dict[i]) for i in load_state_dict.keys() if i.split('.')[0]=='px_decoder') + px_r = load_state_dict['px_r'] + mapping = registry['field_registries']['labels']['state_registry']['categorical_mapping'] + + dropout_decoder = attr_dict['init_params_']['non_kwargs']['dropout_rate'] if vamp_prior_p is None: mean_vprior = None var_vprior = None + elif attr_dict['init_params_']['kwargs']['module_kwargs']['prior']=='mog': + mean_vprior = load_state_dict['prior_means'].clone().detach() + var_vprior = torch.exp(load_state_dict['prior_log_std'])**2 + mp_vprior = torch.nn.Softmax(dim=-1)(load_state_dict['prior_logits']) else: + assert sc_model is not str, "VampPrior requires loading CondSCVI model and providing it" mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior( sc_model.adata, p=vamp_prior_p ).values() - - sc_batch_mapping = ( - sc_model.adata_manager.get_state_registry( - REGISTRY_KEYS.BATCH_KEY - ) - )['categorical_mapping'] + + cls.setup_anndata( + st_adata, + source_registry=registry, + extend_categories=True, + **registry[_SETUP_ARGS_KEY], + ) return cls( st_adata, @@ -172,13 +172,10 @@ def from_rna_model( sc_model.module.n_hidden, sc_model.module.n_latent, sc_model.module.n_layers, - sc_model.module.n_batch, mean_vprior=mean_vprior, var_vprior=var_vprior, mp_vprior=mp_vprior, dropout_decoder=dropout_decoder, - l1_reg=l1_reg, - sc_batch_mapping=sc_batch_mapping, **module_kwargs, ) @@ -221,17 +218,13 @@ def get_proportions( inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) generative_inputs = self.module._get_generative_input(tensors, outputs) - prop_local = self.module.generative(**generative_inputs)["v"] + prop_local = self.module.generative(**generative_inputs)["v"].squeeze(0) prop_ += [prop_local.cpu()] data = torch.cat(prop_).numpy() if indices: index_names = index_names[indices] else: - if indices is not None: - logger.info( - "No amortization for proportions, ignoring indices and returning results for the full data" - ) - data = torch.nn.functional.softplus(self.module.V).transpose(1, 0).detach().cpu().numpy() + data = torch.nn.functional.softplus(self.module.V[indices, :]).transpose(1, 0).detach().cpu().numpy() if normalize: data = data / data.sum(axis=1, keepdims=True) if not keep_noise: @@ -242,14 +235,13 @@ def get_proportions( columns=column_names, index=index_names, ) - - @torch.inference_mode() + + @torch.inference_mode() def get_fine_celltypes( self, sc_model: CondSCVI, - indices: Sequence[int] | None = None, + indices=None, batch_size: int | None = None, - return_numpy: bool = False, ) -> np.ndarray | dict[str, pd.DataFrame]: """Returns the estimated cell-type specific latent space for the spatial data. @@ -261,46 +253,40 @@ def get_fine_celltypes( Indices of cells in adata to use. Only used if amortization. If `None`, all cells are used. batch_size Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`. - return_numpy - if activated, will return a numpy array of shape is n_spots x n_latent x n_labels. """ self._check_if_trained() - - column_names = [str(i) for i in np.arange(self.module.n_latent)] index_names = self.adata.obs.index - - if self.module.amortization in ["both", "latent"]: - stdl = self._make_data_loader( - adata=self.adata, indices=indices, batch_size=batch_size - ) - gamma_ = [] - proportions_modes_ = [] - for tensors in stdl: - inference_inputs = self.module._get_inference_input(tensors) - outputs = self.module.inference(**inference_inputs) - generative_inputs = self.module._get_generative_input(tensors, outputs) - generative_outputs = self.module.generative(**generative_inputs) - gamma_local = generative_outputs["gamma"] - if self.module.prior_mode == 'mog': - proportions_modes_local = generative_outputs['proportion_modes'] # pmc - gamma_local = gamma_local # pncm - else: - proportions_modes_local = torch.ones(gamma_local.shape[0], 1, 1) - gamma_local = gamma_local.squeeze(0) # pncm - gamma_ += [gamma_local.cpu()] - proportions_modes_ += [proportions_modes_local.cpu()] - - proportions_modes = torch.cat(proportions_modes_, dim=-1).numpy() - gamma = torch.cat(gamma_, dim=-1).numpy() - else: - if indices is not None: - logger.info( - "No amortization for latent values, ignoring adata and returning results for the full data" - ) - gamma = self.module.gamma.detach().cpu().numpy() - - sc_latent_distribution = sc_model.get_latent_representation(return_dist=True) - + stdl = self._make_data_loader( + adata=self.adata, indices=indices, batch_size=batch_size + ) + if sc_model.n_fine_labels is None: + raise RuntimeError('Single cell model does not contain fine labels. Please train the single-cell model with fine labels.') + predicted_fine_celltype_ = [] + for tensors in stdl: + inference_inputs = self.module._get_inference_input(tensors) + outputs = self.module.inference(**inference_inputs) + generative_inputs = self.module._get_generative_input(tensors, outputs) + generative_outputs = self.module.generative(**generative_inputs) + + gamma_local = generative_outputs["gamma"][0, ...].transpose(-2, -4) # c, n, p, m + proportions_modes_local = generative_outputs['proportion_modes'][0, ...] # pmc + n_modes, batch_size, n_celltypes = proportions_modes_local.shape + gamma_local_ = gamma_local.permute((3, 2, 0, 1)).reshape(-1, self.module.n_latent) # m*p*c, n + proportions_modes_local_ = proportions_modes_local.permute((1, 0, 2)).flatten() # m*p*c + v_local = generative_outputs['v'][..., :-1].flatten().repeat_interleave(n_modes) # m*p*c + label = torch.arange(self.module.n_labels, device=gamma_local.device).repeat(batch_size).repeat_interleave(n_modes).unsqueeze(-1) # m*p*c, 1 + predicted_fine_celltype_local = v_local.unsqueeze(-1) * proportions_modes_local_.unsqueeze(-1) * torch.nn.functional.softmax( + sc_model.module.classify(gamma_local_, label), dim=-1) + predicted_fine_celltype_sum = predicted_fine_celltype_local.reshape(batch_size, n_celltypes*n_modes, sc_model.n_fine_labels).sum(1) + predicted_fine_celltype_.append(predicted_fine_celltype_sum.detach().cpu()) + predicted_fine_celltype = torch.cat(predicted_fine_celltype_, dim=0).numpy() + + pred = pd.DataFrame( + predicted_fine_celltype, + columns=sc_model._fine_label_mapping, + index=index_names, + ) + return pred @torch.inference_mode() def get_gamma( @@ -335,22 +321,18 @@ def get_gamma( outputs = self.module.inference(**inference_inputs) generative_inputs = self.module._get_generative_input(tensors, outputs) generative_outputs = self.module.generative(**generative_inputs) - gamma_local = generative_outputs["gamma"] + gamma_local = generative_outputs["gamma"].squeeze(0) if self.module.prior_mode == 'mog': - proportions_model_local = generative_outputs['proportion_modes'] + proportions_model_local = generative_outputs['proportion_modes'].squeeze(0) gamma_local = torch.einsum('pncm,pmc->ncm', gamma_local, proportions_model_local) else: - gamma_local = gamma_local.squeeze(0) + gamma_local = gamma_local.squeeze(0).squeeze(0) gamma_ += [gamma_local.cpu()] data = torch.cat(gamma_, dim=-1).numpy() if indices is not None: index_names = index_names[indices] else: - if indices is not None: - logger.info( - "No amortization for latent values, ignoring adata and returning results for the full data" - ) - data = self.module.gamma.detach().cpu().numpy() + data = self.module.gamma[indices, :, :].detach().cpu().numpy() data = np.transpose(data, (2, 0, 1)) if return_numpy: @@ -415,13 +397,15 @@ def get_latent_representation( latent_qzv = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) - z, qz, _ = self.module.inference(**inference_inputs, n_samples=mc_samples).values() + inference_outputs = self.module.inference(**inference_inputs, n_samples=mc_samples).values() + z = inference_outputs['z'][0, ...] + qz = inference_outputs['qz'] if give_mean: - latent += [qz.loc.cpu()] + latent += [qz.loc[0, ...].cpu()] else: latent += [z.cpu()] - latent_qzm += [qz.loc.cpu()] - latent_qzv += [qz.scale.square().cpu()] + latent_qzm += [qz.loc[0, ...].cpu()] + latent_qzv += [qz.scale[0, ...].square().cpu()] return ( (torch.cat(latent_qzm).numpy(), torch.cat(latent_qzv).numpy()) if return_dist @@ -465,7 +449,7 @@ def get_scale_for_ct( inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) generative_inputs = self.module._get_generative_input(tensors, outputs) - px_scale = self.module.generative(**generative_inputs)["px_mu"][:, y, :] + px_scale = self.module.generative(**generative_inputs)["px_mu"][0, :, y, :] scale += [px_scale.cpu()] @@ -515,10 +499,10 @@ def get_expression_for_ct( outputs = self.module.inference(**inference_inputs) generative_inputs = self.module._get_generative_input(tensors, outputs) generative_outputs = self.module.generative(**generative_inputs) - px_scale, proportions = generative_outputs['px_mu'], generative_outputs['v'] - px_scale = torch.einsum('mkl,mk->mkl', px_scale, proportions) - px_scale_proportions = px_scale[:, y, :]/px_scale.sum(dim=1) - x_ct = inference_inputs['x'].to(px_scale_proportions.device) * px_scale_proportions + px_scale, proportions = generative_outputs['px_mu'][0, ...], generative_outputs['v'][0, ...] + px_scale_expected = torch.einsum('mkl,mk->mkl', px_scale, proportions) + px_scale_proportions = px_scale_expected[:, y, :]/px_scale_expected.sum(dim=1) + x_ct = tensors['X'].to(px_scale_proportions.device) * px_scale_proportions expression_ct += [x_ct.cpu()] data = torch.cat(expression_ct).numpy() @@ -606,8 +590,6 @@ def setup_anndata( adata: AnnData, layer: str | None = None, batch_key: str | None = None, - sc_batch_key: str | None = None, - categorical_covariate_keys: Sequence[str] | None = None, **kwargs, ): """%(summary)s. @@ -617,8 +599,6 @@ def setup_anndata( %(param_adata)s %(param_layer)s %(param_batch_key)s - sc_batch_key: - Categorical covariate keys need to line up with single cell model. """ setup_method_args = cls._get_setup_method_args(**locals()) # add index for each cell (provided to pyro plate for correct minibatching) @@ -627,7 +607,6 @@ def setup_anndata( LayerField(REGISTRY_KEYS.X_KEY, layer, is_count_data=True), NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), - CategoricalObsField("batch_index_sc", sc_batch_key), ] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args diff --git a/scvi/model/base/_archesmixin.py b/scvi/model/base/_archesmixin.py index d5cefcf660..be0e2db99c 100644 --- a/scvi/model/base/_archesmixin.py +++ b/scvi/model/base/_archesmixin.py @@ -66,9 +66,9 @@ def load_query_data( freeze_dropout Whether to freeze dropout during training freeze_expression - Freeze neurons corersponding to expression in first layer + Freeze neurons corresponding to expression in first layer freeze_decoder_first_layer - Freeze neurons corersponding to first layer in decoder + Freeze neurons corresponding to first layer in decoder freeze_batchnorm_encoder Whether to freeze batchnorm weight and bias during training for encoder freeze_batchnorm_decoder diff --git a/scvi/module/_mrdeconv.py b/scvi/module/_mrdeconv.py index 62852efd88..017edb8dc6 100644 --- a/scvi/module/_mrdeconv.py +++ b/scvi/module/_mrdeconv.py @@ -3,7 +3,7 @@ import numpy as np import torch -from torch.distributions import Normal +from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal from torch.distributions import kl_divergence as kl from scvi import REGISTRY_KEYS @@ -51,11 +51,7 @@ class MRDeconv(BaseModuleClass): Diagonal variance parameter for each component in the empirical prior over the latent space mp_vprior Mixture proportion in cell type sub-clustering of each component in the empirical prior - amortization - which of the latent variables to amortize inference over (gamma, proportions, both or none) - l1_reg - Scalar parameter indicating the strength of L1 regularization on cell type proportions. - A value of 50 leads to sparser results. + amortization beta_reg Scalar parameter indicating the strength of the variance penalty for the multiplicative offset in gene expression values (beta parameter). Default is 5 @@ -84,14 +80,16 @@ def __init__( px_decoder_state_dict: OrderedDict, px_r: np.ndarray, dropout_decoder: float, - n_batch_sc: Optional[list] = None, - dropout_amortization: float = 0.05, + augmentation: bool = False, + n_samples_augmentation: int = 1, + n_states_per_label: Tunable[int] = 1, + n_states_per_augmented_label: float = 1, + dropout_amortization: float = 0.03, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, mp_vprior: np.ndarray = None, amortization: Literal["none", "latent", "proportion", "both"] = "both", - l1_reg: Tunable[float] = 0.0, - beta_reg: Tunable[float] = 5.0, + beta_reg: Tunable[float] = 500.0, eta_reg: Tunable[float] = 1e-7, prior_mode: Literal["mog", "normal"] = "normal", n_latent_amortization: Optional[int] = None, @@ -106,21 +104,27 @@ def __init__( self.n_labels = n_labels self.n_hidden = n_hidden self.n_latent = n_latent + self.augmentation = augmentation + self.n_samples_augmentation = n_samples_augmentation + self.n_states_per_augmented_label = n_states_per_augmented_label self.dropout_decoder = dropout_decoder + self.n_states_per_label = n_states_per_label self.dropout_amortization = dropout_amortization self.n_genes = n_genes self.amortization = amortization - self.l1_reg = l1_reg self.beta_reg = beta_reg self.eta_reg = eta_reg self.prior_mode = prior_mode self.n_latent_amortization = n_latent_amortization # unpack and copy parameters _extra_decoder_kwargs = extra_decoder_kwargs or {} - cat_list = [n_labels, n_batch_sc] + cat_list = [n_labels] + self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, {}) + batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim + n_input_decoder = n_latent + batch_dim self.decoder = FCLayers( - n_in=n_latent, + n_in=n_input_decoder, n_out=n_hidden, n_cat_list=cat_list, n_layers=n_layers, @@ -130,17 +134,14 @@ def __init__( use_batch_norm=False, **_extra_decoder_kwargs, ) - self.px_decoder = torch.nn.Sequential( - torch.nn.Linear(n_hidden, n_genes), torch.nn.Softplus() - ) - # don't compute gradient for those parameters self.decoder.load_state_dict(decoder_state_dict) for param in self.decoder.parameters(): param.requires_grad = False + self.px_decoder = torch.nn.Linear(n_hidden, n_genes) self.px_decoder.load_state_dict(px_decoder_state_dict) for param in self.px_decoder.parameters(): param.requires_grad = False - self.register_buffer("px_o", torch.tensor(px_r)) + self.px_o = torch.nn.Parameter(px_r) # cell_type specific factor loadings self.V = torch.nn.Parameter(torch.randn(self.n_labels + 1, self.n_spots)) @@ -150,28 +151,34 @@ def __init__( torch.randn(n_latent, self.n_labels, self.n_spots) ) if mean_vprior is not None: - self.p = mean_vprior.shape[1] - self.register_buffer("mean_vprior", torch.tensor(mean_vprior)) - self.register_buffer("var_vprior", torch.tensor(var_vprior)) - self.register_buffer("mp_vprior", torch.tensor(mp_vprior)) + self.register_buffer("mean_vprior", mean_vprior) + self.register_buffer("var_vprior", var_vprior) + self.register_buffer("mp_vprior", mp_vprior) + cats = Categorical(probs=self.mp_vprior) + normal_dists = Independent( + Normal( + self.mean_vprior, + torch.sqrt(self.var_vprior) + 1e-4 + ), + reinterpreted_batch_ndims=1 + ) + self.qz_prior = MixtureSameFamily(cats, normal_dists) else: self.mean_vprior = None self.var_vprior = None # noise from data - self.eta = torch.nn.Parameter(torch.randn(self.n_genes)) + self.eta = torch.nn.Parameter(torch.zeros(self.n_genes)) # additive gene bias - self.beta = torch.nn.Parameter(0.01 * torch.randn(self.n_genes)) + self.beta = torch.nn.Parameter(torch.zeros(self.n_genes)) + print('beta is parameter') # create additional neural nets for amortization # within cell_type factor loadings _extra_encoder_kwargs = extra_encoder_kwargs or {} if self.prior_mode == "mog": - print('Using mixture of gaussians for prior') - return_dist = self.p * n_labels * n_latent + self.p * n_labels + return_dist = self.n_states_per_label * n_labels * n_latent + self.n_states_per_label * n_labels else: - print('Using normal prior') return_dist = n_labels * n_latent - print(f"return_dist: {return_dist}, {self.p}, {n_labels}, {n_latent}, {mean_vprior.shape}") if self.n_latent_amortization is not None: # Uses a combined latent space for proportions and gammas. self.z_encoder = Encoder( @@ -216,7 +223,7 @@ def identity(x, batch_index=None): n_cat_list=None, n_layers=n_layers, n_hidden=n_hidden, - dropout_rate=dropout_amortization, + dropout_rate=0, use_layer_norm=True, use_batch_norm=False, ), @@ -225,106 +232,174 @@ def identity(x, batch_index=None): def _get_inference_input(self, tensors): x = tensors[REGISTRY_KEYS.X_KEY] + m = x.shape[0] + n_samples = self.n_samples_augmentation + 1 + if self.augmentation and self.training: + with torch.no_grad(): + beta = torch.exp(self.beta) # n_genes + # beta = torch.cat([beta.view(1, 1, 1, -1), torch.ones_like(beta).view(1, 1, 1, -1).repeat(n_samples-1, 1, 1, 1)]) + prior_sampled = self.qz_prior.sample( + [n_samples, self.n_states_per_augmented_label, m]).reshape( + n_samples*self.n_states_per_augmented_label, -1, self.n_latent) + enum_label = ( + torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) + ) # m * n_labels, 1 + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, tensors[REGISTRY_KEYS.BATCH_KEY]) + batch_rep_input = batch_rep.repeat_interleave(self.n_labels, dim=0) + decoder_input = torch.cat([prior_sampled, batch_rep_input], dim=-1) + px_scale_augment_ = torch.nn.Softmax(dim=-1)(self.px_decoder(self.decoder(decoder_input, enum_label.to(x.device))) + beta.view(1, 1, -1)) + px_scale_augment = px_scale_augment_.reshape( + (n_samples*self.n_states_per_augmented_label, x.shape[0], self.n_labels, -1) + ) # (samples*states_per_cell, mi, n_labels, n_genes) + library = x.sum(-1).view(1, 1, m, 1, 1).repeat(n_samples, 1, 1, 1, 1) + library[1, ...] = library[1, ...] + 50 + px_scale_augment = px_scale_augment.reshape(n_samples, self.n_states_per_augmented_label, m, self.n_labels, -1) # (samples, states_per_cell, m, n_labels, n_genes) + px_rate = library * px_scale_augment # (samples, states_per_cell, m, n_labels, n_genes) + ratios_ct_augmentation = torch.distributions.Dirichlet( + torch.zeros(self.n_states_per_augmented_label * self.n_labels) + 0.03).sample([n_samples, m]).to(x.device) + ratios_ct_augmentation = ratios_ct_augmentation.reshape(n_samples, m, self.n_states_per_augmented_label, self.n_labels).permute(0, 2, 1, 3) + augmentation_rate = torch.einsum('ilmk, ilmkg -> img', ratios_ct_augmentation, px_rate) # (samples, m, n_genes) + ratio_augmentation_ = torch.distributions.Beta(0.4, 0.5).sample([self.n_samples_augmentation-1, m]).unsqueeze(-1).to(x.device) + ratio_augmentation = torch.cat([torch.zeros((1, m, 1), device=x.device), torch.ones((1, m, 1), device=x.device), ratio_augmentation_], dim=0) + augmented_counts = NegativeBinomial( + augmentation_rate, logits=self.px_o + ).sample() # (samples*states_per_cell, m, n_labels, n_genes) + # print('TTTT1', augmentation_rate[1, ...].sum(-1).min(), augmented_counts[1, ...].sum(-1).min(), x.sum(-1).min(), x.shape) + x_augmented = ( + (1 - ratio_augmentation) * x.unsqueeze(0) + + ratio_augmentation * augmented_counts + ) + else: + x_augmented = x.unsqueeze(0) + prior_sampled = None + ratios_ct_augmentation = None + ratio_augmentation = None + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - input_dict = {"x": x, "batch_index": batch_index} + input_dict = { + "batch_index": batch_index, + "x_augmented": x_augmented, + "prior_sampled": prior_sampled, + "ratios_ct_augmentation": ratios_ct_augmentation, + "ratio_augmentation": ratio_augmentation} return input_dict def _get_generative_input(self, tensors, inference_outputs): z = inference_outputs["z"] library = inference_outputs["library"] ind_x = tensors[REGISTRY_KEYS.INDICES_KEY].long().ravel() - batch_index_sc = tensors['batch_index_sc'] + 2. + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - input_dict = {"z": z, "ind_x": ind_x, "library": library, "batch_index_sc": batch_index_sc} + input_dict = {"z": z, "ind_x": ind_x, "library": library, "batch_index": batch_index} return input_dict @auto_move_data def inference( self, - x, + x_augmented, batch_index, n_samples=1, + prior_sampled=None, + ratios_ct_augmentation=None, + ratio_augmentation=None, ): """Runs the inference (encoder) model.""" - x_ = x - library = torch.log(x.sum(1)).unsqueeze(1) + x_ = x_augmented + library = x_augmented.sum(-1).unsqueeze(-1) x_ = torch.log(1 + x_) if self.n_latent_amortization is not None: qz, z = self.z_encoder(x_, batch_index) else: z = x_ - qz = Normal(x_, scale=1e-6*torch.ones_like(x_)) # dummy distribution - - outputs = {"z": z, "qz": qz, "library": library} + qz = Normal(x_, scale=1e-6*torch.ones_like(x_)) + + outputs = { + "z": z, + "qz": qz, + "library": library, + "x_augmented": x_augmented, + "prior_sampled": prior_sampled, + "ratio_augmentation": ratio_augmentation, + "ratios_ct_augmentation": ratios_ct_augmentation, + } return outputs @auto_move_data - def generative(self, z, ind_x, library, batch_index_sc): + def generative(self, z, ind_x, library, batch_index): """Build the deconvolution model for every cell in the minibatch.""" m = len(ind_x) # setup all non-linearities beta = torch.exp(self.beta) # n_genes - eps = torch.nn.functional.softplus(self.eta) # n_genes + eps = torch.nn.functional.softmax(self.eta, dim=-1) # n_genes + if self.training and self.augmentation: + n_samples = (self.n_samples_augmentation + 1) + else: + n_samples = 1 if self.amortization in ["both", "latent"]: if self.prior_mode == "mog": gamma_ = self.gamma_encoder(z) proportion_modes_logits = torch.transpose( - gamma_[:, -self.p*self.n_labels:], 0, 1).reshape( - (self.p, self.n_labels, m) - ).transpose(1, 2) - proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=0) + gamma_[:, :, -self.n_states_per_label*self.n_labels:], -1, -2).reshape( + (n_samples, self.n_states_per_label, self.n_labels, m) + ).transpose(-1, -2) # n_samples, n_states_per_label, m, n_labels + proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=-3) gamma_ind = torch.transpose( - gamma_[:, :self.p*self.n_labels*self.n_latent], 0, 1).reshape( - (self.p, self.n_latent, self.n_labels, -1) - ) + gamma_[:, :, :self.n_states_per_label*self.n_labels*self.n_latent], -1, -2).reshape( + (n_samples, self.n_states_per_label, self.n_latent, self.n_labels, m) + ) # n_samples, n_states_per_label, n_latent, n_labels, m else: gamma_ind = torch.transpose( self.gamma_encoder(z), 0, 1).reshape( - (1, self.n_latent, self.n_labels, -1) + (n_samples, 1, self.n_latent, self.n_labels, m) ) - proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=z.device) + proportion_modes_logits = proportion_modes = torch.ones( + (n_samples, 1, m, self.n_labels), device=z.device) else: - gamma_ind = self.gamma[:, :, ind_x].unsqueeze(0) # 1, n_latent, n_labels, minibatch_size - proportion_modes_logits = proportion_modes = torch.ones((1, self.n_labels), device=z.device) + gamma_ind = self.gamma[:, :, ind_x].unsqueeze(0).unsqueeze(0).repeat( + n_samples, 1, 1, 1, 1) # n_samples, n_latent, n_labels, m + proportion_modes_logits = proportion_modes = torch.ones((n_samples, 1, m, self.n_labels), device=z.device) if self.amortization in ["both", "proportion"]: v_ind = self.V_encoder(z) else: - v_ind = self.V[:, ind_x].T # minibatch_size, labels + 1 - v_ind = torch.nn.functional.softplus(v_ind) + v_ind = self.V[:, ind_x].T.unsqueeze(0).repeat( + n_samples, 1, 1) # n_samples, m, labels + 1 + v_ind = torch.nn.functional.softmax(v_ind, dim=-1) - px_est = torch.zeros((m, self.n_labels, self.n_genes), device=z.device) + px_est = torch.zeros((n_samples, m, self.n_labels, self.n_genes), device=z.device) enum_label = ( torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) - ) # minibatch_size * n_labels, 1 - batch_index_sc_input = batch_index_sc.repeat_interleave(self.n_labels, dim=0) + ) # m * n_labels, 1 + + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + batch_rep_input = batch_rep.repeat_interleave(self.n_labels, dim=0).unsqueeze(0).repeat(n_samples, 1, 1) - for mode in range(gamma_ind.shape[0]): + for mode in range(gamma_ind.shape[-4]): # reshape and get gene expression value for all minibatch gamma_ind_ = torch.transpose( - gamma_ind[mode, ...], 2, 0 - ) # minibatch_size, n_labels, n_latent + gamma_ind[:, mode, ...], -1, -3 + ) # n_samples, m, n_labels, n_latent gamma_reshape_ = gamma_ind_.reshape( - (-1, self.n_latent) - ) # minibatch_size * n_labels, n_latent - h = self.decoder(gamma_reshape_, enum_label.to(z.device), batch_index_sc_input) - px_est += proportion_modes[mode, ...].unsqueeze(-1) * self.px_decoder(h).reshape( - (m, self.n_labels, -1) - ) # (minibatch, n_labels, n_genes) + (n_samples, -1, self.n_latent) + ) # n_samples, m * n_labels, n_latent + decoder_input_ = torch.cat([gamma_reshape_, batch_rep_input], dim=-1) + h = self.decoder(decoder_input_, enum_label.to(z.device)) + px_est += proportion_modes[:, mode, ...].unsqueeze(-1) * torch.nn.Softmax(dim=-1)(self.px_decoder(h).reshape( + (n_samples, m, self.n_labels, -1) + ) + beta.view(1, 1, 1, -1)) # (n_samples, m, n_labels, n_genes) # add the dummy cell type - eps = eps.repeat((m, 1)).view( - m, 1, -1 - ) # (M, 1, n_genes) <- this is the dummy cell type + eps = eps.unsqueeze(0).repeat(n_samples, m, 1).unsqueeze(-2) # (n_samples, m, 1, n_genes) <- this is the dummy cell type - # account for gene specific bias and add noise + # account for gene specific bias and add noise, take sample without augmentation. r_hat = torch.cat( - [beta.unsqueeze(0).unsqueeze(1) * px_est, eps], dim=1 - ) # M, n_labels + 1, n_genes + [beta.view(1, 1, 1, -1) * px_est, eps], dim=-2 + ) # n_samples, m, n_labels + 1, n_genes + # now combine them for convolution - px_scale = torch.sum(v_ind.unsqueeze(2) * r_hat, dim=1) # batch_size, n_genes + px_scale = torch.sum(v_ind.unsqueeze(-1) * r_hat, dim=-2) # n_samples, m, n_genes px_rate = library * px_scale px_mu = torch.exp(self.px_o) * r_hat @@ -339,81 +414,126 @@ def generative(self, z, ind_x, library, batch_index_sc): "proportion_modes_logits": proportion_modes_logits, } + def _compute_cross_entropy(self, prob_true, prob_pred): + log_prob_pred = torch.log(prob_pred / prob_pred.sum(axis=-1, keepdim=True)) + prob_true = prob_true + 1e-20 + prob_true = prob_true / prob_true.sum(axis=-1, keepdim=True) + kl_div = torch.nn.functional.kl_div(log_prob_pred, prob_true, reduction='batchmean', log_target=False) + + return kl_div + def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0, - n_obs: int = 1.0, - weighting_cross_entropy: float = 1e-6, + ct_sparsity_weight: float = 0., + weighting_augmentation: float = 10., ): """Compute the loss.""" - x = tensors[REGISTRY_KEYS.X_KEY] + x_augmented = inference_outputs["x_augmented"] px_rate = generative_outputs["px_rate"] px_o = generative_outputs["px_o"] gamma = generative_outputs["gamma"] v = generative_outputs["v"] - - reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1) - + ratio_augmentation = inference_outputs['ratio_augmentation'] + ratios_ct_augmentation = inference_outputs['ratios_ct_augmentation'] + + n_samples = self.n_samples_augmentation + 1 + m = x_augmented.shape[1] + + if self.augmentation: + prior_sampled = inference_outputs['prior_sampled'].reshape(n_samples, self.n_states_per_augmented_label, x_augmented.shape[1], self.n_labels, self.n_latent) + mean_vprior = torch.cat( + [self.mean_vprior.unsqueeze(0).unsqueeze(-4).repeat(n_samples, m, 1, 1, 1), + prior_sampled.permute(0, 2, 3, 1, 4)], + dim=-2) # n_samples, m, n_labels, p+n_states_per_augmented_label, n_latent + var_vprior = torch.cat( + [self.var_vprior.unsqueeze(0).unsqueeze(-4).repeat(n_samples, m, 1, 1, 1), + torch.min(self.var_vprior, dim=-2).values.view( + 1, 1, self.n_labels, 1, self.n_latent).repeat(n_samples, m, 1, self.n_states_per_augmented_label, 1)], + dim=-2 + ) # n_samples, m, n_labels, p+n_states_per_augmented_label, n_latent + mp_vprior=torch.cat( + [(1- ratio_augmentation.unsqueeze(-1)) * self.mp_vprior.view(1, 1, self.n_labels, -1).repeat(n_samples, m, 1, 1), + ratio_augmentation.unsqueeze(-1) * ratios_ct_augmentation.permute(0, 2, 3, 1) + ], + dim=-1 + ) # n_samples, m, n_labels, p+n_states_per_augmented_label + else: + mean_vprior = self.mean_vprior.unsqueeze(0).unsqueeze(0) + var_vprior = self.var_vprior.unsqueeze(0).unsqueeze(0) + mp_vprior = self.mp_vprior.unsqueeze(0).unsqueeze(0) + + proportion_modes = generative_outputs["proportion_modes"] + reconst_loss = - NegativeBinomial(px_rate, logits=px_o).log_prob(x_augmented).sum(-1) # eta prior likelihood mean = torch.zeros_like(self.eta) scale = torch.ones_like(self.eta) glo_neg_log_likelihood_prior = ( -self.eta_reg * Normal(mean, scale).log_prob(self.eta).sum() ) - glo_neg_log_likelihood_prior += self.beta_reg * torch.var(self.beta) - - v_sparsity_loss = self.l1_reg * torch.abs(v).mean(1) + # beta loss + glo_neg_log_likelihood_prior += ( + -self.beta_reg * Normal(mean, scale).log_prob(self.beta).sum() + ) + if self.augmentation: + expected_proportions = ( + ratio_augmentation * torch.cat([ratios_ct_augmentation.sum(-3), torch.zeros([n_samples, m, 1]).to(v.device)], dim=-1) + + (1 - ratio_augmentation) * v[0, :, :] # unperturbed proportions + ) + #loss_augmentation = weighting_augmentation * torch.abs(v - expected_proportions).sum(-1) + loss_augmentation = 0 + for i in [1]: + loss_augmentation += weighting_augmentation * self._compute_cross_entropy(expected_proportions[i, ...].squeeze(0), v[i, ...].squeeze(0)) + else: + loss_augmentation = torch.tensor(0., device=x_augmented.device) # gamma prior likelihood if self.mean_vprior is None: # isotropic normal prior mean = torch.zeros_like(gamma) scale = torch.ones_like(gamma) - neg_log_likelihood_prior = -Normal(mean, scale).log_prob(gamma).sum(2).sum(1) + neg_log_likelihood_prior = - Normal(mean, scale).log_prob(gamma).sum(2).sum(1) + elif self.prior_mode == "mog": + # gamma is of shape n_samples, minibatch_size, 1, n_latent, n_labels + gamma = gamma.permute(1, 0, 4, 3, 2) # p, n_samples, minibatch_size, n_labels, n_latent + cats = Categorical(probs=mp_vprior) + normal_dists = Independent( + Normal( + mean_vprior, + var_vprior + ), + reinterpreted_batch_ndims=1 + ) + pre_lse = MixtureSameFamily(cats, normal_dists).log_prob(gamma) # p, n_samples, minibatch_size, n_labels + pre_lse = pre_lse.permute(1, 0, 2, 3) + log_likelihood_prior = torch.mul( + pre_lse, + proportion_modes + 1e-3 + ).sum(-3) # n_samples, minibatch, n_labels + neg_log_likelihood_prior = - log_likelihood_prior.sum(-1) else: - # vampprior - # gamma is of shape minibatch_size, 1, n_latent, n_labels - gamma = gamma.permute(3, 0, 2, 1) # minibatch_size, 1, n_labels, n_latent - mean_vprior = torch.transpose(self.mean_vprior, 0, 1).unsqueeze( - 0 - ) # 1, p, n_labels, n_latent - var_vprior = torch.transpose(self.var_vprior, 0, 1).unsqueeze( - 0 - ) # 1, p, n_labels, n_latent - mp_vprior = torch.transpose(self.mp_vprior, 0, 1) # p, n_labels - if self.prior_mode == "mog": - proportion_modes_logits = generative_outputs["proportion_modes_logits"] - proportion_modes = generative_outputs["proportion_modes"] - pre_lse = ( - Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) - ) + torch.log(1e-3 + proportion_modes).permute(1, 0, 2) # minibatch, p, n_labels - log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels - neg_log_likelihood_prior = - log_likelihood_prior.sum(1) # minibatch - - neg_log_likelihood_prior += weighting_cross_entropy * torch.nn.functional.cross_entropy( - proportion_modes_logits.permute(1, 0, 2), mp_vprior.repeat(x.shape[0], 1, 1), reduction='none').sum(1) - else: - pre_lse = ( - Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(3) - ) + torch.log(1e-3 + mp_vprior) # minibatch, p, n_labels - log_likelihood_prior = torch.logsumexp(pre_lse, 1) # minibatch, n_labels - neg_log_likelihood_prior = - log_likelihood_prior.sum(1) # minibatch - + gamma = gamma.permute(0, 4, 1, 3, 2) # n_samples, minibatch_size, 1, n_labels, n_latent + mean_vprior = torch.transpose(mean_vprior, -3, -2) # n_samples, m, p, n_labels, n_latent + var_vprior = torch.transpose(var_vprior, -3, -2) # n_samples, m, p, n_labels, n_latent + mp_vprior = torch.transpose(mp_vprior, -2, -1) # n_samples, m, p, n_labels + pre_lse = ( + Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4).log_prob(gamma).sum(-1).squeeze(-3) + ) + torch.log(1e-3 + mp_vprior) # n_samples, minibatch, p, n_labels + log_likelihood_prior = torch.logsumexp(pre_lse, -2) # n_samples, minibatch, n_labels + neg_log_likelihood_prior = - log_likelihood_prior.sum(-1) # n_samples, minibatch if self.n_latent_amortization is not None: neg_log_likelihood_prior += kl( inference_outputs["qz"], Normal(torch.zeros([self.n_latent_amortization], device=x.device), torch.ones([self.n_latent_amortization], device=x.device)) ).sum(dim=-1) - - # High v_sparsity_loss is detrimental early in training, scaling by kl_weight to increase over training epochs. - loss = n_obs * ( - torch.mean( - reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss) - ) - + glo_neg_log_likelihood_prior + + v_sparsity_loss = ct_sparsity_weight * torch.distributions.Categorical(probs=v[0, :, :]).entropy() + + loss = torch.mean( + reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss) + glo_neg_log_likelihood_prior + loss_augmentation ) return LossOutput( @@ -433,67 +553,3 @@ def sample( """Sample from the posterior.""" raise NotImplementedError("No sampling method for DestVI") - - @auto_move_data - def get_ct_specific_expression( - self, x: torch.Tensor = None, ind_x: torch.Tensor = None, y: int = None, - batch_index: torch.Tensor = None, batch_index_sc: torch.Tensor = None - ): - """Returns cell type specific gene expression at the queried spots. - - Parameters - ---------- - x - tensor of data - ind_x - tensor of indices - y - integer for cell types - batch_index_sc - tensor of corresponding batch in single cell data for decoder - """ - # cell-type specific gene expression, shape (minibatch, celltype, gene). - beta = torch.exp(self.beta) # n_genes - y_torch = (y * torch.ones_like(ind_x)).ravel() - # obtain the relevant gammas - if self.amortization in ["both", "latent"]: - x_ = torch.log(1 + x) - z_amortization = self.amortization_network(x_, batch_index)[:, :self.n_latent_amortization] - if self.prior_mode == "mog": - gamma_ = self.gamma_encoder(z_amortization) - proportion_modes_logits = torch.transpose( - gamma_[:, -self.p*self.n_labels:], 0, 1).reshape( - (self.p, 1, -1) - ).transpose(1, 2) - proportion_modes = torch.nn.functional.softmax(proportion_modes_logits, dim=0) - # shape (p, n_labels, minibatch_size) - gamma_ind = torch.transpose( - gamma_[:, :self.p*self.n_labels*self.n_latent], 0, 1).reshape( - (self.p, self.n_latent, self.n_labels, -1) - ) - else: - gamma_ind = torch.transpose( - self.gamma_encoder(z_amortization), 0, 1).reshape( - (1, self.n_latent, self.n_labels, -1) - ) - proportion_modes = torch.ones((1, self.n_labels), device=x.device) - else: - gamma_ind = self.gamma[:, :, ind_x].unsqueeze(0) # 1, n_latent, n_labels, minibatch_size - proportion_modes = torch.ones((1, self.n_labels), device=x.device) - gamma_ind = gamma_ind[:, :, y, :] - proportion_modes = proportion_modes[:, y] - - px_est = torch.zeros((x.shape[0], self.n_genes), device=x.device) - for mode in range(gamma_ind.shape[0]): - # reshape and get gene expression value for all minibatch - gamma_ind_ = torch.transpose( - gamma_ind[mode, ...], 1, 0 - ) # minibatch_size, n_latent - h = self.decoder(gamma_ind_, y_torch.unsqueeze(1), batch_index_sc.unsqueeze(1)) - px_est += proportion_modes[mode, ...].unsqueeze(-1) * self.px_decoder(h).reshape( - (x.shape[0], -1) - ) # (minibatch, n_genes) - - px_scale_ct = torch.exp(self.px_o).unsqueeze(0) * beta.unsqueeze(0) * px_est - return px_scale_ct # shape (minibatch, genes) - diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index 370b1b59fc..bb703aabd8 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -5,19 +5,21 @@ import torch from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal from torch.distributions import kl_divergence as kl +from torch.nn import functional as F from scvi import REGISTRY_KEYS +from ._classifier import Classifier from scvi._types import Tunable from scvi.data._constants import ADATA_MINIFY_TYPE from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseMinifiedModeModuleClass, LossOutput, auto_move_data -from scvi.nn import Encoder, FCLayers +from scvi.module.base import EmbeddingModuleMixin, BaseMinifiedModeModuleClass, LossOutput, auto_move_data +from scvi.nn import Encoder, FCLayers, DecoderSCVI torch.backends.cudnn.benchmark = True # Conditional VAE model -class VAEC(BaseMinifiedModeModuleClass): +class VAEC(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): """Conditional Variational auto-encoder model. This is an implementation of the CondSCVI model @@ -53,6 +55,7 @@ def __init__( n_input: int, n_batch: int = 0, n_labels: int = 0, + n_fine_labels: Optional[int] = None, n_hidden: Tunable[int] = 128, n_latent: Tunable[int] = 5, n_layers: Tunable[int] = 2, @@ -62,8 +65,8 @@ def __init__( encode_covariates: bool = False, extra_encoder_kwargs: Optional[dict] = None, extra_decoder_kwargs: Optional[dict] = None, + linear_classifier: bool = True, prior: str = 'normal', - df_ct_id_dict: dict = None, num_classes_mog: Optional[int] = 10, ): super().__init__() @@ -79,15 +82,15 @@ def __init__( # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels + self.n_fine_labels = n_fine_labels self.prior = prior - if df_ct_id_dict is not None: - self.num_classes_mog = max([v[2] for v in df_ct_id_dict.values()]) + 1 - mapping_mog = torch.tensor([v[2] for _, v in sorted(df_ct_id_dict.items())]) - self.register_buffer("mapping_mog", mapping_mog) - else: - self.num_classes_mog = num_classes_mog - cat_list = [n_labels, n_batch] + self.num_classes_mog = num_classes_mog + self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, {}) + batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim + + cat_list = [n_labels] encoder_cat_list = cat_list if self.encode_covariates else [n_labels] + n_input_encoder += batch_dim * encode_covariates # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(n_input)) @@ -95,7 +98,7 @@ def __init__( # z encoder goes from the n_input-dimensional data to an n_latent-d _extra_encoder_kwargs = {} self.z_encoder = Encoder( - n_input, + n_input_encoder, n_latent, n_cat_list=encoder_cat_list, n_layers=n_layers, @@ -107,11 +110,37 @@ def __init__( return_dist=True, **_extra_encoder_kwargs, ) + if n_fine_labels is not None: + cls_parameters = { + "n_layers": 0, + "n_hidden": 0, + "dropout_rate": dropout_rate, + "logits": True, + } + # linear mapping from latent space to a coarse-celltype aware space + self.linear_mapping = FCLayers( + n_in=n_latent, + n_out=n_hidden, + n_cat_list=[n_labels], + use_layer_norm=True, + dropout_rate=0.0, + ) + + self.classifier = Classifier( + n_hidden, + n_labels=n_fine_labels, + use_batch_norm=False, + use_layer_norm=True, + **cls_parameters, + ) + else: + self.classifier = None # decoder goes from n_latent-dimensional space to n_input-d data _extra_decoder_kwargs = {} + n_input_decoder = n_latent + batch_dim self.decoder = FCLayers( - n_in=n_latent, + n_in=n_input_decoder, n_out=n_hidden, n_cat_list=cat_list, n_layers=n_layers, @@ -122,9 +151,7 @@ def __init__( use_layer_norm=True, **_extra_decoder_kwargs, ) - self.px_decoder = torch.nn.Sequential( - torch.nn.Linear(n_hidden, n_input), torch.nn.Softplus() - ) + self.px_decoder = torch.nn.Linear(n_hidden, n_input) if ct_weight is not None: ct_weight = torch.tensor(ct_weight, dtype=torch.float32) @@ -134,7 +161,7 @@ def __init__( if self.prior=='mog': self.prior_means = torch.nn.Parameter( 0.01 * torch.randn([n_labels, self.num_classes_mog, n_latent])) - self.prior_log_scales = torch.nn.Parameter( + self.prior_log_std = torch.nn.Parameter( torch.zeros([n_labels, self.num_classes_mog, n_latent])) self.prior_logits = torch.nn.Parameter( torch.zeros([n_labels, self.num_classes_mog])) @@ -192,11 +219,9 @@ def _regular_inference(self, x, y, batch_index, n_samples=1): library = x.sum(1).unsqueeze(1) if self.log_variational: x_ = torch.log(1 + x_) - if self.encode_covariates: - categorical_input = [y, batch_index] - else: - categorical_input = [y] - qz, z = self.z_encoder(x_, *categorical_input) + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + encoder_input = torch.cat([x_, batch_rep], dim=-1) + qz, z = self.z_encoder(encoder_input, y) if n_samples > 1: untran_z = qz.sample((n_samples,)) @@ -226,12 +251,39 @@ def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): ) outputs = {"z": z, "qz": qz, "library": library} return outputs + + @auto_move_data + def classify( + self, + z: torch.Tensor, + label_index: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass through the encoder and classifier. + + Parameters + ---------- + z + Tensor of shape ``(n_obs, n_latent)``. + label_index + Tensor of shape ``(n_obs,)`` denoting label indices. + + Returns + ------- + Tensor of shape ``(n_obs, n_labels)`` denoting logit scores per label. + """ + if len(label_index.shape)==1: + label_index = label_index.unsqueeze(1) + classifier_latent = self.linear_mapping(z, label_index) + w_y = self.classifier(classifier_latent) + return w_y @auto_move_data def generative(self, z, library, y, batch_index): """Runs the generative model.""" - h = self.decoder(z, y, batch_index) - px_scale = self.px_decoder(h) + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + decoder_input = torch.cat([decoder_input, batch_rep], dim=-1) + h = self.decoder(decoder_input, y, batch_index) + px_scale = torch.nn.Softmax(dim=-1)(self.px_decoder(h)) px_rate = library * px_scale px = NegativeBinomial(px_rate, logits=self.px_r) return {"px": px, "px_scale": px_scale} @@ -242,30 +294,24 @@ def loss( inference_outputs, generative_outputs, kl_weight: float = 1.0, + classification_ratio = 5., ): """Loss computation.""" x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY].ravel().long() qz = inference_outputs["qz"] px = generative_outputs["px"] - fine_celltypes = tensors['fine_labels'].ravel().long() if 'fine_labels' in tensors.keys() else None + fine_labels = tensors['fine_labels'].ravel().long() if 'fine_labels' in tensors.keys() else None if self.prior == "mog": indexed_means = self.prior_means[y] - indexed_log_scales = self.prior_log_scales[y] + indexed_log_std = self.prior_log_std[y] indexed_logits = self.prior_logits[y] - - # Assigns zero meaning equal weight to all unlabeled cells. Otherwise biases to sample from respective MoG. - if fine_celltypes is not None: - logits_input = torch.nn.functional.one_hot( - self.mapping_mog[fine_celltypes], self.num_classes_mog) - cats = Categorical(logits=10*logits_input + indexed_logits) - else: - cats = Categorical(logits=indexed_logits) - normal_dists = torch.distributions.Independent( + cats = Categorical(logits=indexed_logits) + normal_dists = Independent( Normal( indexed_means, - torch.exp(indexed_log_scales) + 1e-4 + torch.exp(indexed_log_std) + 1e-4 ), reinterpreted_batch_ndims=1 ) @@ -280,6 +326,12 @@ def loss( reconst_loss = -px.log_prob(x).sum(-1) scaling_factor = self.ct_weight[y] + + if self.classifier is not None: + fine_labels = fine_labels.view(-1) + logits = self.classify(qz.loc, label_index=tensors[REGISTRY_KEYS.LABELS_KEY]) # (n_obs, n_labels) + reconst_loss = reconst_loss + classification_ratio * F.cross_entropy(logits, fine_labels, reduction="none") + loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) return LossOutput( From f8a68a2b2a85d54af3eb06a0b312d157843d4101 Mon Sep 17 00:00:00 2001 From: cane11 Date: Wed, 10 Jul 2024 22:32:03 -0700 Subject: [PATCH 09/12] Updated DestVI --- scvi/model/base/_vaemixin.py | 216 -------------------- scvi/module/_vaec.py | 381 ----------------------------------- src/scvi/model/_condscvi.py | 14 +- src/scvi/model/_destvi.py | 58 +++--- src/scvi/model/_scanvi.py | 3 + src/scvi/model/_scvi.py | 1 + src/scvi/module/_mrdeconv.py | 245 ++++++++++++---------- src/scvi/module/_vaec.py | 356 +++++++++++++++++++++----------- 8 files changed, 420 insertions(+), 854 deletions(-) delete mode 100644 scvi/model/base/_vaemixin.py delete mode 100644 scvi/module/_vaec.py diff --git a/scvi/model/base/_vaemixin.py b/scvi/model/base/_vaemixin.py deleted file mode 100644 index ded10a8084..0000000000 --- a/scvi/model/base/_vaemixin.py +++ /dev/null @@ -1,216 +0,0 @@ -import logging -from collections.abc import Sequence -from typing import Optional, Union - -import numpy as np -import torch -from anndata import AnnData - -from scvi.utils import unsupported_if_adata_minified - -from ._log_likelihood import compute_elbo, compute_reconstruction_error - -logger = logging.getLogger(__name__) - - -class VAEMixin: - """Univseral VAE methods.""" - - @torch.inference_mode() - @unsupported_if_adata_minified - def get_elbo( - self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - batch_size: Optional[int] = None, - ) -> float: - """Return the ELBO for the data. - - The ELBO is a lower bound on the log likelihood of the data used for optimization - of VAEs. Note, this is not the negative ELBO, higher is better. - - Parameters - ---------- - adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. - indices - Indices of cells in adata to use. If `None`, all cells are used. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - """ - adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) - elbo = compute_elbo(self.module, scdl) - return -elbo - - @torch.inference_mode() - @unsupported_if_adata_minified - def get_marginal_ll( - self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - n_mc_samples: int = 1000, - batch_size: Optional[int] = None, - return_mean: Optional[bool] = True, - **kwargs, - ) -> Union[torch.Tensor, float]: - """Return the marginal LL for the data. - - The computation here is a biased estimator of the marginal log likelihood of the data. - Note, this is not the negative log likelihood, higher is better. - - Parameters - ---------- - adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. - indices - Indices of cells in adata to use. If `None`, all cells are used. - n_mc_samples - Number of Monte Carlo samples to use for marginal LL estimation. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - return_mean - If False, return the marginal log likelihood for each observation. - Otherwise, return the mmean arginal log likelihood. - """ - adata = self._validate_anndata(adata) - if indices is None: - indices = np.arange(adata.n_obs) - scdl = self._make_data_loader( - adata=adata, - indices=indices, - batch_size=batch_size, - shuffle=False, - ) - if hasattr(self.module, "marginal_ll"): - log_lkl = [] - for tensors in scdl: - log_lkl.append( - self.module.marginal_ll( - tensors, - n_mc_samples=n_mc_samples, - return_mean=return_mean, - **kwargs, - ) - ) - if not return_mean: - return torch.cat(log_lkl, 0) - else: - return np.mean(log_lkl) - else: - raise NotImplementedError( - "marginal_ll is not implemented for current model. " - "Please raise an issue on github if you need it." - ) - - @torch.inference_mode() - @unsupported_if_adata_minified - def get_reconstruction_error( - self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - batch_size: Optional[int] = None, - ) -> float: - r"""Return the reconstruction error for the data. - - This is typically written as :math:`p(x \mid z)`, the likelihood term given one posterior sample. - Note, this is not the negative likelihood, higher is better. - - Parameters - ---------- - adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. - indices - Indices of cells in adata to use. If `None`, all cells are used. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - return_mean - If False, return the marginal log likelihood for each observation. - Otherwise, return the mmean arginal log likelihood. - """ - adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) - reconstruction_error = compute_reconstruction_error(self.module, scdl) - return reconstruction_error - - @torch.inference_mode() - def get_latent_representation( - self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - give_mean: bool = True, - mc_samples: int = 5000, - batch_size: Optional[int] = None, - return_dist: bool = False, - ) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray]]: - """Return the latent representation for each cell. - - This is typically denoted as :math:`z_n`. - - Parameters - ---------- - adata - AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the - AnnData object used to initialize the model. - indices - Indices of cells in adata to use. If `None`, all cells are used. - give_mean - Give mean of distribution or sample from it. - mc_samples - For distributions with no closed-form mean (e.g., `logistic normal`), how many Monte Carlo - samples to take for computing mean. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - return_dist - Return (mean, variance) of distributions instead of just the mean. - If `True`, ignores `give_mean` and `mc_samples`. In the case of the latter, - `mc_samples` is used to compute the mean of a transformed distribution. - If `return_dist` is true the untransformed mean and variance are returned. - - Returns - ------- - Low-dimensional representation for each cell or a tuple containing its mean and variance. - """ - self._check_if_trained(warn=False) - - adata = self._validate_anndata(adata) - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) - latent = [] - latent_qzm = [] - latent_qzv = [] - for tensors in scdl: - inference_inputs = self.module._get_inference_input(tensors) - outputs = self.module.inference(**inference_inputs) - if "qz" in outputs: - qz = outputs["qz"] - else: - qz_m, qz_v = outputs["qz_m"], outputs["qz_v"] - qz = torch.distributions.Normal(qz_m, qz_v.sqrt()) - if give_mean: - # does each model need to have this latent distribution param? - if self.module.latent_distribution == "ln": - samples = qz.sample([mc_samples]) - z = torch.nn.functional.softmax(samples, dim=-1) - z = z.mean(dim=0) - else: - z = qz.loc - else: - z = outputs["z"] - - latent += [z.cpu()] - latent_qzm += [qz.loc.cpu()] - latent_qzv += [qz.scale.square().cpu()] - return ( - (torch.cat(latent_qzm).numpy(), torch.cat(latent_qzv).numpy()) - if return_dist - else torch.cat(latent).numpy() - ) diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py deleted file mode 100644 index bb703aabd8..0000000000 --- a/scvi/module/_vaec.py +++ /dev/null @@ -1,381 +0,0 @@ -from collections.abc import Iterable -from typing import Optional - -import numpy as np -import torch -from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal -from torch.distributions import kl_divergence as kl -from torch.nn import functional as F - -from scvi import REGISTRY_KEYS -from ._classifier import Classifier -from scvi._types import Tunable -from scvi.data._constants import ADATA_MINIFY_TYPE -from scvi.distributions import NegativeBinomial -from scvi.module.base import EmbeddingModuleMixin, BaseMinifiedModeModuleClass, LossOutput, auto_move_data -from scvi.nn import Encoder, FCLayers, DecoderSCVI - -torch.backends.cudnn.benchmark = True - - -# Conditional VAE model -class VAEC(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): - """Conditional Variational auto-encoder model. - - This is an implementation of the CondSCVI model - - Parameters - ---------- - n_input - Number of input genes - n_batch - Number of batches - n_labels - Number of labels - n_hidden - Number of nodes per hidden layer - n_latent - Dimensionality of the latent space - n_layers - Number of hidden layers used for encoder and decoder NNs - log_variational - Log(data+1) prior to encoding for numerical stability. Not normalization. - ct_weight - Multiplicative weight for cell type specific latent space. - dropout_rate - Dropout rate for the encoder and decoder neural network. - extra_encoder_kwargs - Keyword arguments passed into :class:`~scvi.nn.Encoder`. - extra_decoder_kwargs - Keyword arguments passed into :class:`~scvi.nn.FCLayers`. - """ - - def __init__( - self, - n_input: int, - n_batch: int = 0, - n_labels: int = 0, - n_fine_labels: Optional[int] = None, - n_hidden: Tunable[int] = 128, - n_latent: Tunable[int] = 5, - n_layers: Tunable[int] = 2, - log_variational: bool = True, - ct_weight: np.ndarray = None, - dropout_rate: Tunable[float] = 0.05, - encode_covariates: bool = False, - extra_encoder_kwargs: Optional[dict] = None, - extra_decoder_kwargs: Optional[dict] = None, - linear_classifier: bool = True, - prior: str = 'normal', - num_classes_mog: Optional[int] = 10, - ): - super().__init__() - self.dispersion = "gene" - self.n_latent = n_latent - self.n_layers = n_layers - self.n_hidden = n_hidden - self.dropout_rate = dropout_rate - self.encode_covariates = encode_covariates - self.log_variational = log_variational - self.gene_likelihood = "nb" - self.latent_distribution = "normal" - # Automatically deactivate if useless - self.n_batch = n_batch - self.n_labels = n_labels - self.n_fine_labels = n_fine_labels - self.prior = prior - self.num_classes_mog = num_classes_mog - self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, {}) - batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim - - cat_list = [n_labels] - encoder_cat_list = cat_list if self.encode_covariates else [n_labels] - n_input_encoder += batch_dim * encode_covariates - - # gene dispersion - self.px_r = torch.nn.Parameter(torch.randn(n_input)) - - # z encoder goes from the n_input-dimensional data to an n_latent-d - _extra_encoder_kwargs = {} - self.z_encoder = Encoder( - n_input_encoder, - n_latent, - n_cat_list=encoder_cat_list, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - inject_covariates=True, - use_batch_norm=False, - use_layer_norm=True, - return_dist=True, - **_extra_encoder_kwargs, - ) - if n_fine_labels is not None: - cls_parameters = { - "n_layers": 0, - "n_hidden": 0, - "dropout_rate": dropout_rate, - "logits": True, - } - # linear mapping from latent space to a coarse-celltype aware space - self.linear_mapping = FCLayers( - n_in=n_latent, - n_out=n_hidden, - n_cat_list=[n_labels], - use_layer_norm=True, - dropout_rate=0.0, - ) - - self.classifier = Classifier( - n_hidden, - n_labels=n_fine_labels, - use_batch_norm=False, - use_layer_norm=True, - **cls_parameters, - ) - else: - self.classifier = None - - # decoder goes from n_latent-dimensional space to n_input-d data - _extra_decoder_kwargs = {} - n_input_decoder = n_latent + batch_dim - self.decoder = FCLayers( - n_in=n_input_decoder, - n_out=n_hidden, - n_cat_list=cat_list, - n_layers=n_layers, - n_hidden=n_hidden, - dropout_rate=dropout_rate, - inject_covariates=True, - use_batch_norm=False, - use_layer_norm=True, - **_extra_decoder_kwargs, - ) - self.px_decoder = torch.nn.Linear(n_hidden, n_input) - - if ct_weight is not None: - ct_weight = torch.tensor(ct_weight, dtype=torch.float32) - else: - ct_weight = torch.ones((self.n_labels,), dtype=torch.float32) - self.register_buffer("ct_weight", ct_weight) - if self.prior=='mog': - self.prior_means = torch.nn.Parameter( - 0.01 * torch.randn([n_labels, self.num_classes_mog, n_latent])) - self.prior_log_std = torch.nn.Parameter( - torch.zeros([n_labels, self.num_classes_mog, n_latent])) - self.prior_logits = torch.nn.Parameter( - torch.zeros([n_labels, self.num_classes_mog])) - - def _get_inference_input(self, tensors): - y = tensors[REGISTRY_KEYS.LABELS_KEY] - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - if self.minified_data_type is None: - x = tensors[REGISTRY_KEYS.X_KEY] - input_dict = { - "x": x, - "y": y, - "batch_index": batch_index, - } - else: - if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): - qzm = tensors[REGISTRY_KEYS.LATENT_QZM_KEY] - qzv = tensors[REGISTRY_KEYS.LATENT_QZV_KEY] - observed_lib_size = tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE] - input_dict = { - "qzm": qzm, - "qzv": qzv, - "observed_lib_size": observed_lib_size, - "y": y, - "batch_index": batch_index, - } - else: - raise NotImplementedError( - f"Unknown minified-data type: {self.minified_data_type}" - ) - - return input_dict - - def _get_generative_input(self, tensors, inference_outputs): - z = inference_outputs["z"] - library = inference_outputs["library"] - y = tensors[REGISTRY_KEYS.LABELS_KEY] - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - - input_dict = { - "z": z, - "library": library, - "y": y, - "batch_index": batch_index, - } - return input_dict - - @auto_move_data - def _regular_inference(self, x, y, batch_index, n_samples=1): - """High level inference method. - - Runs the inference (encoder) model. - """ - x_ = x - library = x.sum(1).unsqueeze(1) - if self.log_variational: - x_ = torch.log(1 + x_) - batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) - encoder_input = torch.cat([x_, batch_rep], dim=-1) - qz, z = self.z_encoder(encoder_input, y) - - if n_samples > 1: - untran_z = qz.sample((n_samples,)) - z = self.z_encoder.z_transformation(untran_z) - library = library.unsqueeze(0).expand( - (n_samples, library.size(0), library.size(1)) - ) - - outputs = {"z": z, "qz": qz, "library": library} - return outputs - - @auto_move_data - def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): - if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): - qz = Normal(qzm, qzv.sqrt()) - # use dist.sample() rather than rsample because we aren't optimizing the z here - untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) - z = self.z_encoder.z_transformation(untran_z) - library = observed_lib_size - if n_samples > 1: - library = library.unsqueeze(0).expand( - (n_samples, library.size(0), library.size(1)) - ) - else: - raise NotImplementedError( - f"Unknown minified-data type: {self.minified_data_type}" - ) - outputs = {"z": z, "qz": qz, "library": library} - return outputs - - @auto_move_data - def classify( - self, - z: torch.Tensor, - label_index: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Forward pass through the encoder and classifier. - - Parameters - ---------- - z - Tensor of shape ``(n_obs, n_latent)``. - label_index - Tensor of shape ``(n_obs,)`` denoting label indices. - - Returns - ------- - Tensor of shape ``(n_obs, n_labels)`` denoting logit scores per label. - """ - if len(label_index.shape)==1: - label_index = label_index.unsqueeze(1) - classifier_latent = self.linear_mapping(z, label_index) - w_y = self.classifier(classifier_latent) - return w_y - - @auto_move_data - def generative(self, z, library, y, batch_index): - """Runs the generative model.""" - batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) - decoder_input = torch.cat([decoder_input, batch_rep], dim=-1) - h = self.decoder(decoder_input, y, batch_index) - px_scale = torch.nn.Softmax(dim=-1)(self.px_decoder(h)) - px_rate = library * px_scale - px = NegativeBinomial(px_rate, logits=self.px_r) - return {"px": px, "px_scale": px_scale} - - def loss( - self, - tensors, - inference_outputs, - generative_outputs, - kl_weight: float = 1.0, - classification_ratio = 5., - ): - """Loss computation.""" - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY].ravel().long() - qz = inference_outputs["qz"] - px = generative_outputs["px"] - fine_labels = tensors['fine_labels'].ravel().long() if 'fine_labels' in tensors.keys() else None - - if self.prior == "mog": - indexed_means = self.prior_means[y] - indexed_log_std = self.prior_log_std[y] - indexed_logits = self.prior_logits[y] - cats = Categorical(logits=indexed_logits) - normal_dists = Independent( - Normal( - indexed_means, - torch.exp(indexed_log_std) + 1e-4 - ), - reinterpreted_batch_ndims=1 - ) - prior = MixtureSameFamily(cats, normal_dists) - u = qz.rsample(sample_shape=(30,)) - # (sample, n_obs, n_latent) -> (sample, n_obs,) - kl_divergence_z = - (prior.log_prob(u) - qz.log_prob(u).sum(-1)).mean(0) - else: - mean = torch.zeros_like(qz.loc) - scale = torch.ones_like(qz.scale) - kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) - - reconst_loss = -px.log_prob(x).sum(-1) - scaling_factor = self.ct_weight[y] - - if self.classifier is not None: - fine_labels = fine_labels.view(-1) - logits = self.classify(qz.loc, label_index=tensors[REGISTRY_KEYS.LABELS_KEY]) # (n_obs, n_labels) - reconst_loss = reconst_loss + classification_ratio * F.cross_entropy(logits, fine_labels, reduction="none") - - loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) - - return LossOutput( - loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z - ) - - @torch.inference_mode() - def sample( - self, - tensors, - n_samples=1, - ) -> np.ndarray: - r"""Generate observation samples from the posterior predictive distribution. - - The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. - - Parameters - ---------- - tensors - Tensors dict - n_samples - Number of required samples for each cell - - Returns - ------- - x_new : :py:class:`torch.Tensor` - tensor with shape (n_cells, n_genes, n_samples) - """ - inference_kwargs = {"n_samples": n_samples} - generative_outputs = self.forward( - tensors, - inference_kwargs=inference_kwargs, - compute_loss=False, - )[1] - - px_r = generative_outputs["px_r"] - px_rate = generative_outputs["px_rate"] - - dist = NegativeBinomial(px_rate, logits=px_r) - if n_samples > 1: - exprs = dist.sample().permute( - [1, 2, 0] - ) # Shape : (n_cells_batch, n_genes, n_samples) - else: - exprs = dist.sample() - - return exprs.cpu() diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index 88fc371685..afa43ff801 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -11,9 +11,9 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager from scvi.data._utils import _get_adata_minify_type, get_anndata_attribute -from scvi.data.fields import CategoricalJointObsField, CategoricalObsField, LayerField +from scvi.data.fields import CategoricalObsField, LabelsWithUnlabeledObsField, LayerField from scvi.model.base import ( - BaseMinifiedModeModelClass, + BaseModelClass, RNASeqMixin, UnsupervisedTrainingMixin, VAEMixin, @@ -25,9 +25,7 @@ logger = logging.getLogger(__name__) -class CondSCVI( - RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseMinifiedModeModelClass -): +class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): """Conditional version of single-cell Variational Inference, used for multi-resolution deconvolution of spatial transcriptomics data :cite:p:`Lopez22`. Parameters @@ -434,7 +432,7 @@ def setup_anndata( labels_key: str | None = None, fine_labels_key: str | None = None, layer: str | None = None, - batch_key: str | None = None, + unlabeled_category: str = "unlabeled", **kwargs, ): """%(summary)s. @@ -446,6 +444,7 @@ def setup_anndata( %(param_labels_key)s fine_labels_key Key in `adata.obs` where fine-grained labels are stored. + %(unlabeled_category)ss %(param_layer)s %(param_batch_key)s """ @@ -457,8 +456,7 @@ def setup_anndata( ] if fine_labels_key is not None: anndata_fields.append( - CategoricalObsField('fine_labels', fine_labels_key - ) + LabelsWithUnlabeledObsField('fine_labels', fine_labels_key, unlabeled_category), ) # register new fields if the adata is minified diff --git a/src/scvi/model/_destvi.py b/src/scvi/model/_destvi.py index 8fa33a63e3..a42d330ecd 100644 --- a/src/scvi/model/_destvi.py +++ b/src/scvi/model/_destvi.py @@ -83,7 +83,8 @@ def __init__( cell_type_mapping: np.ndarray, decoder_state_dict: OrderedDict, px_decoder_state_dict: OrderedDict, - px_r: np.ndarray, + px_r: torch.tensor, + per_ct_bias: torch.tensor, n_hidden: int, n_latent: int, n_layers: int, @@ -99,6 +100,7 @@ def __init__( decoder_state_dict=decoder_state_dict, px_decoder_state_dict=px_decoder_state_dict, px_r=px_r, + per_ct_bias=per_ct_bias, n_genes=st_adata.n_vars, n_latent=n_latent, n_layers=n_layers, @@ -107,6 +109,7 @@ def __init__( **module_kwargs, ) self.cell_type_mapping = cell_type_mapping + self.cell_type_mapping_extended = list(self.cell_type_mapping) + [f'additional_{i}' for i in range(self.module.add_celltypes)] self._model_summary_string = "DestVI Model" self.init_params_ = self._get_init_params(locals()) @@ -143,6 +146,7 @@ def from_rna_model( decoder_state_dict = OrderedDict((i[8:], load_state_dict[i]) for i in load_state_dict.keys() if i.split('.')[0]=='decoder') px_decoder_state_dict = OrderedDict((i[11:], load_state_dict[i]) for i in load_state_dict.keys() if i.split('.')[0]=='px_decoder') px_r = load_state_dict['px_r'] + per_ct_bias = load_state_dict['per_ct_bias'] mapping = registry['field_registries']['labels']['state_registry']['categorical_mapping'] dropout_decoder = attr_dict['init_params_']['non_kwargs']['dropout_rate'] @@ -158,11 +162,15 @@ def from_rna_model( mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior( sc_model.adata, p=vamp_prior_p ).values() + + if anndata_setup_kwargs is None: + anndata_setup_kwargs = {} cls.setup_anndata( st_adata, source_registry=registry, extend_categories=True, + **anndata_setup_kwargs, **registry[_SETUP_ARGS_KEY], ) @@ -172,6 +180,7 @@ def from_rna_model( decoder_state_dict, px_decoder_state_dict, px_r, + per_ct_bias, sc_model.module.n_hidden, sc_model.module.n_latent, sc_model.module.n_layers, @@ -185,19 +194,19 @@ def from_rna_model( @torch.inference_mode() def get_proportions( self, - keep_noise: bool = False, + keep_additional: bool = False, normalize: bool = True, indices: Sequence[int] | None = None, batch_size: int | None = None, ) -> pd.DataFrame: """Returns the estimated cell type proportion for the spatial data. - Shape is n_cells x n_labels OR n_cells x (n_labels + 1) if keep_noise. + Shape is n_cells x n_labels OR n_cells x (n_labels + add_celltypes) if keep_additional. Parameters ---------- - keep_noise - whether to account for the noise term as a standalone cell type in the proportion estimate. + keep_additional + whether to account for the additional cell-types as standalone cell types in the proportion estimate. normalize whether to normalize the proportions to sum to 1. indices @@ -211,8 +220,10 @@ def get_proportions( column_names = self.cell_type_mapping index_names = self.adata.obs.index - if keep_noise: - column_names = np.append(column_names, "noise_term") + if keep_additional: + column_names = list(self.cell_type_mapping_extended) + else: + column_names = list(self.cell_type_mapping) if self.module.amortization in ["both", "proportion"]: stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) @@ -221,17 +232,17 @@ def get_proportions( inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) generative_inputs = self.module._get_generative_input(tensors, outputs) - prop_local = self.module.generative(**generative_inputs)["v"].squeeze(0) + prop_local = self.module.generative(**generative_inputs)["v"][0, ...] prop_ += [prop_local.cpu()] - data = torch.cat(prop_).numpy() + data = torch.cat(prop_).detach().numpy() if indices: index_names = index_names[indices] else: - data = torch.nn.functional.softplus(self.module.V[indices, :]).transpose(1, 0).detach().cpu().numpy() + data = torch.nn.functional.softplus(self.module.V).transpose(0, 1).detach().cpu().numpy() + if not keep_additional: + data = data[:, :-self.module.add_celltypes] if normalize: data = data / data.sum(axis=1, keepdims=True) - if not keep_noise: - data = data[:, :-1] return pd.DataFrame( data=data, @@ -276,7 +287,7 @@ def get_fine_celltypes( n_modes, batch_size, n_celltypes = proportions_modes_local.shape gamma_local_ = gamma_local.permute((3, 2, 0, 1)).reshape(-1, self.module.n_latent) # m*p*c, n proportions_modes_local_ = proportions_modes_local.permute((1, 0, 2)).flatten() # m*p*c - v_local = generative_outputs['v'][..., :-1].flatten().repeat_interleave(n_modes) # m*p*c + v_local = generative_outputs['v'][0, ..., :-self.module.add_celltypes].flatten().repeat_interleave(n_modes) # m*p*c label = torch.arange(self.module.n_labels, device=gamma_local.device).repeat(batch_size).repeat_interleave(n_modes).unsqueeze(-1) # m*p*c, 1 predicted_fine_celltype_local = v_local.unsqueeze(-1) * proportions_modes_local_.unsqueeze(-1) * torch.nn.functional.softmax( sc_model.module.classify(gamma_local_, label), dim=-1) @@ -324,19 +335,18 @@ def get_gamma( outputs = self.module.inference(**inference_inputs) generative_inputs = self.module._get_generative_input(tensors, outputs) generative_outputs = self.module.generative(**generative_inputs) - gamma_local = generative_outputs["gamma"].squeeze(0) + gamma_local = generative_outputs["gamma"][0, ...] if self.module.prior_mode == 'mog': - proportions_model_local = generative_outputs['proportion_modes'].squeeze(0) + proportions_model_local = generative_outputs['proportion_modes'][0, ...] gamma_local = torch.einsum('pncm,pmc->ncm', gamma_local, proportions_model_local) else: - gamma_local = gamma_local.squeeze(0).squeeze(0) + gamma_local = gamma_local[0, ...].squeeze(0) gamma_ += [gamma_local.cpu()] data = torch.cat(gamma_, dim=-1).numpy() if indices is not None: index_names = index_names[indices] else: - data = self.module.gamma[indices, :, :].detach().cpu().numpy() - + data = self.module.gamma.detach().cpu().numpy() data = np.transpose(data, (2, 0, 1)) if return_numpy: return data @@ -398,7 +408,7 @@ def get_latent_representation( latent_qzv = [] for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) - inference_outputs = self.module.inference(**inference_inputs, n_samples=mc_samples).values() + inference_outputs = self.module.inference(**inference_inputs, n_samples=mc_samples) z = inference_outputs['z'][0, ...] qz = inference_outputs['qz'] if give_mean: @@ -438,7 +448,7 @@ def get_scale_for_ct( self._check_if_trained() self._validate_anndata() - cell_type_mapping_extended = list(self.cell_type_mapping) + ['noise'] + cell_type_mapping_extended = list(self.cell_type_mapping) + [f'additional_{i}' for i in range(self.module.add_celltypes)] if label not in cell_type_mapping_extended: raise ValueError("Unknown cell type") @@ -487,11 +497,10 @@ def get_expression_for_ct( Pandas dataframe of gene_expression """ self._check_if_trained() - cell_type_mapping_extended = list(self.cell_type_mapping) + ['noise'] - if label not in cell_type_mapping_extended: + if label not in self.cell_type_mapping_extended: raise ValueError("Unknown cell type") - y = cell_type_mapping_extended.index(label) + y = self.cell_type_mapping_extended.index(label) stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size) expression_ct = [] @@ -591,6 +600,7 @@ def setup_anndata( cls, adata: AnnData, layer: str | None = None, + smoothed_layer: str | None = None, batch_key: str | None = None, **kwargs, ): @@ -610,6 +620,8 @@ def setup_anndata( NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), ] + if smoothed_layer is not None: + anndata_fields.append(LayerField("x_smoothed", smoothed_layer, is_count_data=True)) adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/scvi/model/_scanvi.py b/src/scvi/model/_scanvi.py index 3d70623747..bf44e9663e 100644 --- a/src/scvi/model/_scanvi.py +++ b/src/scvi/model/_scanvi.py @@ -12,9 +12,12 @@ from anndata import AnnData from scvi import REGISTRY_KEYS, settings +from scvi._types import MinifiedDataType from scvi.data import AnnDataManager from scvi.data._constants import ( + _ADATA_MINIFY_TYPE_UNS_KEY, _SETUP_ARGS_KEY, + ADATA_MINIFY_TYPE, ) from scvi.data._utils import _get_adata_minify_type, _is_minified, get_anndata_attribute from scvi.data.fields import ( diff --git a/src/scvi/model/_scvi.py b/src/scvi/model/_scvi.py index 2b427a36cb..57be797b7d 100644 --- a/src/scvi/model/_scvi.py +++ b/src/scvi/model/_scvi.py @@ -9,6 +9,7 @@ from scvi import REGISTRY_KEYS, settings from scvi._types import MinifiedDataType from scvi.data import AnnDataManager +from scvi.data._constants import _ADATA_MINIFY_TYPE_UNS_KEY, ADATA_MINIFY_TYPE from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import ( CategoricalJointObsField, diff --git a/src/scvi/module/_mrdeconv.py b/src/scvi/module/_mrdeconv.py index 849a764421..9d441b33c5 100644 --- a/src/scvi/module/_mrdeconv.py +++ b/src/scvi/module/_mrdeconv.py @@ -3,12 +3,12 @@ import numpy as np import torch -from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal +from torch.distributions import Categorical, Exponential, Independent, Laplace, MixtureSameFamily, Normal from torch.distributions import kl_divergence as kl from scvi import REGISTRY_KEYS from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data +from scvi.module.base import EmbeddingModuleMixin, BaseModuleClass, LossOutput, auto_move_data from scvi.nn import Encoder, FCLayers @@ -17,7 +17,7 @@ def identity(x): return x -class MRDeconv(BaseModuleClass): +class MRDeconv(EmbeddingModuleMixin, BaseModuleClass): """Model for multi-resolution deconvolution of spatial transriptomics. Parameters @@ -34,16 +34,26 @@ class MRDeconv(BaseModuleClass): Number of dimensions used in the latent variables n_genes Number of genes used in the decoder - dropout_decoder - Dropout rate for the decoder neural network (same dropout as in CondSCVI decoder) - dropout_amortization - Dropout rate for the amortization neural network + px_r + parameters for the px_r tensor in the CondSCVI model + per_ct_bias + estimates of per cell-type expression bias in the CondSCVI model decoder_state_dict state_dict from the decoder of the CondSCVI model px_decoder_state_dict state_dict from the px_decoder of the CondSCVI model - px_r - parameters for the px_r tensor in the CondSCVI model + dropout_decoder + Dropout rate for the decoder neural network (same dropout as in CondSCVI decoder) + dropout_amortization + Dropout rate for the amortization neural network + n_samples_augmentation + Number of samples used in the augmentation + n_states_per_label + Number of states per cell-type in each spot + eps_v + Epsilon value for each cell-type proportion used during training. + n_states_per_augmented_label + Number of states per cell-type in each spot during augmentation mean_vprior Mean parameter for each component in the empirical prior over the latent space var_vprior @@ -51,15 +61,13 @@ class MRDeconv(BaseModuleClass): mp_vprior Mixture proportion in cell type sub-clustering of each component in the empirical prior amortization - beta_reg - Scalar parameter indicating the strength of the variance penalty for - the multiplicative offset in gene expression values (beta parameter). Default is 5 - (setting to 0.5 might help if single cell reference and spatial assay are different - e.g. UMI vs non-UMI.) - eta_reg - Scalar parameter indicating the strength of the prior for - the noise term (eta parameter). Default is 1e-4. - (changing value is discouraged.) + prior_mode + Mode of the prior distribution for the latent space. + Either "mog" for mixture of gaussians or "normal" for normal distribution. + add_celltypes + Number of additional cell types compared to single cell data to add to the model + n_latent_amortization + Number of dimensions used in the latent variables for the amortization encoder neural network extra_encoder_kwargs Extra keyword arguments passed into :class:`~scvi.nn.FCLayers`. extra_decoder_kwargs @@ -71,26 +79,26 @@ def __init__( n_spots: int, n_labels: int, n_batch: int, - n_hidden: Tunable[int], - n_layers: Tunable[int], - n_latent: Tunable[int], + n_hidden: int, + n_layers: int, + n_latent: int, n_genes: int, decoder_state_dict: OrderedDict, px_decoder_state_dict: OrderedDict, - px_r: np.ndarray, + px_r: torch.tensor, + per_ct_bias: torch.tensor, dropout_decoder: float, + dropout_amortization: float = 0.03, augmentation: bool = False, n_samples_augmentation: int = 1, - n_states_per_label: Tunable[int] = 1, - n_states_per_augmented_label: float = 1, - dropout_amortization: float = 0.03, + n_states_per_label: int = 1, + eps_v: float = 2e-3, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, mp_vprior: np.ndarray = None, amortization: Literal["none", "latent", "proportion", "both"] = "both", - beta_reg: Tunable[float] = 500.0, - eta_reg: Tunable[float] = 1e-7, - prior_mode: Literal["mog", "normal"] = "normal", + prior_mode: Literal["mog", "normal"] = "mog", + add_celltypes: int = 1, n_latent_amortization: Optional[int] = None, extra_encoder_kwargs: Optional[dict] = None, extra_decoder_kwargs: Optional[dict] = None, @@ -105,20 +113,19 @@ def __init__( self.n_latent = n_latent self.augmentation = augmentation self.n_samples_augmentation = n_samples_augmentation - self.n_states_per_augmented_label = n_states_per_augmented_label self.dropout_decoder = dropout_decoder self.n_states_per_label = n_states_per_label self.dropout_amortization = dropout_amortization self.n_genes = n_genes self.amortization = amortization - self.beta_reg = beta_reg - self.eta_reg = eta_reg self.prior_mode = prior_mode + self.add_celltypes = add_celltypes + self.eps_v = eps_v self.n_latent_amortization = n_latent_amortization # unpack and copy parameters _extra_decoder_kwargs = extra_decoder_kwargs or {} cat_list = [n_labels] - self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, {}) + self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch) batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim n_input_decoder = n_latent + batch_dim @@ -140,10 +147,11 @@ def __init__( self.px_decoder.load_state_dict(px_decoder_state_dict) for param in self.px_decoder.parameters(): param.requires_grad = False - self.px_o = torch.nn.Parameter(px_r) + self.px_r = torch.nn.Parameter(px_r) + self.register_buffer("per_ct_bias", per_ct_bias) # cell_type specific factor loadings - self.V = torch.nn.Parameter(torch.randn(self.n_labels + 1, self.n_spots)) + self.V = torch.nn.Parameter(torch.randn(self.n_labels + self.add_celltypes, self.n_spots)) # within cell_type factor loadings self.gamma = torch.nn.Parameter(torch.randn(n_latent, self.n_labels, self.n_spots)) @@ -164,10 +172,9 @@ def __init__( self.mean_vprior = None self.var_vprior = None # noise from data - self.eta = torch.nn.Parameter(torch.zeros(self.n_genes)) + self.eta = torch.nn.Parameter(torch.zeros(self.add_celltypes, self.n_genes)) # additive gene bias self.beta = torch.nn.Parameter(torch.zeros(self.n_genes)) - print('beta is parameter') # create additional neural nets for amortization # within cell_type factor loadings @@ -192,19 +199,17 @@ def __init__( return_dist=True, **_extra_encoder_kwargs, ) - else: def identity(x, batch_index=None): return x, Normal(x, scale=1e-6*torch.ones_like(x)) self.z_encoder = identity n_latent_amortization = self.n_genes - n_layers = 2 self.gamma_encoder = torch.nn.Sequential( FCLayers( n_in=n_latent_amortization, n_out=n_hidden, n_cat_list=None, - n_layers=n_layers, + n_layers=2, n_hidden=n_hidden, dropout_rate=dropout_amortization, use_layer_norm=True, @@ -224,50 +229,60 @@ def identity(x, batch_index=None): use_layer_norm=True, use_batch_norm=False, ), - torch.nn.Linear(n_hidden, n_labels + 1), + torch.nn.Linear(n_hidden, n_labels + self.add_celltypes), ) def _get_inference_input(self, tensors): x = tensors[REGISTRY_KEYS.X_KEY] + x_smoothed = tensors.get('x_smoothed', None) m = x.shape[0] - n_samples = self.n_samples_augmentation + 1 + if x_smoothed is not None: + n_samples = self.n_samples_augmentation + 2 + n_samples_observed = 2 + else: + n_samples = self.n_samples_augmentation + 1 + n_samples_observed = 1 + px_r = torch.exp(self.px_r) if self.augmentation and self.training: with torch.no_grad(): - beta = torch.exp(self.beta) # n_genes - # beta = torch.cat([beta.view(1, 1, 1, -1), torch.ones_like(beta).view(1, 1, 1, -1).repeat(n_samples-1, 1, 1, 1)]) prior_sampled = self.qz_prior.sample( - [n_samples, self.n_states_per_augmented_label, m]).reshape( - n_samples*self.n_states_per_augmented_label, -1, self.n_latent) + [n_samples, m]).reshape( + n_samples, -1, self.n_latent) enum_label = ( torch.arange(0, self.n_labels).repeat(m).view((-1, 1)) ) # m * n_labels, 1 batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, tensors[REGISTRY_KEYS.BATCH_KEY]) - batch_rep_input = batch_rep.repeat_interleave(self.n_labels, dim=0) + batch_rep_input = batch_rep.repeat_interleave(self.n_labels, dim=0).repeat(n_samples, 1, 1) decoder_input = torch.cat([prior_sampled, batch_rep_input], dim=-1) - px_scale_augment_ = torch.nn.Softmax(dim=-1)(self.px_decoder(self.decoder(decoder_input, enum_label.to(x.device))) + beta.view(1, 1, -1)) + px_scale_augment_ = torch.nn.Softmax(dim=-1)( + self.px_decoder(self.decoder(decoder_input, enum_label.to(x.device))) + self.per_ct_bias[enum_label.ravel()].unsqueeze(-3) + self.beta.view(1, 1, -1)) px_scale_augment = px_scale_augment_.reshape( - (n_samples*self.n_states_per_augmented_label, x.shape[0], self.n_labels, -1) - ) # (samples*states_per_cell, mi, n_labels, n_genes) - library = x.sum(-1).view(1, 1, m, 1, 1).repeat(n_samples, 1, 1, 1, 1) + (n_samples, x.shape[0], self.n_labels, -1) + ) # (samples, mi, n_labels, n_genes) + library = x.sum(-1).view(1, m, 1, 1).repeat(n_samples, 1, 1, 1) library[1, ...] = library[1, ...] + 50 - px_scale_augment = px_scale_augment.reshape(n_samples, self.n_states_per_augmented_label, m, self.n_labels, -1) # (samples, states_per_cell, m, n_labels, n_genes) - px_rate = library * px_scale_augment # (samples, states_per_cell, m, n_labels, n_genes) + px_scale_augment = px_scale_augment.reshape(n_samples, m, self.n_labels, -1) # (samples, m, n_labels, n_genes) + px_rate = library * px_scale_augment # (samples, m, n_labels, n_genes) ratios_ct_augmentation = torch.distributions.Dirichlet( - torch.zeros(self.n_states_per_augmented_label * self.n_labels) + 0.03).sample([n_samples, m]).to(x.device) - ratios_ct_augmentation = ratios_ct_augmentation.reshape(n_samples, m, self.n_states_per_augmented_label, self.n_labels).permute(0, 2, 1, 3) - augmentation_rate = torch.einsum('ilmk, ilmkg -> img', ratios_ct_augmentation, px_rate) # (samples, m, n_genes) + torch.zeros(self.n_labels) + 0.03).sample([n_samples, m]).to(x.device) + ratios_ct_augmentation = ratios_ct_augmentation.reshape(n_samples, m, self.n_labels) + # sum over celltypes + augmentation_rate = torch.einsum('imc, imcg -> img', ratios_ct_augmentation, px_rate) # (samples, m, n_genes) ratio_augmentation_ = torch.distributions.Beta(0.4, 0.5).sample([self.n_samples_augmentation-1, m]).unsqueeze(-1).to(x.device) - ratio_augmentation = torch.cat([torch.zeros((1, m, 1), device=x.device), torch.ones((1, m, 1), device=x.device), ratio_augmentation_], dim=0) + ratio_augmentation = torch.cat([torch.zeros((n_samples_observed, m, 1), device=x.device), torch.ones((1, m, 1), device=x.device), ratio_augmentation_], dim=0) augmented_counts = NegativeBinomial( - augmentation_rate, logits=self.px_o - ).sample() # (samples*states_per_cell, m, n_labels, n_genes) - # print('TTTT1', augmentation_rate[1, ...].sum(-1).min(), augmented_counts[1, ...].sum(-1).min(), x.sum(-1).min(), x.shape) + mu=augmentation_rate, theta=px_r + ).sample() # (samples, m, n_genes) x_augmented = ( (1 - ratio_augmentation) * x.unsqueeze(0) + ratio_augmentation * augmented_counts ) + if x_smoothed is not None: + x_augmented[1, ...] = x_smoothed else: x_augmented = x.unsqueeze(0) + if x_smoothed is not None: + x_augmented = torch.cat([x_augmented, x_smoothed.unsqueeze(0)], dim=0) prior_sampled = None ratios_ct_augmentation = None ratio_augmentation = None @@ -327,12 +342,9 @@ def generative(self, z, ind_x, library, batch_index): """Build the deconvolution model for every cell in the minibatch.""" m = len(ind_x) # setup all non-linearities - beta = torch.exp(self.beta) # n_genes - eps = torch.nn.functional.softmax(self.eta, dim=-1) # n_genes - if self.training and self.augmentation: - n_samples = (self.n_samples_augmentation + 1) - else: - n_samples = 1 + eps = torch.nn.functional.softmax(self.eta, dim=-1) # add_celltypes, n_genes + px_r = torch.exp(self.px_r) + n_samples = z.size(0) if self.amortization in ["both", "latent"]: if self.prior_mode == "mog": @@ -385,26 +397,22 @@ def generative(self, z, ind_x, library, batch_index): h = self.decoder(decoder_input_, enum_label.to(z.device)) px_est += proportion_modes[:, mode, ...].unsqueeze(-1) * torch.nn.Softmax(dim=-1)(self.px_decoder(h).reshape( (n_samples, m, self.n_labels, -1) - ) + beta.view(1, 1, 1, -1)) # (n_samples, m, n_labels, n_genes) + ) + self.per_ct_bias[enum_label.ravel()].reshape(1, m, self.n_labels, -1) + self.beta.view(1, 1, 1, -1)) # (n_samples, m, n_labels, n_genes) - # add the dummy cell type - eps = eps.unsqueeze(0).repeat(n_samples, m, 1).unsqueeze(-2) # (n_samples, m, 1, n_genes) <- this is the dummy cell type - - # account for gene specific bias and add noise, take sample without augmentation. - r_hat = torch.cat( - [beta.view(1, 1, 1, -1) * px_est, eps], dim=-2 - ) # n_samples, m, n_labels + 1, n_genes + # add the additional cell types + eps = eps.unsqueeze(0).repeat(n_samples, m, 1, 1) # (n_samples, m, add_celltypes, n_genes) <- this is the dummy cell type + r_hat = torch.cat([px_est, eps], dim=-2) # n_samples, m, n_labels + add_celltypes, n_genes - # now combine them for convolution - px_scale = torch.sum(v_ind.unsqueeze(-1) * r_hat, dim=-2) # n_samples, m, n_genes + # now combine them for convolution. Add epsilon during training. + eps_v = self.eps_v if self.training else 0. + px_scale = torch.sum((v_ind.unsqueeze(-1) + eps_v) * r_hat, dim=-2) # n_samples, m, n_genes px_rate = library * px_scale - px_mu = torch.exp(self.px_o) * r_hat return { - "px_o": self.px_o, + "px_r": px_r, "px_rate": px_rate, - "px_mu": px_mu, "px_scale": px_scale, + "px_mu": r_hat, "gamma": gamma_ind, "v": v_ind, "proportion_modes": proportion_modes, @@ -427,66 +435,71 @@ def loss( kl_weight: float = 1.0, ct_sparsity_weight: float = 0., weighting_augmentation: float = 10., + weighting_smoothing: float = 0.1, + eta_reg: float = 1., + beta_reg: float = 1e-3, + reconst_weight: float = 1.0, ): """Compute the loss.""" x_augmented = inference_outputs["x_augmented"] px_rate = generative_outputs["px_rate"] - px_o = generative_outputs["px_o"] + px_r = generative_outputs["px_r"] gamma = generative_outputs["gamma"] v = generative_outputs["v"] ratio_augmentation = inference_outputs['ratio_augmentation'] ratios_ct_augmentation = inference_outputs['ratios_ct_augmentation'] - n_samples = self.n_samples_augmentation + 1 + if "x_smoothed" in tensors: + sample_fully_augmented = 2 + n_samples = self.n_samples_augmentation + 2 + else: + sample_fully_augmented = 1 + n_samples = self.n_samples_augmentation + 1 m = x_augmented.shape[1] if self.augmentation: - prior_sampled = inference_outputs['prior_sampled'].reshape(n_samples, self.n_states_per_augmented_label, x_augmented.shape[1], self.n_labels, self.n_latent) + prior_sampled = inference_outputs['prior_sampled'].reshape(n_samples, 1, x_augmented.shape[1], self.n_labels, self.n_latent) mean_vprior = torch.cat( [self.mean_vprior.unsqueeze(0).unsqueeze(-4).repeat(n_samples, m, 1, 1, 1), prior_sampled.permute(0, 2, 3, 1, 4)], - dim=-2) # n_samples, m, n_labels, p+n_states_per_augmented_label, n_latent + dim=-2) # n_samples, m, n_labels, p + 1, n_latent var_vprior = torch.cat( [self.var_vprior.unsqueeze(0).unsqueeze(-4).repeat(n_samples, m, 1, 1, 1), torch.min(self.var_vprior, dim=-2).values.view( - 1, 1, self.n_labels, 1, self.n_latent).repeat(n_samples, m, 1, self.n_states_per_augmented_label, 1)], + 1, 1, self.n_labels, 1, self.n_latent).repeat(n_samples, m, 1, 1, 1)], dim=-2 - ) # n_samples, m, n_labels, p+n_states_per_augmented_label, n_latent + ) # n_samples, m, n_labels, p + 1, n_latent + mp_vprior=torch.cat( [(1- ratio_augmentation.unsqueeze(-1)) * self.mp_vprior.view(1, 1, self.n_labels, -1).repeat(n_samples, m, 1, 1), - ratio_augmentation.unsqueeze(-1) * ratios_ct_augmentation.permute(0, 2, 3, 1) + ratio_augmentation.unsqueeze(-1) * ratios_ct_augmentation.unsqueeze(-1) ], dim=-1 - ) # n_samples, m, n_labels, p+n_states_per_augmented_label + ) # n_samples, m, n_labels, p + 1 else: mean_vprior = self.mean_vprior.unsqueeze(0).unsqueeze(0) var_vprior = self.var_vprior.unsqueeze(0).unsqueeze(0) mp_vprior = self.mp_vprior.unsqueeze(0).unsqueeze(0) proportion_modes = generative_outputs["proportion_modes"] - reconst_loss = - NegativeBinomial(px_rate, logits=px_o).log_prob(x_augmented).sum(-1) - # eta prior likelihood - mean = torch.zeros_like(self.eta) - scale = torch.ones_like(self.eta) - glo_neg_log_likelihood_prior = ( - -self.eta_reg * Normal(mean, scale).log_prob(self.eta).sum() - ) + #rounding errors for softmax lead to px_rate 0 which induces NaNs in the log_prob + reconst_loss = - NegativeBinomial(mu=px_rate , theta=px_r).log_prob(x_augmented).sum(-1).mean(0) + # beta prior likelihood + mean = torch.zeros_like(self.beta) + scale = torch.ones_like(self.beta) # beta loss - glo_neg_log_likelihood_prior += ( - -self.beta_reg * Normal(mean, scale).log_prob(self.beta).sum() + glo_neg_log_likelihood_prior = ( + -beta_reg * Normal(mean, scale).log_prob(self.beta).sum() ) + loss_augmentation = torch.tensor(0., device=x_augmented.device) if self.augmentation: - expected_proportions = ( - ratio_augmentation * torch.cat([ratios_ct_augmentation.sum(-3), torch.zeros([n_samples, m, 1]).to(v.device)], dim=-1) + - (1 - ratio_augmentation) * v[0, :, :] # unperturbed proportions - ) - #loss_augmentation = weighting_augmentation * torch.abs(v - expected_proportions).sum(-1) - loss_augmentation = 0 - for i in [1]: - loss_augmentation += weighting_augmentation * self._compute_cross_entropy(expected_proportions[i, ...].squeeze(0), v[i, ...].squeeze(0)) - else: - loss_augmentation = torch.tensor(0., device=x_augmented.device) - + expected_proportions = torch.cat([ratios_ct_augmentation, torch.zeros([n_samples, m, self.add_celltypes]).to(v.device)], dim=-1) + loss_augmentation += weighting_augmentation * self._compute_cross_entropy( + expected_proportions[sample_fully_augmented, ...].squeeze(0), v[sample_fully_augmented, ...].squeeze(0)) + if "x_smoothed" in tensors: + loss_augmentation += weighting_smoothing * self._compute_cross_entropy( + v[1, ...].squeeze(0), v[0, ...].squeeze(0)) + # gamma prior likelihood if self.mean_vprior is None: # isotropic normal prior @@ -508,9 +521,13 @@ def loss( pre_lse = pre_lse.permute(1, 0, 2, 3) log_likelihood_prior = torch.mul( pre_lse, - proportion_modes + 1e-3 + proportion_modes + 1e-6 ).sum(-3) # n_samples, minibatch, n_labels - neg_log_likelihood_prior = - log_likelihood_prior.sum(-1) + neg_log_likelihood_prior = - torch.mul( + log_likelihood_prior, + v[:, :, :-self.add_celltypes] + 1e-3 + ).sum(-1) # n_samples, minibatch, n_labels + # neg_log_likelihood_prior = - log_likelihood_prior.sum(-1) else: gamma = gamma.permute(0, 4, 1, 3, 2) # n_samples, minibatch_size, 1, n_labels, n_latent mean_vprior = torch.transpose(mean_vprior, -3, -2) # n_samples, m, p, n_labels, n_latent @@ -522,15 +539,18 @@ def loss( log_likelihood_prior = torch.logsumexp(pre_lse, -2) # n_samples, minibatch, n_labels neg_log_likelihood_prior = - log_likelihood_prior.sum(-1) # n_samples, minibatch if self.n_latent_amortization is not None: - neg_log_likelihood_prior += kl( + neg_log_likelihood_prior += 1e-3 * kl( inference_outputs["qz"], - Normal(torch.zeros([self.n_latent_amortization], device=x.device), torch.ones([self.n_latent_amortization], device=x.device)) + Normal(torch.zeros([self.n_latent_amortization], device=x_augmented.device), torch.ones([self.n_latent_amortization], device=x_augmented.device)) ).sum(dim=-1) - v_sparsity_loss = ct_sparsity_weight * torch.distributions.Categorical(probs=v[0, :, :]).entropy() + v_sparsity_loss = ct_sparsity_weight * torch.distributions.Categorical( + probs=v[0, ...]).entropy() + v_sparsity_loss -= eta_reg * Exponential( + torch.ones_like(v[0, :, -self.add_celltypes:])).log_prob(v[0, :, -self.add_celltypes:]).sum() loss = torch.mean( - reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss) + glo_neg_log_likelihood_prior + loss_augmentation + reconst_weight * reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss) + glo_neg_log_likelihood_prior + loss_augmentation ) return LossOutput( @@ -538,6 +558,7 @@ def loss( reconstruction_loss=reconst_loss, kl_local=neg_log_likelihood_prior, kl_global=glo_neg_log_likelihood_prior, + extra_metrics={"v_sparsity": v_sparsity_loss.mean(), "augmentation": loss_augmentation.mean()}, ) @torch.inference_mode() diff --git a/src/scvi/module/_vaec.py b/src/scvi/module/_vaec.py index 8f2053075b..131571657f 100644 --- a/src/scvi/module/_vaec.py +++ b/src/scvi/module/_vaec.py @@ -1,15 +1,24 @@ -from __future__ import annotations +from collections.abc import Iterable +from typing import Optional import numpy as np import torch -from torch.distributions import Distribution +from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal +from torch.distributions import kl_divergence as kl +from torch.nn import functional as F from scvi import REGISTRY_KEYS -from scvi.module._constants import MODULE_KEYS -from scvi.module.base import BaseModuleClass, auto_move_data +from ._classifier import Classifier +from scvi.data._constants import ADATA_MINIFY_TYPE +from scvi.distributions import NegativeBinomial +from scvi.module.base import EmbeddingModuleMixin, BaseMinifiedModeModuleClass, LossOutput, auto_move_data +from scvi.nn import Encoder, FCLayers, DecoderSCVI +torch.backends.cudnn.benchmark = True -class VAEC(BaseModuleClass): + +# Conditional VAE model +class VAEC(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): """Conditional Variational auto-encoder model. This is an implementation of the CondSCVI model @@ -19,7 +28,7 @@ class VAEC(BaseModuleClass): n_input Number of input genes n_batch - Number of batches. If ``0``, no batch correction is performed. + Number of batches n_labels Number of labels n_hidden @@ -34,9 +43,6 @@ class VAEC(BaseModuleClass): Multiplicative weight for cell type specific latent space. dropout_rate Dropout rate for the encoder and decoder neural network. - encode_covariates - If ``True``, covariates are concatenated to gene expression prior to passing through - the encoder(s). Else, only gene expression is used. extra_encoder_kwargs Keyword arguments passed into :class:`~scvi.nn.Encoder`. extra_decoder_kwargs @@ -48,39 +54,51 @@ def __init__( n_input: int, n_batch: int = 0, n_labels: int = 0, + n_fine_labels: Optional[int] = None, n_hidden: int = 128, n_latent: int = 5, n_layers: int = 2, log_variational: bool = True, - ct_weight: np.ndarray | None = None, + ct_weight: np.ndarray = None, dropout_rate: float = 0.05, encode_covariates: bool = False, - extra_encoder_kwargs: dict | None = None, - extra_decoder_kwargs: dict | None = None, + extra_encoder_kwargs: Optional[dict] = None, + extra_decoder_kwargs: Optional[dict] = None, + linear_classifier: bool = True, + prior: str = 'normal', + num_classes_mog: Optional[int] = 10, ): - from scvi.nn import Encoder, FCLayers - super().__init__() self.dispersion = "gene" self.n_latent = n_latent self.n_layers = n_layers self.n_hidden = n_hidden self.dropout_rate = dropout_rate + self.encode_covariates = encode_covariates self.log_variational = log_variational self.gene_likelihood = "nb" self.latent_distribution = "normal" - self.encode_covariates = encode_covariates + # Automatically deactivate if useless self.n_batch = n_batch self.n_labels = n_labels - - if self.encode_covariates and self.n_batch < 1: - raise ValueError("`n_batch` must be greater than 0 if `encode_covariates` is `True`.") - + self.n_fine_labels = n_fine_labels + self.prior = prior + self.num_classes_mog = num_classes_mog + self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **{}) + batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim + + cat_list = [n_labels] + n_input_encoder = n_input + batch_dim * encode_covariates + + # gene dispersion self.px_r = torch.nn.Parameter(torch.randn(n_input)) + + # z encoder goes from the n_input-dimensional data to an n_latent-d + _extra_encoder_kwargs = {} self.z_encoder = Encoder( - n_input, + n_input_encoder, n_latent, - n_cat_list=[n_labels] + ([n_batch] if n_batch > 0 and encode_covariates else []), + n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, @@ -88,63 +106,109 @@ def __init__( use_batch_norm=False, use_layer_norm=True, return_dist=True, - **(extra_encoder_kwargs or {}), + **_extra_encoder_kwargs, ) + if n_fine_labels is not None: + cls_parameters = { + "n_layers": 0, + "n_hidden": 0, + "dropout_rate": dropout_rate, + "logits": True, + } + # linear mapping from latent space to a coarse-celltype aware space + self.linear_mapping = FCLayers( + n_in=n_latent, + n_out=n_hidden, + n_cat_list=[n_labels], + use_layer_norm=True, + dropout_rate=0.0, + ) + + self.classifier = Classifier( + n_hidden, + n_labels=n_fine_labels, + use_batch_norm=False, + use_layer_norm=True, + **cls_parameters, + ) + else: + self.classifier = None + # decoder goes from n_latent-dimensional space to n_input-d data + _extra_decoder_kwargs = {} + n_input_decoder = n_latent + batch_dim self.decoder = FCLayers( - n_in=n_latent, + n_in=n_input_decoder, n_out=n_hidden, - n_cat_list=[n_labels] + ([n_batch] if n_batch > 0 else []), + n_cat_list=cat_list, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=dropout_rate, inject_covariates=True, use_batch_norm=False, use_layer_norm=True, - **(extra_decoder_kwargs or {}), - ) - self.px_decoder = torch.nn.Sequential( - torch.nn.Linear(n_hidden, n_input), torch.nn.Softplus() + **_extra_decoder_kwargs, ) + self.px_decoder = torch.nn.Linear(n_hidden, n_input) + self.per_ct_bias = torch.nn.Parameter(torch.zeros(n_labels, n_input)) - self.register_buffer( - "ct_weight", - ( - torch.ones((self.n_labels,), dtype=torch.float32) - if ct_weight is None - else torch.tensor(ct_weight, dtype=torch.float32) - ), - ) - - def _get_inference_input( - self, tensors: dict[str, torch.Tensor] - ) -> dict[str, torch.Tensor | None]: - return { - MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY], - MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.LABELS_KEY], - MODULE_KEYS.BATCH_INDEX_KEY: tensors.get(REGISTRY_KEYS.BATCH_KEY, None), - } + if ct_weight is not None: + ct_weight = torch.tensor(ct_weight, dtype=torch.float32) + else: + ct_weight = torch.ones((self.n_labels,), dtype=torch.float32) + self.register_buffer("ct_weight", ct_weight) + if self.prior=='mog': + self.prior_means = torch.nn.Parameter( + torch.randn([n_labels, self.num_classes_mog, n_latent])) + self.prior_log_std = torch.nn.Parameter( + torch.zeros([n_labels, self.num_classes_mog, n_latent]) - 2.) + self.prior_logits = torch.nn.Parameter(torch.zeros([n_labels, self.num_classes_mog])) + + def _get_inference_input(self, tensors): + y = tensors[REGISTRY_KEYS.LABELS_KEY] + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] + if self.minified_data_type is None: + x = tensors[REGISTRY_KEYS.X_KEY] + input_dict = { + "x": x, + "y": y, + "batch_index": batch_index, + } + else: + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): + qzm = tensors[REGISTRY_KEYS.LATENT_QZM_KEY] + qzv = tensors[REGISTRY_KEYS.LATENT_QZV_KEY] + observed_lib_size = tensors[REGISTRY_KEYS.OBSERVED_LIB_SIZE] + input_dict = { + "qzm": qzm, + "qzv": qzv, + "observed_lib_size": observed_lib_size, + "y": y, + "batch_index": batch_index, + } + else: + raise NotImplementedError( + f"Unknown minified-data type: {self.minified_data_type}" + ) + + return input_dict + + def _get_generative_input(self, tensors, inference_outputs): + z = inference_outputs["z"] + library = inference_outputs["library"] + y = tensors[REGISTRY_KEYS.LABELS_KEY] + batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - def _get_generative_input( - self, - tensors: dict[str, torch.Tensor], - inference_outputs: dict[str, torch.Tensor | Distribution], - ) -> dict[str, torch.Tensor]: - return { - MODULE_KEYS.Z_KEY: inference_outputs[MODULE_KEYS.Z_KEY], - MODULE_KEYS.LIBRARY_KEY: inference_outputs[MODULE_KEYS.LIBRARY_KEY], - MODULE_KEYS.Y_KEY: tensors[REGISTRY_KEYS.LABELS_KEY], - MODULE_KEYS.BATCH_INDEX_KEY: tensors.get(REGISTRY_KEYS.BATCH_KEY, None), + input_dict = { + "z": z, + "library": library, + "y": y, + "batch_index": batch_index, } + return input_dict @auto_move_data - def inference( - self, - x: torch.Tensor, - y: torch.Tensor, - batch_index: torch.Tensor | None = None, - n_samples: int = 1, - ) -> dict[str, torch.Tensor | Distribution]: + def _regular_inference(self, x, y, batch_index, n_samples=1): """High level inference method. Runs the inference (encoder) model. @@ -152,80 +216,144 @@ def inference( x_ = x library = x.sum(1).unsqueeze(1) if self.log_variational: - x_ = torch.log1p(x_) - - encoder_input = [x, y] - if batch_index is not None and self.encode_covariates: - encoder_input.append(batch_index) - - qz, z = self.z_encoder(*encoder_input) + x_ = torch.log(1 + x_) + if self.encode_covariates: + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + encoder_input = torch.cat([x_, batch_rep], dim=-1) + else: + encoder_input = x_ + qz, z = self.z_encoder(encoder_input, y) if n_samples > 1: untran_z = qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) - library = library.unsqueeze(0).expand((n_samples, library.size(0), library.size(1))) + library = library.unsqueeze(0).expand( + (n_samples, library.size(0), library.size(1)) + ) - return { - MODULE_KEYS.Z_KEY: z, - MODULE_KEYS.QZ_KEY: qz, - MODULE_KEYS.LIBRARY_KEY: library, - } + outputs = {"z": z, "qz": qz, "library": library} + return outputs @auto_move_data - def generative( + def _cached_inference(self, qzm, qzv, observed_lib_size, n_samples=1): + if ADATA_MINIFY_TYPE.__contains__(self.minified_data_type): + qz = Normal(qzm, qzv.sqrt()) + # use dist.sample() rather than rsample because we aren't optimizing the z here + untran_z = qz.sample() if n_samples == 1 else qz.sample((n_samples,)) + z = self.z_encoder.z_transformation(untran_z) + library = observed_lib_size + if n_samples > 1: + library = library.unsqueeze(0).expand( + (n_samples, library.size(0), library.size(1)) + ) + else: + raise NotImplementedError( + f"Unknown minified-data type: {self.minified_data_type}" + ) + outputs = {"z": z, "qz": qz, "library": library} + return outputs + + @auto_move_data + def classify( self, z: torch.Tensor, - library: torch.Tensor, - y: torch.Tensor, - batch_index: torch.Tensor | None = None, - ) -> dict[str, Distribution]: - """Runs the generative model.""" - from scvi.distributions import NegativeBinomial + label_index: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Forward pass through the encoder and classifier. - decoder_input = [z, y] - if batch_index is not None: - decoder_input.append(batch_index) + Parameters + ---------- + z + Tensor of shape ``(n_obs, n_latent)``. + label_index + Tensor of shape ``(n_obs,)`` denoting label indices. - h = self.decoder(*decoder_input) - px_scale = self.px_decoder(h) + Returns + ------- + Tensor of shape ``(n_obs, n_labels)`` denoting logit scores per label. + """ + if len(label_index.shape)==1: + label_index = label_index.unsqueeze(1) + classifier_latent = self.linear_mapping(z, label_index) + w_y = self.classifier(classifier_latent) + return w_y + + @auto_move_data + def generative(self, z, library, y, batch_index): + """Runs the generative model.""" + batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) + decoder_input = torch.cat([z, batch_rep], dim=-1) + h = self.decoder(decoder_input, y, batch_index) + px_scale = torch.nn.Softmax(dim=-1)(self.px_decoder(h) + self.per_ct_bias[y.ravel()]) px_rate = library * px_scale - return {MODULE_KEYS.PX_KEY: NegativeBinomial(px_rate, logits=self.px_r)} + px_r = torch.exp(self.px_r) + px = NegativeBinomial(mu=px_rate, theta=px_r) + return {"px": px, "px_scale": px_scale} def loss( self, - tensors: dict[str, torch.Tensor], - inference_outputs: dict[str, torch.Tensor | Distribution], - generative_outputs: dict[str, Distribution], + tensors, + inference_outputs, + generative_outputs, kl_weight: float = 1.0, + classification_ratio = 5., ): """Loss computation.""" - from torch.distributions import Normal - from torch.distributions import kl_divergence as kl - - from scvi.module.base import LossOutput - x = tensors[REGISTRY_KEYS.X_KEY] - y = tensors[REGISTRY_KEYS.LABELS_KEY] - qz = inference_outputs[MODULE_KEYS.QZ_KEY] - px = generative_outputs[MODULE_KEYS.PX_KEY] - - mean = torch.zeros_like(qz.loc) - scale = torch.ones_like(qz.scale) - - kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) + y = tensors[REGISTRY_KEYS.LABELS_KEY].ravel().long() + qz = inference_outputs["qz"] + px = generative_outputs["px"] + fine_labels = tensors['fine_labels'].ravel().long() if 'fine_labels' in tensors.keys() else None + + if self.prior == "mog": + indexed_means = self.prior_means[y] + indexed_log_std = self.prior_log_std[y] + indexed_logits = self.prior_logits[y] + cats = Categorical(logits=indexed_logits) + normal_dists = Independent( + Normal( + indexed_means, + torch.exp(indexed_log_std) + 1e-4 + ), + reinterpreted_batch_ndims=1 + ) + prior = MixtureSameFamily(cats, normal_dists) + u = qz.rsample(sample_shape=(30,)) + # (sample, n_obs, n_latent) -> (sample, n_obs,) + kl_divergence_z = (qz.log_prob(u).sum(-1) - prior.log_prob(u)).mean(0) + else: + mean = torch.zeros_like(qz.loc) + scale = torch.ones_like(qz.scale) + kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) reconst_loss = -px.log_prob(x).sum(-1) - scaling_factor = self.ct_weight[y.long()[:, 0]] - loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) - - return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z) + scaling_factor = self.ct_weight[y] + + if self.classifier is not None: + fine_labels = fine_labels.view(-1) + logits = self.classify(qz.loc, label_index=tensors[REGISTRY_KEYS.LABELS_KEY]) # (n_obs, n_labels) + classification_loss_ = F.cross_entropy(logits, fine_labels, reduction="none") + mask = (fine_labels != self.n_fine_labels) + classification_loss = classification_ratio * torch.masked_select(classification_loss_, mask).mean(0) + + loss = torch.mean(scaling_factor * (reconst_loss + classification_loss + kl_weight * kl_divergence_z)) + + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z, classification_loss=classification_loss, + logits=logits, true_labels=fine_labels + ) + else: + loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) + return LossOutput( + loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_divergence_z + ) @torch.inference_mode() def sample( self, - tensors: dict[str, torch.Tensor], - n_samples: int = 1, - ) -> torch.Tensor: + tensors, + n_samples=1, + ) -> np.ndarray: r"""Generate observation samples from the posterior predictive distribution. The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`. @@ -242,8 +370,6 @@ def sample( x_new : :py:class:`torch.Tensor` tensor with shape (n_cells, n_genes, n_samples) """ - from scvi.distributions import NegativeBinomial - inference_kwargs = {"n_samples": n_samples} generative_outputs = self.forward( tensors, @@ -256,7 +382,9 @@ def sample( dist = NegativeBinomial(px_rate, logits=px_r) if n_samples > 1: - exprs = dist.sample().permute([1, 2, 0]) # Shape : (n_cells_batch, n_genes, n_samples) + exprs = dist.sample().permute( + [1, 2, 0] + ) # Shape : (n_cells_batch, n_genes, n_samples) else: exprs = dist.sample() From b5c2380b55a638bc088d41c7765964efc683d3c9 Mon Sep 17 00:00:00 2001 From: cane11 Date: Wed, 10 Jul 2024 22:32:25 -0700 Subject: [PATCH 10/12] Changed hyperparameter --- src/scvi/module/_mrdeconv.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/scvi/module/_mrdeconv.py b/src/scvi/module/_mrdeconv.py index 9d441b33c5..116918284d 100644 --- a/src/scvi/module/_mrdeconv.py +++ b/src/scvi/module/_mrdeconv.py @@ -89,9 +89,9 @@ def __init__( per_ct_bias: torch.tensor, dropout_decoder: float, dropout_amortization: float = 0.03, - augmentation: bool = False, - n_samples_augmentation: int = 1, - n_states_per_label: int = 1, + augmentation: bool = True, + n_samples_augmentation: int = 2, + n_states_per_label: int = 3, eps_v: float = 2e-3, mean_vprior: np.ndarray = None, var_vprior: np.ndarray = None, @@ -433,12 +433,13 @@ def loss( inference_outputs, generative_outputs, kl_weight: float = 1.0, - ct_sparsity_weight: float = 0., - weighting_augmentation: float = 10., - weighting_smoothing: float = 0.1, + ct_sparsity_weight: float = 2., + weighting_augmentation: float = 100., + weighting_smoothing: float = 100., eta_reg: float = 1., - beta_reg: float = 1e-3, - reconst_weight: float = 1.0, + beta_reg: float = 1., + weighting_kl_latent: float = 1e-3, + reconst_weight: float = 3.0, ): """Compute the loss.""" x_augmented = inference_outputs["x_augmented"] @@ -539,7 +540,7 @@ def loss( log_likelihood_prior = torch.logsumexp(pre_lse, -2) # n_samples, minibatch, n_labels neg_log_likelihood_prior = - log_likelihood_prior.sum(-1) # n_samples, minibatch if self.n_latent_amortization is not None: - neg_log_likelihood_prior += 1e-3 * kl( + neg_log_likelihood_prior += weighting_kl_latent * kl( inference_outputs["qz"], Normal(torch.zeros([self.n_latent_amortization], device=x_augmented.device), torch.ones([self.n_latent_amortization], device=x_augmented.device)) ).sum(dim=-1) From 2533f12265a9e79a1de41a5bb53c84872ca0b52e Mon Sep 17 00:00:00 2001 From: Can Ergen Date: Sun, 3 Nov 2024 00:40:10 -0700 Subject: [PATCH 11/12] Test predict function --- src/scvi/model/_condscvi.py | 47 +++++++++++++++--------------------- tests/model/test_condscvi.py | 13 ++++++++++ 2 files changed, 32 insertions(+), 28 deletions(-) diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index afa43ff801..9b08784658 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -10,7 +10,7 @@ from scvi import REGISTRY_KEYS, settings from scvi.data import AnnDataManager -from scvi.data._utils import _get_adata_minify_type, get_anndata_attribute +from scvi.data._utils import _get_adata_minify_type from scvi.data.fields import CategoricalObsField, LabelsWithUnlabeledObsField, LayerField from scvi.model.base import ( BaseModelClass, @@ -81,7 +81,7 @@ def __init__( n_batch = self.summary_stats.n_batch n_labels = self.summary_stats.n_labels n_vars = self.summary_stats.n_vars - if 'n_fine_labels' in self.summary_stats: + if "n_fine_labels" in self.summary_stats: self.n_fine_labels = self.summary_stats.n_fine_labels else: self.n_fine_labels = None @@ -153,19 +153,16 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra if self.module.prior == "mog": results = { "mean_vprior": self.module.prior_means, - "var_vprior": torch.exp(self.module.prior_log_std)**2, - "weights_vprior": torch.nn.functional.softmax(self.module.prior_logits, dim=-1) + "var_vprior": torch.exp(self.module.prior_log_std) ** 2, + "weights_vprior": torch.nn.functional.softmax(self.module.prior_logits, dim=-1), } else: - # Extracting latent representation of adata including variances. mean_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent)) var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent)) mp_vprior = np.zeros((self.summary_stats.n_labels, p)) - labels_state_registry = self.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) + labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) key = labels_state_registry.original_key mapping = labels_state_registry.categorical_mapping @@ -217,16 +214,16 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra results = { "mean_vprior": mean_vprior, "var_vprior": var_vprior, - "weights_vprior": mp_vprior + "weights_vprior": mp_vprior, } return results - + @torch.inference_mode() def predict( self, adata: AnnData | None = None, - indices: Sequence[int] | None = None, + indices: list[int] | None = None, soft: bool = False, batch_size: int | None = None, use_posterior_mean: bool = True, @@ -261,7 +258,7 @@ def predict( y_pred = [] for _, tensors in enumerate(scdl): inference_input = self.module._get_inference_input(tensors) - qz = self.module.inference(**inference_input)['qz'] + qz = self.module.inference(**inference_input)["qz"] if use_posterior_mean: z = qz.loc else: @@ -291,8 +288,8 @@ def predict( index=adata.obs_names[indices], ) return pred - - @torch.inference_mode() + + @torch.inference_mode() def confusion_coarse_celltypes( self, adata: AnnData | None = None, @@ -327,15 +324,15 @@ def confusion_coarse_celltypes( batch_size=batch_size, ) # Iterate once over the data and computes the reconstruction error - keys = list(self._label_mapping) + ['original'] + keys = list(self._label_mapping) + ["original"] log_lkl = {key: [] for key in keys} for tensors in scdl: loss_kwargs = {"kl_weight": 1} _, _, losses = self.module(tensors, loss_kwargs=loss_kwargs) - log_lkl['original'] += [losses.reconstruction_loss] + log_lkl["original"] += [losses.reconstruction_loss] for i in range(self.module.n_labels): tensors_ = tensors - tensors_['y'] = torch.full_like(tensors['y'], i) + tensors_["y"] = torch.full_like(tensors["y"], i) _, _, losses = self.module(tensors_, loss_kwargs=loss_kwargs) log_lkl[keys[i]] += [losses.reconstruction_loss] for key in keys: @@ -406,19 +403,15 @@ def train( plan_kwargs=plan_kwargs, **kwargs, ) - + def _set_indices_and_labels(self, adata: AnnData): """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry( - REGISTRY_KEYS.LABELS_KEY - ) + labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key self._label_mapping = labels_state_registry.categorical_mapping self._code_to_label = dict(enumerate(self._label_mapping)) if self.n_fine_labels is not None: - fine_labels_state_registry = self.adata_manager.get_state_registry( - 'fine_labels' - ) + fine_labels_state_registry = self.adata_manager.get_state_registry("fine_labels") self.original_fine_label_key = fine_labels_state_registry.original_key self._fine_label_mapping = fine_labels_state_registry.categorical_mapping self._code_to_fine_label = dict(enumerate(self._fine_label_mapping)) @@ -456,15 +449,13 @@ def setup_anndata( ] if fine_labels_key is not None: anndata_fields.append( - LabelsWithUnlabeledObsField('fine_labels', fine_labels_key, unlabeled_category), + LabelsWithUnlabeledObsField("fine_labels", fine_labels_key, unlabeled_category), ) # register new fields if the adata is minified adata_minify_type = _get_adata_minify_type(adata) if adata_minify_type is not None: anndata_fields += cls._get_fields_for_adata_minification(cls, adata_minify_type) - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/tests/model/test_condscvi.py b/tests/model/test_condscvi.py index a082fd07fc..9e5002595c 100644 --- a/tests/model/test_condscvi.py +++ b/tests/model/test_condscvi.py @@ -1,4 +1,5 @@ import os +import numpy as np import pytest @@ -26,6 +27,18 @@ def test_condscvi_batch_key( model.save(model_path, overwrite=True, save_anndata=False) model = CondSCVI.load(model_path, adata=adata) +def test_condscvi_fine_celltype( + save_path: str, +): + adata = synthetic_iid(n_batches=5, n_labels=5) + adata.obs['fine_labels'] = [i+str(np.random.randint(2)) for i in adata.obs['labels']] + CondSCVI.setup_anndata(adata, batch_key="batch", labels_key="labels", fine_labels_key="fine_labels") + model = CondSCVI(adata, encode_covariates=True) + + model.train(max_epochs=1) + model.predict() + model.predict(adata=adata) + model.predict(adata, soft=True, use_posterior_mean=False) def test_condscvi_batch_key_compat_load(save_path: str): adata = synthetic_iid() From 306938504942258a3fef237009a961d9be34db81 Mon Sep 17 00:00:00 2001 From: ori-kron-wis Date: Mon, 9 Jun 2025 13:55:29 +0300 Subject: [PATCH 12/12] updates --- CHANGELOG.md | 2 +- src/scvi/data/_manager.py | 5 +- src/scvi/dataloaders/_data_splitting.py | 2 +- src/scvi/dataloaders/_semi_dataloader.py | 2 +- src/scvi/external/resolvi/_model.py | 2 +- src/scvi/model/_condscvi.py | 433 +++++++++++------------ src/scvi/model/_destvi.py | 408 ++++++++++----------- src/scvi/model/base/_training_mixin.py | 2 +- src/scvi/module/_mrdeconv.py | 9 +- src/scvi/module/_vaec.py | 40 ++- tests/model/test_destvi.py | 2 +- 11 files changed, 457 insertions(+), 450 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 742ee8618c..111cc421d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,7 +24,7 @@ to [Semantic Versioning]. Full commit history is available in the #### Changed -- Update model {class}`scvi.model.DestVI` with fine cell-type classifier {pr}`33XX`. +- Update model {class}`scvi.model.DestVI` with fine cell-type classifier {pr}`3380`. #### Removed diff --git a/src/scvi/data/_manager.py b/src/scvi/data/_manager.py index 1c7ddb8cdf..8bfa6e3cdf 100644 --- a/src/scvi/data/_manager.py +++ b/src/scvi/data/_manager.py @@ -216,10 +216,7 @@ def _add_field( # If empty, we skip registering the field. if not field.is_empty: # Transfer case: Source registry is used for validation and/or setup. - if ( - source_registry is not None - and field.registry_key in source_registry[_constants._FIELD_REGISTRIES_KEY] - ): + if source_registry is not None: field_registry[_constants._STATE_REGISTRY_KEY] = field.transfer_field( source_registry[_constants._FIELD_REGISTRIES_KEY][field.registry_key][ _constants._STATE_REGISTRY_KEY diff --git a/src/scvi/dataloaders/_data_splitting.py b/src/scvi/dataloaders/_data_splitting.py index 72c6bf8ca4..2b86fc4067 100644 --- a/src/scvi/dataloaders/_data_splitting.py +++ b/src/scvi/dataloaders/_data_splitting.py @@ -413,7 +413,7 @@ def __init__( labels_state_registry.original_key, mod_key=getattr(self.adata_manager.data_registry.labels, "mod_key", None), ).ravel() - self.unlabeled_category = labels_state_registry.unlabeled_category + self.unlabeled_category = getattr(labels_state_registry, "unlabeled_category", None) self._unlabeled_indices = np.argwhere(labels == self.unlabeled_category).ravel() self._labeled_indices = np.argwhere(labels != self.unlabeled_category).ravel() diff --git a/src/scvi/dataloaders/_semi_dataloader.py b/src/scvi/dataloaders/_semi_dataloader.py index 384112d245..622d8a78ea 100644 --- a/src/scvi/dataloaders/_semi_dataloader.py +++ b/src/scvi/dataloaders/_semi_dataloader.py @@ -66,7 +66,7 @@ def __init__( # save a nested list of the indices per labeled category self.labeled_locs = [] for label in np.unique(labels): - if label != labels_state_registry.unlabeled_category: + if label != getattr(labels_state_registry, "unlabeled_category", None): label_loc_idx = np.where(labels[indices] == label)[0] label_loc = self.indices[label_loc_idx] self.labeled_locs.append(label_loc) diff --git a/src/scvi/external/resolvi/_model.py b/src/scvi/external/resolvi/_model.py index 3bb24abd5c..25730a13ac 100644 --- a/src/scvi/external/resolvi/_model.py +++ b/src/scvi/external/resolvi/_model.py @@ -657,7 +657,7 @@ def _set_indices_and_labels(self): """Set indices for labeled and unlabeled cells.""" labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key - self.unlabeled_category_ = labels_state_registry.unlabeled_category + self.unlabeled_category_ = getattr(labels_state_registry, "unlabeled_category", None) labels = get_anndata_attribute( self.adata, diff --git a/src/scvi/model/_condscvi.py b/src/scvi/model/_condscvi.py index ef02b9c962..9cce7da4ab 100644 --- a/src/scvi/model/_condscvi.py +++ b/src/scvi/model/_condscvi.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING import numpy as np -import pandas as pd import torch from scvi import REGISTRY_KEYS, settings @@ -15,24 +14,23 @@ from scvi.model.base import ( BaseModelClass, RNASeqMixin, - UnsupervisedTrainingMixin, + SemisupervisedTrainingMixin, VAEMixin, ) from scvi.module import VAEC from scvi.utils import setup_anndata_dsp -from scvi.utils._docstrings import devices_dsp if TYPE_CHECKING: - from collections.abc import Sequence - from anndata import AnnData logger = logging.getLogger(__name__) -class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass): +class CondSCVI(RNASeqMixin, VAEMixin, SemisupervisedTrainingMixin, BaseModelClass): """Conditional version of single-cell Variational Inference. + Used for multi-resolution deconvolution of spatial transcriptomics data :cite:p:`Lopez22`. + Parameters ---------- adata @@ -64,13 +62,9 @@ class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass) See further usage examples in the following tutorial: 1. :doc:`/tutorials/notebooks/spatial/DestVI_tutorial` - """ _module_cls = VAEC - _LATENT_QZM = "_condscvi_latent_qzm" - _LATENT_QZV = "_condscvi_latent_qzv" - _OBSERVED_LIB_SIZE = "_condscvi_observed_lib_size" def __init__( self, @@ -84,14 +78,9 @@ def __init__( ): super().__init__(adata) - n_batch = self.summary_stats.n_batch - n_labels = self.summary_stats.n_labels - n_vars = self.summary_stats.n_vars - if "n_fine_labels" in self.summary_stats: - self.n_fine_labels = self.summary_stats.n_fine_labels - else: - self.n_fine_labels = None - self._set_indices_and_labels(adata) + self.n_labels = self.summary_stats.n_labels + self.n_vars = self.summary_stats.n_vars + self._set_indices_and_labels() if weight_obs: ct_counts = np.unique( self.get_from_registry(adata, REGISTRY_KEYS.LABELS_KEY), @@ -104,10 +93,10 @@ def __init__( module_kwargs.update({"ct_weight": ct_weight}) self.module = self._module_cls( - n_input=n_vars, - n_batch=n_batch, - n_labels=n_labels, - n_fine_labels=self.n_fine_labels, + n_input=self.n_vars, + n_batch=getattr(self.summary_stats, "n_batch", 0), + n_labels=self.n_labels, + n_fine_labels=getattr(self.summary_stats, "n_fine_labels", None), n_hidden=n_hidden, n_latent=n_latent, n_layers=n_layers, @@ -227,203 +216,206 @@ def get_vamp_prior(self, adata: AnnData | None = None, p: int = 10) -> np.ndarra return results - @torch.inference_mode() - def predict( - self, - adata: AnnData | None = None, - indices: list[int] | None = None, - soft: bool = False, - batch_size: int | None = None, - use_posterior_mean: bool = True, - ) -> np.ndarray | pd.DataFrame: - """Return cell label predictions. - - Parameters - ---------- - adata - AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. - indices - Return probabilities for each class label. - soft - If True, returns per class probabilities - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - use_posterior_mean - If ``True``, uses the mean of the posterior distribution to predict celltype - labels. Otherwise, uses a sample from the posterior distribution - this - means that the predictions will be stochastic. - """ - adata = self._validate_anndata(adata) - - if indices is None: - indices = np.arange(adata.n_obs) - - scdl = self._make_data_loader( - adata=adata, - indices=indices, - batch_size=batch_size, - ) - y_pred = [] - for _, tensors in enumerate(scdl): - inference_input = self.module._get_inference_input(tensors) - qz = self.module.inference(**inference_input)["qz"] - if use_posterior_mean: - z = qz.loc - else: - z = qz.sample() - pred = self.module.classify( - z, - label_index=inference_input["y"], - ) - if self.module.classifier.logits: - pred = torch.nn.functional.softmax(pred, dim=-1) - if not soft: - pred = pred.argmax(dim=1) - y_pred.append(pred.detach().cpu()) - - y_pred = torch.cat(y_pred).numpy() - if not soft: - predictions = [] - for p in y_pred: - predictions.append(self._code_to_fine_label[p]) - - return np.array(predictions) - else: - n_labels = len(pred[0]) - pred = pd.DataFrame( - y_pred, - columns=self._fine_label_mapping[:n_labels], - index=adata.obs_names[indices], - ) - return pred - - @torch.inference_mode() - def confusion_coarse_celltypes( - self, - adata: AnnData | None = None, - indices: Sequence[int] | None = None, - batch_size: int | None = None, - ) -> np.ndarray | pd.DataFrame: - """Return likelihood ratios to assess if switch coarse cell-type resolution is too granular - - Parameters - ---------- - adata - AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. - indices - Return probabilities for each class label. - soft - If True, returns per class probabilities - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - use_posterior_mean - If ``True``, uses the mean of the posterior distribution to predict celltype - labels. Otherwise, uses a sample from the posterior distribution - this - means that the predictions will be stochastic. - - """ - adata = self._validate_anndata(adata) - - if indices is None: - indices = np.arange(adata.n_obs) - - scdl = self._make_data_loader( - adata=adata, - indices=indices, - batch_size=batch_size, - ) - # Iterate once over the data and computes the reconstruction error - keys = list(self._label_mapping) + ["original"] - log_lkl = {key: [] for key in keys} - for tensors in scdl: - loss_kwargs = {"kl_weight": 1} - _, _, losses = self.module(tensors, loss_kwargs=loss_kwargs) - log_lkl["original"] += [losses.reconstruction_loss] - for i in range(self.module.n_labels): - tensors_ = tensors - tensors_["y"] = torch.full_like(tensors["y"], i) - _, _, losses = self.module(tensors_, loss_kwargs=loss_kwargs) - log_lkl[keys[i]] += [losses.reconstruction_loss] - for key in keys: - log_lkl[key] = torch.stack(log_lkl[key]).detach().numpy() - - return log_lkl - - @devices_dsp.dedent - def train( - self, - max_epochs: int = 300, - lr: float = 0.001, - accelerator: str = "auto", - devices: int | list[int] | str = "auto", - train_size: float = 1, - validation_size: float | None = None, - shuffle_set_split: bool = True, - batch_size: int = 128, - datasplitter_kwargs: dict | None = None, - plan_kwargs: dict | None = None, - **kwargs, - ): - """Trains the model using MAP inference. - - Parameters - ---------- - max_epochs - Number of epochs to train for - lr - Learning rate for optimization. - %(param_accelerator)s - %(param_devices)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set - are split in the sequential order of the data according to `validation_size` and - `train_size` percentages. - batch_size - Minibatch size to use during training. - datasplitter_kwargs - Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. - plan_kwargs - Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **kwargs - Other keyword args for :class:`~scvi.train.Trainer`. - """ - update_dict = { - "lr": lr, - } - if plan_kwargs is not None: - plan_kwargs.update(update_dict) - else: - plan_kwargs = update_dict - super().train( - max_epochs=max_epochs, - accelerator=accelerator, - devices=devices, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - batch_size=batch_size, - datasplitter_kwargs=datasplitter_kwargs, - plan_kwargs=plan_kwargs, - **kwargs, - ) - - def _set_indices_and_labels(self, adata: AnnData): - """Set indices for labeled and unlabeled cells.""" - labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) - self.original_label_key = labels_state_registry.original_key - self._label_mapping = labels_state_registry.categorical_mapping - self._code_to_label = dict(enumerate(self._label_mapping)) - if self.n_fine_labels is not None: - fine_labels_state_registry = self.adata_manager.get_state_registry("fine_labels") - self.original_fine_label_key = fine_labels_state_registry.original_key - self._fine_label_mapping = fine_labels_state_registry.categorical_mapping - self._code_to_fine_label = dict(enumerate(self._fine_label_mapping)) + # @torch.inference_mode() + # def predict( + # self, + # adata: AnnData | None = None, + # indices: list[int] | None = None, + # soft: bool = False, + # batch_size: int | None = None, + # use_posterior_mean: bool = True, + # ) -> np.ndarray | pd.DataFrame: + # """Return cell label predictions. + # + # Parameters + # ---------- + # adata + # AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + # indices + # Return probabilities for each class label. + # soft + # If True, returns per class probabilities + # batch_size + # Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + # use_posterior_mean + # If ``True``, uses the mean of the posterior distribution to predict celltype + # labels. Otherwise, uses a sample from the posterior distribution - this + # means that the predictions will be stochastic. + # """ + # + # adata = self._validate_anndata(adata) + # + # if indices is None: + # indices = np.arange(adata.n_obs) + # + # scdl = self._make_data_loader( + # adata=adata, + # indices=indices, + # batch_size=batch_size, + # ) + # y_pred = [] + # for _, tensors in enumerate(scdl): + # inference_input = self.module._get_inference_input(tensors) + # qz = self.module.inference(**inference_input)["qz"] + # if use_posterior_mean: + # z = qz.loc + # else: + # z = qz.sample() + # pred = self.module.classify( + # z, + # label_index=inference_input["y"], + # ) + # if self.module.classifier.logits: + # pred = torch.nn.functional.softmax(pred, dim=-1) + # if not soft: + # pred = pred.argmax(dim=1) + # y_pred.append(pred.detach().cpu()) + # + # y_pred = torch.cat(y_pred).numpy() + # if not soft: + # predictions = [] + # for p in y_pred: + # predictions.append(self._code_to_fine_label[p]) + # + # return np.array(predictions) + # else: + # n_labels = len(pred[0]) + # pred = pd.DataFrame( + # y_pred, + # columns=self._fine_label_mapping[:n_labels], + # index=adata.obs_names[indices], + # ) + # return pred + + # @torch.inference_mode() + # def confusion_coarse_celltypes( + # self, + # adata: AnnData | None = None, + # indices: Sequence[int] | None = None, + # batch_size: int | None = None, + # ) -> np.ndarray | pd.DataFrame: + # """Return likelihood ratios to assess if switch coarse cell-type resolution + # + # is too granular + # + # Parameters + # ---------- + # adata + # AnnData object that has been registered via :meth:`~scvi.model.SCANVI.setup_anndata`. + # indices + # Return probabilities for each class label. + # soft + # If True, returns per class probabilities + # batch_size + # Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + # use_posterior_mean + # If ``True``, uses the mean of the posterior distribution to predict celltype + # labels. Otherwise, uses a sample from the posterior distribution - this + # means that the predictions will be stochastic. + # + # """ + # adata = self._validate_anndata(adata) + # + # if indices is None: + # indices = np.arange(adata.n_obs) + # + # scdl = self._make_data_loader( + # adata=adata, + # indices=indices, + # batch_size=batch_size, + # ) + # # Iterate once over the data and computes the reconstruction error + # keys = list(self._label_mapping) + ["original"] + # log_lkl = {key: [] for key in keys} + # for tensors in scdl: + # loss_kwargs = {"kl_weight": 1} + # _, _, losses = self.module(tensors, loss_kwargs=loss_kwargs) + # log_lkl["original"] += [losses.reconstruction_loss] + # for i in range(self.module.n_labels): + # tensors_ = tensors + # tensors_["y"] = torch.full_like(tensors["y"], i) + # _, _, losses = self.module(tensors_, loss_kwargs=loss_kwargs) + # log_lkl[keys[i]] += [losses.reconstruction_loss] + # for key in keys: + # log_lkl[key] = torch.stack(log_lkl[key]).detach().numpy() + # + # return log_lkl + + # @devices_dsp.dedent + # def train( + # self, + # max_epochs: int = 300, + # lr: float = 0.001, + # accelerator: str = "auto", + # devices: int | list[int] | str = "auto", + # train_size: float = 1, + # validation_size: float | None = None, + # shuffle_set_split: bool = True, + # batch_size: int = 128, + # datasplitter_kwargs: dict | None = None, + # plan_kwargs: dict | None = None, + # **kwargs, + # ): + # """Trains the model using MAP inference. + # + # Parameters + # ---------- + # max_epochs + # Number of epochs to train for + # lr + # Learning rate for optimization. + # %(param_accelerator)s + # %(param_devices)s + # train_size + # Size of training set in the range [0.0, 1.0]. + # validation_size + # Size of the test set. If `None`, defaults to 1 - `train_size`. If + # `train_size + validation_size < 1`, the remaining cells belong to a test set. + # shuffle_set_split + # Whether to shuffle indices before splitting. If `False`, the val, train, and test set + # are split in the sequential order of the data according to `validation_size` and + # `train_size` percentages. + # batch_size + # Minibatch size to use during training. + # datasplitter_kwargs + # Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. + # plan_kwargs + # Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to + # `train()` will overwrite values present in `plan_kwargs`, when appropriate. + # **kwargs + # Other keyword args for :class:`~scvi.train.Trainer`. + # """ + # update_dict = { + # "lr": lr, + # } + # if plan_kwargs is not None: + # plan_kwargs.update(update_dict) + # else: + # plan_kwargs = update_dict + # super().train( + # max_epochs=max_epochs, + # accelerator=accelerator, + # devices=devices, + # train_size=train_size, + # validation_size=validation_size, + # shuffle_set_split=shuffle_set_split, + # batch_size=batch_size, + # datasplitter_kwargs=datasplitter_kwargs, + # plan_kwargs=plan_kwargs, + # **kwargs, + # ) + + # def _set_indices_and_labels(self, adata: AnnData): + # """Set indices for labeled and unlabeled cells.""" + # labels_state_registry = self.adata_manager.get_state_registry(REGISTRY_KEYS.LABELS_KEY) + # self.original_label_key = labels_state_registry.original_key + # self._label_mapping = labels_state_registry.categorical_mapping + # self._code_to_label = dict(enumerate(self._label_mapping)) + # if self.n_fine_labels is not None: + # fine_labels_state_registry = self.adata_manager.get_state_registry("fine_labels") + # self.original_fine_label_key = fine_labels_state_registry.original_key + # self._fine_label_mapping = fine_labels_state_registry.categorical_mapping + # self._code_to_fine_label = dict(enumerate(self._fine_label_mapping)) @classmethod @setup_anndata_dsp.dedent @@ -446,9 +438,8 @@ def setup_anndata( %(param_labels_key)s fine_labels_key Key in `adata.obs` where fine-grained labels are stored. - %(unlabeled_category)ss %(param_layer)s - %(param_batch_key)s + %(param_unlabeled_category)s """ setup_method_args = cls._get_setup_method_args(**locals()) anndata_fields = [ diff --git a/src/scvi/model/_destvi.py b/src/scvi/model/_destvi.py index edb9595065..6b2af34138 100644 --- a/src/scvi/model/_destvi.py +++ b/src/scvi/model/_destvi.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd import torch -from scipy.sparse import csr_matrix from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager @@ -21,7 +20,6 @@ from scvi.model.base._archesmixin import _get_loaded_data from scvi.module import MRDeconv from scvi.utils import setup_anndata_dsp -from scvi.utils._docstrings import devices_dsp if TYPE_CHECKING: from collections.abc import Sequence @@ -144,10 +142,10 @@ def from_rna_model( A value of 50 leads to sparser results. anndata_setup_kwargs Keyword args for :meth:`~scvi.model.DestVI.setup_anndata` - **model_kwargs - Keyword args for :class:`~scvi.model.DestVI` + **module_kwargs + Keyword args for :class:`~scvi.model.MRDeconv` """ - attr_dict, var_names, load_state_dict = _get_loaded_data(sc_model) + attr_dict, var_names, load_state_dict, _ = _get_loaded_data(sc_model) registry = attr_dict.pop("registry_") decoder_state_dict = OrderedDict( @@ -270,80 +268,80 @@ def get_proportions( index=index_names, ) - @torch.inference_mode() - def get_fine_celltypes( - self, - sc_model: CondSCVI, - indices=None, - batch_size: int | None = None, - ) -> np.ndarray | dict[str, pd.DataFrame]: - """Returns the estimated cell-type specific latent space for the spatial data. - - Parameters - ---------- - sc_model - trained CondSCVI model - indices - Indices of cells in adata to use. Only used if amortization. - If `None`, all cells are used. - batch_size - Minibatch size for data loading into model. Only used if amortization. - Defaults to `scvi.settings.batch_size`. - """ - self._check_if_trained() - index_names = self.adata.obs.index - stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) - if sc_model.n_fine_labels is None: - raise RuntimeError( - "Single cell model does not contain fine labels. " - "Please train the single-cell model with fine labels." - ) - predicted_fine_celltype_ = [] - for tensors in stdl: - inference_inputs = self.module._get_inference_input(tensors) - outputs = self.module.inference(**inference_inputs) - generative_inputs = self.module._get_generative_input(tensors, outputs) - generative_outputs = self.module.generative(**generative_inputs) - - gamma_local = generative_outputs["gamma"][0, ...].transpose(-2, -4) # c, n, p, m - proportions_modes_local = generative_outputs["proportion_modes"][0, ...] # pmc - n_modes, batch_size, n_celltypes = proportions_modes_local.shape - gamma_local_ = gamma_local.permute((3, 2, 0, 1)).reshape( - -1, self.module.n_latent - ) # m*p*c, n - proportions_modes_local_ = proportions_modes_local.permute( - (1, 0, 2) - ).flatten() # m*p*c - v_local = ( - generative_outputs["v"][0, ..., : -self.module.add_celltypes] - .flatten() - .repeat_interleave(n_modes) - ) # m*p*c - label = ( - torch.arange(self.module.n_labels, device=gamma_local.device) - .repeat(batch_size) - .repeat_interleave(n_modes) - .unsqueeze(-1) - ) # m*p*c, 1 - predicted_fine_celltype_local = ( - v_local.unsqueeze(-1) - * proportions_modes_local_.unsqueeze(-1) - * torch.nn.functional.softmax( - sc_model.module.classify(gamma_local_, label), dim=-1 - ) - ) - predicted_fine_celltype_sum = predicted_fine_celltype_local.reshape( - batch_size, n_celltypes * n_modes, sc_model.n_fine_labels - ).sum(1) - predicted_fine_celltype_.append(predicted_fine_celltype_sum.detach().cpu()) - predicted_fine_celltype = torch.cat(predicted_fine_celltype_, dim=0).numpy() - - pred = pd.DataFrame( - predicted_fine_celltype, - columns=sc_model._fine_label_mapping, - index=index_names, - ) - return pred + # @torch.inference_mode() + # def get_fine_celltypes( + # self, + # sc_model: CondSCVI, + # indices=None, + # batch_size: int | None = None, + # ) -> np.ndarray | dict[str, pd.DataFrame]: + # """Returns the estimated cell-type specific latent space for the spatial data. + # + # Parameters + # ---------- + # sc_model + # trained CondSCVI model + # indices + # Indices of cells in adata to use. Only used if amortization. + # If `None`, all cells are used. + # batch_size + # Minibatch size for data loading into model. Only used if amortization. + # Defaults to `scvi.settings.batch_size`. + # """ + # self._check_if_trained() + # index_names = self.adata.obs.index + # stdl = self._make_data_loader(adata=self.adata, indices=indices, batch_size=batch_size) + # if sc_model.n_fine_labels is None: + # raise RuntimeError( + # "Single cell model does not contain fine labels. " + # "Please train the single-cell model with fine labels." + # ) + # predicted_fine_celltype_ = [] + # for tensors in stdl: + # inference_inputs = self.module._get_inference_input(tensors) + # outputs = self.module.inference(**inference_inputs) + # generative_inputs = self.module._get_generative_input(tensors, outputs) + # generative_outputs = self.module.generative(**generative_inputs) + # + # gamma_local = generative_outputs["gamma"][0, ...].transpose(-2, -4) # c, n, p, m + # proportions_modes_local = generative_outputs["proportion_modes"][0, ...] # pmc + # n_modes, batch_size, n_celltypes = proportions_modes_local.shape + # gamma_local_ = gamma_local.permute((3, 2, 0, 1)).reshape( + # -1, self.module.n_latent + # ) # m*p*c, n + # proportions_modes_local_ = proportions_modes_local.permute( + # (1, 0, 2) + # ).flatten() # m*p*c + # v_local = ( + # generative_outputs["v"][0, ..., : -self.module.add_celltypes] + # .flatten() + # .repeat_interleave(n_modes) + # ) # m*p*c + # label = ( + # torch.arange(self.module.n_labels, device=gamma_local.device) + # .repeat(batch_size) + # .repeat_interleave(n_modes) + # .unsqueeze(-1) + # ) # m*p*c, 1 + # predicted_fine_celltype_local = ( + # v_local.unsqueeze(-1) + # * proportions_modes_local_.unsqueeze(-1) + # * torch.nn.functional.softmax( + # sc_model.module.classify(gamma_local_, label), dim=-1 + # ) + # ) + # predicted_fine_celltype_sum = predicted_fine_celltype_local.reshape( + # batch_size, n_celltypes * n_modes, sc_model.n_fine_labels + # ).sum(1) + # predicted_fine_celltype_.append(predicted_fine_celltype_sum.detach().cpu()) + # predicted_fine_celltype = torch.cat(predicted_fine_celltype_, dim=0).numpy() + # + # pred = pd.DataFrame( + # predicted_fine_celltype, + # columns=sc_model._fine_label_mapping, + # index=index_names, + # ) + # return pred @torch.inference_mode() def get_gamma( @@ -516,131 +514,131 @@ def get_scale_for_ct( index_names = index_names[indices] return pd.DataFrame(data=data, columns=column_names, index=index_names) - @torch.inference_mode() - def get_expression_for_ct( - self, - label: str, - indices: Sequence[int] | None = None, - batch_size: int | None = None, - return_sparse_array: bool = False, - ) -> pd.DataFrame: - r"""Return the scaled parameter of the NB for every spot in queried cell types. - - Parameters - ---------- - label - cell type of interest - indices - Indices of cells in self.adata to use. If `None`, all cells are used. - batch_size - Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. - return_sparse_array - If `True`, returns a sparse array instead of a dataframe. - - Returns - ------- - Pandas dataframe of gene_expression - """ - self._check_if_trained() - - if label not in self.cell_type_mapping_extended: - raise ValueError("Unknown cell type") - y = self.cell_type_mapping_extended.index(label) - - stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size) - expression_ct = [] - for tensors in stdl: - inference_inputs = self.module._get_inference_input(tensors) - outputs = self.module.inference(**inference_inputs) - generative_inputs = self.module._get_generative_input(tensors, outputs) - generative_outputs = self.module.generative(**generative_inputs) - px_scale, proportions = ( - generative_outputs["px_mu"][0, ...], - generative_outputs["v"][0, ...], - ) - px_scale_expected = torch.einsum("mkl,mk->mkl", px_scale, proportions) - px_scale_proportions = px_scale_expected[:, y, :] / px_scale_expected.sum(dim=1) - x_ct = tensors["X"].to(px_scale_proportions.device) * px_scale_proportions - expression_ct += [x_ct.cpu()] - - data = torch.cat(expression_ct).numpy() - if return_sparse_array: - data = csr_matrix(data.T) - return data - else: - column_names = self.adata.var.index - index_names = self.adata.obs.index - if indices is not None: - index_names = index_names[indices] - return pd.DataFrame(data=data, columns=column_names, index=index_names) - - @devices_dsp.dedent - def train( - self, - max_epochs: int = 2000, - lr: float = 0.003, - accelerator: str = "auto", - devices: int | list[int] | str = "auto", - train_size: float = 1.0, - validation_size: float | None = None, - shuffle_set_split: bool = True, - batch_size: int = 128, - n_epochs_kl_warmup: int = 200, - datasplitter_kwargs: dict | None = None, - plan_kwargs: dict | None = None, - **kwargs, - ): - """Trains the model using MAP inference. - - Parameters - ---------- - max_epochs - Number of epochs to train for - lr - Learning rate for optimization. - %(param_accelerator)s - %(param_devices)s - train_size - Size of training set in the range [0.0, 1.0]. - validation_size - Size of the test set. If `None`, defaults to 1 - `train_size`. If - `train_size + validation_size < 1`, the remaining cells belong to a test set. - shuffle_set_split - Whether to shuffle indices before splitting. If `False`, the val, train, and test set - are split in the sequential order of the data according to `validation_size` and - `train_size` percentages. - batch_size - Minibatch size to use during training. - n_epochs_kl_warmup - number of epochs needed to reach unit kl weight in the elbo - datasplitter_kwargs - Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. - plan_kwargs - Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to - `train()` will overwrite values present in `plan_kwargs`, when appropriate. - **kwargs - Other keyword args for :class:`~scvi.train.Trainer`. - """ - update_dict = { - "lr": lr, - "n_epochs_kl_warmup": n_epochs_kl_warmup, - } - if plan_kwargs is not None: - plan_kwargs.update(update_dict) - else: - plan_kwargs = update_dict - super().train( - max_epochs=max_epochs, - accelerator=accelerator, - devices=devices, - train_size=train_size, - validation_size=validation_size, - shuffle_set_split=shuffle_set_split, - batch_size=batch_size, - datasplitter_kwargs=datasplitter_kwargs, - plan_kwargs=plan_kwargs, - **kwargs, - ) + # @torch.inference_mode() + # def get_expression_for_ct( + # self, + # label: str, + # indices: Sequence[int] | None = None, + # batch_size: int | None = None, + # return_sparse_array: bool = False, + # ) -> pd.DataFrame: + # r"""Return the scaled parameter of the NB for every spot in queried cell types. + # + # Parameters + # ---------- + # label + # cell type of interest + # indices + # Indices of cells in self.adata to use. If `None`, all cells are used. + # batch_size + # Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`. + # return_sparse_array + # If `True`, returns a sparse array instead of a dataframe. + # + # Returns + # ------- + # Pandas dataframe of gene_expression + # """ + # self._check_if_trained() + # + # if label not in self.cell_type_mapping_extended: + # raise ValueError("Unknown cell type") + # y = self.cell_type_mapping_extended.index(label) + # + # stdl = self._make_data_loader(self.adata, indices=indices, batch_size=batch_size) + # expression_ct = [] + # for tensors in stdl: + # inference_inputs = self.module._get_inference_input(tensors) + # outputs = self.module.inference(**inference_inputs) + # generative_inputs = self.module._get_generative_input(tensors, outputs) + # generative_outputs = self.module.generative(**generative_inputs) + # px_scale, proportions = ( + # generative_outputs["px_mu"][0, ...], + # generative_outputs["v"][0, ...], + # ) + # px_scale_expected = torch.einsum("mkl,mk->mkl", px_scale, proportions) + # px_scale_proportions = px_scale_expected[:, y, :] / px_scale_expected.sum(dim=1) + # x_ct = tensors["X"].to(px_scale_proportions.device) * px_scale_proportions + # expression_ct += [x_ct.cpu()] + # + # data = torch.cat(expression_ct).numpy() + # if return_sparse_array: + # data = csr_matrix(data.T) + # return data + # else: + # column_names = self.adata.var.index + # index_names = self.adata.obs.index + # if indices is not None: + # index_names = index_names[indices] + # return pd.DataFrame(data=data, columns=column_names, index=index_names) + + # @devices_dsp.dedent + # def train( + # self, + # max_epochs: int = 2000, + # lr: float = 0.003, + # accelerator: str = "auto", + # devices: int | list[int] | str = "auto", + # train_size: float = 1.0, + # validation_size: float | None = None, + # shuffle_set_split: bool = True, + # batch_size: int = 128, + # n_epochs_kl_warmup: int = 200, + # datasplitter_kwargs: dict | None = None, + # plan_kwargs: dict | None = None, + # **kwargs, + # ): + # """Trains the model using MAP inference. + # + # Parameters + # ---------- + # max_epochs + # Number of epochs to train for + # lr + # Learning rate for optimization. + # %(param_accelerator)s + # %(param_devices)s + # train_size + # Size of training set in the range [0.0, 1.0]. + # validation_size + # Size of the test set. If `None`, defaults to 1 - `train_size`. If + # `train_size + validation_size < 1`, the remaining cells belong to a test set. + # shuffle_set_split + # Whether to shuffle indices before splitting. If `False`, the val, train, and test set + # are split in the sequential order of the data according to `validation_size` and + # `train_size` percentages. + # batch_size + # Minibatch size to use during training. + # n_epochs_kl_warmup + # number of epochs needed to reach unit kl weight in the elbo + # datasplitter_kwargs + # Additional keyword arguments passed into :class:`~scvi.dataloaders.DataSplitter`. + # plan_kwargs + # Keyword args for :class:`~scvi.train.TrainingPlan`. Keyword arguments passed to + # `train()` will overwrite values present in `plan_kwargs`, when appropriate. + # **kwargs + # Other keyword args for :class:`~scvi.train.Trainer`. + # """ + # update_dict = { + # "lr": lr, + # "n_epochs_kl_warmup": n_epochs_kl_warmup, + # } + # if plan_kwargs is not None: + # plan_kwargs.update(update_dict) + # else: + # plan_kwargs = update_dict + # super().train( + # max_epochs=max_epochs, + # accelerator=accelerator, + # devices=devices, + # train_size=train_size, + # validation_size=validation_size, + # shuffle_set_split=shuffle_set_split, + # batch_size=batch_size, + # datasplitter_kwargs=datasplitter_kwargs, + # plan_kwargs=plan_kwargs, + # **kwargs, + # ) @classmethod @setup_anndata_dsp.dedent @@ -658,6 +656,8 @@ def setup_anndata( ---------- %(param_adata)s %(param_layer)s + smoothed_layer + param that... %(param_batch_key)s """ setup_method_args = cls._get_setup_method_args(**locals()) diff --git a/src/scvi/model/base/_training_mixin.py b/src/scvi/model/base/_training_mixin.py index fd7e61af58..707636b4ca 100644 --- a/src/scvi/model/base/_training_mixin.py +++ b/src/scvi/model/base/_training_mixin.py @@ -172,7 +172,7 @@ def _set_indices_and_labels(self, datamodule=None): """Set indices for labeled and unlabeled cells.""" labels_state_registry = self.get_state_registry(REGISTRY_KEYS.LABELS_KEY) self.original_label_key = labels_state_registry.original_key - self.unlabeled_category_ = labels_state_registry.unlabeled_category + self.unlabeled_category_ = getattr(labels_state_registry, "unlabeled_category", None) if datamodule is None: self.labels_ = get_anndata_attribute( diff --git a/src/scvi/module/_mrdeconv.py b/src/scvi/module/_mrdeconv.py index 8e3687c381..66662011fe 100644 --- a/src/scvi/module/_mrdeconv.py +++ b/src/scvi/module/_mrdeconv.py @@ -14,7 +14,12 @@ from scvi import REGISTRY_KEYS from scvi.distributions import NegativeBinomial -from scvi.module.base import BaseModuleClass, EmbeddingModuleMixin, LossOutput, auto_move_data +from scvi.module.base import ( + BaseMinifiedModeModuleClass, + EmbeddingModuleMixin, + LossOutput, + auto_move_data, +) from scvi.nn import Encoder, FCLayers @@ -23,7 +28,7 @@ def identity(x): return x -class MRDeconv(EmbeddingModuleMixin, BaseModuleClass): +class MRDeconv(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): """Model for multi-resolution deconvolution of spatial transriptomics. Parameters diff --git a/src/scvi/module/_vaec.py b/src/scvi/module/_vaec.py index eb7ee949f8..b28d2836ea 100644 --- a/src/scvi/module/_vaec.py +++ b/src/scvi/module/_vaec.py @@ -1,6 +1,6 @@ import numpy as np import torch -from torch.distributions import Categorical, Independent, MixtureSameFamily, Normal +from torch.distributions import Categorical, Distribution, Independent, MixtureSameFamily, Normal from torch.distributions import kl_divergence as kl from torch.nn import functional as F @@ -20,7 +20,6 @@ torch.backends.cudnn.benchmark = True -# Conditional VAE model class VAEC(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): """Conditional Variational auto-encoder model. @@ -31,7 +30,7 @@ class VAEC(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): n_input Number of input genes n_batch - Number of batches + Number of batches. If ``0``, no batch correction is performed. n_labels Number of labels n_hidden @@ -46,6 +45,9 @@ class VAEC(EmbeddingModuleMixin, BaseMinifiedModeModuleClass): Multiplicative weight for cell type specific latent space. dropout_rate Dropout rate for the encoder and decoder neural network. + encode_covariates + If ``True``, covariates are concatenated to gene expression prior to passing through + the encoder(s). Else, only gene expression is used. extra_encoder_kwargs Keyword arguments passed into :class:`~scvi.nn.Encoder`. extra_decoder_kwargs @@ -62,7 +64,7 @@ def __init__( n_latent: int = 5, n_layers: int = 2, log_variational: bool = True, - ct_weight: np.ndarray = None, + ct_weight: np.ndarray | None = None, dropout_rate: float = 0.05, encode_covariates: bool = False, extra_encoder_kwargs: dict | None = None, @@ -88,6 +90,10 @@ def __init__( self.prior = prior self.num_classes_mog = num_classes_mog self.init_embedding(REGISTRY_KEYS.BATCH_KEY, n_batch, **{}) + + if self.encode_covariates and self.n_batch < 1: + raise ValueError("`n_batch` must be greater than 0 if `encode_covariates` is `True`.") + batch_dim = self.get_embedding(REGISTRY_KEYS.BATCH_KEY).embedding_dim cat_list = [n_labels] @@ -109,7 +115,7 @@ def __init__( use_batch_norm=False, use_layer_norm=True, return_dist=True, - **_extra_encoder_kwargs, + **(extra_encoder_kwargs or {}), ) if n_fine_labels is not None: cls_parameters = { @@ -150,7 +156,7 @@ def __init__( inject_covariates=True, use_batch_norm=False, use_layer_norm=True, - **_extra_decoder_kwargs, + **(extra_decoder_kwargs or {}), ) self.px_decoder = torch.nn.Linear(n_hidden, n_input) self.per_ct_bias = torch.nn.Parameter(torch.zeros(n_labels, n_input)) @@ -219,7 +225,7 @@ def _regular_inference(self, x, y, batch_index, n_samples=1): x_ = x library = x.sum(1).unsqueeze(1) if self.log_variational: - x_ = torch.log(1 + x_) + x_ = torch.log1p(x_) if self.encode_covariates: batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) encoder_input = torch.cat([x_, batch_rep], dim=-1) @@ -278,7 +284,14 @@ def classify( return w_y @auto_move_data - def generative(self, z, library, y, batch_index): + def generative( + self, + z: torch.Tensor, + library: torch.Tensor, + y: torch.Tensor, + batch_index: torch.Tensor | None = None, + transform_batch: torch.Tensor | None = None, + ) -> dict[str, Distribution]: """Runs the generative model.""" batch_rep = self.compute_embedding(REGISTRY_KEYS.BATCH_KEY, batch_index) decoder_input = torch.cat([z, batch_rep], dim=-1) @@ -286,15 +299,16 @@ def generative(self, z, library, y, batch_index): px_scale = torch.nn.Softmax(dim=-1)(self.px_decoder(h) + self.per_ct_bias[y.ravel()]) px_rate = library * px_scale px_r = torch.exp(self.px_r) - px = NegativeBinomial(mu=px_rate, theta=px_r) - return {"px": px, "px_scale": px_scale} + px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) + return {"px": px} def loss( self, - tensors, - inference_outputs, - generative_outputs, + tensors: dict[str, torch.Tensor], + inference_outputs: dict[str, torch.Tensor | Distribution], + generative_outputs: dict[str, Distribution], kl_weight: float = 1.0, + labelled_tensors: dict[str, torch.Tensor] | None = None, classification_ratio=5.0, ): """Loss computation.""" diff --git a/tests/model/test_destvi.py b/tests/model/test_destvi.py index 0bcff17ab8..c98becf89e 100644 --- a/tests/model/test_destvi.py +++ b/tests/model/test_destvi.py @@ -13,7 +13,7 @@ def test_destvi(): dataset = synthetic_iid(n_labels=n_labels) dataset.obs["overclustering_vamp"] = list(range(dataset.n_obs)) CondSCVI.setup_anndata(dataset, labels_key="labels") - sc_model = CondSCVI(dataset, n_latent=n_latent, n_layers=n_layers) + sc_model = CondSCVI(dataset, n_latent=n_latent, n_layers=n_layers, prior="mog") sc_model.train(1, train_size=1) sc_model.get_normalized_expression(dataset)