Skip to content

Commit 2d3e321

Browse files
committed
add callback for tracking boosting progress
1 parent ecff7ce commit 2d3e321

File tree

5 files changed

+414
-329
lines changed

5 files changed

+414
-329
lines changed

python/interpret-core/interpret/glassbox/_ebm/_boost.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313

1414

1515
def boost(
16+
stop_flag,
17+
bag_idx,
18+
callback,
1619
dataset,
1720
intercept_rounds,
1821
intercept_learning_rate,
@@ -53,6 +56,7 @@ def boost(
5356
try:
5457
develop._develop_options = develop_options # restore these in this process
5558
step_idx = 0
59+
cur_metric = np.nan
5660

5761
_log.info("Start boosting")
5862
native = Native.get_native_singleton()
@@ -134,12 +138,10 @@ def boost(
134138
make_progress = False
135139
if cyclic_state >= 1.0 or smoothing_rounds > 0:
136140
# if cyclic_state is above 1.0 we make progress
137-
step_idx += 1
138141
make_progress = True
139142
else:
140143
# greedy
141144
make_progress = True
142-
step_idx += 1
143145
_, _, term_idx = heapq.heappop(heap)
144146

145147
contains_nominals = any(nominals[i] for i in term_features[term_idx])
@@ -253,6 +255,8 @@ def boost(
253255
booster.set_term_update(term_idx, noisy_update_tensor)
254256

255257
if make_progress:
258+
step_idx += 1
259+
256260
cur_metric = booster.apply_term_update()
257261
# if early_stopping_tolerance is negative then keep accepting
258262
# model updates as they get worse past the minimum. We might
@@ -294,6 +298,16 @@ def boost(
294298
if min_prev_metric - modified_tolerance <= circular.min():
295299
break
296300

301+
if stop_flag is not None and stop_flag.value:
302+
break
303+
304+
if callback is not None:
305+
is_done = callback(bag_idx, step_idx, make_progress, cur_metric)
306+
if is_done:
307+
if stop_flag is not None:
308+
stop_flag.value = True
309+
break
310+
297311
state_idx = state_idx + 1
298312
if len(term_features) <= state_idx:
299313
if smoothing_rounds > 0:

0 commit comments

Comments
 (0)