Skip to content

Commit 56711df

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
cast arms' parameter values when decoding from json/sqa (facebook#4853)
Summary: Pull Request resolved: facebook#4853 See title. This casts arms' parameter values to the proper type when decoding from json/sqa. This fixes an issue where SQ specified via the UI can result in Float parameters being loaded as ints from SQA. This led to issues with imported QEs, where the imported SQ parameter had a value of 1.0, but the signature of the new SQ did not match the signature of the existing SQ value of 1. Reviewed By: Cesar-Cardoso Differential Revision: D92182397 fbshipit-source-id: 80bc928d43d93306acce268d23626fcec66096ff
1 parent 904aa7d commit 56711df

5 files changed

Lines changed: 284 additions & 0 deletions

File tree

ax/storage/json_store/decoder.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pandas as pd
2323
import torch
2424
from ax.adapter.registry import GeneratorRegistryBase
25+
from ax.core.arm import Arm
2526
from ax.core.base_trial import BaseTrial
2627
from ax.core.data import Data
2728
from ax.core.experiment import Experiment
@@ -47,6 +48,7 @@
4748
from ax.generators.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec
4849
from ax.generators.torch.botorch_modular.utils import ModelConfig
4950
from ax.storage.json_store.decoders import (
51+
_cast_parameter_value,
5052
batch_trial_from_json,
5153
botorch_component_from_json,
5254
tensor_from_json,
@@ -72,6 +74,24 @@
7274
logger: Logger = get_logger(__name__)
7375

7476

77+
def _cast_arm_parameters(arm: Arm, search_space: SearchSpace) -> None:
78+
"""Cast arm parameter values to the appropriate Python type.
79+
80+
This is necessary because JSON may deserialize values as different types
81+
(e.g., ints as floats). This function modifies the arm in place.
82+
83+
Args:
84+
arm: The arm whose parameter values should be cast.
85+
search_space: The search space containing parameter type information.
86+
"""
87+
for param_name, param_value in arm._parameters.items():
88+
if param_name in search_space.parameters:
89+
parameter = search_space.parameters[param_name]
90+
arm._parameters[param_name] = _cast_parameter_value(
91+
param_value, parameter.parameter_type
92+
)
93+
94+
7595
def _raise_on_legacy_callable_refs(kwarg_dict: dict[str, Any]) -> dict[str, Any]:
7696
"""Returns kwarg_dict unchanged if no legacy callable refs are present.
7797
@@ -759,11 +779,17 @@ def _load_experiment_info(
759779
)
760780
for trial in exp._trials.values():
761781
for arm in trial.arms:
782+
# Cast arm parameter values to the appropriate type based on the
783+
# search space parameter types. This is necessary because JSON may
784+
# deserialize values as different types (e.g., ints as floats).
785+
_cast_arm_parameters(arm, exp.search_space)
762786
exp._register_arm(arm)
763787
if trial.ttl_seconds is not None:
764788
exp._trials_have_ttl = True
765789
if exp.status_quo is not None:
766790
sq = none_throws(exp.status_quo)
791+
# Cast status_quo arm parameter values as well.
792+
_cast_arm_parameters(sq, exp.search_space)
767793
exp._register_arm(sq)
768794

769795

ax/storage/json_store/decoders.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from ax.core.parameter import (
2727
ChoiceParameter,
2828
FixedParameter,
29+
PARAMETER_PYTHON_TYPE_MAP,
2930
ParameterType,
3031
TParamValue,
3132
)
@@ -93,6 +94,27 @@ def string_to_parameter_value(s: str, parameter_type: ParameterType) -> TParamVa
9394
return s
9495

9596

97+
def _cast_parameter_value(
98+
value: TParamValue, parameter_type: ParameterType
99+
) -> TParamValue:
100+
"""Cast a parameter value to the appropriate Python type based on parameter_type.
101+
102+
This is necessary because JSON may deserialize values as different types
103+
(e.g., ints as floats).
104+
105+
Args:
106+
value: The value to cast.
107+
parameter_type: The ParameterType to cast to.
108+
109+
Returns:
110+
The value cast to the appropriate Python type.
111+
"""
112+
if value is None:
113+
return None
114+
python_type = PARAMETER_PYTHON_TYPE_MAP[parameter_type]
115+
return python_type(value)
116+
117+
96118
def batch_trial_from_json(
97119
experiment: core.experiment.Experiment,
98120
index: int,

ax/storage/json_store/tests/test_json_store.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,6 +1725,145 @@ def test_choice_parameter_backward_compatibility_sort_values(self) -> None:
17251725
)
17261726
)
17271727

1728+
def test_arm_parameter_values_cast_to_parameter_type(self) -> None:
1729+
"""Test that arm parameter values are cast to the appropriate type on load."""
1730+
from ax.core.arm import Arm
1731+
from ax.core.experiment import Experiment
1732+
from ax.core.parameter import RangeParameter
1733+
from ax.core.search_space import SearchSpace
1734+
from ax.storage.json_store.encoder import object_to_json
1735+
1736+
# Create an experiment with INT parameters
1737+
search_space = SearchSpace(
1738+
parameters=[
1739+
RangeParameter(
1740+
name="x",
1741+
parameter_type=ParameterType.INT,
1742+
lower=0,
1743+
upper=10,
1744+
),
1745+
RangeParameter(
1746+
name="y",
1747+
parameter_type=ParameterType.FLOAT,
1748+
lower=0.0,
1749+
upper=1.0,
1750+
),
1751+
]
1752+
)
1753+
1754+
experiment = Experiment(
1755+
name="test_experiment",
1756+
search_space=search_space,
1757+
status_quo=Arm(parameters={"x": 5, "y": 0.5}, name="status_quo"),
1758+
)
1759+
1760+
# Add a trial with an arm
1761+
trial = experiment.new_trial()
1762+
trial.add_arm(Arm(parameters={"x": 3, "y": 0.3}))
1763+
1764+
# Encode the experiment to JSON
1765+
experiment_json = object_to_json(
1766+
experiment,
1767+
encoder_registry=CORE_ENCODER_REGISTRY,
1768+
class_encoder_registry=CORE_CLASS_ENCODER_REGISTRY,
1769+
)
1770+
1771+
# Manually modify the JSON to simulate float values for INT parameters
1772+
# (as could happen when loading from external sources)
1773+
for arm_json in experiment_json["trials"][0]["generator_run"]["arms"]:
1774+
arm_json["parameters"]["x"] = 3.0 # float instead of int
1775+
experiment_json["status_quo"]["parameters"]["x"] = 5.0 # float instead of int
1776+
1777+
# Decode the experiment from JSON
1778+
loaded_experiment = object_from_json(
1779+
experiment_json,
1780+
decoder_registry=CORE_DECODER_REGISTRY,
1781+
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
1782+
)
1783+
1784+
# Check that arm parameter values are cast to the correct type
1785+
loaded_arm = list(loaded_experiment.trials[0].arms)[0]
1786+
self.assertEqual(loaded_arm.parameters["x"], 3)
1787+
self.assertIs(type(loaded_arm.parameters["x"]), int)
1788+
self.assertEqual(loaded_arm.parameters["y"], 0.3)
1789+
self.assertIs(type(loaded_arm.parameters["y"]), float)
1790+
1791+
# Check that status_quo parameter values are cast to the correct type
1792+
status_quo = loaded_experiment.status_quo
1793+
self.assertIsNotNone(status_quo)
1794+
self.assertEqual(status_quo.parameters["x"], 5)
1795+
self.assertIs(type(status_quo.parameters["x"]), int)
1796+
self.assertEqual(status_quo.parameters["y"], 0.5)
1797+
self.assertIs(type(status_quo.parameters["y"]), float)
1798+
1799+
def test_cast_parameter_value_all_types(self) -> None:
1800+
"""Test _cast_parameter_value handles all parameter types correctly."""
1801+
from ax.storage.json_store.decoders import _cast_parameter_value
1802+
1803+
# Test INT casting
1804+
self.assertEqual(_cast_parameter_value(3.0, ParameterType.INT), 3)
1805+
self.assertIs(type(_cast_parameter_value(3.0, ParameterType.INT)), int)
1806+
self.assertEqual(_cast_parameter_value(3, ParameterType.INT), 3)
1807+
self.assertIs(type(_cast_parameter_value(3, ParameterType.INT)), int)
1808+
1809+
# Test FLOAT casting
1810+
self.assertEqual(_cast_parameter_value(3, ParameterType.FLOAT), 3.0)
1811+
self.assertIs(type(_cast_parameter_value(3, ParameterType.FLOAT)), float)
1812+
self.assertEqual(_cast_parameter_value(3.5, ParameterType.FLOAT), 3.5)
1813+
self.assertIs(type(_cast_parameter_value(3.5, ParameterType.FLOAT)), float)
1814+
1815+
# Test BOOL casting
1816+
self.assertEqual(_cast_parameter_value(1, ParameterType.BOOL), True)
1817+
self.assertIs(type(_cast_parameter_value(1, ParameterType.BOOL)), bool)
1818+
self.assertEqual(_cast_parameter_value(0, ParameterType.BOOL), False)
1819+
self.assertIs(type(_cast_parameter_value(0, ParameterType.BOOL)), bool)
1820+
self.assertEqual(_cast_parameter_value(True, ParameterType.BOOL), True)
1821+
self.assertIs(type(_cast_parameter_value(True, ParameterType.BOOL)), bool)
1822+
1823+
# Test STRING casting
1824+
self.assertEqual(_cast_parameter_value("test", ParameterType.STRING), "test")
1825+
self.assertIs(type(_cast_parameter_value("test", ParameterType.STRING)), str)
1826+
self.assertEqual(_cast_parameter_value(123, ParameterType.STRING), "123")
1827+
self.assertIs(type(_cast_parameter_value(123, ParameterType.STRING)), str)
1828+
1829+
# Test None handling
1830+
self.assertIsNone(_cast_parameter_value(None, ParameterType.INT))
1831+
self.assertIsNone(_cast_parameter_value(None, ParameterType.FLOAT))
1832+
self.assertIsNone(_cast_parameter_value(None, ParameterType.BOOL))
1833+
self.assertIsNone(_cast_parameter_value(None, ParameterType.STRING))
1834+
1835+
def test_cast_arm_parameters_skips_unknown_params(self) -> None:
1836+
"""Test that _cast_arm_parameters skips parameters not in search space."""
1837+
from ax.core.arm import Arm
1838+
from ax.core.parameter import RangeParameter
1839+
from ax.core.search_space import SearchSpace
1840+
from ax.storage.json_store.decoder import _cast_arm_parameters
1841+
1842+
search_space = SearchSpace(
1843+
parameters=[
1844+
RangeParameter(
1845+
name="x",
1846+
parameter_type=ParameterType.INT,
1847+
lower=0,
1848+
upper=10,
1849+
),
1850+
]
1851+
)
1852+
1853+
# Create an arm with a parameter that's not in the search space
1854+
arm = Arm(parameters={"x": 3.0, "unknown_param": "some_value"})
1855+
1856+
# Cast should work without error and only cast known parameters
1857+
_cast_arm_parameters(arm, search_space)
1858+
1859+
# x should be cast to int
1860+
self.assertEqual(arm.parameters["x"], 3)
1861+
self.assertIs(type(arm.parameters["x"]), int)
1862+
1863+
# unknown_param should remain unchanged
1864+
self.assertEqual(arm.parameters["unknown_param"], "some_value")
1865+
self.assertIs(type(arm.parameters["unknown_param"]), str)
1866+
17281867
def test_surrogate_spec_backwards_compatibility(self) -> None:
17291868
# This is an invalid example that has both deprecated args
17301869
# and model config specified. Deprecated args will be ignored.

ax/storage/sqa_store/decoder.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from ax.exceptions.storage import JSONDecodeError, SQADecodeError
5858
from ax.generation_strategy.generation_strategy import GenerationStrategy
5959
from ax.storage.json_store.decoder import _DEPRECATED_GENERATOR_KWARGS, object_from_json
60+
from ax.storage.json_store.decoders import _cast_parameter_value
6061
from ax.storage.sqa_store.db import session_scope
6162
from ax.storage.sqa_store.sqa_classes import (
6263
SQAAbandonedArm,
@@ -90,6 +91,24 @@
9091
logger: Logger = get_logger(__name__)
9192

9293

94+
def _cast_arm_parameters(arm: Arm, search_space: SearchSpace) -> None:
95+
"""Cast arm parameter values to the appropriate Python type.
96+
97+
This is necessary because SQA may deserialize values as different types
98+
(e.g., ints as floats). This function modifies the arm in place.
99+
100+
Args:
101+
arm: The arm whose parameter values should be cast.
102+
search_space: The search space containing parameter type information.
103+
"""
104+
for param_name, param_value in arm._parameters.items():
105+
if param_name in search_space.parameters:
106+
parameter = search_space.parameters[param_name]
107+
arm._parameters[param_name] = _cast_parameter_value(
108+
param_value, parameter.parameter_type
109+
)
110+
111+
93112
class Decoder:
94113
"""Class that contains methods for loading an Ax experiment from SQLAlchemy.
95114
@@ -378,9 +397,15 @@ def experiment_from_sqa(
378397
if trial.ttl_seconds is not None:
379398
experiment._trials_have_ttl = True
380399
for arm in trial.arms:
400+
# Cast arm parameter values to the appropriate type based on the
401+
# search space parameter types. This is necessary because SQA may
402+
# deserialize values as different types (e.g., ints as floats).
403+
_cast_arm_parameters(arm, experiment.search_space)
381404
experiment._register_arm(arm)
382405
if experiment.status_quo is not None:
383406
sq = none_throws(experiment.status_quo)
407+
# Cast status_quo arm parameter values as well.
408+
_cast_arm_parameters(sq, experiment.search_space)
384409
experiment._register_arm(sq)
385410
experiment._time_created = experiment_sqa.time_created
386411
experiment._experiment_type = self.get_enum_name(

ax/storage/sqa_store/tests/test_sqa_store.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ax.core.outcome_constraint import OutcomeConstraint, ScalarizedOutcomeConstraint
4444
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
4545
from ax.core.runner import Runner
46+
from ax.core.search_space import SearchSpace
4647
from ax.core.trial import Trial
4748
from ax.core.trial_status import TrialStatus
4849
from ax.core.types import ComparisonOp
@@ -350,6 +351,77 @@ def test_experiment_save_load(self) -> None:
350351
loaded_experiment = load_experiment(exp.name)
351352
self.assertEqual(loaded_experiment, exp)
352353

354+
def test_arm_parameter_values_cast_to_parameter_type(self) -> None:
355+
"""Test that arm parameter values are cast to the appropriate type on load.
356+
357+
This is important because SQA may deserialize values as different types
358+
(e.g., ints as floats). This test directly modifies the SQA objects to
359+
simulate float values for INT parameters, then verifies that the decoder
360+
casts them back to the correct types.
361+
"""
362+
# Create an experiment with INT and FLOAT parameters
363+
search_space = SearchSpace(
364+
parameters=[
365+
RangeParameter(
366+
name="x",
367+
parameter_type=ParameterType.INT,
368+
lower=0,
369+
upper=10,
370+
),
371+
RangeParameter(
372+
name="y",
373+
parameter_type=ParameterType.FLOAT,
374+
lower=0.0,
375+
upper=1.0,
376+
),
377+
]
378+
)
379+
380+
experiment = Experiment(
381+
name="test_arm_param_cast_experiment_sqa",
382+
search_space=search_space,
383+
status_quo=Arm(parameters={"x": 5, "y": 0.5}, name="status_quo"),
384+
)
385+
386+
# Add a trial with an arm
387+
trial = experiment.new_trial()
388+
trial.add_arm(Arm(parameters={"x": 3, "y": 0.3}))
389+
390+
# Encode the experiment to SQA
391+
experiment_sqa = self.encoder.experiment_to_sqa(experiment)
392+
393+
# Manually modify the SQA objects to simulate float values for INT parameters
394+
# This simulates what can happen when loading from external sources or
395+
# when JSON deserialization converts ints to floats.
396+
# Modify status_quo parameters
397+
none_throws(experiment_sqa.status_quo_parameters)["x"] = (
398+
5.0 # float instead of int
399+
)
400+
401+
# Modify arm parameters in the trial's generator run
402+
for trial_sqa in experiment_sqa.trials:
403+
for gr_sqa in trial_sqa.generator_runs:
404+
for arm_sqa in gr_sqa.arms:
405+
arm_sqa.parameters["x"] = 3.0 # float instead of int
406+
407+
# Decode the experiment from SQA
408+
loaded_experiment = self.decoder.experiment_from_sqa(experiment_sqa)
409+
410+
# Check that arm parameter values are cast to the correct type
411+
loaded_arm = list(loaded_experiment.trials[0].arms)[0]
412+
self.assertEqual(loaded_arm.parameters["x"], 3)
413+
self.assertIs(type(loaded_arm.parameters["x"]), int)
414+
self.assertEqual(loaded_arm.parameters["y"], 0.3)
415+
self.assertIs(type(loaded_arm.parameters["y"]), float)
416+
417+
# Check that status_quo parameter values are cast to the correct type
418+
status_quo = loaded_experiment.status_quo
419+
self.assertIsNotNone(status_quo)
420+
self.assertEqual(none_throws(status_quo).parameters["x"], 5)
421+
self.assertIs(type(none_throws(status_quo).parameters["x"]), int)
422+
self.assertEqual(none_throws(status_quo).parameters["y"], 0.5)
423+
self.assertIs(type(none_throws(status_quo).parameters["y"]), float)
424+
353425
def test_saving_and_loading_experiment_with_aux_exp(self) -> None:
354426
aux_experiment = Experiment(
355427
name="test_aux_exp_in_SQAStoreTest",

0 commit comments

Comments
 (0)