Skip to content

Commit 8779887

Browse files
Louis Tiaofacebook-github-bot
authored andcommitted
New Transform that adds metadata as parameters in an ObservationFeature (#3023)
Summary: Pull Request resolved: #3023 This implements a new transform, `MetadataToRange`, which extracts specified fields from each `ObservationFeature` instance's metadata and incorporates them as parameters. Furthermore, it updates the search space to include the specified field as a `RangeParameter` with bounds determined by observations provided during initialization. This process involves analyzing the metadata of each observation feature and identifying relevant fields that need to be included in the search space. The bounds for these fields are then determined based on the observations provided during initialization. Differential Revision: D65430943
1 parent f4aa969 commit 8779887

File tree

4 files changed

+382
-6
lines changed

4 files changed

+382
-6
lines changed

ax/core/parameter.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def is_numeric(self) -> bool:
7171
ParameterType.BOOL: bool,
7272
}
7373

74+
INVERSE_PARAMETER_PYTHON_TYPE_MAP: dict[TParameterType, ParameterType] = {
75+
v: k for k, v in PARAMETER_PYTHON_TYPE_MAP.items()
76+
}
7477
SUPPORTED_PARAMETER_TYPES: tuple[
7578
type[bool] | type[float] | type[int] | type[str], ...
7679
] = tuple(PARAMETER_PYTHON_TYPE_MAP.values())
@@ -80,10 +83,21 @@ def is_numeric(self) -> bool:
8083
# avoid runtime subscripting errors.
8184
def _get_parameter_type(python_type: type) -> ParameterType:
8285
"""Given a Python type, retrieve corresponding Ax ``ParameterType``."""
83-
for param_type, py_type in PARAMETER_PYTHON_TYPE_MAP.items():
84-
if py_type == python_type:
85-
return param_type
86-
raise ValueError(f"No Ax parameter type corresponding to {python_type}.")
86+
try:
87+
return INVERSE_PARAMETER_PYTHON_TYPE_MAP[python_type]
88+
except KeyError:
89+
raise ValueError(f"No Ax parameter type corresponding to {python_type}.")
90+
91+
92+
def _infer_parameter_type_from_value(value: TParameterType) -> ParameterType:
93+
# search in order of class hierarchy (e.g. bool is a subclass of int)
94+
# therefore cannot directly use SUPPORTED_PARAMETER_TYPES
95+
# (unless it is sorted correctly)
96+
return next(
97+
INVERSE_PARAMETER_PYTHON_TYPE_MAP[typ]
98+
for typ in (bool, int, float, str)
99+
if isinstance(value, typ)
100+
)
87101

88102

89103
class Parameter(SortableBase, metaclass=ABCMeta):
@@ -268,8 +282,7 @@ def __init__(
268282
"""
269283
if is_fidelity and (target_value is None):
270284
raise UserInputError(
271-
"`target_value` should not be None for the fidelity parameter: "
272-
"{}".format(name)
285+
f"`target_value` should not be None for the fidelity parameter: {name}"
273286
)
274287

275288
self._name = name
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from __future__ import annotations
10+
11+
from logging import Logger
12+
from typing import Any, Iterable, Optional, TYPE_CHECKING
13+
14+
from ax.core.observation import Observation, ObservationFeatures
15+
from ax.core.parameter import _infer_parameter_type_from_value, RangeParameter
16+
from ax.core.search_space import SearchSpace
17+
from ax.exceptions.core import DataRequiredError
18+
from ax.modelbridge.transforms.base import Transform
19+
from ax.models.types import TConfig
20+
from ax.utils.common.logger import get_logger
21+
from pyre_extensions import assert_is_instance, none_throws
22+
23+
if TYPE_CHECKING:
24+
# import as module to make sphinx-autodoc-typehints happy
25+
from ax import modelbridge as modelbridge_module # noqa F401
26+
27+
28+
logger: Logger = get_logger(__name__)
29+
30+
31+
class MetadataToRange(Transform):
32+
"""
33+
A transform that converts metadata from observation features into range parameters
34+
for a search space.
35+
36+
This transform takes a list of observations and extracts specified metadata keys
37+
to be used as parameter in the search space. It also updates the search space with
38+
new Range parameters based on the metadata values.
39+
40+
TODO[tiao]: update following
41+
Accepts the following `config` parameters:
42+
43+
- "keys": A list of strings representing the metadata keys to be extracted and
44+
used as features.
45+
- "log_scale": A boolean indicating whether the parameters should be on a
46+
log scale. Defaults to False.
47+
- "is_fidelity": A boolean indicating whether the parameters are fidelity
48+
parameters. Defaults to False.
49+
50+
Transform is done in-place.
51+
"""
52+
53+
DEFAULT_LOG_SCALE: bool = False
54+
DEFAULT_LOGIT_SCALE: bool = False
55+
DEFAULT_IS_FIDELITY: bool = False
56+
ENFORCE_BOUNDS: bool = False
57+
58+
def __init__(
59+
self,
60+
search_space: SearchSpace | None = None,
61+
observations: list[Observation] | None = None,
62+
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None,
63+
config: TConfig | None = None,
64+
) -> None:
65+
if observations is None or not observations:
66+
raise DataRequiredError(
67+
"`MetadataToRange` transform requires non-empty data."
68+
)
69+
config = config or {}
70+
enforce_bounds: bool = assert_is_instance(
71+
config.get("enforce_bounds", self.ENFORCE_BOUNDS), bool
72+
)
73+
self.parameters: dict[str, dict[str, Any]] = assert_is_instance(
74+
config.get("parameters", {}), dict
75+
)
76+
77+
self._parameter_list: list[RangeParameter] = []
78+
for name in self.parameters:
79+
parameter_type = None
80+
lb = ub = None # de facto bounds
81+
for obs in observations:
82+
obsf_metadata = none_throws(obs.features.metadata)
83+
val = obsf_metadata[name]
84+
85+
# TODO[tiao]: give user option to explicitly specify parameter type(?)
86+
# TODO[tiao]: check the inferred type is consistent across all
87+
# observations; such inconsistencies may actually be impossible
88+
# by virtue of the validations carried out upstream(?)
89+
parameter_type = parameter_type or _infer_parameter_type_from_value(val)
90+
91+
lb = min(val, lb) if lb is not None else val
92+
ub = max(val, ub) if ub is not None else val
93+
94+
lower = self.parameters[name].get("lower", lb)
95+
upper = self.parameters[name].get("upper", ub)
96+
97+
if enforce_bounds:
98+
if ub < upper:
99+
raise DataRequiredError(
100+
f"No values observed at upper bound {upper}"
101+
f" (highest observed: {ub})"
102+
)
103+
if lb > lower:
104+
raise DataRequiredError(
105+
f"No values observed at lower bound {lower}"
106+
f" (lowest observed: {lb})"
107+
)
108+
109+
log_scale = self.parameters[name].get("log_scale", self.DEFAULT_LOG_SCALE)
110+
logit_scale = self.parameters[name].get(
111+
"logit_scale", self.DEFAULT_LOGIT_SCALE
112+
)
113+
digits = self.parameters[name].get("digits")
114+
is_fidelity = self.parameters[name].get(
115+
"is_fidelity", self.DEFAULT_IS_FIDELITY
116+
)
117+
118+
# TODO[tiao]: necessary to check within bounds?
119+
target_value = self.parameters[name].get("target_value")
120+
121+
parameter = RangeParameter(
122+
name=name,
123+
parameter_type=none_throws(parameter_type),
124+
lower=lower,
125+
upper=upper,
126+
log_scale=log_scale,
127+
logit_scale=logit_scale,
128+
digits=digits,
129+
is_fidelity=is_fidelity,
130+
target_value=target_value,
131+
)
132+
self._parameter_list.append(parameter)
133+
134+
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
135+
for parameter in self._parameter_list:
136+
search_space.add_parameter(parameter)
137+
return search_space
138+
139+
def transform_observation_features(
140+
self, observation_features: list[ObservationFeatures]
141+
) -> list[ObservationFeatures]:
142+
for obsf in observation_features:
143+
self._transform_observation_feature(obsf)
144+
return observation_features
145+
146+
def untransform_observation_features(
147+
self, observation_features: list[ObservationFeatures]
148+
) -> list[ObservationFeatures]:
149+
for obsf in observation_features:
150+
self._untransform_observation_feature(obsf)
151+
return observation_features
152+
153+
def _transform_observation_feature(self, obsf: ObservationFeatures) -> None:
154+
_transfer(
155+
src=none_throws(obsf.metadata),
156+
dst=obsf.parameters,
157+
keys=self.parameters.keys(),
158+
)
159+
160+
def _untransform_observation_feature(self, obsf: ObservationFeatures) -> None:
161+
obsf.metadata = obsf.metadata or {}
162+
_transfer(
163+
src=obsf.parameters,
164+
dst=obsf.metadata,
165+
keys=self.parameters.keys(),
166+
)
167+
168+
169+
def _transfer(
170+
src: dict[str, Any],
171+
dst: dict[str, Any],
172+
keys: Iterable[str],
173+
) -> None:
174+
"""Transfer items in-place from one dictionary to another."""
175+
for key in keys:
176+
dst[key] = src.pop(key)

0 commit comments

Comments
 (0)