Skip to content

Commit 16a55a5

Browse files
committed
tentative fix for #649
1 parent 2bf1e02 commit 16a55a5

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

CHANGES.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,4 +199,7 @@ v<2.0.2>, <07/06/2024> -- Complete of removing all Tensorflow and Keras code.
199199
v<2.0.2>, <07/21/2024> -- Add DevNet.
200200
v<2.0.3>, <09/06/2024> -- Add Reject Option in Unsupervised Anomaly Detection (#605).
201201
v<2.0.3>, <12/20/2024> -- Massive documentation polish.
202-
v<2.0.5>, <09/04/2025> -- Finally, add the auto model selector (#616).
202+
v<2.0.4>, <04/29/2025> -- Mistakenly we skipped 2.0.4.
203+
v<2.0.5>, <04/29/2025> -- Add wheel for better installation.
204+
v<2.0.6>, <09/04/2025> -- Finally, add the auto model selector (#616).
205+
v<2.0.6>, <12/01/2025> -- Pre-caution for new sklearn break change(#649).

pyod/models/base.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from numpy import percentile
1515
from scipy.special import erf
1616
from scipy.stats import binom
17+
from sklearn.base import BaseEstimator
1718
from sklearn.metrics import roc_auc_score
1819
from sklearn.preprocessing import MinMaxScaler
1920
from sklearn.utils import deprecated
@@ -25,7 +26,7 @@
2526
from ..utils.utility import precision_n_scores
2627

2728

28-
class BaseDetector(metaclass=abc.ABCMeta):
29+
class BaseDetector(BaseEstimator, metaclass=abc.ABCMeta):
2930
"""Abstract class for all outlier detection algorithms.
3031
3132
@@ -711,3 +712,16 @@ def __repr__(self):
711712
class_name = self.__class__.__name__
712713
return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
713714
offset=len(class_name), ),)
715+
716+
717+
def __sklearn_tags__(self):
718+
"""Return sklearn-style Tags for compatibility with scikit-learn >= 1.8.
719+
720+
We mark all PyOD detectors as 'outlier_detector' so that utilities
721+
such as sklearn.utils._tags.get_tags and is_outlier_detector work.
722+
"""
723+
tags = super().__sklearn_tags__()
724+
# match sklearn's OutlierMixin
725+
tags.estimator_type = "outlier_detector"
726+
return tags
727+

pyod/test/test_base.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@
1515
from pyod.models.base import BaseDetector
1616
from pyod.utils.data import generate_data
1717

18+
from sklearn.utils.validation import check_is_fitted
19+
20+
try:
21+
# Present in newer sklearn versions (1.3+, 1.4+, 1.8rc1, etc.)
22+
from sklearn.utils._tags import get_tags
23+
except ImportError:
24+
get_tags = None
25+
1826

1927
# Check sklearn\tests\test_base
2028
# A few test classes
@@ -207,6 +215,24 @@ def test_get_params(self):
207215
def tearDown(self):
208216
pass
209217

218+
@unittest.skipIf(get_tags is None, "sklearn is too old and does not provide get_tags")
219+
def test_sklearn_tags_and_check_is_fitted(self):
220+
# Dummy1 is a BaseDetector subclass
221+
clf = Dummy1()
222+
223+
# 1. get_tags should work and identify it as outlier_detector
224+
tags = get_tags(clf)
225+
assert_equal(tags.estimator_type, "outlier_detector")
226+
227+
# 2. Simulate a fitted model
228+
clf.labels_ = np.zeros(10)
229+
clf.decision_scores_ = np.zeros(10)
230+
clf.threshold_ = 0.5
231+
232+
# This should NOT raise AttributeError on sklearn >= 1.8
233+
check_is_fitted(clf, ['decision_scores_', 'threshold_', 'labels_'])
234+
235+
210236

211237
if __name__ == '__main__':
212238
unittest.main()

0 commit comments

Comments
 (0)