From 842ec7dca4ea95238f1ef0f5bfeef18083319432 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Wed, 11 Dec 2024 20:20:47 -0800 Subject: [PATCH] New Transform that adds metadata as parameters in an ObservationFeature (#3023) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3023 This implements a new transform, `MetadataToRange`, which extracts specified fields from each `ObservationFeature` instance's metadata and incorporates them as parameters. Furthermore, it updates the search space to include the specified field as a `RangeParameter` with bounds determined by observations provided during initialization. This process involves analyzing the metadata of each observation feature and identifying relevant fields that need to be included in the search space. The bounds for these fields are then determined based on the observations provided during initialization. Differential Revision: D65430943 --- ax/core/parameter.py | 25 ++- .../transforms/metadata_to_range.py | 176 +++++++++++++++++ .../tests/test_metadata_to_range_transform.py | 178 ++++++++++++++++++ sphinx/source/modelbridge.rst | 9 + 4 files changed, 382 insertions(+), 6 deletions(-) create mode 100644 ax/modelbridge/transforms/metadata_to_range.py create mode 100644 ax/modelbridge/transforms/tests/test_metadata_to_range_transform.py diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 45ea62d0be2..a2d24475def 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -71,6 +71,9 @@ def is_numeric(self) -> bool: ParameterType.BOOL: bool, } +INVERSE_PARAMETER_PYTHON_TYPE_MAP: dict[TParameterType, ParameterType] = { + v: k for k, v in PARAMETER_PYTHON_TYPE_MAP.items() +} SUPPORTED_PARAMETER_TYPES: tuple[ type[bool] | type[float] | type[int] | type[str], ... ] = tuple(PARAMETER_PYTHON_TYPE_MAP.values()) @@ -80,10 +83,21 @@ def is_numeric(self) -> bool: # avoid runtime subscripting errors. def _get_parameter_type(python_type: type) -> ParameterType: """Given a Python type, retrieve corresponding Ax ``ParameterType``.""" - for param_type, py_type in PARAMETER_PYTHON_TYPE_MAP.items(): - if py_type == python_type: - return param_type - raise ValueError(f"No Ax parameter type corresponding to {python_type}.") + try: + return INVERSE_PARAMETER_PYTHON_TYPE_MAP[python_type] + except KeyError: + raise ValueError(f"No Ax parameter type corresponding to {python_type}.") + + +def _infer_parameter_type_from_value(value: TParameterType) -> ParameterType: + # search in order of class hierarchy (e.g. bool is a subclass of int) + # therefore cannot directly use SUPPORTED_PARAMETER_TYPES + # (unless it is sorted correctly) + return next( + INVERSE_PARAMETER_PYTHON_TYPE_MAP[typ] + for typ in (bool, int, float, str) + if isinstance(value, typ) + ) class Parameter(SortableBase, metaclass=ABCMeta): @@ -268,8 +282,7 @@ def __init__( """ if is_fidelity and (target_value is None): raise UserInputError( - "`target_value` should not be None for the fidelity parameter: " - "{}".format(name) + f"`target_value` should not be None for the fidelity parameter: {name}" ) self._name = name diff --git a/ax/modelbridge/transforms/metadata_to_range.py b/ax/modelbridge/transforms/metadata_to_range.py new file mode 100644 index 00000000000..c64e2382d03 --- /dev/null +++ b/ax/modelbridge/transforms/metadata_to_range.py @@ -0,0 +1,176 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from __future__ import annotations + +from logging import Logger +from typing import Any, Iterable, Optional, TYPE_CHECKING + +from ax.core.observation import Observation, ObservationFeatures +from ax.core.parameter import _infer_parameter_type_from_value, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.base import Transform +from ax.models.types import TConfig +from ax.utils.common.logger import get_logger +from pyre_extensions import assert_is_instance, none_throws + +if TYPE_CHECKING: + # import as module to make sphinx-autodoc-typehints happy + from ax import modelbridge as modelbridge_module # noqa F401 + + +logger: Logger = get_logger(__name__) + + +class MetadataToRange(Transform): + """ + A transform that converts metadata from observation features into range parameters + for a search space. + + This transform takes a list of observations and extracts specified metadata keys + to be used as parameter in the search space. It also updates the search space with + new Range parameters based on the metadata values. + + TODO[tiao]: update following + Accepts the following `config` parameters: + + - "keys": A list of strings representing the metadata keys to be extracted and + used as features. + - "log_scale": A boolean indicating whether the parameters should be on a + log scale. Defaults to False. + - "is_fidelity": A boolean indicating whether the parameters are fidelity + parameters. Defaults to False. + + Transform is done in-place. + """ + + DEFAULT_LOG_SCALE: bool = False + DEFAULT_LOGIT_SCALE: bool = False + DEFAULT_IS_FIDELITY: bool = False + ENFORCE_BOUNDS: bool = False + + def __init__( + self, + search_space: SearchSpace | None = None, + observations: list[Observation] | None = None, + modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + config: TConfig | None = None, + ) -> None: + if observations is None or not observations: + raise DataRequiredError( + "`MetadataToRange` transform requires non-empty data." + ) + config = config or {} + enforce_bounds: bool = assert_is_instance( + config.get("enforce_bounds", self.ENFORCE_BOUNDS), bool + ) + self.parameters: dict[str, dict[str, Any]] = assert_is_instance( + config.get("parameters", {}), dict + ) + + self._parameter_list: list[RangeParameter] = [] + for name in self.parameters: + parameter_type = None + lb = ub = None # de facto bounds + for obs in observations: + obsf_metadata = none_throws(obs.features.metadata) + val = obsf_metadata[name] + + # TODO[tiao]: give user option to explicitly specify parameter type(?) + # TODO[tiao]: check the inferred type is consistent across all + # observations; such inconsistencies may actually be impossible + # by virtue of the validations carried out upstream(?) + parameter_type = parameter_type or _infer_parameter_type_from_value(val) + + lb = min(val, lb) if lb is not None else val + ub = max(val, ub) if ub is not None else val + + lower = self.parameters[name].get("lower", lb) + upper = self.parameters[name].get("upper", ub) + + if enforce_bounds: + if ub < upper: + raise DataRequiredError( + f"No values observed at upper bound {upper}" + f" (highest observed: {ub})" + ) + if lb > lower: + raise DataRequiredError( + f"No values observed at lower bound {lower}" + f" (lowest observed: {lb})" + ) + + log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE) + logit_scale = self.parameters[name].get( + "logit_scale", self.DEFAULT_LOGIT_SCALE + ) + digits = self.parameters[name].get("digits") + is_fidelity = self.parameters[name].get( + "is_fidelity", self.DEFAULT_IS_FIDELITY + ) + + # TODO[tiao]: necessary to check within bounds? + target_value = self.parameters[name].get("target_value") + + parameter = RangeParameter( + name=name, + parameter_type=none_throws(parameter_type), + lower=lower, + upper=upper, + log_scale=log_scale, + logit_scale=logit_scale, + digits=digits, + is_fidelity=is_fidelity, + target_value=target_value, + ) + self._parameter_list.append(parameter) + + def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: + for parameter in self._parameter_list: + search_space.add_parameter(parameter) + return search_space + + def transform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + self._transform_observation_feature(obsf) + return observation_features + + def untransform_observation_features( + self, observation_features: list[ObservationFeatures] + ) -> list[ObservationFeatures]: + for obsf in observation_features: + self._untransform_observation_feature(obsf) + return observation_features + + def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: + _transfer( + src=none_throws(obsf.metadata), + dst=obsf.parameters, + keys=self.parameters.keys(), + ) + + def _untransform_observation_feature(self, obsf: ObservationFeatures) -> None: + obsf.metadata = obsf.metadata or {} + _transfer( + src=obsf.parameters, + dst=obsf.metadata, + keys=self.parameters.keys(), + ) + + +def _transfer( + src: dict[str, Any], + dst: dict[str, Any], + keys: Iterable[str], +) -> None: + """Transfer items in-place from one dictionary to another.""" + for key in keys: + dst[key] = src.pop(key) diff --git a/ax/modelbridge/transforms/tests/test_metadata_to_range_transform.py b/ax/modelbridge/transforms/tests/test_metadata_to_range_transform.py new file mode 100644 index 00000000000..2069c8ac911 --- /dev/null +++ b/ax/modelbridge/transforms/tests/test_metadata_to_range_transform.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from copy import deepcopy +from typing import Iterator + +import numpy as np +from ax.core.observation import Observation, ObservationData, ObservationFeatures +from ax.core.parameter import ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.exceptions.core import DataRequiredError +from ax.modelbridge.transforms.metadata_to_range import MetadataToRange +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance + + +WIDTHS = [2.0, 4.0, 8.0] +HEIGHTS = [4.0, 2.0, 8.0] +STEPS_ENDS = [1, 5, 3] + + +def _enumerate() -> Iterator[tuple[int, float, float, float]]: + yield from ( + (trial_index, width, height, float(i + 1)) + for trial_index, (width, height, steps_end) in enumerate( + zip(WIDTHS, HEIGHTS, STEPS_ENDS) + ) + for i in range(steps_end) + ) + + +class MetadataToRangeTransformTest(TestCase): + def setUp(self) -> None: + super().setUp() + + self.search_space = SearchSpace( + parameters=[ + RangeParameter( + name="width", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + RangeParameter( + name="height", + parameter_type=ParameterType.FLOAT, + lower=1, + upper=20, + ), + ] + ) + + self.observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": 3.0 * steps, + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + self.observations.append(Observation(features=obs_feat, data=obs_data)) + + self.t = MetadataToRange( + observations=self.observations, + config={ + "parameters": {"bar": {"log_scale": True}}, + }, + ) + + def test_Init(self) -> None: + self.assertEqual(len(self.t._parameter_list), 1) + + p = self.t._parameter_list[0] + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + with self.assertRaises(DataRequiredError): + MetadataToRange(search_space=None, observations=None) + with self.assertRaises(DataRequiredError): + MetadataToRange(search_space=None, observations=[]) + + with self.subTest("infer parameter type"): + observations = [] + for trial_index, width, height, steps in _enumerate(): + obs_feat = ObservationFeatures( + trial_index=trial_index, + parameters={"width": width, "height": height}, + metadata={ + "foo": 42, + "bar": int(steps), + }, + ) + obs_data = ObservationData( + metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) + ) + observations.append(Observation(features=obs_feat, data=obs_data)) + + t = MetadataToRange( + observations=observations, + config={ + "parameters": {"bar": {}}, + }, + ) + self.assertEqual(len(t._parameter_list), 1) + + p = t._parameter_list[0] + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.INT) + self.assertEqual(p.lower, 1) + self.assertEqual(p.upper, 5) + self.assertFalse(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformSearchSpace(self) -> None: + ss2 = deepcopy(self.search_space) + ss2 = self.t.transform_search_space(ss2) + + self.assertSetEqual( + set(ss2.parameters.keys()), + {"height", "width", "bar"}, + ) + + p = assert_is_instance(ss2.parameters["bar"], RangeParameter) + + self.assertEqual(p.name, "bar") + self.assertEqual(p.parameter_type, ParameterType.FLOAT) + self.assertEqual(p.lower, 3.0) + self.assertEqual(p.upper, 15.0) + self.assertTrue(p.log_scale) + self.assertFalse(p.logit_scale) + self.assertIsNone(p.digits) + self.assertFalse(p.is_fidelity) + self.assertIsNone(p.target_value) + + def test_TransformObservationFeatures(self) -> None: + observation_features = [obs.features for obs in self.observations] + obs_ft2 = deepcopy(observation_features) + obs_ft2 = self.t.transform_observation_features(obs_ft2) + + self.assertEqual( + obs_ft2, + [ + ObservationFeatures( + trial_index=trial_index, + parameters={ + "width": width, + "height": height, + "bar": 3.0 * steps, + }, + metadata={"foo": 42}, + ) + for trial_index, width, height, steps in _enumerate() + ], + ) + obs_ft2 = self.t.untransform_observation_features(obs_ft2) + self.assertEqual(obs_ft2, observation_features) diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index 98a0c124cc7..fce78566ed2 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -310,6 +310,15 @@ Transforms :undoc-members: :show-inheritance: + +`ax.modelbridge.transforms.metadata\_to\_range` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. automodule:: ax.modelbridge.transforms.metadata_to_range + :members: + :undoc-members: + :show-inheritance: + `ax.modelbridge.transforms.rounding` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~