@@ -217,18 +217,22 @@ def sample_relative(
217217 self ._trial_ids_for_trust_region [id ].append (trial ._trial_id )
218218 return {}
219219
220- # todo(sawa3030): no trial might be get if it takes time to evaluate objective function
220+ states = (TrialState .COMPLETE ,)
221+ all_trials = study ._get_trials (deepcopy = False , states = states , use_cache = True )
222+
221223 for id in range (self ._n_trust_region ):
222224 if self ._cached_params_by_tr [id ] is not None :
223225 continue
224226
225- states = (TrialState .COMPLETE ,)
226- all_trials = study ._get_trials (deepcopy = False , states = states , use_cache = True )
227227 trials = []
228228 for t in all_trials :
229229 if t ._trial_id in self ._trial_ids_for_trust_region [id ]:
230230 trials .append (t )
231231
232+ if len (trials ) < self ._n_startup_trials :
233+ self ._trial_ids_for_trust_region [id ].append (trial ._trial_id )
234+ return {}
235+
232236 internal_search_space = gp_search_space .SearchSpace (search_space )
233237 normalized_params = internal_search_space .get_normalized_params (trials )
234238
@@ -322,13 +326,13 @@ def after_trial(
322326 state : TrialState ,
323327 values : Sequence [float ] | None ,
324328 ) -> None :
325- assert values is not None # todo(sawa3030): handle this case
326- assert len (values ) == 1
327329 for id in range (self ._n_trust_region ):
328330 if trial ._trial_id in self ._trial_ids_for_trust_region [id ]:
329331 self ._cached_params_by_tr [id ] = None
330332 self ._cached_acqf_by_tr [id ] = None
331- self ._count_and_adjust_trust_region_length (id , values , study .direction )
333+ if values is not None :
334+ assert len (values ) == 1
335+ self ._count_and_adjust_trust_region_length (id , values , study .direction )
332336 break
333337
334338 self ._independent_sampler .after_trial (study , trial , state , values )
0 commit comments