From f7d711c9e8de1b9f9961ab5509bd66cdba61c2a1 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 3 Jun 2026 13:23:10 -0700 Subject: [PATCH 1/4] Migrate TL adapter utils tests to non-fb location and drop shims Summary: The TL adapter utilities `get_joint_search_space`, `merge_dependents`, `merge_parameters` (in `ax/adapter/transfer_learning/utils.py`) and `get_mapped_parameter_names` (in `ax/adapter/transfer_learning/utils_torch.py`) were previously migrated out of `ax/fb/adapter/`, which left behind pure re-export shims at `ax/fb/adapter/utils.py` and `ax/fb/adapter/utils_torch.py`. The only remaining coverage for these functions lived in `ax/fb/adapter/tests/test_utils.py` and `test_utils_torch.py`, exercising the migrated code through those shims -- and the non-fb destination had no test coverage of its own. This moves both test files to `ax/adapter/transfer_learning/tests/`, switches their imports to the real non-fb modules (`ax.adapter.transfer_learning.utils`/`utils_torch`, `ax.core.auxiliary_source.AuxiliarySource`, `ax.utils.common.testutils.TestCase`), and removes the now-unused `get_unordered_choice`/`get_ordered_choice` helpers from `test_utils.py`. Since the two test files were the only callers of the `ax.fb.adapter.utils`/`utils_torch` shims repo-wide, the shims are deleted and the BUCK targets are updated accordingly: a new `test_utils` `python_unittest` is added under `ax/adapter/transfer_learning/BUCK`, and the old `test_utils` target, the orphaned `:utils` library, and the `utils_torch.py` src are removed from `ax/fb/adapter/BUCK`. The broader `ax.fb.core.auxiliary_source` shim is left in place; it still has many callers across admarket, pts, automl, storage, and docs, so cleaning it up is a separate effort. Differential Revision: D107429272 --- .../transfer_learning/tests/test_utils.py | 349 ++++++++++++++++++ .../tests/test_utils_torch.py | 118 ++++++ 2 files changed, 467 insertions(+) create mode 100644 ax/adapter/transfer_learning/tests/test_utils.py create mode 100644 ax/adapter/transfer_learning/tests/test_utils_torch.py diff --git a/ax/adapter/transfer_learning/tests/test_utils.py b/ax/adapter/transfer_learning/tests/test_utils.py new file mode 100644 index 00000000000..ccc644778f7 --- /dev/null +++ b/ax/adapter/transfer_learning/tests/test_utils.py @@ -0,0 +1,349 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.adapter.transfer_learning.utils import ( + get_joint_search_space, + merge_dependents, + merge_parameters, +) +from ax.core.auxiliary_source import AuxiliarySource +from ax.core.experiment import Experiment +from ax.core.parameter import ( + ChoiceParameter, + DerivedParameter, + FixedParameter, + Parameter, + ParameterType, + RangeParameter, +) +from ax.core.search_space import SearchSpace +from ax.utils.common.testutils import TestCase +from pyre_extensions import assert_is_instance, none_throws + + +class AxFbCoreUtilsTest(TestCase): + def test_get_joint_search_space(self) -> None: + parameters: list[Parameter] = [ + RangeParameter(f"x{i}", parameter_type=ParameterType.INT, lower=0, upper=5) + for i in range(3) + ] + exp1 = Experiment( + search_space=SearchSpace(parameters=parameters[:2]), name="test1" + ) + exp2 = Experiment( + search_space=SearchSpace(parameters=parameters[:2]), name="test2" + ) + exp3 = Experiment( + search_space=SearchSpace(parameters=parameters[1:]), name="test3" + ) + aux_2 = AuxiliarySource(experiment=exp2) + aux_3 = AuxiliarySource(experiment=exp3) + aux_4 = AuxiliarySource(experiment=exp3, transfer_param_config={"x0": "x2"}) + for exp, aux_srcs, expected_params in ( + (exp1, [aux_2], {"x0", "x1"}), + (exp1, [aux_2, aux_3], {"x0", "x1", "x2"}), + (exp1, [aux_2, aux_4], {"x0", "x1"}), + ): + self.assertEqual( + set( + get_joint_search_space( + search_space=exp.search_space, auxiliary_sources=aux_srcs + ).parameters.keys() + ), + expected_params, + ) + + def test_get_joint_search_space_update_fixed_params(self) -> None: + # test update fixed params + range_param = RangeParameter( + "x", parameter_type=ParameterType.INT, lower=0, upper=5 + ) + fixed_param1 = FixedParameter("y", parameter_type=ParameterType.INT, value=1) + fixed_param2 = FixedParameter("y", parameter_type=ParameterType.INT, value=2) + exp = Experiment( + search_space=SearchSpace(parameters=[range_param, fixed_param1]), + name="test1", + ) + exp2 = Experiment( + search_space=SearchSpace(parameters=[range_param, fixed_param2]), + name="test2", + ) + for update_fixed_params in [True, False]: + aux2 = AuxiliarySource( + experiment=exp2, update_fixed_params=update_fixed_params + ) + ss_params = get_joint_search_space( + search_space=exp.search_space, auxiliary_sources=[aux2] + ).parameters + self.assertEqual( + assert_is_instance(ss_params["y"], FixedParameter).value, 1 + ) + self.assertIn("x", ss_params) + + def test_get_joint_search_space_with_hss_and_choice(self) -> None: + ss1 = SearchSpace( + parameters=[ + FixedParameter( + "root", + parameter_type=ParameterType.INT, + value=1, + dependents={1: ["learning_rate", "optimizer", "method"]}, + ), + ChoiceParameter( + "learning_rate", + parameter_type=ParameterType.FLOAT, + values=[0.01, 0.05], + ), + ChoiceParameter( + "optimizer", + parameter_type=ParameterType.STRING, + values=["Adam", "SGD", "AdaGrad"], + ), + ChoiceParameter( + "method", + parameter_type=ParameterType.STRING, + values=["train", "eval"], + ), + ] + ) + ss2 = SearchSpace( + parameters=[ + FixedParameter( + "root2", + parameter_type=ParameterType.INT, + value=1, + dependents={1: ["lr", "optimizer"]}, + ), + ChoiceParameter( + "lr", parameter_type=ParameterType.FLOAT, values=[0.01, 0.1] + ), + ChoiceParameter( + "optimizer", + parameter_type=ParameterType.STRING, + values=["Adam", "SGD"], + ), + ] + ) + aux_src = AuxiliarySource( + experiment=Experiment(search_space=ss2, name="test"), + transfer_param_config={"learning_rate": "lr", "root": "root2"}, + update_fixed_params=False, + ) + joint_ss = get_joint_search_space(search_space=ss1, auxiliary_sources=[aux_src]) + self.assertEqual( + set(joint_ss.parameters.keys()), + {"root", "learning_rate", "optimizer", "method"}, + ) + self.assertEqual( + set(joint_ss["root"].dependents[1]), + {"learning_rate", "optimizer", "method"}, + ) + self.assertEqual( + assert_is_instance( + joint_ss.parameters["learning_rate"], ChoiceParameter + ).values, + [0.01, 0.05, 0.1], + ) + self.assertEqual( + set( + assert_is_instance( + joint_ss.parameters["optimizer"], ChoiceParameter + ).values + ), + {"Adam", "SGD", "AdaGrad"}, + ) + + def test_merge_dependents(self) -> None: + p_no_dependents = FixedParameter( + "p", parameter_type=ParameterType.BOOL, value=True + ) + # No dependents returns None. + self.assertIsNone( + merge_dependents( + p1=p_no_dependents, p2=p_no_dependents, reverse_param_config={} + ) + ) + p_dependents_1 = FixedParameter( + "p1", parameter_type=ParameterType.INT, value=1, dependents={1: ["q"]} + ) + p_dependents_2 = FixedParameter( + "p2", parameter_type=ParameterType.INT, value=1, dependents={1: ["z"]} + ) + # p1 dependents do not get renamed. + self.assertEqual( + merge_dependents( + p1=p_dependents_1, p2=p_no_dependents, reverse_param_config={"q": "w"} + ), + {1: ["q"]}, + ) + # p2 dependents get renamed. + self.assertEqual( + merge_dependents( + p1=p_no_dependents, p2=p_dependents_1, reverse_param_config={"q": "w"} + ), + {1: ["w"]}, + ) + # Merge p1 & p2 dependents with renaming for p2 only. + self.assertEqual( + set( + none_throws( + merge_dependents( + p1=p_dependents_1, + p2=p_dependents_2, + reverse_param_config={"q": "w", "z": "v"}, + ) + )[1] + ), + {"q", "v"}, + ) + + def test_merge_parameters(self) -> None: + p_fixed = FixedParameter( + name="fixed", parameter_type=ParameterType.BOOL, value=True + ) + p_fixed_2 = FixedParameter(name="f2", parameter_type=ParameterType.INT, value=1) + p_fixed_3 = FixedParameter(name="f3", parameter_type=ParameterType.INT, value=2) + p_fixed_4 = FixedParameter( + name="f4", parameter_type=ParameterType.INT, value=1, dependents={1: ["a"]} + ) + with self.assertRaisesRegex(ValueError, "different names"): + merge_parameters(p1=p_fixed, p2=p_fixed_2, reverse_param_config={}) + with self.assertRaisesRegex(ValueError, "different types"): + merge_parameters( + p1=p_fixed, p2=p_fixed_2, reverse_param_config={"f2": "fixed"} + ) + # Check that it works with both values of update_fixed_params. + for update_fixed_params in [True, False]: + self.assertEqual( + merge_parameters( + p1=p_fixed_2, + p2=p_fixed_3, + reverse_param_config={"f3": "f2"}, + update_fixed_params=update_fixed_params, + ), + FixedParameter( + name="f2", + parameter_type=ParameterType.INT, + value=1, + ), + ) + self.assertEqual( + merge_parameters( + p1=p_fixed_2, p2=p_fixed_4, reverse_param_config={"f4": "f2"} + ), + FixedParameter( + name="f2", + parameter_type=ParameterType.INT, + value=1, + dependents={1: ["a"]}, + ), + ) + p_range_1 = RangeParameter( + name="p", parameter_type=ParameterType.INT, lower=1, upper=3 + ) + p_range_2 = RangeParameter( + name="p", parameter_type=ParameterType.INT, lower=0, upper=2 + ) + self.assertEqual( + merge_parameters(p1=p_range_1, p2=p_range_2, reverse_param_config={}), + RangeParameter( + name="p", parameter_type=ParameterType.INT, lower=0, upper=3 + ), + ) + p_choice_1 = ChoiceParameter( + name="p", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + dependents={"a": ["p1"], "c": ["p2"]}, + ) + p_choice_2 = ChoiceParameter( + name="p", parameter_type=ParameterType.STRING, values=["a", "b", "d"] + ) + self.assertEqual( + merge_parameters(p1=p_choice_1, p2=p_choice_2, reverse_param_config={}), + ChoiceParameter( + name="p", + parameter_type=ParameterType.STRING, + values=["a", "b", "c", "d"], + dependents={"a": ["p1"], "c": ["p2"]}, + ), + ) + + # FixedParameter + ChoiceParameter: fixed value already in choices. + p_fixed_str = FixedParameter( + name="p", parameter_type=ParameterType.STRING, value="a" + ) + merged_fc = merge_parameters( + p1=p_fixed_str, p2=p_choice_1, reverse_param_config={} + ) + self.assertIsInstance(merged_fc, ChoiceParameter) + merged_fc_choice = assert_is_instance(merged_fc, ChoiceParameter) + self.assertEqual(set(merged_fc_choice.values), {"a", "b", "c"}) + # Dependents from the choice parameter are preserved. + self.assertEqual(merged_fc_choice.dependents, {"a": ["p1"], "c": ["p2"]}) + + # FixedParameter + ChoiceParameter: fixed value NOT in choices. + p_fixed_str_new = FixedParameter( + name="p", parameter_type=ParameterType.STRING, value="z" + ) + merged_fc2 = merge_parameters( + p1=p_fixed_str_new, p2=p_choice_1, reverse_param_config={} + ) + self.assertEqual( + set(assert_is_instance(merged_fc2, ChoiceParameter).values), + {"a", "b", "c", "z"}, + ) + + # Reversed order: ChoiceParameter as p1, FixedParameter as p2. + merged_cf = merge_parameters( + p1=p_choice_1, p2=p_fixed_str_new, reverse_param_config={} + ) + self.assertEqual( + set(assert_is_instance(merged_cf, ChoiceParameter).values), + {"a", "b", "c", "z"}, + ) + + # DerivedParameter: same expression succeeds. + p_derived_1 = DerivedParameter( + name="d", + parameter_type=ParameterType.FLOAT, + expression_str="0.5 * x + 0.3 * y", + ) + p_derived_2 = DerivedParameter( + name="d", + parameter_type=ParameterType.FLOAT, + expression_str="0.5 * x + 0.3 * y", + ) + merged = merge_parameters( + p1=p_derived_1, p2=p_derived_2, reverse_param_config={} + ) + self.assertIsInstance(merged, DerivedParameter) + self.assertEqual( + assert_is_instance(merged, DerivedParameter).expression_str, + "0.5 * x + 0.3 * y", + ) + self.assertEqual(merged.name, "d") + + # DerivedParameter: different expressions raises ValueError. + p_derived_3 = DerivedParameter( + name="d", + parameter_type=ParameterType.FLOAT, + expression_str="0.7 * x + 0.1 * y", + ) + with self.assertRaisesRegex(ValueError, "different expressions"): + merge_parameters(p1=p_derived_1, p2=p_derived_3, reverse_param_config={}) + + # DerivedParameter vs FixedParameter raises ValueError (type mismatch). + p_fixed_float = FixedParameter( + name="d", parameter_type=ParameterType.FLOAT, value=1.0 + ) + with self.assertRaisesRegex(ValueError, "different types"): + merge_parameters( + p1=p_derived_1, + p2=p_fixed_float, + reverse_param_config={}, + ) diff --git a/ax/adapter/transfer_learning/tests/test_utils_torch.py b/ax/adapter/transfer_learning/tests/test_utils_torch.py new file mode 100644 index 00000000000..f9850b8a976 --- /dev/null +++ b/ax/adapter/transfer_learning/tests/test_utils_torch.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.adapter.transfer_learning.utils import get_joint_search_space +from ax.adapter.transfer_learning.utils_torch import get_mapped_parameter_names +from ax.adapter.transforms.one_hot import OneHot +from ax.adapter.transforms.remove_fixed import RemoveFixed +from ax.core.auxiliary_source import AuxiliarySource +from ax.core.parameter import ( + ChoiceParameter, + FixedParameter, + ParameterType, + RangeParameter, +) +from ax.core.search_space import SearchSpace +from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_branin_experiment, get_branin_search_space + + +class TestUtilsTorch(TestCase): + def setUp(self) -> None: + super().setUp() + base_params = list(get_branin_search_space().parameters.values()) + fp1 = FixedParameter(name="fp1", parameter_type=ParameterType.STRING, value="a") + fp2 = FixedParameter(name="fp2", parameter_type=ParameterType.STRING, value="b") + x3 = RangeParameter( + name="x3", parameter_type=ParameterType.FLOAT, lower=0, upper=1 + ) + rp1 = RangeParameter( + name="rp1", parameter_type=ParameterType.FLOAT, lower=-10, upper=20 + ) + cp1 = ChoiceParameter( + name="cp1", + parameter_type=ParameterType.STRING, + values=["a", "b", "c"], + is_ordered=False, + ) + self.target_ss = SearchSpace(base_params + [x3, fp1]) + source_ss = SearchSpace(base_params + [rp1, fp2]) + source_ss2 = SearchSpace(base_params + [x3, fp1, rp1, cp1]) + transfer_param_config = {"x3": "rp1"} + + source_exp1 = get_branin_experiment( + with_completed_trial=True, search_space=self.target_ss.clone() + ) + source_exp2 = get_branin_experiment( + with_completed_trial=True, search_space=source_ss + ) + source_exp3 = get_branin_experiment( + with_completed_trial=True, search_space=source_ss2 + ) + + self.auxsrc1 = AuxiliarySource( + experiment=source_exp1, update_fixed_params=False + ) + self.auxsrc2 = AuxiliarySource( + experiment=source_exp2, transfer_param_config=transfer_param_config + ) + self.auxsrc3 = AuxiliarySource( + experiment=source_exp2, + transfer_param_config=transfer_param_config, + update_fixed_params=False, + ) + self.auxsrc4 = AuxiliarySource(experiment=source_exp3) + + def test_mapped_parameter_names(self) -> None: + # Auxsrc1 has 4 params that should all get returned. + # The search space is same as the target. + mapped_names = get_mapped_parameter_names( + self.auxsrc1, target_search_space=self.target_ss + ) + self.assertEqual(mapped_names, ["x1", "x2", "x3", "fp1"]) + # Auxsrc2 has 4 params. The search space is different from the target. + # The fixed param fp2 will be replaced with fp1. rp1 will be mapped to x3. + mapped_names = get_mapped_parameter_names( + self.auxsrc2, target_search_space=self.target_ss + ) + self.assertEqual(mapped_names, ["x1", "x2", "x3", "fp1"]) + # This is same search space as auxsrc2 but fixed param should not change. + # rp1 will be mapped to x3. + mapped_names = get_mapped_parameter_names( + self.auxsrc3, target_search_space=self.target_ss + ) + self.assertEqual(mapped_names, ["x1", "x2", "fp2", "x3"]) + # Auxsrc4 has 6 params. No change expected. + mapped_names = get_mapped_parameter_names( + self.auxsrc4, target_search_space=self.target_ss + ) + self.assertEqual(mapped_names, ["x1", "x2", "x3", "rp1", "cp1", "fp1"]) + # With OneHot, cp1 will convert to 3 parameters. RemoveFixed will remove fp1. + joint_ss = get_joint_search_space( + search_space=self.target_ss, + auxiliary_sources=[self.auxsrc4], + ) + mapped_names = get_mapped_parameter_names( + self.auxsrc4, + target_search_space=self.target_ss, + transforms={ # pyre-ignore[6] + "OneHot": OneHot(search_space=joint_ss), + "RemoveFixed": RemoveFixed(search_space=joint_ss), + }, + ) + self.assertEqual( + mapped_names, + [ + "x1", + "x2", + "x3", + "rp1", + "cp1_OH_PARAM_0", + "cp1_OH_PARAM_1", + "cp1_OH_PARAM_2", + ], + ) From f7862719fd23e41095b8f29ba2f1621a6e776311 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 3 Jun 2026 15:02:19 -0700 Subject: [PATCH 2/4] Add native step_size support to RangeParameter MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Adds a `step_size` arg to `RangeParameter` that snaps values to a grid anchored at `lower` (in `cast()`), for both FLOAT and INT parameters. This is the first diff in the step_size unification stack (see `ax/api/configs/RFC_step_size_unification.md`): `step_size` will subsume both the discrete-grid and limited-resolution (`digits`) use cases under one knob. In this diff `step_size` coexists with the existing `digits` arg (they are mutually exclusive at construction). Subsequent diffs in the stack migrate storage (JSON + SQA), transforms and utils, and the public API (`RangeParameterConfig`) to `step_size`, then deprecate `digits` in favor of it. Behavior: - `cast()` rounds `(value - lower) / step_size` to the nearest integer and returns `lower + n * step_size`. It does NOT clamp to `[lower, upper]`: an out-of-bounds input (e.g. a historical observation recorded outside the current bounds) snaps to the nearest grid point, which may itself be out of bounds. This mirrors the non-`step_size` `cast()`, which leaves out-of-bounds values in place rather than silently moving them into range — range validity is enforced by `validate()`, not `cast()`. - Both bounds must lie on the grid: `(upper - lower)` must be an integer multiple of `step_size` (within `EPS`). Off-grid bounds are rejected at construction. This guarantees `upper` is itself a feasible value, so a value near the upper bound snaps to `upper` rather than to a grid point short of it. - `step_size` must be strictly positive, and must be integer-valued for INT parameters. - `cardinality()` accounts for `step_size`: a grid-valued FLOAT reports the finite number of grid points instead of `inf`, and a grid-valued INT counts grid points rather than every integer in `[lower, upper]`. `step_size` defines a discrete grid but does not, by itself, force discrete acquisition optimization; how the optimizer treats the parameter depends on the grid cardinality and is determined at the generator level. Differential Revision: D107274057 --- ax/core/parameter.py | 196 +++++++++++++++++++++++++++++++- ax/core/tests/test_parameter.py | 191 +++++++++++++++++++++++++++++++ 2 files changed, 385 insertions(+), 2 deletions(-) diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 1e73f1720f0..5b6e657fc4a 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -341,6 +341,7 @@ def __init__( log_scale: bool = False, logit_scale: bool = False, digits: int | None = None, + step_size: float | None = None, is_fidelity: bool = False, target_value: TParamValue = None, backfill_value: TParamValue = None, @@ -359,6 +360,25 @@ def __init__( logit_scale: Whether to sample in logit space when drawing random values of the parameter. digits: Number of digits to round values to for float type. + Deprecated in favor of ``step_size``; cannot be set together + with ``step_size``. + step_size: If set, the parameter's feasible values are the grid + ``{lower + k * step_size : k in N}`` intersected with + ``[lower, upper]``. ``cast()`` snaps values to the nearest grid + point (anchored at ``lower``) without clamping to the bounds, so + an out-of-bounds input snaps to an out-of-bounds grid point -- + mirroring the non-``step_size`` ``cast()``, which also leaves + out-of-bounds values in place. ``step_size`` must be strictly + positive, and the range must be an exact multiple of it: + ``(upper - lower)`` must be an integer multiple of ``step_size`` + (within ``EPS``), so that both bounds lie on the grid. For INT + parameters, ``step_size`` must itself be integer-valued. + + ``step_size`` defines a discrete grid but does not, by itself, + force discrete acquisition optimization. How the optimizer + treats the parameter depends on the grid cardinality + ``floor((upper - lower) / step_size) + 1``, and is determined + at the generator level. is_fidelity: Whether this parameter is a fidelity parameter. target_value: Target value of this parameter if it is a fidelity. backfill_value: For parameters added to experiments that have already run @@ -378,6 +398,10 @@ def __init__( raise UserInputError("RangeParameter type must be int or float.") self._parameter_type = parameter_type self._digits = digits + # ``_step_size`` must be set before casting ``lower`` / ``upper`` below, + # since ``cast()`` reads it to snap values to the grid. + self._step_size: float | None = None + self._validate_and_set_step_size(step_size=step_size) self._lower: TNumeric = self.cast(lower) self._upper: TNumeric = self.cast(upper) self._log_scale = log_scale @@ -393,6 +417,12 @@ def __init__( self.cast(default_value) if default_value is not None else None ) + # Validate the raw inputs: this rejects invalid user input (e.g. a + # non-integer bound for an INT parameter) before ``cast()`` silently + # truncates it. For the non-deprecated paths ``cast()`` does not move a + # bound that would otherwise pass validation -- FLOAT casting is a no-op + # on the value, and ``step_size`` snapping is skipped for bounds -- so + # validating the raw inputs also guarantees the stored bounds are valid. self._validate_range_param( parameter_type=parameter_type, lower=lower, @@ -400,8 +430,19 @@ def __init__( log_scale=log_scale, logit_scale=logit_scale, ) + # ``upper`` must additionally lie on the ``step_size`` grid (the grid is + # anchored at ``lower``). + self._validate_step_size_on_grid() def cardinality(self) -> TNumeric: + if self._step_size is not None: + # Values are snapped to the grid {lower + k * step_size} + # intersected with [lower, upper]. Both bounds lie on the grid + # (enforced at construction), so the number of grid points is + # (upper - lower) / step_size + 1. + step_size = none_throws(self._step_size) + return round((float(self.upper) - float(self.lower)) / step_size) + 1 + if self.parameter_type == ParameterType.FLOAT: return inf @@ -493,6 +534,19 @@ def digits(self) -> int | None: """ return self._digits + @property + def step_size(self) -> float | None: + """Grid spacing that values are snapped to in ``cast()``. + + If set, the parameter's feasible values are the grid + ``{lower + k * step_size : k in N}`` intersected with ``[lower, upper]``, + and ``cast()`` snaps values to the nearest grid point (without clamping + to the bounds). Both bounds are guaranteed to be on the grid (the + constructor requires ``(upper - lower)`` to be an integer multiple of + ``step_size``). ``None`` means no snapping. + """ + return self._step_size + @property def log_scale(self) -> bool: """Whether the parameter's values should be sampled from log space.""" @@ -519,14 +573,25 @@ def update_range( if upper is None: upper = self._upper - cast_lower = self.cast(lower) - cast_upper = self.cast(upper) + # When ``step_size`` is set, cast the bounds without snapping to the + # (old) grid: bounds anchor the grid and must not be silently moved onto + # it. ``super().cast()`` applies only the type cast. The digits path + # (deprecated) keeps its historical rounding behavior via ``self.cast``. + if self._step_size is not None: + cast_lower = assert_is_instance(super().cast(lower), TNumeric) + cast_upper = assert_is_instance(super().cast(upper), TNumeric) + else: + cast_lower = self.cast(lower) + cast_upper = self.cast(upper) self._validate_range_param( lower=cast_lower, upper=cast_upper, log_scale=self.log_scale, logit_scale=self.logit_scale, ) + # The new bounds must lie on the ``step_size`` grid, if one is set. + # Validate before committing so a failed update leaves bounds unchanged. + self._validate_step_size_on_grid(lower=cast_lower, upper=cast_upper) self._lower = cast_lower self._upper = cast_upper return self @@ -546,6 +611,95 @@ def set_digits(self, digits: int | None) -> RangeParameter: self._upper = cast_upper return self + def set_step_size(self, step_size: float | None) -> RangeParameter: + """Set the grid spacing that values are snapped to in ``cast()``. + + The existing bounds are kept as-is (they anchor the grid and define the + feasible range); they are not snapped onto the new grid. Instead we + require that they already lie on it: ``(upper - lower)`` must be an + integer multiple of the new ``step_size``. + + Raises: + UserInputError: If the current bounds do not lie on the new grid. + """ + previous_step_size = self._step_size + self._validate_and_set_step_size(step_size=step_size) + try: + # The current (unchanged) bounds must lie on the new grid. + self._validate_step_size_on_grid() + except UserInputError: + # Leave the parameter unchanged if the new grid is invalid. + self._step_size = previous_step_size + raise + return self + + def _validate_and_set_step_size(self, step_size: float | None) -> None: + """Validate ``step_size`` and store it on ``self._step_size``. + + Raises: + UserInputError: If ``step_size`` is non-positive, if it is set + together with ``digits``, or if it is not integer-valued for an + INT parameter. + """ + if step_size is None: + self._step_size = None + return + if self._digits is not None: + raise UserInputError( + f"Cannot set both `digits` and `step_size` on parameter " + f"{self._name}. `digits` is deprecated; use `step_size` only." + ) + if step_size <= 0: + raise UserInputError( + f"`step_size` must be strictly positive for parameter " + f"{self._name}. Got: {step_size}." + ) + if ( + self._parameter_type is ParameterType.INT + and not float(step_size).is_integer() + ): + raise UserInputError( + f"`step_size` must be integer-valued for INT parameter " + f"{self._name}. Got: {step_size}." + ) + self._step_size = float(step_size) + + def _validate_step_size_on_grid( + self, lower: TNumeric | None = None, upper: TNumeric | None = None + ) -> None: + """Validate that both bounds lie on the ``step_size`` grid. + + The grid is anchored at ``lower``, so ``lower`` is always on it. This + additionally requires ``upper`` to be on the grid, i.e. that + ``(upper - lower)`` is an integer multiple of ``step_size`` (within + ``EPS``). This guarantees ``upper`` is itself a feasible value, so a + value near the upper bound snaps to ``upper`` rather than to a grid + point short of it. + + Args: + lower: Lower bound to validate against. Defaults to ``self._lower``. + upper: Upper bound to validate against. Defaults to ``self._upper``. + These overrides let callers validate prospective bounds before + committing them. + + Raises: + UserInputError: If ``upper`` does not lie on the grid. + """ + if self._step_size is None: + return + lower = self._lower if lower is None else lower + upper = self._upper if upper is None else upper + step_size = none_throws(self._step_size) + width = float(upper) - float(lower) + n = width / step_size + if abs(n - round(n)) * step_size > EPS: + raise UserInputError( + f"`step_size` must evenly divide the range of parameter " + f"{self._name}: (upper - lower) = {width} is not an integer " + f"multiple of step_size = {step_size}. Adjust the bounds or " + f"step_size so that both bounds lie on the grid." + ) + def set_log_scale(self, log_scale: bool) -> RangeParameter: self._log_scale = log_scale return self @@ -647,6 +801,7 @@ def clone(self) -> RangeParameter: log_scale=self._log_scale, logit_scale=self._logit_scale, digits=self._digits, + step_size=self._step_size, is_fidelity=self._is_fidelity, target_value=self._target_value, backfill_value=self._backfill_value, @@ -657,13 +812,50 @@ def cast(self, value: TParamValue) -> TNumeric: value = super().cast(value=value) if self.parameter_type is ParameterType.FLOAT and self._digits is not None: return round(float(value), none_throws(self._digits)) + # Skip snapping while the constructor is still casting the bounds + # themselves (before both ``self._lower`` and ``self._upper`` are set): + # the bounds anchor the grid and must not be snapped (``upper`` is only + # validated to be on the grid after both are assigned). ``_snap_to_grid`` + # needs ``self._lower``; gating on ``self._upper`` too is what excludes + # the ``upper`` cast at construction. + if ( + self._step_size is not None + and getattr(self, "_lower", None) is not None + and getattr(self, "_upper", None) is not None + ): + value = self._snap_to_grid(value=float(value)) return assert_is_instance(value, TNumeric) + def _snap_to_grid(self, value: float) -> TNumeric: + """Snap ``value`` to the nearest grid point. + + The grid is ``{lower + k * step_size : k in Z}``. The nearest grid point + is found by rounding ``(value - lower) / step_size`` to the nearest + integer. The result is *not* clamped to ``[lower, upper]``: an + out-of-bounds input (e.g. historical observations recorded outside the + current bounds) snaps to the nearest grid point, which may itself lie + outside the bounds. This mirrors the non-``step_size`` ``cast()``, which + leaves out-of-bounds values untouched rather than silently moving them + into range -- range validity is enforced by ``validate()``, not by + ``cast()``. For INT parameters the snapped value is integer-valued + (``step_size`` is validated to be an integer), so it is returned as an + ``int``. + """ + step_size = none_throws(self._step_size) + lower = float(self._lower) + n = round((value - lower) / step_size) + snapped = lower + n * step_size + if self.parameter_type is ParameterType.INT: + return int(round(snapped)) + return snapped + def __repr__(self) -> str: ret_val = self._base_repr() if self._digits is not None: ret_val += f", digits={self._digits}" + if self._step_size is not None: + ret_val += f", step_size={self._step_size}" return ret_val + ")" diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index dc40742aa8f..eb3d1182cd0 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -183,6 +183,197 @@ def test_Clone(self) -> None: param_clone._lower = 2.0 self.assertNotEqual(self.param1.lower, param_clone.lower) + def test_step_size_snapping(self) -> None: + # ``cast()`` snaps to the nearest grid point anchored at ``lower``. Each + # case is (param, [(input, expected), ...]). Inputs avoid exact + # half-points (e.g. 0.15) where round-half-to-even makes the result + # ambiguous. + float_param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + # ``lower`` need not be zero; the grid {0.005, 0.015, 0.025} is anchored + # at the (non-zero) lower bound. + offset_param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.005, + upper=0.025, + step_size=0.01, + ) + int_param = RangeParameter( + name="y", + parameter_type=ParameterType.INT, + lower=0, + upper=10, + step_size=2, + ) + cases = [ + ( + float_param, + [ + (0.12, 0.1), + (0.16, 0.2), + (0.13, 0.1), + # Bounds are on the grid (1 / 0.1 = 10) and reachable. + (0.0, 0.0), + (1.0, 1.0), + # 0.98 snaps to the upper bound 1.0. + (0.98, 1.0), + # Out-of-bounds inputs (e.g. historical observations + # recorded outside the current bounds) snap to the nearest + # grid point WITHOUT being clamped into [lower, upper]. This + # mirrors the non-step_size cast(), which leaves + # out-of-bounds values in place rather than mutating them. + (1.5, 1.5), + (1.52, 1.5), + (-0.3, -0.3), + ], + ), + (offset_param, [(0.006, 0.005), (0.013, 0.015), (0.025, 0.025)]), + ( + int_param, + [ + # 3 -> (3-0)/2 = 1.5 -> round-half-to-even -> 2 -> 4. + (3, 4), + # 7 -> (7-0)/2 = 3.5 -> round-half-to-even -> 4 -> 8. + (7, 8), + (0, 0), + (10, 10), + # Out-of-bounds INT inputs are likewise not clamped: + # 15 -> (15-0)/2 = 7.5 -> round-half-to-even -> 8 -> 16, + # which exceeds upper=10 and is kept as-is. + (15, 16), + ], + ), + ] + for param, input_expected in cases: + for value, expected in input_expected: + with self.subTest(param=param.name, value=value): + snapped = none_throws(param.cast(value)) + self.assertAlmostEqual(snapped, expected) + if param.parameter_type is ParameterType.INT: + self.assertIsInstance(snapped, int) + + def test_step_size_validation(self) -> None: + # All special-value / off-grid rejections at construction. + # Non-positive step_size. + with self.assertRaisesRegex(UserInputError, "must be strictly positive"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, step_size=0.0) + with self.assertRaisesRegex(UserInputError, "must be strictly positive"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, step_size=-0.1) + # Non-integer step_size for an INT parameter. + with self.assertRaisesRegex(UserInputError, "must be integer-valued"): + RangeParameter("y", ParameterType.INT, 0, 10, step_size=2.5) + # Off-grid bounds: (upper - lower) is not an integer multiple of + # step_size, so ``upper`` is off the grid anchored at ``lower``. FLOAT + # grid {0, 0.3, 0.6, 0.9} misses upper=1.0; INT grid {0, 3, 6, 9} misses + # upper=10. + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, step_size=0.3) + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + RangeParameter("y", ParameterType.INT, 0, 10, step_size=3) + # Cannot set both digits and step_size. + with self.assertRaisesRegex(UserInputError, "Cannot set both"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, digits=2, step_size=0.1) + + def test_set_step_size(self) -> None: + param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + ) + self.assertIsNone(param.step_size) + returned = param.set_step_size(0.25) + self.assertIs(returned, param) + self.assertEqual(param.step_size, 0.25) + self.assertAlmostEqual(none_throws(param.cast(0.3)), 0.25) + # Clearing step_size disables snapping. + param.set_step_size(None) + self.assertIsNone(param.step_size) + self.assertAlmostEqual(none_throws(param.cast(0.3)), 0.3) + # ``set_step_size`` rejects a step that does not divide the range. + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + param.set_step_size(0.3) + # The failed call leaves the parameter unchanged (no snapping). + self.assertIsNone(param.step_size) + + def test_update_range_respects_step_size_grid(self) -> None: + param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + # New bounds on the grid are accepted and not snapped away. + param.update_range(lower=0.0, upper=0.5) + self.assertAlmostEqual(float(param.upper), 0.5) + # New bounds off the grid are rejected. + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + param.update_range(upper=0.55) + + def test_step_size_repr_and_clone(self) -> None: + param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + self.assertIn("step_size=0.1", str(param)) + clone = param.clone() + self.assertEqual(clone.step_size, 0.1) + self.assertEqual(param, clone) + # Parameters differing only in step_size are not equal. + other = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.2, + ) + self.assertNotEqual(param, other) + + def test_step_size_cardinality(self) -> None: + # FLOAT with a grid has finite cardinality (number of grid points), + # not inf. + float_param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + # Grid {0.0, 0.1, ..., 1.0} has 11 points. + self.assertEqual(float_param.cardinality(), 11) + # INT with a grid counts grid points, not every integer in [lower, + # upper]. + int_param = RangeParameter( + name="y", + parameter_type=ParameterType.INT, + lower=0, + upper=10, + step_size=2, + ) + # Grid {0, 2, 4, 6, 8, 10} has 6 points (not 11). + self.assertEqual(int_param.cardinality(), 6) + # Without step_size, behavior is unchanged: FLOAT is inf, INT counts + # every integer. + self.assertEqual( + RangeParameter( + name="z", + parameter_type=ParameterType.INT, + lower=0, + upper=10, + ).cardinality(), + 11, + ) + def test_get_parameter_type(self) -> None: self.assertEqual(_get_parameter_type(float), ParameterType.FLOAT) self.assertEqual(_get_parameter_type(int), ParameterType.INT) From 4f1983902165d9c91dfe330390e587e3de580c43 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 3 Jun 2026 15:02:19 -0700 Subject: [PATCH 3/4] Add storage support for RangeParameter.step_size (JSON + SQA) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Teaches both storage backends to persist and restore the new `step_size` field on `RangeParameter`. Second diff in the step_size unification stack (builds on the native `step_size` support added in D107274057). The underlying SQA column (`SQAParameter.step_size`) and the corresponding EntAE schema were added and landed separately (D107108021, D107112306), so this diff is pure encoder/decoder logic — it does not touch the schema. SQA: the encoder writes `parameter.step_size` to the column (and continues writing the legacy `digits` column for now; `digits` is dropped in a later diff). The decoder prefers `step_size` and falls back to the legacy `digits` column for rows written by older code, passing only one to the constructor (which rejects both being set; it converts `digits` to `step_size` internally). JSON: `range_parameter_to_dict` emits `step_size` alongside `digits`. There is no custom JSON decoder for `RangeParameter` — the registry maps it directly to the class, and `ax_class_from_json_dict` splats the serialized dict straight into the constructor. So `step_size` flows through automatically, and legacy blobs carrying only `digits` still decode via the constructor's back-compat handling. No decoder change is required for JSON. Differential Revision: D107282941 --- ax/storage/json_store/encoders.py | 1 + .../json_store/tests/test_json_store.py | 39 ++++++++++++++++++- ax/storage/sqa_store/decoder.py | 10 ++++- ax/storage/sqa_store/encoder.py | 1 + ax/storage/sqa_store/tests/test_sqa_store.py | 35 +++++++++++++++++ 5 files changed, 84 insertions(+), 2 deletions(-) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index da84a260a77..181b8d5b19f 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -201,6 +201,7 @@ def range_parameter_to_dict(parameter: RangeParameter) -> dict[str, Any]: "log_scale": parameter.log_scale, "logit_scale": parameter.logit_scale, "digits": parameter.digits, + "step_size": parameter.step_size, "is_fidelity": parameter.is_fidelity, "target_value": parameter.target_value, } diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index d40598dcf47..ed14783fcc1 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -55,7 +55,7 @@ OptimizationConfig, PreferenceOptimizationConfig, ) -from ax.core.parameter import ChoiceParameter, ParameterType +from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.parameter_constraint import ParameterConstraint from ax.core.runner import Runner from ax.exceptions.core import AxStorageWarning, UnsupportedError @@ -398,6 +398,17 @@ ("ParameterConstraint", get_parameter_constraint), ("ParameterConstraint", get_equality_parameter_constraint), ("RangeParameter", get_range_parameter), + ( + "RangeParameter", + partial( + RangeParameter, + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ), + ), ("ScalarizedObjective", get_scalarized_objective), ("ScalarizedOutcomeConstraint", get_scalarized_outcome_constraint), ("OrchestratorOptions", get_default_orchestrator_options), @@ -1942,6 +1953,32 @@ def test_multi_objective_from_json_warning(self) -> None: any("Found unexpected kwargs" in warning for warning in cm.output) ) + def test_range_parameter_legacy_digits_blob_decodes(self) -> None: + # A legacy blob has "digits" but no "step_size" key. It must still + # decode (the constructor accepts digits for back-compat). + legacy_blob = { + "__type": "RangeParameter", + "name": "x", + "parameter_type": {"__type": "ParameterType", "name": "FLOAT"}, + "lower": 0.0, + "upper": 1.0, + "log_scale": False, + "logit_scale": False, + "digits": 2, + "is_fidelity": False, + "target_value": None, + } + decoded = object_from_json( + legacy_blob, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + self.assertIsInstance(decoded, RangeParameter) + self.assertEqual(decoded.digits, 2) + self.assertIsNone(decoded.step_size) + # Rounding behavior from digits=2 is preserved. + self.assertEqual(decoded.cast(0.123), 0.12) + def test_choice_parameter_bypass_cardinality_check_encode_failure(self) -> None: choice_parameter = ChoiceParameter( name="test_choice", diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index d564c764765..3ac2d3c906b 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -470,13 +470,21 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: "`dependents` unexpectedly non-null on range parameter " f"{parameter_sqa.name}." ) + # Prefer the newer ``step_size`` column; fall back to the legacy + # ``digits`` column for rows written by older code. Only one may be + # passed to the constructor (it rejects both being set). The + # constructor converts ``digits`` to ``step_size`` internally. + # We should never have the two together in the DB, so this is extra. + step_size = parameter_sqa.step_size + digits = None if step_size is not None else parameter_sqa.digits parameter = RangeParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, lower=float(none_throws(parameter_sqa.lower)), upper=float(none_throws(parameter_sqa.upper)), log_scale=parameter_sqa.log_scale or False, - digits=parameter_sqa.digits, + digits=digits, + step_size=float(step_size) if step_size is not None else None, is_fidelity=parameter_sqa.is_fidelity or False, target_value=parameter_sqa.target_value, backfill_value=parameter_sqa.backfill_value, diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 7ee21b70b8d..2ba801ab61a 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -308,6 +308,7 @@ def parameter_to_sqa(self, parameter: Parameter) -> SQAParameter: upper=float(parameter.upper), log_scale=parameter.log_scale, digits=parameter.digits, + step_size=parameter.step_size, is_fidelity=parameter.is_fidelity, target_value=parameter.target_value, dependents=parameter.dependents if parameter.is_hierarchical else None, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 5f3c50540b4..d82dc508f37 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1880,6 +1880,41 @@ def test_logit_scale(self) -> None: ) ) + def test_step_size_round_trip(self) -> None: + parameter = RangeParameter( + name="foo", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + sqa_parameter = self.encoder.parameter_to_sqa(parameter) + self.assertEqual(sqa_parameter.step_size, 0.1) + self.assertIsNone(sqa_parameter.digits) + decoded = self.decoder.parameter_from_sqa(sqa_parameter) + self.assertEqual(decoded, parameter) + self.assertEqual(assert_is_instance(decoded, RangeParameter).step_size, 0.1) + + def test_legacy_digits_row_decodes_via_fallback(self) -> None: + # A legacy row carries ``digits`` but no ``step_size``. The decoder + # falls back to ``digits``, which the constructor converts internally. + parameter = RangeParameter( + name="foo", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + digits=2, + ) + sqa_parameter = self.encoder.parameter_to_sqa(parameter) + # Simulate an old row: clear step_size, keep digits populated. + sqa_parameter.step_size = None + sqa_parameter.digits = 2 + decoded = assert_is_instance( + self.decoder.parameter_from_sqa(sqa_parameter), RangeParameter + ) + # The decoded parameter round-trips its rounding behavior. + self.assertEqual(decoded.cast(0.123), parameter.cast(0.123)) + def test_bypass_cardinality_check(self) -> None: choice_parameter = ChoiceParameter( name="test_choice", From 6c023bc013d184e3927befe3d075ee3faff59f8a Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 4 Jun 2026 09:41:25 -0700 Subject: [PATCH 4/4] Migrate transforms and utils to RangeParameter.step_size Summary: Teaches the transform/util layer about step_size. Third diff in the step_size unification stack. - Cast: snaps RangeParameter values to the grid via parameter.cast() on both the observation-features path (already calls cast) and the experiment_data dataframe path (replacing the .round(digits) call), for params with digits OR step_size set. - Log/Logit/UnitX: clear step_size before rescaling (in addition to clearing digits, until digits is fully removed in a later diff); step_size is re-applied in the original space by Cast on untransform. - int_to_float: does not forward step_size to the FLOAT surrogate (anchor would be misaligned); the original INT param's snapping is re-applied by Cast. See TODO. - map_key_to_float, transfer_learning merge_parameters, service instantiation: forward/read step_size, preferring it over digits. - service instantiation: add step_size to EXPECTED_KEYS_IN_PARAM_REPR. The RangeParameter construction path already read representation["step_size"], but the key was not in the recognized-keys set, so any parameter representation passing step_size would have been rejected with an "Unexpected keys" error. This makes the step_size representation path actually usable. - core_stubs: add get_range_parameter_with_step_size helper. Test coverage added for the migrated paths: Log/Logit transform_search_space clears step_size (mirroring the existing UnitX and clears_digits tests), MapKeyToFloat forwards step_size from config, merge_parameters forwards step_size/digits from p1, and parameter_from_json accepts step_size and legacy digits for range parameters. Differential Revision: D107284896 --- .../transfer_learning/tests/test_utils.py | 35 +++++++++ ax/adapter/transfer_learning/utils.py | 1 + ax/adapter/transforms/base.py | 12 ++-- ax/adapter/transforms/cast.py | 35 +++++++-- ax/adapter/transforms/log.py | 7 +- ax/adapter/transforms/logit.py | 8 ++- ax/adapter/transforms/map_key_to_float.py | 7 +- .../transforms/tests/test_cast_transform.py | 71 +++++++++++++++++++ .../transforms/tests/test_log_transform.py | 26 ++++++- .../transforms/tests/test_logit_transform.py | 53 ++++++++------ .../tests/test_map_key_to_float_transform.py | 14 ++++ .../transforms/tests/test_unit_x_transform.py | 54 ++++++++------ ax/adapter/transforms/unit_x.py | 8 ++- ax/service/tests/test_instantiation_utils.py | 33 +++++++++ ax/service/utils/instantiation.py | 14 +++- ax/utils/testing/core_stubs.py | 11 +++ 16 files changed, 328 insertions(+), 61 deletions(-) diff --git a/ax/adapter/transfer_learning/tests/test_utils.py b/ax/adapter/transfer_learning/tests/test_utils.py index ccc644778f7..c98a47e89e5 100644 --- a/ax/adapter/transfer_learning/tests/test_utils.py +++ b/ax/adapter/transfer_learning/tests/test_utils.py @@ -254,6 +254,41 @@ def test_merge_parameters(self) -> None: name="p", parameter_type=ParameterType.INT, lower=0, upper=3 ), ) + # The grid spec (step_size, or legacy digits) is forwarded from p1, even + # when the merged bounds are widened by p2 (as long as they stay on p1's + # grid). + p_range_step = RangeParameter( + name="p", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + p_range_wide = RangeParameter( + name="p", parameter_type=ParameterType.FLOAT, lower=0.0, upper=2.0 + ) + merged_step = assert_is_instance( + merge_parameters(p1=p_range_step, p2=p_range_wide, reverse_param_config={}), + RangeParameter, + ) + self.assertEqual(merged_step.upper, 2.0) + self.assertEqual(merged_step.step_size, 0.1) + self.assertIsNone(merged_step.digits) + p_range_digits = RangeParameter( + name="p", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + digits=2, + ) + merged_digits = assert_is_instance( + merge_parameters( + p1=p_range_digits, p2=p_range_wide, reverse_param_config={} + ), + RangeParameter, + ) + self.assertEqual(merged_digits.digits, 2) + self.assertIsNone(merged_digits.step_size) p_choice_1 = ChoiceParameter( name="p", parameter_type=ParameterType.STRING, diff --git a/ax/adapter/transfer_learning/utils.py b/ax/adapter/transfer_learning/utils.py index 814b6f43e3a..0000f54e9d6 100644 --- a/ax/adapter/transfer_learning/utils.py +++ b/ax/adapter/transfer_learning/utils.py @@ -111,6 +111,7 @@ def merge_parameters( log_scale=p1.log_scale, logit_scale=p1.logit_scale, digits=p1.digits, + step_size=p1.step_size, is_fidelity=p1.is_fidelity, target_value=p1.target_value, ) diff --git a/ax/adapter/transforms/base.py b/ax/adapter/transforms/base.py index 2f47c7c9c86..7c89ad7f091 100644 --- a/ax/adapter/transforms/base.py +++ b/ax/adapter/transforms/base.py @@ -94,12 +94,12 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: transform (does nothing). NOTE for subclasses: If a transform changes the *scale* of a - RangeParameter (e.g., Log, UnitX, Logit), it must clear ``digits`` - via ``p.set_digits(digits=None)`` before calling ``update_range``. - Otherwise, rounding calibrated for the original scale will corrupt - the transformed bounds (e.g., ``digits=-3`` rounds to the nearest - 1000, which collapses [0, 1] to 0). The Cast transform re-applies - ``digits`` in the original space during untransform. + RangeParameter (e.g., Log, UnitX, Logit), it must clear ``step_size`` + via ``p.set_step_size(step_size=None)`` before calling ``update_range``. + Otherwise, snapping calibrated for the original scale will corrupt the + transformed bounds (a grid spacing meaningful in the original space is + meaningless after a non-linear rescale). The Cast transform re-applies + ``step_size`` in the original space during untransform. Args: search_space: The search space diff --git a/ax/adapter/transforms/cast.py b/ax/adapter/transforms/cast.py index 58e7aad5ead..68ab4c8a593 100644 --- a/ax/adapter/transforms/cast.py +++ b/ax/adapter/transforms/cast.py @@ -25,7 +25,7 @@ from ax.exceptions.core import UserInputError from ax.generators.types import TConfig from ax.utils.common.constants import Keys -from pandas import DataFrame +from pandas import DataFrame, Series from pyre_extensions import assert_is_instance, none_throws if TYPE_CHECKING: @@ -314,11 +314,38 @@ def transform_experiment_data( for p, param in self.search_space.parameters.items() } arm_data = arm_data.astype(dtype=column_to_type) - # Round to digits if any parameter specifies it. + # Snap to the parameter's grid (digits or step_size) if specified. + # These mirror ``RangeParameter.cast``'s rounding logic, but are applied + # in a vectorized manner over the whole column rather than via a per-row + # ``Series.apply`` (which calls ``parameter.cast`` once per element and is + # slow for large DataFrames). NaN / ```` values (added for missing + # columns during the ``reindex`` above) propagate through ``round`` and + # the arithmetic, matching the previous ``value if value is None`` guard. for p_name in parameter_names: parameter = self.search_space.parameters[p_name] - if isinstance(parameter, RangeParameter) and parameter.digits is not None: - arm_data[p_name] = arm_data[p_name].round(parameter.digits) + if not isinstance(parameter, RangeParameter): + continue + column: Series = arm_data[p_name] + if ( + parameter.parameter_type is ParameterType.FLOAT + and parameter.digits is not None + ): + # ``Series.round`` uses round-half-to-even, same as Python's + # built-in ``round`` used in ``RangeParameter.cast``. + arm_data[p_name] = column.round(parameter.digits) + elif parameter.step_size is not None: + # Snap to the grid ``{lower + k * step_size : k in Z}`` by + # rounding ``(value - lower) / step_size`` to the nearest integer. + lower = float(parameter.lower) + step_size = none_throws(parameter.step_size) + steps: Series = column.sub(lower).div(step_size).round() + snapped: Series = steps.mul(step_size).add(lower) + if parameter.parameter_type is ParameterType.INT: + # Preserve the nullable ``Int64`` dtype so reindex-added + # ```` values survive the cast. + arm_data[p_name] = snapped.round().astype("Int64") + else: + arm_data[p_name] = snapped return ExperimentData(arm_data=arm_data, observation_data=observation_data) diff --git a/ax/adapter/transforms/log.py b/ax/adapter/transforms/log.py index c03d45e150e..b466e53c007 100644 --- a/ax/adapter/transforms/log.py +++ b/ax/adapter/transforms/log.py @@ -79,7 +79,12 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: isinstance(p, RangeParameter) and p.parameter_type == ParameterType.FLOAT ): - # Don't round in log space + # Don't snap/round in log space; step_size (or legacy + # digits) will be re-applied in the original space by the + # Cast transform during untransform. Both are cleared until + # digits is fully removed (see step_size unification RFC). + if p.step_size is not None: + p.set_step_size(step_size=None) if p.digits is not None: p.set_digits(digits=None) p.set_log_scale(False).update_range( diff --git a/ax/adapter/transforms/logit.py b/ax/adapter/transforms/logit.py index cfe43cd157e..9d0642fe18b 100644 --- a/ax/adapter/transforms/logit.py +++ b/ax/adapter/transforms/logit.py @@ -66,8 +66,12 @@ def transform_observation_features( def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: for p_name, p in search_space.parameters.items(): if p_name in self.transform_parameters and isinstance(p, RangeParameter): - # Don't round in logit space; digits will be re-applied in - # the original space by the Cast transform during untransform. + # Don't snap/round in logit space; step_size (or legacy digits) + # will be re-applied in the original space by the Cast transform + # during untransform. Both are cleared until digits is fully + # removed (see step_size unification RFC). + if p.step_size is not None: + p.set_step_size(step_size=None) if p.digits is not None: p.set_digits(digits=None) p.set_logit_scale(False).update_range( diff --git a/ax/adapter/transforms/map_key_to_float.py b/ax/adapter/transforms/map_key_to_float.py index 1ac256554f5..c5b24542654 100644 --- a/ax/adapter/transforms/map_key_to_float.py +++ b/ax/adapter/transforms/map_key_to_float.py @@ -127,13 +127,18 @@ def __init__( return p_config = self.parameters[MAP_KEY] + # Prefer ``step_size``; fall back to legacy ``digits``. Only one may + # be passed to the constructor (it rejects both being set). + step_size = p_config.get("step_size", None) + digits = None if step_size is not None else p_config.get("digits", None) self._parameter_list.append( RangeParameter( name=MAP_KEY, parameter_type=ParameterType.FLOAT, lower=p_config.get("lower", min(values)), upper=p_config.get("upper", max(values)), - digits=p_config.get("digits", None), + digits=digits, + step_size=step_size, is_fidelity=p_config.get("is_fidelity", False), target_value=p_config.get("target_value", None), ) diff --git a/ax/adapter/transforms/tests/test_cast_transform.py b/ax/adapter/transforms/tests/test_cast_transform.py index d41de7085ee..40a18b1a68a 100644 --- a/ax/adapter/transforms/tests/test_cast_transform.py +++ b/ax/adapter/transforms/tests/test_cast_transform.py @@ -323,6 +323,77 @@ def test_cast_parameter_type_and_none(self) -> None: ] self.assertEqual(tf_observations, expected) + def test_cast_step_size_observation_features(self) -> None: + # Cast snaps RangeParameter values to step_size on (un)transform, just + # as it rounds to digits. + search_space = SearchSpace( + parameters=[ + RangeParameter( + name="range", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ), + ] + ) + t = Cast(search_space=search_space) + obs_features = [ + ObservationFeatures(parameters={"range": 0.12}), + ObservationFeatures(parameters={"range": 0.36}), + ] + tf_obs_features = t.transform_observation_features( + observation_features=obs_features + ) + self.assertAlmostEqual( + float(none_throws(tf_obs_features[0].parameters["range"])), 0.1 + ) + self.assertAlmostEqual( + float(none_throws(tf_obs_features[1].parameters["range"])), 0.4 + ) + + def test_transform_experiment_data_step_size(self) -> None: + # The experiment_data dataframe path snaps RangeParameter values to + # step_size, for both FLOAT and INT parameters. The INT parameter also + # checks that the snapped column keeps the nullable Int64 dtype. + experiment = get_experiment_with_observations( + observations=[[0.0], [1.0]], + search_space=SearchSpace( + parameters=[ + RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ), + RangeParameter( + name="y", + parameter_type=ParameterType.INT, + lower=0, + upper=10, + step_size=2, + ), + ] + ), + parameterizations=[ + {"x": 0.12, "y": 3}, + {"x": 0.36, "y": 7}, + ], + ) + experiment_data = extract_experiment_data( + experiment=experiment, data_loader_config=DataLoaderConfig() + ) + transformed = Cast( + search_space=experiment.search_space + ).transform_experiment_data(experiment_data=deepcopy(experiment_data)) + self.assertAlmostEqual(transformed.arm_data["x"].iloc[0], 0.1) + self.assertAlmostEqual(transformed.arm_data["x"].iloc[1], 0.4) + # 3 snaps to 4 (round half to even: 1.5 -> 2 steps), 7 snaps to 8. + self.assertEqual(transformed.arm_data["y"].iloc[0], 4) + self.assertEqual(transformed.arm_data["y"].iloc[1], 8) + self.assertEqual(transformed.arm_data["y"].dtype, "Int64") + def test_transform_experiment_data_flatten(self) -> None: # Tests for flattening of hierarchical parameterizations. columns = [ diff --git a/ax/adapter/transforms/tests/test_log_transform.py b/ax/adapter/transforms/tests/test_log_transform.py index 463db934c67..72be8a7b794 100644 --- a/ax/adapter/transforms/tests/test_log_transform.py +++ b/ax/adapter/transforms/tests/test_log_transform.py @@ -117,12 +117,36 @@ def test_TransformSearchSpace(self) -> None: ss2 = deepcopy(self.search_space) ss2 = self.t.transform_search_space(ss2) - # Test float log-scale parameter transformation + # Test float log-scale parameter transformation. The grid (legacy + # ``digits`` here; ``step_size`` covered below) must be cleared during + # the transform -- a grid meaningful in the original space is + # meaningless after a log10 rescale, and Cast re-applies it in the + # original space on untransform. param_x = assert_is_instance(ss2.parameters["x"], RangeParameter) self.assertEqual(param_x.lower, math.log10(1)) self.assertEqual(param_x.upper, math.log10(3)) self.assertIsNone(param_x.digits) + # Same clearing behavior for ``step_size`` (mutually exclusive with + # ``digits``, so it needs its own parameter). + ss_step = SearchSpace( + parameters=[ + RangeParameter( + "x", + lower=1.0, + upper=1000.0, + parameter_type=ParameterType.FLOAT, + log_scale=True, + step_size=1.0, + ), + ] + ) + ss_step = Log(search_space=ss_step).transform_search_space(ss_step) + param_x_step = assert_is_instance(ss_step.parameters["x"], RangeParameter) + self.assertIsNone(param_x_step.step_size) + self.assertEqual(param_x_step.lower, math.log10(1.0)) + self.assertEqual(param_x_step.upper, math.log10(1000.0)) + # Test integer log-scale parameter transformation (converted to ChoiceParameter) param_y = assert_is_instance(ss2.parameters["y"], ChoiceParameter) self.assertEqual(param_y.parameter_type, ParameterType.FLOAT) diff --git a/ax/adapter/transforms/tests/test_logit_transform.py b/ax/adapter/transforms/tests/test_logit_transform.py index e76544ba288..ca4b868447b 100644 --- a/ax/adapter/transforms/tests/test_logit_transform.py +++ b/ax/adapter/transforms/tests/test_logit_transform.py @@ -122,27 +122,38 @@ def test_TransformSearchSpace(self) -> None: self.assertEqual(x_param.lower, logit(0.1)) self.assertEqual(x_param.upper, logit(0.3)) - def test_transform_search_space_clears_digits(self) -> None: - """Test that digits is cleared during transform to avoid rounding - in logit space.""" - ss = SearchSpace( - parameters=[ - RangeParameter( - "x", - lower=0.1, - upper=0.9, - parameter_type=ParameterType.FLOAT, - logit_scale=True, - digits=3, - ), - ] - ) - t = Logit(search_space=ss) - ss = t.transform_search_space(ss) - x = assert_is_instance(ss.parameters["x"], RangeParameter) - self.assertIsNone(x.digits) - self.assertAlmostEqual(x.lower, logit(0.1)) - self.assertAlmostEqual(x.upper, logit(0.9)) + def test_transform_search_space_clears_grid(self) -> None: + """The grid (legacy ``digits`` or ``step_size``, which are mutually + exclusive) must be cleared during the transform to avoid rounding / + snapping in logit space; it is re-applied in the original space by Cast + on untransform.""" + grid_params = [ + RangeParameter( + "x", + lower=0.1, + upper=0.9, + parameter_type=ParameterType.FLOAT, + logit_scale=True, + digits=3, + ), + RangeParameter( + "x", + lower=0.1, + upper=0.9, + parameter_type=ParameterType.FLOAT, + logit_scale=True, + step_size=0.1, + ), + ] + for param in grid_params: + with self.subTest(param=param): + ss = SearchSpace(parameters=[param]) + ss = Logit(search_space=ss).transform_search_space(ss) + x = assert_is_instance(ss.parameters["x"], RangeParameter) + self.assertIsNone(x.digits) + self.assertIsNone(x.step_size) + self.assertAlmostEqual(x.lower, logit(0.1)) + self.assertAlmostEqual(x.upper, logit(0.9)) def test_transform_experiment_data(self) -> None: parameterizations = [ diff --git a/ax/adapter/transforms/tests/test_map_key_to_float_transform.py b/ax/adapter/transforms/tests/test_map_key_to_float_transform.py index 5ba81413886..767b39fab03 100644 --- a/ax/adapter/transforms/tests/test_map_key_to_float_transform.py +++ b/ax/adapter/transforms/tests/test_map_key_to_float_transform.py @@ -355,6 +355,20 @@ def test_Init(self) -> None: self.assertEqual(p.upper, 1.0) self.assertFalse(p.log_scale) + # step_size from the config is forwarded to the surrogate parameter. + with self.subTest(msg="step_size from config"): + t = MapKeyToFloat( + experiment_data=self.experiment_data, + config={ + "parameters": { + self.map_key: {"lower": 0.0, "upper": 1.0, "step_size": 0.1} + } + }, + ) + p = t._parameter_list[0] + self.assertEqual(p.step_size, 0.1) + self.assertIsNone(p.digits) + def test_TransformSearchSpace(self) -> None: ss2 = deepcopy(self.search_space) ss2 = self.t.transform_search_space(ss2) diff --git a/ax/adapter/transforms/tests/test_unit_x_transform.py b/ax/adapter/transforms/tests/test_unit_x_transform.py index 5f615c8f182..575b2a79bf0 100644 --- a/ax/adapter/transforms/tests/test_unit_x_transform.py +++ b/ax/adapter/transforms/tests/test_unit_x_transform.py @@ -167,28 +167,38 @@ def test_TransformSearchSpaceEqualityConstraints(self) -> None: self.assertEqual(ineq_c.constraint_dict, {"x": -1.0, "y": 1.0}) self.assertEqual(ineq_c.bound, 0.0) - def test_transform_search_space_clears_digits(self) -> None: - """Test that digits is cleared during transform to avoid rounding - in unit space. Regression test for a bug where digits=-3 (round to - nearest 1000) collapsed [0, 1] bounds to (0.0, 0.0).""" - ss = SearchSpace( - parameters=[ - RangeParameter( - "w", - lower=5000.0, - upper=500000.0, - parameter_type=ParameterType.FLOAT, - digits=-3, - ), - ] - ) - t = UnitX(search_space=ss) - ss = t.transform_search_space(ss) - w = assert_is_instance(ss.parameters["w"], RangeParameter) - # digits must be cleared so rounding doesn't corrupt [0, 1] bounds. - self.assertIsNone(w.digits) - self.assertEqual(w.lower, 0.0) - self.assertEqual(w.upper, 1.0) + def test_transform_search_space_clears_grid(self) -> None: + """The grid (legacy ``digits`` or ``step_size``, which are mutually + exclusive) must be cleared during transform to avoid rounding / snapping + in unit space; a grid meaningful in the original space is meaningless + after rescaling to [0, 1]. The ``digits=-3`` case is a regression test + for a bug where rounding to the nearest 1000 collapsed the [0, 1] bounds + to (0.0, 0.0).""" + grid_params = [ + RangeParameter( + "w", + lower=5000.0, + upper=500000.0, + parameter_type=ParameterType.FLOAT, + digits=-3, + ), + RangeParameter( + "w", + lower=5000.0, + upper=500000.0, + parameter_type=ParameterType.FLOAT, + step_size=1000.0, + ), + ] + for param in grid_params: + with self.subTest(param=param): + ss = SearchSpace(parameters=[param]) + ss = UnitX(search_space=ss).transform_search_space(ss) + w = assert_is_instance(ss.parameters["w"], RangeParameter) + self.assertIsNone(w.digits) + self.assertIsNone(w.step_size) + self.assertEqual(w.lower, 0.0) + self.assertEqual(w.upper, 1.0) def test_TransformNewSearchSpace(self) -> None: new_ss = SearchSpace( diff --git a/ax/adapter/transforms/unit_x.py b/ax/adapter/transforms/unit_x.py index a0b6df828d4..86cb309035e 100644 --- a/ax/adapter/transforms/unit_x.py +++ b/ax/adapter/transforms/unit_x.py @@ -73,8 +73,12 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: if (p_bounds := self.bounds.get(p_name)) is not None and isinstance( p, RangeParameter ): - # Don't round in unit space; digits will be re-applied in - # the original space by the Cast transform during untransform. + # Don't snap/round in unit space; step_size (or legacy digits) + # will be re-applied in the original space by the Cast transform + # during untransform. Both are cleared until digits is fully + # removed (see step_size unification RFC). + if p.step_size is not None: + p.set_step_size(step_size=None) if p.digits is not None: p.set_digits(digits=None) p.update_range( diff --git a/ax/service/tests/test_instantiation_utils.py b/ax/service/tests/test_instantiation_utils.py index 4c7f53ce511..2b75ff48f0f 100644 --- a/ax/service/tests/test_instantiation_utils.py +++ b/ax/service/tests/test_instantiation_utils.py @@ -390,6 +390,39 @@ def test_choice_with_is_sorted(self) -> None: } _ = InstantiationBase.parameter_from_json(representation) + def test_range_parameter_step_size_and_digits(self) -> None: + # ``step_size`` is forwarded from the representation to the parameter. + step_size_param = assert_is_instance( + InstantiationBase.parameter_from_json( + { + "name": "x", + "type": "range", + "bounds": [0.0, 1.0], + "step_size": 0.1, + } + ), + RangeParameter, + ) + self.assertEqual(step_size_param.step_size, 0.1) + self.assertIsNone(step_size_param.digits) + + # Legacy ``digits`` is still accepted for backwards compatibility. + digits_param = assert_is_instance( + InstantiationBase.parameter_from_json( + { + "name": "x", + "type": "range", + "bounds": [0.0, 1.0], + "digits": 2, + } + ), + RangeParameter, + ) + self.assertEqual(digits_param.digits, 2) + self.assertIsNone(digits_param.step_size) + # ``digits=2`` rounds cast values to two decimal places. + self.assertEqual(digits_param.cast(0.123), 0.12) + def test_hss(self) -> None: parameter_dicts: list[ dict[str, TParamValue | Sequence[TParamValue] | dict[str, list[str]]] diff --git a/ax/service/utils/instantiation.py b/ax/service/utils/instantiation.py index 6e2399967de..72367a12707 100644 --- a/ax/service/utils/instantiation.py +++ b/ax/service/utils/instantiation.py @@ -81,6 +81,7 @@ "is_ordered", "is_task", "digits", + "step_size", "dependents", "expression_str", } @@ -246,7 +247,18 @@ def _make_range_param( lower=assert_is_instance_of_tuple(bounds[0], (float, int)), upper=assert_is_instance_of_tuple(bounds[1], (float, int)), log_scale=assert_is_instance(representation.get("log_scale", False), bool), - digits=assert_is_instance_optional(representation.get("digits", None), int), + # Prefer ``step_size``; fall back to legacy ``digits`` (only one may + # be passed -- the constructor rejects both being set). + digits=( + None + if representation.get("step_size", None) is not None + else assert_is_instance_optional( + representation.get("digits", None), int + ) + ), + step_size=assert_is_instance_optional( + representation.get("step_size", None), float + ), is_fidelity=assert_is_instance( representation.get("is_fidelity", False), bool ), diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 926e7704e0d..58b8f84e19f 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -1966,6 +1966,17 @@ def get_range_parameter2() -> RangeParameter: return RangeParameter(name="x", parameter_type=ParameterType.INT, lower=1, upper=10) +def get_range_parameter_with_step_size() -> RangeParameter: + return RangeParameter( + name="w", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + log_scale=False, + step_size=0.1, + ) + + def get_choice_parameter() -> ChoiceParameter: return ChoiceParameter( name="y",