Skip to content

Commit d332ad0

Browse files
committed
Adds fit_param test for StackingClassifier
1 parent 8b6332f commit d332ad0

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

mlxtend/classifier/stacking_classification.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,17 @@ def fit(self, X, y, **fit_params):
130130

131131
meta_features = self._predict_meta_features(X)
132132
# Extract fit_params for meta_clf_
133-
meta_clf_fit_params = {}
133+
meta_fit_params = {}
134134
meta_clf_name = list(self.named_meta_clf_.keys())[0]
135135
for key, value in six.iteritems(fit_params):
136136
if meta_clf_name in key and 'meta-' in meta_clf_name:
137-
meta_clf_fit_params[key.replace(meta_clf_name+'__', '')] = value
137+
meta_fit_params[key.replace(meta_clf_name+'__', '')] = value
138138

139139
if not self.use_features_in_secondary:
140-
self.meta_clf_.fit(meta_features, y, **meta_clf_fit_params)
140+
self.meta_clf_.fit(meta_features, y, **meta_fit_params)
141141
else:
142142
self.meta_clf_.fit(np.hstack((X, meta_features)), y,
143-
**meta_clf_fit_params)
143+
**meta_fit_params)
144144

145145
return self
146146

mlxtend/classifier/tests/test_stacking_classifier.py

+23
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,29 @@ def test_StackingClassifier_proba_concat_1():
7979
assert scores_mean == 0.93, scores_mean
8080

8181

82+
def test_StackingClassifier_fit_params():
83+
np.random.seed(123)
84+
meta = LogisticRegression()
85+
clf1 = RandomForestClassifier()
86+
clf2 = GaussianNB()
87+
sclf = StackingClassifier(classifiers=[clf1, clf2],
88+
meta_classifier=meta)
89+
n_samples = X.shape[0]
90+
fit_params = {
91+
'randomforestclassifier__sample_weight': np.ones(n_samples),
92+
'meta-logisticregression__sample_weight': np.arange(n_samples)
93+
}
94+
95+
scores = cross_val_score(sclf,
96+
X,
97+
y,
98+
cv=5,
99+
scoring='accuracy',
100+
fit_params=fit_params)
101+
scores_mean = (round(scores.mean(), 2))
102+
assert scores_mean == 0.95
103+
104+
82105
def test_StackingClassifier_avg_vs_concat():
83106
np.random.seed(123)
84107
lr1 = LogisticRegression()

0 commit comments

Comments
 (0)