Skip to content

Commit 68bc41c

Browse files
sdaultonfacebook-github-bot
authored andcommitted
Remove inferring reference point in Scheduler for GSS (facebook#2383)
Summary: Pull Request resolved: facebook#2383 We should be inferring the reference point in the GSS when needed. This moves the call to infer the reference point to inside of the GSS. Reviewed By: Balandat Differential Revision: D56227773 fbshipit-source-id: 6a63844fd3d083b12b4f693a6f9a66b9fd89b00b
1 parent cbb6cda commit 68bc41c

File tree

5 files changed

+85
-123
lines changed

5 files changed

+85
-123
lines changed

ax/global_stopping/strategies/improvement.py

+25-11
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
from ax.core.types import ComparisonOp
2323
from ax.global_stopping.strategies.base import BaseGlobalStoppingStrategy
2424
from ax.modelbridge.modelbridge_utils import observed_hypervolume
25-
from ax.plot.pareto_utils import get_tensor_converter_model
26-
from ax.service.utils.best_point import fill_missing_thresholds_from_nadir
25+
from ax.plot.pareto_utils import (
26+
get_tensor_converter_model,
27+
infer_reference_point_from_experiment,
28+
)
2729
from ax.utils.common.logger import get_logger
2830
from ax.utils.common.typeutils import checked_cast, not_none
2931

@@ -78,6 +80,7 @@ def __init__(
7880
self.window_size = window_size
7981
self.improvement_bar = improvement_bar
8082
self.hv_by_trial: Dict[int, float] = {}
83+
self._inferred_objective_thresholds: Optional[List[ObjectiveThreshold]] = None
8184

8285
def _should_stop_optimization(
8386
self,
@@ -104,6 +107,9 @@ def _should_stop_optimization(
104107
when computing hv of the pareto front against. This is used only in the
105108
MOO setting. If not specified, the objective thresholds on the
106109
experiment's optimization config will be used for the purpose.
110+
If no thresholds are provided, they are automatically inferred. They are
111+
only inferred once for each instance of the strategy (i.e. inferred
112+
thresholds don't update with additional data).
107113
kwargs: Unused.
108114
109115
Returns:
@@ -138,10 +144,25 @@ def _should_stop_optimization(
138144
return stop, message
139145

140146
if isinstance(experiment.optimization_config, MultiObjectiveOptimizationConfig):
147+
if objective_thresholds is None:
148+
# self._inferred_objective_thresholds is cached and only computed once.
149+
if self._inferred_objective_thresholds is None:
150+
# only infer reference point if there is data on the experiment.
151+
data = experiment.fetch_data()
152+
if not data.df.empty:
153+
# We infer the nadir reference point to be used by the GSS.
154+
self._inferred_objective_thresholds = (
155+
infer_reference_point_from_experiment(
156+
experiment=experiment, data=data
157+
)
158+
)
159+
# TODO: move this out into a separate infer_objective_thresholds
160+
# instance method or property that handles the caching.
161+
objective_thresholds = self._inferred_objective_thresholds
141162
return self._should_stop_moo(
142163
experiment=experiment,
143164
trial_to_check=trial_to_check,
144-
objective_thresholds=objective_thresholds,
165+
objective_thresholds=not_none(objective_thresholds),
145166
)
146167
else:
147168
return self._should_stop_single_objective(
@@ -152,7 +173,7 @@ def _should_stop_moo(
152173
self,
153174
experiment: Experiment,
154175
trial_to_check: int,
155-
objective_thresholds: Optional[List[ObjectiveThreshold]] = None,
176+
objective_thresholds: List[ObjectiveThreshold],
156177
) -> Tuple[bool, str]:
157178
"""
158179
This is the "should_stop_optimization" method of this class, specialized
@@ -186,13 +207,6 @@ def _should_stop_moo(
186207
data_df_reference = data_df[data_df["trial_index"] <= reference_trial_index]
187208
data_df = data_df[data_df["trial_index"] <= trial_to_check]
188209

189-
optimization_config = checked_cast(
190-
MultiObjectiveOptimizationConfig, experiment.optimization_config
191-
).clone_with_args(objective_thresholds=objective_thresholds)
192-
objective_thresholds = fill_missing_thresholds_from_nadir(
193-
experiment=experiment, optimization_config=optimization_config
194-
)
195-
196210
# Computing or retrieving HV at "window_size" iteration before
197211
if reference_trial_index in self.hv_by_trial:
198212
hv_reference = self.hv_by_trial[reference_trial_index]

ax/global_stopping/tests/test_strategies.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def test_multi_objective(self) -> None:
269269
]
270270
exp = self._create_multi_objective_experiment(metric_values=metric_values)
271271
gss = ImprovementGlobalStoppingStrategy(
272-
min_trials=3, window_size=3, improvement_bar=0.1
272+
min_trials=3, window_size=3, improvement_bar=0.3
273273
)
274274
stop, message = gss.should_stop_optimization(experiment=exp, trial_to_check=4)
275275
self.assertFalse(stop)
@@ -280,9 +280,9 @@ def test_multi_objective(self) -> None:
280280

281281
self.assertEqual(
282282
message,
283-
"The improvement in hypervolume in the past 3 trials (=0.000) is "
284-
"less than improvement_bar (=0.1) times the hypervolume at the "
285-
"start of the window (=0.055).",
283+
"The improvement in hypervolume in the past 3 trials (=0.289) is "
284+
"less than improvement_bar (=0.3) times the hypervolume at the "
285+
"start of the window (=0.104).",
286286
)
287287

288288
# Now we select a very far custom reference point against which the pareto front

ax/plot/pareto_utils.py

+55-23
Original file line numberDiff line numberDiff line change
@@ -6,24 +6,20 @@
66

77
# pyre-strict
88

9-
import copy
109
from copy import deepcopy
1110
from itertools import combinations
1211
from logging import Logger
13-
from typing import cast, Dict, List, NamedTuple, Optional, Tuple, Union
12+
from typing import Dict, List, NamedTuple, Optional, Tuple, Union
1413

1514
import numpy as np
1615
import torch
1716
from ax.core.batch_trial import BatchTrial
1817
from ax.core.data import Data
1918
from ax.core.experiment import Experiment
2019
from ax.core.metric import Metric
21-
from ax.core.objective import ScalarizedObjective
20+
from ax.core.objective import MultiObjective, ScalarizedObjective
2221
from ax.core.observation import ObservationFeatures
23-
from ax.core.optimization_config import (
24-
MultiObjectiveOptimizationConfig,
25-
OptimizationConfig,
26-
)
22+
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
2723
from ax.core.outcome_constraint import (
2824
ComparisonOp,
2925
ObjectiveThreshold,
@@ -43,6 +39,7 @@
4339
from ax.models.torch.posterior_mean import get_PosteriorMean
4440
from ax.models.torch_base import TorchModel
4541
from ax.utils.common.logger import get_logger
42+
from ax.utils.common.typeutils import checked_cast
4643
from ax.utils.stats.statstools import relativize
4744
from botorch.utils.multi_objective import is_non_dominated
4845
from botorch.utils.multi_objective.hypervolume import infer_reference_point
@@ -615,11 +612,24 @@ def infer_reference_point_from_experiment(
615612
# when calculating the Pareto front. Also, defining a multiplier to turn all
616613
# the objectives to be maximized. Note that the multiplier at this point
617614
# contains 0 for outcome_constraint metrics, but this will be dropped later.
618-
dummy_rp = copy.deepcopy(
619-
experiment.optimization_config.objective_thresholds # pyre-ignore
615+
opt_config = checked_cast(
616+
MultiObjectiveOptimizationConfig, experiment.optimization_config
620617
)
618+
inferred_rp = _get_objective_thresholds(optimization_config=opt_config)
621619
multiplier = [0] * len(objective_orders)
622-
for ot in dummy_rp:
620+
if len(opt_config.objective_thresholds) > 0:
621+
inferred_rp = deepcopy(opt_config.objective_thresholds)
622+
else:
623+
inferred_rp = []
624+
for objective in checked_cast(MultiObjective, opt_config.objective).objectives:
625+
ot = ObjectiveThreshold(
626+
metric=objective.metric,
627+
bound=0.0, # dummy value
628+
op=ComparisonOp.LEQ if objective.minimize else ComparisonOp.GEQ,
629+
relative=False,
630+
)
631+
inferred_rp.append(ot)
632+
for ot in inferred_rp:
623633
# In the following, we find the index of the objective in
624634
# `objective_orders`. If there is an objective that does not exist
625635
# in `obs_data`, a ValueError is raised.
@@ -640,12 +650,10 @@ def infer_reference_point_from_experiment(
640650
modelbridge=mb_reference,
641651
observation_features=obs_feats,
642652
observation_data=obs_data,
643-
objective_thresholds=dummy_rp,
653+
objective_thresholds=inferred_rp,
644654
use_model_predictions=False,
645655
)
646-
647656
if len(frontier_observations) == 0:
648-
opt_config = cast(OptimizationConfig, mb_reference._optimization_config)
649657
outcome_constraints = opt_config._outcome_constraints
650658
if len(outcome_constraints) == 0:
651659
raise RuntimeError(
@@ -665,10 +673,11 @@ def infer_reference_point_from_experiment(
665673
modelbridge=mb_reference,
666674
observation_features=obs_feats,
667675
observation_data=obs_data,
668-
objective_thresholds=dummy_rp,
676+
objective_thresholds=inferred_rp,
669677
use_model_predictions=False,
670678
)
671-
opt_config._outcome_constraints = outcome_constraints # restoring constraints
679+
# restoring constraints
680+
opt_config._outcome_constraints = outcome_constraints
672681

673682
# Need to reshuffle columns of `f` and `obj_w` to be consistent
674683
# with objective_orders.
@@ -698,15 +707,38 @@ def infer_reference_point_from_experiment(
698707
x for (i, x) in enumerate(objective_orders) if multiplier[i] != 0
699708
]
700709

701-
# Constructing the objective thresholds.
702-
# NOTE: This assumes that objective_thresholds is already initialized.
703-
nadir_objective_thresholds = copy.deepcopy(
704-
experiment.optimization_config.objective_thresholds
705-
)
706-
707-
for obj_threshold in nadir_objective_thresholds:
710+
for obj_threshold in inferred_rp:
708711
obj_threshold.bound = rp[
709712
objective_orders_reduced.index(obj_threshold.metric.name)
710713
].item()
714+
return inferred_rp
711715

712-
return nadir_objective_thresholds
716+
717+
def _get_objective_thresholds(
718+
optimization_config: MultiObjectiveOptimizationConfig,
719+
) -> List[ObjectiveThreshold]:
720+
"""Get objective thresholds for an optimization config.
721+
722+
This will return objective thresholds with dummy values if there are
723+
no objective thresholds on the optimization config.
724+
725+
Args:
726+
optimization_config: Optimization config.
727+
728+
Returns:
729+
List of objective thresholds.
730+
"""
731+
if optimization_config.objective_thresholds is not None:
732+
return deepcopy(optimization_config.objective_thresholds)
733+
objective_thresholds = []
734+
for objective in checked_cast(
735+
MultiObjective, optimization_config.objective
736+
).objectives:
737+
ot = ObjectiveThreshold(
738+
metric=objective.metric,
739+
bound=0.0, # dummy value
740+
op=ComparisonOp.LEQ if objective.minimize else ComparisonOp.GEQ,
741+
relative=False,
742+
)
743+
objective_thresholds.append(ot)
744+
return objective_thresholds

ax/service/scheduler.py

+1-24
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
MultiObjectiveOptimizationConfig,
4242
OptimizationConfig,
4343
)
44-
from ax.core.outcome_constraint import ObjectiveThreshold
4544
from ax.core.runner import Runner
4645
from ax.core.types import TModelPredictArm, TParameterization
4746

@@ -59,7 +58,6 @@
5958
)
6059
from ax.modelbridge.base import ModelBridge
6160
from ax.modelbridge.generation_strategy import GenerationStrategy
62-
from ax.plot.pareto_utils import infer_reference_point_from_experiment
6361
from ax.service.utils.best_point_mixin import BestPointMixin
6462
from ax.service.utils.scheduler_options import SchedulerOptions, TrialType
6563
from ax.service.utils.with_db_settings_base import DBSettings, WithDBSettingsBase
@@ -202,9 +200,6 @@ class Scheduler(WithDBSettingsBase, BestPointMixin):
202200
# applications where the user wants to run the optimization loop to exhaust
203201
# the declared number of trials.
204202
__ignore_global_stopping_strategy: bool = False
205-
# In MOO cases, the following will be populated by an inferred reference point
206-
# for pareto front after a certain number of completed trials.
207-
__inferred_reference_point: Optional[List[ObjectiveThreshold]] = None
208203
# Default kwargs passed when fetching data if not overridden on `SchedulerOptions`
209204
DEFAULT_FETCH_KWARGS = {
210205
"overwrite_existing_data": True,
@@ -444,26 +439,8 @@ def completion_criterion(self) -> Tuple[bool, str]:
444439
and self.options.global_stopping_strategy is not None
445440
):
446441
gss = not_none(self.options.global_stopping_strategy)
447-
if (
448-
self.experiment.is_moo_problem
449-
and self.__inferred_reference_point is None
450-
and len(self.experiment.trials_by_status[TrialStatus.COMPLETED])
451-
>= gss.min_trials
452-
):
453-
# only infer reference point if there is data on the experiment.
454-
data = self.experiment.fetch_data()
455-
if not data.df.empty:
456-
# We infer the nadir reference point to be used by the GSS.
457-
self.__inferred_reference_point = (
458-
infer_reference_point_from_experiment(
459-
self.experiment,
460-
data=data,
461-
)
462-
)
463-
464442
stop_optimization, global_stopping_msg = gss.should_stop_optimization(
465-
experiment=self.experiment,
466-
objective_thresholds=self.__inferred_reference_point,
443+
experiment=self.experiment
467444
)
468445
if stop_optimization:
469446
return True, global_stopping_msg

ax/service/tests/scheduler_test_utils.py

-61
Original file line numberDiff line numberDiff line change
@@ -793,67 +793,6 @@ def test_run_preattached_trials_only(self) -> None:
793793
all(t.completed_successfully for t in scheduler.experiment.trials.values())
794794
)
795795

796-
def test_inferring_reference_point(self) -> None:
797-
init_test_engine_and_session_factory(force_init=True)
798-
experiment = get_branin_experiment_with_multi_objective()
799-
experiment.runner = self.runner
800-
gs = self._get_generation_strategy_strategy_for_test(
801-
experiment=experiment,
802-
generation_strategy=self.sobol_GS_no_parallelism,
803-
)
804-
805-
scheduler = Scheduler(
806-
experiment=experiment,
807-
generation_strategy=gs,
808-
options=SchedulerOptions(
809-
# Stops the optimization after 5 trials.
810-
global_stopping_strategy=DummyGlobalStoppingStrategy(
811-
min_trials=2, trial_to_stop=5
812-
),
813-
),
814-
db_settings=self.db_settings,
815-
)
816-
817-
with patch(
818-
"ax.service.scheduler.infer_reference_point_from_experiment"
819-
) as mock_infer_rp:
820-
scheduler.run_n_trials(max_trials=10)
821-
mock_infer_rp.assert_called_once()
822-
823-
def test_inferring_reference_point_no_data(self) -> None:
824-
init_test_engine_and_session_factory(force_init=True)
825-
experiment = get_branin_experiment_with_multi_objective()
826-
experiment.runner = self.runner
827-
gs = self._get_generation_strategy_strategy_for_test(
828-
experiment=experiment,
829-
generation_strategy=self.sobol_GS_no_parallelism,
830-
)
831-
832-
scheduler = Scheduler(
833-
experiment=experiment,
834-
generation_strategy=gs,
835-
options=SchedulerOptions(
836-
# Stops the optimization after 5 trials.
837-
global_stopping_strategy=DummyGlobalStoppingStrategy(
838-
min_trials=0,
839-
trial_to_stop=5,
840-
),
841-
),
842-
db_settings=self.db_settings,
843-
)
844-
empty_data = Data(
845-
df=pd.DataFrame(
846-
columns=["metric_name", "arm_name", "trial_index", "mean", "sem"]
847-
)
848-
)
849-
with patch(
850-
"ax.service.scheduler.infer_reference_point_from_experiment"
851-
) as mock_infer_rp, patch.object(
852-
scheduler.experiment, "fetch_data", return_value=empty_data
853-
):
854-
scheduler.run_n_trials(max_trials=1)
855-
mock_infer_rp.assert_not_called()
856-
857796
def test_global_stopping(self) -> None:
858797
gs = self._get_generation_strategy_strategy_for_test(
859798
experiment=self.branin_experiment,

0 commit comments

Comments
 (0)