@@ -1725,6 +1725,145 @@ def test_choice_parameter_backward_compatibility_sort_values(self) -> None:
17251725 )
17261726 )
17271727
1728+ def test_arm_parameter_values_cast_to_parameter_type (self ) -> None :
1729+ """Test that arm parameter values are cast to the appropriate type on load."""
1730+ from ax .core .arm import Arm
1731+ from ax .core .experiment import Experiment
1732+ from ax .core .parameter import RangeParameter
1733+ from ax .core .search_space import SearchSpace
1734+ from ax .storage .json_store .encoder import object_to_json
1735+
1736+ # Create an experiment with INT parameters
1737+ search_space = SearchSpace (
1738+ parameters = [
1739+ RangeParameter (
1740+ name = "x" ,
1741+ parameter_type = ParameterType .INT ,
1742+ lower = 0 ,
1743+ upper = 10 ,
1744+ ),
1745+ RangeParameter (
1746+ name = "y" ,
1747+ parameter_type = ParameterType .FLOAT ,
1748+ lower = 0.0 ,
1749+ upper = 1.0 ,
1750+ ),
1751+ ]
1752+ )
1753+
1754+ experiment = Experiment (
1755+ name = "test_experiment" ,
1756+ search_space = search_space ,
1757+ status_quo = Arm (parameters = {"x" : 5 , "y" : 0.5 }, name = "status_quo" ),
1758+ )
1759+
1760+ # Add a trial with an arm
1761+ trial = experiment .new_trial ()
1762+ trial .add_arm (Arm (parameters = {"x" : 3 , "y" : 0.3 }))
1763+
1764+ # Encode the experiment to JSON
1765+ experiment_json = object_to_json (
1766+ experiment ,
1767+ encoder_registry = CORE_ENCODER_REGISTRY ,
1768+ class_encoder_registry = CORE_CLASS_ENCODER_REGISTRY ,
1769+ )
1770+
1771+ # Manually modify the JSON to simulate float values for INT parameters
1772+ # (as could happen when loading from external sources)
1773+ for arm_json in experiment_json ["trials" ][0 ]["generator_run" ]["arms" ]:
1774+ arm_json ["parameters" ]["x" ] = 3.0 # float instead of int
1775+ experiment_json ["status_quo" ]["parameters" ]["x" ] = 5.0 # float instead of int
1776+
1777+ # Decode the experiment from JSON
1778+ loaded_experiment = object_from_json (
1779+ experiment_json ,
1780+ decoder_registry = CORE_DECODER_REGISTRY ,
1781+ class_decoder_registry = CORE_CLASS_DECODER_REGISTRY ,
1782+ )
1783+
1784+ # Check that arm parameter values are cast to the correct type
1785+ loaded_arm = list (loaded_experiment .trials [0 ].arms )[0 ]
1786+ self .assertEqual (loaded_arm .parameters ["x" ], 3 )
1787+ self .assertIs (type (loaded_arm .parameters ["x" ]), int )
1788+ self .assertEqual (loaded_arm .parameters ["y" ], 0.3 )
1789+ self .assertIs (type (loaded_arm .parameters ["y" ]), float )
1790+
1791+ # Check that status_quo parameter values are cast to the correct type
1792+ status_quo = loaded_experiment .status_quo
1793+ self .assertIsNotNone (status_quo )
1794+ self .assertEqual (status_quo .parameters ["x" ], 5 )
1795+ self .assertIs (type (status_quo .parameters ["x" ]), int )
1796+ self .assertEqual (status_quo .parameters ["y" ], 0.5 )
1797+ self .assertIs (type (status_quo .parameters ["y" ]), float )
1798+
1799+ def test_cast_parameter_value_all_types (self ) -> None :
1800+ """Test _cast_parameter_value handles all parameter types correctly."""
1801+ from ax .storage .json_store .decoders import _cast_parameter_value
1802+
1803+ # Test INT casting
1804+ self .assertEqual (_cast_parameter_value (3.0 , ParameterType .INT ), 3 )
1805+ self .assertIs (type (_cast_parameter_value (3.0 , ParameterType .INT )), int )
1806+ self .assertEqual (_cast_parameter_value (3 , ParameterType .INT ), 3 )
1807+ self .assertIs (type (_cast_parameter_value (3 , ParameterType .INT )), int )
1808+
1809+ # Test FLOAT casting
1810+ self .assertEqual (_cast_parameter_value (3 , ParameterType .FLOAT ), 3.0 )
1811+ self .assertIs (type (_cast_parameter_value (3 , ParameterType .FLOAT )), float )
1812+ self .assertEqual (_cast_parameter_value (3.5 , ParameterType .FLOAT ), 3.5 )
1813+ self .assertIs (type (_cast_parameter_value (3.5 , ParameterType .FLOAT )), float )
1814+
1815+ # Test BOOL casting
1816+ self .assertEqual (_cast_parameter_value (1 , ParameterType .BOOL ), True )
1817+ self .assertIs (type (_cast_parameter_value (1 , ParameterType .BOOL )), bool )
1818+ self .assertEqual (_cast_parameter_value (0 , ParameterType .BOOL ), False )
1819+ self .assertIs (type (_cast_parameter_value (0 , ParameterType .BOOL )), bool )
1820+ self .assertEqual (_cast_parameter_value (True , ParameterType .BOOL ), True )
1821+ self .assertIs (type (_cast_parameter_value (True , ParameterType .BOOL )), bool )
1822+
1823+ # Test STRING casting
1824+ self .assertEqual (_cast_parameter_value ("test" , ParameterType .STRING ), "test" )
1825+ self .assertIs (type (_cast_parameter_value ("test" , ParameterType .STRING )), str )
1826+ self .assertEqual (_cast_parameter_value (123 , ParameterType .STRING ), "123" )
1827+ self .assertIs (type (_cast_parameter_value (123 , ParameterType .STRING )), str )
1828+
1829+ # Test None handling
1830+ self .assertIsNone (_cast_parameter_value (None , ParameterType .INT ))
1831+ self .assertIsNone (_cast_parameter_value (None , ParameterType .FLOAT ))
1832+ self .assertIsNone (_cast_parameter_value (None , ParameterType .BOOL ))
1833+ self .assertIsNone (_cast_parameter_value (None , ParameterType .STRING ))
1834+
1835+ def test_cast_arm_parameters_skips_unknown_params (self ) -> None :
1836+ """Test that _cast_arm_parameters skips parameters not in search space."""
1837+ from ax .core .arm import Arm
1838+ from ax .core .parameter import RangeParameter
1839+ from ax .core .search_space import SearchSpace
1840+ from ax .storage .json_store .decoder import _cast_arm_parameters
1841+
1842+ search_space = SearchSpace (
1843+ parameters = [
1844+ RangeParameter (
1845+ name = "x" ,
1846+ parameter_type = ParameterType .INT ,
1847+ lower = 0 ,
1848+ upper = 10 ,
1849+ ),
1850+ ]
1851+ )
1852+
1853+ # Create an arm with a parameter that's not in the search space
1854+ arm = Arm (parameters = {"x" : 3.0 , "unknown_param" : "some_value" })
1855+
1856+ # Cast should work without error and only cast known parameters
1857+ _cast_arm_parameters (arm , search_space )
1858+
1859+ # x should be cast to int
1860+ self .assertEqual (arm .parameters ["x" ], 3 )
1861+ self .assertIs (type (arm .parameters ["x" ]), int )
1862+
1863+ # unknown_param should remain unchanged
1864+ self .assertEqual (arm .parameters ["unknown_param" ], "some_value" )
1865+ self .assertIs (type (arm .parameters ["unknown_param" ]), str )
1866+
17281867 def test_surrogate_spec_backwards_compatibility (self ) -> None :
17291868 # This is an invalid example that has both deprecated args
17301869 # and model config specified. Deprecated args will be ignored.
0 commit comments