Skip to content

Commit 4c1ed83

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Include SQ arm in DerivedMetric output when relativize_inputs=True (#5029)
Summary: Pull Request resolved: #5029 When `relativize_inputs=True`, the status quo arm was previously excluded from the output because its relativized values are trivially zero. This is incorrect for non-linear expressions: `exp(0) = 1`, not 0. Instead of skipping the SQ arm, construct a DataFrame with zero-valued inputs for all input metrics and let `_compute_derived_values` evaluate the expression on them. This produces correct SQ output for any expression (e.g., `a + b = 0`, `exp(a) = 1`). Reviewed By: Balandat Differential Revision: D96558255 fbshipit-source-id: 7c8283609f459149a19c72473251715e442d8c28
1 parent f3fd499 commit 4c1ed83

2 files changed

Lines changed: 68 additions & 18 deletions

File tree

ax/core/derived_metric.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,9 @@ def _relativize_arm_data(
318318
properly transform both means and SEMs.
319319
320320
When ``relativize_inputs`` is ``False``, returns ``arm_data``
321-
unchanged. When ``True``, the status quo arm is excluded from
322-
the returned dict (its relativized values are zero by definition).
321+
unchanged. When ``True``, the status quo arm is included with
322+
zero-valued inputs so the expression can be evaluated on it
323+
(e.g., ``exp(0)=1``).
323324
"""
324325
if not self._relativize_inputs:
325326
return arm_data
@@ -342,8 +343,24 @@ def _relativize_arm_data(
342343
# different SQ metric values (non-stationarity).
343344
relativized: dict[str, pd.DataFrame] = {}
344345
for arm_name, arm_df in arm_data.items():
345-
# Skip the SQ arm itself — its relativized values are zero.
346+
# SQ relativized against itself is trivially zero for all inputs.
347+
# Include it so _compute_derived_values can evaluate the expression
348+
# on zeros (e.g., exp(0)=1, a+b=0).
346349
if arm_name == sq_name:
350+
sq_rel_rows: list[dict[str, Any]] = []
351+
status_quo_trial_index = int(arm_df["trial_index"].iloc[0])
352+
for metric_name in self._input_metric_names:
353+
sq_rel_rows.append(
354+
{
355+
"trial_index": status_quo_trial_index,
356+
"arm_name": sq_name,
357+
"metric_name": metric_name,
358+
"metric_signature": metric_name,
359+
"mean": 0.0,
360+
"sem": 0.0,
361+
}
362+
)
363+
relativized[sq_name] = pd.DataFrame(sq_rel_rows)
347364
continue
348365

349366
# Determine this arm's source trial_index from its data.
@@ -467,8 +484,8 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs: Any) -> MetricFetchResult
467484
if isinstance(arm_data_result, MetricFetchE):
468485
return Err(arm_data_result)
469486

470-
# After relativization, arm_data may be empty (e.g., a SQ-only trial
471-
# where all arms were excluded). Return empty data, not an error.
487+
# After relativization, arm_data may be empty (e.g., a trial with
488+
# no arms). Return empty data, not an error.
472489
if not arm_data_result:
473490
return Ok(value=Data())
474491

@@ -651,7 +668,8 @@ def _compute_derived_values(
651668
"""Evaluate the expression for each arm using pre-collected data.
652669
653670
When ``relativize_inputs`` is ``True``, the base class has already
654-
relativized the ``mean`` values and excluded the status quo arm.
671+
relativized the ``mean`` values. The status quo arm is included
672+
with zero-valued inputs.
655673
"""
656674
result_rows: list[dict[str, Any]] = []
657675

ax/core/tests/test_derived_metric.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -265,11 +265,15 @@ def test_relativize_arm_data(self) -> None:
265265
result = metric.fetch_trial_data(exp.trials[0])
266266
self.assertIsInstance(result, Ok)
267267
df = none_throws(result.ok).df
268-
# SQ arm should be excluded from output.
269-
self.assertEqual(set(df["arm_name"].unique()), {"arm1"})
268+
# SQ arm should be included with zero-valued inputs (sum=0).
269+
self.assertEqual(set(df["arm_name"].unique()), {"sq", "arm1"})
270+
sq_row = df[df["arm_name"] == "sq"]
271+
# SQ: inputs are zero after relativization, sum(0,0) = 0.
272+
self.assertAlmostEqual(sq_row["mean"].iloc[0], 0.0)
273+
arm1_row = df[df["arm_name"] == "arm1"]
270274
# arm1 relativized (as_percent=True):
271275
# a=(15-10)/10=50%, b=(30-20)/20=50%; sum=100.0
272-
self.assertAlmostEqual(df["mean"].iloc[0], 100.0)
276+
self.assertAlmostEqual(arm1_row["mean"].iloc[0], 100.0)
273277

274278
with self.subTest("no_status_quo"):
275279
exp_no_sq = Experiment(name="no_sq", search_space=get_branin_search_space())
@@ -698,12 +702,13 @@ def test_expression_evaluation_errors(self) -> None:
698702
# ------------------------------------------------------------------
699703

700704
def test_relativize_inputs(self) -> None:
701-
"""Relativized fetch: correct computation, SQ excluded, multi-arm.
705+
"""Relativized fetch: correct computation, SQ included, multi-arm.
702706
Also verifies that relativize_inputs=False (default) includes SQ
703707
and uses raw values."""
704708
# SQ: a=10, b=4.
705-
# arm_1: a=15, b=8 → a_rel=0.5, b_rel=1.0 → a/b = 0.5
706-
# arm_2: a=20, b=6 → a_rel=1.0, b_rel=0.5 → a/b = 2.0
709+
# arm_1: a=15, b=8 → a_rel=50%, b_rel=100% → a+b = 150
710+
# arm_2: a=20, b=6 → a_rel=100%, b_rel=50% → a+b = 150
711+
# SQ: a_rel=0, b_rel=0 → a+b = 0
707712
exp = self._batch_experiment_with_sq(
708713
sq_values={"a": 10.0, "b": 4.0},
709714
arm_values={
@@ -712,20 +717,47 @@ def test_relativize_inputs(self) -> None:
712717
},
713718
)
714719
metric = ExpressionDerivedMetric(
715-
name="ratio_rel",
720+
name="sum_rel",
716721
input_metric_names=["a", "b"],
717-
expression_str="a / b",
722+
expression_str="a + b",
718723
relativize_inputs=True,
719724
)
720725
result = metric.fetch_trial_data(exp.trials[0])
721726
self.assertIsInstance(result, Ok)
722727
df = none_throws(result.ok).df.sort_values("arm_name").reset_index(drop=True)
723-
self.assertEqual(len(df), 2)
724-
self.assertNotIn("status_quo", df["arm_name"].values)
725-
self.assertAlmostEqual(df.loc[0, "mean"], 0.5, places=10)
726-
self.assertAlmostEqual(df.loc[1, "mean"], 2.0, places=10)
728+
# SQ is included: 3 rows (arm_1, arm_2, status_quo).
729+
self.assertEqual(len(df), 3)
730+
self.assertIn("status_quo", df["arm_name"].values)
731+
arm1_row = df[df["arm_name"] == "arm_1"]
732+
arm2_row = df[df["arm_name"] == "arm_2"]
733+
sq_row = df[df["arm_name"] == "status_quo"]
734+
self.assertAlmostEqual(arm1_row["mean"].iloc[0], 150.0, places=10)
735+
self.assertAlmostEqual(arm2_row["mean"].iloc[0], 150.0, places=10)
736+
# SQ: zero-valued inputs → a+b = 0.
737+
self.assertAlmostEqual(sq_row["mean"].iloc[0], 0.0, places=10)
727738
self.assertTrue(df["sem"].isna().all())
728739

740+
with self.subTest("sq_evaluates_expression_on_zeros"):
741+
# exp(0) = 1, verifying the expression is evaluated (not
742+
# hardcoded to 0) on the SQ arm's zero-valued inputs.
743+
exp2 = self._batch_experiment_with_sq(
744+
sq_values={"a": 10.0},
745+
arm_values={"arm_1": {"a": 15.0}},
746+
)
747+
metric2 = ExpressionDerivedMetric(
748+
name="exp_a",
749+
input_metric_names=["a"],
750+
expression_str="exp(a)",
751+
relativize_inputs=True,
752+
)
753+
result2 = metric2.fetch_trial_data(exp2.trials[0])
754+
self.assertIsInstance(result2, Ok)
755+
df2 = none_throws(result2.ok).df
756+
sq_row2 = df2[df2["arm_name"] == "status_quo"]
757+
self.assertEqual(len(sq_row2), 1)
758+
# exp(0) = 1.0
759+
self.assertAlmostEqual(sq_row2["mean"].iloc[0], 1.0, places=10)
760+
729761
with self.subTest("not_applied_by_default"):
730762
exp = self._batch_experiment_with_sq(
731763
sq_values={"a": 10.0, "b": 5.0},

0 commit comments

Comments
 (0)