3535 clip_lfc_factory ,
3636 encode_categories ,
3737 get_n_latent_heuristic ,
38- impute_with_neighbors ,
38+ impute_cats_with_neighbors ,
39+ impute_expr_with_neighbors ,
3940 validate_expression_range ,
4041 validate_marker ,
42+ validate_layer_key ,
4143 validate_obs_keys ,
4244 validate_obsm_keys ,
4345)
@@ -204,6 +206,10 @@ def __init__(
204206 REGISTRY_KEYS .SAMPLE_KEY
205207 ).original_key
206208
209+ self .batch_key = self .adata_manager .get_state_registry (
210+ REGISTRY_KEYS .BATCH_KEY
211+ ).original_key
212+
207213 self ._model_summary_string = ( # noqa: UP032
208214
209215 "CytoVI Model with the following params: \n n_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
@@ -309,11 +315,9 @@ def train(
309315 devices : Union [int , list [int ], str ] = "auto" ,
310316 train_size : float = 0.9 ,
311317 validation_size : Optional [float ] = None ,
312- shuffle_set_split : bool = True ,
313318 batch_size : int = 128 ,
314319 early_stopping : bool = True ,
315320 check_val_every_n_epoch : Optional [int ] = None ,
316- # reduce_lr_on_plateau: bool = True,
317321 n_steps_kl_warmup : Union [int , None ] = None ,
318322 n_epochs_kl_warmup : Union [int , None ] = 400 ,
319323 adversarial_classifier : Optional [bool ] = None ,
@@ -370,18 +374,10 @@ def train(
370374 if adversarial_classifier is None :
371375 adversarial_classifier = self ._use_adversarial_classifier
372376
373- # n_steps_kl_warmup = (
374- # n_steps_kl_warmup
375- # if n_steps_kl_warmup is not None
376- # else int(0.75 * self.adata.n_obs)
377- # )
378- # if reduce_lr_on_plateau:
379- # check_val_every_n_epoch = 1
380377
381378 update_dict = {
382379 "lr" : lr ,
383380 "adversarial_classifier" : adversarial_classifier ,
384- # "reduce_lr_on_plateau": reduce_lr_on_plateau,
385381 "n_epochs_kl_warmup" : n_epochs_kl_warmup ,
386382 "n_steps_kl_warmup" : n_steps_kl_warmup ,
387383 }
@@ -396,7 +392,6 @@ def train(
396392 self .adata_manager ,
397393 train_size = train_size ,
398394 validation_size = validation_size ,
399- # shuffle_set_split=shuffle_set_split,
400395 batch_size = batch_size ,
401396 )
402397 training_plan = self ._training_plan_cls (self .module , ** plan_kwargs )
@@ -405,7 +400,6 @@ def train(
405400 training_plan = training_plan ,
406401 data_splitter = data_splitter ,
407402 max_epochs = max_epochs ,
408- # use_gpu=use_gpu,
409403 accelerator = accelerator ,
410404 devices = devices ,
411405 early_stopping = early_stopping ,
@@ -751,7 +745,7 @@ def differential_expression(
751745 idx1 ,
752746 idx2 ,
753747 all_stats ,
754- scrna_raw_counts_properties , # modify the extended stats summary and include the GMM
748+ scrna_raw_counts_properties ,
755749 col_names ,
756750 mode ,
757751 batchid1 ,
@@ -768,104 +762,110 @@ def differential_expression(
768762 def get_aggregated_posterior (
769763 self ,
770764 adata : AnnData = None ,
771- locs : np .ndarray = None ,
772- scales : np .ndarray = None ,
773765 sample : Union [int , str ] = None ,
774766 indices : Sequence [int ] = None ,
775767 batch_size : int = None ,
768+ dof : float | None = 3. ,
776769 ) -> dist .Distribution :
777- self . _check_if_trained ( warn = False )
770+ """Compute the aggregated posterior over the ``u`` latent representations.
778771
779- if locs is not None and scales is not None :
780- qz_loc = torch . from_numpy ( locs ). T
781- qz_scale = torch . from_numpy ( scales ). T
782- else :
783- adata = self . _validate_anndata ( adata )
784-
785- if indices is None :
786- indices = np . arange ( self . adata . n_obs )
787- if sample is not None :
788- indices = np . intersect1d (
789- np . array ( indices ), np . where ( adata . obs [ self . sample_key ] == sample )[ 0 ]
790- )
772+ Parameters
773+ ----------
774+ adata
775+ AnnData object to use. Defaults to the AnnData object used to initialize the model.
776+ sample
777+ Name or index of the sample to filter on. If ``None``, uses all cells.
778+ indices
779+ Indices of cells to use.
780+ batch_size
781+ Batch size to use for computing the latent representation.
782+ dof
783+ Degrees of freedom for the Student's t-distribution components. If ``None``, components are Normal.
791784
792- dataloader = self ._make_data_loader (adata = adata , indices = indices , batch_size = batch_size )
785+ Returns
786+ -------
787+ A mixture distribution of the aggregated posterior.
788+ """
789+ self ._check_if_trained (warn = False )
790+ adata = self ._validate_anndata (adata )
793791
794- qz_locs = []
795- qz_scales = []
796- for tensors in dataloader :
797- outputs = self .module .inference (self .module ._get_inference_input (tensors ))
798792
799- qz_locs .append (outputs ["qz" ].loc )
800- qz_scales .append (outputs ["qz" ].scale )
793+ if indices is None :
794+ indices = np .arange (self .adata .n_obs )
795+ if sample is not None :
796+ indices = np .intersect1d (
797+ np .array (indices ), np .where (adata .obs [self .sample_key ] == sample )[0 ]
798+ )
801799
802- # transpose because we need num cells to be rightmost dimension for mixture
803- qz_loc = torch .cat (qz_locs , 0 ).T
804- qz_scale = torch .cat (qz_scales , 0 ).T
800+ dataloader = self ._make_data_loader (adata = adata , indices = indices , batch_size = batch_size )
801+ qu_loc , qu_scale = self .get_latent_representation (batch_size = batch_size , return_dist = True , dataloader = dataloader , give_mean = True )
805802
803+ qu_loc = torch .tensor (qu_loc , device = self .device ).T
804+ qu_scale = torch .tensor (qu_scale , device = self .device ).T
805+
806+ if dof is None :
807+ components = dist .Normal (qu_loc , qu_scale )
808+ else :
809+ components = dist .StudentT (dof , qu_loc , qu_scale )
806810 return dist .MixtureSameFamily (
807- dist .Categorical (torch .ones (qz_loc .shape [1 ])), dist .Normal (qz_loc , qz_scale )
808- )
811+ dist .Categorical (logits = torch .ones (qu_loc .shape [1 ], device = qu_loc .device )), components )
809812
810813 def differential_abundance (
811814 self ,
812- adata : AnnData = None ,
813- locs : np .ndarray = None ,
814- scales : np .ndarray = None ,
815- sample_id : np .ndarray = None , # sample ids of each latent rep above.
816- sample_cov_keys : list [str ] = None ,
817- sample_subset : list [str ] = None ,
818- compute_log_enrichment : bool = False ,
815+ adata : AnnData | None = None ,
819816 batch_size : int = 128 ,
817+ downsample_cells : int | None = None ,
818+ dof : float | None = None ,
820819 ) -> pd .DataFrame :
821- adata = self . _validate_anndata ( adata )
820+ """Compute the differential abundance between samples.
822821
823- if locs is not None and scales is not None : # if user passes in latent reps directly
824- us = locs
825- variances = scales
826- unique_samples = np .unique (sample_id )
827- else :
828- # return dist so that we can also get the vars, and don't have redundantly get the latent
829- # reps again in get_aggregated_posterior
830- us , variances = self .get_latent_representation (
831- adata , give_mean = True , batch_size = batch_size , return_dist = True
832- )
822+ Computes the logarithm of the ratio of the probabilities of each sample conditioned on the
823+ estimated aggregate posterior distribution of each cell.
824+
825+ Parameters
826+ ----------
827+ adata
828+ The data object to compute the differential abundance for.
829+ batch_size
830+ Minibatch size for computing the differential abundance.
831+ downsample_cells
832+ Number of cells to subset to before computing the differential abundance.
833+ dof
834+ Degrees of freedom for the Student's t-distribution components for aggregated posterior. If ``None``, components are Normal.
833835
834- unique_samples = adata .obs [self .sample_key ].unique ()
836+ Returns
837+ -------
838+ DataFrame of shape (n_cells, n_samples) containing the log probabilities
839+ for each cell across samples. The rows correspond to cell names from `adata.obs_names`,
840+ and the columns correspond to unique sample identifiers.
841+ """
842+ adata = self ._validate_anndata (adata )
835843
844+ zs = self .get_latent_representation (
845+ batch_size = batch_size , return_dist = False , give_mean = True
846+ )
847+
848+ unique_samples = adata .obs [self .sample_key ].unique ()
849+ dataloader = torch .utils .data .DataLoader (zs , batch_size = batch_size )
836850 log_probs = []
837851 for sample_name in tqdm (unique_samples ):
838- if locs is not None and scales is not None :
839- indices = np .where (sample_id == sample_name )
840- else :
841- indices = np .where (adata .obs [self .sample_key ] == sample_name )[0 ]
842-
843- locs_per_sample = us [indices ]
844- scales_per_sample = variances [indices ]
845- ap = self .get_aggregated_posterior (self , locs = locs_per_sample , scales = scales_per_sample )
852+ indices = np .where (adata .obs [self .sample_key ] == sample_name )[0 ]
853+ if downsample_cells is not None and downsample_cells < indices .shape [0 ]:
854+ indices = np .random .choice (indices , downsample_cells , replace = False )
846855
856+ ap = self .get_aggregated_posterior (adata = adata , indices = indices , dof = dof )
847857 log_probs_ = []
848- n_splits = max (adata .n_obs // batch_size , 1 )
849- for u_rep in np .array_split (us , n_splits ):
850- log_probs_ .append (ap .log_prob (torch .tensor (u_rep )).sum (- 1 , keepdims = True ).cpu ())
851-
852- log_probs .append (np .concatenate (log_probs_ , axis = 0 ))
858+ for z_rep in dataloader :
859+ z_rep = z_rep .to (self .device )
860+ log_probs_ .append (ap .log_prob (z_rep ).sum (- 1 , keepdims = True ))
861+ log_probs .append (torch .cat (log_probs_ , axis = 0 ).cpu ().numpy ())
853862
854863 log_probs = np .concatenate (log_probs , 1 )
855-
856-
857- if locs is not None and scales is not None :
858- indices = np .arange (locs .shape [0 ])
859- else :
860- indices = adata .obs_names .to_numpy ()
861-
862- columns = unique_samples
863-
864- log_probs_df = pd .DataFrame (data = log_probs , index = indices , columns = columns )
865-
864+ log_probs_df = pd .DataFrame (data = log_probs , index = adata .obs_names .to_numpy (), columns = unique_samples )
866865 return log_probs_df
867866
868867
868+
869869 def impute_categories_from_reference (
870870 self ,
871871 adata_reference : AnnData ,
@@ -929,14 +929,122 @@ def impute_categories_from_reference(
929929 rep_query = adata_query .obsm [use_rep ]
930930
931931 # Impute missing categories for the query data
932- imputed_query_cat_indices , uncertainty = impute_with_neighbors (
932+ imputed_query_cat_indices , uncertainty = impute_cats_with_neighbors (
933933 rep_query , rep_ref , cat_encoded_ref , n_neighbors = n_neighbors , compute_uncertainty = return_uncertainty
934934 )
935935
936936 # Convert imputed indices back to category labels
937- imputed_query_cat = ohe .inverse_transform (np .eye (n_cats )[imputed_query_cat_indices ])
937+ imputed_query_cat = ohe .inverse_transform (np .eye (n_cats )[imputed_query_cat_indices ]). reshape ( - 1 )
938938
939939 if return_uncertainty :
940940 return imputed_query_cat , uncertainty
941941 else :
942942 return imputed_query_cat
943+
944+
945+
946+ def impute_rna_from_reference (
947+ self : AnnData ,
948+ reference_batch : str ,
949+ adata_rna : AnnData ,
950+ layer_key : str ,
951+ use_rep : Optional [str ] = None ,
952+ n_neighbors : int = 20 ,
953+ compute_uncertainty : bool = False ,
954+ return_query_only : bool = False ,
955+ ):
956+ """
957+ Impute expression data from missing modality for the query dataset based on a reference dataset using a shared representation.
958+
959+ Parameters
960+ ----------
961+ adata : AnnData
962+ Annotated data matrix containing both reference and query data.
963+ reference_batch : str
964+ Identifier for the reference batch in `adata.obs['technology']`.
965+ adata_to_impute : AnnData
966+ Annotated data matrix containing the expression data to impute.
967+ layer_key : str
968+ Key in the `.layers` attribute of `adata_to_impute` for the reference expression data.
969+ use_rep : str, optional
970+ Key in the `.obsm` attribute to use as the representation space (e.g., latent space).
971+ If `None`, defaults to `X_CytoVI`.
972+ n_neighbors : int, optional (default: 20)
973+ Number of nearest neighbors to use for imputation.
974+ compute_uncertainty : bool, optional (default: False)
975+ If `True`, also computes the uncertainty of the imputation.
976+ return_query_only : bool, optional (default: False)
977+ If `True`, return only the imputed query dataset as an AnnData object.
978+
979+ Returns
980+ -------
981+ AnnData
982+ Imputed AnnData object. If `return_query_only` is `True`, only the query dataset is returned.
983+ If `return_uncertainty` is `True`, also returns the uncertainty matrix.
984+ """
985+ adata = self .adata
986+ batch_key = self .batch_key
987+
988+ # validate input
989+ validate_obsm_keys (adata , use_rep )
990+ validate_layer_key (adata_rna , layer_key )
991+
992+ # retrieve reference and query indices
993+ reference_indices = adata .obs_names [adata .obs [batch_key ] == reference_batch ]
994+ query_indices = adata .obs_names [adata .obs [batch_key ] != reference_batch ]
995+
996+ # validate that query indices are in to impute adata
997+ if not all (idx in adata_rna .obs_names for idx in reference_indices ):
998+ raise ValueError ("Some query indices are not present in `adata_to_impute`." )
999+
1000+ # get representations
1001+ if use_rep is None :
1002+ obsm_keys = adata .obsm .keys ()
1003+
1004+ if CYTOVI_DEFAULT_REP in obsm_keys :
1005+ use_rep = CYTOVI_DEFAULT_REP
1006+ else :
1007+ adata .obsm [CYTOVI_DEFAULT_REP ] = self .get_latent_representation ()
1008+ use_rep = CYTOVI_DEFAULT_REP
1009+
1010+ # Get representations and reference expression
1011+ rep_ref = adata [reference_indices ,:].obsm [use_rep ]
1012+ rep_query = adata [query_indices ,:].obsm [use_rep ]
1013+ expr_data_ref = adata_rna [reference_indices ,:].layers [layer_key ]
1014+
1015+ # Impute expression in query
1016+ imputed_expr_query , uncertainty = impute_expr_with_neighbors (
1017+ rep_query , rep_ref , expr_data_ref , n_neighbors = n_neighbors , compute_uncertainty = compute_uncertainty
1018+ )
1019+
1020+ # create anndata for imputed query dataset
1021+ adata_imputed_query = AnnData (
1022+ X = imputed_expr_query ,
1023+ obs = adata [query_indices ,:].obs ,
1024+ obsm = adata [query_indices ,:].obsm ,
1025+ var = adata_rna .var ,
1026+ layers = {layer_key : imputed_expr_query },
1027+ )
1028+
1029+ if return_query_only :
1030+ return adata_imputed_query
1031+
1032+ # assemble new anndata with imputed expression
1033+ expr_comb = np .concatenate ([expr_data_ref , imputed_expr_query ], axis = 0 )
1034+ obs_comb = adata .obs .loc [np .concatenate ([reference_indices , query_indices ]), :]
1035+
1036+ # restore original indices and add metadata
1037+ adata_combined = AnnData (
1038+ X = expr_comb ,
1039+ obs = obs_comb ,
1040+ var = adata_rna .var )
1041+
1042+ adata_imputed = AnnData (
1043+ X = adata_combined [adata .obs_names ].X ,
1044+ obs = adata .obs ,
1045+ var = adata_rna .var ,
1046+ obsm = adata .obsm ,
1047+ layers = {layer_key : adata_combined [adata .obs_names ].X },
1048+ )
1049+
1050+ return adata_imputed
0 commit comments