Skip to content

Commit e7cb051

Browse files
Sunny Shenfacebook-github-bot
authored andcommitted
Deprecate FixedParameter in favor of single-element ChoiceParameter (#3397)
Summary: As titled. Deprecating `FixedParameter` Doing a "soft" deprecation (i.e. instantiate`ChoiceParameter` if `FixedParameterConfig` is given) becuase there are partner integration code that relies on `FixedParameterConfig` (e.g. https://fburl.com/code/7nxmqhim) Differential Revision: D68241762
1 parent f80caac commit e7cb051

File tree

4 files changed

+55
-60
lines changed

4 files changed

+55
-60
lines changed

ax/core/parameter.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ class ChoiceParameter(Parameter):
527527
name: Name of the parameter.
528528
parameter_type: Enum indicating the type of parameter
529529
value (e.g. string, int).
530-
values: List of allowed values for the parameter.
530+
values: List of allowed value(s) for the parameter.
531531
is_ordered: If False, the parameter is a categorical variable.
532532
Defaults to False if parameter_type is STRING and ``values``
533533
is longer than 2, else True.
@@ -566,9 +566,6 @@ def __init__(
566566
self._is_task = is_task
567567
self._is_fidelity = is_fidelity
568568
self._target_value: TParamValue = self.cast(target_value)
569-
# A choice parameter with only one value is a FixedParameter.
570-
if not len(values) > 1:
571-
raise UserInputError(f"{self._name}({values}): {FIXED_CHOICE_PARAM_ERROR}")
572569
# Cap the number of possible values.
573570
if len(values) > MAX_VALUES_CHOICE_PARAM:
574571
raise UserInputError(
@@ -680,9 +677,6 @@ def set_values(self, values: list[TParamValue]) -> ChoiceParameter:
680677
Args:
681678
values: New list of allowed values.
682679
"""
683-
# A choice parameter with only one value is a FixedParameter.
684-
if not len(values) > 1:
685-
raise UserInputError(FIXED_CHOICE_PARAM_ERROR)
686680
self._values = self._cast_values(values)
687681
return self
688682

@@ -757,7 +751,10 @@ def domain_repr(self) -> str:
757751

758752

759753
class FixedParameter(Parameter):
760-
"""Parameter object that specifies a single fixed value."""
754+
"""
755+
*DEPRECATED*: Use ChoiceParameter with a single value instead.
756+
757+
Parameter object that specifies a single fixed value."""
761758

762759
def __init__(
763760
self,
@@ -768,7 +765,10 @@ def __init__(
768765
target_value: TParamValue = None,
769766
dependents: dict[TParamValue, list[str]] | None = None,
770767
) -> None:
771-
"""Initialize FixedParameter
768+
"""
769+
*DEPRECATED*: Use ChoiceParameter with a single value instead.
770+
771+
Initialize FixedParameter
772772
773773
Args:
774774
name: Name of the parameter.
@@ -780,6 +780,10 @@ def __init__(
780780
dependents: Optional mapping for parameters in hierarchical search
781781
spaces; format is { value -> list of dependent parameter names }.
782782
"""
783+
warn(
784+
"FixedParameter is deprecated. Use ChoiceParameter with a single value "
785+
"instead.",
786+
)
783787
if is_fidelity and (target_value is None):
784788
raise UserInputError(
785789
"`target_value` should not be None for the fidelity parameter: "

ax/modelbridge/transforms/remove_fixed.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from typing import Optional, TYPE_CHECKING
9+
from typing import Optional, TYPE_CHECKING, Union
1010

1111
from ax.core.observation import Observation, ObservationFeatures
1212
from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter
@@ -21,9 +21,9 @@
2121

2222

2323
class RemoveFixed(Transform):
24-
"""Remove fixed parameters.
24+
"""Remove fixed parameters and single-choice choice parameters from the search space.
2525
26-
Fixed parameters should not be included in the SearchSpace.
26+
Fixed parameters and single-choice choice parameters should not be included in the SearchSpace.
2727
This transform removes these parameters, leaving only tunable parameters.
2828
2929
Transform is done in-place for observation features.
@@ -38,24 +38,25 @@ def __init__(
3838
) -> None:
3939
assert search_space is not None, "RemoveFixed requires search space"
4040
# Identify parameters that should be transformed
41-
self.fixed_parameters: dict[str, FixedParameter] = {
41+
self.single_choice_params: dict[str, Union[FixedParameter, ChoiceParameter]] = {
4242
p_name: p
4343
for p_name, p in search_space.parameters.items()
4444
if isinstance(p, FixedParameter)
45+
or (isinstance(p, ChoiceParameter) and len(p.values) == 1)
4546
}
4647

4748
def transform_observation_features(
4849
self, observation_features: list[ObservationFeatures]
4950
) -> list[ObservationFeatures]:
5051
for obsf in observation_features:
51-
for p_name in self.fixed_parameters:
52+
for p_name in self.single_choice_params:
5253
obsf.parameters.pop(p_name, None)
5354
return observation_features
5455

5556
def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
5657
tunable_parameters: list[ChoiceParameter | RangeParameter] = []
5758
for p in search_space.parameters.values():
58-
if p.name not in self.fixed_parameters:
59+
if p.name not in self.single_choice_params:
5960
# If it's not in fixed_parameters, it must be a tunable param.
6061
# pyre: p_ is declared to have type `Union[ChoiceParameter,
6162
# pyre: RangeParameter]` but is used as type `ax.core.
@@ -75,6 +76,9 @@ def untransform_observation_features(
7576
self, observation_features: list[ObservationFeatures]
7677
) -> list[ObservationFeatures]:
7778
for obsf in observation_features:
78-
for p_name, p in self.fixed_parameters.items():
79-
obsf.parameters[p_name] = p.value
79+
for p_name, p in self.single_choice_params.items():
80+
if isinstance(p, FixedParameter):
81+
obsf.parameters[p_name] = p.value
82+
else:
83+
obsf.parameters[p_name] = p.values[0]
8084
return observation_features

ax/modelbridge/transforms/tests/test_remove_fixed_transform.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def setUp(self) -> None:
3333
"b", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
3434
),
3535
FixedParameter("c", parameter_type=ParameterType.STRING, value="a"),
36+
ChoiceParameter("d", parameter_type=ParameterType.INT, values=[1]),
3637
]
3738
)
3839
self.t = RemoveFixed(
@@ -41,11 +42,11 @@ def setUp(self) -> None:
4142
)
4243

4344
def test_Init(self) -> None:
44-
self.assertEqual(list(self.t.fixed_parameters.keys()), ["c"])
45+
self.assertEqual(list(self.t.single_choice_params.keys()), ["c", "d"])
4546

4647
def test_TransformObservationFeatures(self) -> None:
4748
observation_features = [
48-
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"})
49+
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1})
4950
]
5051
obs_ft2 = deepcopy(observation_features)
5152
obs_ft2 = self.t.transform_observation_features(obs_ft2)
@@ -56,10 +57,10 @@ def test_TransformObservationFeatures(self) -> None:
5657
self.assertEqual(obs_ft2, observation_features)
5758

5859
observation_features = [
59-
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"})
60+
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1})
6061
]
6162
observation_features_different = [
62-
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b"})
63+
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b", "d": 10})
6364
]
6465
# Fixed parameter is out of design. It will still get removed.
6566
t_obs = self.t.transform_observation_features(observation_features)
@@ -72,12 +73,16 @@ def test_TransformSearchSpace(self) -> None:
7273
ss2 = self.search_space.clone()
7374
ss2 = self.t.transform_search_space(ss2)
7475
self.assertEqual(ss2.parameters.get("c"), None)
76+
self.assertEqual(ss2.parameters.get("d"), None)
7577

7678
def test_w_parameter_distributions(self) -> None:
7779
rss = get_robust_search_space()
7880
rss.add_parameter(
7981
FixedParameter("d", parameter_type=ParameterType.STRING, value="a"),
8082
)
83+
rss.add_parameter(
84+
ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]),
85+
)
8186
# Transform a non-distributional parameter.
8287
t = RemoveFixed(
8388
search_space=rss,
@@ -90,6 +95,7 @@ def test_w_parameter_distributions(self) -> None:
9095
# pyre-fixme[16]: `SearchSpace` has no attribute `parameter_distributions`.
9196
self.assertEqual(len(rss.parameter_distributions), 2)
9297
self.assertNotIn("d", rss.parameters)
98+
self.assertNotIn("e", rss.parameters)
9399
# Test with environmental variables.
94100
all_params = list(rss.parameters.values())
95101
rss = RobustSearchSpace(
@@ -102,6 +108,9 @@ def test_w_parameter_distributions(self) -> None:
102108
rss.add_parameter(
103109
FixedParameter("d", parameter_type=ParameterType.STRING, value="a"),
104110
)
111+
rss.add_parameter(
112+
ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]),
113+
)
105114
t = RemoveFixed(
106115
search_space=rss,
107116
observations=[],

ax/service/utils/instantiation.py

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
# pyre-strict
88

99
import enum
10+
import warnings
1011
from collections.abc import Sequence
1112
from copy import deepcopy
1213
from dataclasses import dataclass
@@ -251,10 +252,6 @@ def _make_choice_param(
251252
parameter_type: str | None,
252253
) -> ChoiceParameter:
253254
values = representation["values"]
254-
assert isinstance(values, list) and len(values) > 1, (
255-
f"Cannot parse parameter {name}: for choice parameters, json representation"
256-
" should include a list of two or more values."
257-
)
258255
return ChoiceParameter(
259256
name=name,
260257
parameter_type=cls._to_parameter_type(
@@ -283,27 +280,19 @@ def _make_fixed_param(
283280
name: str,
284281
representation: TParameterRepresentation,
285282
parameter_type: str | None,
286-
) -> FixedParameter:
287-
assert "value" in representation, "Value is required for fixed parameters."
288-
value = representation["value"]
289-
assert type(value) in PARAM_TYPES.values(), (
290-
f"Cannot parse fixed parameter {name}: for fixed parameters, json "
291-
"representation should include a single value."
283+
) -> ChoiceParameter:
284+
warnings.warn(
285+
"`fixed` parameters are deprecated. Please use `ChoiceParameter` with a "
286+
"single-value instead. This config will instantiate a `ChoiceParameter`",
287+
DeprecationWarning,
288+
stacklevel=2,
292289
)
293-
return FixedParameter(
290+
# Convert TParameterRepresentation to a ChoiceParameter representation.
291+
representation["values"] = [representation["value"]] # pyre-ignore[6]
292+
return cls._make_choice_param(
294293
name=name,
295-
parameter_type=(
296-
cls._get_parameter_type(type(value)) # pyre-ignore[6]
297-
if parameter_type is None
298-
# pyre-ignore[6]
299-
else cls._get_parameter_type(PARAM_TYPES[parameter_type])
300-
),
301-
value=value, # pyre-ignore[6]
302-
is_fidelity=assert_is_instance(
303-
representation.get("is_fidelity", False), bool
304-
),
305-
target_value=representation.get("target_value", None), # pyre-ignore[6]
306-
dependents=representation.get("dependents", None), # pyre-ignore[6]
294+
representation=representation,
295+
parameter_type=parameter_type,
307296
)
308297

309298
@classmethod
@@ -358,22 +347,11 @@ def parameter_from_json(
358347
"values" in representation
359348
), "Values are required for choice parameters."
360349
values = representation["values"]
361-
if isinstance(values, list) and len(values) == 1:
362-
logger.info(
363-
f"Choice parameter {name} contains only one value, converting to a"
364-
+ " fixed parameter instead."
365-
)
366-
# update the representation to a fixed parameter class
367-
parameter_class = "fixed"
368-
representation["type"] = parameter_class
369-
representation["value"] = values[0]
370-
del representation["values"]
371-
else:
372-
return cls._make_choice_param(
373-
name=name,
374-
representation=representation,
375-
parameter_type=parameter_type,
376-
)
350+
return cls._make_choice_param(
351+
name=name,
352+
representation=representation,
353+
parameter_type=parameter_type,
354+
)
377355

378356
if parameter_class == "fixed":
379357
assert not any(isinstance(val, list) for val in representation.values())

0 commit comments

Comments
 (0)