Skip to content

Commit 4fd19c9

Browse files
authored
Merge pull request #50 from neurostuff/better-dataset-handling
[REF] Improved Dataset fitting
2 parents c7e4fcf + 2f190dd commit 4fd19c9

File tree

10 files changed

+133
-79
lines changed

10 files changed

+133
-79
lines changed

README.md

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,25 @@ from pymare.estimators import VarianceBasedLikelihoodEstimator
6060
dataset = Dataset(y, v, X)
6161
# Estimator class for likelihood-based methods when variances are known
6262
estimator = VarianceBasedLikelihoodEstimator(method='REML')
63-
# All estimators accept a `Dataset` instance as the first argument to `.fit()`
64-
estimator.fit(dataset)
63+
# All estimators expose a fit_dataset() method that takes a `Dataset`
64+
# instance as the first (and usually only) argument.
65+
estimator.fit_dataset(dataset)
6566
# Post-fitting we can obtain a MetaRegressionResults instance via .summary()
6667
results = estimator.summary()
6768
# Print summary of results as a pandas DataFrame
6869
print(result.to_df())
6970
```
71+
72+
And if we want to be even more explicit, we can avoid the `Dataset` abstraction
73+
entirely (though we'll lose some convenient validation checks):
74+
75+
```python
76+
estimator = VarianceBasedLikelihoodEstimator(method='REML')
77+
78+
# X must be 2-d; this is one of the things the Dataset implicitly handles.
79+
X = X[:, None]
80+
81+
estimator.fit(y, v, X)
82+
83+
results = estimator.summary()
84+
```

examples/02_meta-analysis/plot_run_meta-analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@
2929
# Datasets can also be created from pandas DataFrames
3030
# ---------------------------------------------------
3131
dataset = core.Dataset(v=v, X=X, y=y, n=n)
32-
est = estimators.WeightedLeastSquares().fit(dataset)
32+
est = estimators.WeightedLeastSquares().fit_dataset(dataset)
3333
results = est.summary()
3434
print(results.to_df())

pymare/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,5 @@ def meta_regression(y=None, v=None, X=None, n=None, data=None, X_names=None,
144144

145145
# Get estimates
146146
est = est_cls(**kwargs)
147-
est.fit(data)
147+
est.fit_dataset(data)
148148
return est.summary()

pymare/estimators/combination.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,17 +29,18 @@ def p_value(self, z, *args, **kwargs):
2929
def _z_to_p(self, z):
3030
return ss.norm.sf(z)
3131

32-
def _fit(self, y, *args, **kwargs):
32+
def fit(self, z, *args, **kwargs):
3333
if self.mode == 'concordant':
3434
ose = self.__class__(mode='directed')
35-
p1 = ose.p_value(y, *args, **kwargs)
36-
p2 = ose.p_value(-y, *args, **kwargs)
35+
p1 = ose.p_value(z, *args, **kwargs)
36+
p2 = ose.p_value(-z, *args, **kwargs)
3737
p = np.minimum(1, 2 * np.minimum(p1, p2))
3838
else:
3939
if self.mode == 'undirected':
40-
y = np.abs(y)
41-
p = self.p_value(y, *args, **kwargs)
42-
return {'p': p}
40+
z = np.abs(z)
41+
p = self.p_value(z, *args, **kwargs)
42+
self.params_ = {'p': p}
43+
return self
4344

4445
def summary(self):
4546
if not hasattr(self, 'params_'):
@@ -85,6 +86,13 @@ class StoufferCombinationTest(CombinationTest):
8586
(3) This estimator does not support meta-regression; any moderators
8687
passed in to fit() as the X array will be ignored.
8788
"""
89+
90+
# Maps Dataset attributes onto fit() args; see BaseEstimator for details.
91+
_dataset_attr_map = {'z': 'y', 'w': 'v'}
92+
93+
def fit(self, z, w=None):
94+
return super().fit(z, w=w)
95+
8896
def p_value(self, z, w=None):
8997
if w is None:
9098
w = np.ones_like(z)
@@ -128,6 +136,10 @@ class FisherCombinationTest(CombinationTest):
128136
(3) This estimator does not support meta-regression; any moderators
129137
passed in to fit() as the X array will be ignored.
130138
"""
139+
140+
# Maps Dataset attributes onto fit() args; see BaseEstimator for details.
141+
_dataset_attr_map = {'z': 'y'}
142+
131143
def p_value(self, z):
132144
p = self._z_to_p(z)
133145
chi2 = -2 * np.log(p).sum(0)

pymare/estimators/estimators.py

Lines changed: 65 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@wrapt.decorator
1818
def _loopable(wrapped, instance, args, kwargs):
19-
# Decorator for _fit method of Estimator classes to handle naive looping
19+
# Decorator for fit() method of Estimator classes to handle naive looping
2020
# over the 2nd dimension of y/v/n inputs, and reconstruction of outputs.
2121
n_iter = kwargs['y'].shape[1]
2222
if n_iter > 10:
@@ -26,6 +26,7 @@ def _loopable(wrapped, instance, args, kwargs):
2626
"datasets. Consider using the DL, HE, or WLS estimators, "
2727
"which handle parallel datasets more efficiently."
2828
.format(n_iter))
29+
2930
param_dicts = []
3031
for i in range(n_iter):
3132
iter_kwargs = {'X': kwargs['X']}
@@ -35,41 +36,58 @@ def _loopable(wrapped, instance, args, kwargs):
3536
if 'n' in kwargs:
3637
n = kwargs['n'][:, i, None] if kwargs['n'].shape[1] > 1 else kwargs['n']
3738
iter_kwargs['n'] = n
38-
param_dicts.append(wrapped(**iter_kwargs))
39+
wrapped(**iter_kwargs)
40+
param_dicts.append(instance.params_.copy())
41+
3942
params = {}
4043
for k in param_dicts[0]:
4144
concat = np.stack([pd[k].squeeze() for pd in param_dicts], axis=-1)
4245
params[k] = np.atleast_2d(concat)
43-
return params
46+
47+
instance.params_ = params
48+
return instance
4449

4550

4651
class BaseEstimator(metaclass=ABCMeta):
4752

53+
# A class-level mapping from Dataset attributes to fit() arguments. Used by
54+
# fit_dataset() for estimators that take non-standard arguments (e.g., 'z'
55+
# instead of 'y'). Keys are default Dataset attribute names (e.g., 'y') and
56+
# values are the target arg names in the estimator class's fit() method
57+
# (e.g., 'z').
58+
_dataset_attr_map = {}
59+
4860
@abstractmethod
49-
def _fit(self):
50-
# Subclasses must implement _fit() method that directly takes arrays.
51-
# The following named arguments are allowed, and will be automatically
52-
# extracted from the Dataset instance:
53-
# * y (estimates)
54-
# * v (variances)
55-
# * n (sample_sizes)
56-
# * X (predictors)
61+
def fit(self, *args, **kwargs):
5762
pass
5863

59-
def fit(self, dataset=None, **kwargs):
64+
def fit_dataset(self, dataset, *args, **kwargs):
65+
""" Applies the current estimator to the passed Dataset container.
6066
61-
if dataset is not None:
62-
kwargs = {}
63-
spec = getfullargspec(self._fit)
64-
n_kw = len(spec.defaults) if spec.defaults else 0
65-
n_args = len(spec.args) - n_kw - 1
66-
for i, name in enumerate(spec.args[1:]):
67-
if i >= n_args:
68-
kwargs[name] = getattr(dataset, name, spec.defaults[i - n_args])
69-
else:
70-
kwargs[name] = getattr(dataset, name)
67+
A convenience interface that wraps fit() and automatically aligns the
68+
variables held in a Dataset with the required arguments.
7169
72-
self.params_ = self._fit(**kwargs)
70+
Args:
71+
dataset (Dataset): A PyMARE Dataset instance holding the data.
72+
args, kwargs: optional positional and keyword arguments to pass
73+
onto the fit() method.
74+
"""
75+
all_kwargs = {}
76+
spec = getfullargspec(self.fit)
77+
n_kw = len(spec.defaults) if spec.defaults else 0
78+
n_args = len(spec.args) - n_kw - 1
79+
80+
for i, name in enumerate(spec.args[1:]):
81+
# Check for remapped name
82+
attr_name = self._dataset_attr_map.get(name, name)
83+
if i >= n_args:
84+
all_kwargs[name] = getattr(dataset, attr_name,
85+
spec.defaults[i - n_args])
86+
else:
87+
all_kwargs[name] = getattr(dataset, attr_name)
88+
89+
all_kwargs.update(kwargs)
90+
self.fit(*args, **all_kwargs)
7391
self.dataset_ = dataset
7492

7593
return self
@@ -86,7 +104,7 @@ def get_v(self, dataset):
86104
Notes:
87105
This is equivalent to directly accessing `dataset.v` when variances
88106
are present, but affords a way of estimating v from sample size (n)
89-
for any estimator that implicitly estimate a sigma^2 parameter.
107+
for any estimator that implicitly estimates a sigma^2 parameter.
90108
"""
91109
if dataset.v is not None:
92110
return dataset.v
@@ -139,12 +157,13 @@ class WeightedLeastSquares(BaseEstimator):
139157
def __init__(self, tau2=0.):
140158
self.tau2 = tau2
141159

142-
def _fit(self, y, X, v=None):
160+
def fit(self, y, X, v=None):
143161
if v is None:
144162
v = np.ones_like(y)
145163
beta, inv_cov = weighted_least_squares(y, v, X, self.tau2,
146164
return_cov=True)
147-
return {'fe_params': beta, 'tau2': self.tau2, 'inv_cov': inv_cov}
165+
self.params_ = {'fe_params': beta, 'tau2': self.tau2, 'inv_cov': inv_cov}
166+
return self
148167

149168

150169
class DerSimonianLaird(BaseEstimator):
@@ -167,7 +186,7 @@ class DerSimonianLaird(BaseEstimator):
167186
identical for all iterates.
168187
"""
169188

170-
def _fit(self, y, v, X):
189+
def fit(self, y, v, X):
171190
k, p = X.shape
172191

173192
# Estimate initial betas with WLS, assuming tau^2=0
@@ -189,7 +208,8 @@ def _fit(self, y, v, X):
189208
# Re-estimate beta with tau^2 estimate
190209
beta_dl, inv_cov = weighted_least_squares(y, v, X, tau2=tau_dl,
191210
return_cov=True)
192-
return {'fe_params': beta_dl, 'tau2': tau_dl, 'inv_cov': inv_cov}
211+
self.params_ = {'fe_params': beta_dl, 'tau2': tau_dl, 'inv_cov': inv_cov}
212+
return self
193213

194214

195215
class Hedges(BaseEstimator):
@@ -208,7 +228,7 @@ class Hedges(BaseEstimator):
208228
identical for all iterates.
209229
"""
210230

211-
def _fit(self, y, v, X):
231+
def fit(self, y, v, X):
212232
k, p = X.shape[:2]
213233
_unit_v = np.ones_like(y)
214234
beta, inv_cov = weighted_least_squares(y, _unit_v, X, return_cov=True)
@@ -217,7 +237,8 @@ def _fit(self, y, v, X):
217237
tau_ho = np.maximum(0, tau_ho)
218238
# Estimate beta with tau^2 estimate
219239
beta_ho = weighted_least_squares(y, v, X, tau2=tau_ho)
220-
return {'fe_params': beta_ho, 'tau2': tau_ho, 'inv_cov': inv_cov}
240+
self.params_ = {'fe_params': beta_ho, 'tau2': tau_ho, 'inv_cov': inv_cov}
241+
return self
221242

222243

223244
class VarianceBasedLikelihoodEstimator(BaseEstimator):
@@ -255,9 +276,9 @@ def __init__(self, method='ml', **kwargs):
255276
self.kwargs = kwargs
256277

257278
@_loopable
258-
def _fit(self, y, v, X):
279+
def fit(self, y, v, X):
259280
# use D-L estimate for initial values
260-
est_DL = DerSimonianLaird()._fit(y, v, X)
281+
est_DL = DerSimonianLaird().fit(y, v, X).params_
261282
beta = est_DL['fe_params']
262283
tau2 = est_DL['tau2']
263284

@@ -273,7 +294,8 @@ def _fit(self, y, v, X):
273294
beta, tau = res.x[:-1], float(res.x[-1])
274295
tau = np.max([tau, 0])
275296
_, inv_cov = weighted_least_squares(y, v, X, tau, True)
276-
return {'fe_params': beta[:, None], 'tau2': tau, 'inv_cov': inv_cov}
297+
self.params_ = {'fe_params': beta[:, None], 'tau2': tau, 'inv_cov': inv_cov}
298+
return self
277299

278300
def _ml_nll(self, theta, y, v, X):
279301
""" ML negative log-likelihood for meta-regression model. """
@@ -329,7 +351,7 @@ def __init__(self, method='ml', **kwargs):
329351
self.kwargs = kwargs
330352

331353
@_loopable
332-
def _fit(self, y, n, X):
354+
def fit(self, y, n, X):
333355
if n.std() < np.sqrt(np.finfo(float).eps):
334356
raise ValueError("Sample size-based likelihood estimator cannot "
335357
"work with all-equal sample sizes.")
@@ -353,8 +375,13 @@ def _fit(self, y, n, X):
353375
beta, sigma, tau = res.x[:-2], float(res.x[-2]), float(res.x[-1])
354376
tau = np.max([tau, 0])
355377
_, inv_cov = weighted_least_squares(y, sigma / n, X, tau, True)
356-
return {'fe_params': beta[:, None], 'sigma2': np.array(sigma), 'tau2': tau,
357-
'inv_cov': inv_cov}
378+
self.params_ = {
379+
'fe_params': beta[:, None],
380+
'sigma2': np.array(sigma),
381+
'tau2': tau,
382+
'inv_cov': inv_cov
383+
}
384+
return self
358385

359386
def _ml_nll(self, theta, y, n, X):
360387
""" ML negative log-likelihood for meta-regression model. """
@@ -431,7 +458,7 @@ def compile(self):
431458
from pystan import StanModel
432459
self.model = StanModel(model_code=spec)
433460

434-
def _fit(self, y, v, X, groups=None):
461+
def fit(self, y, v, X, groups=None):
435462
"""Run the Stan sampler and return results.
436463
437464
Args:
@@ -479,7 +506,7 @@ def _fit(self, y, v, X, groups=None):
479506
}
480507

481508
self.result_ = self.model.sampling(data=data, **self.sampling_kwargs)
482-
return self.result_
509+
return self
483510

484511
def summary(self, ci=95):
485512
if self.result_ is None:

pymare/results.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def permutation_test(self, n_perm=1000):
177177
y_perm = np.repeat(y[:, None], n_perm, axis=1)
178178

179179
# for v, we might actually be working with n, depending on estimator
180-
has_v = 'v' in getfullargspec(self.estimator._fit).args[1:]
180+
has_v = 'v' in getfullargspec(self.estimator.fit).args[1:]
181181
v = self.dataset.v[:, i] if has_v else self.dataset.n[:, i]
182182

183183
v_perm = np.repeat(v[:, None], n_perm, axis=1)
@@ -203,7 +203,7 @@ def permutation_test(self, n_perm=1000):
203203
# Pass parameters, remembering that v may actually be n
204204
kwargs = {'y': y_perm, 'X': self.dataset.X}
205205
kwargs['v' if has_v else 'n'] = v_perm
206-
params = self.estimator._fit(**kwargs)
206+
params = self.estimator.fit(**kwargs).params_
207207

208208
fe_obs = fe_stats['est'][:, i]
209209
if fe_obs.ndim == 1:
@@ -304,10 +304,10 @@ def permutation_test(self, n_perm=1000):
304304
y_perm *= signs
305305

306306
# Some combination tests can handle weights (passed as v)
307-
kwargs = {'y': y_perm}
308-
if 'v' in getfullargspec(est._fit).args:
309-
kwargs['v'] = self.dataset.v
310-
params = est._fit(**kwargs)
307+
kwargs = {'z': y_perm}
308+
if 'w' in getfullargspec(est.fit).args:
309+
kwargs['w'] = self.dataset.v
310+
params = est.fit(**kwargs).params_
311311

312312
p_obs = self.z[i]
313313
if p_obs.ndim == 1:

pymare/tests/test_combination_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,15 @@
2727

2828
@pytest.mark.parametrize("Cls,data,mode,expected", _params)
2929
def test_combination_test(Cls, data, mode, expected):
30-
results = Cls(mode)._fit(data)
30+
results = Cls(mode).fit(data).params_
3131
z = ss.norm.isf(results['p'])
3232
assert np.allclose(z, expected, atol=1e-5)
3333

3434

3535
@pytest.mark.parametrize("Cls,data,mode,expected", _params)
3636
def test_combination_test_from_dataset(Cls, data, mode, expected):
3737
dset = Dataset(y=data)
38-
est = Cls(mode).fit(dset)
38+
est = Cls(mode).fit_dataset(dset)
3939
results = est.summary()
4040
z = ss.norm.isf(results.p)
4141
assert np.allclose(z, expected, atol=1e-5)

0 commit comments

Comments
 (0)