|
13 | 13 |
|
14 | 14 |
|
15 | 15 | def boost( |
| 16 | + stop_flag, |
| 17 | + bag_idx, |
| 18 | + callback, |
16 | 19 | dataset, |
17 | 20 | intercept_rounds, |
18 | 21 | intercept_learning_rate, |
@@ -53,6 +56,7 @@ def boost( |
53 | 56 | try: |
54 | 57 | develop._develop_options = develop_options # restore these in this process |
55 | 58 | step_idx = 0 |
| 59 | + cur_metric = np.nan |
56 | 60 |
|
57 | 61 | _log.info("Start boosting") |
58 | 62 | native = Native.get_native_singleton() |
@@ -134,12 +138,10 @@ def boost( |
134 | 138 | make_progress = False |
135 | 139 | if cyclic_state >= 1.0 or smoothing_rounds > 0: |
136 | 140 | # if cyclic_state is above 1.0 we make progress |
137 | | - step_idx += 1 |
138 | 141 | make_progress = True |
139 | 142 | else: |
140 | 143 | # greedy |
141 | 144 | make_progress = True |
142 | | - step_idx += 1 |
143 | 145 | _, _, term_idx = heapq.heappop(heap) |
144 | 146 |
|
145 | 147 | contains_nominals = any(nominals[i] for i in term_features[term_idx]) |
@@ -253,6 +255,8 @@ def boost( |
253 | 255 | booster.set_term_update(term_idx, noisy_update_tensor) |
254 | 256 |
|
255 | 257 | if make_progress: |
| 258 | + step_idx += 1 |
| 259 | + |
256 | 260 | cur_metric = booster.apply_term_update() |
257 | 261 | # if early_stopping_tolerance is negative then keep accepting |
258 | 262 | # model updates as they get worse past the minimum. We might |
@@ -294,6 +298,16 @@ def boost( |
294 | 298 | if min_prev_metric - modified_tolerance <= circular.min(): |
295 | 299 | break |
296 | 300 |
|
| 301 | + if stop_flag is not None and stop_flag.value: |
| 302 | + break |
| 303 | + |
| 304 | + if callback is not None: |
| 305 | + is_done = callback(bag_idx, step_idx, make_progress, cur_metric) |
| 306 | + if is_done: |
| 307 | + if stop_flag is not None: |
| 308 | + stop_flag.value = True |
| 309 | + break |
| 310 | + |
297 | 311 | state_idx = state_idx + 1 |
298 | 312 | if len(term_features) <= state_idx: |
299 | 313 | if smoothing_rounds > 0: |
|
0 commit comments