Skip to content

Commit

Permalink
Update PowerTransformY to operate without a config (facebook#3033)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#3033

Previously, this transform required a config to be specified. The updates make the config optional.

Also updated `match_ci_width_truncated` to fix a bug I ran into during benchmarking but I could not reproduce it using tests. The update should be no-op simplification in the worst case.

Reviewed By: Balandat

Differential Revision: D65433936

fbshipit-source-id: 873c30fa71158bab99e233839b4974c8df16d9ff
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Nov 8, 2024
1 parent a3c1402 commit 2322259
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 23 deletions.
34 changes: 23 additions & 11 deletions ax/modelbridge/transforms/power_transform_y.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from ax.core.optimization_config import OptimizationConfig
from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.base import Transform
from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import checked_cast_list
from pyre_extensions import assert_is_instance
from sklearn.preprocessing import PowerTransformer

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,19 +59,29 @@ def __init__(
modelbridge: modelbridge_module.base.ModelBridge | None = None,
config: TConfig | None = None,
) -> None:
assert observations is not None, "PowerTransformY requires observations"
if config is None:
raise ValueError("PowerTransform requires a config.")
# pyre-fixme[6]: Same issue as for LogY
metric_names = list(config.get("metrics", []))
if len(metric_names) == 0:
raise ValueError("Must specify at least one metric in the config.")
# pyre-fixme[4]: Attribute must be annotated.
self.clip_mean = config.get("clip_mean", True)
# pyre-fixme[4]: Attribute must be annotated.
self.metric_names = metric_names
"""Initialize the ``PowerTransformY`` transform.
Args:
search_space: The search space of the experiment. Unused.
observations: A list of observations from the experiment.
modelbridge: The `ModelBridge` within which the transform is used. Unused.
config: A dictionary of options to control the behavior of the transform.
Can contain the following keys:
- "metrics": A list of metric names to apply the transform to. If
omitted, all metrics found in `observations` are transformed.
- "clip_mean": Whether to clip the mean to the image of the transform.
Defaults to True.
"""
if observations is None or len(observations) == 0:
raise DataRequiredError("PowerTransformY requires observations.")
# pyre-fixme[9]: Can't annotate config["metrics"] properly.
metric_names: list[str] | None = config.get("metrics", None) if config else None
self.clip_mean: bool = (
assert_is_instance(config.get("clip_mean", True), bool) if config else True
)
observation_data = [obs.data for obs in observations]
Ys = get_data(observation_data=observation_data, metric_names=metric_names)
self.metric_names: list[str] = list(Ys.keys())
# pyre-fixme[4]: Attribute must be annotated.
self.power_transforms = _compute_power_transforms(Ys=Ys)
# pyre-fixme[4]: Attribute must be annotated.
Expand Down
19 changes: 12 additions & 7 deletions ax/modelbridge/transforms/tests/test_power_y_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,12 @@ def test_Init(self) -> None:
"search_space": None,
"observations": self.observations[:2],
}
# Test error for not specifying a config
with self.assertRaises(ValueError):
PowerTransformY(**shared_init_args)
# Test error for not specifying at least one metric
with self.assertRaises(ValueError):
PowerTransformY(**shared_init_args, config={})
# Test default init
# Init without a config.
t = PowerTransformY(**shared_init_args)
self.assertTrue(t.clip_mean)
self.assertEqual(t.metric_names, ["m1", "m2"])

# Test init with config.
for m in ["m1", "m2"]:
tf = PowerTransformY(**shared_init_args, config={"metrics": [m]})
# tf.power_transforms should only exist for m and be a PowerTransformer
Expand Down Expand Up @@ -202,6 +201,12 @@ def test_TransformAndUntransformOneMetric(self) -> None:
)[0]
cov_results = np.array(transformed_obsd_nan.covariance)
self.assertTrue(np.all(np.isnan(np.diag(cov_results))))
untransformed = pt._untransform_observation_data([transformed_obsd_nan])[0]
self.assertTrue(
np.array_equal(
untransformed.covariance, self.obsd_nan.covariance, equal_nan=True
)
)

def test_TransformAndUntransformAllMetrics(self) -> None:
pt = PowerTransformY(
Expand Down
13 changes: 8 additions & 5 deletions ax/modelbridge/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,17 @@ def match_ci_width_truncated(
See log_y transform for the original. Here, bounds are forced to lie
within a [lower_bound, upper_bound] interval after transformation."""
fac = norm.ppf(1 - (1 - level) / 2)
d = fac * np.sqrt(variance)
if clip_mean:
mean = np.clip(mean, lower_bound + margin, upper_bound - margin)
right = min(mean + d, upper_bound - margin)
left = max(mean - d, lower_bound + margin)
width_asym = transform(right) - transform(left)
new_mean = transform(mean)
new_variance = float("nan") if isnan(variance) else (width_asym / 2 / fac) ** 2
if isnan(variance):
new_variance = variance
else:
d = fac * np.sqrt(variance)
right = min(mean + d, upper_bound - margin)
left = max(mean - d, lower_bound + margin)
width_asym = transform(right) - transform(left)
new_variance = (width_asym / 2 / fac) ** 2
return new_mean, new_variance


Expand Down

0 comments on commit 2322259

Please sign in to comment.