@@ -27,6 +27,8 @@ class BlendSearch(Searcher):
2727 '''
2828
2929 cost_attr = "time_total_s" # cost attribute in result
30+ lagrange = '_lagrange' # suffix for lagrange-modified metric
31+ penalty = 1e+10 # penalty term for constraints
3032
3133 def __init__ (self ,
3234 metric : Optional [str ] = None ,
@@ -106,6 +108,11 @@ def __init__(self,
106108 self ._metric , self ._mode = metric , mode
107109 init_config = low_cost_partial_config or {}
108110 self ._points_to_evaluate = points_to_evaluate or []
111+ self ._config_constraints = config_constraints
112+ self ._metric_constraints = metric_constraints
113+ if self ._metric_constraints :
114+ # metric modified by lagrange
115+ metric += self .lagrange
109116 if global_search_alg is not None :
110117 self ._gs = global_search_alg
111118 elif getattr (self , '__name__' , None ) != 'CFO' :
@@ -115,8 +122,6 @@ def __init__(self,
115122 self ._ls = LocalSearch (
116123 init_config , metric , mode , cat_hp_cost , space ,
117124 prune_attr , min_resource , max_resource , reduction_factor , seed )
118- self ._config_constraints = config_constraints
119- self ._metric_constraints = metric_constraints
120125 self ._init_search ()
121126
122127 def set_search_properties (self ,
@@ -131,6 +136,11 @@ def set_search_properties(self,
131136 else :
132137 if metric :
133138 self ._metric = metric
139+ if self ._metric_constraints :
140+ # metric modified by lagrange
141+ metric += self .lagrange
142+ # TODO: don't change metric for global search methods that
143+ # can handle constraints already
134144 if mode :
135145 self ._mode = mode
136146 self ._ls .set_search_properties (metric , mode , config )
@@ -156,6 +166,13 @@ def _init_search(self):
156166 self ._gs_admissible_max = self ._ls_bound_max .copy ()
157167 self ._result = {} # config_signature: tuple -> result: Dict
158168 self ._deadline = np .inf
169+ if self ._metric_constraints :
170+ self ._metric_constraint_satisfied = False
171+ self ._metric_constraint_penalty = [
172+ self .penalty for _ in self ._metric_constraints ]
173+ else :
174+ self ._metric_constraint_satisfied = True
175+ self ._metric_constraint_penalty = None
159176
160177 def save (self , checkpoint_path : str ):
161178 save_object = self
@@ -182,6 +199,8 @@ def restore(self, checkpoint_path: str):
182199 self ._ls = state ._ls
183200 self ._config_constraints = state ._config_constraints
184201 self ._metric_constraints = state ._metric_constraints
202+ self ._metric_constraint_satisfied = state ._metric_constraint_satisfied
203+ self ._metric_constraint_penalty = state ._metric_constraint_penalty
185204
186205 def restore_from_dir (self , checkpoint_dir : str ):
187206 super .restore_from_dir (checkpoint_dir )
@@ -190,10 +209,11 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
190209 error : bool = False ):
191210 ''' search thread updater and cleaner
192211 '''
212+ metric_constraint_satisfied = True
193213 if result and not error and self ._metric_constraints :
194- # accout for metric constraints if any
214+ # account for metric constraints if any
195215 objective = result [self ._metric ]
196- for constraint in self ._metric_constraints :
216+ for i , constraint in enumerate ( self ._metric_constraints ) :
197217 metric_constraint , sign , threshold = constraint
198218 value = result .get (metric_constraint )
199219 if value :
@@ -202,8 +222,16 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
202222 violation = (value - threshold ) * sign_op
203223 if violation > 0 :
204224 # add penalty term to the metric
205- objective += 1e+10 * violation * self ._ls .metric_op
206- result [self ._metric ] = objective
225+ objective += self ._metric_constraint_penalty [
226+ i ] * violation * self ._ls .metric_op
227+ metric_constraint_satisfied = False
228+ if self ._metric_constraint_penalty [i ] < self .penalty :
229+ self ._metric_constraint_penalty [i ] += violation
230+ result [self ._metric + self .lagrange ] = objective
231+ if metric_constraint_satisfied and not self ._metric_constraint_satisfied :
232+ # found a feasible point
233+ self ._metric_constraint_penalty = [1 for _ in self ._metric_constraints ]
234+ self ._metric_constraint_satisfied |= metric_constraint_satisfied
207235 thread_id = self ._trial_proposed_by .get (trial_id )
208236 if thread_id in self ._search_thread_pool :
209237 self ._search_thread_pool [thread_id ].on_trial_complete (
@@ -219,10 +247,13 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
219247 else : # add to result cache
220248 self ._result [self ._ls .config_signature (config )] = result
221249 # update target metric if improved
222- objective = result [self ._metric ]
250+ objective = result [
251+ self ._metric + self .lagrange ] if self ._metric_constraints \
252+ else result [self ._metric ]
223253 if (objective - self ._metric_target ) * self ._ls .metric_op < 0 :
224254 self ._metric_target = objective
225- if not thread_id and self ._create_condition (result ):
255+ if not thread_id and metric_constraint_satisfied \
256+ and self ._create_condition (result ):
226257 # thread creator
227258 self ._search_thread_pool [self ._thread_count ] = SearchThread (
228259 self ._ls .mode ,
@@ -233,6 +264,9 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
233264 self ._thread_count += 1
234265 self ._update_admissible_region (
235266 config , self ._ls_bound_min , self ._ls_bound_max )
267+ elif thread_id and not self ._metric_constraint_satisfied :
268+ # no point has been found to satisfy metric constraint
269+ self ._expand_admissible_region ()
236270 # reset admissible region to ls bounding box
237271 self ._gs_admissible_min .update (self ._ls_bound_min )
238272 self ._gs_admissible_max .update (self ._ls_bound_max )
@@ -306,6 +340,8 @@ def on_trial_result(self, trial_id: str, result: Dict):
306340 thread_id = self ._trial_proposed_by [trial_id ]
307341 if thread_id not in self ._search_thread_pool :
308342 return
343+ if result and self ._metric_constraints :
344+ result [self ._metric + self .lagrange ] = result [self ._metric ]
309345 self ._search_thread_pool [thread_id ].on_trial_result (trial_id , result )
310346
311347 def suggest (self , trial_id : str ) -> Optional [Dict ]:
0 commit comments