@@ -203,10 +203,22 @@ def combat( # noqa: PLR0915
203203 sanitize_anndata (adata )
204204
205205 # construct a pandas series of the batch annotation
206- model = adata .obs [[key , * (covariates if covariates else [])]]
207- batch_info = model .groupby (key , observed = True ).indices . values ()
206+ model : pd . DataFrame = adata .obs [[key , * (covariates if covariates else [])]]
207+ batch_info = model .groupby (key , observed = True ).indices
208208 n_batch = len (batch_info )
209- n_batches = np .array ([len (v ) for v in batch_info ])
209+ n_batches = np .array ([len (v ) for v in batch_info .values ()])
210+
211+ # check for batches with fewer than 2 cells
212+ small_batches = [
213+ batch for batch , size in zip (batch_info , n_batches , strict = True ) if size < 2
214+ ]
215+ if small_batches :
216+ msg = (
217+ f"Batches { small_batches !r} have fewer than 2 cells. "
218+ "ComBat requires at least 2 cells per batch to estimate "
219+ "within-batch variance. Filter these batches before running combat."
220+ )
221+ raise ValueError (msg )
210222 n_array = float (sum (n_batches ))
211223
212224 # standardize across genes using a pooled variance estimator
@@ -221,7 +233,9 @@ def combat( # noqa: PLR0915
221233 la .inv (batch_design .T @ batch_design ) @ batch_design .T @ s_data .T
222234 ).values
223235 # first estimate for the multiplicative batch effect
224- delta_hat = [s_data .iloc [:, batch_idxs ].var (axis = 1 ) for batch_idxs in batch_info ]
236+ delta_hat = [
237+ s_data .iloc [:, batch_idxs ].var (axis = 1 ) for batch_idxs in batch_info .values ()
238+ ]
225239
226240 # empirically fix the prior hyperparameters
227241 gamma_bar = gamma_hat .mean (axis = 1 )
@@ -234,7 +248,7 @@ def combat( # noqa: PLR0915
234248 # gamma star and delta star will be our empirical bayes (EB) estimators
235249 # for the additive and multiplicative batch effect per batch and cell
236250 gamma_star , delta_star = [], []
237- for i , batch_idxs in enumerate (batch_info ):
251+ for i , batch_idxs in enumerate (batch_info . values () ):
238252 # temp stores our estimates for the batch effect parameters.
239253 # temp[0] is the additive batch effect
240254 # temp[1] is the multiplicative batch effect
@@ -258,7 +272,7 @@ def combat( # noqa: PLR0915
258272
259273 # we now apply the parametric adjustment to the standardized data from above
260274 # loop over all batches in the data
261- for j , batch_idxs in enumerate (batch_info ):
275+ for j , batch_idxs in enumerate (batch_info . values () ):
262276 # we basically subtract the additive batch effect, rescale by the ratio
263277 # of multiplicative batch effect to pooled variance and add the overall gene
264278 # wise mean
0 commit comments