diff --git a/ax/modelbridge/transforms/power_transform_y.py b/ax/modelbridge/transforms/power_transform_y.py index dca4ac63e3a..a8e79496b3f 100644 --- a/ax/modelbridge/transforms/power_transform_y.py +++ b/ax/modelbridge/transforms/power_transform_y.py @@ -17,11 +17,13 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated from ax.models.types import TConfig from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import checked_cast_list +from pyre_extensions import assert_is_instance from sklearn.preprocessing import PowerTransformer if TYPE_CHECKING: @@ -57,19 +59,29 @@ def __init__( modelbridge: modelbridge_module.base.ModelBridge | None = None, config: TConfig | None = None, ) -> None: - assert observations is not None, "PowerTransformY requires observations" - if config is None: - raise ValueError("PowerTransform requires a config.") - # pyre-fixme[6]: Same issue as for LogY - metric_names = list(config.get("metrics", [])) - if len(metric_names) == 0: - raise ValueError("Must specify at least one metric in the config.") - # pyre-fixme[4]: Attribute must be annotated. - self.clip_mean = config.get("clip_mean", True) - # pyre-fixme[4]: Attribute must be annotated. - self.metric_names = metric_names + """Initialize the ``PowerTransformY`` transform. + + Args: + search_space: The search space of the experiment. Unused. + observations: A list of observations from the experiment. + modelbridge: The `ModelBridge` within which the transform is used. Unused. + config: A dictionary of options to control the behavior of the transform. + Can contain the following keys: + - "metrics": A list of metric names to apply the transform to. If + omitted, all metrics found in `observations` are transformed. + - "clip_mean": Whether to clip the mean to the image of the transform. + Defaults to True. + """ + if observations is None or len(observations) == 0: + raise DataRequiredError("PowerTransformY requires observations.") + # pyre-fixme[9]: Can't annotate config["metrics"] properly. + metric_names: list[str] | None = config.get("metrics", None) if config else None + self.clip_mean: bool = ( + assert_is_instance(config.get("clip_mean", True), bool) if config else True + ) observation_data = [obs.data for obs in observations] Ys = get_data(observation_data=observation_data, metric_names=metric_names) + self.metric_names: list[str] = list(Ys.keys()) # pyre-fixme[4]: Attribute must be annotated. self.power_transforms = _compute_power_transforms(Ys=Ys) # pyre-fixme[4]: Attribute must be annotated. diff --git a/ax/modelbridge/transforms/tests/test_power_y_transform.py b/ax/modelbridge/transforms/tests/test_power_y_transform.py index 97d5972fa69..15eb4f7fa0b 100644 --- a/ax/modelbridge/transforms/tests/test_power_y_transform.py +++ b/ax/modelbridge/transforms/tests/test_power_y_transform.py @@ -72,13 +72,12 @@ def test_Init(self) -> None: "search_space": None, "observations": self.observations[:2], } - # Test error for not specifying a config - with self.assertRaises(ValueError): - PowerTransformY(**shared_init_args) - # Test error for not specifying at least one metric - with self.assertRaises(ValueError): - PowerTransformY(**shared_init_args, config={}) - # Test default init + # Init without a config. + t = PowerTransformY(**shared_init_args) + self.assertTrue(t.clip_mean) + self.assertEqual(t.metric_names, ["m1", "m2"]) + + # Test init with config. for m in ["m1", "m2"]: tf = PowerTransformY(**shared_init_args, config={"metrics": [m]}) # tf.power_transforms should only exist for m and be a PowerTransformer @@ -202,6 +201,12 @@ def test_TransformAndUntransformOneMetric(self) -> None: )[0] cov_results = np.array(transformed_obsd_nan.covariance) self.assertTrue(np.all(np.isnan(np.diag(cov_results)))) + untransformed = pt._untransform_observation_data([transformed_obsd_nan])[0] + self.assertTrue( + np.array_equal( + untransformed.covariance, self.obsd_nan.covariance, equal_nan=True + ) + ) def test_TransformAndUntransformAllMetrics(self) -> None: pt = PowerTransformY( diff --git a/ax/modelbridge/transforms/utils.py b/ax/modelbridge/transforms/utils.py index 614a33e69ef..e3db346eecb 100644 --- a/ax/modelbridge/transforms/utils.py +++ b/ax/modelbridge/transforms/utils.py @@ -113,14 +113,17 @@ def match_ci_width_truncated( See log_y transform for the original. Here, bounds are forced to lie within a [lower_bound, upper_bound] interval after transformation.""" fac = norm.ppf(1 - (1 - level) / 2) - d = fac * np.sqrt(variance) if clip_mean: mean = np.clip(mean, lower_bound + margin, upper_bound - margin) - right = min(mean + d, upper_bound - margin) - left = max(mean - d, lower_bound + margin) - width_asym = transform(right) - transform(left) new_mean = transform(mean) - new_variance = float("nan") if isnan(variance) else (width_asym / 2 / fac) ** 2 + if isnan(variance): + new_variance = variance + else: + d = fac * np.sqrt(variance) + right = min(mean + d, upper_bound - margin) + left = max(mean - d, lower_bound + margin) + width_asym = transform(right) - transform(left) + new_variance = (width_asym / 2 / fac) ** 2 return new_mean, new_variance