Skip to content

Commit

Permalink
New Transform that adds metadata as parameters in an ObservationFeatu…
Browse files Browse the repository at this point in the history
…re (#3023)

Summary:
Pull Request resolved: #3023

This implements a new transform, `MetadataToRange`, which extracts specified fields from each `ObservationFeature` instance's metadata and incorporates them as parameters. Furthermore, it updates the search space to include the specified field as a `RangeParameter` with bounds determined by observations provided during initialization. This process involves analyzing the metadata of each observation feature and identifying relevant fields that need to be included in the search space. The bounds for these fields are then determined based on the observations provided during initialization.

Differential Revision: D65430943
  • Loading branch information
Louis Tiao authored and facebook-github-bot committed Dec 12, 2024
1 parent f4aa969 commit 8779887
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 6 deletions.
25 changes: 19 additions & 6 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def is_numeric(self) -> bool:
ParameterType.BOOL: bool,
}

INVERSE_PARAMETER_PYTHON_TYPE_MAP: dict[TParameterType, ParameterType] = {
v: k for k, v in PARAMETER_PYTHON_TYPE_MAP.items()
}
SUPPORTED_PARAMETER_TYPES: tuple[
type[bool] | type[float] | type[int] | type[str], ...
] = tuple(PARAMETER_PYTHON_TYPE_MAP.values())
Expand All @@ -80,10 +83,21 @@ def is_numeric(self) -> bool:
# avoid runtime subscripting errors.
def _get_parameter_type(python_type: type) -> ParameterType:
"""Given a Python type, retrieve corresponding Ax ``ParameterType``."""
for param_type, py_type in PARAMETER_PYTHON_TYPE_MAP.items():
if py_type == python_type:
return param_type
raise ValueError(f"No Ax parameter type corresponding to {python_type}.")
try:
return INVERSE_PARAMETER_PYTHON_TYPE_MAP[python_type]
except KeyError:
raise ValueError(f"No Ax parameter type corresponding to {python_type}.")


def _infer_parameter_type_from_value(value: TParameterType) -> ParameterType:
# search in order of class hierarchy (e.g. bool is a subclass of int)
# therefore cannot directly use SUPPORTED_PARAMETER_TYPES
# (unless it is sorted correctly)
return next(
INVERSE_PARAMETER_PYTHON_TYPE_MAP[typ]
for typ in (bool, int, float, str)
if isinstance(value, typ)
)


class Parameter(SortableBase, metaclass=ABCMeta):
Expand Down Expand Up @@ -268,8 +282,7 @@ def __init__(
"""
if is_fidelity and (target_value is None):
raise UserInputError(
"`target_value` should not be None for the fidelity parameter: "
"{}".format(name)
f"`target_value` should not be None for the fidelity parameter: {name}"
)

self._name = name
Expand Down
176 changes: 176 additions & 0 deletions ax/modelbridge/transforms/metadata_to_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from __future__ import annotations

from logging import Logger
from typing import Any, Iterable, Optional, TYPE_CHECKING

from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import _infer_parameter_type_from_value, RangeParameter
from ax.core.search_space import SearchSpace
from ax.exceptions.core import DataRequiredError
from ax.modelbridge.transforms.base import Transform
from ax.models.types import TConfig
from ax.utils.common.logger import get_logger
from pyre_extensions import assert_is_instance, none_throws

if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import modelbridge as modelbridge_module # noqa F401


logger: Logger = get_logger(__name__)


class MetadataToRange(Transform):
"""
A transform that converts metadata from observation features into range parameters
for a search space.
This transform takes a list of observations and extracts specified metadata keys
to be used as parameter in the search space. It also updates the search space with
new Range parameters based on the metadata values.
TODO[tiao]: update following
Accepts the following `config` parameters:
- "keys": A list of strings representing the metadata keys to be extracted and
used as features.
- "log_scale": A boolean indicating whether the parameters should be on a
log scale. Defaults to False.
- "is_fidelity": A boolean indicating whether the parameters are fidelity
parameters. Defaults to False.
Transform is done in-place.
"""

DEFAULT_LOG_SCALE: bool = False
DEFAULT_LOGIT_SCALE: bool = False
DEFAULT_IS_FIDELITY: bool = False
ENFORCE_BOUNDS: bool = False

def __init__(
self,
search_space: SearchSpace | None = None,
observations: list[Observation] | None = None,
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
config: TConfig | None = None,
) -> None:
if observations is None or not observations:
raise DataRequiredError(
"`MetadataToRange` transform requires non-empty data."
)
config = config or {}
enforce_bounds: bool = assert_is_instance(
config.get("enforce_bounds", self.ENFORCE_BOUNDS), bool
)
self.parameters: dict[str, dict[str, Any]] = assert_is_instance(
config.get("parameters", {}), dict
)

self._parameter_list: list[RangeParameter] = []
for name in self.parameters:
parameter_type = None
lb = ub = None # de facto bounds
for obs in observations:
obsf_metadata = none_throws(obs.features.metadata)
val = obsf_metadata[name]

# TODO[tiao]: give user option to explicitly specify parameter type(?)
# TODO[tiao]: check the inferred type is consistent across all
# observations; such inconsistencies may actually be impossible
# by virtue of the validations carried out upstream(?)
parameter_type = parameter_type or _infer_parameter_type_from_value(val)

lb = min(val, lb) if lb is not None else val
ub = max(val, ub) if ub is not None else val

lower = self.parameters[name].get("lower", lb)
upper = self.parameters[name].get("upper", ub)

if enforce_bounds:
if ub < upper:
raise DataRequiredError(
f"No values observed at upper bound {upper}"
f" (highest observed: {ub})"
)
if lb > lower:
raise DataRequiredError(
f"No values observed at lower bound {lower}"
f" (lowest observed: {lb})"
)

log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE)
logit_scale = self.parameters[name].get(
"logit_scale", self.DEFAULT_LOGIT_SCALE
)
digits = self.parameters[name].get("digits")
is_fidelity = self.parameters[name].get(
"is_fidelity", self.DEFAULT_IS_FIDELITY
)

# TODO[tiao]: necessary to check within bounds?
target_value = self.parameters[name].get("target_value")

parameter = RangeParameter(
name=name,
parameter_type=none_throws(parameter_type),
lower=lower,
upper=upper,
log_scale=log_scale,
logit_scale=logit_scale,
digits=digits,
is_fidelity=is_fidelity,
target_value=target_value,
)
self._parameter_list.append(parameter)

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
for parameter in self._parameter_list:
search_space.add_parameter(parameter)
return search_space

def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
self._transform_observation_feature(obsf)
return observation_features

def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
self._untransform_observation_feature(obsf)
return observation_features

def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
_transfer(
src=none_throws(obsf.metadata),
dst=obsf.parameters,
keys=self.parameters.keys(),
)

def _untransform_observation_feature(self, obsf: ObservationFeatures) -> None:
obsf.metadata = obsf.metadata or {}
_transfer(
src=obsf.parameters,
dst=obsf.metadata,
keys=self.parameters.keys(),
)


def _transfer(
src: dict[str, Any],
dst: dict[str, Any],
keys: Iterable[str],
) -> None:
"""Transfer items in-place from one dictionary to another."""
for key in keys:
dst[key] = src.pop(key)
Loading

0 comments on commit 8779887

Please sign in to comment.