diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 45ea62d0be2..a2d24475def 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -71,6 +71,9 @@ def is_numeric(self) -> bool: ParameterType.BOOL: bool, } +INVERSE_PARAMETER_PYTHON_TYPE_MAP: dict[TParameterType, ParameterType] = { + v: k for k, v in PARAMETER_PYTHON_TYPE_MAP.items() +} SUPPORTED_PARAMETER_TYPES: tuple[ type[bool] | type[float] | type[int] | type[str], ... ] = tuple(PARAMETER_PYTHON_TYPE_MAP.values()) @@ -80,10 +83,21 @@ def is_numeric(self) -> bool: # avoid runtime subscripting errors. def _get_parameter_type(python_type: type) -> ParameterType: """Given a Python type, retrieve corresponding Ax ``ParameterType``.""" - for param_type, py_type in PARAMETER_PYTHON_TYPE_MAP.items(): - if py_type == python_type: - return param_type - raise ValueError(f"No Ax parameter type corresponding to {python_type}.") + try: + return INVERSE_PARAMETER_PYTHON_TYPE_MAP[python_type] + except KeyError: + raise ValueError(f"No Ax parameter type corresponding to {python_type}.") + + +def _infer_parameter_type_from_value(value: TParameterType) -> ParameterType: + # search in order of class hierarchy (e.g. bool is a subclass of int) + # therefore cannot directly use SUPPORTED_PARAMETER_TYPES + # (unless it is sorted correctly) + return next( + INVERSE_PARAMETER_PYTHON_TYPE_MAP[typ] + for typ in (bool, int, float, str) + if isinstance(value, typ) + ) class Parameter(SortableBase, metaclass=ABCMeta): @@ -268,8 +282,7 @@ def __init__( """ if is_fidelity and (target_value is None): raise UserInputError( - "`target_value` should not be None for the fidelity parameter: " - "{}".format(name) + f"`target_value` should not be None for the fidelity parameter: {name}" ) self._name = name diff --git a/ax/modelbridge/transforms/metadata_to_range.py b/ax/modelbridge/transforms/metadata_to_range.py new file mode 100644 index 00000000000..c64e2382d03 --- /dev/null +++ b/ax/modelbridge/transforms/metadata_to_range.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from logging import Logger +from typing import Any, Iterable, Optional, TYPE_CHECKING + +from ax.core.observation import Observation, ObservationFeatures +from ax.core.parameter import _infer_parameter_type_from_value, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.base import Transform +from ax.models.types import TConfig +from ax.utils.common.logger import get_logger +from pyre_extensions import assert_is_instance, none_throws + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +logger: Logger = get_logger(__name__) + + +class MetadataToRange(Transform): + """ + A transform that converts metadata from observation features into range parameters + for a search space. + + This transform takes a list of observations and extracts specified metadata keys + to be used as parameter in the search space. It also updates the search space with + new Range parameters based on the metadata values. + + TODO[tiao]: update following + Accepts the following `config` parameters: + + - "keys": A list of strings representing the metadata keys to be extracted and + used as features. + - "log_scale": A boolean indicating whether the parameters should be on a + log scale. Defaults to False. + - "is_fidelity": A boolean indicating whether the parameters are fidelity + parameters. Defaults to False. + + Transform is done in-place. + """ + + DEFAULT_LOG_SCALE: bool = False + DEFAULT_LOGIT_SCALE: bool = False + DEFAULT_IS_FIDELITY: bool = False + ENFORCE_BOUNDS: bool = False + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + if observations is None or not observations: + raise DataRequiredError( + "`MetadataToRange` transform requires non-empty data." + ) + config = config or {} + enforce_bounds: bool = assert_is_instance( + config.get("enforce_bounds", self.ENFORCE_BOUNDS), bool + ) + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( + config.get("parameters", {}), dict + ) + + self._parameter_list: list[RangeParameter] = [] + for name in self.parameters: + parameter_type = None + lb = ub = None # de facto bounds + for obs in observations: + obsf_metadata = none_throws(obs.features.metadata) + val = obsf_metadata[name] + + # TODO[tiao]: give user option to explicitly specify parameter type(?) + # TODO[tiao]: check the inferred type is consistent across all + # observations; such inconsistencies may actually be impossible + # by virtue of the validations carried out upstream(?) + parameter_type = parameter_type or _infer_parameter_type_from_value(val) + + lb = min(val, lb) if lb is not None else val + ub = max(val, ub) if ub is not None else val + + lower = self.parameters[name].get("lower", lb) + upper = self.parameters[name].get("upper", ub) + + if enforce_bounds: + if ub < upper: + raise DataRequiredError( + f"No values observed at upper bound {upper}" + f" (highest observed: {ub})" + ) + if lb > lower: + raise DataRequiredError( + f"No values observed at lower bound {lower}" + f" (lowest observed: {lb})" + ) + + log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE) + logit_scale = self.parameters[name].get( + "logit_scale", self.DEFAULT_LOGIT_SCALE + ) + digits = self.parameters[name].get("digits") + is_fidelity = self.parameters[name].get( + "is_fidelity", self.DEFAULT_IS_FIDELITY + ) + + # TODO[tiao]: necessary to check within bounds? + target_value = self.parameters[name].get("target_value") + + parameter = RangeParameter( + name=name, + parameter_type=none_throws(parameter_type), + lower=lower, + upper=upper, + log_scale=log_scale, + logit_scale=logit_scale, + digits=digits, + is_fidelity=is_fidelity, + target_value=target_value, + ) + self._parameter_list.append(parameter) + + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: + for parameter in self._parameter_list: + search_space.add_parameter(parameter) + return search_space + + def transform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + self._transform_observation_feature(obsf) + return observation_features + + def untransform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + self._untransform_observation_feature(obsf) + return observation_features + + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: + _transfer( + src=none_throws(obsf.metadata), + dst=obsf.parameters, + keys=self.parameters.keys(), + ) + + def _untransform_observation_feature(self, obsf: ObservationFeatures) -> None: + obsf.metadata = obsf.metadata or {} + _transfer( + src=obsf.parameters, + dst=obsf.metadata, + keys=self.parameters.keys(), + ) + + +def _transfer( + src: dict[str, Any], + dst: dict[str, Any], + keys: Iterable[str], +) -> None: + """Transfer items in-place from one dictionary to another.""" + for key in keys: + dst[key] = src.pop(key) diff --git a/ax/modelbridge/transforms/tests/test_metadata_to_range_transform.py b/ax/modelbridge/transforms/tests/test_metadata_to_range_transform.py new file mode 100644 index 00000000000..2069c8ac911 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_metadata_to_range_transform.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from typing import Iterator + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.metadata_to_range import MetadataToRange +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance + + +WIDTHS = [2.0, 4.0, 8.0] +HEIGHTS = [4.0, 2.0, 8.0] +STEPS_ENDS = [1, 5, 3] + + +def _enumerate() -> Iterator[tuple[int, float, float, float]]: + yield from ( + (trial_index, width, height, float(i + 1)) + for trial_index, (width, height, steps_end) in enumerate( + zip(WIDTHS, HEIGHTS, STEPS_ENDS) + ) + for i in range(steps_end) + ) + + +class MetadataToRangeTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + name="width", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + RangeParameter( + name="height", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + ] + ) + + self.observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": 3.0 * steps, + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + self.observations.append(Observation(features=obs_feat, data=obs_data)) + + self.t = MetadataToRange( + observations=self.observations, + config={ + "parameters": {"bar": {"log_scale": True}}, + }, + ) + + def test_Init(self) -> None: + self.assertEqual(len(self.t._parameter_list), 1) + + p = self.t._parameter_list[0] + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + with self.assertRaises(DataRequiredError): + MetadataToRange(search_space=None, observations=None) + with self.assertRaises(DataRequiredError): + MetadataToRange(search_space=None, observations=[]) + + with self.subTest("infer parameter type"): + observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": int(steps), + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + observations.append(Observation(features=obs_feat, data=obs_data)) + + t = MetadataToRange( + observations=observations, + config={ + "parameters": {"bar": {}}, + }, + ) + self.assertEqual(len(t._parameter_list), 1) + + p = t._parameter_list[0] + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.INT) + self.assertEqual(p.lower, 1) + self.assertEqual(p.upper, 5) + self.assertFalse(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + + self.assertSetEqual( + set(ss2.parameters.keys()), + {"height", "width", "bar"}, + ) + + p = assert_is_instance(ss2.parameters["bar"], RangeParameter) + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformObservationFeatures(self) -> None: + observation_features = [obs.features for obs in self.observations] + obs_ft2 = deepcopy(observation_features) + obs_ft2 = self.t.transform_observation_features(obs_ft2) + + self.assertEqual( + obs_ft2, + [ + ObservationFeatures( + trial_index=trial_index, + parameters={ + "width": width, + "height": height, + "bar": 3.0 * steps, + }, + metadata={"foo": 42}, + ) + for trial_index, width, height, steps in _enumerate() + ], + ) + obs_ft2 = self.t.untransform_observation_features(obs_ft2) + self.assertEqual(obs_ft2, observation_features) diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 98a0c124cc7..fce78566ed2 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -310,6 +310,15 @@ Transforms :undoc-members: :show-inheritance: + +`ax.modelbridge.transforms.metadata\_to\_range` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.metadata_to_range + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.rounding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~