Skip to content

Commit 81d384c

Browse files
blethamfacebook-github-bot
authored andcommitted
add parameter constraints while adding new parameters (#5023)
Summary: AxClient and Experiment each have a method for adding parameters to the experiment (added in D93766951). In settings where we are dealing with constrained parameters, we need a similarly convenient interface for adding the constraints on those parameters. I think the cleanest approach is to add constraints as a kwarg to the existing method, which is what is done here. I also fix a bug in the AxClient add parameters interface where the docstring says "status_quo_values: Optional parameter values for the new parameters to use in the status quo (baseline) arm, if one is defined. If None, the backfill values will be used for the status quo." but didn't use the backfill. The Experiment method this subsequently calls requiers status_quo_values, so this was actually producing an error if status_quo_values were not provided. Differential Revision: D96522400
1 parent 0d3535e commit 81d384c

4 files changed

Lines changed: 195 additions & 1 deletion

File tree

ax/core/experiment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
OptimizationConfig,
4242
)
4343
from ax.core.parameter import DerivedParameter, Parameter
44+
from ax.core.parameter_constraint import ParameterConstraint
4445
from ax.core.runner import Runner
4546
from ax.core.search_space import SearchSpace
4647
from ax.core.trial import Trial
@@ -362,6 +363,7 @@ def add_parameters_to_search_space(
362363
self,
363364
parameters: Sequence[Parameter],
364365
status_quo_values: TParameterization | None = None,
366+
parameter_constraints: Sequence[ParameterConstraint] | None = None,
365367
) -> None:
366368
"""
367369
Add new parameters to the experiment's search space. This allows extending
@@ -376,6 +378,8 @@ def add_parameters_to_search_space(
376378
space.
377379
status_quo_values: Optional parameter values for the new parameters to
378380
use in the status quo (baseline) arm, if one is defined.
381+
parameter_constraints: Optional sequence of typed ParameterConstraint
382+
objects to add to the search space after the parameters are added.
379383
"""
380384
status_quo_values = status_quo_values or {}
381385

@@ -429,6 +433,10 @@ def add_parameters_to_search_space(
429433
# Add parameters to search space
430434
self._search_space.add_parameters(parameters)
431435

436+
# Add parameter constraints to search space
437+
if parameter_constraints:
438+
self._search_space.add_parameter_constraints(list(parameter_constraints))
439+
432440
def disable_parameters_in_search_space(
433441
self, default_parameter_values: TParameterization
434442
) -> None:

ax/core/tests/test_experiment.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
ParameterType,
3737
RangeParameter,
3838
)
39+
from ax.core.parameter_constraint import ParameterConstraint
3940
from ax.core.search_space import SearchSpace
4041
from ax.core.types import ComparisonOp
4142
from ax.exceptions.core import (
@@ -427,6 +428,31 @@ def test_add_search_space_parameters(self) -> None:
427428
self.assertIn("new_param", experiment.status_quo.parameters)
428429
self.assertEqual(experiment.status_quo.parameters["new_param"], 0.0)
429430

431+
with self.subTest("Add parameter with parameter constraints"):
432+
experiment = self.experiment.clone_with(trial_indices=[])
433+
num_existing_constraints = len(
434+
experiment.search_space.parameter_constraints
435+
)
436+
constraint = ParameterConstraint(
437+
inequality="new_param + w <= 5.0",
438+
)
439+
experiment.add_parameters_to_search_space(
440+
parameters=[new_param],
441+
status_quo_values={new_param.name: 0.0},
442+
parameter_constraints=[constraint],
443+
)
444+
# Verify parameter was added
445+
self.assertIn("new_param", experiment.search_space.parameters)
446+
# Verify constraint was added
447+
self.assertEqual(
448+
len(experiment.search_space.parameter_constraints),
449+
num_existing_constraints + 1,
450+
)
451+
added_constraint = experiment.search_space.parameter_constraints[-1]
452+
self.assertIn("new_param", added_constraint.constraint_dict)
453+
self.assertIn("w", added_constraint.constraint_dict)
454+
self.assertEqual(added_constraint.bound, 5.0)
455+
430456
def test_add_derived_parameter_to_search_space_with_trials(self) -> None:
431457
"""Test adding DerivedParameters to an experiment that has existing trials.
432458

ax/service/ax_client.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ax.core.objective import Objective
3131
from ax.core.observation import ObservationFeatures
3232
from ax.core.parameter import RangeParameter
33+
from ax.core.parameter_constraint import ParameterConstraint
3334
from ax.core.runner import Runner
3435
from ax.core.trial import Trial
3536
from ax.core.trial_status import TrialStatus
@@ -555,6 +556,7 @@ def add_parameters(
555556
parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig],
556557
backfill_values: TParameterization,
557558
status_quo_values: TParameterization | None = None,
559+
parameter_constraints: list[str] | None = None,
558560
) -> None:
559561
"""
560562
Add new parameters to the experiment's search space. This allows extending
@@ -574,6 +576,10 @@ def add_parameters(
574576
status_quo_values: Optional parameter values for the new parameters to
575577
use in the status quo (baseline) arm, if one is defined. If None,
576578
the backfill values will be used for the status quo.
579+
parameter_constraints: Optional list of string representations of
580+
parameter constraints to add (e.g., ``"x1 + x2 <= 5.0"``
581+
or ``"x1 <= x2"``). May reference both existing and new
582+
parameters.
577583
"""
578584
parameters_to_add = [
579585
parameter_from_config(parameter_config) for parameter_config in parameters
@@ -594,9 +600,25 @@ def add_parameters(
594600
for parameter in parameters_to_add:
595601
if parameter.name in backfill_values:
596602
parameter._backfill_value = backfill_values[parameter.name]
603+
604+
# Convert string constraints to typed ParameterConstraint objects.
605+
typed_parameter_constraints: list[ParameterConstraint] = []
606+
if parameter_constraints:
607+
# Build a parameter map with both existing and new parameters so
608+
# constraints can reference either.
609+
parameter_map = {
610+
**self.experiment.search_space.parameters,
611+
**{p.name: p for p in parameters_to_add},
612+
}
613+
typed_parameter_constraints = [
614+
InstantiationBase.constraint_from_str(c, parameter_map)
615+
for c in parameter_constraints
616+
]
617+
597618
self.experiment.add_parameters_to_search_space(
598619
parameters=parameters_to_add,
599-
status_quo_values=status_quo_values,
620+
status_quo_values=status_quo_values or backfill_values,
621+
parameter_constraints=typed_parameter_constraints or None,
600622
)
601623
self._save_experiment_to_db_if_possible(experiment=self.experiment)
602624

ax/service/tests/test_ax_client.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1506,6 +1506,144 @@ def test_add_parameters(self) -> None:
15061506
assert isinstance(param_x3, ChoiceParameter)
15071507
self.assertEqual(param_x3.values, ["a", "b", "c"])
15081508

1509+
def test_add_parameters_backfill_values_used_for_status_quo(self) -> None:
1510+
"""Test that backfill_values are used for the status quo arm when
1511+
status_quo_values is not provided.
1512+
"""
1513+
ax_client = AxClient()
1514+
ax_client.create_experiment(
1515+
name="test_experiment",
1516+
parameters=[
1517+
{
1518+
"name": "x1",
1519+
"type": "range",
1520+
"bounds": [0.0, 1.0],
1521+
"value_type": "float",
1522+
},
1523+
],
1524+
status_quo={"x1": 0.5},
1525+
is_test=True,
1526+
immutable_search_space_and_opt_config=False,
1527+
)
1528+
1529+
ax_client.add_parameters(
1530+
parameters=[
1531+
RangeParameterConfig(
1532+
name="x2",
1533+
bounds=(0.0, 10.0),
1534+
parameter_type="float",
1535+
),
1536+
ChoiceParameterConfig(
1537+
name="x3",
1538+
values=["a", "b", "c"],
1539+
parameter_type="str",
1540+
),
1541+
],
1542+
backfill_values={"x2": 5.0, "x3": "a"},
1543+
)
1544+
1545+
# Verify the status quo arm was updated with backfill_values
1546+
status_quo = ax_client.experiment.status_quo
1547+
self.assertIsNotNone(status_quo)
1548+
assert status_quo is not None
1549+
self.assertEqual(status_quo.parameters["x1"], 0.5)
1550+
self.assertEqual(status_quo.parameters["x2"], 5.0)
1551+
self.assertEqual(status_quo.parameters["x3"], "a")
1552+
1553+
def test_add_parameters_with_constraints(self) -> None:
1554+
"""Test that add_parameters correctly adds parameter constraints."""
1555+
ax_client = AxClient()
1556+
ax_client.create_experiment(
1557+
name="test_experiment",
1558+
parameters=[
1559+
{
1560+
"name": "x1",
1561+
"type": "range",
1562+
"bounds": [0.0, 10.0],
1563+
"value_type": "float",
1564+
},
1565+
],
1566+
is_test=True,
1567+
immutable_search_space_and_opt_config=False,
1568+
)
1569+
1570+
with self.subTest("Sum constraint on new parameters"):
1571+
ax_client.add_parameters(
1572+
parameters=[
1573+
RangeParameterConfig(
1574+
name="x2",
1575+
bounds=(0.0, 10.0),
1576+
parameter_type="float",
1577+
),
1578+
],
1579+
backfill_values={"x2": 5.0},
1580+
parameter_constraints=["x1 + x2 <= 5.0"],
1581+
)
1582+
search_space = ax_client.experiment.search_space
1583+
self.assertIn("x2", search_space.parameters)
1584+
self.assertEqual(len(search_space.parameter_constraints), 1)
1585+
constraint = search_space.parameter_constraints[0]
1586+
self.assertIn("x1", constraint.constraint_dict)
1587+
self.assertIn("x2", constraint.constraint_dict)
1588+
self.assertEqual(constraint.bound, 5.0)
1589+
1590+
with self.subTest("Order constraint referencing existing and new parameter"):
1591+
ax_client_2 = AxClient()
1592+
ax_client_2.create_experiment(
1593+
name="test_experiment_2",
1594+
parameters=[
1595+
{
1596+
"name": "x1",
1597+
"type": "range",
1598+
"bounds": [0.0, 10.0],
1599+
"value_type": "float",
1600+
},
1601+
],
1602+
is_test=True,
1603+
immutable_search_space_and_opt_config=False,
1604+
)
1605+
ax_client_2.add_parameters(
1606+
parameters=[
1607+
RangeParameterConfig(
1608+
name="x2",
1609+
bounds=(0.0, 10.0),
1610+
parameter_type="float",
1611+
),
1612+
],
1613+
backfill_values={"x2": 5.0},
1614+
parameter_constraints=["x1 <= x2"],
1615+
)
1616+
search_space = ax_client_2.experiment.search_space
1617+
self.assertEqual(len(search_space.parameter_constraints), 1)
1618+
1619+
with self.subTest("Constraint referencing non-existent parameter"):
1620+
ax_client_3 = AxClient()
1621+
ax_client_3.create_experiment(
1622+
name="test_experiment_3",
1623+
parameters=[
1624+
{
1625+
"name": "x1",
1626+
"type": "range",
1627+
"bounds": [0.0, 10.0],
1628+
"value_type": "float",
1629+
},
1630+
],
1631+
is_test=True,
1632+
immutable_search_space_and_opt_config=False,
1633+
)
1634+
with self.assertRaises(ValueError):
1635+
ax_client_3.add_parameters(
1636+
parameters=[
1637+
RangeParameterConfig(
1638+
name="x2",
1639+
bounds=(0.0, 10.0),
1640+
parameter_type="float",
1641+
),
1642+
],
1643+
backfill_values={"x2": 5.0},
1644+
parameter_constraints=["x1 + nonexistent <= 5.0"],
1645+
)
1646+
15091647
def test_disable_parameters(self) -> None:
15101648
"""Test that disable_parameters correctly disables parameters in the search
15111649
space."""

0 commit comments

Comments
 (0)