Skip to content

Commit 93618fb

Browse files
committed
Re-Added to_numpy()
1 parent 6c6940f commit 93618fb

3 files changed

Lines changed: 38 additions & 46 deletions

File tree

benchmarks/benchmarks/preprocessing_log.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,11 @@
99
from typing import TYPE_CHECKING
1010

1111
import anndata as ad
12+
import numpy as np
1213

1314
import scanpy as sc
1415

15-
from ._utils import get_dataset, param_skipper
16+
from ._utils import get_dataset, param_skipper, pbmc3k
1617

1718
if TYPE_CHECKING:
1819
from ._utils import Dataset, KeyX
@@ -72,3 +73,26 @@ def time_scale(self, *_) -> None:
7273

7374
def peakmem_scale(self, *_) -> None:
7475
sc.pp.scale(self.adata, max_value=10)
76+
77+
78+
class CombatSuite:
79+
"""Benchmark combat batch correction."""
80+
81+
def setup_cache(self) -> None:
82+
adata = pbmc3k()
83+
sc.pp.highly_variable_genes(adata, n_top_genes=500)
84+
adata = adata[:, adata.var["highly_variable"]].copy()
85+
sc.pp.scale(adata, max_value=10)
86+
# assign cells to 3 batches deterministically
87+
rng = np.random.default_rng(0)
88+
adata.obs["batch"] = rng.choice(["A", "B", "C"], size=adata.n_obs)
89+
adata.write_h5ad("adata_combat.h5ad")
90+
91+
def setup(self) -> None:
92+
self.adata = ad.read_h5ad("adata_combat.h5ad")
93+
94+
def time_combat(self) -> None:
95+
sc.pp.combat(self.adata, key="batch")
96+
97+
def peakmem_combat(self) -> None:
98+
sc.pp.combat(self.adata, key="batch")

benchmarks/benchmarks/tools.py

Lines changed: 1 addition & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import scanpy as sc
1111

12-
from ._utils import pbmc3k, pbmc68k_reduced
12+
from ._utils import pbmc68k_reduced
1313

1414

1515
class ToolsSuite: # noqa: D101
@@ -44,28 +44,3 @@ def time_rank_genes_groups(self) -> None:
4444

4545
def peakmem_rank_genes_groups(self) -> None:
4646
sc.tl.rank_genes_groups(self.adata, "bulk_labels", method="wilcoxon")
47-
48-
49-
class CombatSuite:
50-
"""Benchmark combat batch correction."""
51-
52-
def setup_cache(self) -> None:
53-
import numpy as np
54-
55-
adata = pbmc3k()
56-
sc.pp.highly_variable_genes(adata, n_top_genes=500)
57-
adata = adata[:, adata.var["highly_variable"]].copy()
58-
sc.pp.scale(adata, max_value=10)
59-
# assign cells to 3 batches deterministically
60-
np.random.seed(0)
61-
adata.obs["batch"] = np.random.choice(["A", "B", "C"], size=adata.n_obs)
62-
adata.write_h5ad("adata_combat.h5ad")
63-
64-
def setup(self) -> None:
65-
self.adata = ad.read_h5ad("adata_combat.h5ad")
66-
67-
def time_combat(self) -> None:
68-
sc.pp.combat(self.adata, key="batch")
69-
70-
def peakmem_combat(self) -> None:
71-
sc.pp.combat(self.adata, key="batch")

src/scanpy/preprocessing/_combat.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -105,21 +105,17 @@ def _standardize_data(
105105

106106
design = _design_matrix(model, batch_key, batch_levels)
107107

108-
# use numpty .values extration only once to avoid pandas overhead
109-
design_arr = design.values
110108
# compute pooled variance estimator
111-
b_hat = np.dot(
112-
np.dot(la.inv(np.dot(design_arr.T, design_arr)), design_arr.T), data.values.T
113-
)
109+
b_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
114110
grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :])
115-
var_pooled = (data.values - np.dot(design_arr, b_hat).T) ** 2
111+
var_pooled = np.asarray((data - np.dot(design, b_hat).T) ** 2)
116112
var_pooled = np.mean(var_pooled, axis=1, keepdims=True)
117113

118114
# Compute the means
119115
if np.sum(var_pooled == 0) > 0:
120116
print(f"Found {np.sum(var_pooled == 0)} genes with zero variance.")
121-
stand_mean = grand_mean[:, np.newaxis]
122-
tmp = design_arr.copy()
117+
stand_mean = np.asarray(grand_mean)[:, np.newaxis]
118+
tmp = np.array(design.copy())
123119
tmp[:, :n_batch] = 0
124120
stand_mean = stand_mean + np.dot(tmp, b_hat).T
125121

@@ -128,7 +124,7 @@ def _standardize_data(
128124
s_data = np.where(
129125
var_pooled == 0,
130126
0,
131-
(data.values - stand_mean) / np.sqrt(var_pooled),
127+
(np.asarray(data) - stand_mean) / np.sqrt(var_pooled),
132128
)
133129
s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns)
134130

@@ -272,27 +268,24 @@ def combat( # noqa: PLR0915
272268

273269
# we now apply the parametric adjustment to the standardized data from above
274270
# loop over all batches in the data
275-
bayesdata_arr = bayesdata.to_numpy(copy=True)
276-
batch_design_arr = batch_design.values
277271
for j, batch_idxs in enumerate(batch_info.values()):
278272
# we basically subtract the additive batch effect, rescale by the ratio
279273
# of multiplicative batch effect to pooled variance and add the overall gene
280274
# wise mean
281275
dsq = np.sqrt(delta_star[j, :])
282-
numer = (
283-
bayesdata_arr[:, batch_idxs]
284-
- np.dot(batch_design_arr[batch_idxs], gamma_star).T
276+
numer = np.array(
277+
bayesdata.iloc[:, batch_idxs]
278+
- np.dot(batch_design.iloc[batch_idxs], gamma_star).T
285279
)
286-
bayesdata_arr[:, batch_idxs] = numer / dsq[:, np.newaxis]
280+
bayesdata.iloc[:, batch_idxs] = numer / dsq[:, np.newaxis]
287281

288-
bayesdata_arr = bayesdata_arr * np.sqrt(var_pooled) + stand_mean
282+
bayesdata = bayesdata * np.sqrt(var_pooled) + stand_mean
289283

290284
# put back into the adata object or return
291-
x = bayesdata.to_numpy().transpose()
292285
if inplace:
293-
adata.X = bayesdata_arr.T
286+
adata.X = bayesdata.to_numpy().transpose()
294287
return None
295-
return bayesdata_arr.T
288+
return bayesdata.to_numpy().transpose()
296289

297290

298291
def _it_sol(

0 commit comments

Comments
 (0)