Skip to content

Commit 2322259

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Update PowerTransformY to operate without a config (facebook#3033)
Summary: Pull Request resolved: facebook#3033 Previously, this transform required a config to be specified. The updates make the config optional. Also updated `match_ci_width_truncated` to fix a bug I ran into during benchmarking but I could not reproduce it using tests. The update should be no-op simplification in the worst case. Reviewed By: Balandat Differential Revision: D65433936 fbshipit-source-id: 873c30fa71158bab99e233839b4974c8df16d9ff
1 parent a3c1402 commit 2322259

File tree

3 files changed

+43
-23
lines changed

3 files changed

+43
-23
lines changed

ax/modelbridge/transforms/power_transform_y.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,13 @@
1717
from ax.core.optimization_config import OptimizationConfig
1818
from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint
1919
from ax.core.search_space import SearchSpace
20+
from ax.exceptions.core import DataRequiredError
2021
from ax.modelbridge.transforms.base import Transform
2122
from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated
2223
from ax.models.types import TConfig
2324
from ax.utils.common.logger import get_logger
2425
from ax.utils.common.typeutils import checked_cast_list
26+
from pyre_extensions import assert_is_instance
2527
from sklearn.preprocessing import PowerTransformer
2628

2729
if TYPE_CHECKING:
@@ -57,19 +59,29 @@ def __init__(
5759
modelbridge: modelbridge_module.base.ModelBridge | None = None,
5860
config: TConfig | None = None,
5961
) -> None:
60-
assert observations is not None, "PowerTransformY requires observations"
61-
if config is None:
62-
raise ValueError("PowerTransform requires a config.")
63-
# pyre-fixme[6]: Same issue as for LogY
64-
metric_names = list(config.get("metrics", []))
65-
if len(metric_names) == 0:
66-
raise ValueError("Must specify at least one metric in the config.")
67-
# pyre-fixme[4]: Attribute must be annotated.
68-
self.clip_mean = config.get("clip_mean", True)
69-
# pyre-fixme[4]: Attribute must be annotated.
70-
self.metric_names = metric_names
62+
"""Initialize the ``PowerTransformY`` transform.
63+
64+
Args:
65+
search_space: The search space of the experiment. Unused.
66+
observations: A list of observations from the experiment.
67+
modelbridge: The `ModelBridge` within which the transform is used. Unused.
68+
config: A dictionary of options to control the behavior of the transform.
69+
Can contain the following keys:
70+
- "metrics": A list of metric names to apply the transform to. If
71+
omitted, all metrics found in `observations` are transformed.
72+
- "clip_mean": Whether to clip the mean to the image of the transform.
73+
Defaults to True.
74+
"""
75+
if observations is None or len(observations) == 0:
76+
raise DataRequiredError("PowerTransformY requires observations.")
77+
# pyre-fixme[9]: Can't annotate config["metrics"] properly.
78+
metric_names: list[str] | None = config.get("metrics", None) if config else None
79+
self.clip_mean: bool = (
80+
assert_is_instance(config.get("clip_mean", True), bool) if config else True
81+
)
7182
observation_data = [obs.data for obs in observations]
7283
Ys = get_data(observation_data=observation_data, metric_names=metric_names)
84+
self.metric_names: list[str] = list(Ys.keys())
7385
# pyre-fixme[4]: Attribute must be annotated.
7486
self.power_transforms = _compute_power_transforms(Ys=Ys)
7587
# pyre-fixme[4]: Attribute must be annotated.

ax/modelbridge/transforms/tests/test_power_y_transform.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,12 @@ def test_Init(self) -> None:
7272
"search_space": None,
7373
"observations": self.observations[:2],
7474
}
75-
# Test error for not specifying a config
76-
with self.assertRaises(ValueError):
77-
PowerTransformY(**shared_init_args)
78-
# Test error for not specifying at least one metric
79-
with self.assertRaises(ValueError):
80-
PowerTransformY(**shared_init_args, config={})
81-
# Test default init
75+
# Init without a config.
76+
t = PowerTransformY(**shared_init_args)
77+
self.assertTrue(t.clip_mean)
78+
self.assertEqual(t.metric_names, ["m1", "m2"])
79+
80+
# Test init with config.
8281
for m in ["m1", "m2"]:
8382
tf = PowerTransformY(**shared_init_args, config={"metrics": [m]})
8483
# tf.power_transforms should only exist for m and be a PowerTransformer
@@ -202,6 +201,12 @@ def test_TransformAndUntransformOneMetric(self) -> None:
202201
)[0]
203202
cov_results = np.array(transformed_obsd_nan.covariance)
204203
self.assertTrue(np.all(np.isnan(np.diag(cov_results))))
204+
untransformed = pt._untransform_observation_data([transformed_obsd_nan])[0]
205+
self.assertTrue(
206+
np.array_equal(
207+
untransformed.covariance, self.obsd_nan.covariance, equal_nan=True
208+
)
209+
)
205210

206211
def test_TransformAndUntransformAllMetrics(self) -> None:
207212
pt = PowerTransformY(

ax/modelbridge/transforms/utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,17 @@ def match_ci_width_truncated(
113113
See log_y transform for the original. Here, bounds are forced to lie
114114
within a [lower_bound, upper_bound] interval after transformation."""
115115
fac = norm.ppf(1 - (1 - level) / 2)
116-
d = fac * np.sqrt(variance)
117116
if clip_mean:
118117
mean = np.clip(mean, lower_bound + margin, upper_bound - margin)
119-
right = min(mean + d, upper_bound - margin)
120-
left = max(mean - d, lower_bound + margin)
121-
width_asym = transform(right) - transform(left)
122118
new_mean = transform(mean)
123-
new_variance = float("nan") if isnan(variance) else (width_asym / 2 / fac) ** 2
119+
if isnan(variance):
120+
new_variance = variance
121+
else:
122+
d = fac * np.sqrt(variance)
123+
right = min(mean + d, upper_bound - margin)
124+
left = max(mean - d, lower_bound + margin)
125+
width_asym = transform(right) - transform(left)
126+
new_variance = (width_asym / 2 / fac) ** 2
124127
return new_mean, new_variance
125128

126129

0 commit comments

Comments
 (0)