-
Notifications
You must be signed in to change notification settings - Fork 39
Add permutation test #726
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add permutation test #726
Changes from 3 commits
f99e5ac
66f228e
74e3ec1
4fdd140
64b7114
af7ca5d
325f1db
25d1689
1b6c65e
3e81976
14736b3
5873b87
676b4f0
442b603
8ae69ce
ebc30fb
1336ddb
95d7da4
e2c53fb
412ab3b
52d2d58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| import scipy.stats | ||
| import statsmodels | ||
| from anndata import AnnData | ||
| from joblib import Parallel, delayed | ||
| from pandas.core.api import DataFrame as DataFrame | ||
| from scipy.sparse import diags, issparse | ||
| from tqdm.auto import tqdm | ||
|
|
@@ -152,11 +153,154 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: | |
|
|
||
|
|
||
| class TTest(SimpleComparisonBase): | ||
| """Perform a unpaired or paired T-test""" | ||
| """Perform a unpaired or paired T-test.""" | ||
|
|
||
| @staticmethod | ||
| def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float: | ||
| if paired: | ||
| return scipy.stats.ttest_rel(x0, x1, **kwargs).pvalue | ||
| else: | ||
| return scipy.stats.ttest_ind(x0, x1, **kwargs).pvalue | ||
|
|
||
|
|
||
| class PermutationTest(SimpleComparisonBase): | ||
| """Perform a permutation test. | ||
|
|
||
| The permutation test relies on another test (e.g. WilcoxonTest) to perform the actual comparison | ||
| based on permuted data. The p-value is then calculated based on the distribution of the test | ||
| statistic under the null hypothesis. | ||
|
|
||
| For paired tests, each paired observation is permuted together and distributed randoml between | ||
| the two groups. For unpaired tests, all observations are permuted independently. | ||
|
|
||
| The null hypothesis for the unpaired test is that all observations come from the same underlying | ||
| distribution and have been randomly assigned to one of the samples. | ||
|
|
||
| The null hypothesis for the paired permutation test is that the observations within each pair are | ||
| drawn from the same underlying distribution and that their assignment to a sample is random. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def _test( | ||
| x0: np.ndarray, | ||
Zethson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x1: np.ndarray, | ||
| paired: bool, | ||
Zethson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| test: type["SimpleComparisonBase"] = WilcoxonTest, | ||
| n_permutations: int = 100, | ||
| seed: int = 0, | ||
| **kwargs, | ||
| ) -> float: | ||
| """Perform a permutation test. | ||
|
|
||
| Args: | ||
| x0: Array with baseline values. | ||
| x1: Array with values to compare. | ||
| paired: Indicates whether to perform a paired test | ||
maltekuehl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| test: The test to use for the actual comparison. | ||
| n_permutations: Number of permutations to perform. | ||
| **kwargs: kwargs passed to the test function | ||
maltekuehl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
|
|
||
| def call_test(x0, x1, **kwargs): | ||
| """Perform the actual test.""" | ||
| return test._test(x0, x1, paired, **kwargs) | ||
|
|
||
| if paired: | ||
maltekuehl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return scipy.stats.permutation_test( | ||
| [x0, x1], | ||
| statistic=call_test, | ||
| n_resamples=n_permutations, | ||
| permutation_type="samples", | ||
| rng=seed, | ||
| **kwargs, | ||
| ).pvalue | ||
| else: | ||
| return scipy.stats.permutation_test( | ||
| [x0, x1], | ||
| statistic=call_test, | ||
| n_resamples=n_permutations, | ||
| permutation_type="independent", | ||
| rng=seed, | ||
| **kwargs, | ||
| ).pvalue | ||
|
|
||
| @classmethod | ||
| def compare_groups( | ||
| cls, | ||
| adata: AnnData, | ||
| column: str, | ||
| baseline: str, | ||
| groups_to_compare: str | Sequence[str], | ||
| test: type["SimpleComparisonBase"] = WilcoxonTest, | ||
| n_permutations: int = 100, | ||
| n_jobs: int = -1, | ||
| seed: int = 0, | ||
|
||
| *, | ||
| paired_by: str | None = None, | ||
| mask: str | None = None, | ||
| layer: str | None = None, | ||
| fit_kwargs: Mapping = MappingProxyType({}), | ||
| test_kwargs: Mapping = MappingProxyType({}), | ||
| ) -> DataFrame: | ||
| """Perform a comparison between groups using a permutation test. | ||
|
|
||
| Args: | ||
| adata: Annotated data object. | ||
| column: Column in `adata.obs` that contains the groups to compare. | ||
grst marked this conversation as resolved.
Show resolved
Hide resolved
maltekuehl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| baseline: Reference group. | ||
Zethson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| groups_to_compare: Groups to compare against the baseline. | ||
| test: The test to use for the actual comparison after permutation. Default is TTest. | ||
maltekuehl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| n_permutations: Number of permutations to perform. | ||
| n_jobs: Number of parallel jobs to use. | ||
| paired_by: Column in `adata.obs` to use for pairing. | ||
| mask: Mask to apply to the data. | ||
| layer: Layer to use for the comparison. | ||
| fit_kwargs: Additional kwargs passed to the test function. | ||
maltekuehl marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| test_kwargs: Additional kwargs passed to the test function. | ||
| """ | ||
| if len(fit_kwargs): | ||
|
||
| warnings.warn("fit_kwargs not used for simple tests.", UserWarning, stacklevel=2) | ||
| paired = paired_by is not None | ||
| model = cls(adata, mask=mask, layer=layer) | ||
| if groups_to_compare is None: | ||
| # compare against all other | ||
| groups_to_compare = sorted(set(model.adata.obs[column]) - {baseline}) | ||
| if isinstance(groups_to_compare, str): | ||
| groups_to_compare = [groups_to_compare] | ||
|
|
||
| def _get_idx(column, value): | ||
| mask = model.adata.obs[column] == value | ||
| if paired: | ||
| dummies = pd.get_dummies(model.adata.obs[paired_by], sparse=True).sparse.to_coo().tocsr() | ||
| if not np.all(np.sum(dummies, axis=0) == 2): | ||
| raise ValueError("Pairing is only possible with exactly two values per group") | ||
| # Use matrix multiplication to only retreive those dummy entries that are associated with the current `value`. | ||
| # Convert to COO matrix to get rows/cols | ||
| # row indices refers to the indices of rows that have `column == value` (equivalent to np.where(mask)[0]) | ||
| # col indices refers to the numeric index of each "pair" in obs_names | ||
| ind_mat = diags(mask.values, dtype=bool) @ dummies | ||
| if not np.all(np.sum(ind_mat, axis=0) == 1): | ||
| raise ValueError("Pairing is only possible with exactly two values per group") | ||
| ind_mat = ind_mat.tocoo() | ||
| return ind_mat.row[np.argsort(ind_mat.col)] | ||
| else: | ||
| return np.where(mask)[0] | ||
|
|
||
| test_kwargs_mutable = dict(test_kwargs) | ||
| test_kwargs_mutable.update({"test": test, "n_permutations": n_permutations, "seed": seed}) | ||
|
|
||
| res_dfs = [] | ||
| baseline_idx = _get_idx(column, baseline) | ||
|
|
||
| comparison_indices = [_get_idx(column, group_to_compare) for group_to_compare in groups_to_compare] | ||
| res_dfs = Parallel(n_jobs=n_jobs)( | ||
| delayed(model._compare_single_group)(baseline_idx, comparison_idx, paired=paired, **test_kwargs_mutable) | ||
| for comparison_idx in comparison_indices | ||
| ) | ||
| res_dfs = [ | ||
| df.assign( | ||
| comparison=f"{group_to_compare}_vs_{baseline if baseline is not None else 'rest'}", | ||
| ) | ||
| for df, group_to_compare in zip(res_dfs, groups_to_compare, strict=False) | ||
| ] | ||
| return fdr_correction(pd.concat(res_dfs)) | ||
Uh oh!
There was an error while loading. Please reload this page.