Skip to content

Commit 81b1c6c

Browse files
committed
address edge cases
1 parent 9bbad4c commit 81b1c6c

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

package/samplers/turbo/sampler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -217,18 +217,22 @@ def sample_relative(
217217
self._trial_ids_for_trust_region[id].append(trial._trial_id)
218218
return {}
219219

220-
# todo(sawa3030): no trial might be get if it takes time to evaluate objective function
220+
states = (TrialState.COMPLETE,)
221+
all_trials = study._get_trials(deepcopy=False, states=states, use_cache=True)
222+
221223
for id in range(self._n_trust_region):
222224
if self._cached_params_by_tr[id] is not None:
223225
continue
224226

225-
states = (TrialState.COMPLETE,)
226-
all_trials = study._get_trials(deepcopy=False, states=states, use_cache=True)
227227
trials = []
228228
for t in all_trials:
229229
if t._trial_id in self._trial_ids_for_trust_region[id]:
230230
trials.append(t)
231231

232+
if len(trials) < self._n_startup_trials:
233+
self._trial_ids_for_trust_region[id].append(trial._trial_id)
234+
return {}
235+
232236
internal_search_space = gp_search_space.SearchSpace(search_space)
233237
normalized_params = internal_search_space.get_normalized_params(trials)
234238

@@ -322,13 +326,13 @@ def after_trial(
322326
state: TrialState,
323327
values: Sequence[float] | None,
324328
) -> None:
325-
assert values is not None # todo(sawa3030): handle this case
326-
assert len(values) == 1
327329
for id in range(self._n_trust_region):
328330
if trial._trial_id in self._trial_ids_for_trust_region[id]:
329331
self._cached_params_by_tr[id] = None
330332
self._cached_acqf_by_tr[id] = None
331-
self._count_and_adjust_trust_region_length(id, values, study.direction)
333+
if values is not None:
334+
assert len(values) == 1
335+
self._count_and_adjust_trust_region_length(id, values, study.direction)
332336
break
333337

334338
self._independent_sampler.after_trial(study, trial, state, values)

0 commit comments

Comments
 (0)