Skip to content

Commit c27576b

Browse files
AnneBeyerlucyleeow
andauthored
FIX n_classes in DecisionBoundaryDisplay with custom estimators (scikit-learn#33202)
Co-authored-by: Lucy Liu <jliu176@gmail.com>
1 parent 9c36752 commit c27576b

4 files changed

Lines changed: 136 additions & 16 deletions

File tree

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- In :class:`inspection.DecisionBoundaryDisplay`, `n_classes` is now inferred more
2+
robustly from the estimator. If it fails for custom estimators, a comprehensive error
3+
message is shown.
4+
By :user:`Anne Beyer <AnneBeyer>`.

examples/cluster/plot_inductive_clustering.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import matplotlib.pyplot as plt
2727

28-
from sklearn.base import BaseEstimator, clone
28+
from sklearn.base import BaseEstimator, ClusterMixin, clone
2929
from sklearn.cluster import AgglomerativeClustering
3030
from sklearn.datasets import make_blobs
3131
from sklearn.ensemble import RandomForestClassifier
@@ -50,7 +50,7 @@ def _classifier_has(attr):
5050
)
5151

5252

53-
class InductiveClusterer(BaseEstimator):
53+
class InductiveClusterer(ClusterMixin, BaseEstimator):
5454
def __init__(self, clusterer, classifier):
5555
self.clusterer = clusterer
5656
self.classifier = classifier
@@ -60,6 +60,7 @@ def fit(self, X, y=None):
6060
self.classifier_ = clone(self.classifier)
6161
y = self.clusterer_.fit_predict(X)
6262
self.classifier_.fit(X, y)
63+
self.labels_ = y
6364
return self
6465

6566
@available_if(_classifier_has("predict"))
@@ -122,7 +123,12 @@ def plot_scatter(X, color, alpha=0.5):
122123

123124
# Plotting decision regions
124125
DecisionBoundaryDisplay.from_estimator(
125-
inductive_learner, X, response_method="predict", alpha=0.4, ax=ax
126+
inductive_learner,
127+
X,
128+
response_method="predict",
129+
multiclass_colors="viridis",
130+
alpha=0.4,
131+
ax=ax,
126132
)
127133
plt.title("Classify unknown instances")
128134

sklearn/inspection/_plot/decision_boundary.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from sklearn.utils._optional_dependencies import check_matplotlib_support
1313
from sklearn.utils._response import _get_response_values
1414
from sklearn.utils._set_output import _get_adapter_from_container
15+
from sklearn.utils.multiclass import type_of_target
1516
from sklearn.utils.validation import (
1617
_is_arraylike_not_scalar,
1718
_num_features,
@@ -576,6 +577,31 @@ def from_estimator(
576577
encoder.classes_ = estimator.classes_
577578
response = encoder.transform(response)
578579

580+
# infer n_classes from the estimator
581+
if (
582+
class_of_interest is not None
583+
or is_regressor(estimator)
584+
or is_outlier_detector(estimator)
585+
):
586+
n_classes = 2
587+
elif is_classifier(estimator) and hasattr(estimator, "classes_"):
588+
n_classes = len(estimator.classes_)
589+
elif is_clusterer(estimator) and hasattr(estimator, "labels_"):
590+
n_classes = len(np.unique(estimator.labels_))
591+
else:
592+
target_type = type_of_target(response)
593+
if target_type in ("binary", "continuous"):
594+
n_classes = 2
595+
elif target_type == "multiclass":
596+
n_classes = len(np.unique(response))
597+
else:
598+
raise ValueError(
599+
"Number of classes or labels cannot be inferred from "
600+
f"{estimator.__class__.__name__}. Please make sure your estimator "
601+
"follows scikit-learn's estimator API as described here: "
602+
"https://scikit-learn.org/stable/developers/develop.html#rolling-your-own-estimator"
603+
)
604+
579605
if response.ndim == 1:
580606
response = response.reshape(*xx0.shape)
581607
else:
@@ -591,17 +617,6 @@ def from_estimator(
591617
else:
592618
response = response.reshape(*xx0.shape, response.shape[-1])
593619

594-
if (
595-
class_of_interest is not None
596-
or is_regressor(estimator)
597-
or is_outlier_detector(estimator)
598-
):
599-
n_classes = 2
600-
elif is_classifier(estimator):
601-
n_classes = len(estimator.classes_)
602-
elif is_clusterer(estimator):
603-
n_classes = len(np.unique(estimator.labels_))
604-
605620
if xlabel is None:
606621
xlabel = X.columns[0] if hasattr(X, "columns") else ""
607622

sklearn/inspection/_plot/tests/test_boundary_decision_display.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from sklearn.inspection import DecisionBoundaryDisplay
2020
from sklearn.inspection._plot.decision_boundary import _check_boundary_response_method
2121
from sklearn.linear_model import LogisticRegression
22-
from sklearn.preprocessing import scale
22+
from sklearn.pipeline import Pipeline
23+
from sklearn.preprocessing import StandardScaler, scale
2324
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
2425
from sklearn.utils._testing import (
2526
_convert_container,
@@ -703,6 +704,34 @@ def test_multiclass_colors_cmap(
703704
# assert len(disp.surface_.levels) >= disp.n_classes
704705

705706

707+
# estimator classes for non-regression test cases for issue #33194
708+
class CustomBinaryEstimator(BaseEstimator):
709+
def fit(self, X, y):
710+
self.fitted_ = True
711+
return self
712+
713+
def predict(self, X):
714+
return np.arange(X.shape[0]) % 2
715+
716+
717+
class CustomMulticlassEstimator(BaseEstimator):
718+
def fit(self, X, y):
719+
self.fitted_ = True
720+
return self
721+
722+
def predict(self, X):
723+
return np.arange(X.shape[0]) % 7
724+
725+
726+
class CustomContinuousEstimator(BaseEstimator):
727+
def fit(self, X, y):
728+
self.fitted_ = True
729+
return self
730+
731+
def predict(self, X):
732+
return np.arange(X.shape[0]) * 0.5
733+
734+
706735
@pytest.mark.parametrize(
707736
"estimator, n_blobs, expected_n_classes",
708737
[
@@ -712,12 +741,57 @@ def test_multiclass_colors_cmap(
712741
(KMeans(n_clusters=2, random_state=0), 2, 2),
713742
(DecisionTreeRegressor(random_state=0), 7, 2),
714743
(IsolationForest(random_state=0), 7, 2),
744+
(CustomBinaryEstimator(), 2, 2),
745+
(CustomMulticlassEstimator(), 7, 7),
746+
(CustomContinuousEstimator(), 7, 2),
747+
(
748+
Pipeline(
749+
[
750+
("scale", StandardScaler()),
751+
("dt", DecisionTreeClassifier(random_state=0)),
752+
]
753+
),
754+
7,
755+
7,
756+
),
757+
# non-regression test case for issue #33194
758+
(
759+
Pipeline(
760+
[
761+
("scale", StandardScaler()),
762+
("kmeans", KMeans(n_clusters=7, random_state=0)),
763+
]
764+
),
765+
7,
766+
7,
767+
),
768+
(
769+
Pipeline(
770+
[
771+
("scale", StandardScaler()),
772+
("reg", DecisionTreeRegressor(random_state=0)),
773+
]
774+
),
775+
7,
776+
2,
777+
),
778+
(
779+
Pipeline(
780+
[
781+
("scale", StandardScaler()),
782+
("kmeans", IsolationForest(random_state=0)),
783+
]
784+
),
785+
7,
786+
2,
787+
),
715788
],
716789
)
717790
def test_n_classes_attribute(pyplot, estimator, n_blobs, expected_n_classes):
718791
"""Check that `n_classes` is set correctly.
719792
720-
Introduced in https://github.com/scikit-learn/scikit-learn/pull/33015"""
793+
Introduced in https://github.com/scikit-learn/scikit-learn/pull/33015.
794+
"""
721795

722796
X, y = make_blobs(n_samples=150, centers=n_blobs, n_features=2, random_state=42)
723797
clf = estimator.fit(X, y)
@@ -731,6 +805,27 @@ def test_n_classes_attribute(pyplot, estimator, n_blobs, expected_n_classes):
731805
assert disp_coi.n_classes == 2
732806

733807

808+
def test_n_classes_raises_if_not_inferrable(pyplot):
809+
"""Check behaviour if `n_classes` can't be inferred.
810+
811+
Non-regression test for issue #33194.
812+
"""
813+
814+
class CustomUnknownEstimator(BaseEstimator):
815+
def fit(self, X, y):
816+
self.fitted_ = True
817+
return self
818+
819+
def predict(self, X):
820+
return np.array(0)
821+
822+
X, y = load_iris_2d_scaled()
823+
est = CustomUnknownEstimator().fit(X, y)
824+
msg = "Number of classes or labels cannot be inferred from CustomUnknownEstimator"
825+
with pytest.raises(ValueError, match=msg):
826+
DecisionBoundaryDisplay.from_estimator(est, X, response_method="predict")
827+
828+
734829
def test_cmap_and_colors_logic(pyplot):
735830
"""Check the handling logic for `cmap` and `colors`."""
736831
X, y = load_iris_2d_scaled()

0 commit comments

Comments
 (0)