Skip to content

Commit 1b1bda7

Browse files
Merge pull request #12 from florianingelfinger/dev
Added tests, RNA imputation function, performance fixes
2 parents 0bce496 + d741718 commit 1b1bda7

File tree

7 files changed

+401
-127
lines changed

7 files changed

+401
-127
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ urls.Source = "https://github.com/florianingelfinger/CytoVI"
2020
urls.Home-page = "https://github.com/florianingelfinger/CytoVI"
2121
dependencies = [
2222
"anndata",
23+
"scvi-tools>=1.2",
24+
"pynndescent",
2325
# for debug logging (referenced from the issue template)
2426
"session-info"
2527
]

src/cytovi/_constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@ class _REGISTRY_KEYS_NT(NamedTuple):
2121
REGISTRY_KEYS = _REGISTRY_KEYS_NT()
2222

2323
CYTOVI_DEFAULT_REP = 'X_CytoVI'
24+
CYTOVI_SCATTER_FEATS = ("FSC", "Fsc", "fsc", "SSC", "Ssc", "ssc")

src/cytovi/_model.py

Lines changed: 194 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@
3535
clip_lfc_factory,
3636
encode_categories,
3737
get_n_latent_heuristic,
38-
impute_with_neighbors,
38+
impute_cats_with_neighbors,
39+
impute_expr_with_neighbors,
3940
validate_expression_range,
4041
validate_marker,
42+
validate_layer_key,
4143
validate_obs_keys,
4244
validate_obsm_keys,
4345
)
@@ -204,6 +206,10 @@ def __init__(
204206
REGISTRY_KEYS.SAMPLE_KEY
205207
).original_key
206208

209+
self.batch_key = self.adata_manager.get_state_registry(
210+
REGISTRY_KEYS.BATCH_KEY
211+
).original_key
212+
207213
self._model_summary_string = ( # noqa: UP032
208214

209215
"CytoVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: "
@@ -309,11 +315,9 @@ def train(
309315
devices: Union[int, list[int], str] = "auto",
310316
train_size: float = 0.9,
311317
validation_size: Optional[float] = None,
312-
shuffle_set_split: bool = True,
313318
batch_size: int = 128,
314319
early_stopping: bool = True,
315320
check_val_every_n_epoch: Optional[int] = None,
316-
# reduce_lr_on_plateau: bool = True,
317321
n_steps_kl_warmup: Union[int, None] = None,
318322
n_epochs_kl_warmup: Union[int, None] = 400,
319323
adversarial_classifier: Optional[bool] = None,
@@ -370,18 +374,10 @@ def train(
370374
if adversarial_classifier is None:
371375
adversarial_classifier = self._use_adversarial_classifier
372376

373-
# n_steps_kl_warmup = (
374-
# n_steps_kl_warmup
375-
# if n_steps_kl_warmup is not None
376-
# else int(0.75 * self.adata.n_obs)
377-
# )
378-
# if reduce_lr_on_plateau:
379-
# check_val_every_n_epoch = 1
380377

381378
update_dict = {
382379
"lr": lr,
383380
"adversarial_classifier": adversarial_classifier,
384-
# "reduce_lr_on_plateau": reduce_lr_on_plateau,
385381
"n_epochs_kl_warmup": n_epochs_kl_warmup,
386382
"n_steps_kl_warmup": n_steps_kl_warmup,
387383
}
@@ -396,7 +392,6 @@ def train(
396392
self.adata_manager,
397393
train_size=train_size,
398394
validation_size=validation_size,
399-
# shuffle_set_split=shuffle_set_split,
400395
batch_size=batch_size,
401396
)
402397
training_plan = self._training_plan_cls(self.module, **plan_kwargs)
@@ -405,7 +400,6 @@ def train(
405400
training_plan=training_plan,
406401
data_splitter=data_splitter,
407402
max_epochs=max_epochs,
408-
# use_gpu=use_gpu,
409403
accelerator=accelerator,
410404
devices=devices,
411405
early_stopping=early_stopping,
@@ -751,7 +745,7 @@ def differential_expression(
751745
idx1,
752746
idx2,
753747
all_stats,
754-
scrna_raw_counts_properties, # modify the extended stats summary and include the GMM
748+
scrna_raw_counts_properties,
755749
col_names,
756750
mode,
757751
batchid1,
@@ -768,104 +762,110 @@ def differential_expression(
768762
def get_aggregated_posterior(
769763
self,
770764
adata: AnnData = None,
771-
locs: np.ndarray = None,
772-
scales: np.ndarray = None,
773765
sample: Union [int, str] = None,
774766
indices: Sequence[int] = None,
775767
batch_size: int = None,
768+
dof: float | None = 3.,
776769
) -> dist.Distribution:
777-
self._check_if_trained(warn=False)
770+
"""Compute the aggregated posterior over the ``u`` latent representations.
778771
779-
if locs is not None and scales is not None:
780-
qz_loc = torch.from_numpy(locs).T
781-
qz_scale = torch.from_numpy(scales).T
782-
else:
783-
adata = self._validate_anndata(adata)
784-
785-
if indices is None:
786-
indices = np.arange(self.adata.n_obs)
787-
if sample is not None:
788-
indices = np.intersect1d(
789-
np.array(indices), np.where(adata.obs[self.sample_key] == sample)[0]
790-
)
772+
Parameters
773+
----------
774+
adata
775+
AnnData object to use. Defaults to the AnnData object used to initialize the model.
776+
sample
777+
Name or index of the sample to filter on. If ``None``, uses all cells.
778+
indices
779+
Indices of cells to use.
780+
batch_size
781+
Batch size to use for computing the latent representation.
782+
dof
783+
Degrees of freedom for the Student's t-distribution components. If ``None``, components are Normal.
791784
792-
dataloader = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
785+
Returns
786+
-------
787+
A mixture distribution of the aggregated posterior.
788+
"""
789+
self._check_if_trained(warn=False)
790+
adata = self._validate_anndata(adata)
793791

794-
qz_locs = []
795-
qz_scales = []
796-
for tensors in dataloader:
797-
outputs = self.module.inference(self.module._get_inference_input(tensors))
798792

799-
qz_locs.append(outputs["qz"].loc)
800-
qz_scales.append(outputs["qz"].scale)
793+
if indices is None:
794+
indices = np.arange(self.adata.n_obs)
795+
if sample is not None:
796+
indices = np.intersect1d(
797+
np.array(indices), np.where(adata.obs[self.sample_key] == sample)[0]
798+
)
801799

802-
# transpose because we need num cells to be rightmost dimension for mixture
803-
qz_loc = torch.cat(qz_locs, 0).T
804-
qz_scale = torch.cat(qz_scales, 0).T
800+
dataloader = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
801+
qu_loc, qu_scale = self.get_latent_representation(batch_size=batch_size, return_dist=True, dataloader=dataloader, give_mean=True)
805802

803+
qu_loc = torch.tensor(qu_loc, device=self.device).T
804+
qu_scale = torch.tensor(qu_scale, device=self.device).T
805+
806+
if dof is None:
807+
components = dist.Normal(qu_loc, qu_scale)
808+
else:
809+
components = dist.StudentT(dof, qu_loc, qu_scale)
806810
return dist.MixtureSameFamily(
807-
dist.Categorical(torch.ones(qz_loc.shape[1])), dist.Normal(qz_loc, qz_scale)
808-
)
811+
dist.Categorical(logits=torch.ones(qu_loc.shape[1], device=qu_loc.device)), components)
809812

810813
def differential_abundance(
811814
self,
812-
adata: AnnData = None,
813-
locs: np.ndarray = None,
814-
scales: np.ndarray = None,
815-
sample_id: np.ndarray = None, # sample ids of each latent rep above.
816-
sample_cov_keys: list[str] = None,
817-
sample_subset: list[str] = None,
818-
compute_log_enrichment: bool = False,
815+
adata: AnnData | None = None,
819816
batch_size: int = 128,
817+
downsample_cells: int | None = None,
818+
dof: float | None = None,
820819
) -> pd.DataFrame:
821-
adata = self._validate_anndata(adata)
820+
"""Compute the differential abundance between samples.
822821
823-
if locs is not None and scales is not None: # if user passes in latent reps directly
824-
us = locs
825-
variances = scales
826-
unique_samples = np.unique(sample_id)
827-
else:
828-
# return dist so that we can also get the vars, and don't have redundantly get the latent
829-
# reps again in get_aggregated_posterior
830-
us, variances = self.get_latent_representation(
831-
adata, give_mean=True, batch_size=batch_size, return_dist=True
832-
)
822+
Computes the logarithm of the ratio of the probabilities of each sample conditioned on the
823+
estimated aggregate posterior distribution of each cell.
824+
825+
Parameters
826+
----------
827+
adata
828+
The data object to compute the differential abundance for.
829+
batch_size
830+
Minibatch size for computing the differential abundance.
831+
downsample_cells
832+
Number of cells to subset to before computing the differential abundance.
833+
dof
834+
Degrees of freedom for the Student's t-distribution components for aggregated posterior. If ``None``, components are Normal.
833835
834-
unique_samples = adata.obs[self.sample_key].unique()
836+
Returns
837+
-------
838+
DataFrame of shape (n_cells, n_samples) containing the log probabilities
839+
for each cell across samples. The rows correspond to cell names from `adata.obs_names`,
840+
and the columns correspond to unique sample identifiers.
841+
"""
842+
adata = self._validate_anndata(adata)
835843

844+
zs = self.get_latent_representation(
845+
batch_size=batch_size, return_dist=False, give_mean=True
846+
)
847+
848+
unique_samples = adata.obs[self.sample_key].unique()
849+
dataloader = torch.utils.data.DataLoader(zs, batch_size=batch_size)
836850
log_probs = []
837851
for sample_name in tqdm(unique_samples):
838-
if locs is not None and scales is not None:
839-
indices = np.where(sample_id == sample_name)
840-
else:
841-
indices = np.where(adata.obs[self.sample_key] == sample_name)[0]
842-
843-
locs_per_sample = us[indices]
844-
scales_per_sample = variances[indices]
845-
ap = self.get_aggregated_posterior(self, locs=locs_per_sample, scales=scales_per_sample)
852+
indices = np.where(adata.obs[self.sample_key] == sample_name)[0]
853+
if downsample_cells is not None and downsample_cells < indices.shape[0]:
854+
indices = np.random.choice(indices, downsample_cells, replace=False)
846855

856+
ap = self.get_aggregated_posterior(adata=adata, indices=indices, dof=dof)
847857
log_probs_ = []
848-
n_splits = max(adata.n_obs // batch_size, 1)
849-
for u_rep in np.array_split(us, n_splits):
850-
log_probs_.append(ap.log_prob(torch.tensor(u_rep)).sum(-1, keepdims=True).cpu())
851-
852-
log_probs.append(np.concatenate(log_probs_, axis=0))
858+
for z_rep in dataloader:
859+
z_rep = z_rep.to(self.device)
860+
log_probs_.append(ap.log_prob(z_rep).sum(-1, keepdims=True))
861+
log_probs.append(torch.cat(log_probs_, axis=0).cpu().numpy())
853862

854863
log_probs = np.concatenate(log_probs, 1)
855-
856-
857-
if locs is not None and scales is not None:
858-
indices = np.arange(locs.shape[0])
859-
else:
860-
indices = adata.obs_names.to_numpy()
861-
862-
columns = unique_samples
863-
864-
log_probs_df = pd.DataFrame(data=log_probs, index=indices, columns=columns)
865-
864+
log_probs_df = pd.DataFrame(data=log_probs, index=adata.obs_names.to_numpy(), columns=unique_samples)
866865
return log_probs_df
867866

868867

868+
869869
def impute_categories_from_reference(
870870
self,
871871
adata_reference: AnnData,
@@ -929,14 +929,122 @@ def impute_categories_from_reference(
929929
rep_query = adata_query.obsm[use_rep]
930930

931931
# Impute missing categories for the query data
932-
imputed_query_cat_indices, uncertainty = impute_with_neighbors(
932+
imputed_query_cat_indices, uncertainty = impute_cats_with_neighbors(
933933
rep_query, rep_ref, cat_encoded_ref, n_neighbors=n_neighbors, compute_uncertainty=return_uncertainty
934934
)
935935

936936
# Convert imputed indices back to category labels
937-
imputed_query_cat = ohe.inverse_transform(np.eye(n_cats)[imputed_query_cat_indices])
937+
imputed_query_cat = ohe.inverse_transform(np.eye(n_cats)[imputed_query_cat_indices]).reshape(-1)
938938

939939
if return_uncertainty:
940940
return imputed_query_cat, uncertainty
941941
else:
942942
return imputed_query_cat
943+
944+
945+
946+
def impute_rna_from_reference(
947+
self: AnnData,
948+
reference_batch: str,
949+
adata_rna: AnnData,
950+
layer_key: str,
951+
use_rep: Optional[str] = None,
952+
n_neighbors: int = 20,
953+
compute_uncertainty: bool = False,
954+
return_query_only: bool = False,
955+
):
956+
"""
957+
Impute expression data from missing modality for the query dataset based on a reference dataset using a shared representation.
958+
959+
Parameters
960+
----------
961+
adata : AnnData
962+
Annotated data matrix containing both reference and query data.
963+
reference_batch : str
964+
Identifier for the reference batch in `adata.obs['technology']`.
965+
adata_to_impute : AnnData
966+
Annotated data matrix containing the expression data to impute.
967+
layer_key : str
968+
Key in the `.layers` attribute of `adata_to_impute` for the reference expression data.
969+
use_rep : str, optional
970+
Key in the `.obsm` attribute to use as the representation space (e.g., latent space).
971+
If `None`, defaults to `X_CytoVI`.
972+
n_neighbors : int, optional (default: 20)
973+
Number of nearest neighbors to use for imputation.
974+
compute_uncertainty : bool, optional (default: False)
975+
If `True`, also computes the uncertainty of the imputation.
976+
return_query_only : bool, optional (default: False)
977+
If `True`, return only the imputed query dataset as an AnnData object.
978+
979+
Returns
980+
-------
981+
AnnData
982+
Imputed AnnData object. If `return_query_only` is `True`, only the query dataset is returned.
983+
If `return_uncertainty` is `True`, also returns the uncertainty matrix.
984+
"""
985+
adata = self.adata
986+
batch_key = self.batch_key
987+
988+
# validate input
989+
validate_obsm_keys(adata, use_rep)
990+
validate_layer_key(adata_rna, layer_key)
991+
992+
# retrieve reference and query indices
993+
reference_indices = adata.obs_names[adata.obs[batch_key] == reference_batch]
994+
query_indices = adata.obs_names[adata.obs[batch_key] != reference_batch]
995+
996+
# validate that query indices are in to impute adata
997+
if not all(idx in adata_rna.obs_names for idx in reference_indices):
998+
raise ValueError("Some query indices are not present in `adata_to_impute`.")
999+
1000+
# get representations
1001+
if use_rep is None:
1002+
obsm_keys = adata.obsm.keys()
1003+
1004+
if CYTOVI_DEFAULT_REP in obsm_keys:
1005+
use_rep = CYTOVI_DEFAULT_REP
1006+
else:
1007+
adata.obsm[CYTOVI_DEFAULT_REP] = self.get_latent_representation()
1008+
use_rep = CYTOVI_DEFAULT_REP
1009+
1010+
# Get representations and reference expression
1011+
rep_ref = adata[reference_indices,:].obsm[use_rep]
1012+
rep_query = adata[query_indices,:].obsm[use_rep]
1013+
expr_data_ref = adata_rna[reference_indices,:].layers[layer_key]
1014+
1015+
# Impute expression in query
1016+
imputed_expr_query, uncertainty = impute_expr_with_neighbors(
1017+
rep_query, rep_ref, expr_data_ref, n_neighbors=n_neighbors, compute_uncertainty=compute_uncertainty
1018+
)
1019+
1020+
# create anndata for imputed query dataset
1021+
adata_imputed_query = AnnData(
1022+
X = imputed_expr_query,
1023+
obs = adata[query_indices,:].obs,
1024+
obsm = adata[query_indices,:].obsm,
1025+
var = adata_rna.var,
1026+
layers={layer_key: imputed_expr_query},
1027+
)
1028+
1029+
if return_query_only:
1030+
return adata_imputed_query
1031+
1032+
# assemble new anndata with imputed expression
1033+
expr_comb = np.concatenate([expr_data_ref, imputed_expr_query], axis=0)
1034+
obs_comb = adata.obs.loc[np.concatenate([reference_indices, query_indices]), :]
1035+
1036+
# restore original indices and add metadata
1037+
adata_combined = AnnData(
1038+
X = expr_comb,
1039+
obs = obs_comb,
1040+
var=adata_rna.var)
1041+
1042+
adata_imputed = AnnData(
1043+
X=adata_combined[adata.obs_names].X,
1044+
obs=adata.obs,
1045+
var=adata_rna.var,
1046+
obsm=adata.obsm,
1047+
layers={layer_key: adata_combined[adata.obs_names].X},
1048+
)
1049+
1050+
return adata_imputed

0 commit comments

Comments
 (0)