1616
1717@wrapt .decorator
1818def _loopable (wrapped , instance , args , kwargs ):
19- # Decorator for _fit method of Estimator classes to handle naive looping
19+ # Decorator for fit() method of Estimator classes to handle naive looping
2020 # over the 2nd dimension of y/v/n inputs, and reconstruction of outputs.
2121 n_iter = kwargs ['y' ].shape [1 ]
2222 if n_iter > 10 :
@@ -26,6 +26,7 @@ def _loopable(wrapped, instance, args, kwargs):
2626 "datasets. Consider using the DL, HE, or WLS estimators, "
2727 "which handle parallel datasets more efficiently."
2828 .format (n_iter ))
29+
2930 param_dicts = []
3031 for i in range (n_iter ):
3132 iter_kwargs = {'X' : kwargs ['X' ]}
@@ -35,41 +36,58 @@ def _loopable(wrapped, instance, args, kwargs):
3536 if 'n' in kwargs :
3637 n = kwargs ['n' ][:, i , None ] if kwargs ['n' ].shape [1 ] > 1 else kwargs ['n' ]
3738 iter_kwargs ['n' ] = n
38- param_dicts .append (wrapped (** iter_kwargs ))
39+ wrapped (** iter_kwargs )
40+ param_dicts .append (instance .params_ .copy ())
41+
3942 params = {}
4043 for k in param_dicts [0 ]:
4144 concat = np .stack ([pd [k ].squeeze () for pd in param_dicts ], axis = - 1 )
4245 params [k ] = np .atleast_2d (concat )
43- return params
46+
47+ instance .params_ = params
48+ return instance
4449
4550
4651class BaseEstimator (metaclass = ABCMeta ):
4752
53+ # A class-level mapping from Dataset attributes to fit() arguments. Used by
54+ # fit_dataset() for estimators that take non-standard arguments (e.g., 'z'
55+ # instead of 'y'). Keys are default Dataset attribute names (e.g., 'y') and
56+ # values are the target arg names in the estimator class's fit() method
57+ # (e.g., 'z').
58+ _dataset_attr_map = {}
59+
4860 @abstractmethod
49- def _fit (self ):
50- # Subclasses must implement _fit() method that directly takes arrays.
51- # The following named arguments are allowed, and will be automatically
52- # extracted from the Dataset instance:
53- # * y (estimates)
54- # * v (variances)
55- # * n (sample_sizes)
56- # * X (predictors)
61+ def fit (self , * args , ** kwargs ):
5762 pass
5863
59- def fit (self , dataset = None , ** kwargs ):
64+ def fit_dataset (self , dataset , * args , ** kwargs ):
65+ """ Applies the current estimator to the passed Dataset container.
6066
61- if dataset is not None :
62- kwargs = {}
63- spec = getfullargspec (self ._fit )
64- n_kw = len (spec .defaults ) if spec .defaults else 0
65- n_args = len (spec .args ) - n_kw - 1
66- for i , name in enumerate (spec .args [1 :]):
67- if i >= n_args :
68- kwargs [name ] = getattr (dataset , name , spec .defaults [i - n_args ])
69- else :
70- kwargs [name ] = getattr (dataset , name )
67+ A convenience interface that wraps fit() and automatically aligns the
68+ variables held in a Dataset with the required arguments.
7169
72- self .params_ = self ._fit (** kwargs )
70+ Args:
71+ dataset (Dataset): A PyMARE Dataset instance holding the data.
72+ args, kwargs: optional positional and keyword arguments to pass
73+ onto the fit() method.
74+ """
75+ all_kwargs = {}
76+ spec = getfullargspec (self .fit )
77+ n_kw = len (spec .defaults ) if spec .defaults else 0
78+ n_args = len (spec .args ) - n_kw - 1
79+
80+ for i , name in enumerate (spec .args [1 :]):
81+ # Check for remapped name
82+ attr_name = self ._dataset_attr_map .get (name , name )
83+ if i >= n_args :
84+ all_kwargs [name ] = getattr (dataset , attr_name ,
85+ spec .defaults [i - n_args ])
86+ else :
87+ all_kwargs [name ] = getattr (dataset , attr_name )
88+
89+ all_kwargs .update (kwargs )
90+ self .fit (* args , ** all_kwargs )
7391 self .dataset_ = dataset
7492
7593 return self
@@ -86,7 +104,7 @@ def get_v(self, dataset):
86104 Notes:
87105 This is equivalent to directly accessing `dataset.v` when variances
88106 are present, but affords a way of estimating v from sample size (n)
89- for any estimator that implicitly estimate a sigma^2 parameter.
107+ for any estimator that implicitly estimates a sigma^2 parameter.
90108 """
91109 if dataset .v is not None :
92110 return dataset .v
@@ -139,12 +157,13 @@ class WeightedLeastSquares(BaseEstimator):
139157 def __init__ (self , tau2 = 0. ):
140158 self .tau2 = tau2
141159
142- def _fit (self , y , X , v = None ):
160+ def fit (self , y , X , v = None ):
143161 if v is None :
144162 v = np .ones_like (y )
145163 beta , inv_cov = weighted_least_squares (y , v , X , self .tau2 ,
146164 return_cov = True )
147- return {'fe_params' : beta , 'tau2' : self .tau2 , 'inv_cov' : inv_cov }
165+ self .params_ = {'fe_params' : beta , 'tau2' : self .tau2 , 'inv_cov' : inv_cov }
166+ return self
148167
149168
150169class DerSimonianLaird (BaseEstimator ):
@@ -167,7 +186,7 @@ class DerSimonianLaird(BaseEstimator):
167186 identical for all iterates.
168187 """
169188
170- def _fit (self , y , v , X ):
189+ def fit (self , y , v , X ):
171190 k , p = X .shape
172191
173192 # Estimate initial betas with WLS, assuming tau^2=0
@@ -189,7 +208,8 @@ def _fit(self, y, v, X):
189208 # Re-estimate beta with tau^2 estimate
190209 beta_dl , inv_cov = weighted_least_squares (y , v , X , tau2 = tau_dl ,
191210 return_cov = True )
192- return {'fe_params' : beta_dl , 'tau2' : tau_dl , 'inv_cov' : inv_cov }
211+ self .params_ = {'fe_params' : beta_dl , 'tau2' : tau_dl , 'inv_cov' : inv_cov }
212+ return self
193213
194214
195215class Hedges (BaseEstimator ):
@@ -208,7 +228,7 @@ class Hedges(BaseEstimator):
208228 identical for all iterates.
209229 """
210230
211- def _fit (self , y , v , X ):
231+ def fit (self , y , v , X ):
212232 k , p = X .shape [:2 ]
213233 _unit_v = np .ones_like (y )
214234 beta , inv_cov = weighted_least_squares (y , _unit_v , X , return_cov = True )
@@ -217,7 +237,8 @@ def _fit(self, y, v, X):
217237 tau_ho = np .maximum (0 , tau_ho )
218238 # Estimate beta with tau^2 estimate
219239 beta_ho = weighted_least_squares (y , v , X , tau2 = tau_ho )
220- return {'fe_params' : beta_ho , 'tau2' : tau_ho , 'inv_cov' : inv_cov }
240+ self .params_ = {'fe_params' : beta_ho , 'tau2' : tau_ho , 'inv_cov' : inv_cov }
241+ return self
221242
222243
223244class VarianceBasedLikelihoodEstimator (BaseEstimator ):
@@ -255,9 +276,9 @@ def __init__(self, method='ml', **kwargs):
255276 self .kwargs = kwargs
256277
257278 @_loopable
258- def _fit (self , y , v , X ):
279+ def fit (self , y , v , X ):
259280 # use D-L estimate for initial values
260- est_DL = DerSimonianLaird ()._fit (y , v , X )
281+ est_DL = DerSimonianLaird ().fit (y , v , X ). params_
261282 beta = est_DL ['fe_params' ]
262283 tau2 = est_DL ['tau2' ]
263284
@@ -273,7 +294,8 @@ def _fit(self, y, v, X):
273294 beta , tau = res .x [:- 1 ], float (res .x [- 1 ])
274295 tau = np .max ([tau , 0 ])
275296 _ , inv_cov = weighted_least_squares (y , v , X , tau , True )
276- return {'fe_params' : beta [:, None ], 'tau2' : tau , 'inv_cov' : inv_cov }
297+ self .params_ = {'fe_params' : beta [:, None ], 'tau2' : tau , 'inv_cov' : inv_cov }
298+ return self
277299
278300 def _ml_nll (self , theta , y , v , X ):
279301 """ ML negative log-likelihood for meta-regression model. """
@@ -329,7 +351,7 @@ def __init__(self, method='ml', **kwargs):
329351 self .kwargs = kwargs
330352
331353 @_loopable
332- def _fit (self , y , n , X ):
354+ def fit (self , y , n , X ):
333355 if n .std () < np .sqrt (np .finfo (float ).eps ):
334356 raise ValueError ("Sample size-based likelihood estimator cannot "
335357 "work with all-equal sample sizes." )
@@ -353,8 +375,13 @@ def _fit(self, y, n, X):
353375 beta , sigma , tau = res .x [:- 2 ], float (res .x [- 2 ]), float (res .x [- 1 ])
354376 tau = np .max ([tau , 0 ])
355377 _ , inv_cov = weighted_least_squares (y , sigma / n , X , tau , True )
356- return {'fe_params' : beta [:, None ], 'sigma2' : np .array (sigma ), 'tau2' : tau ,
357- 'inv_cov' : inv_cov }
378+ self .params_ = {
379+ 'fe_params' : beta [:, None ],
380+ 'sigma2' : np .array (sigma ),
381+ 'tau2' : tau ,
382+ 'inv_cov' : inv_cov
383+ }
384+ return self
358385
359386 def _ml_nll (self , theta , y , n , X ):
360387 """ ML negative log-likelihood for meta-regression model. """
@@ -431,7 +458,7 @@ def compile(self):
431458 from pystan import StanModel
432459 self .model = StanModel (model_code = spec )
433460
434- def _fit (self , y , v , X , groups = None ):
461+ def fit (self , y , v , X , groups = None ):
435462 """Run the Stan sampler and return results.
436463
437464 Args:
@@ -479,7 +506,7 @@ def _fit(self, y, v, X, groups=None):
479506 }
480507
481508 self .result_ = self .model .sampling (data = data , ** self .sampling_kwargs )
482- return self . result_
509+ return self
483510
484511 def summary (self , ci = 95 ):
485512 if self .result_ is None :
0 commit comments