diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index 44c6b52411e..cef84761cad 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -140,7 +140,8 @@ def __init__( expand_model_space: If True, expand range parameter bounds in model space to cover given training data. This will make the modeling space larger than the search space if training data fall outside - the search space. + the search space. Will also include training points that violate + parameter constraints in the modeling. fit_out_of_design: If specified, all training data are used. Otherwise, only in design points are used. fit_abandoned: Whether data for abandoned arms or trials should be @@ -442,6 +443,8 @@ def _set_model_space(self, observations: list[Observation]) -> None: if isinstance(p, RangeParameter): p.lower = min(p.lower, min(param_vals[p.name])) p.upper = max(p.upper, max(param_vals[p.name])) + # Remove parameter constraints from the model space. + self._model_space.set_parameter_constraints([]) def _set_status_quo( self, diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index bc0822ba9bb..6232aae351b 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -22,6 +22,7 @@ from ax.core.observation import ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.parameter import FixedParameter, ParameterType, RangeParameter +from ax.core.parameter_constraint import SumConstraint from ax.core.search_space import SearchSpace from ax.exceptions.core import UnsupportedError, UserInputError from ax.modelbridge.base import ( @@ -1000,10 +1001,21 @@ def test_SetModelSpace(self) -> None: experiment.attach_data(get_branin_data_batch(batch=trial, fill_vals=sq_vals)) trial.mark_completed() data = experiment.lookup_data() + # Make search space with a parameter constraint + ss = experiment.search_space.clone() + ss.set_parameter_constraints( + [ + SumConstraint( + parameters=list(ss.parameters.values()), + is_upper_bound=True, + bound=30.0, + ) + ] + ) # Check that SQ and custom are OOD m = Adapter( - search_space=experiment.search_space, + search_space=ss, model=None, experiment=experiment, data=data, @@ -1014,10 +1026,11 @@ def test_SetModelSpace(self) -> None: self.assertEqual(set(ood_arms), {"status_quo", "custom"}) self.assertEqual(m.model_space.parameters["x1"].lower, -5.0) # pyre-ignore[16] self.assertEqual(m.model_space.parameters["x2"].upper, 15.0) # pyre-ignore[16] + self.assertEqual(len(m.model_space.parameter_constraints), 1) # With expand model space, custom is not OOD, and model space is expanded m = Adapter( - search_space=experiment.search_space, + search_space=ss, model=None, experiment=experiment, data=data, @@ -1027,10 +1040,11 @@ def test_SetModelSpace(self) -> None: self.assertEqual(set(ood_arms), {"status_quo"}) self.assertEqual(m.model_space.parameters["x1"].lower, -20.0) self.assertEqual(m.model_space.parameters["x2"].upper, 18.0) + self.assertEqual(m.model_space.parameter_constraints, []) # With fill values, SQ is also in design, and x2 is further expanded m = Adapter( - search_space=experiment.search_space, + search_space=ss, model=None, experiment=experiment, data=data, @@ -1039,6 +1053,7 @@ def test_SetModelSpace(self) -> None: ) self.assertEqual(sum(m.training_in_design), 7) self.assertEqual(m.model_space.parameters["x2"].upper, 20) + self.assertEqual(m.model_space.parameter_constraints, []) @mock.patch( "ax.modelbridge.base.observations_from_data",