55# License: GNU General Public License v3.0
66
77from __future__ import print_function
8- from scipy import special , spatial
8+ from scipy import special , spatial , stats
99import numpy as np
1010from numba import jit
1111import warnings
@@ -82,6 +82,24 @@ class CMIknn(CondIndTest):
8282 Number of workers to use for parallel processing. If -1 is given
8383 all processors are used. Default: -1.
8484
85+ null_fit : {None, 'normal', 'gamma'}, optional (default: None)
86+ If None, the empirical surrogate distribution is used to compute
87+ the p-value (default behavior). If 'normal' or 'gamma', a parametric
88+ distribution is fit to the null samples:
89+
90+ * 'normal': Fit a Gaussian N(μ, σ²) to the null distribution.
91+ * 'gamma' : Fit a three-parameter Gamma(a, loc, scale).
92+
93+ This can reduce the number of required surrogate samples for
94+ significance testing, but may lead to miscalibrated p-values if
95+ the parametric family is a poor fit.
96+
97+ permute : {'Y', 'X'}, optional (default: 'X')
98+ Which variable to permute in the restricted shuffle test.
99+ - 'Y': shuffle Y within Z-neighborhoods (default). This is often
100+ preferable when Z is chosen as (approximate) parents of Y.
101+ - 'X': shuffle X within Z-neighborhoods.
102+
85103 model_selection_folds : int (optional, default = 3)
86104 Number of folds in cross-validation used in model selection.
87105
@@ -106,6 +124,8 @@ def __init__(self,
106124 transform = 'ranks' ,
107125 workers = - 1 ,
108126 model_selection_folds = 3 ,
127+ null_fit = None ,
128+ permute = 'X' ,
109129 ** kwargs ):
110130 # Set the member variables
111131 self .knn = knn
@@ -116,6 +136,8 @@ def __init__(self,
116136 self .residual_based = False
117137 self .recycle_residuals = False
118138 self .workers = workers
139+ self .null_fit = null_fit
140+ self .permute = permute
119141 self .model_selection_folds = model_selection_folds
120142 # Call the parent constructor
121143 CondIndTest .__init__ (self , significance = significance , ** kwargs )
@@ -126,7 +148,10 @@ def __init__(self,
126148 else :
127149 print ("knn = %s" % self .knn )
128150 print ("shuffle_neighbors = %d\n " % self .shuffle_neighbors )
129-
151+ print (f"Restricted shuffle permutes: { self .permute } " )
152+ if self .null_fit is not None :
153+ print ("Using parametric null fit:" , self .null_fit )
154+
130155 @jit (forceobj = True )
131156 def _get_nearest_neighbors (self , array , xyz , knn ):
132157 """Returns nearest neighbors according to Frenzel and Pompe (2007).
@@ -269,9 +294,10 @@ def get_shuffle_significance(self, array, xyz, value,
269294 For non-empty Z, overwrites get_shuffle_significance from the parent
270295 class which is a block shuffle test, which does not preserve
271296 dependencies of X and Y with Z. Here the parameter shuffle_neighbors is
272- used to permute only those values :math:`x_i` and :math:`x_j` for which
273- :math:`z_j` is among the nearest niehgbors of :math:`z_i`. If Z is
274- empty, the block-shuffle test is used.
297+ used to permute only those values :math:`y_i` and :math:`y_j` for which
298+ :math:`z_j` is among the nearest neighbors of :math:`z_i`. If Z is
299+ empty, the block-shuffle test is used. In the paper X is permuted, but
300+ permuting Y is preferable when Z is chosen as (approximate) parents of Y.
275301
276302 Parameters
277303 ----------
@@ -300,6 +326,7 @@ class which is a block shuffle test, which does not preserve
300326
301327 # max_neighbors = max(1, int(max_neighbor_ratio*T))
302328 x_indices = np .where (xyz == 0 )[0 ]
329+ y_indices = np .where (xyz == 1 )[0 ]
303330 z_indices = np .where (xyz == 2 )[0 ]
304331
305332 if len (z_indices ) > 0 and self .shuffle_neighbors < T :
@@ -361,8 +388,15 @@ class which is a block shuffle test, which does not preserve
361388 order = order )
362389
363390 array_shuffled = np .copy (array )
364- for i in x_indices :
365- array_shuffled [i ] = array [i , restricted_permutation ]
391+ if self .permute == 'X' :
392+ for i in x_indices :
393+ array_shuffled [i ] = array [i , restricted_permutation ]
394+ else : # permute Y
395+ for i in y_indices :
396+ array_shuffled [i ] = array [i , restricted_permutation ]
397+ # array_shuffled = np.copy(array)
398+ # for i in x_indices:
399+ # array_shuffled[i] = array[i, restricted_permutation]
366400
367401 null_dist [sam ] = self .get_dependence_measure (array_shuffled ,
368402 xyz )
@@ -375,8 +409,19 @@ class which is a block shuffle test, which does not preserve
375409 sig_blocklength = self .sig_blocklength ,
376410 verbosity = self .verbosity )
377411
378- # pval = (null_dist >= value).mean()
379- pval = float (np .sum (null_dist >= value ) + 1 ) / (self .sig_samples + 1 )
412+ if self .null_fit == 'normal' :
413+ mu , sigma = stats .norm .fit (null_dist )
414+ pval = 1.0 - stats .norm .cdf (value , loc = mu , scale = sigma )
415+ elif self .null_fit == 'gamma' :
416+ try :
417+ a , loc , scale = stats .gamma .fit (null_dist )
418+ pval = 1.0 - stats .gamma .cdf (value , a , loc = loc , scale = scale )
419+ except Exception as e :
420+ warnings .warn (f"Gamma fit failed, falling back to empirical: { e } " )
421+ pval = float (np .sum (null_dist >= value ) + 1 ) / (self .sig_samples + 1 )
422+ else :
423+ # fallback: empirical Monte Carlo
424+ pval = float (np .sum (null_dist >= value ) + 1 ) / (self .sig_samples + 1 )
380425
381426 if return_null_dist :
382427 # Sort
@@ -586,7 +631,11 @@ def lin_f(x): return x
586631 }
587632 knn = 10
588633 maxlag = 1
589- cmi = CMIknn (seed = seed , knn = knn , sig_samples = 500 )
634+ cmi = CMIknn (seed = seed , knn = knn ,
635+ sig_samples = 50 ,
636+ null_fit = 'gamma' ,
637+ permut = 'Y' ,
638+ )
590639
591640 realizations = 100
592641 realizations_data = toys .structural_causal_process_ensemble (realizations = realizations ,
0 commit comments