22
33import warnings
44from abc import abstractmethod
5- from collections .abc import Mapping , Sequence
5+ from collections .abc import Callable , Mapping , Sequence
66from types import MappingProxyType
77
88import numpy as np
@@ -33,7 +33,7 @@ def fdr_correction(
3333class SimpleComparisonBase (MethodBase ):
3434 @staticmethod
3535 @abstractmethod
36- def _test (x0 : np .ndarray , x1 : np .ndarray , paired : bool , ** kwargs ) -> float :
36+ def _test (x0 : np .ndarray , x1 : np .ndarray , paired : bool , ** kwargs ) -> dict [ str , float ] :
3737 """Perform a statistical test between values in x0 and x1.
3838
3939 If `paired` is True, x0 and x1 must be of the same length and ordered such that
@@ -44,6 +44,10 @@ def _test(x0: np.ndarray, x1: np.ndarray, paired: bool, **kwargs) -> float:
4444 x1: Array with values to compare.
4545 paired: Indicates whether to perform a paired test
4646 **kwargs: kwargs passed to the test function
47+
48+ Returns:
49+ A dictionary metric -> value.
50+ This allows to return values for different metrics (e.g. p-value + test statistic).
4751 """
4852 ...
4953
@@ -71,16 +75,16 @@ def _compare_single_group(
7175 x0 = x0 .tocsc ()
7276 x1 = x1 .tocsc ()
7377
74- res = []
78+ res : list [ dict [ str , float ]] = []
7579 for var in tqdm (self .adata .var_names ):
7680 tmp_x0 = x0 [:, self .adata .var_names == var ]
7781 tmp_x0 = np .asarray (tmp_x0 .todense ()).flatten () if issparse (tmp_x0 ) else tmp_x0 .flatten ()
7882 tmp_x1 = x1 [:, self .adata .var_names == var ]
7983 tmp_x1 = np .asarray (tmp_x1 .todense ()).flatten () if issparse (tmp_x1 ) else tmp_x1 .flatten ()
80- pval = self ._test (tmp_x0 , tmp_x1 , paired , ** kwargs )
84+ test_result = self ._test (tmp_x0 , tmp_x1 , paired , ** kwargs )
8185 mean_x0 = np .mean (tmp_x0 )
8286 mean_x1 = np .mean (tmp_x1 )
83- res .append ({"variable" : var , "p_value" : pval , " log_fc" : np .log2 (mean_x1 ) - np .log2 (mean_x0 )})
87+ res .append ({"variable" : var , "log_fc" : np .log2 (mean_x1 ) - np .log2 (mean_x0 ), ** test_result })
8488 return pd .DataFrame (res ).sort_values ("p_value" )
8589
8690 @classmethod
@@ -97,6 +101,20 @@ def compare_groups(
97101 fit_kwargs : Mapping = MappingProxyType ({}),
98102 test_kwargs : Mapping = MappingProxyType ({}),
99103 ) -> DataFrame :
104+ """Perform a comparison between groups.
105+
106+ Args:
107+ adata: Data with observations to compare.
108+ column: Column in `adata.obs` that contains the groups to compare.
109+ baseline: Reference group.
110+ groups_to_compare: Groups to compare against the baseline. If None, all other groups
111+ are compared.
112+ paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed.
113+ mask: Mask to apply to the data.
114+ layer: Layer to use for the comparison.
115+ fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify.
116+ test_kwargs: Additional kwargs passed to the test function.
117+ """
100118 if len (fit_kwargs ):
101119 warnings .warn ("fit_kwargs not used for simple tests." , UserWarning , stacklevel = 2 )
102120 paired = paired_by is not None
@@ -144,19 +162,150 @@ class WilcoxonTest(SimpleComparisonBase):
144162 """
145163
146164 @staticmethod
147- def _test (x0 : np .ndarray , x1 : np .ndarray , paired : bool , ** kwargs ) -> float :
148- if paired :
149- return scipy .stats .wilcoxon (x0 , x1 , ** kwargs ).pvalue
150- else :
151- return scipy .stats .mannwhitneyu (x0 , x1 , ** kwargs ).pvalue
165+ def _test (x0 : np .ndarray , x1 : np .ndarray , paired : bool , ** kwargs ) -> dict [str , float ]:
166+ """Perform an unpaired or paired Wilcoxon/Mann-Whitney-U test."""
167+ test_result = scipy .stats .wilcoxon (x0 , x1 , ** kwargs ) if paired else scipy .stats .mannwhitneyu (x0 , x1 , ** kwargs )
168+
169+ return {
170+ "p_value" : test_result .pvalue ,
171+ "statistic" : test_result .statistic ,
172+ }
152173
153174
154175class TTest (SimpleComparisonBase ):
155176 """Perform a unpaired or paired T-test."""
156177
157178 @staticmethod
158- def _test (x0 : np .ndarray , x1 : np .ndarray , paired : bool , ** kwargs ) -> float :
159- if paired :
160- return scipy .stats .ttest_rel (x0 , x1 , ** kwargs ).pvalue
161- else :
162- return scipy .stats .ttest_ind (x0 , x1 , ** kwargs ).pvalue
179+ def _test (x0 : np .ndarray , x1 : np .ndarray , paired : bool , ** kwargs ) -> dict [str , float ]:
180+ test_result = scipy .stats .ttest_rel (x0 , x1 , ** kwargs ) if paired else scipy .stats .ttest_ind (x0 , x1 , ** kwargs )
181+
182+ return {
183+ "p_value" : test_result .pvalue ,
184+ "statistic" : test_result .statistic ,
185+ }
186+
187+
188+ class PermutationTest (SimpleComparisonBase ):
189+ """Perform a permutation test.
190+
191+ The permutation test relies on another test statistic (e.g. t-statistic or your own) to obtain a p-value through
192+ random permutations of the data and repeated generation of the test statistic.
193+
194+ For paired tests, each paired observation is permuted together and distributed randomly between the two groups. For
195+ unpaired tests, all observations are permuted independently.
196+
197+ The null hypothesis for the unpaired test is that all observations come from the same underlying distribution and
198+ have been randomly assigned to one of the samples.
199+
200+ The null hypothesis for the paired permutation test is that the observations within each pair are drawn from the
201+ same underlying distribution and that their assignment to a sample is random.
202+ """
203+
204+ @classmethod
205+ def compare_groups (
206+ cls ,
207+ adata : AnnData ,
208+ column : str ,
209+ baseline : str ,
210+ groups_to_compare : str | Sequence [str ],
211+ * ,
212+ paired_by : str | None = None ,
213+ mask : str | None = None ,
214+ layer : str | None = None ,
215+ n_permutations : int = 1000 ,
216+ test_statistic : Callable [[np .ndarray , np .ndarray ], float ] = lambda x , y : np .log2 (np .mean (y ) + 1e-8 )
217+ - np .log2 (np .mean (x ) + 1e-8 ),
218+ fit_kwargs : Mapping = MappingProxyType ({}),
219+ test_kwargs : Mapping = MappingProxyType ({}),
220+ ) -> DataFrame :
221+ """Perform a permutation test comparison between groups.
222+
223+ Args:
224+ adata: Data with observations to compare.
225+ column: Column in `adata.obs` that contains the groups to compare.
226+ baseline: Reference group.
227+ groups_to_compare: Groups to compare against the baseline. If None, all other groups
228+ are compared.
229+ paired_by: Column in `adata.obs` to use for pairing. If None, an unpaired test is performed.
230+ mask: Mask to apply to the data.
231+ layer: Layer to use for the comparison.
232+ n_permutations: Number of permutations to perform.
233+ test_statistic: A callable that takes two arrays (x0, x1) and returns a float statistic.
234+ Defaults to log2 fold change with pseudocount: log2(mean(x1) + 1e-8) - log2(mean(x0) + 1e-8).
235+ The callable should have signature: test_statistic(x0, x1) -> float.
236+ fit_kwargs: Unused argument for compatibility with the `MethodBase` interface, do not specify.
237+ test_kwargs: Additional kwargs passed to the permutation test function (not the test statistic). The
238+ permutation test function is `scipy.stats.permutation_test`, so please refer to its documentation for
239+ available options. Note that `test_statistic` and `n_permutations` are set by this function and should
240+ not be provided here.
241+
242+ Examples:
243+ >>> # Difference in means (log fold change)
244+ >>> PermutationTest.compare_groups(
245+ ... adata,
246+ ... column="condition",
247+ ... baseline="A",
248+ ... groups_to_compare="B",
249+ ... test_statistic=lambda x, y: np.log2(np.mean(y)) - np.log2(np.mean(x)),
250+ ... n_permutations=1000,
251+ ... test_kwargs={"rng": 0},
252+ ... )
253+ """
254+ enhanced_test_kwargs = dict (test_kwargs )
255+ enhanced_test_kwargs .update ({"test_statistic" : test_statistic , "n_permutations" : n_permutations })
256+
257+ return super ().compare_groups (
258+ adata = adata ,
259+ column = column ,
260+ baseline = baseline ,
261+ groups_to_compare = groups_to_compare ,
262+ paired_by = paired_by ,
263+ mask = mask ,
264+ layer = layer ,
265+ fit_kwargs = fit_kwargs ,
266+ test_kwargs = enhanced_test_kwargs ,
267+ )
268+
269+ @staticmethod
270+ def _test (
271+ x0 : np .ndarray ,
272+ x1 : np .ndarray ,
273+ paired : bool ,
274+ test_statistic : Callable [[np .ndarray , np .ndarray ], float ] = lambda x , y : np .log2 (np .mean (y ) + 1e-8 )
275+ - np .log2 (np .mean (x ) + 1e-8 ),
276+ n_permutations : int = 1000 ,
277+ ** kwargs ,
278+ ) -> dict [str , float ]:
279+ """Perform a permutation test.
280+
281+ This function uses a simple test statistic function to compute p-values through permutations.
282+
283+ Args:
284+ x0: Array with baseline values.
285+ x1: Array with values to compare.
286+ paired: Whether to perform a paired test.
287+ test_statistic: A callable that takes two arrays (x0, x1) and returns a float statistic. Please refer to
288+ the examples below for usage. The callable should have signature: test_statistic(x0, x1) -> float.
289+ n_permutations: Number of permutations to perform.
290+ **kwargs: Additional kwargs passed to scipy.stats.permutation_test.
291+
292+ Examples:
293+ >>> # Difference in means (log fold change)
294+ >>> PermutationTest._test(x0, x1, paired=False)
295+ >>>
296+ >>> # Difference in medians
297+ >>> median_diff = lambda x, y: np.median(y) - np.median(x)
298+ >>> PermutationTest._test(x0, x1, paired=False, test_statistic=median_diff)
299+ """
300+ test_result = scipy .stats .permutation_test (
301+ [x0 , x1 ],
302+ statistic = lambda x0_perm , x1_perm : test_statistic (x0_perm , x1_perm ),
303+ n_resamples = n_permutations ,
304+ permutation_type = ("samples" if paired else "independent" ),
305+ ** kwargs ,
306+ )
307+
308+ return {
309+ "p_value" : test_result .pvalue ,
310+ "statistic" : test_result .statistic ,
311+ }
0 commit comments