@@ -70,11 +70,14 @@ def setUp(self) -> None:
70
70
t .mark_completed ()
71
71
self .data = self .exp .fetch_data ()
72
72
73
+ self ._refresh_modelbridge ()
74
+
75
+ def _refresh_modelbridge (self ) -> None :
73
76
self .modelbridge = ModelBridge (
74
77
search_space = self .exp .search_space ,
75
78
model = Model (),
76
79
experiment = self .exp ,
77
- data = self .data ,
80
+ data = self .exp . lookup_data () ,
78
81
status_quo_name = "status_quo" ,
79
82
)
80
83
@@ -141,16 +144,18 @@ def test_single_trial_is_not_transformed(self) -> None:
141
144
obs2 = tf .transform_observations (obs )
142
145
self .assertEqual (obs , obs2 )
143
146
144
- def test_taget_trial_index (self ) -> None :
147
+ def test_target_trial_index (self ) -> None :
145
148
sobol = get_sobol (search_space = self .exp .search_space )
146
- self .exp .new_batch_trial (generator_run = sobol .gen (2 ))
149
+ self .exp .new_batch_trial (generator_run = sobol .gen (2 ), optimize_for_power = True )
147
150
t = self .exp .trials [1 ]
148
151
t = assert_is_instance (t , BatchTrial )
149
152
t .mark_running (no_runner_required = True )
150
153
self .exp .attach_data (
151
154
get_branin_data_batch (batch = assert_is_instance (t , BatchTrial ))
152
155
)
153
156
157
+ self ._refresh_modelbridge ()
158
+
154
159
observations = observations_from_data (
155
160
experiment = self .exp ,
156
161
data = self .exp .lookup_data (),
@@ -166,12 +171,12 @@ def test_taget_trial_index(self) -> None:
166
171
167
172
with mock .patch (
168
173
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index" ,
169
- return_value = 10 ,
174
+ return_value = 0 ,
170
175
):
171
176
t = TransformToNewSQ (
172
177
search_space = self .exp .search_space ,
173
178
observations = observations ,
174
179
modelbridge = self .modelbridge ,
175
180
)
176
181
177
- self .assertEqual (t .default_trial_idx , 10 )
182
+ self .assertEqual (t .default_trial_idx , 0 )
0 commit comments