Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3994.fix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{func}`scanpy.pp.combat` now raises a {class}`ValueError` when a batch contains fewer than 2 cells, instead of silently producing NaN values in the corrected data {smaller}`L Zhang`
26 changes: 20 additions & 6 deletions src/scanpy/preprocessing/_combat.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,22 @@ def combat( # noqa: PLR0915
sanitize_anndata(adata)

# construct a pandas series of the batch annotation
model = adata.obs[[key, *(covariates if covariates else [])]]
batch_info = model.groupby(key, observed=True).indices.values()
model: pd.DataFrame = adata.obs[[key, *(covariates if covariates else [])]]
batch_info = model.groupby(key, observed=True).indices
n_batch = len(batch_info)
n_batches = np.array([len(v) for v in batch_info])
n_batches = np.array([len(v) for v in batch_info.values()])

# check for batches with fewer than 2 cells
small_batches = [
batch for batch, size in zip(batch_info, n_batches, strict=True) if size < 2
]
if small_batches:
msg = (
f"Batches {small_batches!r} have fewer than 2 cells. "
"ComBat requires at least 2 cells per batch to estimate "
"within-batch variance. Filter these batches before running combat."
)
raise ValueError(msg)
n_array = float(sum(n_batches))

# standardize across genes using a pooled variance estimator
Expand All @@ -220,7 +232,9 @@ def combat( # noqa: PLR0915
la.inv(batch_design.T @ batch_design) @ batch_design.T @ s_data.T
).values
# first estimate for the multiplicative batch effect
delta_hat = [s_data.iloc[:, batch_idxs].var(axis=1) for batch_idxs in batch_info]
delta_hat = [
s_data.iloc[:, batch_idxs].var(axis=1) for batch_idxs in batch_info.values()
]

# empirically fix the prior hyperparameters
gamma_bar = gamma_hat.mean(axis=1)
Expand All @@ -233,7 +247,7 @@ def combat( # noqa: PLR0915
# gamma star and delta star will be our empirical bayes (EB) estimators
# for the additive and multiplicative batch effect per batch and cell
gamma_star, delta_star = [], []
for i, batch_idxs in enumerate(batch_info):
for i, batch_idxs in enumerate(batch_info.values()):
# temp stores our estimates for the batch effect parameters.
# temp[0] is the additive batch effect
# temp[1] is the multiplicative batch effect
Expand All @@ -257,7 +271,7 @@ def combat( # noqa: PLR0915

# we now apply the parametric adjustment to the standardized data from above
# loop over all batches in the data
for j, batch_idxs in enumerate(batch_info):
for j, batch_idxs in enumerate(batch_info.values()):
# we basically subtract the additive batch effect, rescale by the ratio
# of multiplicative batch effect to pooled variance and add the overall gene
# wise mean
Expand Down
14 changes: 14 additions & 0 deletions tests/test_combat.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,20 @@ def test_combat_obs_names():
assert_equal(a, b)


def test_combat_single_cell_batch():
"""Test that combat raises an error when a batch has fewer than 2 cells.

Regression test for https://github.com/scverse/scanpy/issues/1175
"""
adata = sc.datasets.blobs()
# Create a batch where one category has only 1 cell
batch = pd.Categorical(["single"] + ["other"] * (adata.n_obs - 1))
adata.obs["batch"] = batch

with pytest.raises(ValueError, match="fewer than 2 cells"):
sc.pp.combat(adata, key="batch")


def test_silhouette():
# this test checks wether combat can align data from several gaussians
# it checks this by computing the silhouette coefficient in a pca embedding
Expand Down
Loading