@@ -202,10 +202,22 @@ def combat( # noqa: PLR0915
202202 sanitize_anndata (adata )
203203
204204 # construct a pandas series of the batch annotation
205- model = adata .obs [[key , * (covariates if covariates else [])]]
206- batch_info = model .groupby (key , observed = True ).indices . values ()
205+ model : pd . DataFrame = adata .obs [[key , * (covariates if covariates else [])]]
206+ batch_info = model .groupby (key , observed = True ).indices
207207 n_batch = len (batch_info )
208- n_batches = np .array ([len (v ) for v in batch_info ])
208+ n_batches = np .array ([len (v ) for v in batch_info .values ()])
209+
210+ # check for batches with fewer than 2 cells
211+ small_batches = [
212+ batch for batch , size in zip (batch_info , n_batches , strict = True ) if size < 2
213+ ]
214+ if small_batches :
215+ msg = (
216+ f"Batches { small_batches !r} have fewer than 2 cells. "
217+ "ComBat requires at least 2 cells per batch to estimate "
218+ "within-batch variance. Filter these batches before running combat."
219+ )
220+ raise ValueError (msg )
209221 n_array = float (sum (n_batches ))
210222
211223 # standardize across genes using a pooled variance estimator
@@ -220,7 +232,9 @@ def combat( # noqa: PLR0915
220232 la .inv (batch_design .T @ batch_design ) @ batch_design .T @ s_data .T
221233 ).values
222234 # first estimate for the multiplicative batch effect
223- delta_hat = [s_data .iloc [:, batch_idxs ].var (axis = 1 ) for batch_idxs in batch_info ]
235+ delta_hat = [
236+ s_data .iloc [:, batch_idxs ].var (axis = 1 ) for batch_idxs in batch_info .values ()
237+ ]
224238
225239 # empirically fix the prior hyperparameters
226240 gamma_bar = gamma_hat .mean (axis = 1 )
@@ -233,7 +247,7 @@ def combat( # noqa: PLR0915
233247 # gamma star and delta star will be our empirical bayes (EB) estimators
234248 # for the additive and multiplicative batch effect per batch and cell
235249 gamma_star , delta_star = [], []
236- for i , batch_idxs in enumerate (batch_info ):
250+ for i , batch_idxs in enumerate (batch_info . values () ):
237251 # temp stores our estimates for the batch effect parameters.
238252 # temp[0] is the additive batch effect
239253 # temp[1] is the multiplicative batch effect
@@ -257,7 +271,7 @@ def combat( # noqa: PLR0915
257271
258272 # we now apply the parametric adjustment to the standardized data from above
259273 # loop over all batches in the data
260- for j , batch_idxs in enumerate (batch_info ):
274+ for j , batch_idxs in enumerate (batch_info . values () ):
261275 # we basically subtract the additive batch effect, rescale by the ratio
262276 # of multiplicative batch effect to pooled variance and add the overall gene
263277 # wise mean
0 commit comments