|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# |
| 4 | +# This source code is licensed under the MIT license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +# pyre-strict |
| 8 | + |
| 9 | +from __future__ import annotations |
| 10 | + |
| 11 | +from logging import Logger |
| 12 | +from typing import Any, Iterable, Optional, TYPE_CHECKING |
| 13 | + |
| 14 | +from ax.core.observation import Observation, ObservationFeatures |
| 15 | +from ax.core.parameter import _infer_parameter_type_from_value, RangeParameter |
| 16 | +from ax.core.search_space import SearchSpace |
| 17 | +from ax.exceptions.core import DataRequiredError |
| 18 | +from ax.modelbridge.transforms.base import Transform |
| 19 | +from ax.models.types import TConfig |
| 20 | +from ax.utils.common.logger import get_logger |
| 21 | +from pyre_extensions import assert_is_instance, none_throws |
| 22 | + |
| 23 | +if TYPE_CHECKING: |
| 24 | + # import as module to make sphinx-autodoc-typehints happy |
| 25 | + from ax import modelbridge as modelbridge_module # noqa F401 |
| 26 | + |
| 27 | + |
| 28 | +logger: Logger = get_logger(__name__) |
| 29 | + |
| 30 | + |
| 31 | +class MetadataToRange(Transform): |
| 32 | + """ |
| 33 | + A transform that converts metadata from observation features into range parameters |
| 34 | + for a search space. |
| 35 | +
|
| 36 | + This transform takes a list of observations and extracts specified metadata keys |
| 37 | + to be used as parameter in the search space. It also updates the search space with |
| 38 | + new Range parameters based on the metadata values. |
| 39 | +
|
| 40 | + TODO[tiao]: update following |
| 41 | + Accepts the following `config` parameters: |
| 42 | +
|
| 43 | + - "keys": A list of strings representing the metadata keys to be extracted and |
| 44 | + used as features. |
| 45 | + - "log_scale": A boolean indicating whether the parameters should be on a |
| 46 | + log scale. Defaults to False. |
| 47 | + - "is_fidelity": A boolean indicating whether the parameters are fidelity |
| 48 | + parameters. Defaults to False. |
| 49 | +
|
| 50 | + Transform is done in-place. |
| 51 | + """ |
| 52 | + |
| 53 | + DEFAULT_LOG_SCALE: bool = False |
| 54 | + DEFAULT_LOGIT_SCALE: bool = False |
| 55 | + DEFAULT_IS_FIDELITY: bool = False |
| 56 | + ENFORCE_BOUNDS: bool = False |
| 57 | + |
| 58 | + def __init__( |
| 59 | + self, |
| 60 | + search_space: SearchSpace | None = None, |
| 61 | + observations: list[Observation] | None = None, |
| 62 | + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, |
| 63 | + config: TConfig | None = None, |
| 64 | + ) -> None: |
| 65 | + if observations is None or not observations: |
| 66 | + raise DataRequiredError( |
| 67 | + "`MetadataToRange` transform requires non-empty data." |
| 68 | + ) |
| 69 | + config = config or {} |
| 70 | + enforce_bounds: bool = assert_is_instance( |
| 71 | + config.get("enforce_bounds", self.ENFORCE_BOUNDS), bool |
| 72 | + ) |
| 73 | + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( |
| 74 | + config.get("parameters", {}), dict |
| 75 | + ) |
| 76 | + |
| 77 | + self._parameter_list: list[RangeParameter] = [] |
| 78 | + for name in self.parameters: |
| 79 | + parameter_type = None |
| 80 | + lb = ub = None # de facto bounds |
| 81 | + for obs in observations: |
| 82 | + obsf_metadata = none_throws(obs.features.metadata) |
| 83 | + val = obsf_metadata[name] |
| 84 | + |
| 85 | + # TODO[tiao]: give user option to explicitly specify parameter type(?) |
| 86 | + # TODO[tiao]: check the inferred type is consistent across all |
| 87 | + # observations; such inconsistencies may actually be impossible |
| 88 | + # by virtue of the validations carried out upstream(?) |
| 89 | + parameter_type = parameter_type or _infer_parameter_type_from_value(val) |
| 90 | + |
| 91 | + lb = min(val, lb) if lb is not None else val |
| 92 | + ub = max(val, ub) if ub is not None else val |
| 93 | + |
| 94 | + lower = self.parameters[name].get("lower", lb) |
| 95 | + upper = self.parameters[name].get("upper", ub) |
| 96 | + |
| 97 | + if enforce_bounds: |
| 98 | + if ub < upper: |
| 99 | + raise DataRequiredError( |
| 100 | + f"No values observed at upper bound {upper}" |
| 101 | + f" (highest observed: {ub})" |
| 102 | + ) |
| 103 | + if lb > lower: |
| 104 | + raise DataRequiredError( |
| 105 | + f"No values observed at lower bound {lower}" |
| 106 | + f" (lowest observed: {lb})" |
| 107 | + ) |
| 108 | + |
| 109 | + log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE) |
| 110 | + logit_scale = self.parameters[name].get( |
| 111 | + "logit_scale", self.DEFAULT_LOGIT_SCALE |
| 112 | + ) |
| 113 | + digits = self.parameters[name].get("digits") |
| 114 | + is_fidelity = self.parameters[name].get( |
| 115 | + "is_fidelity", self.DEFAULT_IS_FIDELITY |
| 116 | + ) |
| 117 | + |
| 118 | + # TODO[tiao]: necessary to check within bounds? |
| 119 | + target_value = self.parameters[name].get("target_value") |
| 120 | + |
| 121 | + parameter = RangeParameter( |
| 122 | + name=name, |
| 123 | + parameter_type=none_throws(parameter_type), |
| 124 | + lower=lower, |
| 125 | + upper=upper, |
| 126 | + log_scale=log_scale, |
| 127 | + logit_scale=logit_scale, |
| 128 | + digits=digits, |
| 129 | + is_fidelity=is_fidelity, |
| 130 | + target_value=target_value, |
| 131 | + ) |
| 132 | + self._parameter_list.append(parameter) |
| 133 | + |
| 134 | + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: |
| 135 | + for parameter in self._parameter_list: |
| 136 | + search_space.add_parameter(parameter) |
| 137 | + return search_space |
| 138 | + |
| 139 | + def transform_observation_features( |
| 140 | + self, observation_features: list[ObservationFeatures] |
| 141 | + ) -> list[ObservationFeatures]: |
| 142 | + for obsf in observation_features: |
| 143 | + self._transform_observation_feature(obsf) |
| 144 | + return observation_features |
| 145 | + |
| 146 | + def untransform_observation_features( |
| 147 | + self, observation_features: list[ObservationFeatures] |
| 148 | + ) -> list[ObservationFeatures]: |
| 149 | + for obsf in observation_features: |
| 150 | + self._untransform_observation_feature(obsf) |
| 151 | + return observation_features |
| 152 | + |
| 153 | + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: |
| 154 | + _transfer( |
| 155 | + src=none_throws(obsf.metadata), |
| 156 | + dst=obsf.parameters, |
| 157 | + keys=self.parameters.keys(), |
| 158 | + ) |
| 159 | + |
| 160 | + def _untransform_observation_feature(self, obsf: ObservationFeatures) -> None: |
| 161 | + obsf.metadata = obsf.metadata or {} |
| 162 | + _transfer( |
| 163 | + src=obsf.parameters, |
| 164 | + dst=obsf.metadata, |
| 165 | + keys=self.parameters.keys(), |
| 166 | + ) |
| 167 | + |
| 168 | + |
| 169 | +def _transfer( |
| 170 | + src: dict[str, Any], |
| 171 | + dst: dict[str, Any], |
| 172 | + keys: Iterable[str], |
| 173 | +) -> None: |
| 174 | + """Transfer items in-place from one dictionary to another.""" |
| 175 | + for key in keys: |
| 176 | + dst[key] = src.pop(key) |
0 commit comments