Skip to content
Merged
6 changes: 6 additions & 0 deletions benchmarks/benchmarks/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ def time_rank_genes_groups(self) -> None:

def peakmem_rank_genes_groups(self) -> None:
sc.tl.rank_genes_groups(self.adata, "bulk_labels", method="wilcoxon")

def time_combat(self) -> None:
sc.pp.combat(self.adata, key="bulk_labels")

def peakmem_combat(self) -> None:
sc.pp.combat(self.adata, key="bulk_labels")
29 changes: 9 additions & 20 deletions src/scanpy/preprocessing/_combat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,25 +108,23 @@ def _standardize_data(
# compute pooled variance estimator
b_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :])
var_pooled = (data - np.dot(design, b_hat).T) ** 2
var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array))
var_pooled = (
(data - np.dot(design, b_hat).T).pow(2).to_numpy().mean(axis=1, keepdims=True)
)

# Compute the means
if np.sum(var_pooled == 0) > 0:
print(f"Found {np.sum(var_pooled == 0)} genes with zero variance.")
stand_mean = np.dot(
grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array)))
)
tmp = np.array(design.copy())
tmp = design.to_numpy(copy=True)
tmp[:, :n_batch] = 0
stand_mean += np.dot(tmp, b_hat).T
stand_mean = grand_mean[:, np.newaxis] + np.dot(tmp, b_hat).T

# need to be a bit careful with the zero variance genes
# just set the zero variance genes to zero in the standardized data
s_data = np.where(
var_pooled == 0,
0,
((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))),
(data - stand_mean) / np.sqrt(var_pooled),
)
s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns)

Expand Down Expand Up @@ -218,7 +216,6 @@ def combat( # noqa: PLR0915
"within-batch variance. Filter these batches before running combat."
)
raise ValueError(msg)
n_array = float(sum(n_batches))

# standardize across genes using a pooled variance estimator
logg.info("Standardizing Data across genes.\n")
Expand Down Expand Up @@ -276,16 +273,13 @@ def combat( # noqa: PLR0915
# of multiplicative batch effect to pooled variance and add the overall gene
# wise mean
dsq = np.sqrt(delta_star[j, :])
dsq = dsq.reshape((len(dsq), 1))
denom = np.dot(dsq, np.ones((1, n_batches[j])))
numer = np.array(
bayesdata.iloc[:, batch_idxs]
- np.dot(batch_design.iloc[batch_idxs], gamma_star).T
)
bayesdata.iloc[:, batch_idxs] = numer / denom
bayesdata.iloc[:, batch_idxs] = numer / dsq[:, np.newaxis]

vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean
bayesdata = bayesdata * np.sqrt(var_pooled) + stand_mean

# put back into the adata object or return
x = bayesdata.to_numpy().transpose()
Expand Down Expand Up @@ -348,12 +342,7 @@ def _it_sol(
# in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence.
while change > conv:
g_new = (t2 * n * g_hat + d_old * g_bar) / (t2 * n + d_old)
sum2 = s_data - g_new.reshape((g_new.shape[0], 1)) @ np.ones((
1,
s_data.shape[1],
))
sum2 = sum2**2
sum2 = sum2.sum(axis=1)
sum2 = ((s_data - g_new[:, np.newaxis]) ** 2).sum(axis=1)
d_new = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)

change = max(
Expand Down
Loading