Skip to content

Commit

Permalink
use most recent trial if no SQ data for target trial in TransformToNe…
Browse files Browse the repository at this point in the history
…wSQ (#3225)

Summary:

see title. This ensures that status_quo_data_by_trial contains the target trial index by default.

Differential Revision: D67875128
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 16, 2025
1 parent 5be6f17 commit f88252d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
15 changes: 10 additions & 5 deletions ax/modelbridge/transforms/tests/test_transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ def setUp(self) -> None:
t.mark_completed()
self.data = self.exp.fetch_data()

self._refresh_modelbridge()

def _refresh_modelbridge(self) -> None:
self.modelbridge = ModelBridge(
search_space=self.exp.search_space,
model=Model(),
experiment=self.exp,
data=self.data,
data=self.exp.lookup_data(),
status_quo_name="status_quo",
)

Expand Down Expand Up @@ -141,16 +144,18 @@ def test_single_trial_is_not_transformed(self) -> None:
obs2 = tf.transform_observations(obs)
self.assertEqual(obs, obs2)

def test_taget_trial_index(self) -> None:
def test_target_trial_index(self) -> None:
sobol = get_sobol(search_space=self.exp.search_space)
self.exp.new_batch_trial(generator_run=sobol.gen(2))
self.exp.new_batch_trial(generator_run=sobol.gen(2), optimize_for_power=True)
t = self.exp.trials[1]
t = assert_is_instance(t, BatchTrial)
t.mark_running(no_runner_required=True)
self.exp.attach_data(
get_branin_data_batch(batch=assert_is_instance(t, BatchTrial))
)

self._refresh_modelbridge()

observations = observations_from_data(
experiment=self.exp,
data=self.exp.lookup_data(),
Expand All @@ -166,12 +171,12 @@ def test_taget_trial_index(self) -> None:

with mock.patch(
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
return_value=10,
return_value=0,
):
t = TransformToNewSQ(
search_space=self.exp.search_space,
observations=observations,
modelbridge=self.modelbridge,
)

self.assertEqual(t.default_trial_idx, 10)
self.assertEqual(t.default_trial_idx, 0)
3 changes: 3 additions & 0 deletions ax/modelbridge/transforms/transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(
target_trial_index = get_target_trial_index(
experiment=modelbridge._experiment
)
trials_indices_with_sq_data = self.status_quo_data_by_trial.keys()
if target_trial_index not in trials_indices_with_sq_data:
target_trial_index = max(trials_indices_with_sq_data)

if target_trial_index is not None:
self.default_trial_idx: int = assert_is_instance(
Expand Down

0 comments on commit f88252d

Please sign in to comment.