|
17 | 17 | from ax.core.optimization_config import OptimizationConfig
|
18 | 18 | from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint
|
19 | 19 | from ax.core.search_space import SearchSpace
|
| 20 | +from ax.exceptions.core import DataRequiredError |
20 | 21 | from ax.modelbridge.transforms.base import Transform
|
21 | 22 | from ax.modelbridge.transforms.utils import get_data, match_ci_width_truncated
|
22 | 23 | from ax.models.types import TConfig
|
23 | 24 | from ax.utils.common.logger import get_logger
|
24 | 25 | from ax.utils.common.typeutils import checked_cast_list
|
| 26 | +from pyre_extensions import assert_is_instance |
25 | 27 | from sklearn.preprocessing import PowerTransformer
|
26 | 28 |
|
27 | 29 | if TYPE_CHECKING:
|
@@ -57,19 +59,29 @@ def __init__(
|
57 | 59 | modelbridge: modelbridge_module.base.ModelBridge | None = None,
|
58 | 60 | config: TConfig | None = None,
|
59 | 61 | ) -> None:
|
60 |
| - assert observations is not None, "PowerTransformY requires observations" |
61 |
| - if config is None: |
62 |
| - raise ValueError("PowerTransform requires a config.") |
63 |
| - # pyre-fixme[6]: Same issue as for LogY |
64 |
| - metric_names = list(config.get("metrics", [])) |
65 |
| - if len(metric_names) == 0: |
66 |
| - raise ValueError("Must specify at least one metric in the config.") |
67 |
| - # pyre-fixme[4]: Attribute must be annotated. |
68 |
| - self.clip_mean = config.get("clip_mean", True) |
69 |
| - # pyre-fixme[4]: Attribute must be annotated. |
70 |
| - self.metric_names = metric_names |
| 62 | + """Initialize the ``PowerTransformY`` transform. |
| 63 | +
|
| 64 | + Args: |
| 65 | + search_space: The search space of the experiment. Unused. |
| 66 | + observations: A list of observations from the experiment. |
| 67 | + modelbridge: The `ModelBridge` within which the transform is used. Unused. |
| 68 | + config: A dictionary of options to control the behavior of the transform. |
| 69 | + Can contain the following keys: |
| 70 | + - "metrics": A list of metric names to apply the transform to. If |
| 71 | + omitted, all metrics found in `observations` are transformed. |
| 72 | + - "clip_mean": Whether to clip the mean to the image of the transform. |
| 73 | + Defaults to True. |
| 74 | + """ |
| 75 | + if observations is None or len(observations) == 0: |
| 76 | + raise DataRequiredError("PowerTransformY requires observations.") |
| 77 | + # pyre-fixme[9]: Can't annotate config["metrics"] properly. |
| 78 | + metric_names: list[str] | None = config.get("metrics", None) if config else None |
| 79 | + self.clip_mean: bool = ( |
| 80 | + assert_is_instance(config.get("clip_mean", True), bool) if config else True |
| 81 | + ) |
71 | 82 | observation_data = [obs.data for obs in observations]
|
72 | 83 | Ys = get_data(observation_data=observation_data, metric_names=metric_names)
|
| 84 | + self.metric_names: list[str] = list(Ys.keys()) |
73 | 85 | # pyre-fixme[4]: Attribute must be annotated.
|
74 | 86 | self.power_transforms = _compute_power_transforms(Ys=Ys)
|
75 | 87 | # pyre-fixme[4]: Attribute must be annotated.
|
|
0 commit comments