Skip to content

Commit b575948

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Add native step_size support to RangeParameter (#5213)
Summary: Adds a `step_size` arg to `RangeParameter` that snaps values to a grid anchored at `lower` (in `cast()`), for both FLOAT and INT parameters. This is the first diff in the step_size unification stack: `step_size` will subsume both the discrete-grid and limited-resolution (`digits`) use cases under one knob. - Next diff will add storage support. The internal DB has already been updated to include the new column. - We will then migrate all current usage off `digits` and onto `step_size`. - We will add support for treating low-cardinality float-range parameters as discrete in `Adapter`, so that it is efficiently optimized over the correct grid (rather than having to use continuous optimization + rounding). - At this point, we will have proper support for `step_size`, so we can update the ax/api usage to leverage it, rather than resolving to `ChoiceParameter`. - We can then deprecate `digits` and do any remaining clean-up. In this diff `step_size` coexists with the existing `digits` arg (they are mutually exclusive at construction). Subsequent diffs in the stack migrate storage (JSON + SQA), transforms and utils, and the public API (`RangeParameterConfig`) to `step_size`, then deprecate `digits` in favor of it. Behavior: - `cast()` rounds `(value - lower) / step_size` to the nearest integer and returns `lower + n * step_size`. It does NOT clamp to `[lower, upper]`: an out-of-bounds input (e.g. a historical observation recorded outside the current bounds) snaps to the nearest grid point, which may itself be out of bounds. This mirrors the non-`step_size` `cast()`, which leaves out-of-bounds values in place rather than silently moving them into range — range validity is enforced by `validate()`, not `cast()`. - Both bounds must lie on the grid: `(upper - lower)` must be an integer multiple of `step_size` (within `EPS`). Off-grid bounds are rejected at construction. This guarantees `upper` is itself a feasible value, so a value near the upper bound snaps to `upper` rather than to a grid point short of it. - `step_size` must be strictly positive, and must be integer-valued for INT parameters. - `cardinality()` accounts for `step_size`: a grid-valued FLOAT reports the finite number of grid points instead of `inf`, and a grid-valued INT counts grid points rather than every integer in `[lower, upper]`. `step_size` defines a discrete grid but does not, by itself, force discrete acquisition optimization; how the optimizer treats the parameter depends on the grid cardinality and is determined at the generator level. Differential Revision: D107274057
1 parent d0ae700 commit b575948

2 files changed

Lines changed: 385 additions & 2 deletions

File tree

ax/core/parameter.py

Lines changed: 194 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def __init__(
341341
log_scale: bool = False,
342342
logit_scale: bool = False,
343343
digits: int | None = None,
344+
step_size: float | None = None,
344345
is_fidelity: bool = False,
345346
target_value: TParamValue = None,
346347
backfill_value: TParamValue = None,
@@ -359,6 +360,25 @@ def __init__(
359360
logit_scale: Whether to sample in logit space when drawing
360361
random values of the parameter.
361362
digits: Number of digits to round values to for float type.
363+
Deprecated in favor of ``step_size``; cannot be set together
364+
with ``step_size``.
365+
step_size: If set, the parameter's feasible values are the grid
366+
``{lower + k * step_size : k in N}`` intersected with
367+
``[lower, upper]``. ``cast()`` snaps values to the nearest grid
368+
point (anchored at ``lower``) without clamping to the bounds, so
369+
an out-of-bounds input snaps to an out-of-bounds grid point --
370+
mirroring the non-``step_size`` ``cast()``, which also leaves
371+
out-of-bounds values in place. ``step_size`` must be strictly
372+
positive, and the range must be an exact multiple of it:
373+
``(upper - lower)`` must be an integer multiple of ``step_size``
374+
(within ``EPS``), so that both bounds lie on the grid. For INT
375+
parameters, ``step_size`` must itself be integer-valued.
376+
377+
``step_size`` defines a discrete grid but does not, by itself,
378+
force discrete acquisition optimization. How the optimizer
379+
treats the parameter depends on the grid cardinality
380+
``floor((upper - lower) / step_size) + 1``, and is determined
381+
at the generator level.
362382
is_fidelity: Whether this parameter is a fidelity parameter.
363383
target_value: Target value of this parameter if it is a fidelity.
364384
backfill_value: For parameters added to experiments that have already run
@@ -378,6 +398,10 @@ def __init__(
378398
raise UserInputError("RangeParameter type must be int or float.")
379399
self._parameter_type = parameter_type
380400
self._digits = digits
401+
# ``_step_size`` must be set before casting ``lower`` / ``upper`` below,
402+
# since ``cast()`` reads it to snap values to the grid.
403+
self._step_size: float | None = None
404+
self._validate_and_set_step_size(step_size=step_size)
381405
self._lower: TNumeric = self.cast(lower)
382406
self._upper: TNumeric = self.cast(upper)
383407
self._log_scale = log_scale
@@ -393,15 +417,32 @@ def __init__(
393417
self.cast(default_value) if default_value is not None else None
394418
)
395419

420+
# Validate the raw inputs: this rejects invalid user input (e.g. a
421+
# non-integer bound for an INT parameter) before ``cast()`` silently
422+
# truncates it. For the non-deprecated paths ``cast()`` does not move a
423+
# bound that would otherwise pass validation -- FLOAT casting is a no-op
424+
# on the value, and ``step_size`` snapping is skipped for bounds -- so
425+
# validating the raw inputs also guarantees the stored bounds are valid.
396426
self._validate_range_param(
397427
parameter_type=parameter_type,
398428
lower=lower,
399429
upper=upper,
400430
log_scale=log_scale,
401431
logit_scale=logit_scale,
402432
)
433+
# ``upper`` must additionally lie on the ``step_size`` grid (the grid is
434+
# anchored at ``lower``).
435+
self._validate_step_size_on_grid()
403436

404437
def cardinality(self) -> TNumeric:
438+
if self._step_size is not None:
439+
# Values are snapped to the grid {lower + k * step_size}
440+
# intersected with [lower, upper]. Both bounds lie on the grid
441+
# (enforced at construction), so the number of grid points is
442+
# (upper - lower) / step_size + 1.
443+
step_size = none_throws(self._step_size)
444+
return round((float(self.upper) - float(self.lower)) / step_size) + 1
445+
405446
if self.parameter_type == ParameterType.FLOAT:
406447
return inf
407448

@@ -493,6 +534,19 @@ def digits(self) -> int | None:
493534
"""
494535
return self._digits
495536

537+
@property
538+
def step_size(self) -> float | None:
539+
"""Grid spacing that values are snapped to in ``cast()``.
540+
541+
If set, the parameter's feasible values are the grid
542+
``{lower + k * step_size : k in N}`` intersected with ``[lower, upper]``,
543+
and ``cast()`` snaps values to the nearest grid point (without clamping
544+
to the bounds). Both bounds are guaranteed to be on the grid (the
545+
constructor requires ``(upper - lower)`` to be an integer multiple of
546+
``step_size``). ``None`` means no snapping.
547+
"""
548+
return self._step_size
549+
496550
@property
497551
def log_scale(self) -> bool:
498552
"""Whether the parameter's values should be sampled from log space."""
@@ -519,14 +573,25 @@ def update_range(
519573
if upper is None:
520574
upper = self._upper
521575

522-
cast_lower = self.cast(lower)
523-
cast_upper = self.cast(upper)
576+
# When ``step_size`` is set, cast the bounds without snapping to the
577+
# (old) grid: bounds anchor the grid and must not be silently moved onto
578+
# it. ``super().cast()`` applies only the type cast. The digits path
579+
# (deprecated) keeps its historical rounding behavior via ``self.cast``.
580+
if self._step_size is not None:
581+
cast_lower = assert_is_instance(super().cast(lower), TNumeric)
582+
cast_upper = assert_is_instance(super().cast(upper), TNumeric)
583+
else:
584+
cast_lower = self.cast(lower)
585+
cast_upper = self.cast(upper)
524586
self._validate_range_param(
525587
lower=cast_lower,
526588
upper=cast_upper,
527589
log_scale=self.log_scale,
528590
logit_scale=self.logit_scale,
529591
)
592+
# The new bounds must lie on the ``step_size`` grid, if one is set.
593+
# Validate before committing so a failed update leaves bounds unchanged.
594+
self._validate_step_size_on_grid(lower=cast_lower, upper=cast_upper)
530595
self._lower = cast_lower
531596
self._upper = cast_upper
532597
return self
@@ -546,6 +611,95 @@ def set_digits(self, digits: int | None) -> RangeParameter:
546611
self._upper = cast_upper
547612
return self
548613

614+
def set_step_size(self, step_size: float | None) -> RangeParameter:
615+
"""Set the grid spacing that values are snapped to in ``cast()``.
616+
617+
The existing bounds are kept as-is (they anchor the grid and define the
618+
feasible range); they are not snapped onto the new grid. Instead we
619+
require that they already lie on it: ``(upper - lower)`` must be an
620+
integer multiple of the new ``step_size``.
621+
622+
Raises:
623+
UserInputError: If the current bounds do not lie on the new grid.
624+
"""
625+
previous_step_size = self._step_size
626+
self._validate_and_set_step_size(step_size=step_size)
627+
try:
628+
# The current (unchanged) bounds must lie on the new grid.
629+
self._validate_step_size_on_grid()
630+
except UserInputError:
631+
# Leave the parameter unchanged if the new grid is invalid.
632+
self._step_size = previous_step_size
633+
raise
634+
return self
635+
636+
def _validate_and_set_step_size(self, step_size: float | None) -> None:
637+
"""Validate ``step_size`` and store it on ``self._step_size``.
638+
639+
Raises:
640+
UserInputError: If ``step_size`` is non-positive, if it is set
641+
together with ``digits``, or if it is not integer-valued for an
642+
INT parameter.
643+
"""
644+
if step_size is None:
645+
self._step_size = None
646+
return
647+
if self._digits is not None:
648+
raise UserInputError(
649+
f"Cannot set both `digits` and `step_size` on parameter "
650+
f"{self._name}. `digits` is deprecated; use `step_size` only."
651+
)
652+
if step_size <= 0:
653+
raise UserInputError(
654+
f"`step_size` must be strictly positive for parameter "
655+
f"{self._name}. Got: {step_size}."
656+
)
657+
if (
658+
self._parameter_type is ParameterType.INT
659+
and not float(step_size).is_integer()
660+
):
661+
raise UserInputError(
662+
f"`step_size` must be integer-valued for INT parameter "
663+
f"{self._name}. Got: {step_size}."
664+
)
665+
self._step_size = float(step_size)
666+
667+
def _validate_step_size_on_grid(
668+
self, lower: TNumeric | None = None, upper: TNumeric | None = None
669+
) -> None:
670+
"""Validate that both bounds lie on the ``step_size`` grid.
671+
672+
The grid is anchored at ``lower``, so ``lower`` is always on it. This
673+
additionally requires ``upper`` to be on the grid, i.e. that
674+
``(upper - lower)`` is an integer multiple of ``step_size`` (within
675+
``EPS``). This guarantees ``upper`` is itself a feasible value, so a
676+
value near the upper bound snaps to ``upper`` rather than to a grid
677+
point short of it.
678+
679+
Args:
680+
lower: Lower bound to validate against. Defaults to ``self._lower``.
681+
upper: Upper bound to validate against. Defaults to ``self._upper``.
682+
These overrides let callers validate prospective bounds before
683+
committing them.
684+
685+
Raises:
686+
UserInputError: If ``upper`` does not lie on the grid.
687+
"""
688+
if self._step_size is None:
689+
return
690+
lower = self._lower if lower is None else lower
691+
upper = self._upper if upper is None else upper
692+
step_size = none_throws(self._step_size)
693+
width = float(upper) - float(lower)
694+
n = width / step_size
695+
if abs(n - round(n)) * step_size > EPS:
696+
raise UserInputError(
697+
f"`step_size` must evenly divide the range of parameter "
698+
f"{self._name}: (upper - lower) = {width} is not an integer "
699+
f"multiple of step_size = {step_size}. Adjust the bounds or "
700+
f"step_size so that both bounds lie on the grid."
701+
)
702+
549703
def set_log_scale(self, log_scale: bool) -> RangeParameter:
550704
self._log_scale = log_scale
551705
return self
@@ -647,6 +801,7 @@ def clone(self) -> RangeParameter:
647801
log_scale=self._log_scale,
648802
logit_scale=self._logit_scale,
649803
digits=self._digits,
804+
step_size=self._step_size,
650805
is_fidelity=self._is_fidelity,
651806
target_value=self._target_value,
652807
backfill_value=self._backfill_value,
@@ -657,13 +812,50 @@ def cast(self, value: TParamValue) -> TNumeric:
657812
value = super().cast(value=value)
658813
if self.parameter_type is ParameterType.FLOAT and self._digits is not None:
659814
return round(float(value), none_throws(self._digits))
815+
# Skip snapping while the constructor is still casting the bounds
816+
# themselves (before both ``self._lower`` and ``self._upper`` are set):
817+
# the bounds anchor the grid and must not be snapped (``upper`` is only
818+
# validated to be on the grid after both are assigned). ``_snap_to_grid``
819+
# needs ``self._lower``; gating on ``self._upper`` too is what excludes
820+
# the ``upper`` cast at construction.
821+
if (
822+
self._step_size is not None
823+
and getattr(self, "_lower", None) is not None
824+
and getattr(self, "_upper", None) is not None
825+
):
826+
value = self._snap_to_grid(value=float(value))
660827
return assert_is_instance(value, TNumeric)
661828

829+
def _snap_to_grid(self, value: float) -> TNumeric:
830+
"""Snap ``value`` to the nearest grid point.
831+
832+
The grid is ``{lower + k * step_size : k in Z}``. The nearest grid point
833+
is found by rounding ``(value - lower) / step_size`` to the nearest
834+
integer. The result is *not* clamped to ``[lower, upper]``: an
835+
out-of-bounds input (e.g. historical observations recorded outside the
836+
current bounds) snaps to the nearest grid point, which may itself lie
837+
outside the bounds. This mirrors the non-``step_size`` ``cast()``, which
838+
leaves out-of-bounds values untouched rather than silently moving them
839+
into range -- range validity is enforced by ``validate()``, not by
840+
``cast()``. For INT parameters the snapped value is integer-valued
841+
(``step_size`` is validated to be an integer), so it is returned as an
842+
``int``.
843+
"""
844+
step_size = none_throws(self._step_size)
845+
lower = float(self._lower)
846+
n = round((value - lower) / step_size)
847+
snapped = lower + n * step_size
848+
if self.parameter_type is ParameterType.INT:
849+
return int(round(snapped))
850+
return snapped
851+
662852
def __repr__(self) -> str:
663853
ret_val = self._base_repr()
664854

665855
if self._digits is not None:
666856
ret_val += f", digits={self._digits}"
857+
if self._step_size is not None:
858+
ret_val += f", step_size={self._step_size}"
667859

668860
return ret_val + ")"
669861

0 commit comments

Comments
 (0)