22# Distributed under the MIT software license
33
44import heapq
5+ import inspect
56import json
67import logging
78import os
7879_log = logging .getLogger (__name__ )
7980
8081
82+ _CALLBACK_TYPES = {
83+ "progress" : {"bag" , "stage" , "step" , "term" , "metric" },
84+ "exam" : {"bag" , "stage" , "step" , "term" , "gain" },
85+ }
86+ _CallbackSpec = Callable [..., bool ] | tuple [Callable [..., bool ], ...]
87+
88+
89+ def _classify_callback (callback ):
90+ if not callable (callback ):
91+ msg = "callback must be a callable or a tuple of callables"
92+ _log .error (msg )
93+ raise ValueError (msg )
94+
95+ try :
96+ param_names = set (inspect .signature (callback ).parameters )
97+ except (TypeError , ValueError ) as exc :
98+ msg = "callback must have an inspectable signature"
99+ _log .error (msg )
100+ raise ValueError (msg ) from exc
101+
102+ for name , params in _CALLBACK_TYPES .items ():
103+ if params == param_names :
104+ return name
105+
106+ msg = (
107+ "callback must accept either the progress signature "
108+ "(*, bag, stage, step, term, metric) or the examination signature "
109+ "(*, bag, stage, step, term, gain)"
110+ )
111+ _log .error (msg )
112+ raise ValueError (msg )
113+
114+
115+ def _normalize_callbacks (callback ):
116+ if callback is None :
117+ return None , None
118+
119+ callbacks = callback if isinstance (callback , tuple ) else (callback ,)
120+
121+ progress_callback = None
122+ exam_callback = None
123+ for callback_item in callbacks :
124+ callback_type = _classify_callback (callback_item )
125+ if callback_type == "progress" :
126+ if progress_callback is not None :
127+ msg = "callback tuple cannot contain more than one progress callback"
128+ _log .error (msg )
129+ raise ValueError (msg )
130+ progress_callback = callback_item
131+ else :
132+ if exam_callback is not None :
133+ msg = "callback tuple cannot contain more than one examination callback"
134+ _log .error (msg )
135+ raise ValueError (msg )
136+ exam_callback = callback_item
137+
138+ return progress_callback , exam_callback
139+
140+
81141class EBMExplanation (FeatureValueExplanation ):
82142 """Visualizes specifically for EBM."""
83143
@@ -851,7 +911,8 @@ def fit(
851911 interaction_smoothing_rounds = 0
852912 early_stopping_rounds = 0
853913 early_stopping_tolerance = 0.0
854- callback = None
914+ progress_callback = None
915+ exam_callback = None
855916 min_samples_leaf = 0
856917 min_hessian = 0.0
857918 reg_alpha = 0.0
@@ -879,7 +940,7 @@ def fit(
879940 interaction_smoothing_rounds = self .interaction_smoothing_rounds
880941 early_stopping_rounds = self .early_stopping_rounds
881942 early_stopping_tolerance = self .early_stopping_tolerance
882- callback = self .callback
943+ progress_callback , exam_callback = _normalize_callbacks ( self .callback )
883944 min_samples_leaf = self .min_samples_leaf
884945 min_hessian = self .min_hessian
885946 reg_alpha = self .reg_alpha
@@ -1018,7 +1079,8 @@ def fit(
10181079 shared ,
10191080 )
10201081
1021- with nullcontext () if callback is None else SharedMemoryManager () as smm :
1082+ has_callback = progress_callback is not None or exam_callback is not None
1083+ with nullcontext () if not has_callback else SharedMemoryManager () as smm :
10221084 stop_flag : npt .NDArray [np .bool_ ] | None
10231085 if smm is not None :
10241086 shm = smm .SharedMemory (size = 1 )
@@ -1034,7 +1096,8 @@ def fit(
10341096 shm_name = shm_name ,
10351097 bag_idx = idx ,
10361098 stage = 0 ,
1037- callback = callback ,
1099+ progress_callback = progress_callback ,
1100+ exam_callback = exam_callback ,
10381101 dataset = (
10391102 shared .name if shared .name is not None else shared .dataset
10401103 ),
@@ -1274,7 +1337,8 @@ def fit(
12741337 shm_name = shm_name ,
12751338 bag_idx = idx ,
12761339 stage = 1 ,
1277- callback = callback ,
1340+ progress_callback = progress_callback ,
1341+ exam_callback = exam_callback ,
12781342 dataset = (
12791343 shared .name
12801344 if shared .name is not None
@@ -1386,7 +1450,8 @@ def fit(
13861450 shm_name = None ,
13871451 bag_idx = 0 ,
13881452 stage = - 1 ,
1389- callback = None ,
1453+ progress_callback = None ,
1454+ exam_callback = None ,
13901455 dataset = shared .dataset ,
13911456 intercept_rounds = develop .get_option ("n_intercept_rounds_final" ),
13921457 intercept_learning_rate = develop .get_option (
@@ -3312,15 +3377,15 @@ class EBMModel(BaseEBM):
33123377 tradeoff for the ensemble of models --- not the individual models --- a small
33133378 amount of overfitting of the individual models can improve the accuracy of
33143379 the ensemble as a whole.
3315- callback : Optional[Callable[..., bool]], default=None
3316- A user-defined function invoked after each progressing boosting step. Must use
3317- keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3318- If it returns True, boosting is stopped immediately .
3319- The callback receives: ``bag`` (int) the outer bag index,
3320- ``stage`` (int) the boosting stage (0=mains, 1=pairs),
3321- ``step`` (int) the number of boosting steps completed ,
3322- ``term`` (int) the index of the term that was just boosted,
3323- and ``metric`` (float) the current validation metric .
3380+ callback : Optional[Union[ Callable[..., bool], tuple[Callable[..., bool], ...] ]], default=None
3381+ A user-defined callback or tuple of callbacks invoked during boosting.
3382+ A progress callback is invoked after each progressing boosting step and must use
3383+ keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)`` .
3384+ An examination callback is invoked whenever a term is examined and its gain is
3385+ calculated, and must use keyword-only arguments:
3386+ ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True ,
3387+ boosting is stopped immediately. A tuple can contain at most one progress callback
3388+ and one examination callback .
33243389 min_samples_leaf : int, default=4
33253390 Minimum number of samples allowed in the leaves.
33263391 min_hessian : float, default=0.0
@@ -3437,7 +3502,7 @@ def __init__(
34373502 max_rounds : int | None = 50000 ,
34383503 early_stopping_rounds : int | None = 100 ,
34393504 early_stopping_tolerance : float | None = 1e-5 ,
3440- callback : Callable [..., bool ] | None = None ,
3505+ callback : _CallbackSpec | None = None ,
34413506 # Trees
34423507 min_samples_leaf : int | None = 4 ,
34433508 min_hessian : float | None = 0.0 ,
@@ -3577,15 +3642,15 @@ class EBMClassifier(EBMClassifierMixin, EBMModel):
35773642 tradeoff for the ensemble of models --- not the individual models --- a small
35783643 amount of overfitting of the individual models can improve the accuracy of
35793644 the ensemble as a whole.
3580- callback : Optional[Callable[..., bool]], default=None
3581- A user-defined function invoked after each progressing boosting step. Must use
3582- keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3583- If it returns True, boosting is stopped immediately .
3584- The callback receives: ``bag`` (int) the outer bag index,
3585- ``stage`` (int) the boosting stage (0=mains, 1=pairs),
3586- ``step`` (int) the number of boosting steps completed ,
3587- ``term`` (int) the index of the term that was just boosted,
3588- and ``metric`` (float) the current validation metric .
3645+ callback : Optional[Union[ Callable[..., bool], tuple[Callable[..., bool], ...] ]], default=None
3646+ A user-defined callback or tuple of callbacks invoked during boosting.
3647+ A progress callback is invoked after each progressing boosting step and must use
3648+ keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)`` .
3649+ An examination callback is invoked whenever a term is examined and its gain is
3650+ calculated, and must use keyword-only arguments:
3651+ ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True ,
3652+ boosting is stopped immediately. A tuple can contain at most one progress callback
3653+ and one examination callback .
35893654 min_samples_leaf : int, default=4
35903655 Minimum number of samples allowed in the leaves.
35913656 min_hessian : float, default=1e-4
@@ -3761,7 +3826,7 @@ def __init__(
37613826 max_rounds : int | None = 50000 ,
37623827 early_stopping_rounds : int | None = 100 ,
37633828 early_stopping_tolerance : float | None = 1e-5 ,
3764- callback : Callable [..., bool ] | None = None ,
3829+ callback : _CallbackSpec | None = None ,
37653830 # Trees
37663831 min_samples_leaf : int | None = 4 ,
37673832 min_hessian : float | None = 1e-4 ,
@@ -3903,15 +3968,15 @@ class EBMRegressor(EBMRegressorMixin, EBMModel):
39033968 tradeoff for the ensemble of models --- not the individual models --- a small
39043969 amount of overfitting of the individual models can improve the accuracy of
39053970 the ensemble as a whole.
3906- callback : Optional[Callable[..., bool]], default=None
3907- A user-defined function invoked after each progressing boosting step. Must use
3908- keyword-only arguments: ``def my_callback(*, bag, stage, step, term, metric)``.
3909- If it returns True, boosting is stopped immediately .
3910- The callback receives: ``bag`` (int) the outer bag index,
3911- ``stage`` (int) the boosting stage (0=mains, 1=pairs),
3912- ``step`` (int) the number of boosting steps completed ,
3913- ``term`` (int) the index of the term that was just boosted,
3914- and ``metric`` (float) the current validation metric .
3971+ callback : Optional[Union[ Callable[..., bool], tuple[Callable[..., bool], ...] ]], default=None
3972+ A user-defined callback or tuple of callbacks invoked during boosting.
3973+ A progress callback is invoked after each progressing boosting step and must use
3974+ keyword-only arguments: ``def progress_cb(*, bag, stage, step, term, metric)`` .
3975+ An examination callback is invoked whenever a term is examined and its gain is
3976+ calculated, and must use keyword-only arguments:
3977+ ``def exam_cb(*, bag, stage, step, term, gain)``. If any callback returns True ,
3978+ boosting is stopped immediately. A tuple can contain at most one progress callback
3979+ and one examination callback .
39153980 min_samples_leaf : int, default=4
39163981 Minimum number of samples allowed in the leaves.
39173982 min_hessian : float, default=0.0
@@ -4091,7 +4156,7 @@ def __init__(
40914156 max_rounds : int | None = 50000 ,
40924157 early_stopping_rounds : int | None = 100 ,
40934158 early_stopping_tolerance : float | None = 1e-5 ,
4094- callback : Callable [..., bool ] | None = None ,
4159+ callback : _CallbackSpec | None = None ,
40954160 # Trees
40964161 min_samples_leaf : int | None = 4 ,
40974162 min_hessian : float | None = 0.0 ,
0 commit comments