Skip to content

Commit 8abe542

Browse files
committed
Adds fit_params to StackingCVClassifier fit method
1 parent 4e222fe commit 8abe542

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

mlxtend/classifier/stacking_cv_classification.py

+33-9
Original file line numberDiff line numberDiff line change
@@ -111,29 +111,36 @@ def __init__(self, classifiers, meta_classifier,
111111
self.stratify = stratify
112112
self.shuffle = shuffle
113113

114-
def fit(self, X, y, groups=None):
114+
def fit(self, X, y, groups=None, **fit_params):
115115
""" Fit ensemble classifers and the meta-classifier.
116116
117117
Parameters
118118
----------
119119
X : numpy array, shape = [n_samples, n_features]
120120
Training vectors, where n_samples is the number of samples and
121121
n_features is the number of features.
122-
123122
y : numpy array, shape = [n_samples]
124123
Target values.
125-
126124
groups : numpy array/None, shape = [n_samples]
127125
The group that each sample belongs to. This is used by specific
128126
folding strategies such as GroupKFold()
127+
fit_params : dict, optional
128+
Parameters to pass to the fit methods of `classifiers` and
129+
`meta_classifier`. Note that only fit parameters for `classifiers`
130+
that are the same for each cross-validation split are supported
131+
(e.g. `sample_weight` is not currently supported).
129132
130133
Returns
131134
-------
132135
self : object
133136
134137
"""
135138
self.clfs_ = [clone(clf) for clf in self.classifiers]
139+
self.named_clfs_ = {key: value for key, value in
140+
_name_estimators(self.clfs_)}
136141
self.meta_clf_ = clone(self.meta_classifier)
142+
self.named_meta_clf_ = {'meta-%s' % key: value for key, value in
143+
_name_estimators([self.meta_clf_])}
137144
if self.verbose > 0:
138145
print("Fitting %d classifiers..." % (len(self.classifiers)))
139146

@@ -144,8 +151,23 @@ def fit(self, X, y, groups=None):
144151
final_cv.shuffle = self.shuffle
145152
skf = list(final_cv.split(X, y, groups))
146153

154+
# Get fit_params for each classifier in self.named_clfs_
155+
named_clfs_fit_params = {}
156+
for name, clf in six.iteritems(self.named_clfs_):
157+
clf_fit_params = {}
158+
for key, value in six.iteritems(fit_params):
159+
if name in key and 'meta-' not in key:
160+
clf_fit_params[key.replace(name+'__', '')] = value
161+
named_clfs_fit_params[name] = clf_fit_params
162+
# Get fit_params for self.named_meta_clf_
163+
meta_fit_params = {}
164+
meta_clf_name = list(self.named_meta_clf_.keys())[0]
165+
for key, value in six.iteritems(fit_params):
166+
if meta_clf_name in key and 'meta-' in meta_clf_name:
167+
meta_fit_params[key.replace(meta_clf_name+'__', '')] = value
168+
147169
all_model_predictions = np.array([]).reshape(len(y), 0)
148-
for model in self.clfs_:
170+
for name, model in six.iteritems(self.named_clfs_):
149171

150172
if self.verbose > 0:
151173
i = self.clfs_.index(model) + 1
@@ -172,7 +194,8 @@ def fit(self, X, y, groups=None):
172194
((num + 1), final_cv.get_n_splits()))
173195

174196
try:
175-
model.fit(X[train_index], y[train_index])
197+
model.fit(X[train_index], y[train_index],
198+
**named_clfs_fit_params[name])
176199
except TypeError as e:
177200
raise TypeError(str(e) + '\nPlease check that X and y'
178201
'are NumPy arrays. If X and y are lists'
@@ -215,16 +238,17 @@ def fit(self, X, y, groups=None):
215238
X[test_index]))
216239

217240
# Fit the base models correctly this time using ALL the training set
218-
for model in self.clfs_:
219-
model.fit(X, y)
241+
for name, model in six.iteritems(self.named_clfs_):
242+
model.fit(X, y, **named_clfs_fit_params[name])
220243

221244
# Fit the secondary model
222245
if not self.use_features_in_secondary:
223-
self.meta_clf_.fit(all_model_predictions, reordered_labels)
246+
self.meta_clf_.fit(all_model_predictions, reordered_labels,
247+
**meta_fit_params)
224248
else:
225249
self.meta_clf_.fit(np.hstack((reordered_features,
226250
all_model_predictions)),
227-
reordered_labels)
251+
reordered_labels, **meta_fit_params)
228252

229253
return self
230254

mlxtend/classifier/tests/test_stacking_cv_classifier.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from mlxtend.classifier import StackingCVClassifier
99

1010
import pandas as pd
11-
from sklearn.linear_model import LogisticRegression
11+
from sklearn.linear_model import LogisticRegression, SGDClassifier
1212
from sklearn.naive_bayes import GaussianNB
1313
from sklearn.ensemble import RandomForestClassifier
1414
from sklearn.neighbors import KNeighborsClassifier
@@ -61,6 +61,29 @@ def test_StackingClassifier_proba():
6161
assert scores_mean == 0.93
6262

6363

64+
def test_StackingClassifier_fit_params():
65+
np.random.seed(123)
66+
meta = LogisticRegression()
67+
clf1 = RandomForestClassifier()
68+
clf2 = SGDClassifier(random_state=2)
69+
sclf = StackingCVClassifier(classifiers=[clf1, clf2],
70+
meta_classifier=meta,
71+
shuffle=False)
72+
fit_params = {
73+
'sgdclassifier__intercept_init': np.unique(y),
74+
'meta-logisticregression__sample_weight': np.full(X.shape[0], 2)
75+
}
76+
77+
scores = cross_val_score(sclf,
78+
X,
79+
y,
80+
cv=5,
81+
scoring='accuracy',
82+
fit_params=fit_params)
83+
scores_mean = (round(scores.mean(), 2))
84+
assert scores_mean == 0.86
85+
86+
6487
def test_gridsearch():
6588
np.random.seed(123)
6689
meta = LogisticRegression()

0 commit comments

Comments
 (0)