forked from facebook/Ax
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathfill_missing_parameters.py
More file actions
103 lines (90 loc) · 3.82 KB
/
fill_missing_parameters.py
File metadata and controls
103 lines (90 loc) · 3.82 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
from logging import Logger
from typing import cast, TYPE_CHECKING
from ax.adapter.data_utils import ExperimentData
from ax.adapter.transforms.base import Transform
from ax.core.observation import ObservationFeatures
from ax.core.parameter import DerivedParameter
from ax.core.search_space import SearchSpace
from ax.core.types import TParameterization
from ax.generators.types import TConfig
from ax.utils.common.logger import get_logger
if TYPE_CHECKING:
# import as module to make sphinx-autodoc-typehints happy
from ax import adapter as adapter_module # noqa F401
logger: Logger = get_logger(__name__)
class FillMissingParameters(Transform):
"""If a parameter is missing from an arm, fill it with the value from
the dict given in the config.
"""
def __init__(
self,
search_space: SearchSpace | None = None,
experiment_data: ExperimentData | None = None,
adapter: adapter_module.base.Adapter | None = None,
config: TConfig | None = None, # Deprecated
) -> None:
super().__init__(
search_space=search_space,
experiment_data=experiment_data,
adapter=adapter,
config=config,
)
self._fill_values: TParameterization = {}
# Read fill_values from deprecated config if provided to maintain backwards
# compatibility
if config is not None:
logger.error(
"Use of config for FillMissingParameters has been deprecated. "
"Use search_space.add_parameters instead."
)
self._fill_values.update(
cast(TParameterization, config.get("fill_values", {}))
)
# Add backfill values from search space. These will override any values
# provided in the deprecated config.
if search_space is not None:
self._fill_values.update(search_space.backfill_values())
# Collect derived parameters from search space.
self._derived_parameters: dict[str, DerivedParameter] = {}
if search_space is not None:
self._derived_parameters = {
name: p.clone()
for name, p in search_space.nontunable_parameters.items()
if isinstance(p, DerivedParameter)
}
def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
fill_params = {
k: v
for k, v in self._fill_values.items()
if k not in obsf.parameters or (obsf.parameters[k] is None)
}
obsf.parameters.update(fill_params)
return observation_features
def transform_experiment_data(
self, experiment_data: ExperimentData
) -> ExperimentData:
arm_data = experiment_data.arm_data.fillna(value=self._fill_values)
# If any of the fill columns are missing in arm_data, add it.
missing_columns = set(self._fill_values) - set(arm_data.columns)
for col in missing_columns:
arm_data[col] = self._fill_values[col]
# Compute derived parameter values. These are always (re-)computed
# to ensure correctness, since they are deterministic functions of
# other parameters.
if self._derived_parameters:
for p_name, p in self._derived_parameters.items():
arm_data[p_name] = p.compute_array(arm_data)
return ExperimentData(
arm_data=arm_data,
observation_data=experiment_data.observation_data,
)