Skip to content

Commit 7465939

Browse files
authored
Add an inflation factor to correct for multiple contrasts in Stouffer's combination test (#117)
* Add correction term for multiple contrasts in Stouffer's combination test * Update RTD yml * Update .readthedocs.yml * Update combination.py * Update .readthedocs.yml * Update setup.cfg * Update testing.yml * Update testing.yml * Run black * Make sure solutions and symbols match * Update combination.py
1 parent c38b0bb commit 7465939

File tree

2 files changed

+105
-8
lines changed

2 files changed

+105
-8
lines changed

pymare/estimators/combination.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,17 +111,77 @@ class StoufferCombinationTest(CombinationTest):
111111
"""
112112

113113
# Maps Dataset attributes onto fit() args; see BaseEstimator for details.
114-
_dataset_attr_map = {"z": "y", "w": "v"}
115-
116-
def fit(self, z, w=None):
117-
"""Fit the estimator to z-values, optionally with weights."""
118-
return super().fit(z, w=w)
119-
120-
def p_value(self, z, w=None):
114+
_dataset_attr_map = {"z": "y", "w": "n", "g": "v"}
115+
116+
def _inflation_term(self, z, w, g):
117+
"""Calculate the variance inflation term for each group.
118+
119+
This term is used to adjust the variance of the combined z-score when
120+
multiple sample come from the same study.
121+
122+
Parameters
123+
----------
124+
z : :obj:`numpy.ndarray` of shape (n, d)
125+
Array of z-values.
126+
w : :obj:`numpy.ndarray` of shape (n, d)
127+
Array of weights.
128+
g : :obj:`numpy.ndarray` of shape (n, d)
129+
Array of group labels.
130+
131+
Returns
132+
-------
133+
sigma : float
134+
The variance inflation term.
135+
"""
136+
# Only center if the samples are not all the same, to prevent division by zero
137+
# when calculating the correlation matrix.
138+
# This centering is problematic for N=2
139+
all_samples_same = np.all(np.equal(z, z[0]), axis=0).all()
140+
z = z if all_samples_same else z - z.mean(0)
141+
142+
# Use the value from one feature, as all features have the same groups and weights
143+
groups = g[:, 0]
144+
weights = w[:, 0]
145+
146+
# Loop over groups
147+
unique_groups = np.unique(groups)
148+
149+
sigma = 0
150+
for group in unique_groups:
151+
group_indices = np.where(groups == group)[0]
152+
group_z = z[group_indices]
153+
154+
# For groups with only one sample the contribution to the summand is 0
155+
n_samples = len(group_indices)
156+
if n_samples < 2:
157+
continue
158+
159+
# Calculate the within group correlation matrix and sum the non-diagonal elements
160+
corr = np.corrcoef(group_z, rowvar=True)
161+
upper_indices = np.triu_indices(n_samples, k=1)
162+
non_diag_corr = corr[upper_indices]
163+
w_i, w_j = weights[upper_indices[0]], weights[upper_indices[1]]
164+
165+
sigma += (2 * w_i * w_j * non_diag_corr).sum()
166+
167+
return sigma
168+
169+
def fit(self, z, w=None, g=None):
170+
"""Fit the estimator to z-values, optionally with weights and groups."""
171+
return super().fit(z, w=w, g=g)
172+
173+
def p_value(self, z, w=None, g=None):
121174
"""Calculate p-values."""
122175
if w is None:
123176
w = np.ones_like(z)
124-
cz = (z * w).sum(0) / np.sqrt((w**2).sum(0))
177+
178+
# Calculate the variance inflation term, sum of non-diagonal elements of sigma.
179+
sigma = self._inflation_term(z, w, g) if g is not None else 0
180+
181+
# The sum of diagonal elements of sigma is given by (w**2).sum(0).
182+
variance = (w**2).sum(0) + sigma
183+
184+
cz = (z * w).sum(0) / np.sqrt(variance)
125185
return ss.norm.sf(cz)
126186

127187

pymare/tests/test_combination_tests.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,40 @@ def test_combination_test_from_dataset(Cls, data, mode, expected):
4242
results = est.summary()
4343
z = ss.norm.isf(results.p)
4444
assert np.allclose(z, expected, atol=1e-5)
45+
46+
47+
def test_stouffer_adjusted():
48+
"""Test StoufferCombinationTest with weights and groups."""
49+
# Test with weights and groups
50+
data = np.array(
51+
[
52+
[2.1, 0.7, -0.2, 4.1, 3.8],
53+
[1.1, 0.2, 0.4, 1.3, 1.5],
54+
[-0.6, -1.6, -2.3, -0.8, -4.0],
55+
[2.5, 1.7, 2.1, 2.3, 2.5],
56+
[3.1, 2.7, 3.1, 3.3, 3.5],
57+
[3.6, 3.2, 3.6, 3.8, 4.0],
58+
]
59+
)
60+
weights = np.tile(np.array([4, 3, 4, 10, 15, 10]), (data.shape[1], 1)).T
61+
groups = np.tile(np.array([0, 0, 1, 2, 2, 2]), (data.shape[1], 1)).T
62+
63+
results = StoufferCombinationTest("directed").fit(z=data, w=weights, g=groups).params_
64+
z = ss.norm.isf(results["p"])
65+
66+
z_expected = np.array([5.00088912, 3.70356943, 4.05465924, 5.4633001, 5.18927878])
67+
assert np.allclose(z, z_expected, atol=1e-5)
68+
69+
# Test with weights and no groups. Limiting cases.
70+
# Limiting case 1: all correlations are one.
71+
n_maps_l1 = 5
72+
common_sample = np.array([2.1, 0.7, -0.2])
73+
data_l1 = np.tile(common_sample, (n_maps_l1, 1))
74+
groups_l1 = np.tile(np.array([0, 0, 0, 0, 0]), (data_l1.shape[1], 1)).T
75+
76+
results_l1 = StoufferCombinationTest("directed").fit(z=data_l1, g=groups_l1).params_
77+
z_l1 = ss.norm.isf(results_l1["p"])
78+
79+
sigma_l1 = n_maps_l1 * (n_maps_l1 - 1) # Expected inflation term
80+
z_expected_l1 = n_maps_l1 * common_sample / np.sqrt(n_maps_l1 + sigma_l1)
81+
assert np.allclose(z_l1, z_expected_l1, atol=1e-5)

0 commit comments

Comments
 (0)