Skip to content

Commit 4f19839

Browse files
saitcakmakmeta-codesync[bot]
authored andcommitted
Add storage support for RangeParameter.step_size (JSON + SQA)
Summary: Teaches both storage backends to persist and restore the new `step_size` field on `RangeParameter`. Second diff in the step_size unification stack (builds on the native `step_size` support added in D107274057). The underlying SQA column (`SQAParameter.step_size`) and the corresponding EntAE schema were added and landed separately (D107108021, D107112306), so this diff is pure encoder/decoder logic — it does not touch the schema. SQA: the encoder writes `parameter.step_size` to the column (and continues writing the legacy `digits` column for now; `digits` is dropped in a later diff). The decoder prefers `step_size` and falls back to the legacy `digits` column for rows written by older code, passing only one to the constructor (which rejects both being set; it converts `digits` to `step_size` internally). JSON: `range_parameter_to_dict` emits `step_size` alongside `digits`. There is no custom JSON decoder for `RangeParameter` — the registry maps it directly to the class, and `ax_class_from_json_dict` splats the serialized dict straight into the constructor. So `step_size` flows through automatically, and legacy blobs carrying only `digits` still decode via the constructor's back-compat handling. No decoder change is required for JSON. Differential Revision: D107282941
1 parent f786271 commit 4f19839

5 files changed

Lines changed: 84 additions & 2 deletions

File tree

ax/storage/json_store/encoders.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def range_parameter_to_dict(parameter: RangeParameter) -> dict[str, Any]:
201201
"log_scale": parameter.log_scale,
202202
"logit_scale": parameter.logit_scale,
203203
"digits": parameter.digits,
204+
"step_size": parameter.step_size,
204205
"is_fidelity": parameter.is_fidelity,
205206
"target_value": parameter.target_value,
206207
}

ax/storage/json_store/tests/test_json_store.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
OptimizationConfig,
5656
PreferenceOptimizationConfig,
5757
)
58-
from ax.core.parameter import ChoiceParameter, ParameterType
58+
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
5959
from ax.core.parameter_constraint import ParameterConstraint
6060
from ax.core.runner import Runner
6161
from ax.exceptions.core import AxStorageWarning, UnsupportedError
@@ -398,6 +398,17 @@
398398
("ParameterConstraint", get_parameter_constraint),
399399
("ParameterConstraint", get_equality_parameter_constraint),
400400
("RangeParameter", get_range_parameter),
401+
(
402+
"RangeParameter",
403+
partial(
404+
RangeParameter,
405+
name="x",
406+
parameter_type=ParameterType.FLOAT,
407+
lower=0.0,
408+
upper=1.0,
409+
step_size=0.1,
410+
),
411+
),
401412
("ScalarizedObjective", get_scalarized_objective),
402413
("ScalarizedOutcomeConstraint", get_scalarized_outcome_constraint),
403414
("OrchestratorOptions", get_default_orchestrator_options),
@@ -1942,6 +1953,32 @@ def test_multi_objective_from_json_warning(self) -> None:
19421953
any("Found unexpected kwargs" in warning for warning in cm.output)
19431954
)
19441955

1956+
def test_range_parameter_legacy_digits_blob_decodes(self) -> None:
1957+
# A legacy blob has "digits" but no "step_size" key. It must still
1958+
# decode (the constructor accepts digits for back-compat).
1959+
legacy_blob = {
1960+
"__type": "RangeParameter",
1961+
"name": "x",
1962+
"parameter_type": {"__type": "ParameterType", "name": "FLOAT"},
1963+
"lower": 0.0,
1964+
"upper": 1.0,
1965+
"log_scale": False,
1966+
"logit_scale": False,
1967+
"digits": 2,
1968+
"is_fidelity": False,
1969+
"target_value": None,
1970+
}
1971+
decoded = object_from_json(
1972+
legacy_blob,
1973+
decoder_registry=CORE_DECODER_REGISTRY,
1974+
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
1975+
)
1976+
self.assertIsInstance(decoded, RangeParameter)
1977+
self.assertEqual(decoded.digits, 2)
1978+
self.assertIsNone(decoded.step_size)
1979+
# Rounding behavior from digits=2 is preserved.
1980+
self.assertEqual(decoded.cast(0.123), 0.12)
1981+
19451982
def test_choice_parameter_bypass_cardinality_check_encode_failure(self) -> None:
19461983
choice_parameter = ChoiceParameter(
19471984
name="test_choice",

ax/storage/sqa_store/decoder.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,13 +470,21 @@ def parameter_from_sqa(self, parameter_sqa: SQAParameter) -> Parameter:
470470
"`dependents` unexpectedly non-null on range parameter "
471471
f"{parameter_sqa.name}."
472472
)
473+
# Prefer the newer ``step_size`` column; fall back to the legacy
474+
# ``digits`` column for rows written by older code. Only one may be
475+
# passed to the constructor (it rejects both being set). The
476+
# constructor converts ``digits`` to ``step_size`` internally.
477+
# We should never have the two together in the DB, so this is extra.
478+
step_size = parameter_sqa.step_size
479+
digits = None if step_size is not None else parameter_sqa.digits
473480
parameter = RangeParameter(
474481
name=parameter_sqa.name,
475482
parameter_type=parameter_sqa.parameter_type,
476483
lower=float(none_throws(parameter_sqa.lower)),
477484
upper=float(none_throws(parameter_sqa.upper)),
478485
log_scale=parameter_sqa.log_scale or False,
479-
digits=parameter_sqa.digits,
486+
digits=digits,
487+
step_size=float(step_size) if step_size is not None else None,
480488
is_fidelity=parameter_sqa.is_fidelity or False,
481489
target_value=parameter_sqa.target_value,
482490
backfill_value=parameter_sqa.backfill_value,

ax/storage/sqa_store/encoder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ def parameter_to_sqa(self, parameter: Parameter) -> SQAParameter:
308308
upper=float(parameter.upper),
309309
log_scale=parameter.log_scale,
310310
digits=parameter.digits,
311+
step_size=parameter.step_size,
311312
is_fidelity=parameter.is_fidelity,
312313
target_value=parameter.target_value,
313314
dependents=parameter.dependents if parameter.is_hierarchical else None,

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,6 +1880,41 @@ def test_logit_scale(self) -> None:
18801880
)
18811881
)
18821882

1883+
def test_step_size_round_trip(self) -> None:
1884+
parameter = RangeParameter(
1885+
name="foo",
1886+
parameter_type=ParameterType.FLOAT,
1887+
lower=0.0,
1888+
upper=1.0,
1889+
step_size=0.1,
1890+
)
1891+
sqa_parameter = self.encoder.parameter_to_sqa(parameter)
1892+
self.assertEqual(sqa_parameter.step_size, 0.1)
1893+
self.assertIsNone(sqa_parameter.digits)
1894+
decoded = self.decoder.parameter_from_sqa(sqa_parameter)
1895+
self.assertEqual(decoded, parameter)
1896+
self.assertEqual(assert_is_instance(decoded, RangeParameter).step_size, 0.1)
1897+
1898+
def test_legacy_digits_row_decodes_via_fallback(self) -> None:
1899+
# A legacy row carries ``digits`` but no ``step_size``. The decoder
1900+
# falls back to ``digits``, which the constructor converts internally.
1901+
parameter = RangeParameter(
1902+
name="foo",
1903+
parameter_type=ParameterType.FLOAT,
1904+
lower=0.0,
1905+
upper=1.0,
1906+
digits=2,
1907+
)
1908+
sqa_parameter = self.encoder.parameter_to_sqa(parameter)
1909+
# Simulate an old row: clear step_size, keep digits populated.
1910+
sqa_parameter.step_size = None
1911+
sqa_parameter.digits = 2
1912+
decoded = assert_is_instance(
1913+
self.decoder.parameter_from_sqa(sqa_parameter), RangeParameter
1914+
)
1915+
# The decoded parameter round-trips its rounding behavior.
1916+
self.assertEqual(decoded.cast(0.123), parameter.cast(0.123))
1917+
18831918
def test_bypass_cardinality_check(self) -> None:
18841919
choice_parameter = ChoiceParameter(
18851920
name="test_choice",

0 commit comments

Comments
 (0)