Skip to content

Commit b206363

Browse files
authored
metric constraint (#90)
* penalty change * metric modification * catboost init
1 parent 0925e2b commit b206363

3 files changed

Lines changed: 55 additions & 8 deletions

File tree

flaml/automl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,7 @@ def custom_metric(X_test, y_test, estimator, labels,
922922
# set up learner search space
923923
for estimator_name in estimator_list:
924924
estimator_class = self._state.learner_classes[estimator_name]
925+
estimator_class.init()
925926
self._search_states[estimator_name] = SearchState(
926927
learner_class=estimator_class,
927928
data_size=self._state.data_size, task=self._state.task,

flaml/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,11 @@ def cost_relative2lgbm(cls):
163163
'''[optional method] relative cost compared to lightgbm'''
164164
return 1.0
165165

166+
@classmethod
167+
def init(cls):
168+
'''[optional method] initialize the class'''
169+
pass
170+
166171

167172
class SKLearnEstimator(BaseEstimator):
168173

@@ -632,6 +637,11 @@ def size(cls, config):
632637
def cost_relative2lgbm(cls):
633638
return 15
634639

640+
@classmethod
641+
def init(cls):
642+
CatBoostEstimator._time_per_iter = None
643+
CatBoostEstimator._train_size = 0
644+
635645
def __init__(
636646
self, task='binary:logistic', n_jobs=1,
637647
n_estimators=8192, learning_rate=0.1, early_stopping_rounds=4, **params

flaml/searcher/blendsearch.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class BlendSearch(Searcher):
2727
'''
2828

2929
cost_attr = "time_total_s" # cost attribute in result
30+
lagrange = '_lagrange' # suffix for lagrange-modified metric
31+
penalty = 1e+10 # penalty term for constraints
3032

3133
def __init__(self,
3234
metric: Optional[str] = None,
@@ -106,6 +108,11 @@ def __init__(self,
106108
self._metric, self._mode = metric, mode
107109
init_config = low_cost_partial_config or {}
108110
self._points_to_evaluate = points_to_evaluate or []
111+
self._config_constraints = config_constraints
112+
self._metric_constraints = metric_constraints
113+
if self._metric_constraints:
114+
# metric modified by lagrange
115+
metric += self.lagrange
109116
if global_search_alg is not None:
110117
self._gs = global_search_alg
111118
elif getattr(self, '__name__', None) != 'CFO':
@@ -115,8 +122,6 @@ def __init__(self,
115122
self._ls = LocalSearch(
116123
init_config, metric, mode, cat_hp_cost, space,
117124
prune_attr, min_resource, max_resource, reduction_factor, seed)
118-
self._config_constraints = config_constraints
119-
self._metric_constraints = metric_constraints
120125
self._init_search()
121126

122127
def set_search_properties(self,
@@ -131,6 +136,11 @@ def set_search_properties(self,
131136
else:
132137
if metric:
133138
self._metric = metric
139+
if self._metric_constraints:
140+
# metric modified by lagrange
141+
metric += self.lagrange
142+
# TODO: don't change metric for global search methods that
143+
# can handle constraints already
134144
if mode:
135145
self._mode = mode
136146
self._ls.set_search_properties(metric, mode, config)
@@ -156,6 +166,13 @@ def _init_search(self):
156166
self._gs_admissible_max = self._ls_bound_max.copy()
157167
self._result = {} # config_signature: tuple -> result: Dict
158168
self._deadline = np.inf
169+
if self._metric_constraints:
170+
self._metric_constraint_satisfied = False
171+
self._metric_constraint_penalty = [
172+
self.penalty for _ in self._metric_constraints]
173+
else:
174+
self._metric_constraint_satisfied = True
175+
self._metric_constraint_penalty = None
159176

160177
def save(self, checkpoint_path: str):
161178
save_object = self
@@ -182,6 +199,8 @@ def restore(self, checkpoint_path: str):
182199
self._ls = state._ls
183200
self._config_constraints = state._config_constraints
184201
self._metric_constraints = state._metric_constraints
202+
self._metric_constraint_satisfied = state._metric_constraint_satisfied
203+
self._metric_constraint_penalty = state._metric_constraint_penalty
185204

186205
def restore_from_dir(self, checkpoint_dir: str):
187206
super.restore_from_dir(checkpoint_dir)
@@ -190,10 +209,11 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
190209
error: bool = False):
191210
''' search thread updater and cleaner
192211
'''
212+
metric_constraint_satisfied = True
193213
if result and not error and self._metric_constraints:
194-
# accout for metric constraints if any
214+
# account for metric constraints if any
195215
objective = result[self._metric]
196-
for constraint in self._metric_constraints:
216+
for i, constraint in enumerate(self._metric_constraints):
197217
metric_constraint, sign, threshold = constraint
198218
value = result.get(metric_constraint)
199219
if value:
@@ -202,8 +222,16 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
202222
violation = (value - threshold) * sign_op
203223
if violation > 0:
204224
# add penalty term to the metric
205-
objective += 1e+10 * violation * self._ls.metric_op
206-
result[self._metric] = objective
225+
objective += self._metric_constraint_penalty[
226+
i] * violation * self._ls.metric_op
227+
metric_constraint_satisfied = False
228+
if self._metric_constraint_penalty[i] < self.penalty:
229+
self._metric_constraint_penalty[i] += violation
230+
result[self._metric + self.lagrange] = objective
231+
if metric_constraint_satisfied and not self._metric_constraint_satisfied:
232+
# found a feasible point
233+
self._metric_constraint_penalty = [1 for _ in self._metric_constraints]
234+
self._metric_constraint_satisfied |= metric_constraint_satisfied
207235
thread_id = self._trial_proposed_by.get(trial_id)
208236
if thread_id in self._search_thread_pool:
209237
self._search_thread_pool[thread_id].on_trial_complete(
@@ -219,10 +247,13 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
219247
else: # add to result cache
220248
self._result[self._ls.config_signature(config)] = result
221249
# update target metric if improved
222-
objective = result[self._metric]
250+
objective = result[
251+
self._metric + self.lagrange] if self._metric_constraints \
252+
else result[self._metric]
223253
if (objective - self._metric_target) * self._ls.metric_op < 0:
224254
self._metric_target = objective
225-
if not thread_id and self._create_condition(result):
255+
if not thread_id and metric_constraint_satisfied \
256+
and self._create_condition(result):
226257
# thread creator
227258
self._search_thread_pool[self._thread_count] = SearchThread(
228259
self._ls.mode,
@@ -233,6 +264,9 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
233264
self._thread_count += 1
234265
self._update_admissible_region(
235266
config, self._ls_bound_min, self._ls_bound_max)
267+
elif thread_id and not self._metric_constraint_satisfied:
268+
# no point has been found to satisfy metric constraint
269+
self._expand_admissible_region()
236270
# reset admissible region to ls bounding box
237271
self._gs_admissible_min.update(self._ls_bound_min)
238272
self._gs_admissible_max.update(self._ls_bound_max)
@@ -306,6 +340,8 @@ def on_trial_result(self, trial_id: str, result: Dict):
306340
thread_id = self._trial_proposed_by[trial_id]
307341
if thread_id not in self._search_thread_pool:
308342
return
343+
if result and self._metric_constraints:
344+
result[self._metric + self.lagrange] = result[self._metric]
309345
self._search_thread_pool[thread_id].on_trial_result(trial_id, result)
310346

311347
def suggest(self, trial_id: str) -> Optional[Dict]:

0 commit comments

Comments
 (0)