Skip to content

Commit cc001a0

Browse files
Backport PR #3994 on branch 1.12.x (fix: raise ValueError when combat batch has fewer than 2 cells) (#4001)
Co-authored-by: LiudengZhang <99156394+LiudengZhang@users.noreply.github.com>
1 parent fec6eb4 commit cc001a0

3 files changed

Lines changed: 35 additions & 6 deletions

File tree

docs/release-notes/3994.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{func}`scanpy.pp.combat` now raises a {class}`ValueError` when a batch contains fewer than 2 cells, instead of silently producing NaN values in the corrected data {smaller}`L Zhang`

src/scanpy/preprocessing/_combat.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,22 @@ def combat( # noqa: PLR0915
203203
sanitize_anndata(adata)
204204

205205
# construct a pandas series of the batch annotation
206-
model = adata.obs[[key, *(covariates if covariates else [])]]
207-
batch_info = model.groupby(key, observed=True).indices.values()
206+
model: pd.DataFrame = adata.obs[[key, *(covariates if covariates else [])]]
207+
batch_info = model.groupby(key, observed=True).indices
208208
n_batch = len(batch_info)
209-
n_batches = np.array([len(v) for v in batch_info])
209+
n_batches = np.array([len(v) for v in batch_info.values()])
210+
211+
# check for batches with fewer than 2 cells
212+
small_batches = [
213+
batch for batch, size in zip(batch_info, n_batches, strict=True) if size < 2
214+
]
215+
if small_batches:
216+
msg = (
217+
f"Batches {small_batches!r} have fewer than 2 cells. "
218+
"ComBat requires at least 2 cells per batch to estimate "
219+
"within-batch variance. Filter these batches before running combat."
220+
)
221+
raise ValueError(msg)
210222
n_array = float(sum(n_batches))
211223

212224
# standardize across genes using a pooled variance estimator
@@ -221,7 +233,9 @@ def combat( # noqa: PLR0915
221233
la.inv(batch_design.T @ batch_design) @ batch_design.T @ s_data.T
222234
).values
223235
# first estimate for the multiplicative batch effect
224-
delta_hat = [s_data.iloc[:, batch_idxs].var(axis=1) for batch_idxs in batch_info]
236+
delta_hat = [
237+
s_data.iloc[:, batch_idxs].var(axis=1) for batch_idxs in batch_info.values()
238+
]
225239

226240
# empirically fix the prior hyperparameters
227241
gamma_bar = gamma_hat.mean(axis=1)
@@ -234,7 +248,7 @@ def combat( # noqa: PLR0915
234248
# gamma star and delta star will be our empirical bayes (EB) estimators
235249
# for the additive and multiplicative batch effect per batch and cell
236250
gamma_star, delta_star = [], []
237-
for i, batch_idxs in enumerate(batch_info):
251+
for i, batch_idxs in enumerate(batch_info.values()):
238252
# temp stores our estimates for the batch effect parameters.
239253
# temp[0] is the additive batch effect
240254
# temp[1] is the multiplicative batch effect
@@ -258,7 +272,7 @@ def combat( # noqa: PLR0915
258272

259273
# we now apply the parametric adjustment to the standardized data from above
260274
# loop over all batches in the data
261-
for j, batch_idxs in enumerate(batch_info):
275+
for j, batch_idxs in enumerate(batch_info.values()):
262276
# we basically subtract the additive batch effect, rescale by the ratio
263277
# of multiplicative batch effect to pooled variance and add the overall gene
264278
# wise mean

tests/test_combat.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,20 @@ def test_combat_obs_names():
7575
assert_equal(a, b)
7676

7777

78+
def test_combat_single_cell_batch():
79+
"""Test that combat raises an error when a batch has fewer than 2 cells.
80+
81+
Regression test for https://github.com/scverse/scanpy/issues/1175
82+
"""
83+
adata = sc.datasets.blobs()
84+
# Create a batch where one category has only 1 cell
85+
batch = pd.Categorical(["single"] + ["other"] * (adata.n_obs - 1))
86+
adata.obs["batch"] = batch
87+
88+
with pytest.raises(ValueError, match="fewer than 2 cells"):
89+
sc.pp.combat(adata, key="batch")
90+
91+
7892
def test_silhouette():
7993
# this test checks wether combat can align data from several gaussians
8094
# it checks this by computing the silhouette coefficient in a pca embedding

0 commit comments

Comments
 (0)