Skip to content

Commit 6c023bc

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Migrate transforms and utils to RangeParameter.step_size
Summary: Teaches the transform/util layer about step_size. Third diff in the step_size unification stack. - Cast: snaps RangeParameter values to the grid via parameter.cast() on both the observation-features path (already calls cast) and the experiment_data dataframe path (replacing the .round(digits) call), for params with digits OR step_size set. - Log/Logit/UnitX: clear step_size before rescaling (in addition to clearing digits, until digits is fully removed in a later diff); step_size is re-applied in the original space by Cast on untransform. - int_to_float: does not forward step_size to the FLOAT surrogate (anchor would be misaligned); the original INT param's snapping is re-applied by Cast. See TODO. - map_key_to_float, transfer_learning merge_parameters, service instantiation: forward/read step_size, preferring it over digits. - service instantiation: add step_size to EXPECTED_KEYS_IN_PARAM_REPR. The RangeParameter construction path already read representation["step_size"], but the key was not in the recognized-keys set, so any parameter representation passing step_size would have been rejected with an "Unexpected keys" error. This makes the step_size representation path actually usable. - core_stubs: add get_range_parameter_with_step_size helper. Test coverage added for the migrated paths: Log/Logit transform_search_space clears step_size (mirroring the existing UnitX and clears_digits tests), MapKeyToFloat forwards step_size from config, merge_parameters forwards step_size/digits from p1, and parameter_from_json accepts step_size and legacy digits for range parameters. Differential Revision: D107284896
1 parent 4f19839 commit 6c023bc

16 files changed

Lines changed: 328 additions & 61 deletions

ax/adapter/transfer_learning/tests/test_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,41 @@ def test_merge_parameters(self) -> None:
254254
name="p", parameter_type=ParameterType.INT, lower=0, upper=3
255255
),
256256
)
257+
# The grid spec (step_size, or legacy digits) is forwarded from p1, even
258+
# when the merged bounds are widened by p2 (as long as they stay on p1's
259+
# grid).
260+
p_range_step = RangeParameter(
261+
name="p",
262+
parameter_type=ParameterType.FLOAT,
263+
lower=0.0,
264+
upper=1.0,
265+
step_size=0.1,
266+
)
267+
p_range_wide = RangeParameter(
268+
name="p", parameter_type=ParameterType.FLOAT, lower=0.0, upper=2.0
269+
)
270+
merged_step = assert_is_instance(
271+
merge_parameters(p1=p_range_step, p2=p_range_wide, reverse_param_config={}),
272+
RangeParameter,
273+
)
274+
self.assertEqual(merged_step.upper, 2.0)
275+
self.assertEqual(merged_step.step_size, 0.1)
276+
self.assertIsNone(merged_step.digits)
277+
p_range_digits = RangeParameter(
278+
name="p",
279+
parameter_type=ParameterType.FLOAT,
280+
lower=0.0,
281+
upper=1.0,
282+
digits=2,
283+
)
284+
merged_digits = assert_is_instance(
285+
merge_parameters(
286+
p1=p_range_digits, p2=p_range_wide, reverse_param_config={}
287+
),
288+
RangeParameter,
289+
)
290+
self.assertEqual(merged_digits.digits, 2)
291+
self.assertIsNone(merged_digits.step_size)
257292
p_choice_1 = ChoiceParameter(
258293
name="p",
259294
parameter_type=ParameterType.STRING,

ax/adapter/transfer_learning/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def merge_parameters(
111111
log_scale=p1.log_scale,
112112
logit_scale=p1.logit_scale,
113113
digits=p1.digits,
114+
step_size=p1.step_size,
114115
is_fidelity=p1.is_fidelity,
115116
target_value=p1.target_value,
116117
)

ax/adapter/transforms/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
9494
transform (does nothing).
9595
9696
NOTE for subclasses: If a transform changes the *scale* of a
97-
RangeParameter (e.g., Log, UnitX, Logit), it must clear ``digits``
98-
via ``p.set_digits(digits=None)`` before calling ``update_range``.
99-
Otherwise, rounding calibrated for the original scale will corrupt
100-
the transformed bounds (e.g., ``digits=-3`` rounds to the nearest
101-
1000, which collapses [0, 1] to 0). The Cast transform re-applies
102-
``digits`` in the original space during untransform.
97+
RangeParameter (e.g., Log, UnitX, Logit), it must clear ``step_size``
98+
via ``p.set_step_size(step_size=None)`` before calling ``update_range``.
99+
Otherwise, snapping calibrated for the original scale will corrupt the
100+
transformed bounds (a grid spacing meaningful in the original space is
101+
meaningless after a non-linear rescale). The Cast transform re-applies
102+
``step_size`` in the original space during untransform.
103103
104104
Args:
105105
search_space: The search space

ax/adapter/transforms/cast.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ax.exceptions.core import UserInputError
2626
from ax.generators.types import TConfig
2727
from ax.utils.common.constants import Keys
28-
from pandas import DataFrame
28+
from pandas import DataFrame, Series
2929
from pyre_extensions import assert_is_instance, none_throws
3030

3131
if TYPE_CHECKING:
@@ -314,11 +314,38 @@ def transform_experiment_data(
314314
for p, param in self.search_space.parameters.items()
315315
}
316316
arm_data = arm_data.astype(dtype=column_to_type)
317-
# Round to digits if any parameter specifies it.
317+
# Snap to the parameter's grid (digits or step_size) if specified.
318+
# These mirror ``RangeParameter.cast``'s rounding logic, but are applied
319+
# in a vectorized manner over the whole column rather than via a per-row
320+
# ``Series.apply`` (which calls ``parameter.cast`` once per element and is
321+
# slow for large DataFrames). NaN / ``<NA>`` values (added for missing
322+
# columns during the ``reindex`` above) propagate through ``round`` and
323+
# the arithmetic, matching the previous ``value if value is None`` guard.
318324
for p_name in parameter_names:
319325
parameter = self.search_space.parameters[p_name]
320-
if isinstance(parameter, RangeParameter) and parameter.digits is not None:
321-
arm_data[p_name] = arm_data[p_name].round(parameter.digits)
326+
if not isinstance(parameter, RangeParameter):
327+
continue
328+
column: Series = arm_data[p_name]
329+
if (
330+
parameter.parameter_type is ParameterType.FLOAT
331+
and parameter.digits is not None
332+
):
333+
# ``Series.round`` uses round-half-to-even, same as Python's
334+
# built-in ``round`` used in ``RangeParameter.cast``.
335+
arm_data[p_name] = column.round(parameter.digits)
336+
elif parameter.step_size is not None:
337+
# Snap to the grid ``{lower + k * step_size : k in Z}`` by
338+
# rounding ``(value - lower) / step_size`` to the nearest integer.
339+
lower = float(parameter.lower)
340+
step_size = none_throws(parameter.step_size)
341+
steps: Series = column.sub(lower).div(step_size).round()
342+
snapped: Series = steps.mul(step_size).add(lower)
343+
if parameter.parameter_type is ParameterType.INT:
344+
# Preserve the nullable ``Int64`` dtype so reindex-added
345+
# ``<NA>`` values survive the cast.
346+
arm_data[p_name] = snapped.round().astype("Int64")
347+
else:
348+
arm_data[p_name] = snapped
322349

323350
return ExperimentData(arm_data=arm_data, observation_data=observation_data)
324351

ax/adapter/transforms/log.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
7979
isinstance(p, RangeParameter)
8080
and p.parameter_type == ParameterType.FLOAT
8181
):
82-
# Don't round in log space
82+
# Don't snap/round in log space; step_size (or legacy
83+
# digits) will be re-applied in the original space by the
84+
# Cast transform during untransform. Both are cleared until
85+
# digits is fully removed (see step_size unification RFC).
86+
if p.step_size is not None:
87+
p.set_step_size(step_size=None)
8388
if p.digits is not None:
8489
p.set_digits(digits=None)
8590
p.set_log_scale(False).update_range(

ax/adapter/transforms/logit.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ def transform_observation_features(
6666
def transform_search_space(self, search_space: SearchSpace) -> SearchSpace:
6767
for p_name, p in search_space.parameters.items():
6868
if p_name in self.transform_parameters and isinstance(p, RangeParameter):
69-
# Don't round in logit space; digits will be re-applied in
70-
# the original space by the Cast transform during untransform.
69+
# Don't snap/round in logit space; step_size (or legacy digits)
70+
# will be re-applied in the original space by the Cast transform
71+
# during untransform. Both are cleared until digits is fully
72+
# removed (see step_size unification RFC).
73+
if p.step_size is not None:
74+
p.set_step_size(step_size=None)
7175
if p.digits is not None:
7276
p.set_digits(digits=None)
7377
p.set_logit_scale(False).update_range(

ax/adapter/transforms/map_key_to_float.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,13 +127,18 @@ def __init__(
127127
return
128128

129129
p_config = self.parameters[MAP_KEY]
130+
# Prefer ``step_size``; fall back to legacy ``digits``. Only one may
131+
# be passed to the constructor (it rejects both being set).
132+
step_size = p_config.get("step_size", None)
133+
digits = None if step_size is not None else p_config.get("digits", None)
130134
self._parameter_list.append(
131135
RangeParameter(
132136
name=MAP_KEY,
133137
parameter_type=ParameterType.FLOAT,
134138
lower=p_config.get("lower", min(values)),
135139
upper=p_config.get("upper", max(values)),
136-
digits=p_config.get("digits", None),
140+
digits=digits,
141+
step_size=step_size,
137142
is_fidelity=p_config.get("is_fidelity", False),
138143
target_value=p_config.get("target_value", None),
139144
)

ax/adapter/transforms/tests/test_cast_transform.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,77 @@ def test_cast_parameter_type_and_none(self) -> None:
323323
]
324324
self.assertEqual(tf_observations, expected)
325325

326+
def test_cast_step_size_observation_features(self) -> None:
327+
# Cast snaps RangeParameter values to step_size on (un)transform, just
328+
# as it rounds to digits.
329+
search_space = SearchSpace(
330+
parameters=[
331+
RangeParameter(
332+
name="range",
333+
parameter_type=ParameterType.FLOAT,
334+
lower=0.0,
335+
upper=1.0,
336+
step_size=0.1,
337+
),
338+
]
339+
)
340+
t = Cast(search_space=search_space)
341+
obs_features = [
342+
ObservationFeatures(parameters={"range": 0.12}),
343+
ObservationFeatures(parameters={"range": 0.36}),
344+
]
345+
tf_obs_features = t.transform_observation_features(
346+
observation_features=obs_features
347+
)
348+
self.assertAlmostEqual(
349+
float(none_throws(tf_obs_features[0].parameters["range"])), 0.1
350+
)
351+
self.assertAlmostEqual(
352+
float(none_throws(tf_obs_features[1].parameters["range"])), 0.4
353+
)
354+
355+
def test_transform_experiment_data_step_size(self) -> None:
356+
# The experiment_data dataframe path snaps RangeParameter values to
357+
# step_size, for both FLOAT and INT parameters. The INT parameter also
358+
# checks that the snapped column keeps the nullable Int64 dtype.
359+
experiment = get_experiment_with_observations(
360+
observations=[[0.0], [1.0]],
361+
search_space=SearchSpace(
362+
parameters=[
363+
RangeParameter(
364+
name="x",
365+
parameter_type=ParameterType.FLOAT,
366+
lower=0.0,
367+
upper=1.0,
368+
step_size=0.1,
369+
),
370+
RangeParameter(
371+
name="y",
372+
parameter_type=ParameterType.INT,
373+
lower=0,
374+
upper=10,
375+
step_size=2,
376+
),
377+
]
378+
),
379+
parameterizations=[
380+
{"x": 0.12, "y": 3},
381+
{"x": 0.36, "y": 7},
382+
],
383+
)
384+
experiment_data = extract_experiment_data(
385+
experiment=experiment, data_loader_config=DataLoaderConfig()
386+
)
387+
transformed = Cast(
388+
search_space=experiment.search_space
389+
).transform_experiment_data(experiment_data=deepcopy(experiment_data))
390+
self.assertAlmostEqual(transformed.arm_data["x"].iloc[0], 0.1)
391+
self.assertAlmostEqual(transformed.arm_data["x"].iloc[1], 0.4)
392+
# 3 snaps to 4 (round half to even: 1.5 -> 2 steps), 7 snaps to 8.
393+
self.assertEqual(transformed.arm_data["y"].iloc[0], 4)
394+
self.assertEqual(transformed.arm_data["y"].iloc[1], 8)
395+
self.assertEqual(transformed.arm_data["y"].dtype, "Int64")
396+
326397
def test_transform_experiment_data_flatten(self) -> None:
327398
# Tests for flattening of hierarchical parameterizations.
328399
columns = [

ax/adapter/transforms/tests/test_log_transform.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,36 @@ def test_TransformSearchSpace(self) -> None:
117117
ss2 = deepcopy(self.search_space)
118118
ss2 = self.t.transform_search_space(ss2)
119119

120-
# Test float log-scale parameter transformation
120+
# Test float log-scale parameter transformation. The grid (legacy
121+
# ``digits`` here; ``step_size`` covered below) must be cleared during
122+
# the transform -- a grid meaningful in the original space is
123+
# meaningless after a log10 rescale, and Cast re-applies it in the
124+
# original space on untransform.
121125
param_x = assert_is_instance(ss2.parameters["x"], RangeParameter)
122126
self.assertEqual(param_x.lower, math.log10(1))
123127
self.assertEqual(param_x.upper, math.log10(3))
124128
self.assertIsNone(param_x.digits)
125129

130+
# Same clearing behavior for ``step_size`` (mutually exclusive with
131+
# ``digits``, so it needs its own parameter).
132+
ss_step = SearchSpace(
133+
parameters=[
134+
RangeParameter(
135+
"x",
136+
lower=1.0,
137+
upper=1000.0,
138+
parameter_type=ParameterType.FLOAT,
139+
log_scale=True,
140+
step_size=1.0,
141+
),
142+
]
143+
)
144+
ss_step = Log(search_space=ss_step).transform_search_space(ss_step)
145+
param_x_step = assert_is_instance(ss_step.parameters["x"], RangeParameter)
146+
self.assertIsNone(param_x_step.step_size)
147+
self.assertEqual(param_x_step.lower, math.log10(1.0))
148+
self.assertEqual(param_x_step.upper, math.log10(1000.0))
149+
126150
# Test integer log-scale parameter transformation (converted to ChoiceParameter)
127151
param_y = assert_is_instance(ss2.parameters["y"], ChoiceParameter)
128152
self.assertEqual(param_y.parameter_type, ParameterType.FLOAT)

ax/adapter/transforms/tests/test_logit_transform.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -122,27 +122,38 @@ def test_TransformSearchSpace(self) -> None:
122122
self.assertEqual(x_param.lower, logit(0.1))
123123
self.assertEqual(x_param.upper, logit(0.3))
124124

125-
def test_transform_search_space_clears_digits(self) -> None:
126-
"""Test that digits is cleared during transform to avoid rounding
127-
in logit space."""
128-
ss = SearchSpace(
129-
parameters=[
130-
RangeParameter(
131-
"x",
132-
lower=0.1,
133-
upper=0.9,
134-
parameter_type=ParameterType.FLOAT,
135-
logit_scale=True,
136-
digits=3,
137-
),
138-
]
139-
)
140-
t = Logit(search_space=ss)
141-
ss = t.transform_search_space(ss)
142-
x = assert_is_instance(ss.parameters["x"], RangeParameter)
143-
self.assertIsNone(x.digits)
144-
self.assertAlmostEqual(x.lower, logit(0.1))
145-
self.assertAlmostEqual(x.upper, logit(0.9))
125+
def test_transform_search_space_clears_grid(self) -> None:
126+
"""The grid (legacy ``digits`` or ``step_size``, which are mutually
127+
exclusive) must be cleared during the transform to avoid rounding /
128+
snapping in logit space; it is re-applied in the original space by Cast
129+
on untransform."""
130+
grid_params = [
131+
RangeParameter(
132+
"x",
133+
lower=0.1,
134+
upper=0.9,
135+
parameter_type=ParameterType.FLOAT,
136+
logit_scale=True,
137+
digits=3,
138+
),
139+
RangeParameter(
140+
"x",
141+
lower=0.1,
142+
upper=0.9,
143+
parameter_type=ParameterType.FLOAT,
144+
logit_scale=True,
145+
step_size=0.1,
146+
),
147+
]
148+
for param in grid_params:
149+
with self.subTest(param=param):
150+
ss = SearchSpace(parameters=[param])
151+
ss = Logit(search_space=ss).transform_search_space(ss)
152+
x = assert_is_instance(ss.parameters["x"], RangeParameter)
153+
self.assertIsNone(x.digits)
154+
self.assertIsNone(x.step_size)
155+
self.assertAlmostEqual(x.lower, logit(0.1))
156+
self.assertAlmostEqual(x.upper, logit(0.9))
146157

147158
def test_transform_experiment_data(self) -> None:
148159
parameterizations = [

0 commit comments

Comments
 (0)