@@ -108,25 +108,23 @@ def _standardize_data(
108108 # compute pooled variance estimator
109109 b_hat = np .dot (np .dot (la .inv (np .dot (design .T , design )), design .T ), data .T )
110110 grand_mean = np .dot ((n_batches / n_array ).T , b_hat [:n_batch , :])
111- var_pooled = (data - np .dot (design , b_hat ).T ) ** 2
112- var_pooled = np .dot (var_pooled , np .ones ((int (n_array ), 1 )) / int (n_array ))
111+ var_pooled = (
112+ (data - np .dot (design , b_hat ).T ).pow (2 ).to_numpy ().mean (axis = 1 , keepdims = True )
113+ )
113114
114115 # Compute the means
115116 if np .sum (var_pooled == 0 ) > 0 :
116117 print (f"Found { np .sum (var_pooled == 0 )} genes with zero variance." )
117- stand_mean = np .dot (
118- grand_mean .T .reshape ((len (grand_mean ), 1 )), np .ones ((1 , int (n_array )))
119- )
120- tmp = np .array (design .copy ())
118+ tmp = design .to_numpy (copy = True )
121119 tmp [:, :n_batch ] = 0
122- stand_mean += np .dot (tmp , b_hat ).T
120+ stand_mean = grand_mean [:, np . newaxis ] + np .dot (tmp , b_hat ).T
123121
124122 # need to be a bit careful with the zero variance genes
125123 # just set the zero variance genes to zero in the standardized data
126124 s_data = np .where (
127125 var_pooled == 0 ,
128126 0 ,
129- (( data - stand_mean ) / np .dot ( np . sqrt (var_pooled ), np . ones (( 1 , int ( n_array )))) ),
127+ (data - stand_mean ) / np .sqrt (var_pooled ),
130128 )
131129 s_data = pd .DataFrame (s_data , index = data .index , columns = data .columns )
132130
@@ -219,7 +217,6 @@ def combat( # noqa: PLR0915
219217 "within-batch variance. Filter these batches before running combat."
220218 )
221219 raise ValueError (msg )
222- n_array = float (sum (n_batches ))
223220
224221 # standardize across genes using a pooled variance estimator
225222 logg .info ("Standardizing Data across genes.\n " )
@@ -277,16 +274,13 @@ def combat( # noqa: PLR0915
277274 # of multiplicative batch effect to pooled variance and add the overall gene
278275 # wise mean
279276 dsq = np .sqrt (delta_star [j , :])
280- dsq = dsq .reshape ((len (dsq ), 1 ))
281- denom = np .dot (dsq , np .ones ((1 , n_batches [j ])))
282277 numer = np .array (
283278 bayesdata .iloc [:, batch_idxs ]
284279 - np .dot (batch_design .iloc [batch_idxs ], gamma_star ).T
285280 )
286- bayesdata .iloc [:, batch_idxs ] = numer / denom
281+ bayesdata .iloc [:, batch_idxs ] = numer / dsq [:, np . newaxis ]
287282
288- vpsq = np .sqrt (var_pooled ).reshape ((len (var_pooled ), 1 ))
289- bayesdata = bayesdata * np .dot (vpsq , np .ones ((1 , int (n_array )))) + stand_mean
283+ bayesdata = bayesdata * np .sqrt (var_pooled ) + stand_mean
290284
291285 # put back into the adata object or return
292286 if inplace :
@@ -348,12 +342,7 @@ def _it_sol(
348342 # in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence.
349343 while change > conv :
350344 g_new = (t2 * n * g_hat + d_old * g_bar ) / (t2 * n + d_old )
351- sum2 = s_data - g_new .reshape ((g_new .shape [0 ], 1 )) @ np .ones ((
352- 1 ,
353- s_data .shape [1 ],
354- ))
355- sum2 = sum2 ** 2
356- sum2 = sum2 .sum (axis = 1 )
345+ sum2 = ((s_data - g_new [:, np .newaxis ]) ** 2 ).sum (axis = 1 )
357346 d_new = (0.5 * sum2 + b ) / (n / 2.0 + a - 1.0 )
358347
359348 change = max (
0 commit comments