Skip to content

Commit ecac750

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Make objective direction checks more strict (#2382)
Summary: With this change, at least one of `minimize` or `lower_is_better` must be specified. If both are specified, they must match. Added a json storage helper & updated SQA storage helpers for deserializing previously saved objectives in a backwards compatible manner, resolving the conflicts in favor of `minimize`. Differential Revision: D56315542
1 parent 5b38f2a commit ecac750

20 files changed

+822
-846
lines changed

ax/core/objective.py

+17-26
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing import Any, Iterable, List, Optional, Tuple
1414

1515
from ax.core.metric import Metric
16+
from ax.exceptions.core import UserInputError
1617
from ax.utils.common.base import SortableBase
1718
from ax.utils.common.logger import get_logger
1819
from ax.utils.common.typeutils import not_none
@@ -34,36 +35,27 @@ def __init__(self, metric: Metric, minimize: Optional[bool] = None) -> None:
3435
metric: The metric to be optimized.
3536
minimize: If True, minimize metric. If None, will be set based on the
3637
`lower_is_better` property of the metric (if that is not specified,
37-
will raise a DeprecationWarning).
38+
will raise a `UserInputError`).
3839
3940
"""
4041
lower_is_better = metric.lower_is_better
4142
if minimize is None:
4243
if lower_is_better is None:
43-
warnings.warn(
44-
f"Defaulting to `minimize=False` for metric {metric.name} not "
45-
+ "specifying `lower_is_better` property. This is a wild guess. "
46-
+ "Specify either `lower_is_better` on the metric, or specify "
47-
+ "`minimize` explicitly. This will become an error in the future.",
48-
DeprecationWarning,
44+
raise UserInputError(
45+
f"Metric {metric.name} does not specify `lower_is_better` "
46+
"and `minimize` is not specified. At least one of these "
47+
"must be specified."
4948
)
50-
minimize = False
5149
else:
5250
minimize = lower_is_better
53-
if lower_is_better is not None:
54-
if lower_is_better and not minimize:
55-
warnings.warn(
56-
f"Attempting to maximize metric {metric.name} with property "
57-
"`lower_is_better=True`."
58-
)
59-
elif not lower_is_better and minimize:
60-
warnings.warn(
61-
f"Attempting to minimize metric {metric.name} with property "
62-
"`lower_is_better=False`."
63-
)
64-
self._metric = metric
65-
# pyre-fixme[4]: Attribute must be annotated.
66-
self.minimize = not_none(minimize)
51+
elif lower_is_better is not None and lower_is_better != minimize:
52+
raise UserInputError(
53+
f"Metric {metric.name} specifies {lower_is_better=}, "
54+
"which doesn't match the specified optimization direction "
55+
f"{minimize=}."
56+
)
57+
self._metric: Metric = metric
58+
self.minimize: bool = not_none(minimize)
6759

6860
@property
6961
def metric(self) -> Metric:
@@ -130,18 +122,17 @@ def __init__(
130122
"as input to `MultiObjective` constructor."
131123
)
132124
metrics = extra_kwargs["metrics"]
133-
minimize = extra_kwargs.get("minimize", False)
125+
minimize = extra_kwargs.get("minimize", None)
134126
warnings.warn(
135127
"Passing `metrics` and `minimize` as input to the `MultiObjective` "
136128
"constructor will soon be deprecated. Instead, pass a list of "
137129
"`objectives`. This will become an error in the future.",
138130
DeprecationWarning,
131+
stacklevel=2,
139132
)
140133
objectives = []
141134
for metric in metrics:
142-
lower_is_better = metric.lower_is_better or False
143-
_minimize = not lower_is_better if minimize else lower_is_better
144-
objectives.append(Objective(metric=metric, minimize=_minimize))
135+
objectives.append(Objective(metric=metric, minimize=minimize))
145136

146137
# pyre-fixme[4]: Attribute must be annotated.
147138
self._objectives = not_none(objectives)

ax/core/tests/test_objective.py

+23-25
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77
# pyre-strict
88

9-
import warnings
10-
119
from ax.core.metric import Metric
1210
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
11+
from ax.exceptions.core import UserInputError
1312
from ax.utils.common.testutils import TestCase
1413

1514

@@ -21,7 +20,7 @@ def setUp(self) -> None:
2120
"m3": Metric(name="m3", lower_is_better=False),
2221
}
2322
self.objectives = {
24-
"o1": Objective(metric=self.metrics["m1"]),
23+
"o1": Objective(metric=self.metrics["m1"], minimize=True),
2524
"o2": Objective(metric=self.metrics["m2"], minimize=True),
2625
"o3": Objective(metric=self.metrics["m3"], minimize=False),
2726
}
@@ -38,6 +37,12 @@ def setUp(self) -> None:
3837
)
3938

4039
def test_Init(self) -> None:
40+
with self.assertRaisesRegex(UserInputError, "does not specify"):
41+
Objective(metric=self.metrics["m1"]),
42+
with self.assertRaisesRegex(
43+
UserInputError, "doesn't match the specified optimization direction"
44+
):
45+
Objective(metric=self.metrics["m2"], minimize=False)
4146
with self.assertRaises(ValueError):
4247
ScalarizedObjective(
4348
metrics=[self.metrics["m1"], self.metrics["m2"]], weights=[1.0]
@@ -52,20 +57,6 @@ def test_Init(self) -> None:
5257
metrics=[self.metrics["m1"], self.metrics["m2"]],
5358
minimize=False,
5459
)
55-
warnings.resetwarnings()
56-
warnings.simplefilter("always", append=True)
57-
with warnings.catch_warnings(record=True) as ws:
58-
Objective(metric=self.metrics["m1"])
59-
self.assertTrue(any(issubclass(w.category, DeprecationWarning) for w in ws))
60-
self.assertTrue(
61-
any("Defaulting to `minimize=False`" in str(w.message) for w in ws)
62-
)
63-
with warnings.catch_warnings(record=True) as ws:
64-
Objective(Metric(name="m4", lower_is_better=True), minimize=False)
65-
self.assertTrue(any("Attempting to maximize" in str(w.message) for w in ws))
66-
with warnings.catch_warnings(record=True) as ws:
67-
Objective(Metric(name="m4", lower_is_better=False), minimize=True)
68-
self.assertTrue(any("Attempting to minimize" in str(w.message) for w in ws))
6960
self.assertEqual(
7061
self.objective.get_unconstrainable_metrics(), [self.metrics["m1"]]
7162
)
@@ -77,15 +68,15 @@ def test_MultiObjective(self) -> None:
7768

7869
self.assertEqual(self.multi_objective.metrics, list(self.metrics.values()))
7970
minimizes = [obj.minimize for obj in self.multi_objective.objectives]
80-
self.assertEqual(minimizes, [False, True, False])
71+
self.assertEqual(minimizes, [True, True, False])
8172
weights = [mw[1] for mw in self.multi_objective.objective_weights]
8273
self.assertEqual(weights, [1.0, 1.0, 1.0])
8374
self.assertEqual(self.multi_objective.clone(), self.multi_objective)
8475
self.assertEqual(
8576
str(self.multi_objective),
8677
(
8778
"MultiObjective(objectives="
88-
'[Objective(metric_name="m1", minimize=False), '
79+
'[Objective(metric_name="m1", minimize=True), '
8980
'Objective(metric_name="m2", minimize=True), '
9081
'Objective(metric_name="m3", minimize=False)])'
9182
),
@@ -96,19 +87,26 @@ def test_MultiObjective(self) -> None:
9687
)
9788

9889
def test_MultiObjectiveBackwardsCompatibility(self) -> None:
99-
multi_objective = MultiObjective(
100-
metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]]
101-
)
90+
metrics = [
91+
Metric(name="m1", lower_is_better=False),
92+
self.metrics["m2"],
93+
self.metrics["m3"],
94+
]
95+
multi_objective = MultiObjective(metrics=metrics)
10296
minimizes = [obj.minimize for obj in multi_objective.objectives]
103-
self.assertEqual(multi_objective.metrics, list(self.metrics.values()))
97+
self.assertEqual(multi_objective.metrics, metrics)
10498
self.assertEqual(minimizes, [False, True, False])
10599

106100
multi_objective_min = MultiObjective(
107-
metrics=[self.metrics["m1"], self.metrics["m2"], self.metrics["m3"]],
101+
metrics=[
102+
Metric(name="m1"),
103+
Metric(name="m2"),
104+
Metric(name="m3", lower_is_better=True),
105+
],
108106
minimize=True,
109107
)
110108
minimizes = [obj.minimize for obj in multi_objective_min.objectives]
111-
self.assertEqual(minimizes, [True, False, True])
109+
self.assertEqual(minimizes, [True, True, True])
112110

113111
def test_ScalarizedObjective(self) -> None:
114112
with self.assertRaises(NotImplementedError):

ax/core/tests/test_optimization_config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def setUp(self) -> None:
277277
"o2": Objective(metric=self.metrics["m2"], minimize=False),
278278
"o3": Objective(metric=self.metrics["m3"], minimize=False),
279279
}
280-
self.objective = Objective(metric=self.metrics["m1"], minimize=False)
280+
self.objective = Objective(metric=self.metrics["m1"], minimize=True)
281281
self.multi_objective = MultiObjective(
282282
objectives=[self.objectives["o1"], self.objectives["o2"]]
283283
)

ax/core/tests/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def setUp(self) -> None:
158158
self.data = Data(df=self.df)
159159

160160
self.optimization_config = OptimizationConfig(
161-
objective=Objective(metric=Metric(name="a")),
161+
objective=Objective(metric=Metric(name="a"), minimize=False),
162162
outcome_constraints=[
163163
OutcomeConstraint(
164164
metric=Metric(name="b"),

ax/modelbridge/tests/test_base_modelbridge.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def test_ModelBridge(
156156
observation_features=[get_observation1trans().features], weights=[2]
157157
),
158158
)
159-
oc = OptimizationConfig(objective=Objective(metric=Metric(name="test_metric")))
159+
oc = get_optimization_config_no_constraints()
160160
modelbridge._set_kwargs_to_save(
161161
model_key="TestModel", model_kwargs={}, bridge_kwargs={}
162162
)
@@ -322,7 +322,7 @@ def warn_and_return_mock_obs(
322322
fit_tracking_metrics=False,
323323
)
324324
new_oc = OptimizationConfig(
325-
objective=Objective(metric=Metric(name="test_metric2"))
325+
objective=Objective(metric=Metric(name="test_metric2"), minimize=False),
326326
)
327327
with self.assertRaisesRegex(UnsupportedError, "fit_tracking_metrics"):
328328
modelbridge.gen(n=1, optimization_config=new_oc)

ax/modelbridge/tests/test_cross_validation.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def test_HasGoodOptConfigModelFit(self) -> None:
344344

345345
# Test single objective
346346
optimization_config = OptimizationConfig(
347-
objective=Objective(metric=Metric("a"))
347+
objective=Objective(metric=Metric("a"), minimize=True)
348348
)
349349
has_good_fit = has_good_opt_config_model_fit(
350350
optimization_config=optimization_config,
@@ -354,7 +354,12 @@ def test_HasGoodOptConfigModelFit(self) -> None:
354354

355355
# Test multi objective
356356
optimization_config = MultiObjectiveOptimizationConfig(
357-
objective=MultiObjective(metrics=[Metric("a"), Metric("b")])
357+
objective=MultiObjective(
358+
objectives=[
359+
Objective(Metric("a"), minimize=False),
360+
Objective(Metric("b"), minimize=False),
361+
]
362+
)
358363
)
359364
has_good_fit = has_good_opt_config_model_fit(
360365
optimization_config=optimization_config,
@@ -364,7 +369,7 @@ def test_HasGoodOptConfigModelFit(self) -> None:
364369

365370
# Test constraints
366371
optimization_config = OptimizationConfig(
367-
objective=Objective(metric=Metric("a")),
372+
objective=Objective(metric=Metric("a"), minimize=False),
368373
outcome_constraints=[
369374
OutcomeConstraint(metric=Metric("b"), op=ComparisonOp.GEQ, bound=0.1)
370375
],

ax/modelbridge/tests/test_torch_modelbridge.py

+4-9
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
get_branin_experiment,
4646
get_branin_search_space,
4747
get_experiment_with_observations,
48+
get_optimization_config_no_constraints,
4849
get_search_space_for_range_value,
4950
)
5051
from ax.utils.testing.mock import fast_botorch_optimize
@@ -363,9 +364,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
363364
observation_features=[
364365
ObservationFeatures(parameters={"x": 1.0, "y": 2.0})
365366
],
366-
optimization_config=OptimizationConfig(
367-
objective=Objective(metric=Metric(name="test_metric"))
368-
),
367+
optimization_config=get_optimization_config_no_constraints(),
369368
)
370369

371370
self.assertEqual(acqf_vals, [5.0])
@@ -392,9 +391,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
392391
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
393392
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
394393
],
395-
optimization_config=OptimizationConfig(
396-
objective=Objective(metric=Metric(name="test_metric"))
397-
),
394+
optimization_config=get_optimization_config_no_constraints(),
398395
)
399396
t.transform_observation_features.assert_any_call(
400397
[ObservationFeatures(parameters={"x": 1.0, "y": 2.0})],
@@ -418,9 +415,7 @@ def test_evaluate_acquisition_function(self, _, mock_torch_model: Mock) -> None:
418415
ObservationFeatures(parameters={"x": 1.0, "y": 2.0}),
419416
]
420417
],
421-
optimization_config=OptimizationConfig(
422-
objective=Objective(metric=Metric(name="test_metric"))
423-
),
418+
optimization_config=get_optimization_config_no_constraints(),
424419
)
425420
t.transform_observation_features.assert_any_call(
426421
[

ax/modelbridge/tests/test_utils.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def test_extract_outcome_constraints(self) -> None:
117117
def test_extract_objective_thresholds(self) -> None:
118118
outcomes = ["m1", "m2", "m3", "m4"]
119119
objective = MultiObjective(
120-
objectives=[Objective(metric=Metric(name)) for name in outcomes[:3]]
120+
objectives=[
121+
Objective(metric=Metric(name), minimize=False) for name in outcomes[:3]
122+
]
121123
)
122124
objective_thresholds = [
123125
ObjectiveThreshold(
@@ -159,7 +161,7 @@ def test_extract_objective_thresholds(self) -> None:
159161
self.assertTrue(np.isnan(obj_t[-2:]).all())
160162

161163
# Fails if a threshold does not have a corresponding metric.
162-
objective2 = Objective(Metric("m1"))
164+
objective2 = Objective(Metric("m1"), minimize=False)
163165
with self.assertRaisesRegex(ValueError, "corresponding metrics"):
164166
extract_objective_thresholds(
165167
objective_thresholds=objective_thresholds,

ax/modelbridge/transforms/tests/test_derelativize_transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def test_DerelativizeTransform(
102102
)
103103

104104
# Test with no relative constraints
105-
objective = Objective(Metric("c"))
105+
objective = Objective(Metric("c"), minimize=True)
106106
oc = OptimizationConfig(
107107
objective=objective,
108108
outcome_constraints=[
@@ -300,7 +300,7 @@ def test_Errors(self) -> None:
300300
observations=[],
301301
)
302302
oc = OptimizationConfig(
303-
objective=Objective(Metric("c")),
303+
objective=Objective(Metric("c"), minimize=False),
304304
outcome_constraints=[
305305
OutcomeConstraint(Metric("a"), ComparisonOp.LEQ, bound=2, relative=True)
306306
],

ax/modelbridge/transforms/tests/test_winsorize_transform.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def test_relative_constraints(
581581
RangeParameter("y", ParameterType.FLOAT, 0, 20),
582582
]
583583
)
584-
objective = Objective(Metric("c"))
584+
objective = Objective(Metric("c"), minimize=False)
585585

586586
# Test with relative constraint, in-design status quo
587587
oc = OptimizationConfig(

ax/service/tests/scheduler_test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def setUp(self) -> None:
321321
self.branin_experiment_no_impl_runner_or_metrics = Experiment(
322322
search_space=get_branin_search_space(),
323323
optimization_config=OptimizationConfig(
324-
objective=Objective(metric=Metric(name="branin"))
324+
objective=Objective(metric=Metric(name="branin"), minimize=False)
325325
),
326326
name="branin_experiment_no_impl_runner_or_metrics",
327327
)

ax/service/tests/test_report_utils.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -560,11 +560,11 @@ def test_get_metric_name_pairs(self) -> None:
560560
exp._optimization_config = MultiObjectiveOptimizationConfig(
561561
objective=MultiObjective(
562562
objectives=[
563-
Objective(metric=Metric("m0")),
564-
Objective(metric=Metric("m1")),
565-
Objective(metric=Metric("m2")),
566-
Objective(metric=Metric("m3")),
567-
Objective(metric=Metric("m4")),
563+
Objective(metric=Metric("m0"), minimize=False),
564+
Objective(metric=Metric("m1"), minimize=False),
565+
Objective(metric=Metric("m2"), minimize=False),
566+
Objective(metric=Metric("m3"), minimize=False),
567+
Objective(metric=Metric("m4"), minimize=False),
568568
]
569569
)
570570
)
@@ -1052,9 +1052,9 @@ def test_compare_to_baseline_moo(self) -> None:
10521052
optimization_config = MultiObjectiveOptimizationConfig(
10531053
objective=MultiObjective(
10541054
objectives=[
1055-
Objective(metric=Metric("m0")),
1055+
Objective(metric=Metric("m0"), minimize=False),
10561056
Objective(metric=Metric("m1"), minimize=True),
1057-
Objective(metric=Metric("m3")),
1057+
Objective(metric=Metric("m3"), minimize=False),
10581058
]
10591059
)
10601060
)

ax/service/utils/report_utils.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
)
2626

2727
import gpytorch
28-
2928
import numpy as np
3029
import pandas as pd
3130
import plotly.graph_objects as go
@@ -140,7 +139,7 @@ def _get_objective_trace_plot(
140139
plot_objective_value_vs_trial_index(
141140
exp_df=exp_df,
142141
metric_colname=metric_name,
143-
minimize=(
142+
minimize=not_none(
144143
optimization_config.objective.minimize
145144
if optimization_config.objective.metric.name == metric_name
146145
else experiment.metrics[metric_name].lower_is_better

0 commit comments

Comments
 (0)