forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adds subclass of MetadataToRange Transform that provides sensible def…
…aults for MapData (facebook#3155) Summary: Pull Request resolved: facebook#3155 Differential Revision: D66945078
- Loading branch information
1 parent
8779887
commit 0e94c4a
Showing
2 changed files
with
170 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from typing import Any, Optional, TYPE_CHECKING | ||
|
||
from ax.core.map_metric import MapMetric | ||
from ax.core.observation import Observation, ObservationFeatures | ||
from ax.core.search_space import SearchSpace | ||
from ax.modelbridge.transforms.metadata_to_range import MetadataToRange | ||
from ax.models.types import TConfig | ||
from pyre_extensions import assert_is_instance | ||
|
||
if TYPE_CHECKING: | ||
# import as module to make sphinx-autodoc-typehints happy | ||
from ax import modelbridge as modelbridge_module # noqa F401 | ||
|
||
|
||
class MapKeyToRange(MetadataToRange): | ||
DEFAULT_LOG_SCALE: bool = True | ||
DEFAULT_MAP_KEY: str = MapMetric.map_key_info.key | ||
|
||
def __init__( | ||
self, | ||
search_space: SearchSpace | None = None, | ||
observations: list[Observation] | None = None, | ||
modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, | ||
config: TConfig | None = None, | ||
) -> None: | ||
config = config or {} | ||
self.parameters: dict[str, dict[str, Any]] = assert_is_instance( | ||
config.setdefault("parameters", {}), dict | ||
) | ||
# TODO[tiao]: raise warning if `DEFAULT_MAP_KEY` is already in keys(?) | ||
self.parameters.setdefault(self.DEFAULT_MAP_KEY, {}) | ||
super().__init__( | ||
search_space=search_space, | ||
observations=observations, | ||
modelbridge=modelbridge, | ||
config=config, | ||
) | ||
|
||
def _transform_observation_feature(self, obsf: ObservationFeatures) -> None: | ||
if not obsf.parameters: | ||
for p in self._parameter_list: | ||
# TODO[tiao]: can we use be p.target_value? | ||
# (not its original intended use but could be advantageous) | ||
obsf.parameters[p.name] = p.upper | ||
return | ||
super()._transform_observation_feature(obsf) |
116 changes: 116 additions & 0 deletions
116
ax/modelbridge/transforms/tests/test_map_key_to_range_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
from copy import deepcopy | ||
from typing import Iterator | ||
|
||
import numpy as np | ||
from ax.core.observation import Observation, ObservationData, ObservationFeatures | ||
from ax.core.parameter import ParameterType, RangeParameter | ||
from ax.core.search_space import SearchSpace | ||
from ax.exceptions.core import DataRequiredError | ||
from ax.modelbridge.transforms.map_key_to_range import MapKeyToRange | ||
from ax.utils.common.testutils import TestCase | ||
from pyre_extensions import assert_is_instance | ||
|
||
|
||
WIDTHS = [2.0, 4.0, 8.0] | ||
HEIGHTS = [4.0, 2.0, 8.0] | ||
STEPS_ENDS = [1, 5, 3] | ||
|
||
|
||
def _enumerate() -> Iterator[tuple[int, float, float, float]]: | ||
yield from ( | ||
(trial_index, width, height, float(i + 1)) | ||
for trial_index, (width, height, steps_end) in enumerate( | ||
zip(WIDTHS, HEIGHTS, STEPS_ENDS) | ||
) | ||
for i in range(steps_end) | ||
) | ||
|
||
|
||
class MapKeyToRangeTransformTest(TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
|
||
self.search_space = SearchSpace( | ||
parameters=[ | ||
RangeParameter( | ||
name="width", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
RangeParameter( | ||
name="height", | ||
parameter_type=ParameterType.FLOAT, | ||
lower=1, | ||
upper=20, | ||
), | ||
] | ||
) | ||
|
||
self.observations = [] | ||
for trial_index, width, height, steps in _enumerate(): | ||
obs_feat = ObservationFeatures( | ||
trial_index=trial_index, | ||
parameters={"width": width, "height": height}, | ||
metadata={ | ||
"foo": 42, | ||
MapKeyToRange.DEFAULT_MAP_KEY: steps, | ||
}, | ||
) | ||
obs_data = ObservationData( | ||
metric_names=[], means=np.array([]), covariance=np.empty((0, 0)) | ||
) | ||
self.observations.append(Observation(features=obs_feat, data=obs_data)) | ||
|
||
# does not require explicitly specifying `config` | ||
self.t = MapKeyToRange( | ||
observations=self.observations, | ||
) | ||
|
||
def test_Init(self) -> None: | ||
self.assertEqual(len(self.t._parameter_list), 1) | ||
|
||
p = self.t._parameter_list[0] | ||
|
||
self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertTrue(p.log_scale) | ||
|
||
# test that one is able to override default config | ||
t = MapKeyToRange( | ||
observations=self.observations, | ||
config={ | ||
"parameters": {MapKeyToRange.DEFAULT_MAP_KEY: {"log_scale": False}} | ||
}, | ||
) | ||
self.assertDictEqual(t.parameters, {"steps": {"log_scale": False}}) | ||
|
||
self.assertEqual(len(t._parameter_list), 1) | ||
|
||
p = t._parameter_list[0] | ||
|
||
self.assertEqual(p.name, MapKeyToRange.DEFAULT_MAP_KEY) | ||
self.assertEqual(p.parameter_type, ParameterType.FLOAT) | ||
self.assertEqual(p.lower, 1.0) | ||
self.assertEqual(p.upper, 5.0) | ||
self.assertFalse(p.log_scale) | ||
|
||
def test_TransformObservationFeaturesWithEmptyParameters(self) -> None: | ||
obsf = ObservationFeatures(parameters={}) | ||
self.t.transform_observation_features([obsf]) | ||
|
||
p = self.t._parameter_list[0] | ||
self.assertEqual( | ||
obsf, | ||
ObservationFeatures(parameters={MapKeyToRange.DEFAULT_MAP_KEY: p.upper}), | ||
) |