Skip to content

Commit d6fb040

Browse files
authored
fix: LC2ST kwarg defaults type hints (#1565)
* fix kwarg typing * fix if else logic * fix init bugs
1 parent a1a4077 commit d6fb040

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

sbi/diagnostics/lc2st.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ def __init__(
123123
'or a valid scikit-learn classifier class.'
124124
)
125125
assert issubclass(classifier, BaseEstimator), (
126-
"classifier must be a subclass of sklearn's BaseEstimator"
126+
"classier must either be a string or a subclass of BaseEstimator."
127127
)
128128
self.clf_class = classifier
129129

130-
self.clf_kwargs = classifier_kwargs
131-
if self.clf_kwargs is None:
130+
# for MLPClassifier, set default parameters
131+
if classifier_kwargs is None:
132132
if self.clf_class == MLPClassifier:
133133
ndim = thetas.shape[-1]
134134
self.clf_kwargs = {
@@ -140,7 +140,7 @@ def __init__(
140140
"n_iter_no_change": 50,
141141
}
142142
else:
143-
self.clf_kwargs = {}
143+
self.clf_kwargs: Dict[str, Any] = {}
144144

145145
# initialize classifiers, will be set after training
146146
self.trained_clfs = None
@@ -269,7 +269,7 @@ def train_on_observed_data(
269269
if seed is not None:
270270
if "random_state" in self.clf_kwargs:
271271
print("WARNING: changing the random state of the classifier.")
272-
self.clf_kwargs["random_state"] = seed # type: ignore
272+
self.clf_kwargs["random_state"] = seed
273273

274274
# train the classifier
275275
trained_clfs = self._train(

0 commit comments

Comments
 (0)