Skip to content

Commit d6f2c51

Browse files
ilaykavpre-commit-ci[bot]flying-sheepCopilot
authored
perf: Combat perf improvements (#4070)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Philipp A. <flying-sheep@web.de> Co-authored-by: Copilot <copilot@github.com>
1 parent d15aede commit d6f2c51

2 files changed

Lines changed: 15 additions & 20 deletions

File tree

benchmarks/benchmarks/tools.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,9 @@ def time_rank_genes_groups(self) -> None:
4444

4545
def peakmem_rank_genes_groups(self) -> None:
4646
sc.tl.rank_genes_groups(self.adata, "bulk_labels", method="wilcoxon")
47+
48+
def time_combat(self) -> None:
49+
sc.pp.combat(self.adata, key="bulk_labels")
50+
51+
def peakmem_combat(self) -> None:
52+
sc.pp.combat(self.adata, key="bulk_labels")

src/scanpy/preprocessing/_combat.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -108,25 +108,23 @@ def _standardize_data(
108108
# compute pooled variance estimator
109109
b_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
110110
grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :])
111-
var_pooled = (data - np.dot(design, b_hat).T) ** 2
112-
var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array))
111+
var_pooled = (
112+
(data - np.dot(design, b_hat).T).pow(2).to_numpy().mean(axis=1, keepdims=True)
113+
)
113114

114115
# Compute the means
115116
if np.sum(var_pooled == 0) > 0:
116117
print(f"Found {np.sum(var_pooled == 0)} genes with zero variance.")
117-
stand_mean = np.dot(
118-
grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array)))
119-
)
120-
tmp = np.array(design.copy())
118+
tmp = design.to_numpy(copy=True)
121119
tmp[:, :n_batch] = 0
122-
stand_mean += np.dot(tmp, b_hat).T
120+
stand_mean = grand_mean[:, np.newaxis] + np.dot(tmp, b_hat).T
123121

124122
# need to be a bit careful with the zero variance genes
125123
# just set the zero variance genes to zero in the standardized data
126124
s_data = np.where(
127125
var_pooled == 0,
128126
0,
129-
((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))),
127+
(data - stand_mean) / np.sqrt(var_pooled),
130128
)
131129
s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns)
132130

@@ -218,7 +216,6 @@ def combat( # noqa: PLR0915
218216
"within-batch variance. Filter these batches before running combat."
219217
)
220218
raise ValueError(msg)
221-
n_array = float(sum(n_batches))
222219

223220
# standardize across genes using a pooled variance estimator
224221
logg.info("Standardizing Data across genes.\n")
@@ -276,16 +273,13 @@ def combat( # noqa: PLR0915
276273
# of multiplicative batch effect to pooled variance and add the overall gene
277274
# wise mean
278275
dsq = np.sqrt(delta_star[j, :])
279-
dsq = dsq.reshape((len(dsq), 1))
280-
denom = np.dot(dsq, np.ones((1, n_batches[j])))
281276
numer = np.array(
282277
bayesdata.iloc[:, batch_idxs]
283278
- np.dot(batch_design.iloc[batch_idxs], gamma_star).T
284279
)
285-
bayesdata.iloc[:, batch_idxs] = numer / denom
280+
bayesdata.iloc[:, batch_idxs] = numer / dsq[:, np.newaxis]
286281

287-
vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
288-
bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean
282+
bayesdata = bayesdata * np.sqrt(var_pooled) + stand_mean
289283

290284
# put back into the adata object or return
291285
x = bayesdata.to_numpy().transpose()
@@ -348,12 +342,7 @@ def _it_sol(
348342
# in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence.
349343
while change > conv:
350344
g_new = (t2 * n * g_hat + d_old * g_bar) / (t2 * n + d_old)
351-
sum2 = s_data - g_new.reshape((g_new.shape[0], 1)) @ np.ones((
352-
1,
353-
s_data.shape[1],
354-
))
355-
sum2 = sum2**2
356-
sum2 = sum2.sum(axis=1)
345+
sum2 = ((s_data - g_new[:, np.newaxis]) ** 2).sum(axis=1)
357346
d_new = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
358347

359348
change = max(

0 commit comments

Comments
 (0)