Skip to content

Commit

Permalink
Deprecate FixedParameter in favor of single-element ChoiceParameter (#…
Browse files Browse the repository at this point in the history
…3397)

Summary:

As titled. Deprecating `FixedParameter`


Doing a "soft" deprecation (i.e. instantiate`ChoiceParameter` if `FixedParameterConfig` is given) becuase there are partner integration code that relies on `FixedParameterConfig` (e.g. https://fburl.com/code/7nxmqhim)

Differential Revision: D68241762
  • Loading branch information
Sunny Shen authored and facebook-github-bot committed Feb 20, 2025
1 parent f80caac commit e7cb051
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 60 deletions.
22 changes: 13 additions & 9 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ class ChoiceParameter(Parameter):
name: Name of the parameter.
parameter_type: Enum indicating the type of parameter
value (e.g. string, int).
values: List of allowed values for the parameter.
values: List of allowed value(s) for the parameter.
is_ordered: If False, the parameter is a categorical variable.
Defaults to False if parameter_type is STRING and ``values``
is longer than 2, else True.
Expand Down Expand Up @@ -566,9 +566,6 @@ def __init__(
self._is_task = is_task
self._is_fidelity = is_fidelity
self._target_value: TParamValue = self.cast(target_value)
# A choice parameter with only one value is a FixedParameter.
if not len(values) > 1:
raise UserInputError(f"{self._name}({values}): {FIXED_CHOICE_PARAM_ERROR}")
# Cap the number of possible values.
if len(values) > MAX_VALUES_CHOICE_PARAM:
raise UserInputError(
Expand Down Expand Up @@ -680,9 +677,6 @@ def set_values(self, values: list[TParamValue]) -> ChoiceParameter:
Args:
values: New list of allowed values.
"""
# A choice parameter with only one value is a FixedParameter.
if not len(values) > 1:
raise UserInputError(FIXED_CHOICE_PARAM_ERROR)
self._values = self._cast_values(values)
return self

Expand Down Expand Up @@ -757,7 +751,10 @@ def domain_repr(self) -> str:


class FixedParameter(Parameter):
"""Parameter object that specifies a single fixed value."""
"""
*DEPRECATED*: Use ChoiceParameter with a single value instead.
Parameter object that specifies a single fixed value."""

def __init__(
self,
Expand All @@ -768,7 +765,10 @@ def __init__(
target_value: TParamValue = None,
dependents: dict[TParamValue, list[str]] | None = None,
) -> None:
"""Initialize FixedParameter
"""
*DEPRECATED*: Use ChoiceParameter with a single value instead.
Initialize FixedParameter
Args:
name: Name of the parameter.
Expand All @@ -780,6 +780,10 @@ def __init__(
dependents: Optional mapping for parameters in hierarchical search
spaces; format is { value -> list of dependent parameter names }.
"""
warn(
"FixedParameter is deprecated. Use ChoiceParameter with a single value "
"instead.",
)
if is_fidelity and (target_value is None):
raise UserInputError(
"`target_value` should not be None for the fidelity parameter: "
Expand Down
20 changes: 12 additions & 8 deletions ax/modelbridge/transforms/remove_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, Union

from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter
Expand All @@ -21,9 +21,9 @@


class RemoveFixed(Transform):
"""Remove fixed parameters.
"""Remove fixed parameters and single-choice choice parameters from the search space.
Fixed parameters should not be included in the SearchSpace.
Fixed parameters and single-choice choice parameters should not be included in the SearchSpace.
This transform removes these parameters, leaving only tunable parameters.
Transform is done in-place for observation features.
Expand All @@ -38,24 +38,25 @@ def __init__(
) -> None:
assert search_space is not None, "RemoveFixed requires search space"
# Identify parameters that should be transformed
self.fixed_parameters: dict[str, FixedParameter] = {
self.single_choice_params: dict[str, Union[FixedParameter, ChoiceParameter]] = {
p_name: p
for p_name, p in search_space.parameters.items()
if isinstance(p, FixedParameter)
or (isinstance(p, ChoiceParameter) and len(p.values) == 1)
}

def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
for p_name in self.fixed_parameters:
for p_name in self.single_choice_params:
obsf.parameters.pop(p_name, None)
return observation_features

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
tunable_parameters: list[ChoiceParameter | RangeParameter] = []
for p in search_space.parameters.values():
if p.name not in self.fixed_parameters:
if p.name not in self.single_choice_params:
# If it's not in fixed_parameters, it must be a tunable param.
# pyre: p_ is declared to have type `Union[ChoiceParameter,
# pyre: RangeParameter]` but is used as type `ax.core.
Expand All @@ -75,6 +76,9 @@ def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
for p_name, p in self.fixed_parameters.items():
obsf.parameters[p_name] = p.value
for p_name, p in self.single_choice_params.items():
if isinstance(p, FixedParameter):
obsf.parameters[p_name] = p.value
else:
obsf.parameters[p_name] = p.values[0]
return observation_features
17 changes: 13 additions & 4 deletions ax/modelbridge/transforms/tests/test_remove_fixed_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self) -> None:
"b", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
),
FixedParameter("c", parameter_type=ParameterType.STRING, value="a"),
ChoiceParameter("d", parameter_type=ParameterType.INT, values=[1]),
]
)
self.t = RemoveFixed(
Expand All @@ -41,11 +42,11 @@ def setUp(self) -> None:
)

def test_Init(self) -> None:
self.assertEqual(list(self.t.fixed_parameters.keys()), ["c"])
self.assertEqual(list(self.t.single_choice_params.keys()), ["c", "d"])

def test_TransformObservationFeatures(self) -> None:
observation_features = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"})
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1})
]
obs_ft2 = deepcopy(observation_features)
obs_ft2 = self.t.transform_observation_features(obs_ft2)
Expand All @@ -56,10 +57,10 @@ def test_TransformObservationFeatures(self) -> None:
self.assertEqual(obs_ft2, observation_features)

observation_features = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"})
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1})
]
observation_features_different = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b"})
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b", "d": 10})
]
# Fixed parameter is out of design. It will still get removed.
t_obs = self.t.transform_observation_features(observation_features)
Expand All @@ -72,12 +73,16 @@ def test_TransformSearchSpace(self) -> None:
ss2 = self.search_space.clone()
ss2 = self.t.transform_search_space(ss2)
self.assertEqual(ss2.parameters.get("c"), None)
self.assertEqual(ss2.parameters.get("d"), None)

def test_w_parameter_distributions(self) -> None:
rss = get_robust_search_space()
rss.add_parameter(
FixedParameter("d", parameter_type=ParameterType.STRING, value="a"),
)
rss.add_parameter(
ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]),
)
# Transform a non-distributional parameter.
t = RemoveFixed(
search_space=rss,
Expand All @@ -90,6 +95,7 @@ def test_w_parameter_distributions(self) -> None:
# pyre-fixme[16]: `SearchSpace` has no attribute `parameter_distributions`.
self.assertEqual(len(rss.parameter_distributions), 2)
self.assertNotIn("d", rss.parameters)
self.assertNotIn("e", rss.parameters)
# Test with environmental variables.
all_params = list(rss.parameters.values())
rss = RobustSearchSpace(
Expand All @@ -102,6 +108,9 @@ def test_w_parameter_distributions(self) -> None:
rss.add_parameter(
FixedParameter("d", parameter_type=ParameterType.STRING, value="a"),
)
rss.add_parameter(
ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]),
)
t = RemoveFixed(
search_space=rss,
observations=[],
Expand Down
56 changes: 17 additions & 39 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import enum
import warnings
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -251,10 +252,6 @@ def _make_choice_param(
parameter_type: str | None,
) -> ChoiceParameter:
values = representation["values"]
assert isinstance(values, list) and len(values) > 1, (
f"Cannot parse parameter {name}: for choice parameters, json representation"
" should include a list of two or more values."
)
return ChoiceParameter(
name=name,
parameter_type=cls._to_parameter_type(
Expand Down Expand Up @@ -283,27 +280,19 @@ def _make_fixed_param(
name: str,
representation: TParameterRepresentation,
parameter_type: str | None,
) -> FixedParameter:
assert "value" in representation, "Value is required for fixed parameters."
value = representation["value"]
assert type(value) in PARAM_TYPES.values(), (
f"Cannot parse fixed parameter {name}: for fixed parameters, json "
"representation should include a single value."
) -> ChoiceParameter:
warnings.warn(
"`fixed` parameters are deprecated. Please use `ChoiceParameter` with a "
"single-value instead. This config will instantiate a `ChoiceParameter`",
DeprecationWarning,
stacklevel=2,
)
return FixedParameter(
# Convert TParameterRepresentation to a ChoiceParameter representation.
representation["values"] = [representation["value"]] # pyre-ignore[6]
return cls._make_choice_param(
name=name,
parameter_type=(
cls._get_parameter_type(type(value)) # pyre-ignore[6]
if parameter_type is None
# pyre-ignore[6]
else cls._get_parameter_type(PARAM_TYPES[parameter_type])
),
value=value, # pyre-ignore[6]
is_fidelity=assert_is_instance(
representation.get("is_fidelity", False), bool
),
target_value=representation.get("target_value", None), # pyre-ignore[6]
dependents=representation.get("dependents", None), # pyre-ignore[6]
representation=representation,
parameter_type=parameter_type,
)

@classmethod
Expand Down Expand Up @@ -358,22 +347,11 @@ def parameter_from_json(
"values" in representation
), "Values are required for choice parameters."
values = representation["values"]
if isinstance(values, list) and len(values) == 1:
logger.info(
f"Choice parameter {name} contains only one value, converting to a"
+ " fixed parameter instead."
)
# update the representation to a fixed parameter class
parameter_class = "fixed"
representation["type"] = parameter_class
representation["value"] = values[0]
del representation["values"]
else:
return cls._make_choice_param(
name=name,
representation=representation,
parameter_type=parameter_type,
)
return cls._make_choice_param(
name=name,
representation=representation,
parameter_type=parameter_type,
)

if parameter_class == "fixed":
assert not any(isinstance(val, list) for val in representation.values())
Expand Down

0 comments on commit e7cb051

Please sign in to comment.