Skip to content

Commit e0369a7

Browse files
committed
add __sklearn_tags__
1 parent bd005ec commit e0369a7

File tree

6 files changed

+389
-1
lines changed

6 files changed

+389
-1
lines changed

python/interpret-core/interpret/glassbox/_aplr.py

+65
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# Distributed under the MIT software license
33
from typing import Dict, List, Optional, Tuple
44
from warnings import warn
5+
from dataclasses import dataclass, field
6+
from typing import Optional
57

68
import numpy as np
79
import pandas as pd
@@ -23,6 +25,57 @@
2325
IntMatrix = np.ndarray
2426

2527

28+
@dataclass
29+
class APLRInputTags:
30+
one_d_array: bool = False
31+
two_d_array: bool = True
32+
three_d_array: bool = False
33+
sparse: bool = False
34+
categorical: bool = False
35+
string: bool = False
36+
dict: bool = False
37+
positive_only: bool = False
38+
allow_nan: bool = False
39+
pairwise: bool = False
40+
41+
42+
@dataclass
43+
class APLRTargetTags:
44+
required: bool = True
45+
one_d_labels: bool = True
46+
two_d_labels: bool = False
47+
positive_only: bool = False
48+
multi_output: bool = False
49+
single_output: bool = True
50+
51+
52+
@dataclass
53+
class APLRClassifierTags:
54+
poor_score: bool = False
55+
multi_class: bool = True
56+
multi_label: bool = False
57+
58+
59+
@dataclass
60+
class APLRRegressorTags:
61+
poor_score: bool = False
62+
63+
64+
@dataclass
65+
class APLRTags:
66+
estimator_type: Optional[str] = None
67+
target_tags: APLRTargetTags = field(default_factory=APLRTargetTags)
68+
transformer_tags: None = None
69+
classifier_tags: Optional[APLRClassifierTags] = None
70+
regressor_tags: Optional[APLRRegressorTags] = None
71+
array_api_support: bool = False
72+
no_validation: bool = False
73+
non_deterministic: bool = False
74+
requires_fit: bool = True
75+
_skip_test: bool = False
76+
input_tags: APLRInputTags = field(default_factory=APLRInputTags)
77+
78+
2679
class APLRRegressor(RegressorMixin, ExplainerMixin):
2780
available_explanations = ["local", "global"]
2881
explainer_type = "model"
@@ -409,6 +462,12 @@ def explain_local(
409462
selector=selector,
410463
)
411464

465+
def __sklearn_tags__(self):
466+
tags = APLRTags()
467+
tags.estimator_type = "regressor"
468+
tags.regressor_tags = APLRRegressorTags()
469+
return tags
470+
412471

413472
def calculate_densities(X: FloatMatrix) -> Tuple[List[List[int]], List[List[float]]]:
414473
bin_counts: List[List[int]] = []
@@ -816,6 +875,12 @@ def explain_local(
816875
selector=selector,
817876
)
818877

878+
def __sklearn_tags__(self):
879+
tags = APLRTags()
880+
tags.estimator_type = "classifier"
881+
tags.classifier_tags = APLRClassifierTags()
882+
return tags
883+
819884

820885
class APLRExplanation(FeatureValueExplanation):
821886
"""Visualizes specifically for APLR."""

python/interpret-core/interpret/glassbox/_decisiontree.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import logging
55
from abc import abstractmethod
66
from copy import deepcopy
7+
from dataclasses import dataclass, field
8+
from typing import Optional
79

810
import numpy as np
911
from sklearn.base import ClassifierMixin, RegressorMixin, is_classifier
@@ -219,6 +221,57 @@ def _weight_nodes_feature(self, nodes, feature_name):
219221
return new_nodes
220222

221223

224+
@dataclass
225+
class TreeInputTags:
226+
one_d_array: bool = False
227+
two_d_array: bool = True
228+
three_d_array: bool = False
229+
sparse: bool = True
230+
categorical: bool = False
231+
string: bool = True
232+
dict: bool = True
233+
positive_only: bool = False
234+
allow_nan: bool = True
235+
pairwise: bool = False
236+
237+
238+
@dataclass
239+
class TreeTargetTags:
240+
required: bool = True
241+
one_d_labels: bool = True
242+
two_d_labels: bool = False
243+
positive_only: bool = False
244+
multi_output: bool = False
245+
single_output: bool = True
246+
247+
248+
@dataclass
249+
class TreeClassifierTags:
250+
poor_score: bool = False
251+
multi_class: bool = True
252+
multi_label: bool = False
253+
254+
255+
@dataclass
256+
class TreeRegressorTags:
257+
poor_score: bool = False
258+
259+
260+
@dataclass
261+
class TreeTags:
262+
estimator_type: Optional[str] = None
263+
target_tags: TreeTargetTags = field(default_factory=TreeTargetTags)
264+
transformer_tags: None = None
265+
classifier_tags: Optional[TreeClassifierTags] = None
266+
regressor_tags: Optional[TreeRegressorTags] = None
267+
array_api_support: bool = True
268+
no_validation: bool = False
269+
non_deterministic: bool = False
270+
requires_fit: bool = True
271+
_skip_test: bool = False
272+
input_tags: TreeInputTags = field(default_factory=TreeInputTags)
273+
274+
222275
class BaseShallowDecisionTree:
223276
"""Shallow Decision Tree (low depth).
224277
@@ -280,7 +333,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
280333
X, n_samples = preclean_X(X, self.feature_names, self.feature_types, len(y))
281334

282335
X, self.feature_names_in_, self.feature_types_in_ = unify_data(
283-
X, n_samples, self.feature_names, self.feature_types, False, 0
336+
X, n_samples, self.feature_names, self.feature_types, True, 0
284337
)
285338

286339
model = self._model()
@@ -540,6 +593,9 @@ def recur(i, depth=0):
540593
recur(0)
541594
return nodes, edges
542595

596+
def __sklearn_tags__(self):
597+
return TreeTags()
598+
543599

544600
class RegressionTree(BaseShallowDecisionTree, RegressorMixin, ExplainerMixin):
545601
"""Regression tree with shallow depth."""
@@ -583,6 +639,12 @@ def fit(self, X, y, sample_weight=None, check_input=True):
583639
check_input=check_input,
584640
)
585641

642+
def __sklearn_tags__(self):
643+
tags = super().__sklearn_tags__()
644+
tags.estimator_type = "regressor"
645+
tags.regressor_tags = TreeRegressorTags()
646+
return tags
647+
586648

587649
class ClassificationTree(BaseShallowDecisionTree, ClassifierMixin, ExplainerMixin):
588650
"""Classification tree with shallow depth."""
@@ -644,3 +706,9 @@ def predict_proba(self, X):
644706
)
645707

646708
return self._model().predict_proba(X)
709+
710+
def __sklearn_tags__(self):
711+
tags = super().__sklearn_tags__()
712+
tags.estimator_type = "classifier"
713+
tags.classifier_tags = TreeClassifierTags()
714+
return tags

python/interpret-core/interpret/glassbox/_ebm/_ebm.py

+79
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from math import ceil, isnan
1212
from typing import Dict, List, Mapping, Optional, Sequence, Tuple, Union
1313
from warnings import warn
14+
from dataclasses import dataclass, field
1415

1516
import numpy as np
1617
from sklearn.base import (
@@ -276,6 +277,57 @@ def _clean_exclude(exclude, feature_map):
276277
return ret
277278

278279

280+
@dataclass
281+
class EbmInputTags:
282+
one_d_array: bool = False
283+
two_d_array: bool = True
284+
three_d_array: bool = False
285+
sparse: bool = True
286+
categorical: bool = True
287+
string: bool = True
288+
dict: bool = True
289+
positive_only: bool = False
290+
allow_nan: bool = True
291+
pairwise: bool = False
292+
293+
294+
@dataclass
295+
class EbmTargetTags:
296+
required: bool = True
297+
one_d_labels: bool = True
298+
two_d_labels: bool = False
299+
positive_only: bool = False
300+
multi_output: bool = False
301+
single_output: bool = True
302+
303+
304+
@dataclass
305+
class EbmClassifierTags:
306+
poor_score: bool = False
307+
multi_class: bool = True
308+
multi_label: bool = False
309+
310+
311+
@dataclass
312+
class EbmRegressorTags:
313+
poor_score: bool = False
314+
315+
316+
@dataclass
317+
class EbmTags:
318+
estimator_type: Optional[str] = None
319+
target_tags: EbmTargetTags = field(default_factory=EbmTargetTags)
320+
transformer_tags: None = None
321+
classifier_tags: Optional[EbmClassifierTags] = None
322+
regressor_tags: Optional[EbmRegressorTags] = None
323+
array_api_support: bool = True
324+
no_validation: bool = False
325+
non_deterministic: bool = False
326+
requires_fit: bool = True
327+
_skip_test: bool = False
328+
input_tags: EbmInputTags = field(default_factory=EbmInputTags)
329+
330+
279331
class EBMModel(BaseEstimator):
280332
"""Base class for all EBMs."""
281333

@@ -2627,6 +2679,9 @@ def _more_tags(self):
26272679
],
26282680
}
26292681

2682+
def __sklearn_tags__(self):
2683+
return EbmTags()
2684+
26302685

26312686
class ExplainableBoostingClassifier(EBMModel, ClassifierMixin, ExplainerMixin):
26322687
r"""An Explainable Boosting Classifier.
@@ -2977,6 +3032,12 @@ def predict(self, X, init_score=None):
29773032
# multiclass
29783033
return self.classes_[np.argmax(scores, axis=1)]
29793034

3035+
def __sklearn_tags__(self):
3036+
tags = super().__sklearn_tags__()
3037+
tags.estimator_type = "classifier"
3038+
tags.classifier_tags = EbmClassifierTags()
3039+
return tags
3040+
29803041

29813042
class ExplainableBoostingRegressor(EBMModel, RegressorMixin, ExplainerMixin):
29823043
r"""An Explainable Boosting Regressor.
@@ -3293,6 +3354,12 @@ def predict(self, X, init_score=None):
32933354
scores = self._predict_score(X, init_score)
32943355
return inv_link(scores, self.link_, self.link_param_)
32953356

3357+
def __sklearn_tags__(self):
3358+
tags = super().__sklearn_tags__()
3359+
tags.estimator_type = "regressor"
3360+
tags.regressor_tags = EbmRegressorTags()
3361+
return tags
3362+
32963363

32973364
class DPExplainableBoostingClassifier(EBMModel, ClassifierMixin, ExplainerMixin):
32983365
r"""Differentially Private Explainable Boosting Classifier.
@@ -3554,6 +3621,12 @@ def predict(self, X, init_score=None):
35543621
# multiclass
35553622
return self.classes_[np.argmax(scores, axis=1)]
35563623

3624+
def __sklearn_tags__(self):
3625+
tags = super().__sklearn_tags__()
3626+
tags.estimator_type = "classifier"
3627+
tags.classifier_tags = EbmClassifierTags()
3628+
return tags
3629+
35573630

35583631
class DPExplainableBoostingRegressor(EBMModel, RegressorMixin, ExplainerMixin):
35593632
r"""Differentially Private Explainable Boosting Regressor.
@@ -3791,3 +3864,9 @@ def predict(self, X, init_score=None):
37913864
"""
37923865
scores = self._predict_score(X, init_score)
37933866
return inv_link(scores, self.link_, self.link_param_)
3867+
3868+
def __sklearn_tags__(self):
3869+
tags = super().__sklearn_tags__()
3870+
tags.estimator_type = "regressor"
3871+
tags.regressor_tags = EbmRegressorTags()
3872+
return tags

0 commit comments

Comments
 (0)