Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate FixedParameter in favor of single-element ChoiceParameter #3397

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions ax/core/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ class ChoiceParameter(Parameter):
name: Name of the parameter.
parameter_type: Enum indicating the type of parameter
value (e.g. string, int).
values: List of allowed values for the parameter.
values: List of allowed value(s) for the parameter.
is_ordered: If False, the parameter is a categorical variable.
Defaults to False if parameter_type is STRING and ``values``
is longer than 2, else True.
Expand Down Expand Up @@ -566,9 +566,6 @@ def __init__(
self._is_task = is_task
self._is_fidelity = is_fidelity
self._target_value: TParamValue = self.cast(target_value)
# A choice parameter with only one value is a FixedParameter.
if not len(values) > 1:
raise UserInputError(f"{self._name}({values}): {FIXED_CHOICE_PARAM_ERROR}")
# Cap the number of possible values.
if len(values) > MAX_VALUES_CHOICE_PARAM:
raise UserInputError(
Expand Down Expand Up @@ -680,9 +677,6 @@ def set_values(self, values: list[TParamValue]) -> ChoiceParameter:
Args:
values: New list of allowed values.
"""
# A choice parameter with only one value is a FixedParameter.
if not len(values) > 1:
raise UserInputError(FIXED_CHOICE_PARAM_ERROR)
self._values = self._cast_values(values)
return self

Expand Down Expand Up @@ -757,7 +751,10 @@ def domain_repr(self) -> str:


class FixedParameter(Parameter):
"""Parameter object that specifies a single fixed value."""
"""
*DEPRECATED*: Use ChoiceParameter with a single value instead.

Parameter object that specifies a single fixed value."""

def __init__(
self,
Expand All @@ -768,7 +765,10 @@ def __init__(
target_value: TParamValue = None,
dependents: dict[TParamValue, list[str]] | None = None,
) -> None:
"""Initialize FixedParameter
"""
*DEPRECATED*: Use ChoiceParameter with a single value instead.

Initialize FixedParameter

Args:
name: Name of the parameter.
Expand All @@ -780,6 +780,10 @@ def __init__(
dependents: Optional mapping for parameters in hierarchical search
spaces; format is { value -> list of dependent parameter names }.
"""
warn(
"FixedParameter is deprecated. Use ChoiceParameter with a single value "
"instead.",
)
if is_fidelity and (target_value is None):
raise UserInputError(
"`target_value` should not be None for the fidelity parameter: "
Expand Down
8 changes: 0 additions & 8 deletions ax/core/tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,14 +334,6 @@ def test_Setter(self) -> None:
self.assertTrue(self.param1.validate("bar"))
self.assertFalse(self.param1.validate("foo"))

def test_SingleValue(self) -> None:
with self.assertRaises(UserInputError):
ChoiceParameter(
name="x", parameter_type=ParameterType.STRING, values=["foo"]
)
with self.assertRaises(UserInputError):
self.param1.set_values(["foo"])

def test_Clone(self) -> None:
param_clone = self.param1.clone()
self.assertEqual(len(self.param1.values), len(param_clone.values))
Expand Down
23 changes: 15 additions & 8 deletions ax/modelbridge/transforms/remove_fixed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# pyre-strict

from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, Union

from ax.core.observation import Observation, ObservationFeatures
from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter
Expand All @@ -21,9 +21,12 @@


class RemoveFixed(Transform):
"""Remove fixed parameters.
"""Remove fixed parameters and single-choice choice parameters from
the search space.

Fixed parameters and single-choice choice parameters should not be included
in the SearchSpace.

Fixed parameters should not be included in the SearchSpace.
This transform removes these parameters, leaving only tunable parameters.

Transform is done in-place for observation features.
Expand All @@ -38,24 +41,25 @@ def __init__(
) -> None:
assert search_space is not None, "RemoveFixed requires search space"
# Identify parameters that should be transformed
self.fixed_parameters: dict[str, FixedParameter] = {
self.single_choice_params: dict[str, Union[FixedParameter, ChoiceParameter]] = {
p_name: p
for p_name, p in search_space.parameters.items()
if isinstance(p, FixedParameter)
or (isinstance(p, ChoiceParameter) and len(p.values) == 1)
}

def transform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
for p_name in self.fixed_parameters:
for p_name in self.single_choice_params:
obsf.parameters.pop(p_name, None)
return observation_features

def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
tunable_parameters: list[ChoiceParameter | RangeParameter] = []
for p in search_space.parameters.values():
if p.name not in self.fixed_parameters:
if p.name not in self.single_choice_params:
# If it's not in fixed_parameters, it must be a tunable param.
# pyre: p_ is declared to have type `Union[ChoiceParameter,
# pyre: RangeParameter]` but is used as type `ax.core.
Expand All @@ -75,6 +79,9 @@ def untransform_observation_features(
self, observation_features: list[ObservationFeatures]
) -> list[ObservationFeatures]:
for obsf in observation_features:
for p_name, p in self.fixed_parameters.items():
obsf.parameters[p_name] = p.value
for p_name, p in self.single_choice_params.items():
if isinstance(p, FixedParameter):
obsf.parameters[p_name] = p.value
else:
obsf.parameters[p_name] = p.values[0]
return observation_features
17 changes: 13 additions & 4 deletions ax/modelbridge/transforms/tests/test_remove_fixed_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def setUp(self) -> None:
"b", parameter_type=ParameterType.STRING, values=["a", "b", "c"]
),
FixedParameter("c", parameter_type=ParameterType.STRING, value="a"),
ChoiceParameter("d", parameter_type=ParameterType.INT, values=[1]),
]
)
self.t = RemoveFixed(
Expand All @@ -41,11 +42,11 @@ def setUp(self) -> None:
)

def test_Init(self) -> None:
self.assertEqual(list(self.t.fixed_parameters.keys()), ["c"])
self.assertEqual(list(self.t.single_choice_params.keys()), ["c", "d"])

def test_TransformObservationFeatures(self) -> None:
observation_features = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"})
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1})
]
obs_ft2 = deepcopy(observation_features)
obs_ft2 = self.t.transform_observation_features(obs_ft2)
Expand All @@ -56,10 +57,10 @@ def test_TransformObservationFeatures(self) -> None:
self.assertEqual(obs_ft2, observation_features)

observation_features = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"})
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1})
]
observation_features_different = [
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b"})
ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b", "d": 10})
]
# Fixed parameter is out of design. It will still get removed.
t_obs = self.t.transform_observation_features(observation_features)
Expand All @@ -72,12 +73,16 @@ def test_TransformSearchSpace(self) -> None:
ss2 = self.search_space.clone()
ss2 = self.t.transform_search_space(ss2)
self.assertEqual(ss2.parameters.get("c"), None)
self.assertEqual(ss2.parameters.get("d"), None)

def test_w_parameter_distributions(self) -> None:
rss = get_robust_search_space()
rss.add_parameter(
FixedParameter("d", parameter_type=ParameterType.STRING, value="a"),
)
rss.add_parameter(
ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]),
)
# Transform a non-distributional parameter.
t = RemoveFixed(
search_space=rss,
Expand All @@ -90,6 +95,7 @@ def test_w_parameter_distributions(self) -> None:
# pyre-fixme[16]: `SearchSpace` has no attribute `parameter_distributions`.
self.assertEqual(len(rss.parameter_distributions), 2)
self.assertNotIn("d", rss.parameters)
self.assertNotIn("e", rss.parameters)
# Test with environmental variables.
all_params = list(rss.parameters.values())
rss = RobustSearchSpace(
Expand All @@ -102,6 +108,9 @@ def test_w_parameter_distributions(self) -> None:
rss.add_parameter(
FixedParameter("d", parameter_type=ParameterType.STRING, value="a"),
)
rss.add_parameter(
ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]),
)
t = RemoveFixed(
search_space=rss,
observations=[],
Expand Down
65 changes: 25 additions & 40 deletions ax/service/utils/instantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# pyre-strict

import enum
import warnings
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -251,14 +252,17 @@ def _make_choice_param(
parameter_type: str | None,
) -> ChoiceParameter:
values = representation["values"]
assert isinstance(values, list) and len(values) > 1, (
f"Cannot parse parameter {name}: for choice parameters, json representation"
" should include a list of two or more values."
assert isinstance(values, list), (
f"Can't parse parameter {name} with values: {values},"
"values should type list."
)
return ChoiceParameter(
name=name,
parameter_type=cls._to_parameter_type(
values, parameter_type, name, "values"
values,
parameter_type,
name,
"values",
),
values=values,
is_ordered=assert_is_instance_optional(
Expand All @@ -283,27 +287,20 @@ def _make_fixed_param(
name: str,
representation: TParameterRepresentation,
parameter_type: str | None,
) -> FixedParameter:
assert "value" in representation, "Value is required for fixed parameters."
value = representation["value"]
assert type(value) in PARAM_TYPES.values(), (
f"Cannot parse fixed parameter {name}: for fixed parameters, json "
"representation should include a single value."
) -> ChoiceParameter:
warnings.warn(
"`fixed` parameters are deprecated. Please use `ChoiceParameter` with a "
"single-value instead. This config will instantiate a `ChoiceParameter`",
DeprecationWarning,
stacklevel=2,
)
return FixedParameter(
# Convert TParameterRepresentation to a ChoiceParameter representation.
representation["values"] = [representation["value"]] # pyre-ignore[6]
representation["type"] = "choice"
return cls._make_choice_param(
name=name,
parameter_type=(
cls._get_parameter_type(type(value)) # pyre-ignore[6]
if parameter_type is None
# pyre-ignore[6]
else cls._get_parameter_type(PARAM_TYPES[parameter_type])
),
value=value, # pyre-ignore[6]
is_fidelity=assert_is_instance(
representation.get("is_fidelity", False), bool
),
target_value=representation.get("target_value", None), # pyre-ignore[6]
dependents=representation.get("dependents", None), # pyre-ignore[6]
representation=representation,
parameter_type=parameter_type,
)

@classmethod
Expand Down Expand Up @@ -357,23 +354,11 @@ def parameter_from_json(
assert (
"values" in representation
), "Values are required for choice parameters."
values = representation["values"]
if isinstance(values, list) and len(values) == 1:
logger.info(
f"Choice parameter {name} contains only one value, converting to a"
+ " fixed parameter instead."
)
# update the representation to a fixed parameter class
parameter_class = "fixed"
representation["type"] = parameter_class
representation["value"] = values[0]
del representation["values"]
else:
return cls._make_choice_param(
name=name,
representation=representation,
parameter_type=parameter_type,
)
return cls._make_choice_param(
name=name,
representation=representation,
parameter_type=parameter_type,
)

if parameter_class == "fixed":
assert not any(isinstance(val, list) for val in representation.values())
Expand Down