From 5bee7dc9fe526a38746fc7d6ef383e9039ad9cdf Mon Sep 17 00:00:00 2001 From: Sunny Shen Date: Thu, 20 Feb 2025 15:46:21 -0800 Subject: [PATCH] Deprecate FixedParameter in favor of single-element ChoiceParameter (#3397) Summary: Deprecating `FixedParameter` -- if `FixedParameterConfig` is used, we will create a single-element `ChoiceParameter` instead Doing a "soft" deprecation (i.e. instantiate`ChoiceParameter` if `FixedParameterConfig` is given instead of erroring out) because many partner integration code relies on `FixedParameterConfig` (e.g. https://fburl.com/code/7nxmqhim) Differential Revision: D68241762 --- ax/core/parameter.py | 22 ++++--- ax/core/tests/test_parameter.py | 8 --- ax/modelbridge/transforms/remove_fixed.py | 23 ++++--- .../tests/test_remove_fixed_transform.py | 17 +++-- ax/service/utils/instantiation.py | 65 +++++++------------ 5 files changed, 66 insertions(+), 69 deletions(-) diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 45ea62d0be2..5f384a8665d 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -527,7 +527,7 @@ class ChoiceParameter(Parameter): name: Name of the parameter. parameter_type: Enum indicating the type of parameter value (e.g. string, int). - values: List of allowed values for the parameter. + values: List of allowed value(s) for the parameter. is_ordered: If False, the parameter is a categorical variable. Defaults to False if parameter_type is STRING and ``values`` is longer than 2, else True. @@ -566,9 +566,6 @@ def __init__( self._is_task = is_task self._is_fidelity = is_fidelity self._target_value: TParamValue = self.cast(target_value) - # A choice parameter with only one value is a FixedParameter. - if not len(values) > 1: - raise UserInputError(f"{self._name}({values}): {FIXED_CHOICE_PARAM_ERROR}") # Cap the number of possible values. if len(values) > MAX_VALUES_CHOICE_PARAM: raise UserInputError( @@ -680,9 +677,6 @@ def set_values(self, values: list[TParamValue]) -> ChoiceParameter: Args: values: New list of allowed values. """ - # A choice parameter with only one value is a FixedParameter. - if not len(values) > 1: - raise UserInputError(FIXED_CHOICE_PARAM_ERROR) self._values = self._cast_values(values) return self @@ -757,7 +751,10 @@ def domain_repr(self) -> str: class FixedParameter(Parameter): - """Parameter object that specifies a single fixed value.""" + """ + *DEPRECATED*: Use ChoiceParameter with a single value instead. + + Parameter object that specifies a single fixed value.""" def __init__( self, @@ -768,7 +765,10 @@ def __init__( target_value: TParamValue = None, dependents: dict[TParamValue, list[str]] | None = None, ) -> None: - """Initialize FixedParameter + """ + *DEPRECATED*: Use ChoiceParameter with a single value instead. + + Initialize FixedParameter Args: name: Name of the parameter. @@ -780,6 +780,10 @@ def __init__( dependents: Optional mapping for parameters in hierarchical search spaces; format is { value -> list of dependent parameter names }. """ + warn( + "FixedParameter is deprecated. Use ChoiceParameter with a single value " + "instead.", + ) if is_fidelity and (target_value is None): raise UserInputError( "`target_value` should not be None for the fidelity parameter: " diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index e0b078e522c..49698e2f595 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -334,14 +334,6 @@ def test_Setter(self) -> None: self.assertTrue(self.param1.validate("bar")) self.assertFalse(self.param1.validate("foo")) - def test_SingleValue(self) -> None: - with self.assertRaises(UserInputError): - ChoiceParameter( - name="x", parameter_type=ParameterType.STRING, values=["foo"] - ) - with self.assertRaises(UserInputError): - self.param1.set_values(["foo"]) - def test_Clone(self) -> None: param_clone = self.param1.clone() self.assertEqual(len(self.param1.values), len(param_clone.values)) diff --git a/ax/modelbridge/transforms/remove_fixed.py b/ax/modelbridge/transforms/remove_fixed.py index 51d3a256ca8..67b49a2f28f 100644 --- a/ax/modelbridge/transforms/remove_fixed.py +++ b/ax/modelbridge/transforms/remove_fixed.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING, Union from ax.core.observation import Observation, ObservationFeatures from ax.core.parameter import ChoiceParameter, FixedParameter, RangeParameter @@ -21,9 +21,12 @@ class RemoveFixed(Transform): - """Remove fixed parameters. + """Remove fixed parameters and single-choice choice parameters from + the search space. + + Fixed parameters and single-choice choice parameters should not be included + in the SearchSpace. - Fixed parameters should not be included in the SearchSpace. This transform removes these parameters, leaving only tunable parameters. Transform is done in-place for observation features. @@ -38,24 +41,25 @@ def __init__( ) -> None: assert search_space is not None, "RemoveFixed requires search space" # Identify parameters that should be transformed - self.fixed_parameters: dict[str, FixedParameter] = { + self.single_choice_params: dict[str, Union[FixedParameter, ChoiceParameter]] = { p_name: p for p_name, p in search_space.parameters.items() if isinstance(p, FixedParameter) + or (isinstance(p, ChoiceParameter) and len(p.values) == 1) } def transform_observation_features( self, observation_features: list[ObservationFeatures] ) -> list[ObservationFeatures]: for obsf in observation_features: - for p_name in self.fixed_parameters: + for p_name in self.single_choice_params: obsf.parameters.pop(p_name, None) return observation_features def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: tunable_parameters: list[ChoiceParameter | RangeParameter] = [] for p in search_space.parameters.values(): - if p.name not in self.fixed_parameters: + if p.name not in self.single_choice_params: # If it's not in fixed_parameters, it must be a tunable param. # pyre: p_ is declared to have type `Union[ChoiceParameter, # pyre: RangeParameter]` but is used as type `ax.core. @@ -75,6 +79,9 @@ def untransform_observation_features( self, observation_features: list[ObservationFeatures] ) -> list[ObservationFeatures]: for obsf in observation_features: - for p_name, p in self.fixed_parameters.items(): - obsf.parameters[p_name] = p.value + for p_name, p in self.single_choice_params.items(): + if isinstance(p, FixedParameter): + obsf.parameters[p_name] = p.value + else: + obsf.parameters[p_name] = p.values[0] return observation_features diff --git a/ax/modelbridge/transforms/tests/test_remove_fixed_transform.py b/ax/modelbridge/transforms/tests/test_remove_fixed_transform.py index 686d744bfef..80c987a8a4e 100644 --- a/ax/modelbridge/transforms/tests/test_remove_fixed_transform.py +++ b/ax/modelbridge/transforms/tests/test_remove_fixed_transform.py @@ -33,6 +33,7 @@ def setUp(self) -> None: "b", parameter_type=ParameterType.STRING, values=["a", "b", "c"] ), FixedParameter("c", parameter_type=ParameterType.STRING, value="a"), + ChoiceParameter("d", parameter_type=ParameterType.INT, values=[1]), ] ) self.t = RemoveFixed( @@ -41,11 +42,11 @@ def setUp(self) -> None: ) def test_Init(self) -> None: - self.assertEqual(list(self.t.fixed_parameters.keys()), ["c"]) + self.assertEqual(list(self.t.single_choice_params.keys()), ["c", "d"]) def test_TransformObservationFeatures(self) -> None: observation_features = [ - ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"}) + ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1}) ] obs_ft2 = deepcopy(observation_features) obs_ft2 = self.t.transform_observation_features(obs_ft2) @@ -56,10 +57,10 @@ def test_TransformObservationFeatures(self) -> None: self.assertEqual(obs_ft2, observation_features) observation_features = [ - ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a"}) + ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "a", "d": 1}) ] observation_features_different = [ - ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b"}) + ObservationFeatures(parameters={"a": 2.2, "b": "b", "c": "b", "d": 10}) ] # Fixed parameter is out of design. It will still get removed. t_obs = self.t.transform_observation_features(observation_features) @@ -72,12 +73,16 @@ def test_TransformSearchSpace(self) -> None: ss2 = self.search_space.clone() ss2 = self.t.transform_search_space(ss2) self.assertEqual(ss2.parameters.get("c"), None) + self.assertEqual(ss2.parameters.get("d"), None) def test_w_parameter_distributions(self) -> None: rss = get_robust_search_space() rss.add_parameter( FixedParameter("d", parameter_type=ParameterType.STRING, value="a"), ) + rss.add_parameter( + ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]), + ) # Transform a non-distributional parameter. t = RemoveFixed( search_space=rss, @@ -90,6 +95,7 @@ def test_w_parameter_distributions(self) -> None: # pyre-fixme[16]: `SearchSpace` has no attribute `parameter_distributions`. self.assertEqual(len(rss.parameter_distributions), 2) self.assertNotIn("d", rss.parameters) + self.assertNotIn("e", rss.parameters) # Test with environmental variables. all_params = list(rss.parameters.values()) rss = RobustSearchSpace( @@ -102,6 +108,9 @@ def test_w_parameter_distributions(self) -> None: rss.add_parameter( FixedParameter("d", parameter_type=ParameterType.STRING, value="a"), ) + rss.add_parameter( + ChoiceParameter("e", parameter_type=ParameterType.INT, values=[1]), + ) t = RemoveFixed( search_space=rss, observations=[], diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 99fb39b142a..e0bfeb4e977 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -7,6 +7,7 @@ # pyre-strict import enum +import warnings from collections.abc import Sequence from copy import deepcopy from dataclasses import dataclass @@ -251,14 +252,17 @@ def _make_choice_param( parameter_type: str | None, ) -> ChoiceParameter: values = representation["values"] - assert isinstance(values, list) and len(values) > 1, ( - f"Cannot parse parameter {name}: for choice parameters, json representation" - " should include a list of two or more values." + assert isinstance(values, list), ( + f"Can't parse parameter {name} with values: {values}," + "values should type list." ) return ChoiceParameter( name=name, parameter_type=cls._to_parameter_type( - values, parameter_type, name, "values" + values, + parameter_type, + name, + "values", ), values=values, is_ordered=assert_is_instance_optional( @@ -283,27 +287,20 @@ def _make_fixed_param( name: str, representation: TParameterRepresentation, parameter_type: str | None, - ) -> FixedParameter: - assert "value" in representation, "Value is required for fixed parameters." - value = representation["value"] - assert type(value) in PARAM_TYPES.values(), ( - f"Cannot parse fixed parameter {name}: for fixed parameters, json " - "representation should include a single value." + ) -> ChoiceParameter: + warnings.warn( + "`fixed` parameters are deprecated. Please use `ChoiceParameter` with a " + "single-value instead. This config will instantiate a `ChoiceParameter`", + DeprecationWarning, + stacklevel=2, ) - return FixedParameter( + # Convert TParameterRepresentation to a ChoiceParameter representation. + representation["values"] = [representation["value"]] # pyre-ignore[6] + representation["type"] = "choice" + return cls._make_choice_param( name=name, - parameter_type=( - cls._get_parameter_type(type(value)) # pyre-ignore[6] - if parameter_type is None - # pyre-ignore[6] - else cls._get_parameter_type(PARAM_TYPES[parameter_type]) - ), - value=value, # pyre-ignore[6] - is_fidelity=assert_is_instance( - representation.get("is_fidelity", False), bool - ), - target_value=representation.get("target_value", None), # pyre-ignore[6] - dependents=representation.get("dependents", None), # pyre-ignore[6] + representation=representation, + parameter_type=parameter_type, ) @classmethod @@ -357,23 +354,11 @@ def parameter_from_json( assert ( "values" in representation ), "Values are required for choice parameters." - values = representation["values"] - if isinstance(values, list) and len(values) == 1: - logger.info( - f"Choice parameter {name} contains only one value, converting to a" - + " fixed parameter instead." - ) - # update the representation to a fixed parameter class - parameter_class = "fixed" - representation["type"] = parameter_class - representation["value"] = values[0] - del representation["values"] - else: - return cls._make_choice_param( - name=name, - representation=representation, - parameter_type=parameter_type, - ) + return cls._make_choice_param( + name=name, + representation=representation, + parameter_type=parameter_type, + ) if parameter_class == "fixed": assert not any(isinstance(val, list) for val in representation.values())