Skip to content

Commit 853b340

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Make objective direction checks more strict (#2382)
Summary: Pull Request resolved: #2382 With this change, at least one of `minimize` or `lower_is_better` must be specified. If both are specified, they must match. Added a json storage helper & updated SQA storage helpers for deserializing previously saved objectives in a backwards compatible manner, resolving the conflicts in favor of `minimize`. Reviewed By: mpolson64 Differential Revision: D56315542 fbshipit-source-id: 5936fc3ecf5ee88ab80011ed095f98e097d16c21
1 parent 6c55c0f commit 853b340

20 files changed

+822
-838
lines changed

ax/core/objective.py

+17-26
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Iterable, List, Optional, Tuple
1414

1515
from ax.core.metric import Metric
16+
from ax.exceptions.core import UserInputError
1617
from ax.utils.common.base import SortableBase
1718
from ax.utils.common.logger import get_logger
1819
from ax.utils.common.typeutils import not_none
@@ -34,36 +35,27 @@ def __init__(self, metric: Metric, minimize: Optional[bool] = None) -> None:
3435
metric: The metric to be optimized.
3536
minimize: If True, minimize metric. If None, will be set based on the
3637
`lower_is_better` property of the metric (if that is not specified,
37-
will raise a DeprecationWarning).
38+
will raise a `UserInputError`).
3839
3940
"""
4041
lower_is_better = metric.lower_is_better
4142
if minimize is None:
4243
if lower_is_better is None:
43-
warnings.warn(
44-
f"Defaulting to `minimize=False` for metric {metric.name} not "
45-
+ "specifying `lower_is_better` property. This is a wild guess. "
46-
+ "Specify either `lower_is_better` on the metric, or specify "
47-
+ "`minimize` explicitly. This will become an error in the future.",
48-
DeprecationWarning,
44+
raise UserInputError(
45+
f"Metric {metric.name} does not specify `lower_is_better` "
46+
"and `minimize` is not specified. At least one of these "
47+
"must be specified."
4948
)
50-
minimize = False
5149
else:
5250
minimize = lower_is_better
53-
if lower_is_better is not None:
54-
if lower_is_better and not minimize:
55-
warnings.warn(
56-
f"Attempting to maximize metric {metric.name} with property "
57-
"`lower_is_better=True`."
58-
)
59-
elif not lower_is_better and minimize:
60-
warnings.warn(
61-
f"Attempting to minimize metric {metric.name} with property "
62-
"`lower_is_better=False`."
63-
)
64-
self._metric = metric
65-
# pyre-fixme[4]: Attribute must be annotated.
66-
self.minimize = not_none(minimize)
51+
elif lower_is_better is not None and lower_is_better != minimize:
52+
raise UserInputError(
53+
f"Metric {metric.name} specifies {lower_is_better=}, "
54+
"which doesn't match the specified optimization direction "
55+
f"{minimize=}."
56+
)
57+
self._metric: Metric = metric
58+
self.minimize: bool = not_none(minimize)
6759

6860
@property
6961
def metric(self) -> Metric:
@@ -130,18 +122,17 @@ def __init__(
130122
"as input to `MultiObjective` constructor."
131123
)
132124
metrics = extra_kwargs["metrics"]
133-
minimize = extra_kwargs.get("minimize", False)
125+
minimize = extra_kwargs.get("minimize", None)
134126
warnings.warn(
135127
"Passing `metrics` and `minimize` as input to the `MultiObjective` "
136128
"constructor will soon be deprecated. Instead, pass a list of "
137129
"`objectives`. This will become an error in the future.",
138130
DeprecationWarning,
131+
stacklevel=2,
139132
)
140133
objectives = []
141134
for metric in metrics:
142-
lower_is_better = metric.lower_is_better or False
143-
_minimize = not lower_is_better if minimize else lower_is_better
144-
objectives.append(Objective(metric=metric, minimize=_minimize))
135+
objectives.append(Objective(metric=metric, minimize=minimize))
145136

146137
# pyre-fixme[4]: Attribute must be annotated.
147138
self._objectives = not_none(objectives)

ax/core/tests/test_objective.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ax.core.metric import Metric
1010
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
11+
from ax.exceptions.core import UserInputError
1112
from ax.utils.common.testutils import TestCase
1213

1314

@@ -20,7 +21,7 @@ def setUp(self) -> None:
2021
"m3": Metric(name="m3", lower_is_better=False),
2122
}
2223
self.objectives = {
23-
"o1": Objective(metric=self.metrics["m1"]),
24+
"o1": Objective(metric=self.metrics["m1"], minimize=True),
2425
"o2": Objective(metric=self.metrics["m2"], minimize=True),
2526
"o3": Objective(metric=self.metrics["m3"], minimize=False),
2627
}
@@ -37,6 +38,12 @@ def setUp(self) -> None:
3738
)
3839

3940
def test_Init(self) -> None:
41+
with self.assertRaisesRegex(UserInputError, "does not specify"):
42+
Objective(metric=self.metrics["m1"]),
43+
with self.assertRaisesRegex(
44+
UserInputError, "doesn't match the specified optimization direction"
45+
):
46+
Objective(metric=self.metrics["m2"], minimize=False)
4047
with self.assertRaises(ValueError):
4148
ScalarizedObjective(
4249
metrics=[self.metrics["m1"], self.metrics["m2"]], weights=[1.0]
@@ -51,14 +58,6 @@ def test_Init(self) -> None:
5158
metrics=[self.metrics["m1"], self.metrics["m2"]],
5259
minimize=False,
5360
)
54-
with self.assertWarnsRegex(
55-
DeprecationWarning, "Defaulting to `minimize=False`"
56-
):
57-
Objective(metric=self.metrics["m1"])
58-
with self.assertWarnsRegex(UserWarning, "Attempting to maximize"):
59-
Objective(Metric(name="m4", lower_is_better=True), minimize=False)
60-
with self.assertWarnsRegex(UserWarning, "Attempting to minimize"):
61-
Objective(Metric(name="m4", lower_is_better=False), minimize=True)
6261
self.assertEqual(
6362
self.objective.get_unconstrainable_metrics(), [self.metrics["m1"]]
6463
)
@@ -70,15 +69,15 @@ def test_MultiObjective(self) -> None:
7069

7170
self.assertEqual(self.multi_objective.metrics, list(self.metrics.values()))
7271
minimizes = [obj.minimize for obj in self.multi_objective.objectives]
73-
self.assertEqual(minimizes, [False, True, False])
72+
self.assertEqual(minimizes, [True, True, False])
7473
weights = [mw[1] for mw in self.multi_objective.objective_weights]
7574
self.assertEqual(weights, [1.0, 1.0, 1.0])
7675
self.assertEqual(self.multi_objective.clone(), self.multi_objective)
7776
self.assertEqual(
7877
str(self.multi_objective),
7978
(
8079
"MultiObjective(objectives="
81-
'[Objective(metric_name="m1", minimize=False), '
80+
'[Objective(metric_name="m1", minimize=True), '
8281
'Objective(metric_name="m2", minimize=True), '
8382
'Objective(metric_name="m3", minimize=False)])'
8483
),
@@ -89,19 +88,26 @@ def test_MultiObjective(self) -> None:
8988
)
9089

9190
def test_MultiObjectiveBackwardsCompatibility(self) -> None:
92-
multi_objective = MultiObjective(
93-
metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]]
94-
)
91+
metrics = [
92+
Metric(name="m1", lower_is_better=False),
93+
self.metrics["m2"],
94+
self.metrics["m3"],
95+
]
96+
multi_objective = MultiObjective(metrics=metrics)
9597
minimizes = [obj.minimize for obj in multi_objective.objectives]
96-
self.assertEqual(multi_objective.metrics, list(self.metrics.values()))
98+
self.assertEqual(multi_objective.metrics, metrics)
9799
self.assertEqual(minimizes, [False, True, False])
98100

99101
multi_objective_min = MultiObjective(
100-
metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]],
102+
metrics=[
103+
Metric(name="m1"),
104+
Metric(name="m2"),
105+
Metric(name="m3", lower_is_better=True),
106+
],
101107
minimize=True,
102108
)
103109
minimizes = [obj.minimize for obj in multi_objective_min.objectives]
104-
self.assertEqual(minimizes, [True, False, True])
110+
self.assertEqual(minimizes, [True, True, True])
105111

106112
def test_ScalarizedObjective(self) -> None:
107113
with self.assertRaises(NotImplementedError):

ax/core/tests/test_optimization_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def setUp(self) -> None:
279279
"o2": Objective(metric=self.metrics["m2"], minimize=False),
280280
"o3": Objective(metric=self.metrics["m3"], minimize=False),
281281
}
282-
self.objective = Objective(metric=self.metrics["m1"], minimize=False)
282+
self.objective = Objective(metric=self.metrics["m1"], minimize=True)
283283
self.multi_objective = MultiObjective(
284284
objectives=[self.objectives["o1"], self.objectives["o2"]]
285285
)

ax/core/tests/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def setUp(self) -> None:
158158
self.data = Data(df=self.df)
159159

160160
self.optimization_config = OptimizationConfig(
161-
objective=Objective(metric=Metric(name="a")),
161+
objective=Objective(metric=Metric(name="a"), minimize=False),
162162
outcome_constraints=[
163163
OutcomeConstraint(
164164
metric=Metric(name="b"),

ax/modelbridge/tests/test_base_modelbridge.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_ModelBridge(
156156
observation_features=[get_observation1trans().features], weights=[2]
157157
),
158158
)
159-
oc = OptimizationConfig(objective=Objective(metric=Metric(name="test_metric")))
159+
oc = get_optimization_config_no_constraints()
160160
modelbridge._set_kwargs_to_save(
161161
model_key="TestModel", model_kwargs={}, bridge_kwargs={}
162162
)
@@ -322,7 +322,7 @@ def warn_and_return_mock_obs(
322322
fit_tracking_metrics=False,
323323
)
324324
new_oc = OptimizationConfig(
325-
objective=Objective(metric=Metric(name="test_metric2"))
325+
objective=Objective(metric=Metric(name="test_metric2"), minimize=False),
326326
)
327327
with self.assertRaisesRegex(UnsupportedError, "fit_tracking_metrics"):
328328
modelbridge.gen(n=1, optimization_config=new_oc)

ax/modelbridge/tests/test_cross_validation.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def test_HasGoodOptConfigModelFit(self) -> None:
345345

346346
# Test single objective
347347
optimization_config = OptimizationConfig(
348-
objective=Objective(metric=Metric("a"))
348+
objective=Objective(metric=Metric("a"), minimize=True)
349349
)
350350
has_good_fit = has_good_opt_config_model_fit(
351351
optimization_config=optimization_config,
@@ -355,7 +355,12 @@ def test_HasGoodOptConfigModelFit(self) -> None:
355355

356356
# Test multi objective
357357
optimization_config = MultiObjectiveOptimizationConfig(
358-
objective=MultiObjective(metrics=[Metric("a"), Metric("b")])
358+
objective=MultiObjective(
359+
objectives=[
360+
Objective(Metric("a"), minimize=False),
361+
Objective(Metric("b"), minimize=False),
362+
]
363+
)
359364
)
360365
has_good_fit = has_good_opt_config_model_fit(
361366
optimization_config=optimization_config,
@@ -365,7 +370,7 @@ def test_HasGoodOptConfigModelFit(self) -> None:
365370

366371
# Test constraints
367372
optimization_config = OptimizationConfig(
368-
objective=Objective(metric=Metric("a")),
373+
objective=Objective(metric=Metric("a"), minimize=False),
369374
outcome_constraints=[
370375
OutcomeConstraint(metric=Metric("b"), op=ComparisonOp.GEQ, bound=0.1)
371376
],

ax/modelbridge/tests/test_torch_modelbridge.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
get_branin_experiment,
4646
get_branin_search_space,
4747
get_experiment_with_observations,
48+
get_optimization_config_no_constraints,
4849
get_search_space_for_range_value,
4950
)
5051
from ax.utils.testing.mock import fast_botorch_optimize
@@ -363,9 +364,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
363364
observation_features=[
364365
ObservationFeatures(parameters={"x": 1.0, "y": 2.0})
365366
],
366-
optimization_config=OptimizationConfig(
367-
objective=Objective(metric=Metric(name="test_metric"))
368-
),
367+
optimization_config=get_optimization_config_no_constraints(),
369368
)
370369

371370
self.assertEqual(acqf_vals, [5.0])
@@ -392,9 +391,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
392391
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
393392
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
394393
],
395-
optimization_config=OptimizationConfig(
396-
objective=Objective(metric=Metric(name="test_metric"))
397-
),
394+
optimization_config=get_optimization_config_no_constraints(),
398395
)
399396
t.transform_observation_features.assert_any_call(
400397
[ObservationFeatures(parameters={"x": 1.0, "y": 2.0})],
@@ -418,9 +415,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
418415
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
419416
]
420417
],
421-
optimization_config=OptimizationConfig(
422-
objective=Objective(metric=Metric(name="test_metric"))
423-
),
418+
optimization_config=get_optimization_config_no_constraints(),
424419
)
425420
t.transform_observation_features.assert_any_call(
426421
[

ax/modelbridge/tests/test_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,9 @@ def test_extract_outcome_constraints(self) -> None:
118118
def test_extract_objective_thresholds(self) -> None:
119119
outcomes = ["m1", "m2", "m3", "m4"]
120120
objective = MultiObjective(
121-
objectives=[Objective(metric=Metric(name)) for name in outcomes[:3]]
121+
objectives=[
122+
Objective(metric=Metric(name), minimize=False) for name in outcomes[:3]
123+
]
122124
)
123125
objective_thresholds = [
124126
ObjectiveThreshold(
@@ -160,7 +162,7 @@ def test_extract_objective_thresholds(self) -> None:
160162
self.assertTrue(np.isnan(obj_t[-2:]).all())
161163

162164
# Fails if a threshold does not have a corresponding metric.
163-
objective2 = Objective(Metric("m1"))
165+
objective2 = Objective(Metric("m1"), minimize=False)
164166
with self.assertRaisesRegex(ValueError, "corresponding metrics"):
165167
extract_objective_thresholds(
166168
objective_thresholds=objective_thresholds,

ax/modelbridge/transforms/tests/test_derelativize_transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_DerelativizeTransform(
103103
)
104104

105105
# Test with no relative constraints
106-
objective = Objective(Metric("c"))
106+
objective = Objective(Metric("c"), minimize=True)
107107
oc = OptimizationConfig(
108108
objective=objective,
109109
outcome_constraints=[
@@ -301,7 +301,7 @@ def test_Errors(self) -> None:
301301
observations=[],
302302
)
303303
oc = OptimizationConfig(
304-
objective=Objective(Metric("c")),
304+
objective=Objective(Metric("c"), minimize=False),
305305
outcome_constraints=[
306306
OutcomeConstraint(Metric("a"), ComparisonOp.LEQ, bound=2, relative=True)
307307
],

ax/modelbridge/transforms/tests/test_winsorize_transform.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ def test_relative_constraints(
578578
RangeParameter("y", ParameterType.FLOAT, 0, 20),
579579
]
580580
)
581-
objective = Objective(Metric("c"))
581+
objective = Objective(Metric("c"), minimize=False)
582582

583583
# Test with relative constraint, in-design status quo
584584
oc = OptimizationConfig(

ax/service/tests/scheduler_test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def setUp(self) -> None:
321321
self.branin_experiment_no_impl_runner_or_metrics = Experiment(
322322
search_space=get_branin_search_space(),
323323
optimization_config=OptimizationConfig(
324-
objective=Objective(metric=Metric(name="branin"))
324+
objective=Objective(metric=Metric(name="branin"), minimize=False)
325325
),
326326
name="branin_experiment_no_impl_runner_or_metrics",
327327
)

ax/service/tests/test_report_utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -560,11 +560,11 @@ def test_get_metric_name_pairs(self) -> None:
560560
exp._optimization_config = MultiObjectiveOptimizationConfig(
561561
objective=MultiObjective(
562562
objectives=[
563-
Objective(metric=Metric("m0")),
564-
Objective(metric=Metric("m1")),
565-
Objective(metric=Metric("m2")),
566-
Objective(metric=Metric("m3")),
567-
Objective(metric=Metric("m4")),
563+
Objective(metric=Metric("m0"), minimize=False),
564+
Objective(metric=Metric("m1"), minimize=False),
565+
Objective(metric=Metric("m2"), minimize=False),
566+
Objective(metric=Metric("m3"), minimize=False),
567+
Objective(metric=Metric("m4"), minimize=False),
568568
]
569569
)
570570
)
@@ -1052,9 +1052,9 @@ def test_compare_to_baseline_moo(self) -> None:
10521052
optimization_config = MultiObjectiveOptimizationConfig(
10531053
objective=MultiObjective(
10541054
objectives=[
1055-
Objective(metric=Metric("m0")),
1055+
Objective(metric=Metric("m0"), minimize=False),
10561056
Objective(metric=Metric("m1"), minimize=True),
1057-
Objective(metric=Metric("m3")),
1057+
Objective(metric=Metric("m3"), minimize=False),
10581058
]
10591059
)
10601060
)

ax/service/utils/report_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
)
2626

2727
import gpytorch
28-
2928
import numpy as np
3029
import pandas as pd
3130
import plotly.graph_objects as go
@@ -140,7 +139,7 @@ def _get_objective_trace_plot(
140139
plot_objective_value_vs_trial_index(
141140
exp_df=exp_df,
142141
metric_colname=metric_name,
143-
minimize=(
142+
minimize=not_none(
144143
optimization_config.objective.minimize
145144
if optimization_config.objective.metric.name == metric_name
146145
else experiment.metrics[metric_name].lower_is_better

0 commit comments

Comments
 (0)