Skip to content

Commit f27b4b8

Browse files
esantorellafacebook-github-bot
authored andcommitted
When there are no completed or feasible trials, return the first arm as the best point (facebook#3352)
Summary: Pull Request resolved: facebook#3352 Context: When there are no completed or feasible trials, `BestPointMixin._get_best_trial` will return None, even if there are trials that are or running (potentially with available data) or early-stopped. This is problematic when we need to construct an inference trace with early-stopped trials, since it is plausible that all trials will either violate constraints or have stopped early (at least in benchmarksand in unit tests, maybe not in reality). I discovered this was an issue with an existing unit test while working on some related logic -- I think we've been silently sidestepping it until now. Even broader context: Yes, the best-point utilities are flawed and it would be better to fix them. I think the right fix would be an exception, so there would still be a need for a change in benchmarking logic if we want to ensure that some point is always recommended. This PR: * Has `BenchmarkMethod.get_best_parameters` return the first arm from the most recently created trial if no trials are completed and satisfy constraints Reviewed By: saitcakmak Differential Revision: D69488839 fbshipit-source-id: abb136b9add65510c740780fb321352963ba0d28
1 parent 51a90db commit f27b4b8

File tree

2 files changed

+47
-6
lines changed

2 files changed

+47
-6
lines changed

ax/benchmark/benchmark_method.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
MultiObjectiveOptimizationConfig,
1313
OptimizationConfig,
1414
)
15+
from ax.core.trial_status import TrialStatus
1516
from ax.core.types import TParameterization
1617
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
1718

@@ -101,9 +102,19 @@ def get_best_parameters(
101102
raise NotImplementedError(
102103
f"Currently only n_points=1 is supported. Got {n_points=}."
103104
)
105+
if len(experiment.trials) == 0:
106+
raise ValueError(
107+
"Cannot identify a best point if experiment has no trials."
108+
)
109+
110+
def _get_first_parameterization_from_last_trial() -> TParameterization:
111+
return experiment.trials[max(experiment.trials)].arms[0].parameters
104112

105113
# SOO, n=1 case.
106114
# Note: This has the same effect as Scheduler.get_best_parameters
115+
if len(experiment.trials_by_status[TrialStatus.COMPLETED]) == 0:
116+
return [_get_first_parameterization_from_last_trial()]
117+
107118
result = BestPointMixin._get_best_trial(
108119
experiment=experiment,
109120
generation_strategy=self.generation_strategy,
@@ -113,7 +124,7 @@ def get_best_parameters(
113124
if result is None:
114125
# This can happen if no points are predicted to satisfy all outcome
115126
# constraints.
116-
return []
117-
118-
i, params, prediction = none_throws(result)
127+
params = _get_first_parameterization_from_last_trial()
128+
else:
129+
i, params, prediction = none_throws(result)
119130
return [params]

ax/benchmark/tests/test_benchmark_method.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
)
1414
from ax.benchmark.methods.sobol import get_sobol_generation_strategy
1515
from ax.core.experiment import Experiment
16+
from ax.modelbridge.factory import get_sobol
1617
from ax.utils.common.testutils import TestCase
18+
from ax.utils.testing.core_stubs import get_experiment_with_observations
1719
from pyre_extensions import none_throws
1820

1921

@@ -71,8 +73,36 @@ def test_get_best_parameters(self) -> None:
7173
experiment=experiment, optimization_config=soo_config, n_points=2
7274
)
7375

74-
with self.subTest("Empty experiment"):
75-
result = method.get_best_parameters(
76+
with self.subTest("Empty experiment"), self.assertRaisesRegex(
77+
ValueError, "Cannot identify a best point if experiment has no trials"
78+
):
79+
method.get_best_parameters(
7680
experiment=experiment, optimization_config=soo_config, n_points=1
7781
)
78-
self.assertEqual(result, [])
82+
83+
with self.subTest("All constraints violated"):
84+
experiment = get_experiment_with_observations(
85+
observations=[[1, -1], [2, -1]],
86+
constrained=True,
87+
)
88+
best_point = method.get_best_parameters(
89+
n_points=1,
90+
experiment=experiment,
91+
optimization_config=none_throws(experiment.optimization_config),
92+
)
93+
self.assertEqual(len(best_point), 1)
94+
self.assertEqual(best_point[0], experiment.trials[1].arms[0].parameters)
95+
96+
with self.subTest("No completed trials"):
97+
experiment = get_experiment_with_observations(observations=[])
98+
sobol_generator = get_sobol(search_space=experiment.search_space)
99+
for _ in range(3):
100+
trial = experiment.new_trial(generator_run=sobol_generator.gen(n=1))
101+
trial.run()
102+
best_point = method.get_best_parameters(
103+
n_points=1,
104+
experiment=experiment,
105+
optimization_config=none_throws(experiment.optimization_config),
106+
)
107+
self.assertEqual(len(best_point), 1)
108+
self.assertEqual(best_point[0], experiment.trials[2].arms[0].parameters)

0 commit comments

Comments
 (0)