@@ -341,6 +341,7 @@ def __init__(
341341 log_scale : bool = False ,
342342 logit_scale : bool = False ,
343343 digits : int | None = None ,
344+ step_size : float | None = None ,
344345 is_fidelity : bool = False ,
345346 target_value : TParamValue = None ,
346347 backfill_value : TParamValue = None ,
@@ -359,6 +360,25 @@ def __init__(
359360 logit_scale: Whether to sample in logit space when drawing
360361 random values of the parameter.
361362 digits: Number of digits to round values to for float type.
363+ Deprecated in favor of ``step_size``; cannot be set together
364+ with ``step_size``.
365+ step_size: If set, the parameter's feasible values are the grid
366+ ``{lower + k * step_size : k in N}`` intersected with
367+ ``[lower, upper]``. ``cast()`` snaps values to the nearest grid
368+ point (anchored at ``lower``) without clamping to the bounds, so
369+ an out-of-bounds input snaps to an out-of-bounds grid point --
370+ mirroring the non-``step_size`` ``cast()``, which also leaves
371+ out-of-bounds values in place. ``step_size`` must be strictly
372+ positive, and the range must be an exact multiple of it:
373+ ``(upper - lower)`` must be an integer multiple of ``step_size``
374+ (within ``EPS``), so that both bounds lie on the grid. For INT
375+ parameters, ``step_size`` must itself be integer-valued.
376+
377+ ``step_size`` defines a discrete grid but does not, by itself,
378+ force discrete acquisition optimization. How the optimizer
379+ treats the parameter depends on the grid cardinality
380+ ``floor((upper - lower) / step_size) + 1``, and is determined
381+ at the generator level.
362382 is_fidelity: Whether this parameter is a fidelity parameter.
363383 target_value: Target value of this parameter if it is a fidelity.
364384 backfill_value: For parameters added to experiments that have already run
@@ -378,6 +398,10 @@ def __init__(
378398 raise UserInputError ("RangeParameter type must be int or float." )
379399 self ._parameter_type = parameter_type
380400 self ._digits = digits
401+ # ``_step_size`` must be set before casting ``lower`` / ``upper`` below,
402+ # since ``cast()`` reads it to snap values to the grid.
403+ self ._step_size : float | None = None
404+ self ._validate_and_set_step_size (step_size = step_size )
381405 self ._lower : TNumeric = self .cast (lower )
382406 self ._upper : TNumeric = self .cast (upper )
383407 self ._log_scale = log_scale
@@ -393,15 +417,32 @@ def __init__(
393417 self .cast (default_value ) if default_value is not None else None
394418 )
395419
420+ # Validate the raw inputs: this rejects invalid user input (e.g. a
421+ # non-integer bound for an INT parameter) before ``cast()`` silently
422+ # truncates it. For the non-deprecated paths ``cast()`` does not move a
423+ # bound that would otherwise pass validation -- FLOAT casting is a no-op
424+ # on the value, and ``step_size`` snapping is skipped for bounds -- so
425+ # validating the raw inputs also guarantees the stored bounds are valid.
396426 self ._validate_range_param (
397427 parameter_type = parameter_type ,
398428 lower = lower ,
399429 upper = upper ,
400430 log_scale = log_scale ,
401431 logit_scale = logit_scale ,
402432 )
433+ # ``upper`` must additionally lie on the ``step_size`` grid (the grid is
434+ # anchored at ``lower``).
435+ self ._validate_step_size_on_grid ()
403436
404437 def cardinality (self ) -> TNumeric :
438+ if self ._step_size is not None :
439+ # Values are snapped to the grid {lower + k * step_size}
440+ # intersected with [lower, upper]. Both bounds lie on the grid
441+ # (enforced at construction), so the number of grid points is
442+ # (upper - lower) / step_size + 1.
443+ step_size = none_throws (self ._step_size )
444+ return round ((float (self .upper ) - float (self .lower )) / step_size ) + 1
445+
405446 if self .parameter_type == ParameterType .FLOAT :
406447 return inf
407448
@@ -493,6 +534,19 @@ def digits(self) -> int | None:
493534 """
494535 return self ._digits
495536
537+ @property
538+ def step_size (self ) -> float | None :
539+ """Grid spacing that values are snapped to in ``cast()``.
540+
541+ If set, the parameter's feasible values are the grid
542+ ``{lower + k * step_size : k in N}`` intersected with ``[lower, upper]``,
543+ and ``cast()`` snaps values to the nearest grid point (without clamping
544+ to the bounds). Both bounds are guaranteed to be on the grid (the
545+ constructor requires ``(upper - lower)`` to be an integer multiple of
546+ ``step_size``). ``None`` means no snapping.
547+ """
548+ return self ._step_size
549+
496550 @property
497551 def log_scale (self ) -> bool :
498552 """Whether the parameter's values should be sampled from log space."""
@@ -519,14 +573,25 @@ def update_range(
519573 if upper is None :
520574 upper = self ._upper
521575
522- cast_lower = self .cast (lower )
523- cast_upper = self .cast (upper )
576+ # When ``step_size`` is set, cast the bounds without snapping to the
577+ # (old) grid: bounds anchor the grid and must not be silently moved onto
578+ # it. ``super().cast()`` applies only the type cast. The digits path
579+ # (deprecated) keeps its historical rounding behavior via ``self.cast``.
580+ if self ._step_size is not None :
581+ cast_lower = assert_is_instance (super ().cast (lower ), TNumeric )
582+ cast_upper = assert_is_instance (super ().cast (upper ), TNumeric )
583+ else :
584+ cast_lower = self .cast (lower )
585+ cast_upper = self .cast (upper )
524586 self ._validate_range_param (
525587 lower = cast_lower ,
526588 upper = cast_upper ,
527589 log_scale = self .log_scale ,
528590 logit_scale = self .logit_scale ,
529591 )
592+ # The new bounds must lie on the ``step_size`` grid, if one is set.
593+ # Validate before committing so a failed update leaves bounds unchanged.
594+ self ._validate_step_size_on_grid (lower = cast_lower , upper = cast_upper )
530595 self ._lower = cast_lower
531596 self ._upper = cast_upper
532597 return self
@@ -546,6 +611,95 @@ def set_digits(self, digits: int | None) -> RangeParameter:
546611 self ._upper = cast_upper
547612 return self
548613
614+ def set_step_size (self , step_size : float | None ) -> RangeParameter :
615+ """Set the grid spacing that values are snapped to in ``cast()``.
616+
617+ The existing bounds are kept as-is (they anchor the grid and define the
618+ feasible range); they are not snapped onto the new grid. Instead we
619+ require that they already lie on it: ``(upper - lower)`` must be an
620+ integer multiple of the new ``step_size``.
621+
622+ Raises:
623+ UserInputError: If the current bounds do not lie on the new grid.
624+ """
625+ previous_step_size = self ._step_size
626+ self ._validate_and_set_step_size (step_size = step_size )
627+ try :
628+ # The current (unchanged) bounds must lie on the new grid.
629+ self ._validate_step_size_on_grid ()
630+ except UserInputError :
631+ # Leave the parameter unchanged if the new grid is invalid.
632+ self ._step_size = previous_step_size
633+ raise
634+ return self
635+
636+ def _validate_and_set_step_size (self , step_size : float | None ) -> None :
637+ """Validate ``step_size`` and store it on ``self._step_size``.
638+
639+ Raises:
640+ UserInputError: If ``step_size`` is non-positive, if it is set
641+ together with ``digits``, or if it is not integer-valued for an
642+ INT parameter.
643+ """
644+ if step_size is None :
645+ self ._step_size = None
646+ return
647+ if self ._digits is not None :
648+ raise UserInputError (
649+ f"Cannot set both `digits` and `step_size` on parameter "
650+ f"{ self ._name } . `digits` is deprecated; use `step_size` only."
651+ )
652+ if step_size <= 0 :
653+ raise UserInputError (
654+ f"`step_size` must be strictly positive for parameter "
655+ f"{ self ._name } . Got: { step_size } ."
656+ )
657+ if (
658+ self ._parameter_type is ParameterType .INT
659+ and not float (step_size ).is_integer ()
660+ ):
661+ raise UserInputError (
662+ f"`step_size` must be integer-valued for INT parameter "
663+ f"{ self ._name } . Got: { step_size } ."
664+ )
665+ self ._step_size = float (step_size )
666+
667+ def _validate_step_size_on_grid (
668+ self , lower : TNumeric | None = None , upper : TNumeric | None = None
669+ ) -> None :
670+ """Validate that both bounds lie on the ``step_size`` grid.
671+
672+ The grid is anchored at ``lower``, so ``lower`` is always on it. This
673+ additionally requires ``upper`` to be on the grid, i.e. that
674+ ``(upper - lower)`` is an integer multiple of ``step_size`` (within
675+ ``EPS``). This guarantees ``upper`` is itself a feasible value, so a
676+ value near the upper bound snaps to ``upper`` rather than to a grid
677+ point short of it.
678+
679+ Args:
680+ lower: Lower bound to validate against. Defaults to ``self._lower``.
681+ upper: Upper bound to validate against. Defaults to ``self._upper``.
682+ These overrides let callers validate prospective bounds before
683+ committing them.
684+
685+ Raises:
686+ UserInputError: If ``upper`` does not lie on the grid.
687+ """
688+ if self ._step_size is None :
689+ return
690+ lower = self ._lower if lower is None else lower
691+ upper = self ._upper if upper is None else upper
692+ step_size = none_throws (self ._step_size )
693+ width = float (upper ) - float (lower )
694+ n = width / step_size
695+ if abs (n - round (n )) * step_size > EPS :
696+ raise UserInputError (
697+ f"`step_size` must evenly divide the range of parameter "
698+ f"{ self ._name } : (upper - lower) = { width } is not an integer "
699+ f"multiple of step_size = { step_size } . Adjust the bounds or "
700+ f"step_size so that both bounds lie on the grid."
701+ )
702+
549703 def set_log_scale (self , log_scale : bool ) -> RangeParameter :
550704 self ._log_scale = log_scale
551705 return self
@@ -647,6 +801,7 @@ def clone(self) -> RangeParameter:
647801 log_scale = self ._log_scale ,
648802 logit_scale = self ._logit_scale ,
649803 digits = self ._digits ,
804+ step_size = self ._step_size ,
650805 is_fidelity = self ._is_fidelity ,
651806 target_value = self ._target_value ,
652807 backfill_value = self ._backfill_value ,
@@ -657,13 +812,50 @@ def cast(self, value: TParamValue) -> TNumeric:
657812 value = super ().cast (value = value )
658813 if self .parameter_type is ParameterType .FLOAT and self ._digits is not None :
659814 return round (float (value ), none_throws (self ._digits ))
815+ # Skip snapping while the constructor is still casting the bounds
816+ # themselves (before both ``self._lower`` and ``self._upper`` are set):
817+ # the bounds anchor the grid and must not be snapped (``upper`` is only
818+ # validated to be on the grid after both are assigned). ``_snap_to_grid``
819+ # needs ``self._lower``; gating on ``self._upper`` too is what excludes
820+ # the ``upper`` cast at construction.
821+ if (
822+ self ._step_size is not None
823+ and getattr (self , "_lower" , None ) is not None
824+ and getattr (self , "_upper" , None ) is not None
825+ ):
826+ value = self ._snap_to_grid (value = float (value ))
660827 return assert_is_instance (value , TNumeric )
661828
829+ def _snap_to_grid (self , value : float ) -> TNumeric :
830+ """Snap ``value`` to the nearest grid point.
831+
832+ The grid is ``{lower + k * step_size : k in Z}``. The nearest grid point
833+ is found by rounding ``(value - lower) / step_size`` to the nearest
834+ integer. The result is *not* clamped to ``[lower, upper]``: an
835+ out-of-bounds input (e.g. historical observations recorded outside the
836+ current bounds) snaps to the nearest grid point, which may itself lie
837+ outside the bounds. This mirrors the non-``step_size`` ``cast()``, which
838+ leaves out-of-bounds values untouched rather than silently moving them
839+ into range -- range validity is enforced by ``validate()``, not by
840+ ``cast()``. For INT parameters the snapped value is integer-valued
841+ (``step_size`` is validated to be an integer), so it is returned as an
842+ ``int``.
843+ """
844+ step_size = none_throws (self ._step_size )
845+ lower = float (self ._lower )
846+ n = round ((value - lower ) / step_size )
847+ snapped = lower + n * step_size
848+ if self .parameter_type is ParameterType .INT :
849+ return int (round (snapped ))
850+ return snapped
851+
662852 def __repr__ (self ) -> str :
663853 ret_val = self ._base_repr ()
664854
665855 if self ._digits is not None :
666856 ret_val += f", digits={ self ._digits } "
857+ if self ._step_size is not None :
858+ ret_val += f", step_size={ self ._step_size } "
667859
668860 return ret_val + ")"
669861
0 commit comments