From c20744a55d3a6a2729080e69803ab0fb3b187987 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 4 Jun 2026 09:34:52 -0700 Subject: [PATCH 1/2] Add native step_size support to RangeParameter (#5213) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: TL;DR: Adds `step_size`, which is already exposed in ax/api `RangeParameterConfig` but resolves to `ChoiceParameter`, to `RangeParameter`. This will generalize and replace `digits` as a way to allow large discrete range parameters over a grid of possible values, with the downstream code deciding whether to treat it as a continuous or discrete parameter for optimization, based on parameter cardinality. Adds a `step_size` arg to `RangeParameter` that snaps values to a grid anchored at `lower` (in `cast()`), for both FLOAT and INT parameters. This is the first diff in the step_size unification stack: `step_size` will subsume both the discrete-grid and limited-resolution (`digits`) use cases under one knob. - Next diff will add storage support. The internal DB has already been updated to include the new column. - We will then migrate all current usage off `digits` and onto `step_size`. - We will add support for treating low-cardinality float-range parameters as discrete in `Adapter`, so that it is efficiently optimized over the correct grid (rather than having to use continuous optimization + rounding). - At this point, we will have proper support for `step_size`, so we can update the ax/api usage to leverage it, rather than resolving to `ChoiceParameter`. - We can then deprecate `digits` and do any remaining clean-up. In this diff `step_size` coexists with the existing `digits` arg (they are mutually exclusive at construction). Subsequent diffs in the stack migrate storage (JSON + SQA), transforms and utils, and the public API (`RangeParameterConfig`) to `step_size`, then deprecate `digits` in favor of it. Behavior: - `cast()` rounds `(value - lower) / step_size` to the nearest integer and returns `lower + n * step_size`. It does NOT clamp to `[lower, upper]`: an out-of-bounds input (e.g. a historical observation recorded outside the current bounds) snaps to the nearest grid point, which may itself be out of bounds. This mirrors the non-`step_size` `cast()`, which leaves out-of-bounds values in place rather than silently moving them into range — range validity is enforced by `validate()`, not `cast()`. - Both bounds must lie on the grid: `(upper - lower)` must be an integer multiple of `step_size` (within `EPS`). Off-grid bounds are rejected at construction. This guarantees `upper` is itself a feasible value, so a value near the upper bound snaps to `upper` rather than to a grid point short of it. - `step_size` must be strictly positive, and must be integer-valued for INT parameters. - `cardinality()` accounts for `step_size`: a grid-valued FLOAT reports the finite number of grid points instead of `inf`, and a grid-valued INT counts grid points rather than every integer in `[lower, upper]`. `step_size` defines a discrete grid but does not, by itself, force discrete acquisition optimization; how the optimizer treats the parameter depends on the grid cardinality and is determined at the generator level. Differential Revision: D107274057 --- ax/core/parameter.py | 196 +++++++++++++++++++++++++++++++- ax/core/tests/test_parameter.py | 191 +++++++++++++++++++++++++++++++ 2 files changed, 385 insertions(+), 2 deletions(-) diff --git a/ax/core/parameter.py b/ax/core/parameter.py index 1e73f1720f0..5b6e657fc4a 100644 --- a/ax/core/parameter.py +++ b/ax/core/parameter.py @@ -341,6 +341,7 @@ def __init__( log_scale: bool = False, logit_scale: bool = False, digits: int | None = None, + step_size: float | None = None, is_fidelity: bool = False, target_value: TParamValue = None, backfill_value: TParamValue = None, @@ -359,6 +360,25 @@ def __init__( logit_scale: Whether to sample in logit space when drawing random values of the parameter. digits: Number of digits to round values to for float type. + Deprecated in favor of ``step_size``; cannot be set together + with ``step_size``. + step_size: If set, the parameter's feasible values are the grid + ``{lower + k * step_size : k in N}`` intersected with + ``[lower, upper]``. ``cast()`` snaps values to the nearest grid + point (anchored at ``lower``) without clamping to the bounds, so + an out-of-bounds input snaps to an out-of-bounds grid point -- + mirroring the non-``step_size`` ``cast()``, which also leaves + out-of-bounds values in place. ``step_size`` must be strictly + positive, and the range must be an exact multiple of it: + ``(upper - lower)`` must be an integer multiple of ``step_size`` + (within ``EPS``), so that both bounds lie on the grid. For INT + parameters, ``step_size`` must itself be integer-valued. + + ``step_size`` defines a discrete grid but does not, by itself, + force discrete acquisition optimization. How the optimizer + treats the parameter depends on the grid cardinality + ``floor((upper - lower) / step_size) + 1``, and is determined + at the generator level. is_fidelity: Whether this parameter is a fidelity parameter. target_value: Target value of this parameter if it is a fidelity. backfill_value: For parameters added to experiments that have already run @@ -378,6 +398,10 @@ def __init__( raise UserInputError("RangeParameter type must be int or float.") self._parameter_type = parameter_type self._digits = digits + # ``_step_size`` must be set before casting ``lower`` / ``upper`` below, + # since ``cast()`` reads it to snap values to the grid. + self._step_size: float | None = None + self._validate_and_set_step_size(step_size=step_size) self._lower: TNumeric = self.cast(lower) self._upper: TNumeric = self.cast(upper) self._log_scale = log_scale @@ -393,6 +417,12 @@ def __init__( self.cast(default_value) if default_value is not None else None ) + # Validate the raw inputs: this rejects invalid user input (e.g. a + # non-integer bound for an INT parameter) before ``cast()`` silently + # truncates it. For the non-deprecated paths ``cast()`` does not move a + # bound that would otherwise pass validation -- FLOAT casting is a no-op + # on the value, and ``step_size`` snapping is skipped for bounds -- so + # validating the raw inputs also guarantees the stored bounds are valid. self._validate_range_param( parameter_type=parameter_type, lower=lower, @@ -400,8 +430,19 @@ def __init__( log_scale=log_scale, logit_scale=logit_scale, ) + # ``upper`` must additionally lie on the ``step_size`` grid (the grid is + # anchored at ``lower``). + self._validate_step_size_on_grid() def cardinality(self) -> TNumeric: + if self._step_size is not None: + # Values are snapped to the grid {lower + k * step_size} + # intersected with [lower, upper]. Both bounds lie on the grid + # (enforced at construction), so the number of grid points is + # (upper - lower) / step_size + 1. + step_size = none_throws(self._step_size) + return round((float(self.upper) - float(self.lower)) / step_size) + 1 + if self.parameter_type == ParameterType.FLOAT: return inf @@ -493,6 +534,19 @@ def digits(self) -> int | None: """ return self._digits + @property + def step_size(self) -> float | None: + """Grid spacing that values are snapped to in ``cast()``. + + If set, the parameter's feasible values are the grid + ``{lower + k * step_size : k in N}`` intersected with ``[lower, upper]``, + and ``cast()`` snaps values to the nearest grid point (without clamping + to the bounds). Both bounds are guaranteed to be on the grid (the + constructor requires ``(upper - lower)`` to be an integer multiple of + ``step_size``). ``None`` means no snapping. + """ + return self._step_size + @property def log_scale(self) -> bool: """Whether the parameter's values should be sampled from log space.""" @@ -519,14 +573,25 @@ def update_range( if upper is None: upper = self._upper - cast_lower = self.cast(lower) - cast_upper = self.cast(upper) + # When ``step_size`` is set, cast the bounds without snapping to the + # (old) grid: bounds anchor the grid and must not be silently moved onto + # it. ``super().cast()`` applies only the type cast. The digits path + # (deprecated) keeps its historical rounding behavior via ``self.cast``. + if self._step_size is not None: + cast_lower = assert_is_instance(super().cast(lower), TNumeric) + cast_upper = assert_is_instance(super().cast(upper), TNumeric) + else: + cast_lower = self.cast(lower) + cast_upper = self.cast(upper) self._validate_range_param( lower=cast_lower, upper=cast_upper, log_scale=self.log_scale, logit_scale=self.logit_scale, ) + # The new bounds must lie on the ``step_size`` grid, if one is set. + # Validate before committing so a failed update leaves bounds unchanged. + self._validate_step_size_on_grid(lower=cast_lower, upper=cast_upper) self._lower = cast_lower self._upper = cast_upper return self @@ -546,6 +611,95 @@ def set_digits(self, digits: int | None) -> RangeParameter: self._upper = cast_upper return self + def set_step_size(self, step_size: float | None) -> RangeParameter: + """Set the grid spacing that values are snapped to in ``cast()``. + + The existing bounds are kept as-is (they anchor the grid and define the + feasible range); they are not snapped onto the new grid. Instead we + require that they already lie on it: ``(upper - lower)`` must be an + integer multiple of the new ``step_size``. + + Raises: + UserInputError: If the current bounds do not lie on the new grid. + """ + previous_step_size = self._step_size + self._validate_and_set_step_size(step_size=step_size) + try: + # The current (unchanged) bounds must lie on the new grid. + self._validate_step_size_on_grid() + except UserInputError: + # Leave the parameter unchanged if the new grid is invalid. + self._step_size = previous_step_size + raise + return self + + def _validate_and_set_step_size(self, step_size: float | None) -> None: + """Validate ``step_size`` and store it on ``self._step_size``. + + Raises: + UserInputError: If ``step_size`` is non-positive, if it is set + together with ``digits``, or if it is not integer-valued for an + INT parameter. + """ + if step_size is None: + self._step_size = None + return + if self._digits is not None: + raise UserInputError( + f"Cannot set both `digits` and `step_size` on parameter " + f"{self._name}. `digits` is deprecated; use `step_size` only." + ) + if step_size <= 0: + raise UserInputError( + f"`step_size` must be strictly positive for parameter " + f"{self._name}. Got: {step_size}." + ) + if ( + self._parameter_type is ParameterType.INT + and not float(step_size).is_integer() + ): + raise UserInputError( + f"`step_size` must be integer-valued for INT parameter " + f"{self._name}. Got: {step_size}." + ) + self._step_size = float(step_size) + + def _validate_step_size_on_grid( + self, lower: TNumeric | None = None, upper: TNumeric | None = None + ) -> None: + """Validate that both bounds lie on the ``step_size`` grid. + + The grid is anchored at ``lower``, so ``lower`` is always on it. This + additionally requires ``upper`` to be on the grid, i.e. that + ``(upper - lower)`` is an integer multiple of ``step_size`` (within + ``EPS``). This guarantees ``upper`` is itself a feasible value, so a + value near the upper bound snaps to ``upper`` rather than to a grid + point short of it. + + Args: + lower: Lower bound to validate against. Defaults to ``self._lower``. + upper: Upper bound to validate against. Defaults to ``self._upper``. + These overrides let callers validate prospective bounds before + committing them. + + Raises: + UserInputError: If ``upper`` does not lie on the grid. + """ + if self._step_size is None: + return + lower = self._lower if lower is None else lower + upper = self._upper if upper is None else upper + step_size = none_throws(self._step_size) + width = float(upper) - float(lower) + n = width / step_size + if abs(n - round(n)) * step_size > EPS: + raise UserInputError( + f"`step_size` must evenly divide the range of parameter " + f"{self._name}: (upper - lower) = {width} is not an integer " + f"multiple of step_size = {step_size}. Adjust the bounds or " + f"step_size so that both bounds lie on the grid." + ) + def set_log_scale(self, log_scale: bool) -> RangeParameter: self._log_scale = log_scale return self @@ -647,6 +801,7 @@ def clone(self) -> RangeParameter: log_scale=self._log_scale, logit_scale=self._logit_scale, digits=self._digits, + step_size=self._step_size, is_fidelity=self._is_fidelity, target_value=self._target_value, backfill_value=self._backfill_value, @@ -657,13 +812,50 @@ def cast(self, value: TParamValue) -> TNumeric: value = super().cast(value=value) if self.parameter_type is ParameterType.FLOAT and self._digits is not None: return round(float(value), none_throws(self._digits)) + # Skip snapping while the constructor is still casting the bounds + # themselves (before both ``self._lower`` and ``self._upper`` are set): + # the bounds anchor the grid and must not be snapped (``upper`` is only + # validated to be on the grid after both are assigned). ``_snap_to_grid`` + # needs ``self._lower``; gating on ``self._upper`` too is what excludes + # the ``upper`` cast at construction. + if ( + self._step_size is not None + and getattr(self, "_lower", None) is not None + and getattr(self, "_upper", None) is not None + ): + value = self._snap_to_grid(value=float(value)) return assert_is_instance(value, TNumeric) + def _snap_to_grid(self, value: float) -> TNumeric: + """Snap ``value`` to the nearest grid point. + + The grid is ``{lower + k * step_size : k in Z}``. The nearest grid point + is found by rounding ``(value - lower) / step_size`` to the nearest + integer. The result is *not* clamped to ``[lower, upper]``: an + out-of-bounds input (e.g. historical observations recorded outside the + current bounds) snaps to the nearest grid point, which may itself lie + outside the bounds. This mirrors the non-``step_size`` ``cast()``, which + leaves out-of-bounds values untouched rather than silently moving them + into range -- range validity is enforced by ``validate()``, not by + ``cast()``. For INT parameters the snapped value is integer-valued + (``step_size`` is validated to be an integer), so it is returned as an + ``int``. + """ + step_size = none_throws(self._step_size) + lower = float(self._lower) + n = round((value - lower) / step_size) + snapped = lower + n * step_size + if self.parameter_type is ParameterType.INT: + return int(round(snapped)) + return snapped + def __repr__(self) -> str: ret_val = self._base_repr() if self._digits is not None: ret_val += f", digits={self._digits}" + if self._step_size is not None: + ret_val += f", step_size={self._step_size}" return ret_val + ")" diff --git a/ax/core/tests/test_parameter.py b/ax/core/tests/test_parameter.py index dc40742aa8f..eb3d1182cd0 100644 --- a/ax/core/tests/test_parameter.py +++ b/ax/core/tests/test_parameter.py @@ -183,6 +183,197 @@ def test_Clone(self) -> None: param_clone._lower = 2.0 self.assertNotEqual(self.param1.lower, param_clone.lower) + def test_step_size_snapping(self) -> None: + # ``cast()`` snaps to the nearest grid point anchored at ``lower``. Each + # case is (param, [(input, expected), ...]). Inputs avoid exact + # half-points (e.g. 0.15) where round-half-to-even makes the result + # ambiguous. + float_param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + # ``lower`` need not be zero; the grid {0.005, 0.015, 0.025} is anchored + # at the (non-zero) lower bound. + offset_param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.005, + upper=0.025, + step_size=0.01, + ) + int_param = RangeParameter( + name="y", + parameter_type=ParameterType.INT, + lower=0, + upper=10, + step_size=2, + ) + cases = [ + ( + float_param, + [ + (0.12, 0.1), + (0.16, 0.2), + (0.13, 0.1), + # Bounds are on the grid (1 / 0.1 = 10) and reachable. + (0.0, 0.0), + (1.0, 1.0), + # 0.98 snaps to the upper bound 1.0. + (0.98, 1.0), + # Out-of-bounds inputs (e.g. historical observations + # recorded outside the current bounds) snap to the nearest + # grid point WITHOUT being clamped into [lower, upper]. This + # mirrors the non-step_size cast(), which leaves + # out-of-bounds values in place rather than mutating them. + (1.5, 1.5), + (1.52, 1.5), + (-0.3, -0.3), + ], + ), + (offset_param, [(0.006, 0.005), (0.013, 0.015), (0.025, 0.025)]), + ( + int_param, + [ + # 3 -> (3-0)/2 = 1.5 -> round-half-to-even -> 2 -> 4. + (3, 4), + # 7 -> (7-0)/2 = 3.5 -> round-half-to-even -> 4 -> 8. + (7, 8), + (0, 0), + (10, 10), + # Out-of-bounds INT inputs are likewise not clamped: + # 15 -> (15-0)/2 = 7.5 -> round-half-to-even -> 8 -> 16, + # which exceeds upper=10 and is kept as-is. + (15, 16), + ], + ), + ] + for param, input_expected in cases: + for value, expected in input_expected: + with self.subTest(param=param.name, value=value): + snapped = none_throws(param.cast(value)) + self.assertAlmostEqual(snapped, expected) + if param.parameter_type is ParameterType.INT: + self.assertIsInstance(snapped, int) + + def test_step_size_validation(self) -> None: + # All special-value / off-grid rejections at construction. + # Non-positive step_size. + with self.assertRaisesRegex(UserInputError, "must be strictly positive"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, step_size=0.0) + with self.assertRaisesRegex(UserInputError, "must be strictly positive"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, step_size=-0.1) + # Non-integer step_size for an INT parameter. + with self.assertRaisesRegex(UserInputError, "must be integer-valued"): + RangeParameter("y", ParameterType.INT, 0, 10, step_size=2.5) + # Off-grid bounds: (upper - lower) is not an integer multiple of + # step_size, so ``upper`` is off the grid anchored at ``lower``. FLOAT + # grid {0, 0.3, 0.6, 0.9} misses upper=1.0; INT grid {0, 3, 6, 9} misses + # upper=10. + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, step_size=0.3) + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + RangeParameter("y", ParameterType.INT, 0, 10, step_size=3) + # Cannot set both digits and step_size. + with self.assertRaisesRegex(UserInputError, "Cannot set both"): + RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0, digits=2, step_size=0.1) + + def test_set_step_size(self) -> None: + param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + ) + self.assertIsNone(param.step_size) + returned = param.set_step_size(0.25) + self.assertIs(returned, param) + self.assertEqual(param.step_size, 0.25) + self.assertAlmostEqual(none_throws(param.cast(0.3)), 0.25) + # Clearing step_size disables snapping. + param.set_step_size(None) + self.assertIsNone(param.step_size) + self.assertAlmostEqual(none_throws(param.cast(0.3)), 0.3) + # ``set_step_size`` rejects a step that does not divide the range. + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + param.set_step_size(0.3) + # The failed call leaves the parameter unchanged (no snapping). + self.assertIsNone(param.step_size) + + def test_update_range_respects_step_size_grid(self) -> None: + param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + # New bounds on the grid are accepted and not snapped away. + param.update_range(lower=0.0, upper=0.5) + self.assertAlmostEqual(float(param.upper), 0.5) + # New bounds off the grid are rejected. + with self.assertRaisesRegex(UserInputError, "must evenly divide the range"): + param.update_range(upper=0.55) + + def test_step_size_repr_and_clone(self) -> None: + param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + self.assertIn("step_size=0.1", str(param)) + clone = param.clone() + self.assertEqual(clone.step_size, 0.1) + self.assertEqual(param, clone) + # Parameters differing only in step_size are not equal. + other = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.2, + ) + self.assertNotEqual(param, other) + + def test_step_size_cardinality(self) -> None: + # FLOAT with a grid has finite cardinality (number of grid points), + # not inf. + float_param = RangeParameter( + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + # Grid {0.0, 0.1, ..., 1.0} has 11 points. + self.assertEqual(float_param.cardinality(), 11) + # INT with a grid counts grid points, not every integer in [lower, + # upper]. + int_param = RangeParameter( + name="y", + parameter_type=ParameterType.INT, + lower=0, + upper=10, + step_size=2, + ) + # Grid {0, 2, 4, 6, 8, 10} has 6 points (not 11). + self.assertEqual(int_param.cardinality(), 6) + # Without step_size, behavior is unchanged: FLOAT is inf, INT counts + # every integer. + self.assertEqual( + RangeParameter( + name="z", + parameter_type=ParameterType.INT, + lower=0, + upper=10, + ).cardinality(), + 11, + ) + def test_get_parameter_type(self) -> None: self.assertEqual(_get_parameter_type(float), ParameterType.FLOAT) self.assertEqual(_get_parameter_type(int), ParameterType.INT) From e4fed74ee2fabecbad9da8ecf40f6e8aec9ac594 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Thu, 4 Jun 2026 09:34:52 -0700 Subject: [PATCH 2/2] Add storage support for RangeParameter.step_size (JSON + SQA) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Teaches both storage backends to persist and restore the new `step_size` field on `RangeParameter`. Second diff in the step_size unification stack (builds on the native `step_size` support added in D107274057). The underlying SQA column (`SQAParameter.step_size`) and the corresponding EntAE schema were added and landed separately (D107108021, D107112306), so this diff is pure encoder/decoder logic — it does not touch the schema. SQA: the encoder writes `parameter.step_size` to the column (and continues writing the legacy `digits` column for now; `digits` is dropped in a later diff). The decoder prefers `step_size` and falls back to the legacy `digits` column for rows written by older code, passing only one to the constructor (which rejects both being set; it converts `digits` to `step_size` internally). JSON: `range_parameter_to_dict` emits `step_size` alongside `digits`. There is no custom JSON decoder for `RangeParameter` — the registry maps it directly to the class, and `ax_class_from_json_dict` splats the serialized dict straight into the constructor. So `step_size` flows through automatically, and legacy blobs carrying only `digits` still decode via the constructor's back-compat handling. No decoder change is required for JSON. Differential Revision: D107282941 --- ax/storage/json_store/encoders.py | 1 + .../json_store/tests/test_json_store.py | 39 ++++++++++++++++++- ax/storage/sqa_store/decoder.py | 10 ++++- ax/storage/sqa_store/encoder.py | 1 + ax/storage/sqa_store/tests/test_sqa_store.py | 35 +++++++++++++++++ 5 files changed, 84 insertions(+), 2 deletions(-) diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index da84a260a77..181b8d5b19f 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -201,6 +201,7 @@ def range_parameter_to_dict(parameter: RangeParameter) -> dict[str, Any]: "log_scale": parameter.log_scale, "logit_scale": parameter.logit_scale, "digits": parameter.digits, + "step_size": parameter.step_size, "is_fidelity": parameter.is_fidelity, "target_value": parameter.target_value, } diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index d40598dcf47..ed14783fcc1 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -55,7 +55,7 @@ OptimizationConfig, PreferenceOptimizationConfig, ) -from ax.core.parameter import ChoiceParameter, ParameterType +from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.parameter_constraint import ParameterConstraint from ax.core.runner import Runner from ax.exceptions.core import AxStorageWarning, UnsupportedError @@ -398,6 +398,17 @@ ("ParameterConstraint", get_parameter_constraint), ("ParameterConstraint", get_equality_parameter_constraint), ("RangeParameter", get_range_parameter), + ( + "RangeParameter", + partial( + RangeParameter, + name="x", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ), + ), ("ScalarizedObjective", get_scalarized_objective), ("ScalarizedOutcomeConstraint", get_scalarized_outcome_constraint), ("OrchestratorOptions", get_default_orchestrator_options), @@ -1942,6 +1953,32 @@ def test_multi_objective_from_json_warning(self) -> None: any("Found unexpected kwargs" in warning for warning in cm.output) ) + def test_range_parameter_legacy_digits_blob_decodes(self) -> None: + # A legacy blob has "digits" but no "step_size" key. It must still + # decode (the constructor accepts digits for back-compat). + legacy_blob = { + "__type": "RangeParameter", + "name": "x", + "parameter_type": {"__type": "ParameterType", "name": "FLOAT"}, + "lower": 0.0, + "upper": 1.0, + "log_scale": False, + "logit_scale": False, + "digits": 2, + "is_fidelity": False, + "target_value": None, + } + decoded = object_from_json( + legacy_blob, + decoder_registry=CORE_DECODER_REGISTRY, + class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, + ) + self.assertIsInstance(decoded, RangeParameter) + self.assertEqual(decoded.digits, 2) + self.assertIsNone(decoded.step_size) + # Rounding behavior from digits=2 is preserved. + self.assertEqual(decoded.cast(0.123), 0.12) + def test_choice_parameter_bypass_cardinality_check_encode_failure(self) -> None: choice_parameter = ChoiceParameter( name="test_choice", diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index d564c764765..3ac2d3c906b 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -470,13 +470,21 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter: "`dependents` unexpectedly non-null on range parameter " f"{parameter_sqa.name}." ) + # Prefer the newer ``step_size`` column; fall back to the legacy + # ``digits`` column for rows written by older code. Only one may be + # passed to the constructor (it rejects both being set). The + # constructor converts ``digits`` to ``step_size`` internally. + # We should never have the two together in the DB, so this is extra. + step_size = parameter_sqa.step_size + digits = None if step_size is not None else parameter_sqa.digits parameter = RangeParameter( name=parameter_sqa.name, parameter_type=parameter_sqa.parameter_type, lower=float(none_throws(parameter_sqa.lower)), upper=float(none_throws(parameter_sqa.upper)), log_scale=parameter_sqa.log_scale or False, - digits=parameter_sqa.digits, + digits=digits, + step_size=float(step_size) if step_size is not None else None, is_fidelity=parameter_sqa.is_fidelity or False, target_value=parameter_sqa.target_value, backfill_value=parameter_sqa.backfill_value, diff --git a/ax/storage/sqa_store/encoder.py b/ax/storage/sqa_store/encoder.py index 7ee21b70b8d..2ba801ab61a 100644 --- a/ax/storage/sqa_store/encoder.py +++ b/ax/storage/sqa_store/encoder.py @@ -308,6 +308,7 @@ def parameter_to_sqa(self, parameter: Parameter) -> SQAParameter: upper=float(parameter.upper), log_scale=parameter.log_scale, digits=parameter.digits, + step_size=parameter.step_size, is_fidelity=parameter.is_fidelity, target_value=parameter.target_value, dependents=parameter.dependents if parameter.is_hierarchical else None, diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 5f3c50540b4..d82dc508f37 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1880,6 +1880,41 @@ def test_logit_scale(self) -> None: ) ) + def test_step_size_round_trip(self) -> None: + parameter = RangeParameter( + name="foo", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + step_size=0.1, + ) + sqa_parameter = self.encoder.parameter_to_sqa(parameter) + self.assertEqual(sqa_parameter.step_size, 0.1) + self.assertIsNone(sqa_parameter.digits) + decoded = self.decoder.parameter_from_sqa(sqa_parameter) + self.assertEqual(decoded, parameter) + self.assertEqual(assert_is_instance(decoded, RangeParameter).step_size, 0.1) + + def test_legacy_digits_row_decodes_via_fallback(self) -> None: + # A legacy row carries ``digits`` but no ``step_size``. The decoder + # falls back to ``digits``, which the constructor converts internally. + parameter = RangeParameter( + name="foo", + parameter_type=ParameterType.FLOAT, + lower=0.0, + upper=1.0, + digits=2, + ) + sqa_parameter = self.encoder.parameter_to_sqa(parameter) + # Simulate an old row: clear step_size, keep digits populated. + sqa_parameter.step_size = None + sqa_parameter.digits = 2 + decoded = assert_is_instance( + self.decoder.parameter_from_sqa(sqa_parameter), RangeParameter + ) + # The decoded parameter round-trips its rounding behavior. + self.assertEqual(decoded.cast(0.123), parameter.cast(0.123)) + def test_bypass_cardinality_check(self) -> None: choice_parameter = ChoiceParameter( name="test_choice",