diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 25b522fbc6d..a6f5f6afda8 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -1535,6 +1535,7 @@ def attach_trial( ttl_seconds: int | None = None, run_metadata: dict[str, Any] | None = None, optimize_for_power: bool = False, + raise_parameter_error: bool = True, ) -> tuple[dict[str, TParameterization], int]: """Attach a new trial with the given parameterization to the experiment. @@ -1552,6 +1553,9 @@ def attach_trial( trial such that the experiment's power to detect effects of certain size is as high as possible. Refer to documentation of `BatchTrial.set_status_quo_and_optimize_power` for more detail. + raise_parameter_error: If True, raise an error if validating membership + of the parameterization in the search space fails. If False, do not + raise an error. Returns: Tuple of arm name to parameterization dict, and trial index from @@ -1564,7 +1568,9 @@ def attach_trial( # Validate search space membership for all parameterizations for parameterization in parameterizations: - self.search_space.validate_membership(parameters=parameterization) + self.search_space.validate_membership( + parameters=parameterization, raise_error=raise_parameter_error + ) # Validate number of arm names if any arm names are provided. named_arms = False diff --git a/ax/core/search_space.py b/ax/core/search_space.py index b75cf380596..2f1824db704 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -378,8 +378,17 @@ def _validate_parameter_constraints( f"`{parameter_name}` does not exist in search space." ) - def validate_membership(self, parameters: TParameterization) -> None: - self.check_membership(parameterization=parameters, raise_error=True) + def validate_membership( + self, parameters: TParameterization, raise_error: bool = True + ) -> None: + """Validates if a parameterization belongs in the search space. + + Args: + raise_error: If True, raise an error if validating membership + of the parameterization in the search space fails. If False, do not + raise an error. + """ + self.check_membership(parameterization=parameters, raise_error=raise_error) # `check_membership` uses int and float interchangeably, which we don't # want here. for p_name, parameter in self.parameters.items(): diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 01a9af29c62..b7a63b0cb11 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -881,6 +881,7 @@ def attach_trial( ttl_seconds: int | None = None, run_metadata: dict[str, Any] | None = None, arm_name: str | None = None, + raise_parameter_error: bool = True, ) -> tuple[TParameterization, int]: """Attach a new trial with the given parameterization to the experiment. @@ -889,6 +890,9 @@ def attach_trial( ttl_seconds: If specified, will consider the trial failed after this many seconds. Used to detect dead trials that were not marked failed properly. + raise_parameter_error: If True, raise an error if validating membership + of the parameterization in the search space fails. If False, do not + raise an error. Returns: Tuple of parameterization and trial index from newly created trial. @@ -899,6 +903,7 @@ def attach_trial( arm_names=[arm_name] if arm_name else None, ttl_seconds=ttl_seconds, run_metadata=run_metadata, + raise_parameter_error=raise_parameter_error, ) self._save_or_update_trial_in_db_if_possible( experiment=self.experiment, diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 7988563b6b8..26c30bdd8ce 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -1941,6 +1941,39 @@ def test_attach_trial_and_get_trial_parameters(self) -> None: with self.assertRaisesRegex(UnsupportedError, ".* is of type"): ax_client.attach_trial({"x": 1, "y": 2}) + def test_attach_trial_invalid_parameters(self) -> None: + ax_client = AxClient() + ax_client.create_experiment( + parameters=[ + {"name": "x", "type": "range", "bounds": [-5.0, 10.0]}, + {"name": "y", "type": "range", "bounds": [0.0, 15.0]}, + ], + ) + # Test parameter outside bounds fails by default + with self.assertRaisesRegex( + ValueError, + "20\\.0 is not a valid value for parameter RangeParameter\\(name='x'", + ): + ax_client.attach_trial(parameters={"x": 20.0, "y": 1.0}) + + # Test parameter outside bounds fails with raise_parameter_error=True + with self.assertRaisesRegex( + ValueError, + "20\\.0 is not a valid value for parameter RangeParameter\\(name='x'", + ): + ax_client.attach_trial( + parameters={"x": 20.0, "y": 1.0}, raise_parameter_error=True + ) + + # Test parameter outside bounds succeeds with raise_parameter_error=False + _, idx = ax_client.attach_trial( + parameters={"x": 20.0, "y": 1.0}, raise_parameter_error=False + ) + ax_client.complete_trial(trial_index=idx, raw_data=5) + self.assertEqual( + ax_client.get_trial_parameters(trial_index=idx), {"x": 20.0, "y": 1.0} + ) + def test_attach_trial_ttl_seconds(self) -> None: ax_client = AxClient() ax_client.create_experiment(