Skip to content

Commit 4e222fe

Browse files
committed
Adds fit_params to StackingRegressor fit method + adds test
1 parent d332ad0 commit 4e222fe

File tree

3 files changed

+44
-7
lines changed

3 files changed

+44
-7
lines changed

mlxtend/classifier/stacking_classification.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def fit(self, X, y, **fit_params):
8888
y : array-like, shape = [n_samples] or [n_samples, n_outputs]
8989
Target values.
9090
fit_params : dict, optional
91-
Parameters to pass to the fit methods of the classifiers and
92-
meta_classifier.
91+
Parameters to pass to the fit methods of `classifiers` and
92+
`meta_classifier`.
9393
9494
Returns
9595
-------

mlxtend/regressor/stacking_regression.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(self, regressors, meta_regressor, verbose=0):
6565
_name_estimators([meta_regressor])}
6666
self.verbose = verbose
6767

68-
def fit(self, X, y):
68+
def fit(self, X, y, **fit_params):
6969
"""Learn weight coefficients from training data for each regressor.
7070
7171
Parameters
@@ -75,18 +75,25 @@ def fit(self, X, y):
7575
n_features is the number of features.
7676
y : array-like, shape = [n_samples] or [n_samples, n_targets]
7777
Target values.
78+
fit_params : dict, optional
79+
Parameters to pass to the fit methods of `regressors` and
80+
`meta_regressor`.
7881
7982
Returns
8083
-------
8184
self : object
8285
8386
"""
8487
self.regr_ = [clone(regr) for regr in self.regressors]
88+
self.named_regr_ = {key: value for key, value in
89+
_name_estimators(self.regr_)}
8590
self.meta_regr_ = clone(self.meta_regressor)
91+
self.named_meta_regr_ = {'meta-%s' % key: value for key, value in
92+
_name_estimators([self.meta_regr_])}
8693
if self.verbose > 0:
8794
print("Fitting %d regressors..." % (len(self.regressors)))
8895

89-
for regr in self.regr_:
96+
for name, regr in six.iteritems(self.named_regr_):
9097

9198
if self.verbose > 0:
9299
i = self.regr_.index(regr) + 1
@@ -100,10 +107,23 @@ def fit(self, X, y):
100107
if self.verbose > 1:
101108
print(_name_estimators((regr,))[0][1])
102109

103-
regr.fit(X, y)
110+
# Extract fit_params for regr
111+
regr_fit_params = {}
112+
for key, value in six.iteritems(fit_params):
113+
if name in key and 'meta-' not in key:
114+
regr_fit_params[key.replace(name+'__', '')] = value
115+
116+
regr.fit(X, y, **regr_fit_params)
104117

105118
meta_features = self._predict_meta_features(X)
106-
self.meta_regr_.fit(meta_features, y)
119+
# Extract fit_params for meta_regr_
120+
meta_fit_params = {}
121+
meta_regr_name = list(self.named_meta_regr_.keys())[0]
122+
for key, value in six.iteritems(fit_params):
123+
if meta_regr_name in key and 'meta-' in meta_regr_name:
124+
meta_fit_params[key.replace(meta_regr_name+'__', '')] = value
125+
self.meta_regr_.fit(meta_features, y, **meta_fit_params)
126+
107127
return self
108128

109129
@property

mlxtend/regressor/tests/test_stacking_regression.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import numpy as np
1212
from numpy.testing import assert_almost_equal
1313
from nose.tools import raises
14-
from sklearn.model_selection import GridSearchCV
14+
from sklearn.model_selection import GridSearchCV, cross_val_score
1515

1616
# Generating a sample dataset
1717
np.random.seed(1)
@@ -108,6 +108,23 @@ def test_gridsearch_numerate_regr():
108108
assert best == got
109109

110110

111+
def test_StackingRegressor_fit_params():
112+
lr = LinearRegression()
113+
svr_lin = SVR(kernel='linear')
114+
ridge = Ridge(random_state=1)
115+
svr_rbf = SVR(kernel='rbf')
116+
stregr = StackingRegressor(regressors=[svr_lin, lr, ridge],
117+
meta_regressor=svr_rbf)
118+
119+
fit_params = {'ridge__sample_weight': np.ones(X1.shape[0]),
120+
'svr__sample_weight': np.ones(X1.shape[0]),
121+
'meta-svr__sample_weight': np.ones(X1.shape[0])}
122+
123+
scores = cross_val_score(stregr, X1, y, cv=5, fit_params=fit_params)
124+
scores_mean = (round(scores.mean(), 1))
125+
assert scores_mean == 0.1
126+
127+
111128
def test_get_coeff():
112129
lr = LinearRegression()
113130
svr_lin = SVR(kernel='linear')

0 commit comments

Comments
 (0)