Skip to content

Commit 0a331b7

Browse files
authored
Merge pull request #466 from jakobrunge/developer
Developer
2 parents 12151c1 + 702e70a commit 0a331b7

File tree

2 files changed

+60
-11
lines changed

2 files changed

+60
-11
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def run(self):
6161
# Run the setup
6262
setup(
6363
name="tigramite",
64-
version="5.2.8.5",
64+
version="5.2.9.1",
6565
packages=["tigramite", "tigramite.independence_tests", "tigramite.toymodels"],
6666
license="GNU General Public License v3.0",
6767
description="Tigramite causal inference for time series",

tigramite/independence_tests/cmiknn.py

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# License: GNU General Public License v3.0
66

77
from __future__ import print_function
8-
from scipy import special, spatial
8+
from scipy import special, spatial, stats
99
import numpy as np
1010
from numba import jit
1111
import warnings
@@ -82,6 +82,24 @@ class CMIknn(CondIndTest):
8282
Number of workers to use for parallel processing. If -1 is given
8383
all processors are used. Default: -1.
8484
85+
null_fit : {None, 'normal', 'gamma'}, optional (default: None)
86+
If None, the empirical surrogate distribution is used to compute
87+
the p-value (default behavior). If 'normal' or 'gamma', a parametric
88+
distribution is fit to the null samples:
89+
90+
* 'normal': Fit a Gaussian N(μ, σ²) to the null distribution.
91+
* 'gamma' : Fit a three-parameter Gamma(a, loc, scale).
92+
93+
This can reduce the number of required surrogate samples for
94+
significance testing, but may lead to miscalibrated p-values if
95+
the parametric family is a poor fit.
96+
97+
permute : {'Y', 'X'}, optional (default: 'X')
98+
Which variable to permute in the restricted shuffle test.
99+
- 'Y': shuffle Y within Z-neighborhoods (default). This is often
100+
preferable when Z is chosen as (approximate) parents of Y.
101+
- 'X': shuffle X within Z-neighborhoods.
102+
85103
model_selection_folds : int (optional, default = 3)
86104
Number of folds in cross-validation used in model selection.
87105
@@ -106,6 +124,8 @@ def __init__(self,
106124
transform='ranks',
107125
workers=-1,
108126
model_selection_folds=3,
127+
null_fit=None,
128+
permute='X',
109129
**kwargs):
110130
# Set the member variables
111131
self.knn = knn
@@ -116,6 +136,8 @@ def __init__(self,
116136
self.residual_based = False
117137
self.recycle_residuals = False
118138
self.workers = workers
139+
self.null_fit = null_fit
140+
self.permute = permute
119141
self.model_selection_folds = model_selection_folds
120142
# Call the parent constructor
121143
CondIndTest.__init__(self, significance=significance, **kwargs)
@@ -126,7 +148,10 @@ def __init__(self,
126148
else:
127149
print("knn = %s" % self.knn)
128150
print("shuffle_neighbors = %d\n" % self.shuffle_neighbors)
129-
151+
print(f"Restricted shuffle permutes: {self.permute}")
152+
if self.null_fit is not None:
153+
print("Using parametric null fit:", self.null_fit)
154+
130155
@jit(forceobj=True)
131156
def _get_nearest_neighbors(self, array, xyz, knn):
132157
"""Returns nearest neighbors according to Frenzel and Pompe (2007).
@@ -269,9 +294,10 @@ def get_shuffle_significance(self, array, xyz, value,
269294
For non-empty Z, overwrites get_shuffle_significance from the parent
270295
class which is a block shuffle test, which does not preserve
271296
dependencies of X and Y with Z. Here the parameter shuffle_neighbors is
272-
used to permute only those values :math:`x_i` and :math:`x_j` for which
273-
:math:`z_j` is among the nearest niehgbors of :math:`z_i`. If Z is
274-
empty, the block-shuffle test is used.
297+
used to permute only those values :math:`y_i` and :math:`y_j` for which
298+
:math:`z_j` is among the nearest neighbors of :math:`z_i`. If Z is
299+
empty, the block-shuffle test is used. In the paper X is permuted, but
300+
permuting Y is preferable when Z is chosen as (approximate) parents of Y.
275301
276302
Parameters
277303
----------
@@ -300,6 +326,7 @@ class which is a block shuffle test, which does not preserve
300326

301327
# max_neighbors = max(1, int(max_neighbor_ratio*T))
302328
x_indices = np.where(xyz == 0)[0]
329+
y_indices = np.where(xyz == 1)[0]
303330
z_indices = np.where(xyz == 2)[0]
304331

305332
if len(z_indices) > 0 and self.shuffle_neighbors < T:
@@ -361,8 +388,15 @@ class which is a block shuffle test, which does not preserve
361388
order=order)
362389

363390
array_shuffled = np.copy(array)
364-
for i in x_indices:
365-
array_shuffled[i] = array[i, restricted_permutation]
391+
if self.permute == 'X':
392+
for i in x_indices:
393+
array_shuffled[i] = array[i, restricted_permutation]
394+
else: # permute Y
395+
for i in y_indices:
396+
array_shuffled[i] = array[i, restricted_permutation]
397+
# array_shuffled = np.copy(array)
398+
# for i in x_indices:
399+
# array_shuffled[i] = array[i, restricted_permutation]
366400

367401
null_dist[sam] = self.get_dependence_measure(array_shuffled,
368402
xyz)
@@ -375,8 +409,19 @@ class which is a block shuffle test, which does not preserve
375409
sig_blocklength=self.sig_blocklength,
376410
verbosity=self.verbosity)
377411

378-
# pval = (null_dist >= value).mean()
379-
pval = float(np.sum(null_dist >= value) + 1) / (self.sig_samples + 1)
412+
if self.null_fit == 'normal':
413+
mu, sigma = stats.norm.fit(null_dist)
414+
pval = 1.0 - stats.norm.cdf(value, loc=mu, scale=sigma)
415+
elif self.null_fit == 'gamma':
416+
try:
417+
a, loc, scale = stats.gamma.fit(null_dist)
418+
pval = 1.0 - stats.gamma.cdf(value, a, loc=loc, scale=scale)
419+
except Exception as e:
420+
warnings.warn(f"Gamma fit failed, falling back to empirical: {e}")
421+
pval = float(np.sum(null_dist >= value) + 1) / (self.sig_samples + 1)
422+
else:
423+
# fallback: empirical Monte Carlo
424+
pval = float(np.sum(null_dist >= value) + 1) / (self.sig_samples + 1)
380425

381426
if return_null_dist:
382427
# Sort
@@ -586,7 +631,11 @@ def lin_f(x): return x
586631
}
587632
knn = 10
588633
maxlag = 1
589-
cmi = CMIknn(seed=seed, knn=knn, sig_samples=500)
634+
cmi = CMIknn(seed=seed, knn=knn,
635+
sig_samples=50,
636+
null_fit='gamma',
637+
permut='Y',
638+
)
590639

591640
realizations = 100
592641
realizations_data = toys.structural_causal_process_ensemble(realizations=realizations,

0 commit comments

Comments
 (0)