Skip to content

Commit 746a05b

Browse files
Backport PR #4070 on branch 1.12.x (perf: Combat perf improvements) (#4085)
Co-authored-by: Ilay Kavitzky <ilay.kavitzky@gmail.com>
1 parent a4ebf32 commit 746a05b

2 files changed

Lines changed: 15 additions & 20 deletions

File tree

benchmarks/benchmarks/tools.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@ def time_rank_genes_groups(self) -> None:
4444

4545
def peakmem_rank_genes_groups(self) -> None:
4646
sc.tl.rank_genes_groups(self.adata, "bulk_labels", method="wilcoxon")
47+
48+
def time_combat(self) -> None:
49+
sc.pp.combat(self.adata, key="bulk_labels")
50+
51+
def peakmem_combat(self) -> None:
52+
sc.pp.combat(self.adata, key="bulk_labels")

src/scanpy/preprocessing/_combat.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -108,25 +108,23 @@ def _standardize_data(
108108
# compute pooled variance estimator
109109
b_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
110110
grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :])
111-
var_pooled = (data - np.dot(design, b_hat).T) ** 2
112-
var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array))
111+
var_pooled = (
112+
(data - np.dot(design, b_hat).T).pow(2).to_numpy().mean(axis=1, keepdims=True)
113+
)
113114

114115
# Compute the means
115116
if np.sum(var_pooled == 0) > 0:
116117
print(f"Found {np.sum(var_pooled == 0)} genes with zero variance.")
117-
stand_mean = np.dot(
118-
grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array)))
119-
)
120-
tmp = np.array(design.copy())
118+
tmp = design.to_numpy(copy=True)
121119
tmp[:, :n_batch] = 0
122-
stand_mean += np.dot(tmp, b_hat).T
120+
stand_mean = grand_mean[:, np.newaxis] + np.dot(tmp, b_hat).T
123121

124122
# need to be a bit careful with the zero variance genes
125123
# just set the zero variance genes to zero in the standardized data
126124
s_data = np.where(
127125
var_pooled == 0,
128126
0,
129-
((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))),
127+
(data - stand_mean) / np.sqrt(var_pooled),
130128
)
131129
s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns)
132130

@@ -219,7 +217,6 @@ def combat( # noqa: PLR0915
219217
"within-batch variance. Filter these batches before running combat."
220218
)
221219
raise ValueError(msg)
222-
n_array = float(sum(n_batches))
223220

224221
# standardize across genes using a pooled variance estimator
225222
logg.info("Standardizing Data across genes.\n")
@@ -277,16 +274,13 @@ def combat( # noqa: PLR0915
277274
# of multiplicative batch effect to pooled variance and add the overall gene
278275
# wise mean
279276
dsq = np.sqrt(delta_star[j, :])
280-
dsq = dsq.reshape((len(dsq), 1))
281-
denom = np.dot(dsq, np.ones((1, n_batches[j])))
282277
numer = np.array(
283278
bayesdata.iloc[:, batch_idxs]
284279
- np.dot(batch_design.iloc[batch_idxs], gamma_star).T
285280
)
286-
bayesdata.iloc[:, batch_idxs] = numer / denom
281+
bayesdata.iloc[:, batch_idxs] = numer / dsq[:, np.newaxis]
287282

288-
vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
289-
bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean
283+
bayesdata = bayesdata * np.sqrt(var_pooled) + stand_mean
290284

291285
# put back into the adata object or return
292286
if inplace:
@@ -348,12 +342,7 @@ def _it_sol(
348342
# in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence.
349343
while change > conv:
350344
g_new = (t2 * n * g_hat + d_old * g_bar) / (t2 * n + d_old)
351-
sum2 = s_data - g_new.reshape((g_new.shape[0], 1)) @ np.ones((
352-
1,
353-
s_data.shape[1],
354-
))
355-
sum2 = sum2**2
356-
sum2 = sum2.sum(axis=1)
345+
sum2 = ((s_data - g_new[:, np.newaxis]) ** 2).sum(axis=1)
357346
d_new = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
358347

359348
change = max(

0 commit comments

Comments
 (0)