Skip to content

Commit 5dacc67

Browse files
lordy5pre-commit-ci[bot]ori-kron-wisclaudegithub-actions[bot]
authored
fix: differential abundance and aggregated posterior computations (#3789)
Fixes: - Use the original anndata for aggregated posterior computation, not the subset of the anndata. All cells in a sample should be considered when computing the aggregated posterior (up to num_cells_posterior) - Update the tests accordingly - Correctly pass scales into torch distributions, not variances. get_latent_representation returns variances, not scales, while the aggregated posterior code previously expected scales, and so accidentally passed variances into torch distributions (Normal and Student's T expect scales instead) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Ori Kronfeld <ori.kronfeld@weizmann.ac.il> Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: claude <claude@users.noreply.github.com>
1 parent db0900c commit 5dacc67

2 files changed

Lines changed: 21 additions & 10 deletions

File tree

src/scvi/model/base/_vaemixin.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -401,12 +401,13 @@ def get_aggregated_posterior(
401401
indices = np.arange(adata.n_obs)
402402

403403
dataloader = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
404-
qu_loc, qu_scale = self.get_latent_representation(
404+
qu_loc, qu_var = self.get_latent_representation(
405405
batch_size=batch_size, return_dist=True, dataloader=dataloader, give_mean=True
406406
)
407407

408408
qu_loc = torch.tensor(qu_loc, device=self.device) # (n_cells, n_latent_u)
409-
qu_scale = torch.tensor(qu_scale, device=self.device)
409+
qu_var = torch.tensor(qu_var, device=self.device)
410+
qu_scale = torch.sqrt(qu_var)
410411

411412
if dof is None:
412413
components = dist.Normal(qu_loc, qu_scale)
@@ -421,6 +422,7 @@ def get_aggregated_posterior(
421422
def differential_abundance(
422423
self,
423424
adata: AnnOrMuData | None = None,
425+
adata_sub: AnnOrMuData | None = None,
424426
sample_key: str | None = None,
425427
batch_size: int = 128,
426428
num_cells_posterior: int | None = None,
@@ -434,8 +436,13 @@ def differential_abundance(
434436
Parameters
435437
----------
436438
adata
439+
The full data object used to compute each aggregated posterior.
440+
Defaults to the AnnData object used to initialize the model.
441+
adata_sub
437442
The data object to compute the differential abundance for.
438-
For very large datasets, this should be a subset of the original data object.
443+
For very large datasets, this should be used to pass in a subset of the full data
444+
object. The aggregated posteriors are still computed from the full data object.
445+
The resulting log_probs matrix is stored in adata_sub.obsm
439446
sample_key
440447
Key for the sample covariate.
441448
batch_size
@@ -451,14 +458,17 @@ def differential_abundance(
451458
from tqdm import tqdm
452459

453460
adata = self._validate_anndata(adata)
461+
if adata_sub is None:
462+
adata_sub = adata
463+
else:
464+
adata_sub = self._validate_anndata(adata_sub)
454465

455-
# In case user passes in a subset of model's anndata
456-
adata_dataloader = self._make_data_loader(adata=adata, batch_size=batch_size)
466+
adata_dataloader = self._make_data_loader(adata=adata_sub, batch_size=batch_size)
457467
us = self.get_latent_representation(
458468
batch_size=batch_size, dataloader=adata_dataloader, give_mean=True
459469
)
460470
dataloader = torch.utils.data.DataLoader(us, batch_size=batch_size)
461-
unique_samples = adata.obs[sample_key].unique()
471+
unique_samples = adata_sub.obs[sample_key].unique()
462472

463473
log_probs = []
464474
for sample_name in tqdm(unique_samples):
@@ -476,6 +486,7 @@ def differential_abundance(
476486
log_probs.append(torch.cat(log_probs_, axis=0).cpu().numpy())
477487

478488
log_probs = np.array(log_probs).T
479-
log_probs_df = pd.DataFrame(data=log_probs, index=adata.obs_names, columns=unique_samples)
480-
481-
adata.obsm["da_log_probs"] = log_probs_df
489+
log_probs_df = pd.DataFrame(
490+
data=log_probs, index=adata_sub.obs_names, columns=unique_samples
491+
)
492+
adata_sub.obsm["da_log_probs"] = log_probs_df

tests/model/test_differential_abundance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def test_differential_abundance(model: VAEMixin, adata: AnnData, mdata: MuData,
160160

161161
subset_indices = np.random.choice(adata.n_obs, adata.n_obs // 2, replace=False)
162162
adata_subset = adata[subset_indices, :].copy()
163-
model.differential_abundance(adata_subset, **da_kwargs)
163+
model.differential_abundance(adata, adata_subset, **da_kwargs)
164164
assert isinstance(adata_subset.obsm["da_log_probs"], pd.DataFrame)
165165

166166

0 commit comments

Comments
 (0)