forked from scverse/scanpy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path_combat.py
More file actions
377 lines (310 loc) · 12.1 KB
/
_combat.py
File metadata and controls
377 lines (310 loc) · 12.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
from numpy import linalg as la
from .. import logging as logg
from .._compat import CSBase
from .._utils import sanitize_anndata
if TYPE_CHECKING:
from collections.abc import Collection
from anndata import AnnData
def _design_matrix(
model: pd.DataFrame, batch_key: str, batch_levels: Collection[str]
) -> pd.DataFrame:
"""Compute a simple design matrix.
Parameters
----------
model
Contains the batch annotation
batch_key
Name of the batch column
batch_levels
Levels of the batch annotation
Returns
-------
The design matrix for the regression problem
"""
import patsy
design = patsy.dmatrix(
f"~ 0 + C(Q('{batch_key}'), levels=batch_levels)",
model,
return_type="dataframe",
)
model = model.drop([batch_key], axis=1)
numerical_covariates = model.select_dtypes("number").columns.values
logg.info(f"Found {design.shape[1]} batches\n")
other_cols = [c for c in model.columns.values if c not in numerical_covariates]
if other_cols:
col_repr = " + ".join(f"Q('{x}')" for x in other_cols)
factor_matrix = patsy.dmatrix(
f"~ 0 + {col_repr}", model[other_cols], return_type="dataframe"
)
design = pd.concat((design, factor_matrix), axis=1)
logg.info(f"Found {len(other_cols)} categorical variables:")
logg.info(f"\t{', '.join(other_cols)}\n")
if numerical_covariates is not None:
logg.info(f"Found {len(numerical_covariates)} numerical variables:")
logg.info(f"\t{', '.join(numerical_covariates)}\n")
for n_c in numerical_covariates:
design[n_c] = model[n_c]
return design
def _standardize_data(
model: pd.DataFrame, data: pd.DataFrame, batch_key: str
) -> tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray]:
"""Standardize the data per gene.
The aim here is to make mean and variance be comparable across batches.
Parameters
----------
model
Contains the batch annotation
data
Contains the Data
batch_key
Name of the batch column in the model matrix
Returns
-------
s_data
Standardized Data
design
Batch assignment as one-hot encodings
var_pooled
Pooled variance per gene
stand_mean
Gene-wise mean
"""
# compute the design matrix
batch_items = model.groupby(batch_key, observed=True).groups.items()
batch_levels, batch_info = zip(*batch_items, strict=True)
n_batch = len(batch_info)
n_batches = np.array([len(v) for v in batch_info])
n_array = float(sum(n_batches))
design = _design_matrix(model, batch_key, batch_levels)
# compute pooled variance estimator
b_hat = np.dot(np.dot(la.inv(np.dot(design.T, design)), design.T), data.T)
grand_mean = np.dot((n_batches / n_array).T, b_hat[:n_batch, :])
var_pooled = (data - np.dot(design, b_hat).T) ** 2
var_pooled = np.dot(var_pooled, np.ones((int(n_array), 1)) / int(n_array))
# Compute the means
if np.sum(var_pooled == 0) > 0:
print(f"Found {np.sum(var_pooled == 0)} genes with zero variance.")
stand_mean = np.dot(
grand_mean.T.reshape((len(grand_mean), 1)), np.ones((1, int(n_array)))
)
tmp = np.array(design.copy())
tmp[:, :n_batch] = 0
stand_mean += np.dot(tmp, b_hat).T
# need to be a bit careful with the zero variance genes
# just set the zero variance genes to zero in the standardized data
s_data = np.where(
var_pooled == 0,
0,
((data - stand_mean) / np.dot(np.sqrt(var_pooled), np.ones((1, int(n_array))))),
)
s_data = pd.DataFrame(s_data, index=data.index, columns=data.columns)
return s_data, design, var_pooled, stand_mean
def combat( # noqa: PLR0915
adata: AnnData,
key: str = "batch",
*,
covariates: Collection[str] | None = None,
inplace: bool = True,
) -> np.ndarray | None:
"""ComBat function for batch effect correction :cite:p:`Johnson2006,Leek2012,Pedersen2012`.
Corrects for batch effects by fitting linear models, gains statistical power
via an EB framework where information is borrowed across genes.
This uses the implementation `combat.py`_ :cite:p:`Pedersen2012`.
.. _combat.py: https://github.com/brentp/combat.py
.. array-support:: pp.combat
Parameters
----------
adata
Annotated data matrix
key
Key to a categorical annotation from :attr:`~anndata.AnnData.obs`
that will be used for batch effect removal.
covariates
Additional covariates besides the batch variable such as adjustment
variables or biological condition. This parameter refers to the design
matrix `X` in Equation 2.1 in :cite:t:`Johnson2006` and to the `mod` argument in
the original combat function in the sva R package.
Note that not including covariates may introduce bias or lead to the
removal of biological signal in unbalanced designs.
inplace
Whether to replace adata.X or to return the corrected data
Returns
-------
Returns :class:`numpy.ndarray` if `inplace=False`, else returns `None` and sets the following field in the `adata` object:
`adata.X` : :class:`numpy.ndarray` (dtype `float`)
Corrected data matrix.
"""
# check the input
if key not in adata.obs:
msg = f"Could not find the key {key!r} in adata.obs"
raise ValueError(msg)
if covariates is not None:
cov_exist = np.isin(covariates, adata.obs.columns)
if np.any(~cov_exist):
missing_cov = np.array(covariates)[~cov_exist].tolist()
msg = f"Could not find the covariate(s) {missing_cov!r} in adata.obs"
raise ValueError(msg)
if key in covariates:
msg = "Batch key and covariates cannot overlap"
raise ValueError(msg)
if len(covariates) != len(set(covariates)):
msg = "Covariates must be unique"
raise ValueError(msg)
# only works on dense matrices so far
x = adata.X.toarray().T if isinstance(adata.X, CSBase) else adata.X.T
data = pd.DataFrame(data=x, index=adata.var_names, columns=adata.obs_names)
sanitize_anndata(adata)
# construct a pandas series of the batch annotation
model: pd.DataFrame = adata.obs[[key, *(covariates if covariates else [])]]
batch_info = model.groupby(key, observed=True).indices
n_batch = len(batch_info)
n_batches = np.array([len(v) for v in batch_info.values()])
# check for batches with fewer than 2 cells
small_batches = [
batch for batch, size in zip(batch_info, n_batches, strict=True) if size < 2
]
if small_batches:
msg = (
f"Batches {small_batches!r} have fewer than 2 cells. "
"ComBat requires at least 2 cells per batch to estimate "
"within-batch variance. Filter these batches before running combat."
)
raise ValueError(msg)
n_array = float(sum(n_batches))
# standardize across genes using a pooled variance estimator
logg.info("Standardizing Data across genes.\n")
s_data, design, var_pooled, stand_mean = _standardize_data(model, data, key)
# fitting the parameters on the standardized data
logg.info("Fitting L/S model and finding priors\n")
batch_design = design[design.columns[:n_batch]]
# first estimate of the additive batch effect
gamma_hat = (
la.inv(batch_design.T @ batch_design) @ batch_design.T @ s_data.T
).values
# first estimate for the multiplicative batch effect
delta_hat = [
s_data.iloc[:, batch_idxs].var(axis=1) for batch_idxs in batch_info.values()
]
# empirically fix the prior hyperparameters
gamma_bar = gamma_hat.mean(axis=1)
t2 = gamma_hat.var(axis=1)
# a_prior and b_prior are the priors on lambda and theta from Johnson and Li (2006)
a_prior = list(map(_aprior, delta_hat))
b_prior = list(map(_bprior, delta_hat))
logg.info("Finding parametric adjustments\n")
# gamma star and delta star will be our empirical bayes (EB) estimators
# for the additive and multiplicative batch effect per batch and cell
gamma_star, delta_star = [], []
for i, batch_idxs in enumerate(batch_info.values()):
# temp stores our estimates for the batch effect parameters.
# temp[0] is the additive batch effect
# temp[1] is the multiplicative batch effect
gamma, delta = _it_sol(
s_data.iloc[:, batch_idxs].values,
gamma_hat[i],
delta_hat[i].values,
g_bar=gamma_bar[i],
t2=t2[i],
a=a_prior[i],
b=b_prior[i],
)
gamma_star.append(gamma)
delta_star.append(delta)
logg.info("Adjusting data\n")
bayesdata = s_data
gamma_star = np.array(gamma_star)
delta_star = np.array(delta_star)
# we now apply the parametric adjustment to the standardized data from above
# loop over all batches in the data
for j, batch_idxs in enumerate(batch_info.values()):
# we basically subtract the additive batch effect, rescale by the ratio
# of multiplicative batch effect to pooled variance and add the overall gene
# wise mean
dsq = np.sqrt(delta_star[j, :])
dsq = dsq.reshape((len(dsq), 1))
denom = np.dot(dsq, np.ones((1, n_batches[j])))
numer = np.array(
bayesdata.iloc[:, batch_idxs]
- np.dot(batch_design.iloc[batch_idxs], gamma_star).T
)
bayesdata.iloc[:, batch_idxs] = numer / denom
vpsq = np.sqrt(var_pooled).reshape((len(var_pooled), 1))
bayesdata = bayesdata * np.dot(vpsq, np.ones((1, int(n_array)))) + stand_mean
# put back into the adata object or return
if inplace:
adata.X = bayesdata.values.transpose()
else:
return bayesdata.values.transpose()
def _it_sol(
s_data: np.ndarray,
g_hat: np.ndarray,
d_hat: np.ndarray,
*,
g_bar: float,
t2: float,
a: float,
b: float,
conv: float = 0.0001,
) -> tuple[np.ndarray, np.ndarray]:
"""Iteratively compute the conditional posterior means for gamma and delta.
gamma is an estimator for the additive batch effect, deltat is an estimator
for the multiplicative batch effect. We use an EB framework to estimate these
two. Analytical expressions exist for both parameters, which however depend on each other.
We therefore iteratively evalutate these two expressions until convergence is reached.
Parameters
----------
s_data
Contains the standardized Data
g_hat
Initial guess for gamma
d_hat
Initial guess for delta
g_bar, t2, a, b
Hyperparameters
conv: float, optional (default: `0.0001`)
convergence criterium
Returns
-------
gamma
estimated value for gamma
delta
estimated value for delta
""" # noqa: D401
n = (1 - np.isnan(s_data)).sum(axis=1)
g_old = g_hat.copy()
d_old = d_hat.copy()
change = 1
count = 0
# They need to be initialized for numba to properly infer types
g_new = g_old
d_new = d_old
# we place a normally distributed prior on gamma and and inverse gamma prior on delta
# in the loop, gamma and delta are updated together. they depend on each other. we iterate until convergence.
while change > conv:
g_new = (t2 * n * g_hat + d_old * g_bar) / (t2 * n + d_old)
sum2 = s_data - g_new.reshape((g_new.shape[0], 1)) @ np.ones((
1,
s_data.shape[1],
))
sum2 = sum2**2
sum2 = sum2.sum(axis=1)
d_new = (0.5 * sum2 + b) / (n / 2.0 + a - 1.0)
change = max(
(abs(g_new - g_old) / g_old).max(), (abs(d_new - d_old) / d_old).max()
)
g_old = g_new # .copy()
d_old = d_new # .copy()
count = count + 1
return g_new, d_new
def _aprior(delta_hat):
m = delta_hat.mean()
s2 = delta_hat.var()
return (2 * s2 + m**2) / s2
def _bprior(delta_hat):
m = delta_hat.mean()
s2 = delta_hat.var()
return (m * s2 + m**3) / s2