Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow suppressing parameter validation errors in ax_client.attach_trial() (#3269) #3402

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
8 changes: 7 additions & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,6 +1535,7 @@ def attach_trial(
ttl_seconds: int | None = None,
run_metadata: dict[str, Any] | None = None,
optimize_for_power: bool = False,
raise_parameter_error: bool = True,
) -> tuple[dict[str, TParameterization], int]:
"""Attach a new trial with the given parameterization to the experiment.

Expand All @@ -1552,6 +1553,9 @@ def attach_trial(
trial such that the experiment's power to detect effects of
certain size is as high as possible. Refer to documentation of
`BatchTrial.set_status_quo_and_optimize_power` for more detail.
raise_parameter_error: If True, raise an error if validating membership
of the parameterization in the search space fails. If False, do not
raise an error.

Returns:
Tuple of arm name to parameterization dict, and trial index from
Expand All @@ -1564,7 +1568,9 @@ def attach_trial(

# Validate search space membership for all parameterizations
for parameterization in parameterizations:
self.search_space.validate_membership(parameters=parameterization)
self.search_space.validate_membership(
parameters=parameterization, raise_error=raise_parameter_error
)

# Validate number of arm names if any arm names are provided.
named_arms = False
Expand Down
13 changes: 11 additions & 2 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,17 @@ def _validate_parameter_constraints(
f"`{parameter_name}` does not exist in search space."
)

def validate_membership(self, parameters: TParameterization) -> None:
self.check_membership(parameterization=parameters, raise_error=True)
def validate_membership(
self, parameters: TParameterization, raise_error: bool = True
) -> None:
"""Validates if a parameterization belongs in the search space.

Args:
raise_error: If True, raise an error if validating membership
of the parameterization in the search space fails. If False, do not
raise an error.
"""
self.check_membership(parameterization=parameters, raise_error=raise_error)
# `check_membership` uses int and float interchangeably, which we don't
# want here.
for p_name, parameter in self.parameters.items():
Expand Down
5 changes: 5 additions & 0 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ def attach_trial(
ttl_seconds: int | None = None,
run_metadata: dict[str, Any] | None = None,
arm_name: str | None = None,
raise_parameter_error: bool = True,
) -> tuple[TParameterization, int]:
"""Attach a new trial with the given parameterization to the experiment.
Expand All @@ -889,6 +890,9 @@ def attach_trial(
ttl_seconds: If specified, will consider the trial failed after this
many seconds. Used to detect dead trials that were not marked
failed properly.
raise_parameter_error: If True, raise an error if validating membership
of the parameterization in the search space fails. If False, do not
raise an error.
Returns:
Tuple of parameterization and trial index from newly created trial.
Expand All @@ -899,6 +903,7 @@ def attach_trial(
arm_names=[arm_name] if arm_name else None,
ttl_seconds=ttl_seconds,
run_metadata=run_metadata,
raise_parameter_error=raise_parameter_error,
)
self._save_or_update_trial_in_db_if_possible(
experiment=self.experiment,
Expand Down
33 changes: 33 additions & 0 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,39 @@ def test_attach_trial_and_get_trial_parameters(self) -> None:
with self.assertRaisesRegex(UnsupportedError, ".* is of type"):
ax_client.attach_trial({"x": 1, "y": 2})

def test_attach_trial_invalid_parameters(self) -> None:
ax_client = AxClient()
ax_client.create_experiment(
parameters=[
{"name": "x", "type": "range", "bounds": [-5.0, 10.0]},
{"name": "y", "type": "range", "bounds": [0.0, 15.0]},
],
)
# Test parameter outside bounds fails by default
with self.assertRaisesRegex(
ValueError,
"20\\.0 is not a valid value for parameter RangeParameter\\(name='x'",
):
ax_client.attach_trial(parameters={"x": 20.0, "y": 1.0})

# Test parameter outside bounds fails with raise_parameter_error=True
with self.assertRaisesRegex(
ValueError,
"20\\.0 is not a valid value for parameter RangeParameter\\(name='x'",
):
ax_client.attach_trial(
parameters={"x": 20.0, "y": 1.0}, raise_parameter_error=True
)

# Test parameter outside bounds succeeds with raise_parameter_error=False
_, idx = ax_client.attach_trial(
parameters={"x": 20.0, "y": 1.0}, raise_parameter_error=False
)
ax_client.complete_trial(trial_index=idx, raw_data=5)
self.assertEqual(
ax_client.get_trial_parameters(trial_index=idx), {"x": 20.0, "y": 1.0}
)

def test_attach_trial_ttl_seconds(self) -> None:
ax_client = AxClient()
ax_client.create_experiment(
Expand Down