Skip to content

Commit 446ddb8

Browse files
ItsMrLinmeta-codesync[bot]
authored andcommitted
Skip TransformToNewSQ for metrics with near-zero status quo mean (#5076)
Summary: Pull Request resolved: #5076 When `ExpressionDerivedMetric` is used as an objective in PTS experiments, its status quo value is naturally zero (0% change from itself). This caused `TransformToNewSQ.transform_experiment_data` to crash with a `ValueError` in `relativize()` because division by zero is undefined for the delta method. D96574758 handled missing SQ data (trials without any SQ) but not zero-valued SQ data (SQ exists but the metric value is zero). This diff adds guards in `TransformToNewSQ` to skip metrics where the status quo mean is near-zero, with a warning so users know the transform was skipped. The `relativize()` utility itself still raises on zero control -- we only prevent calling it with zero args. Two code paths are guarded: - `transform_experiment_data` (vectorized/DataFrame path): checks target SQ and source trial SQ means before calling `relativize()`. - `_get_rel_mean_sem` (per-observation path): same guard, needed for untransform symmetry so that predictions are not incorrectly un-transformed for metrics that were never transformed. Meta: this unblocks Ax experiment `ifu_rbvm_session_proxy_pts` Reviewed By: Balandat Differential Revision: D97357997 fbshipit-source-id: e41f2d57998cd42ca03ab1d43cadbe4615fb0be1
1 parent b577f94 commit 446ddb8

4 files changed

Lines changed: 106 additions & 6 deletions

File tree

ax/adapter/torch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,13 @@ def _convert_experiment_data(
460460
# Drop NaN columns from means & corresponding params.
461461
outcome_means = mean_and_params[outcome_col_name].to_numpy()
462462
to_keep = ~np.isnan(outcome_means)
463+
if not np.any(to_keep):
464+
logger.warning(
465+
f"Skipping outcome '{outcome}': no non-NaN observations "
466+
f"remain after filtering. This can happen when a metric "
467+
f"has data in only a subset of trials."
468+
)
469+
continue
463470
Y = torch.from_numpy(outcome_means[to_keep]).double().view(-1, 1)
464471
X = torch.from_numpy(params_np[to_keep]).double()
465472
sem = sems_df[outcome].to_numpy()[to_keep]

ax/adapter/transforms/tests/test_transform_to_new_sq.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# pyre-strict
77

88

9+
import logging
910
import unittest
1011
from copy import deepcopy
1112
from unittest import mock
@@ -427,3 +428,55 @@ def test_non_relativizable_trial_preserved(self) -> None:
427428
transformed.observation_data.loc[0]["sem", "branin"],
428429
)
429430
)
431+
432+
def test_zero_sq_metric_skipped_with_warning(self) -> None:
433+
"""Metrics whose status quo mean is near-zero (e.g.,
434+
ExpressionDerivedMetric that is already relativized) should be
435+
skipped with a warning rather than crashing on division by zero.
436+
"""
437+
sobol = get_sobol(search_space=self.exp.search_space)
438+
for sq_val in (2.0, 3.0):
439+
t = self.exp.new_batch_trial(
440+
generator_run=sobol.gen(2), should_add_status_quo_arm=True
441+
).mark_completed(unsafe=True)
442+
data = get_branin_data_batch(batch=t)
443+
data.df.loc[(data.df["arm_name"] == "status_quo"), "mean"] = sq_val
444+
self.exp.attach_data(data=data)
445+
self._refresh_adapter()
446+
447+
experiment_data = extract_experiment_data(
448+
experiment=self.exp, data_loader_config=DataLoaderConfig()
449+
)
450+
451+
tf = TransformToNewSQ(
452+
search_space=None,
453+
adapter=self.adapter,
454+
config={"target_trial_index": 2},
455+
)
456+
457+
# Set the target trial's SQ mean to zero, simulating an
458+
# ExpressionDerivedMetric whose SQ is naturally zero.
459+
tf.status_quo_data_by_trial[2].means[0] = 0.0
460+
461+
with self.assertLogs(
462+
"ax.adapter.transforms.transform_to_new_sq", level=logging.WARNING
463+
) as cm:
464+
transformed_data = tf.transform_experiment_data(
465+
experiment_data=deepcopy(experiment_data)
466+
)
467+
468+
# Verify a warning was emitted for the skipped metric.
469+
self.assertTrue(
470+
any("near-zero" in msg for msg in cm.output),
471+
f"Expected a near-zero warning, got: {cm.output}",
472+
)
473+
474+
# Data for all non-target trials should be unchanged (no transform
475+
# was applied for the metric with zero SQ).
476+
for t_idx in (0, 1):
477+
orig = experiment_data.observation_data.loc[t_idx]
478+
orig_non_sq = orig[
479+
orig.index.get_level_values("arm_name") != self.adapter.status_quo_name
480+
]
481+
tf_data = transformed_data.observation_data.loc[t_idx]
482+
assert_frame_equal(orig_non_sq, tf_data)

ax/adapter/transforms/transform_to_new_sq.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from __future__ import annotations
1010

11+
import logging
1112
from collections.abc import Callable
1213
from typing import TYPE_CHECKING
1314

@@ -21,9 +22,11 @@
2122
from ax.core.search_space import SearchSpace
2223
from ax.core.utils import get_target_trial_index
2324
from ax.generators.types import TConfig
24-
from ax.utils.stats.math_utils import relativize, unrelativize
25+
from ax.utils.stats.math_utils import MEAN_CONTROL_EPSILON, relativize, unrelativize
2526
from pyre_extensions import assert_is_instance, none_throws
2627

28+
logger: logging.Logger = logging.getLogger(__name__)
29+
2730
if TYPE_CHECKING:
2831
# import as module to make sphinx-autodoc-typehints happy
2932
from ax import adapter as adapter_module # noqa F401
@@ -127,6 +130,19 @@ def transform_experiment_data(
127130
if metric not in target_sq_data.metric_signatures:
128131
continue
129132

133+
# Check target SQ mean first -- if near-zero, relativization is
134+
# undefined (unrelativization would collapse all values to zero).
135+
target_j = get_metric_index(data=target_sq_data, metric_signature=metric)
136+
target_mean_c = target_sq_data.means[target_j]
137+
if np.abs(target_mean_c) < MEAN_CONTROL_EPSILON:
138+
logger.warning(
139+
f"Skipping TransformToNewSQ for metric '{metric}': "
140+
f"target trial status quo mean is near-zero "
141+
f"({target_mean_c}). This can happen when the metric "
142+
f"is already relativized (e.g., ExpressionDerivedMetric)."
143+
)
144+
continue
145+
130146
# Build per-row control arrays from each trial's SQ data.
131147
mean_c, sem_c = [], []
132148
for idx in trial_indices[transform_mask]:
@@ -135,18 +151,26 @@ def transform_experiment_data(
135151
mean_c.append(sq_data.means[j])
136152
sem_c.append(sq_data.covariance[j, j] ** 0.5)
137153

154+
mean_c_arr = np.array(mean_c)
155+
if np.any(np.abs(mean_c_arr) < MEAN_CONTROL_EPSILON):
156+
logger.warning(
157+
f"Skipping TransformToNewSQ for metric '{metric}': "
158+
f"one or more trial status quo means are near-zero. "
159+
f"This can happen when the metric is already relativized "
160+
f"(e.g., ExpressionDerivedMetric)."
161+
)
162+
continue
163+
138164
means_rel, sems_rel = relativize(
139165
means_t=observation_data.loc[transform_mask, ("mean", metric)],
140166
sems_t=observation_data.loc[transform_mask, ("sem", metric)],
141-
mean_c=np.array(mean_c),
167+
mean_c=mean_c_arr,
142168
sem_c=np.array(sem_c),
143169
as_percent=False,
144170
control_as_constant=self.control_as_constant,
145171
)
146172

147173
# Unrelativize with respect to target trial's status quo.
148-
target_j = get_metric_index(data=target_sq_data, metric_signature=metric)
149-
target_mean_c = target_sq_data.means[target_j]
150174
abs_target_mean_c = np.abs(target_mean_c)
151175
observation_data.loc[transform_mask, ("mean", metric)] = (
152176
means_rel * abs_target_mean_c + target_mean_c
@@ -232,6 +256,19 @@ def _get_rel_mean_sem(
232256
j = get_metric_index(data=target_status_quo_data, metric_signature=metric)
233257
target_mean_c = target_status_quo_data.means[j]
234258
abs_target_mean_c = np.abs(target_mean_c)
259+
# Skip if control or target SQ mean is near-zero -- relativization
260+
# is undefined (division by zero). The guard here is needed for
261+
# untransform symmetry: if transform_experiment_data skipped a
262+
# metric, the untransform path must also skip it.
263+
if abs_target_mean_c < MEAN_CONTROL_EPSILON or (
264+
np.abs(mean_c) < MEAN_CONTROL_EPSILON
265+
):
266+
logger.warning(
267+
f"Skipping TransformToNewSQ for metric '{metric}': "
268+
f"status quo mean is near-zero (target={target_mean_c}, "
269+
f"control={mean_c})."
270+
)
271+
return means_t, sems_t
235272
if rel_op == unrelativize:
236273
means_t = (means_t - target_mean_c) / abs_target_mean_c
237274
sems_t = sems_t / abs_target_mean_c

ax/utils/stats/math_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import numpy as np
1010
import numpy.typing as npt
1111

12+
# Minimum absolute value for a control mean to be considered non-zero
13+
# for relativization via the delta method.
14+
MEAN_CONTROL_EPSILON: float = 1e-10
15+
1216

1317
def relativize(
1418
means_t: npt.NDArray | list[float] | float,
@@ -83,8 +87,7 @@ def relativize(
8387
8488
"""
8589
# if mean_c is too small, bail
86-
epsilon = 1e-10
87-
if np.any(np.abs(mean_c) < epsilon):
90+
if np.any(np.abs(mean_c) < MEAN_CONTROL_EPSILON):
8891
raise ValueError(
8992
"mean_control ({} +/- {}) is smaller than 1 in 10 billion, "
9093
"which is too small to reliably analyze ratios using the delta "

0 commit comments

Comments
 (0)