Skip to content

Commit

Permalink
When there are no completed or feasible trials, return the first arm …
Browse files Browse the repository at this point in the history
…as the best point (facebook#3352)

Summary:
Pull Request resolved: facebook#3352

Context:

When there are no completed or feasible trials, `BestPointMixin._get_best_trial` will return None, even if there are trials that are or running (potentially with available data) or early-stopped. This is problematic when we need to construct an inference trace with early-stopped trials, since it is plausible that all trials will either violate constraints or have stopped early (at least in benchmarksand in unit tests, maybe not in reality).

I discovered this was an issue with an existing unit test while working on some related logic -- I think we've been silently sidestepping it until now.

Even broader context: Yes, the best-point utilities are flawed and it would be better to fix them. I think the right fix would be an exception, so there would still be a need for a change in benchmarking logic if we want to ensure that some point is always recommended.

This PR:
* Has `BenchmarkMethod.get_best_parameters` return the first arm from the most recently created trial if no trials are completed and satisfy constraints

Reviewed By: saitcakmak

Differential Revision: D69488839

fbshipit-source-id: abb136b9add65510c740780fb321352963ba0d28
  • Loading branch information
esantorella authored and facebook-github-bot committed Feb 14, 2025
1 parent 51a90db commit f27b4b8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 6 deletions.
17 changes: 14 additions & 3 deletions ax/benchmark/benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
MultiObjectiveOptimizationConfig,
OptimizationConfig,
)
from ax.core.trial_status import TrialStatus
from ax.core.types import TParameterization
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy

Expand Down Expand Up @@ -101,9 +102,19 @@ def get_best_parameters(
raise NotImplementedError(
f"Currently only n_points=1 is supported. Got {n_points=}."
)
if len(experiment.trials) == 0:
raise ValueError(
"Cannot identify a best point if experiment has no trials."
)

def _get_first_parameterization_from_last_trial() -> TParameterization:
return experiment.trials[max(experiment.trials)].arms[0].parameters

# SOO, n=1 case.
# Note: This has the same effect as Scheduler.get_best_parameters
if len(experiment.trials_by_status[TrialStatus.COMPLETED]) == 0:
return [_get_first_parameterization_from_last_trial()]

result = BestPointMixin._get_best_trial(
experiment=experiment,
generation_strategy=self.generation_strategy,
Expand All @@ -113,7 +124,7 @@ def get_best_parameters(
if result is None:
# This can happen if no points are predicted to satisfy all outcome
# constraints.
return []

i, params, prediction = none_throws(result)
params = _get_first_parameterization_from_last_trial()
else:
i, params, prediction = none_throws(result)
return [params]
36 changes: 33 additions & 3 deletions ax/benchmark/tests/test_benchmark_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
)
from ax.benchmark.methods.sobol import get_sobol_generation_strategy
from ax.core.experiment import Experiment
from ax.modelbridge.factory import get_sobol
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment_with_observations
from pyre_extensions import none_throws


Expand Down Expand Up @@ -71,8 +73,36 @@ def test_get_best_parameters(self) -> None:
experiment=experiment, optimization_config=soo_config, n_points=2
)

with self.subTest("Empty experiment"):
result = method.get_best_parameters(
with self.subTest("Empty experiment"), self.assertRaisesRegex(
ValueError, "Cannot identify a best point if experiment has no trials"
):
method.get_best_parameters(
experiment=experiment, optimization_config=soo_config, n_points=1
)
self.assertEqual(result, [])

with self.subTest("All constraints violated"):
experiment = get_experiment_with_observations(
observations=[[1, -1], [2, -1]],
constrained=True,
)
best_point = method.get_best_parameters(
n_points=1,
experiment=experiment,
optimization_config=none_throws(experiment.optimization_config),
)
self.assertEqual(len(best_point), 1)
self.assertEqual(best_point[0], experiment.trials[1].arms[0].parameters)

with self.subTest("No completed trials"):
experiment = get_experiment_with_observations(observations=[])
sobol_generator = get_sobol(search_space=experiment.search_space)
for _ in range(3):
trial = experiment.new_trial(generator_run=sobol_generator.gen(n=1))
trial.run()
best_point = method.get_best_parameters(
n_points=1,
experiment=experiment,
optimization_config=none_throws(experiment.optimization_config),
)
self.assertEqual(len(best_point), 1)
self.assertEqual(best_point[0], experiment.trials[2].arms[0].parameters)

0 comments on commit f27b4b8

Please sign in to comment.