@@ -111,29 +111,36 @@ def __init__(self, classifiers, meta_classifier,
111
111
self .stratify = stratify
112
112
self .shuffle = shuffle
113
113
114
- def fit (self , X , y , groups = None ):
114
+ def fit (self , X , y , groups = None , ** fit_params ):
115
115
""" Fit ensemble classifers and the meta-classifier.
116
116
117
117
Parameters
118
118
----------
119
119
X : numpy array, shape = [n_samples, n_features]
120
120
Training vectors, where n_samples is the number of samples and
121
121
n_features is the number of features.
122
-
123
122
y : numpy array, shape = [n_samples]
124
123
Target values.
125
-
126
124
groups : numpy array/None, shape = [n_samples]
127
125
The group that each sample belongs to. This is used by specific
128
126
folding strategies such as GroupKFold()
127
+ fit_params : dict, optional
128
+ Parameters to pass to the fit methods of `classifiers` and
129
+ `meta_classifier`. Note that only fit parameters for `classifiers`
130
+ that are the same for each cross-validation split are supported
131
+ (e.g. `sample_weight` is not currently supported).
129
132
130
133
Returns
131
134
-------
132
135
self : object
133
136
134
137
"""
135
138
self .clfs_ = [clone (clf ) for clf in self .classifiers ]
139
+ self .named_clfs_ = {key : value for key , value in
140
+ _name_estimators (self .clfs_ )}
136
141
self .meta_clf_ = clone (self .meta_classifier )
142
+ self .named_meta_clf_ = {'meta-%s' % key : value for key , value in
143
+ _name_estimators ([self .meta_clf_ ])}
137
144
if self .verbose > 0 :
138
145
print ("Fitting %d classifiers..." % (len (self .classifiers )))
139
146
@@ -144,8 +151,23 @@ def fit(self, X, y, groups=None):
144
151
final_cv .shuffle = self .shuffle
145
152
skf = list (final_cv .split (X , y , groups ))
146
153
154
+ # Get fit_params for each classifier in self.named_clfs_
155
+ named_clfs_fit_params = {}
156
+ for name , clf in six .iteritems (self .named_clfs_ ):
157
+ clf_fit_params = {}
158
+ for key , value in six .iteritems (fit_params ):
159
+ if name in key and 'meta-' not in key :
160
+ clf_fit_params [key .replace (name + '__' , '' )] = value
161
+ named_clfs_fit_params [name ] = clf_fit_params
162
+ # Get fit_params for self.named_meta_clf_
163
+ meta_fit_params = {}
164
+ meta_clf_name = list (self .named_meta_clf_ .keys ())[0 ]
165
+ for key , value in six .iteritems (fit_params ):
166
+ if meta_clf_name in key and 'meta-' in meta_clf_name :
167
+ meta_fit_params [key .replace (meta_clf_name + '__' , '' )] = value
168
+
147
169
all_model_predictions = np .array ([]).reshape (len (y ), 0 )
148
- for model in self .clfs_ :
170
+ for name , model in six . iteritems ( self .named_clfs_ ) :
149
171
150
172
if self .verbose > 0 :
151
173
i = self .clfs_ .index (model ) + 1
@@ -172,7 +194,8 @@ def fit(self, X, y, groups=None):
172
194
((num + 1 ), final_cv .get_n_splits ()))
173
195
174
196
try :
175
- model .fit (X [train_index ], y [train_index ])
197
+ model .fit (X [train_index ], y [train_index ],
198
+ ** named_clfs_fit_params [name ])
176
199
except TypeError as e :
177
200
raise TypeError (str (e ) + '\n Please check that X and y'
178
201
'are NumPy arrays. If X and y are lists'
@@ -215,16 +238,17 @@ def fit(self, X, y, groups=None):
215
238
X [test_index ]))
216
239
217
240
# Fit the base models correctly this time using ALL the training set
218
- for model in self .clfs_ :
219
- model .fit (X , y )
241
+ for name , model in six . iteritems ( self .named_clfs_ ) :
242
+ model .fit (X , y , ** named_clfs_fit_params [ name ] )
220
243
221
244
# Fit the secondary model
222
245
if not self .use_features_in_secondary :
223
- self .meta_clf_ .fit (all_model_predictions , reordered_labels )
246
+ self .meta_clf_ .fit (all_model_predictions , reordered_labels ,
247
+ ** meta_fit_params )
224
248
else :
225
249
self .meta_clf_ .fit (np .hstack ((reordered_features ,
226
250
all_model_predictions )),
227
- reordered_labels )
251
+ reordered_labels , ** meta_fit_params )
228
252
229
253
return self
230
254
0 commit comments