Skip to content

Commit c40dfae

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Per-metric relativization for Summary with preference metrics (#5218)
Summary: Pull Request resolved: #5218 ## Summary Supersedes D97533888. Addresses drfreund's review comment about deduping with `Data.relativize` by placing the per-metric scoping in the data layer rather than adding analysis-specific logic. When an experiment has both preference metrics (e.g., `pairwise_pref_query`) and standard tracking metrics, the Summary should relativize tracking metrics normally while skipping the preference metric (whose binary 0/1 labels have SQ mean near zero, causing "mean_control too small" crash). Previously D99037272 applied a blanket guard that skipped ALL relativization when any objective was a preference metric. This diff replaces that with per-metric scoping: non-preference metrics are relativized and %-formatted, while preference metrics are excluded from relativization and their columns are dropped from the summary table (binary 0/1 labels are not informative in a tabular summary). Labeling-only trial rows (with no tracking metric data) are also dropped. Changes: - `Data.relativize()` and `relativize_dataframe()` in `ax/core/data.py`: add `metric_names` parameter to scope which metrics get relativized. Unscoped metrics pass through with raw values. SEM zeroing for status quo rows is also scoped -- non-relativized metrics retain their original SEM. - `Experiment.to_df()` in `ax/core/experiment.py`: add `metric_names_to_relativize` parameter, threaded to `Data.relativize()`. Percentage formatting also scoped to only relativized metrics. - `Summary.compute()` in `ax/analysis/summary.py`: replace blanket `not has_preference_objective` guard with per-metric scoping. Builds a list of non-preference metric names, passes it as `metric_names_to_relativize`, then drops preference columns and labeling-only rows. Differential Revision: D99149923 fbshipit-source-id: dade9eb72af2d75b13bacf471eb0cd1df7bb95d7
1 parent b3ddd71 commit c40dfae

7 files changed

Lines changed: 280 additions & 35 deletions

File tree

ax/analysis/summary.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,44 @@ def compute(
8484
# (3) experiment data does not have has_step_column=True (data with a
8585
# progression doesn't support relativization due to time-series step
8686
# alignment complexities.)
87-
# (4) no preference metric objectives -- preference metrics (e.g.,
88-
# pairwise_pref_query) have binary 0/1 labels with SQ mean near zero,
89-
# causing relativization to crash with "mean_control too small."
9087
data = experiment.lookup_data(trial_indices=self.trial_indices)
91-
has_preference_objective = experiment.optimization_config is not None and any(
92-
is_preference_metric(n)
93-
for n in experiment.optimization_config.objective.metric_names
94-
)
9588
should_relativize = (
9689
len(experiment.metrics) > 0
9790
and experiment.status_quo is not None
9891
and not data.has_step_column
99-
and not has_preference_objective
10092
)
10193

94+
# When relativizing, scope to non-preference metrics only. Preference
95+
# metrics (e.g., pairwise_pref_query) have binary 0/1 labels with SQ
96+
# mean near zero, causing relativization to crash with "mean_control
97+
# too small." Non-preference tracking metrics are relativized normally.
98+
non_preference_metric_names: list[str] | None = None
99+
if should_relativize:
100+
non_preference_metric_names = [
101+
name for name in experiment.metrics if not is_preference_metric(name)
102+
]
103+
104+
df = experiment.to_df(
105+
trial_indices=self.trial_indices,
106+
omit_empty_columns=self.omit_empty_columns,
107+
trial_statuses=self.trial_statuses,
108+
relativize=should_relativize,
109+
metric_names_to_relativize=non_preference_metric_names,
110+
)
111+
112+
# Drop preference metric columns (e.g., pairwise_pref_query) -- their
113+
# binary 0/1 values are not informative in a summary table. Then drop
114+
# rows where all remaining metric columns are empty, which removes
115+
# labeling-only trials that have no tracking metric data.
116+
preference_cols = [col for col in df.columns if is_preference_metric(col)]
117+
if preference_cols:
118+
df = df.drop(columns=preference_cols)
119+
remaining_metric_cols = [
120+
col for col in df.columns if col in experiment.metrics
121+
]
122+
if remaining_metric_cols:
123+
df = df.dropna(subset=remaining_metric_cols, how="all")
124+
102125
return self._create_analysis_card(
103126
title=(
104127
"Summary for "
@@ -112,10 +135,5 @@ def compute(
112135
"Metric results are relativized against status quo."
113136
)
114137
),
115-
df=experiment.to_df(
116-
trial_indices=self.trial_indices,
117-
omit_empty_columns=self.omit_empty_columns,
118-
trial_statuses=self.trial_statuses,
119-
relativize=should_relativize,
120-
),
138+
df=df,
121139
)

ax/analysis/tests/test_summary.py

Lines changed: 78 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from ax.analysis.summary import Summary
1111
from ax.api.client import Client
1212
from ax.api.configs import RangeParameterConfig
13-
from ax.core.arm import Arm
1413
from ax.core.base_trial import TrialStatus
1514
from ax.core.data import Data
15+
from ax.core.experiment import Experiment
1616
from ax.core.metric import Metric
1717
from ax.core.objective import MultiObjective, Objective
1818
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
@@ -224,35 +224,97 @@ def test_trial_status_filter(self) -> None:
224224
self.assertIn(0, card.df["trial_index"].values)
225225
self.assertIn(1, card.df["trial_index"].values)
226226

227-
def test_compute_with_preference_objective_skips_relativization(self) -> None:
228-
"""Summary should skip relativization when a preference metric is an
229-
objective, since binary 0/1 labels have SQ mean near zero which causes
230-
'mean_control too small' errors."""
227+
def _attach_binary_pairwise_data(
228+
self, experiment: Experiment, pairwise_name: str
229+
) -> None:
230+
"""Attach binary 0/1 pairwise data to every trial, with the status-quo
231+
(control) arm set to 0. A near-zero control mean is exactly what makes
232+
relativizing this metric crash with 'mean_control too small', so this
233+
deterministically reproduces the crash unless the preference metric is
234+
scoped out of relativization."""
235+
status_quo_name = none_throws(experiment.status_quo).name
236+
for trial in experiment.trials.values():
237+
arm_names = [arm.name for arm in trial.arms]
238+
# Control arm gets 0.0 so |mean_control| is below the relativization
239+
# epsilon; all other arms get 1.0.
240+
means = [0.0 if name == status_quo_name else 1.0 for name in arm_names]
241+
experiment.attach_data(
242+
Data(
243+
df=pd.DataFrame(
244+
{
245+
"arm_name": arm_names,
246+
"metric_name": [pairwise_name] * len(arm_names),
247+
"mean": means,
248+
"sem": [0.0] * len(arm_names),
249+
"trial_index": [trial.index] * len(arm_names),
250+
"metric_signature": [pairwise_name] * len(arm_names),
251+
}
252+
)
253+
)
254+
)
255+
256+
def test_compute_with_preference_objective_per_metric_relativization(
257+
self,
258+
) -> None:
259+
"""Summary with a preference metric objective should relativize only
260+
non-preference metrics. The preference metric (pairwise_pref_query)
261+
has binary 0/1 labels with SQ mean near zero -- relativizing it would
262+
crash with 'mean_control too small'. Non-preference metrics should
263+
be relativized normally."""
231264
pairwise_name = Keys.PAIRWISE_PREFERENCE_QUERY.value
232265

233-
# Use Client to set up experiment with SQ and data
234-
client = self.client
235-
client.configure_optimization(objective="foo")
236-
experiment = client._experiment
237-
experiment.status_quo = Arm(parameters={"x1": 0.5, "x2": 0.5})
266+
# Use an experiment with BatchTrials and SQ data, which triggers
267+
# relativization in the Summary. get_branin_experiment_with_status_quo_trials
268+
# creates BatchTrials with a SQ arm, so data.relativize() has SQ data.
269+
experiment = get_branin_experiment_with_status_quo_trials()
238270

239-
# Add pairwise_pref_query as an additional objective
271+
# Add pairwise_pref_query as an additional objective alongside branin.
240272
experiment.add_tracking_metric(Metric(name=pairwise_name))
241273
experiment.optimization_config = MultiObjectiveOptimizationConfig(
242274
objective=MultiObjective(
243275
objectives=[
244-
Objective(metric=Metric(name="foo"), minimize=True),
276+
Objective(metric=experiment.metrics["branin"], minimize=True),
245277
Objective(metric=Metric(name=pairwise_name), minimize=False),
246278
]
247279
)
248280
)
249281

250-
client.get_next_trials(max_trials=1)
251-
client.complete_trial(trial_index=0, raw_data={"foo": 1.0, pairwise_name: 0.0})
282+
self._attach_binary_pairwise_data(experiment, pairwise_name)
252283

253-
# Should succeed without "mean_control too small" error
284+
# Should succeed without "mean_control too small" crash
254285
card = Summary().compute(experiment=experiment)
255-
self.assertNotIn("relativized", card.subtitle)
286+
287+
# Subtitle should indicate relativization (non-preference metrics)
288+
self.assertIn("relativized", card.subtitle)
289+
290+
# Preference metric column should be dropped from the summary
291+
self.assertNotIn(pairwise_name, card.df.columns)
292+
293+
def test_compute_with_preference_tracking_metric_and_no_optimization_config(
294+
self,
295+
) -> None:
296+
"""A preference metric attached as a tracking metric (with a status quo
297+
but no optimization_config) must still be scoped out of relativization.
298+
Relativization is gated only on metrics/status_quo/step data, not on the
299+
optimization_config, so without scoping the binary 0/1 preference metric
300+
(SQ mean ~0) would crash with 'mean_control too small'."""
301+
pairwise_name = Keys.PAIRWISE_PREFERENCE_QUERY.value
302+
303+
experiment = get_branin_experiment_with_status_quo_trials()
304+
# Preference metric is tracking-only; there is no optimization_config.
305+
experiment.add_tracking_metric(Metric(name=pairwise_name))
306+
experiment._optimization_config = None
307+
308+
self._attach_binary_pairwise_data(experiment, pairwise_name)
309+
310+
# Should succeed without "mean_control too small" crash even though
311+
# there is no optimization_config.
312+
card = Summary().compute(experiment=experiment)
313+
314+
# branin is still relativized (non-preference metric).
315+
self.assertIn("relativized", card.subtitle)
316+
# Preference metric column should be dropped from the summary.
317+
self.assertNotIn(pairwise_name, card.df.columns)
256318

257319
def test_default_excludes_stale_trials(self) -> None:
258320
"""Test that Summary defaults to excluding STALE trials."""

ax/analysis/tests/test_utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -721,6 +721,59 @@ def test_relativize_df_with_sq_multiple_trials(self) -> None:
721721
decimal=1,
722722
)
723723

724+
def test_relativize_df_with_sq_skips_metrics_missing_from_sq(self) -> None:
725+
"""When model predictions include metrics not in the status quo df
726+
(e.g., pairwise_pref_query from a preference model), relativization
727+
should skip those metrics and leave their columns untouched."""
728+
df = pd.DataFrame(
729+
{
730+
"trial_index": [0, 0],
731+
"arm_name": ["status_quo", "arm1"],
732+
"foo_mean": [10.0, 12.0],
733+
"foo_sem": [1.0, 1.2],
734+
"pairwise_pref_query_mean": [0.5, 0.8],
735+
"pairwise_pref_query_sem": [0.1, 0.2],
736+
}
737+
)
738+
# Status quo df only has foo, not pairwise_pref_query
739+
status_quo_df = pd.DataFrame(
740+
{
741+
"trial_index": [0],
742+
"arm_name": ["status_quo"],
743+
"foo_mean": [10.0],
744+
"foo_sem": [1.0],
745+
}
746+
)
747+
748+
rel_df = _relativize_df_with_sq(
749+
df=df,
750+
status_quo_df=status_quo_df,
751+
status_quo_name="status_quo",
752+
)
753+
754+
with self.subTest("foo is relativized"):
755+
np.testing.assert_almost_equal(
756+
rel_df.loc[rel_df["arm_name"] == "arm1", "foo_mean"].iloc[0],
757+
0.2,
758+
decimal=1,
759+
)
760+
761+
with self.subTest("pairwise_pref_query is untouched"):
762+
np.testing.assert_almost_equal(
763+
rel_df.loc[
764+
rel_df["arm_name"] == "arm1", "pairwise_pref_query_mean"
765+
].iloc[0],
766+
0.8,
767+
decimal=5,
768+
)
769+
np.testing.assert_almost_equal(
770+
rel_df.loc[
771+
rel_df["arm_name"] == "arm1", "pairwise_pref_query_sem"
772+
].iloc[0],
773+
0.2,
774+
decimal=5,
775+
)
776+
724777
def test_is_preference_metric(self) -> None:
725778
self.assertTrue(is_preference_metric("pairwise_pref_query"))
726779
self.assertFalse(is_preference_metric("branin"))

ax/analysis/utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,14 @@ def _relativize_df_with_sq(
845845
and 'METRIC_NAME_sem' columns relativized to the status quo arm for each metric
846846
within each trial.
847847
"""
848-
metric_names = [name[:-5] for name in df.columns if name.endswith("_mean")]
848+
# Only relativize metrics present in both the data df and the status quo
849+
# df. Model predictions may include metrics (e.g., pairwise_pref_query)
850+
# that the status quo df doesn't have columns for.
851+
metric_names = [
852+
name[:-5]
853+
for name in df.columns
854+
if name.endswith("_mean") and name in status_quo_df.columns
855+
]
849856

850857
rel_df = df.copy()
851858

ax/core/data.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import itertools
1212
import math
1313
from bisect import bisect_right
14-
from collections.abc import Iterable
14+
from collections.abc import Iterable, Sequence
1515
from copy import deepcopy
1616
from functools import cached_property
1717
from io import StringIO
@@ -439,6 +439,7 @@ def relativize(
439439
include_sq: bool = False,
440440
bias_correction: bool = True,
441441
control_as_constant: bool = False,
442+
metric_names: Sequence[str] | None = None,
442443
) -> Data:
443444
"""Relativize a data object w.r.t. a status_quo arm.
444445
@@ -453,6 +454,9 @@ def relativize(
453454
ax.utils.stats.math_utils.relativize for more details.
454455
control_as_constant: If true, control is treated as a constant.
455456
bias_correction is ignored when this is true.
457+
metric_names: If provided, only relativize these metrics. Other
458+
metrics are passed through with raw values. If None, all
459+
metrics are relativized.
456460
457461
Returns:
458462
The new data object with the relativized metrics (excluding the
@@ -471,6 +475,7 @@ def relativize(
471475
include_sq=include_sq,
472476
bias_correction=bias_correction,
473477
control_as_constant=control_as_constant,
478+
metric_names=metric_names,
474479
)
475480
return self.__class__(df=df_rel)
476481

@@ -698,6 +703,7 @@ def relativize_dataframe(
698703
include_sq: bool = False,
699704
bias_correction: bool = True,
700705
control_as_constant: bool = False,
706+
metric_names: Sequence[str] | None = None,
701707
) -> pd.DataFrame:
702708
"""Relativize a dataframe w.r.t. a status_quo arm.
703709
@@ -712,6 +718,9 @@ def relativize_dataframe(
712718
ax.utils.stats.math_utils.relativize for more details.
713719
control_as_constant: If true, control is treated as a constant.
714720
bias_correction is ignored when this is true.
721+
metric_names: If provided, only relativize these metrics. Other
722+
metrics are passed through with raw values. If None, all
723+
metrics are relativized.
715724
716725
Returns:
717726
The new dataframe with the relativized metrics (excluding the
@@ -721,11 +730,26 @@ def relativize_dataframe(
721730
grp_cols = list(
722731
{"trial_index", "metric_name", "random_split"}.intersection(df.columns.values)
723732
)
733+
metric_names_set = set(metric_names) if metric_names is not None else None
724734

725735
grouped_df = df.groupby(grp_cols)
726736
dfs = []
727737
for grp in grouped_df.groups.keys():
728738
subgroup_df = grouped_df.get_group(grp)
739+
740+
# If metric scoping is requested, skip relativization for excluded
741+
# metrics and pass through raw data (with or without SQ row).
742+
if metric_names_set is not None and "metric_name" in grp_cols:
743+
grp_metric = (
744+
grp if isinstance(grp, str) else grp[grp_cols.index("metric_name")]
745+
)
746+
if grp_metric not in metric_names_set:
747+
if include_sq:
748+
dfs.append(subgroup_df)
749+
else:
750+
dfs.append(subgroup_df[subgroup_df["arm_name"] != status_quo_name])
751+
continue
752+
729753
is_sq = subgroup_df["arm_name"] == status_quo_name
730754

731755
# Check if status quo exists in this subgroup (trial)
@@ -758,7 +782,13 @@ def relativize_dataframe(
758782
dfs.append(subgroup_df.assign(mean=means_rel, sem=sems_rel))
759783
df_rel = pd.concat(dfs, axis=0)
760784
if include_sq:
761-
df_rel.loc[df_rel["arm_name"] == status_quo_name, "sem"] = 0.0
785+
# Zero SEM only for metrics that were actually relativized.
786+
# Non-relativized metrics (excluded via metric_names scoping)
787+
# should retain their original SEM values.
788+
sq_mask = df_rel["arm_name"] == status_quo_name
789+
if metric_names_set is not None and "metric_name" in df_rel.columns:
790+
sq_mask = sq_mask & df_rel["metric_name"].isin(metric_names_set)
791+
df_rel.loc[sq_mask, "sem"] = 0.0
762792
df_rel.reset_index(inplace=True, drop=True)
763793
# Reorder columns to match expected order (reuses Data class logic)
764794
df_rel = Data._get_df_with_cols_in_expected_order(df_rel)

0 commit comments

Comments
 (0)