Skip to content

Commit 358c4c6

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Fix metric_config_summary_df for scalarized objectives (facebook#5061)
Summary: Pull Request resolved: facebook#5061 The `metric_config_summary_df` method in `Experiment` crashes when using scalarized objectives because the `else` branch calls `sub_obj.minimize`, which raises `UserInputError` for scalarized objectives. Add an `elif objective.is_scalarized_objective:` branch that iterates `objective.metric_weights` and sets per-metric goal using `weight < 0`. Reviewed By: dme65 Differential Revision: D97122953 fbshipit-source-id: 8829b930426489cc98aebfd6b08fd60ed3dcd6d3
1 parent 4b2905b commit 358c4c6

2 files changed

Lines changed: 33 additions & 4 deletions

File tree

ax/core/experiment.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2215,13 +2215,23 @@ def metric_config_summary_df(self) -> pd.DataFrame:
22152215
if objective.is_multi_objective:
22162216
parts = [p.strip() for p in objective.expression.split(",")]
22172217
sub_objectives = [Objective(expression=part) for part in parts]
2218+
for sub_obj in sub_objectives:
2219+
obj_name = sub_obj.metric_names[0]
2220+
if obj_name in records:
2221+
records[obj_name][METRIC_DF_COLNAMES["goal"]] = (
2222+
"minimize" if sub_obj.minimize else "maximize"
2223+
)
2224+
elif objective.is_scalarized_objective:
2225+
for metric_name, weight in objective.metric_weights:
2226+
if metric_name in records:
2227+
records[metric_name][METRIC_DF_COLNAMES["goal"]] = (
2228+
"minimize" if weight < 0 else "maximize"
2229+
)
22182230
else:
2219-
sub_objectives = [objective]
2220-
for sub_obj in sub_objectives:
2221-
obj_name = sub_obj.metric_names[0]
2231+
obj_name = objective.metric_names[0]
22222232
if obj_name in records:
22232233
records[obj_name][METRIC_DF_COLNAMES["goal"]] = (
2224-
"minimize" if sub_obj.minimize else "maximize"
2234+
"minimize" if objective.minimize else "maximize"
22252235
)
22262236

22272237
objective_threshold_names: set[str] = set()

ax/core/tests/test_experiment.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,6 +1625,25 @@ def test_metric_summary_df(self) -> None:
16251625
)
16261626
pd.testing.assert_frame_equal(df, expected_df)
16271627

1628+
def test_metric_summary_df_scalarized_objective(self) -> None:
1629+
experiment = Experiment(
1630+
name="test_experiment",
1631+
search_space=SearchSpace(parameters=[]),
1632+
optimization_config=OptimizationConfig(
1633+
objective=Objective(expression="2*metric_a + -3*metric_b"),
1634+
),
1635+
tracking_metrics=[
1636+
Metric(name="metric_a", lower_is_better=False),
1637+
Metric(name="metric_b", lower_is_better=True),
1638+
],
1639+
)
1640+
df = experiment.metric_config_summary_df
1641+
# metric_a has positive weight -> maximize
1642+
# metric_b has negative weight -> minimize
1643+
goal_by_name = dict(zip(df["Name"], df["Goal"]))
1644+
self.assertEqual(goal_by_name["metric_a"], "maximize")
1645+
self.assertEqual(goal_by_name["metric_b"], "minimize")
1646+
16281647
def test_arms_by_signature_for_deduplication(self) -> None:
16291648
experiment = self.experiment
16301649
trial = experiment.new_trial()

0 commit comments

Comments
 (0)