Skip to content

Commit

Permalink
use most recent trial if no SQ data for target trial in TransformToNe…
Browse files Browse the repository at this point in the history
…wSQ (facebook#3225)

Summary:

see title. This ensures that status_quo_data_by_trial contains the target trial index by default.

Differential Revision: D67875128
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 16, 2025
1 parent 5be6f17 commit 3d5d17f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
15 changes: 10 additions & 5 deletions ax/modelbridge/transforms/tests/test_transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ def setUp(self) -> None:
t.mark_completed()
self.data = self.exp.fetch_data()

self._refresh_modelbridge()

def _refresh_modelbridge(self) -> None:
self.modelbridge = ModelBridge(
search_space=self.exp.search_space,
model=Model(),
experiment=self.exp,
data=self.data,
data=self.exp.lookup_data(),
status_quo_name="status_quo",
)

Expand Down Expand Up @@ -141,16 +144,18 @@ def test_single_trial_is_not_transformed(self) -> None:
obs2 = tf.transform_observations(obs)
self.assertEqual(obs, obs2)

def test_taget_trial_index(self) -> None:
def test_target_trial_index(self) -> None:
sobol = get_sobol(search_space=self.exp.search_space)
self.exp.new_batch_trial(generator_run=sobol.gen(2))
self.exp.new_batch_trial(generator_run=sobol.gen(2), optimize_for_power=True)
t = self.exp.trials[1]
t = assert_is_instance(t, BatchTrial)
t.mark_running(no_runner_required=True)
self.exp.attach_data(
get_branin_data_batch(batch=assert_is_instance(t, BatchTrial))
)

self._refresh_modelbridge()

observations = observations_from_data(
experiment=self.exp,
data=self.exp.lookup_data(),
Expand All @@ -166,12 +171,12 @@ def test_taget_trial_index(self) -> None:

with mock.patch(
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
return_value=10,
return_value=0,
):
t = TransformToNewSQ(
search_space=self.exp.search_space,
observations=observations,
modelbridge=self.modelbridge,
)

self.assertEqual(t.default_trial_idx, 10)
self.assertEqual(t.default_trial_idx, 0)
3 changes: 3 additions & 0 deletions ax/modelbridge/transforms/transform_to_new_sq.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ def __init__(
target_trial_index = get_target_trial_index(
experiment=modelbridge._experiment
)
trials_indices_with_sq_data = self.status_quo_data_by_trial.keys()
if target_trial_index not in trials_indices_with_sq_data:
target_trial_index = max(trials_indices_with_sq_data)

if target_trial_index is not None:
self.default_trial_idx: int = assert_is_instance(
Expand Down

0 comments on commit 3d5d17f

Please sign in to comment.