File tree Expand file tree Collapse file tree 1 file changed +5
-5
lines changed
Expand file tree Collapse file tree 1 file changed +5
-5
lines changed Original file line number Diff line number Diff line change @@ -123,12 +123,12 @@ def __init__(
123123 'or a valid scikit-learn classifier class.'
124124 )
125125 assert issubclass (classifier , BaseEstimator ), (
126- "classifier must be a subclass of sklearn's BaseEstimator"
126+ "classier must either be a string or a subclass of BaseEstimator. "
127127 )
128128 self .clf_class = classifier
129129
130- self . clf_kwargs = classifier_kwargs
131- if self . clf_kwargs is None :
130+ # for MLPClassifier, set default parameters
131+ if classifier_kwargs is None :
132132 if self .clf_class == MLPClassifier :
133133 ndim = thetas .shape [- 1 ]
134134 self .clf_kwargs = {
@@ -140,7 +140,7 @@ def __init__(
140140 "n_iter_no_change" : 50 ,
141141 }
142142 else :
143- self .clf_kwargs = {}
143+ self .clf_kwargs : Dict [ str , Any ] = {}
144144
145145 # initialize classifiers, will be set after training
146146 self .trained_clfs = None
@@ -269,7 +269,7 @@ def train_on_observed_data(
269269 if seed is not None :
270270 if "random_state" in self .clf_kwargs :
271271 print ("WARNING: changing the random state of the classifier." )
272- self .clf_kwargs ["random_state" ] = seed # type: ignore
272+ self .clf_kwargs ["random_state" ] = seed
273273
274274 # train the classifier
275275 trained_clfs = self ._train (
You can’t perform that action at this time.
0 commit comments