-
Notifications
You must be signed in to change notification settings - Fork 39
Add permutation test #726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add permutation test #726
Changes from 11 commits
f99e5ac
66f228e
74e3ec1
4fdd140
64b7114
af7ca5d
325f1db
25d1689
1b6c65e
3e81976
14736b3
5873b87
676b4f0
442b603
8ae69ce
ebc30fb
1336ddb
95d7da4
e2c53fb
412ab3b
52d2d58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,14 +2,16 @@ | |
|
|
||
| import warnings | ||
| from abc import abstractmethod | ||
| from collections.abc import Mapping, Sequence | ||
| from collections.abc import Callable, Mapping, Sequence | ||
| from types import MappingProxyType | ||
|
|
||
| import numpy as np | ||
| import pandas as pd | ||
| import scipy.stats | ||
| import statsmodels | ||
| from anndata import AnnData | ||
| from joblib import Parallel, delayed | ||
| from lamin_utils import logger | ||
| from pandas.core.api import DataFrame as DataFrame | ||
| from scipy.sparse import diags, issparse | ||
| from tqdm.auto import tqdm | ||
|
|
@@ -94,9 +96,28 @@ def compare_groups( | |
| paired_by: str | None = None, | ||
| mask: str | None = None, | ||
| layer: str | None = None, | ||
| n_permutations: int = 1000, | ||
| permutation_test: type["SimpleComparisonBase"] | None = None, | ||
| fit_kwargs: Mapping = MappingProxyType({}), | ||
| test_kwargs: Mapping = MappingProxyType({}), | ||
| n_jobs: int = -1, | ||
| ) -> DataFrame: | ||
| """Perform a comparison between groups. | ||
|
|
||
| Args: | ||
| adata (AnnData): Data with observations to compare. | ||
| column (str): Column in `adata.obs` that contains the groups to compare. | ||
| baseline (str): Reference group. | ||
| groups_to_compare (str | Sequence[str]): Groups to compare against the baseline. If None, all other groups are compared. | ||
| paired_by (str | None): Column in `adata.obs` to use for pairing. If None, an unpaired test is performed. | ||
| mask (str | None): Mask to apply to the data. | ||
| layer (str | None): Layer to use for the comparison. | ||
| n_permutations (int): Number of permutations to perform if a permutation test is used. | ||
| permutation_test (type[SimpleComparisonBase] | None): Test to use after permutation if a permutation test is used. | ||
| fit_kwargs (Mapping): Not used for simple tests. | ||
Zethson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| test_kwargs (Mapping): Additional kwargs passed to the test function. | ||
| n_jobs (int): Number of parallel jobs to use. | ||
| """ | ||
| if len(fit_kwargs): | ||
| warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) | ||
| paired = paired_by is not None | ||
|
|
@@ -127,13 +148,24 @@ def _get_idx(column, value): | |
|
|
||
| res_dfs = [] | ||
| baseline_idx = _get_idx(column, baseline) | ||
| for group_to_compare in groups_to_compare: | ||
| comparison_idx = _get_idx(column, group_to_compare) | ||
| res_dfs.append( | ||
| model._compare_single_group(baseline_idx, comparison_idx, paired=paired, **test_kwargs).assign( | ||
| comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}" | ||
| ) | ||
|
|
||
| if permutation_test: | ||
| test_kwargs = dict(test_kwargs) | ||
| test_kwargs.update({"test": permutation_test, "n_permutations": n_permutations}) | ||
| elif permutation_test is None and cls.__name__ == "PermutationTest": | ||
| logger.warning("No permutation test specified. Using WilcoxonTest as default.") | ||
|
|
||
| comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] | ||
|
||
| res_dfs = Parallel(n_jobs=n_jobs)( | ||
| delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs) | ||
| for comparison_idx in comparison_indices | ||
| ) | ||
| res_dfs = [ | ||
| df.assign( | ||
| comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}", | ||
| ) | ||
| for df, group_to_compare in zip(res_dfs, groups_to_compare, strict=False) | ||
| ] | ||
| return fdr_correction(pd.concat(res_dfs)) | ||
|
|
||
|
|
||
|
|
@@ -144,19 +176,100 @@ class WilcoxonTest(SimpleComparisonBase): | |
| """ | ||
|
|
||
| @staticmethod | ||
| def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: | ||
| def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, return_attribute: str = "pvalue", **kwargs) -> float: | ||
| if paired: | ||
| return scipy.stats.wilcoxon(x0, x1, **kwargs).pvalue | ||
| return scipy.stats.wilcoxon(x0, x1, **kwargs).__getattribute__(return_attribute) | ||
|
||
| else: | ||
| return scipy.stats.mannwhitneyu(x0, x1, **kwargs).pvalue | ||
| return scipy.stats.mannwhitneyu(x0, x1, **kwargs).__getattribute__(return_attribute) | ||
|
|
||
|
|
||
| class TTest(SimpleComparisonBase): | ||
| """Perform a unpaired or paired T-test""" | ||
| """Perform a unpaired or paired T-test.""" | ||
|
|
||
| @staticmethod | ||
| def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: | ||
| def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, return_attribute: str = "pvalue", **kwargs) -> float: | ||
| if paired: | ||
| return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue | ||
| return scipy.stats.ttest_rel(x0, x1, **kwargs).__getattribute__(return_attribute) | ||
| else: | ||
| return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue | ||
| return scipy.stats.ttest_ind(x0, x1, **kwargs).__getattribute__(return_attribute) | ||
|
|
||
|
|
||
| class PermutationTest(SimpleComparisonBase): | ||
| """Perform a permutation test. | ||
|
|
||
| The permutation test relies on another test (e.g. WilcoxonTest) to perform the actual comparison | ||
| based on permuted data. The p-value is then calculated based on the distribution of the test | ||
| statistic under the null hypothesis. | ||
|
|
||
| For paired tests, each paired observation is permuted together and distributed randomly between | ||
| the two groups. For unpaired tests, all observations are permuted independently. | ||
|
|
||
| The null hypothesis for the unpaired test is that all observations come from the same underlying | ||
| distribution and have been randomly assigned to one of the samples. | ||
|
|
||
| The null hypothesis for the paired permutation test is that the observations within each pair are | ||
| drawn from the same underlying distribution and that their assignment to a sample is random. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _test( | ||
| x0: np.ndarray, | ||
Zethson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x1: np.ndarray, | ||
| paired: bool, | ||
Zethson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| test: type["SimpleComparisonBase"] | Callable = WilcoxonTest, | ||
| n_permutations: int = 1000, | ||
| return_attribute: str = "pvalue", | ||
| **kwargs, | ||
| ) -> float: | ||
| """Perform a permutation test. | ||
|
|
||
| This function relies on another test (e.g. WilcoxonTest) to generate a test statistic for each permutation. | ||
|
|
||
| .. code-block:: python | ||
Zethson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from pertpy.tools import PermutationTest, WilcoxonTest | ||
|
|
||
| # Using rank-sum statistic | ||
| p_value = PermutationTest._test(x0, x1, paired=True, test=WilcoxonTest, n_permutations=1000, rng=0) | ||
|
|
||
|
|
||
| # Using a custom test statistic | ||
| def compare_means(x0, x1, paired): | ||
| # paired logic not implemented here | ||
| return np.mean(x1) - np.mean(x0) | ||
|
|
||
|
|
||
| p_value = PermutationTest._test(x0, x1, paired=False, test=compare_means, n_permutations=1000, rng=0) | ||
|
|
||
| Args: | ||
| x0: Array with baseline values. | ||
| x1: Array with values to compare. | ||
| paired: Whether to perform a paired test | ||
| test: The class or function to generate the test statistic from permuted data. | ||
| n_permutations: Number of permutations to perform. | ||
| return_attribute: Attribute to return from the test statistic. | ||
| **kwargs: kwargs passed to the permutation test function, not the test function after permutation. | ||
| """ | ||
| if test is PermutationTest: | ||
| raise ValueError( | ||
| "The `test` argument cannot be `PermutationTest`. Use a base test like `WilcoxonTest` or `TTest`." | ||
| ) | ||
|
|
||
| def call_test(data_baseline, data_comparison, axis: int | None = None, **kwargs): | ||
|
||
| """Perform the actual test.""" | ||
| # Setting the axis allows the operation to be vectorized | ||
| if axis is not None: | ||
| kwargs.update({"axis": axis}) | ||
|
|
||
| if not hasattr(test, "_test"): | ||
| return test(data_baseline, data_comparison, paired, **kwargs) | ||
|
|
||
| return test._test(data_baseline, data_comparison, paired, return_attribute="statistic", **kwargs) | ||
|
|
||
| return scipy.stats.permutation_test( | ||
| [x0, x1], | ||
| statistic=call_test, | ||
| n_resamples=n_permutations, | ||
| permutation_type=("samples" if paired else "independent"), | ||
| vectorized=hasattr(test, "_test"), | ||
| **kwargs, | ||
| ).__getattribute__(return_attribute) | ||
Uh oh!
There was an error while loading. Please reload this page.