Skip to content

Commit ef25311

Browse files
maltekuehlgrst
andauthored
Add permutation test (#726)
* Add permutation test * Fix test kwargs update * Add n_jobs argument and change test to check for significance agreement with both TTest and Wilcoxontest and add seed * Simplify and generalize compare_groups by adding most important permutation arguments, passing others through kwargs * Make permutation_test argument optional but raise warning if not provided * Make test case a bit stricter again for significant values, enable returning statistic from tests and fix bug where the permutation_test was not applied * Remove unnecessary import * Remove parallelization and return statistic and p-value everywhere * Remove parallelization and return statistic and p-value everywhere * Remove parallelization and return statistic and p-value everywhere * Fix docstring and examples of permutation test * Simplify permutation test with callable only * Default on user facing function only Co-authored-by: Gregor Sturm <[email protected]> * Undo last commit Set default value for test_statistic parameter. * Actually revert --------- Co-authored-by: Gregor Sturm <[email protected]>
1 parent c836bef commit ef25311

File tree

6 files changed

+217
-20
lines changed

6 files changed

+217
-20
lines changed

docs/api/tools_index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ Pertpy enables differential gene expression tests through a common interface tha
1919
tools.EdgeR
2020
tools.WilcoxonTest
2121
tools.TTest
22+
tools.PermutationTest
2223
tools.Statsmodels
2324
```
2425

pertpy/tools/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __getattr__(name: str):
3434
raise ImportError(
3535
"Extra dependencies required: toytree, ete4. Please install with: pip install toytree ete4"
3636
) from None
37-
elif name in ["EdgeR", "PyDESeq2", "Statsmodels", "TTest", "WilcoxonTest"]:
37+
elif name in ["EdgeR", "PermutationTest", "PyDESeq2", "Statsmodels", "TTest", "WilcoxonTest"]:
3838
module = import_module("pertpy.tools._differential_gene_expression")
3939
return getattr(module, name)
4040
elif name == "Scgen":
@@ -63,6 +63,7 @@ def __dir__():
6363
"PyDESeq2",
6464
"WilcoxonTest",
6565
"TTest",
66+
"PermutationTest",
6667
"Statsmodels",
6768
"DistanceTest",
6869
"Distance",

pertpy/tools/_differential_gene_expression/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,8 @@
33
from importlib.util import find_spec
44

55
from ._base import LinearModelBase, MethodBase
6-
from ._dge_comparison import DGEEVAL
76
from ._edger import EdgeR
8-
from ._simple_tests import SimpleComparisonBase, TTest, WilcoxonTest
7+
from ._simple_tests import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest
98

109

1110
def __getattr__(name: str):
@@ -57,4 +56,5 @@ def _get_available_methods():
5756
"SimpleComparisonBase",
5857
"WilcoxonTest",
5958
"TTest",
59+
"PermutationTest",
6060
]

pertpy/tools/_differential_gene_expression/_pydeseq2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def fit(self, **kwargs) -> pd.DataFrame:
4040
**kwargs: Keyword arguments specific to DeseqDataSet(), except for `n_cpus` which will use all available CPUs minus one if the argument is not passed.
4141
"""
4242
try:
43-
usable_cpus = len(os.sched_getaffinity(0))
43+
usable_cpus = len(os.sched_getaffinity(0)) # type: ignore # os.sched_getaffinity is not available on Windows and macOS
4444
except AttributeError:
4545
usable_cpus = os.cpu_count()
4646

pertpy/tools/_differential_gene_expression/_simple_tests.py

Lines changed: 164 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import warnings
44
from abc import abstractmethod
5-
from collections.abc import Mapping, Sequence
5+
from collections.abc import Callable, Mapping, Sequence
66
from types import MappingProxyType
77

88
import numpy as np
@@ -33,7 +33,7 @@ def fdr_correction(
3333
class SimpleComparisonBase(MethodBase):
3434
@staticmethod
3535
@abstractmethod
36-
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
36+
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
3737
"""Perform a statistical test between values in x0 and x1.
3838
3939
If `paired` is True, x0 and x1 must be of the same length and ordered such that
@@ -44,6 +44,10 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
4444
x1: Array with values to compare.
4545
paired: Indicates whether to perform a paired test
4646
**kwargs: kwargs passed to the test function
47+
48+
Returns:
49+
A dictionary metric -> value.
50+
This allows to return values for different metrics (e.g. p-value + test statistic).
4751
"""
4852
...
4953

@@ -71,16 +75,16 @@ def _compare_single_group(
7175
x0 = x0.tocsc()
7276
x1 = x1.tocsc()
7377

74-
res = []
78+
res: list[dict[str, float]] = []
7579
for var in tqdm(self.adata.var_names):
7680
tmp_x0 = x0[:, self.adata.var_names == var]
7781
tmp_x0 = np.asarray(tmp_x0.todense()).flatten() if issparse(tmp_x0) else tmp_x0.flatten()
7882
tmp_x1 = x1[:, self.adata.var_names == var]
7983
tmp_x1 = np.asarray(tmp_x1.todense()).flatten() if issparse(tmp_x1) else tmp_x1.flatten()
80-
pval = self._test(tmp_x0, tmp_x1, paired, **kwargs)
84+
test_result = self._test(tmp_x0, tmp_x1, paired, **kwargs)
8185
mean_x0 = np.mean(tmp_x0)
8286
mean_x1 = np.mean(tmp_x1)
83-
res.append({"variable": var, "p_value": pval, "log_fc": np.log2(mean_x1) - np.log2(mean_x0)})
87+
res.append({"variable": var, "log_fc": np.log2(mean_x1) - np.log2(mean_x0), **test_result})
8488
return pd.DataFrame(res).sort_values("p_value")
8589

8690
@classmethod
@@ -97,6 +101,20 @@ def compare_groups(
97101
fit_kwargs: Mapping = MappingProxyType({}),
98102
test_kwargs: Mapping = MappingProxyType({}),
99103
) -> DataFrame:
104+
"""Perform a comparison between groups.
105+
106+
Args:
107+
adata: Data with observations to compare.
108+
column: Column in `adata.obs` that contains the groups to compare.
109+
baseline: Reference group.
110+
groups_to_compare: Groups to compare against the baseline. If None, all other groups
111+
are compared.
112+
paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed.
113+
mask: Mask to apply to the data.
114+
layer: Layer to use for the comparison.
115+
fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify.
116+
test_kwargs: Additional kwargs passed to the test function.
117+
"""
100118
if len(fit_kwargs):
101119
warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2)
102120
paired = paired_by is not None
@@ -144,19 +162,150 @@ class WilcoxonTest(SimpleComparisonBase):
144162
"""
145163

146164
@staticmethod
147-
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
148-
if paired:
149-
return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue
150-
else:
151-
return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue
165+
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
166+
"""Perform an unpaired or paired Wilcoxon/Mann-Whitney-U test."""
167+
test_result = scipy.stats.wilcoxon(x0, x1, **kwargs) if paired else scipy.stats.mannwhitneyu(x0, x1, **kwargs)
168+
169+
return {
170+
"p_value": test_result.pvalue,
171+
"statistic": test_result.statistic,
172+
}
152173

153174

154175
class TTest(SimpleComparisonBase):
155176
"""Perform a unpaired or paired T-test."""
156177

157178
@staticmethod
158-
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
159-
if paired:
160-
return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue
161-
else:
162-
return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue
179+
def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> dict[str, float]:
180+
test_result = scipy.stats.ttest_rel(x0, x1, **kwargs) if paired else scipy.stats.ttest_ind(x0, x1, **kwargs)
181+
182+
return {
183+
"p_value": test_result.pvalue,
184+
"statistic": test_result.statistic,
185+
}
186+
187+
188+
class PermutationTest(SimpleComparisonBase):
189+
"""Perform a permutation test.
190+
191+
The permutation test relies on another test statistic (e.g. t-statistic or your own) to obtain a p-value through
192+
random permutations of the data and repeated generation of the test statistic.
193+
194+
For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For
195+
unpaired tests, all observations are permuted independently.
196+
197+
The null hypothesis for the unpaired test is that all observations come from the same underlying distribution and
198+
have been randomly assigned to one of the samples.
199+
200+
The null hypothesis for the paired permutation test is that the observations within each pair are drawn from the
201+
same underlying distribution and that their assignment to a sample is random.
202+
"""
203+
204+
@classmethod
205+
def compare_groups(
206+
cls,
207+
adata: AnnData,
208+
column: str,
209+
baseline: str,
210+
groups_to_compare: str | Sequence[str],
211+
*,
212+
paired_by: str | None = None,
213+
mask: str | None = None,
214+
layer: str | None = None,
215+
n_permutations: int = 1000,
216+
test_statistic: Callable[[np.ndarray, np.ndarray], float] = lambda x, y: np.log2(np.mean(y) + 1e-8)
217+
- np.log2(np.mean(x) + 1e-8),
218+
fit_kwargs: Mapping = MappingProxyType({}),
219+
test_kwargs: Mapping = MappingProxyType({}),
220+
) -> DataFrame:
221+
"""Perform a permutation test comparison between groups.
222+
223+
Args:
224+
adata: Data with observations to compare.
225+
column: Column in `adata.obs` that contains the groups to compare.
226+
baseline: Reference group.
227+
groups_to_compare: Groups to compare against the baseline. If None, all other groups
228+
are compared.
229+
paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed.
230+
mask: Mask to apply to the data.
231+
layer: Layer to use for the comparison.
232+
n_permutations: Number of permutations to perform.
233+
test_statistic: A callable that takes two arrays (x0, x1) and returns a float statistic.
234+
Defaults to log2 fold change with pseudocount: log2(mean(x1) + 1e-8) - log2(mean(x0) + 1e-8).
235+
The callable should have signature: test_statistic(x0, x1) -> float.
236+
fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify.
237+
test_kwargs: Additional kwargs passed to the permutation test function (not the test statistic). The
238+
permutation test function is `scipy.stats.permutation_test`, so please refer to its documentation for
239+
available options. Note that `test_statistic` and `n_permutations` are set by this function and should
240+
not be provided here.
241+
242+
Examples:
243+
>>> # Difference in means (log fold change)
244+
>>> PermutationTest.compare_groups(
245+
... adata,
246+
... column="condition",
247+
... baseline="A",
248+
... groups_to_compare="B",
249+
... test_statistic=lambda x, y: np.log2(np.mean(y)) - np.log2(np.mean(x)),
250+
... n_permutations=1000,
251+
... test_kwargs={"rng": 0},
252+
... )
253+
"""
254+
enhanced_test_kwargs = dict(test_kwargs)
255+
enhanced_test_kwargs.update({"test_statistic": test_statistic, "n_permutations": n_permutations})
256+
257+
return super().compare_groups(
258+
adata=adata,
259+
column=column,
260+
baseline=baseline,
261+
groups_to_compare=groups_to_compare,
262+
paired_by=paired_by,
263+
mask=mask,
264+
layer=layer,
265+
fit_kwargs=fit_kwargs,
266+
test_kwargs=enhanced_test_kwargs,
267+
)
268+
269+
@staticmethod
270+
def _test(
271+
x0: np.ndarray,
272+
x1: np.ndarray,
273+
paired: bool,
274+
test_statistic: Callable[[np.ndarray, np.ndarray], float] = lambda x, y: np.log2(np.mean(y) + 1e-8)
275+
- np.log2(np.mean(x) + 1e-8),
276+
n_permutations: int = 1000,
277+
**kwargs,
278+
) -> dict[str, float]:
279+
"""Perform a permutation test.
280+
281+
This function uses a simple test statistic function to compute p-values through permutations.
282+
283+
Args:
284+
x0: Array with baseline values.
285+
x1: Array with values to compare.
286+
paired: Whether to perform a paired test.
287+
test_statistic: A callable that takes two arrays (x0, x1) and returns a float statistic. Please refer to
288+
the examples below for usage. The callable should have signature: test_statistic(x0, x1) -> float.
289+
n_permutations: Number of permutations to perform.
290+
**kwargs: Additional kwargs passed to scipy.stats.permutation_test.
291+
292+
Examples:
293+
>>> # Difference in means (log fold change)
294+
>>> PermutationTest._test(x0, x1, paired=False)
295+
>>>
296+
>>> # Difference in medians
297+
>>> median_diff = lambda x, y: np.median(y) - np.median(x)
298+
>>> PermutationTest._test(x0, x1, paired=False, test_statistic=median_diff)
299+
"""
300+
test_result = scipy.stats.permutation_test(
301+
[x0, x1],
302+
statistic=lambda x0_perm, x1_perm: test_statistic(x0_perm, x1_perm),
303+
n_resamples=n_permutations,
304+
permutation_type=("samples" if paired else "independent"),
305+
**kwargs,
306+
)
307+
308+
return {
309+
"p_value": test_result.pvalue,
310+
"statistic": test_result.statistic,
311+
}

tests/tools/_differential_gene_expression/test_simple_tests.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
if find_spec("formulaic_contrasts") is None or find_spec("formulaic") is None:
99
pytestmark = pytest.mark.skip(reason="formulaic_contrasts and formulaic not available")
1010

11-
from pertpy.tools._differential_gene_expression import SimpleComparisonBase, TTest, WilcoxonTest
11+
from pertpy.tools._differential_gene_expression import PermutationTest, SimpleComparisonBase, TTest, WilcoxonTest
1212

1313

1414
@pytest.mark.parametrize(
@@ -67,6 +67,52 @@ def test_t(test_adata_minimal, paired_by, expected):
6767
assert actual[gene] == pytest.approx(expected[gene], abs=0.02)
6868

6969

70+
@pytest.mark.parametrize(
71+
"paired_by,expected",
72+
[
73+
pytest.param(
74+
None,
75+
{"gene1": {"p_value": 2.13e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.96, "log_fc": -0.016}},
76+
id="unpaired",
77+
),
78+
pytest.param(
79+
"pairing",
80+
{"gene1": {"p_value": 1.63e-26, "log_fc": -5.14}, "gene2": {"p_value": 0.85, "log_fc": -0.016}},
81+
id="paired",
82+
),
83+
],
84+
)
85+
def test_permutation(test_adata_minimal, paired_by, expected):
86+
"""Test that permutation test gives the correct values.
87+
88+
Reference values have been computed in R using wilcox.test
89+
"""
90+
# Test with different simple test statistics
91+
test_statistics = [
92+
lambda x, y: np.log2(np.mean(y)) - np.log2(np.mean(x)), # log fold change between means (default)
93+
lambda x, y: np.mean(y) - np.mean(x), # mean difference
94+
lambda x, y: np.max(y) - np.max(x), # max difference
95+
]
96+
97+
for test_stat in test_statistics:
98+
res_df = PermutationTest.compare_groups(
99+
adata=test_adata_minimal,
100+
column="condition",
101+
baseline="A",
102+
groups_to_compare="B",
103+
test_statistic=test_stat,
104+
paired_by=paired_by,
105+
n_permutations=1000,
106+
test_kwargs={"rng": 0},
107+
)
108+
assert isinstance(res_df, DataFrame), "PermutationTest.compare_groups should return a DataFrame"
109+
actual = res_df.loc[:, ["variable", "p_value", "log_fc"]].set_index("variable").to_dict(orient="index")
110+
for gene in expected:
111+
assert (actual[gene]["p_value"] < 0.05) == (expected[gene]["p_value"] < 0.05)
112+
if actual[gene]["p_value"] < 0.05:
113+
assert actual[gene] == pytest.approx(expected[gene], abs=0.02)
114+
115+
70116
@pytest.mark.parametrize("seed", range(10))
71117
def test_simple_comparison_pairing(test_adata_minimal, seed):
72118
"""Test that paired samples are properly matched in a paired test"""

0 commit comments

Comments
 (0)