Skip to content

Commit 9436132

Browse files
committed
correlation filter for enrich_vs_all
1 parent b75b623 commit 9436132

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

grassp/preprocessing/enrichment.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .simple import aggregate_samples
99
import scipy.stats as stats
1010
import numpy as np
11+
import warnings
1112

1213

1314
def calculate_enrichment_vs_untagged(
@@ -152,44 +153,55 @@ def calculate_enrichment_vs_all(
152153

153154
data = adata.copy()
154155

155-
if covariates is None:
156-
covariates = data.var.columns[data.var.columns.str.startswith("covariate_")]
157-
else:
158-
# Check that all covariates are in the data
159-
for c in covariates:
160-
if c not in data.var.columns:
161-
raise ValueError(f"Covariate {c} not found in data.var.columns")
156+
# if covariates is None:
157+
# covariates = data.var.columns[data.var.columns.str.startswith("covariate_")]
158+
# else:
159+
# Check that all covariates are in the data
160+
for c in covariates:
161+
if c not in data.var.columns:
162+
raise ValueError(f"Covariate {c} not found in data.var.columns")
162163

164+
if not isinstance(covariates, list):
165+
covariates = [covariates]
163166
# Create aggregated data with the desired output shape
164-
grouping_columns = [subcellular_enrichment_column] + covariates.tolist()
167+
grouping_columns = [subcellular_enrichment_column] + covariates
165168
# Create a temporary column that contains the experimental conditions
166169
data.var["_experimental_condition"] = data.var[grouping_columns].apply(
167170
lambda x: "_".join(x.dropna().astype(str)),
168171
axis=1,
169172
)
170173

171-
data_aggr = aggregate_samples(data, grouping_columns=grouping_columns)
174+
data_aggr = aggregate_samples(
175+
data, grouping_columns=grouping_columns, keep_raw=False
176+
)
172177
data_aggr.var_names = data_aggr.var_names.str.replace(r"_\d+", "", regex=True)
173178

174179
if original_intensities_key is not None:
175180
data_aggr.layers[original_intensities_key] = data_aggr.X
176181
data_aggr.layers["pvals"] = np.zeros_like(data_aggr.X)
182+
data_aggr.var["enriched_vs"] = ""
177183

178-
corr_matrix = np.corrcoef(data_aggr.X.T)
184+
intensities = data_aggr.X.copy()
185+
corr_matrix = np.corrcoef(intensities.T)
179186

180187
for experimental_condition in data_aggr.var["_experimental_condition"].unique():
181-
mask = data_aggr.var["_experimental_condition"] != experimental_condition
182-
intensities_control = data_aggr[:, mask].X
183-
intensities_ip = data_aggr[:, ~mask].X
184-
intensities_ip = data[
185-
:, data.var["_experimental_condition"] == experimental_condition
186-
].X
188+
mask = data_aggr.var["_experimental_condition"] == experimental_condition
189+
corr_mat_sub = corr_matrix[mask, :].mean(axis=0)
190+
control_mask = ~mask & (corr_mat_sub < correlation_threshold)
191+
if control_mask.sum() < 10:
192+
warnings.warn(
193+
f"Less than 10 ({control_mask.sum()}) control samples found for condition: {experimental_condition}"
194+
)
195+
intensities_control = intensities[:, control_mask]
196+
intensities_ip = intensities[:, mask]
187197
scores, pv = stats.ttest_ind(intensities_ip.T, intensities_control.T)
188198
lfc = np.median(intensities_ip, axis=1) - np.median(intensities_control, axis=1)
189199
aggr_mask = data_aggr.var["_experimental_condition"] == experimental_condition
190200
data_aggr.layers["pvals"][:, aggr_mask] = pv[:, None]
191201
data_aggr[:, aggr_mask].X = lfc[:, None]
192-
202+
data_aggr.var.loc[aggr_mask, "enriched_vs"] = ",".join(
203+
data_aggr.var_names[control_mask]
204+
)
193205
data_aggr.var.drop(columns=["_experimental_condition"], inplace=True)
194206
if keep_raw:
195207
data_aggr.raw = data.copy()

0 commit comments

Comments
 (0)