Skip to content

Commit e29e8ba

Browse files
sdaultonfacebook-github-bot
authored andcommitted
use most recent trial if no SQ data for target trial in TransformToNewSQ (#3225)
Summary: Pull Request resolved: #3225 see title. This ensures that status_quo_data_by_trial contains the target trial index by default. Reviewed By: danielcohenlive Differential Revision: D67875128 fbshipit-source-id: 89a48f2812b9ae84a7ac3496de9f89961adce178
1 parent b2d01c1 commit e29e8ba

File tree

2 files changed

+31
-4
lines changed

2 files changed

+31
-4
lines changed

ax/modelbridge/transforms/tests/test_transform_to_new_sq.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,14 @@ def setUp(self) -> None:
7070
t.mark_completed()
7171
self.data = self.exp.fetch_data()
7272

73+
self._refresh_modelbridge()
74+
75+
def _refresh_modelbridge(self) -> None:
7376
self.modelbridge = ModelBridge(
7477
search_space=self.exp.search_space,
7578
model=Model(),
7679
experiment=self.exp,
77-
data=self.data,
80+
data=self.exp.lookup_data(),
7881
status_quo_name="status_quo",
7982
)
8083

@@ -141,16 +144,18 @@ def test_single_trial_is_not_transformed(self) -> None:
141144
obs2 = tf.transform_observations(obs)
142145
self.assertEqual(obs, obs2)
143146

144-
def test_taget_trial_index(self) -> None:
147+
def test_target_trial_index(self) -> None:
145148
sobol = get_sobol(search_space=self.exp.search_space)
146-
self.exp.new_batch_trial(generator_run=sobol.gen(2))
149+
self.exp.new_batch_trial(generator_run=sobol.gen(2), optimize_for_power=True)
147150
t = self.exp.trials[1]
148151
t = assert_is_instance(t, BatchTrial)
149152
t.mark_running(no_runner_required=True)
150153
self.exp.attach_data(
151154
get_branin_data_batch(batch=assert_is_instance(t, BatchTrial))
152155
)
153156

157+
self._refresh_modelbridge()
158+
154159
observations = observations_from_data(
155160
experiment=self.exp,
156161
data=self.exp.lookup_data(),
@@ -164,6 +169,18 @@ def test_taget_trial_index(self) -> None:
164169

165170
self.assertEqual(t.default_trial_idx, 1)
166171

172+
with mock.patch(
173+
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
174+
return_value=0,
175+
):
176+
t = TransformToNewSQ(
177+
search_space=self.exp.search_space,
178+
observations=observations,
179+
modelbridge=self.modelbridge,
180+
)
181+
182+
self.assertEqual(t.default_trial_idx, 0)
183+
# test falling back to latest trial with SQ data
167184
with mock.patch(
168185
"ax.modelbridge.transforms.transform_to_new_sq.get_target_trial_index",
169186
return_value=10,
@@ -174,4 +191,4 @@ def test_taget_trial_index(self) -> None:
174191
modelbridge=self.modelbridge,
175192
)
176193

177-
self.assertEqual(t.default_trial_idx, 10)
194+
self.assertEqual(t.default_trial_idx, 1)

ax/modelbridge/transforms/transform_to_new_sq.py

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from __future__ import annotations
1010

1111
from collections.abc import Callable
12+
from logging import Logger
1213

1314
from math import sqrt
1415
from typing import TYPE_CHECKING
@@ -22,12 +23,14 @@
2223
from ax.core.utils import get_target_trial_index
2324
from ax.modelbridge.transforms.relativize import BaseRelativize, get_metric_index
2425
from ax.models.types import TConfig
26+
from ax.utils.common.logger import get_logger
2527
from ax.utils.stats.statstools import relativize, unrelativize
2628
from pyre_extensions import assert_is_instance, none_throws
2729

2830
if TYPE_CHECKING:
2931
# import as module to make sphinx-autodoc-typehints happy
3032
from ax import modelbridge as modelbridge_module # noqa F401
33+
logger: Logger = get_logger(__name__)
3134

3235

3336
class TransformToNewSQ(BaseRelativize):
@@ -73,6 +76,13 @@ def __init__(
7376
target_trial_index = get_target_trial_index(
7477
experiment=modelbridge._experiment
7578
)
79+
trials_indices_with_sq_data = self.status_quo_data_by_trial.keys()
80+
if target_trial_index not in trials_indices_with_sq_data:
81+
target_trial_index = max(trials_indices_with_sq_data)
82+
logger.info(
83+
"No SQ data for target trial. Failing back to "
84+
f"{target_trial_index}."
85+
)
7686

7787
if target_trial_index is not None:
7888
self.default_trial_idx: int = assert_is_instance(

0 commit comments

Comments
 (0)