Skip to content

Commit 59f7388

Browse files
feat: add tuple support and exam_cb for callbacks (#665)
* feat: add tuple support and exam_cb for callbacks Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> * test: cover callback validation branches Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> * fix: address callback tuple review feedback Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> --------- Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com> Co-authored-by: Paul Koch <46825734+paulbkoch@users.noreply.github.com>
1 parent e6d20f9 commit 59f7388

3 files changed

Lines changed: 345 additions & 42 deletions

File tree

python/interpret-core/interpret/glassbox/_ebm.py

Lines changed: 101 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Distributed under the MIT software license
33

44
import heapq
5+
import inspect
56
import json
67
import logging
78
import os
@@ -78,6 +79,65 @@
7879
_log = logging.getLogger(__name__)
7980

8081

82+
_CALLBACK_TYPES = {
83+
"progress": {"bag", "stage", "step", "term", "metric"},
84+
"exam": {"bag", "stage", "step", "term", "gain"},
85+
}
86+
_CallbackSpec = Callable[..., bool] | tuple[Callable[..., bool], ...]
87+
88+
89+
def _classify_callback(callback):
90+
if not callable(callback):
91+
msg = "callback must be a callable or a tuple of callables"
92+
_log.error(msg)
93+
raise ValueError(msg)
94+
95+
try:
96+
param_names = set(inspect.signature(callback).parameters)
97+
except (TypeError, ValueError) as exc:
98+
msg = "callback must have an inspectable signature"
99+
_log.error(msg)
100+
raise ValueError(msg) from exc
101+
102+
for name, params in _CALLBACK_TYPES.items():
103+
if params == param_names:
104+
return name
105+
106+
msg = (
107+
"callback must accept either the progress signature "
108+
"(*, bag, stage, step, term, metric) or the examination signature "
109+
"(*, bag, stage, step, term, gain)"
110+
)
111+
_log.error(msg)
112+
raise ValueError(msg)
113+
114+
115+
def _normalize_callbacks(callback):
116+
if callback is None:
117+
return None, None
118+
119+
callbacks = callback if isinstance(callback, tuple) else (callback,)
120+
121+
progress_callback = None
122+
exam_callback = None
123+
for callback_item in callbacks:
124+
callback_type = _classify_callback(callback_item)
125+
if callback_type == "progress":
126+
if progress_callback is not None:
127+
msg = "callback tuple cannot contain more than one progress callback"
128+
_log.error(msg)
129+
raise ValueError(msg)
130+
progress_callback = callback_item
131+
else:
132+
if exam_callback is not None:
133+
msg = "callback tuple cannot contain more than one examination callback"
134+
_log.error(msg)
135+
raise ValueError(msg)
136+
exam_callback = callback_item
137+
138+
return progress_callback, exam_callback
139+
140+
81141
class EBMExplanation(FeatureValueExplanation):
82142
"""Visualizes specifically for EBM."""
83143

@@ -851,7 +911,8 @@ def fit(
851911
interaction_smoothing_rounds = 0
852912
early_stopping_rounds = 0
853913
early_stopping_tolerance = 0.0
854-
callback = None
914+
progress_callback = None
915+
exam_callback = None
855916
min_samples_leaf = 0
856917
min_hessian = 0.0
857918
reg_alpha = 0.0
@@ -879,7 +940,7 @@ def fit(
879940
interaction_smoothing_rounds = self.interaction_smoothing_rounds
880941
early_stopping_rounds = self.early_stopping_rounds
881942
early_stopping_tolerance = self.early_stopping_tolerance
882-
callback = self.callback
943+
progress_callback, exam_callback = _normalize_callbacks(self.callback)
883944
min_samples_leaf = self.min_samples_leaf
884945
min_hessian = self.min_hessian
885946
reg_alpha = self.reg_alpha
@@ -1018,7 +1079,8 @@ def fit(
10181079
shared,
10191080
)
10201081

1021-
with nullcontext() if callback is None else SharedMemoryManager() as smm:
1082+
has_callback = progress_callback is not None or exam_callback is not None
1083+
with nullcontext() if not has_callback else SharedMemoryManager() as smm:
10221084
stop_flag: npt.NDArray[np.bool_] | None
10231085
if smm is not None:
10241086
shm = smm.SharedMemory(size=1)
@@ -1034,7 +1096,8 @@ def fit(
10341096
shm_name=shm_name,
10351097
bag_idx=idx,
10361098
stage=0,
1037-
callback=callback,
1099+
progress_callback=progress_callback,
1100+
exam_callback=exam_callback,
10381101
dataset=(
10391102
shared.name if shared.name is not None else shared.dataset
10401103
),
@@ -1274,7 +1337,8 @@ def fit(
12741337
shm_name=shm_name,
12751338
bag_idx=idx,
12761339
stage=1,
1277-
callback=callback,
1340+
progress_callback=progress_callback,
1341+
exam_callback=exam_callback,
12781342
dataset=(
12791343
shared.name
12801344
if shared.name is not None
@@ -1386,7 +1450,8 @@ def fit(
13861450
shm_name=None,
13871451
bag_idx=0,
13881452
stage=-1,
1389-
callback=None,
1453+
progress_callback=None,
1454+
exam_callback=None,
13901455
dataset=shared.dataset,
13911456
intercept_rounds=develop.get_option("n_intercept_rounds_final"),
13921457
intercept_learning_rate=develop.get_option(
@@ -3312,15 +3377,15 @@ class EBMModel(BaseEBM):
33123377
tradeoff for the ensemble of models --- not the individual models --- a small
33133378
amount of overfitting of the individual models can improve the accuracy of
33143379
the ensemble as a whole.
3315-
callback : Optional[Callable[..., bool]], default=None
3316-
A user-defined function invoked after each progressing boosting step. Must use
3317-
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3318-
If it returns True, boosting is stopped immediately.
3319-
The callback receives: ``bag`` (int) the outer bag index,
3320-
``stage`` (int) the boosting stage (0=mains, 1=pairs),
3321-
``step`` (int) the number of boosting steps completed,
3322-
``term`` (int) the index of the term that was just boosted,
3323-
and ``metric`` (float) the current validation metric.
3380+
callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None
3381+
A user-defined callback or tuple of callbacks invoked during boosting.
3382+
A progress callback is invoked after each progressing boosting step and must use
3383+
keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``.
3384+
An examination callback is invoked whenever a term is examined and its gain is
3385+
calculated, and must use keyword-only arguments:
3386+
``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True,
3387+
boosting is stopped immediately. A tuple can contain at most one progress callback
3388+
and one examination callback.
33243389
min_samples_leaf : int, default=4
33253390
Minimum number of samples allowed in the leaves.
33263391
min_hessian : float, default=0.0
@@ -3437,7 +3502,7 @@ def __init__(
34373502
max_rounds: int | None = 50000,
34383503
early_stopping_rounds: int | None = 100,
34393504
early_stopping_tolerance: float | None = 1e-5,
3440-
callback: Callable[..., bool] | None = None,
3505+
callback: _CallbackSpec | None = None,
34413506
# Trees
34423507
min_samples_leaf: int | None = 4,
34433508
min_hessian: float | None = 0.0,
@@ -3577,15 +3642,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel):
35773642
tradeoff for the ensemble of models --- not the individual models --- a small
35783643
amount of overfitting of the individual models can improve the accuracy of
35793644
the ensemble as a whole.
3580-
callback : Optional[Callable[..., bool]], default=None
3581-
A user-defined function invoked after each progressing boosting step. Must use
3582-
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3583-
If it returns True, boosting is stopped immediately.
3584-
The callback receives: ``bag`` (int) the outer bag index,
3585-
``stage`` (int) the boosting stage (0=mains, 1=pairs),
3586-
``step`` (int) the number of boosting steps completed,
3587-
``term`` (int) the index of the term that was just boosted,
3588-
and ``metric`` (float) the current validation metric.
3645+
callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None
3646+
A user-defined callback or tuple of callbacks invoked during boosting.
3647+
A progress callback is invoked after each progressing boosting step and must use
3648+
keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``.
3649+
An examination callback is invoked whenever a term is examined and its gain is
3650+
calculated, and must use keyword-only arguments:
3651+
``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True,
3652+
boosting is stopped immediately. A tuple can contain at most one progress callback
3653+
and one examination callback.
35893654
min_samples_leaf : int, default=4
35903655
Minimum number of samples allowed in the leaves.
35913656
min_hessian : float, default=1e-4
@@ -3761,7 +3826,7 @@ def __init__(
37613826
max_rounds: int | None = 50000,
37623827
early_stopping_rounds: int | None = 100,
37633828
early_stopping_tolerance: float | None = 1e-5,
3764-
callback: Callable[..., bool] | None = None,
3829+
callback: _CallbackSpec | None = None,
37653830
# Trees
37663831
min_samples_leaf: int | None = 4,
37673832
min_hessian: float | None = 1e-4,
@@ -3903,15 +3968,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel):
39033968
tradeoff for the ensemble of models --- not the individual models --- a small
39043969
amount of overfitting of the individual models can improve the accuracy of
39053970
the ensemble as a whole.
3906-
callback : Optional[Callable[..., bool]], default=None
3907-
A user-defined function invoked after each progressing boosting step. Must use
3908-
keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3909-
If it returns True, boosting is stopped immediately.
3910-
The callback receives: ``bag`` (int) the outer bag index,
3911-
``stage`` (int) the boosting stage (0=mains, 1=pairs),
3912-
``step`` (int) the number of boosting steps completed,
3913-
``term`` (int) the index of the term that was just boosted,
3914-
and ``metric`` (float) the current validation metric.
3971+
callback : Optional[Union[Callable[..., bool], tuple[Callable[..., bool], ...]]], default=None
3972+
A user-defined callback or tuple of callbacks invoked during boosting.
3973+
A progress callback is invoked after each progressing boosting step and must use
3974+
keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)``.
3975+
An examination callback is invoked whenever a term is examined and its gain is
3976+
calculated, and must use keyword-only arguments:
3977+
``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True,
3978+
boosting is stopped immediately. A tuple can contain at most one progress callback
3979+
and one examination callback.
39153980
min_samples_leaf : int, default=4
39163981
Minimum number of samples allowed in the leaves.
39173982
min_hessian : float, default=0.0
@@ -4091,7 +4156,7 @@ def __init__(
40914156
max_rounds: int | None = 50000,
40924157
early_stopping_rounds: int | None = 100,
40934158
early_stopping_tolerance: float | None = 1e-5,
4094-
callback: Callable[..., bool] | None = None,
4159+
callback: _CallbackSpec | None = None,
40954160
# Trees
40964161
min_samples_leaf: int | None = 4,
40974162
min_hessian: float | None = 0.0,

python/interpret-core/interpret/glassbox/_ebm_core/_boost.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def boost(
2929
shm_name,
3030
bag_idx,
3131
stage,
32-
callback,
32+
progress_callback,
33+
exam_callback,
3334
dataset,
3435
intercept_rounds,
3536
intercept_learning_rate,
@@ -264,6 +265,22 @@ def boost(
264265
# penalize nominals a bit because they benefit from sorting categories
265266
avg_gain *= gain_scale
266267

268+
if stop_flag is not None and stop_flag[0]:
269+
break
270+
271+
if exam_callback is not None:
272+
is_done = exam_callback(
273+
bag=bag_idx,
274+
stage=stage,
275+
step=step_idx,
276+
term=term_idx,
277+
gain=avg_gain,
278+
)
279+
if is_done:
280+
if stop_flag is not None:
281+
stop_flag[0] = True
282+
break
283+
267284
gainkey = (-avg_gain, native.generate_seed(rng), term_idx)
268285
if not make_progress and (
269286
bestkey is None or gainkey < bestkey
@@ -365,11 +382,8 @@ def boost(
365382
):
366383
break
367384

368-
if stop_flag is not None and stop_flag[0]:
369-
break
370-
371-
if callback is not None:
372-
is_done = callback(
385+
if progress_callback is not None:
386+
is_done = progress_callback(
373387
bag=bag_idx,
374388
stage=stage,
375389
step=step_idx,

0 commit comments

Comments
 (0)