Skip to content

Commit 3d5d17f

Browse files
sdaultonfacebook-github-bot
authored andcommitted
use most recent trial if no SQ data for target trial in TransformToNewSQ (facebook#3225)
Summary: see title. This ensures that status_quo_data_by_trial contains the target trial index by default. Differential Revision: D67875128
1 parent 5be6f17 commit 3d5d17f

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

ax/modelbridge/transforms/tests/test_transform_to_new_sq.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,14 @@ def setUp(self) -> None:
7070
t.mark_completed()
7171
self.data = self.exp.fetch_data()
7272

73+
self._refresh_modelbridge()
74+
75+
def _refresh_modelbridge(self) -> None:
7376
self.modelbridge = ModelBridge(
7477
search_space=self.exp.search_space,
7578
model=Model(),
7679
experiment=self.exp,
77-
data=self.data,
80+
data=self.exp.lookup_data(),
7881
status_quo_name="status_quo",
7982
)
8083

@@ -141,16 +144,18 @@ def test_single_trial_is_not_transformed(self) -> None:
141144
obs2 = tf.transform_observations(obs)
142145
self.assertEqual(obs, obs2)
143146

144-
def test_taget_trial_index(self) -> None:
147+
def test_target_trial_index(self) -> None:
145148
sobol = get_sobol(search_space=self.exp.search_space)
146-
self.exp.new_batch_trial(generator_run=sobol.gen(2))
149+
self.exp.new_batch_trial(generator_run=sobol.gen(2), optimize_for_power=True)
147150
t = self.exp.trials[1]
148151
t = assert_is_instance(t, BatchTrial)
149152
t.mark_running(no_runner_required=True)
150153
self.exp.attach_data(
151154
get_branin_data_batch(batch=assert_is_instance(t, BatchTrial))
152155
)
153156

157+
self._refresh_modelbridge()
158+
154159
observations = observations_from_data(
155160
experiment=self.exp,
156161
data=self.exp.lookup_data(),
@@ -166,12 +171,12 @@ def test_taget_trial_index(self) -> None:
166171

167172
with mock.patch(
168173
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
169-
return_value=10,
174+
return_value=0,
170175
):
171176
t = TransformToNewSQ(
172177
search_space=self.exp.search_space,
173178
observations=observations,
174179
modelbridge=self.modelbridge,
175180
)
176181

177-
self.assertEqual(t.default_trial_idx, 10)
182+
self.assertEqual(t.default_trial_idx, 0)

ax/modelbridge/transforms/transform_to_new_sq.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def __init__(
7373
target_trial_index = get_target_trial_index(
7474
experiment=modelbridge._experiment
7575
)
76+
trials_indices_with_sq_data = self.status_quo_data_by_trial.keys()
77+
if target_trial_index not in trials_indices_with_sq_data:
78+
target_trial_index = max(trials_indices_with_sq_data)
7679

7780
if target_trial_index is not None:
7881
self.default_trial_idx: int = assert_is_instance(

0 commit comments

Comments
 (0)