Skip to content

Commit 04c9aba

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Raise UnsupportedError for scalarized objectives in best_point_mixin and report_utils (#5065)
Summary: Pull Request resolved: #5065 Several methods in `best_point_mixin.py` and `report_utils.py` use `metric_names[0]` as the objective metric and call `.minimize`, which is semantically wrong for scalarized objectives (where the objective is a combination of metrics). Add early guards raising `UnsupportedError`. Affected methods: - `_get_trace_by_progression` - `get_improvement_over_baseline` - `_get_objective_trace_plot` - `maybe_extract_baseline_comparison_values` Reviewed By: Balandat Differential Revision: D97123523 fbshipit-source-id: 664f3284cf3c39cc3bb1ec30c5a4d2f985f64123
1 parent 3280609 commit 04c9aba

4 files changed

Lines changed: 93 additions & 11 deletions

File tree

ax/service/tests/test_best_point.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,15 @@
2424
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2525
from ax.core.optimization_config import (
2626
MultiObjectiveOptimizationConfig,
27+
OptimizationConfig,
2728
PreferenceOptimizationConfig,
2829
)
2930
from ax.core.outcome_constraint import OutcomeConstraint
3031
from ax.core.parameter import ParameterType, RangeParameter
3132
from ax.core.search_space import SearchSpace
3233
from ax.core.trial import Trial
3334
from ax.core.types import ComparisonOp
34-
from ax.exceptions.core import DataRequiredError, UserInputError
35+
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
3536
from ax.service.utils.best_point import (
3637
get_tensor_converter_adapter,
3738
get_trace,
@@ -799,3 +800,25 @@ def test_get_tensor_converter_adapter(self) -> None:
799800
self.assertIsInstance(
800801
get_tensor_converter_adapter(experiment=experiment), TorchAdapter
801802
)
803+
804+
def test_get_trace_by_progression_scalarized(self) -> None:
805+
"""_get_trace_by_progression raises UnsupportedError for scalarized."""
806+
experiment = get_experiment_with_trial()
807+
experiment._optimization_config = OptimizationConfig(
808+
objective=Objective(expression="2*m1 + -1*m2"),
809+
)
810+
with self.assertRaisesRegex(UnsupportedError, "not supported for scalarized"):
811+
BestPointMixin._get_trace_by_progression(experiment=experiment)
812+
813+
def test_get_improvement_over_baseline_scalarized(self) -> None:
814+
"""get_improvement_over_baseline raises UnsupportedError for scalarized."""
815+
experiment = get_experiment_with_trial()
816+
experiment._optimization_config = OptimizationConfig(
817+
objective=Objective(expression="2*m1 + -1*m2"),
818+
)
819+
mixin = BestPointMixin.__new__(BestPointMixin)
820+
with self.assertRaisesRegex(UnsupportedError, "not supported for scalarized"):
821+
mixin.get_improvement_over_baseline(
822+
experiment=experiment,
823+
generation_strategy=Mock(),
824+
)

ax/service/tests/test_report_utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,13 @@
2020
from ax.core.arm import Arm
2121
from ax.core.metric import Metric
2222
from ax.core.objective import MultiObjective, Objective
23-
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
23+
from ax.core.optimization_config import (
24+
MultiObjectiveOptimizationConfig,
25+
OptimizationConfig,
26+
)
2427
from ax.core.outcome_constraint import ObjectiveThreshold
2528
from ax.core.types import ComparisonOp
29+
from ax.exceptions.core import UnsupportedError
2630
from ax.generation_strategy.generation_node import GenerationStep
2731
from ax.generation_strategy.generation_strategy import GenerationStrategy
2832
from ax.orchestration.orchestrator import Orchestrator
@@ -49,6 +53,7 @@
4953
get_branin_experiment,
5054
get_branin_experiment_with_multi_objective,
5155
get_branin_experiment_with_timestamp_map_metric,
56+
get_branin_metric,
5257
get_experiment_with_observations,
5358
get_high_dimensional_branin_experiment,
5459
get_multi_type_experiment,
@@ -743,3 +748,31 @@ def test_maybe_extract_baseline_comparison_values_metric_missing_moo(self) -> No
743748
baseline_arm_name=arm_names[0],
744749
)
745750
self.assertIsNone(result)
751+
752+
def test_get_objective_trace_plot_scalarized(self) -> None:
753+
"""_get_objective_trace_plot raises UnsupportedError for scalarized."""
754+
exp = get_branin_experiment(with_completed_trial=True)
755+
exp.add_tracking_metric(get_branin_metric(name="branin2"))
756+
exp._optimization_config = OptimizationConfig(
757+
objective=Objective(expression="2*branin + -1*branin2"),
758+
)
759+
with self.assertRaisesRegex(UnsupportedError, "not supported for scalarized"):
760+
_get_objective_trace_plot(experiment=exp)
761+
762+
def test_maybe_extract_baseline_comparison_values_scalarized(self) -> None:
763+
"""maybe_extract_baseline_comparison_values raises UnsupportedError
764+
for scalarized."""
765+
exp = get_branin_experiment_with_multi_objective(with_batch=True)
766+
exp.trials[0].run()
767+
exp.fetch_data()
768+
arm_names = list(exp.arms_by_name.keys())
769+
exp._optimization_config = OptimizationConfig(
770+
objective=Objective(expression="2*branin_a + -1*branin_b"),
771+
)
772+
with self.assertRaisesRegex(UnsupportedError, "not supported for scalarized"):
773+
maybe_extract_baseline_comparison_values(
774+
experiment=exp,
775+
optimization_config=exp.optimization_config,
776+
comparison_arm_names=[arm_names[1]],
777+
baseline_arm_name=arm_names[0],
778+
)

ax/service/utils/best_point_mixin.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from ax.core.trial import Trial
2424
from ax.core.types import TModelPredictArm, TParameterization
25-
from ax.exceptions.core import UserInputError
25+
from ax.exceptions.core import UnsupportedError, UserInputError
2626
from ax.generation_strategy.generation_strategy import GenerationStrategy
2727
from ax.service.utils import best_point as best_point_utils
2828
from ax.service.utils.best_point import get_tensor_converter_adapter
@@ -385,6 +385,12 @@ def _get_trace_by_progression(
385385
optimization_config = optimization_config or none_throws(
386386
experiment.optimization_config
387387
)
388+
if optimization_config.objective.is_scalarized_objective:
389+
raise UnsupportedError(
390+
"`_get_trace_by_progression` is not supported for scalarized "
391+
"objectives. The objective is a combination of metrics, not a "
392+
"single metric."
393+
)
388394
objective = optimization_config.objective.metric_names[0]
389395
minimize = optimization_config.objective.minimize
390396
map_data = experiment.lookup_data()
@@ -458,16 +464,21 @@ def get_improvement_over_baseline(
458464
"`get_improvement_over_baseline` not yet implemented"
459465
+ " for multi-objective problems."
460466
)
467+
optimization_config = experiment.optimization_config
468+
if not optimization_config:
469+
raise ValueError("No optimization config found.")
470+
if optimization_config.objective.is_scalarized_objective:
471+
raise UnsupportedError(
472+
"`get_improvement_over_baseline` is not supported for "
473+
"scalarized objectives. The objective is a combination of "
474+
"metrics, not a single metric."
475+
)
461476
if not baseline_arm_name:
462477
baseline_arm_name, _ = select_baseline_name_default_first_trial(
463478
experiment=experiment,
464479
baseline_arm_name=baseline_arm_name,
465480
)
466481

467-
optimization_config = experiment.optimization_config
468-
if not optimization_config:
469-
raise ValueError("No optimization config found.")
470-
471482
objective_metric_name = optimization_config.objective.metric_names[0]
472483

473484
# get the baseline trial

ax/service/utils/report_utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
from ax.core.trial import BaseTrial
4343
from ax.core.trial_status import TrialStatus
4444
from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy
45-
from ax.exceptions.core import DataRequiredError, UserInputError
45+
from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError
4646
from ax.generation_strategy.generation_strategy import GenerationStrategy
4747
from ax.plot.contour import interact_contour_plotly
4848
from ax.plot.diagnostic import interact_cross_validation_plotly
@@ -123,10 +123,18 @@ def _get_objective_trace_plot(
123123
if optimization_config is None:
124124
return []
125125

126+
objective = optimization_config.objective
127+
if objective.is_scalarized_objective:
128+
raise UnsupportedError(
129+
"`_get_objective_trace_plot` is not supported for scalarized "
130+
"objectives. The objective is a combination of metrics, not a "
131+
"single metric."
132+
)
133+
126134
metric_names = (
127135
metric_name
128136
for metric_name in [
129-
optimization_config.objective.metric_names[0],
137+
objective.metric_names[0],
130138
true_objective_metric_name,
131139
]
132140
if metric_name is not None
@@ -137,8 +145,8 @@ def _get_objective_trace_plot(
137145
exp_df=exp_df,
138146
metric_colname=metric_name,
139147
minimize=none_throws(
140-
optimization_config.objective.minimize
141-
if optimization_config.objective.metric_names[0] == metric_name
148+
objective.minimize
149+
if objective.metric_names[0] == metric_name
142150
else experiment.metrics[metric_name].lower_is_better
143151
),
144152
title=f"Best {metric_name} found vs. trial index",
@@ -1372,6 +1380,13 @@ def maybe_extract_baseline_comparison_values(
13721380
result_list.append(result_tuple)
13731381
return result_list if result_list else None
13741382

1383+
if optimization_config.objective.is_scalarized_objective:
1384+
raise UnsupportedError(
1385+
"`maybe_extract_baseline_comparison_values` is not supported for "
1386+
"scalarized objectives. The objective is a combination of "
1387+
"metrics, not a single metric."
1388+
)
1389+
13751390
objective_name = optimization_config.objective.metric_names[0]
13761391

13771392
# Check if metric column exists in both comparison and baseline dataframes

0 commit comments

Comments
 (0)