Skip to content

Commit 36cecb4

Browse files
blethamfacebook-github-bot
authored andcommitted
add parameter constraints while adding new parameters
Summary: AxClient and Experiment each have a method for adding parameters to the experiment (added in D93766951). In settings where we are dealing with constrained parameters, we need a similarly convenient interface for adding the constraints on those parameters. I think the cleanest approach is to add constraints as a kwarg to the existing method, which is what is done here. I also fix a bug in the AxClient add parameters interface where the docstring says "status_quo_values: Optional parameter values for the new parameters to use in the status quo (baseline) arm, if one is defined. If None, the backfill values will be used for the status quo." but didn't use the backfill. The Experiment method this subsequently calls requiers status_quo_values, so this was actually producing an error if status_quo_values were not provided. Differential Revision: D96522400
1 parent e2056d2 commit 36cecb4

4 files changed

Lines changed: 195 additions & 1 deletion

File tree

ax/core/experiment.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ax.core.objective import MultiObjective
3939
from ax.core.optimization_config import ObjectiveThreshold, OptimizationConfig
4040
from ax.core.parameter import DerivedParameter, Parameter
41+
from ax.core.parameter_constraint import ParameterConstraint
4142
from ax.core.runner import Runner
4243
from ax.core.search_space import SearchSpace
4344
from ax.core.trial import Trial
@@ -341,6 +342,7 @@ def add_parameters_to_search_space(
341342
self,
342343
parameters: Sequence[Parameter],
343344
status_quo_values: TParameterization | None = None,
345+
parameter_constraints: Sequence[ParameterConstraint] | None = None,
344346
) -> None:
345347
"""
346348
Add new parameters to the experiment's search space. This allows extending
@@ -355,6 +357,8 @@ def add_parameters_to_search_space(
355357
space.
356358
status_quo_values: Optional parameter values for the new parameters to
357359
use in the status quo (baseline) arm, if one is defined.
360+
parameter_constraints: Optional sequence of typed ParameterConstraint
361+
objects to add to the search space after the parameters are added.
358362
"""
359363
status_quo_values = status_quo_values or {}
360364

@@ -408,6 +412,10 @@ def add_parameters_to_search_space(
408412
# Add parameters to search space
409413
self._search_space.add_parameters(parameters)
410414

415+
# Add parameter constraints to search space
416+
if parameter_constraints:
417+
self._search_space.add_parameter_constraints(list(parameter_constraints))
418+
411419
def disable_parameters_in_search_space(
412420
self, default_parameter_values: TParameterization
413421
) -> None:

ax/core/tests/test_experiment.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ParameterType,
3838
RangeParameter,
3939
)
40+
from ax.core.parameter_constraint import ParameterConstraint
4041
from ax.core.search_space import SearchSpace
4142
from ax.core.types import ComparisonOp
4243
from ax.exceptions.core import (
@@ -412,6 +413,31 @@ def test_add_search_space_parameters(self) -> None:
412413
self.assertIn("new_param", experiment.status_quo.parameters)
413414
self.assertEqual(experiment.status_quo.parameters["new_param"], 0.0)
414415

416+
with self.subTest("Add parameter with parameter constraints"):
417+
experiment = self.experiment.clone_with(trial_indices=[])
418+
num_existing_constraints = len(
419+
experiment.search_space.parameter_constraints
420+
)
421+
constraint = ParameterConstraint(
422+
inequality="new_param + w <= 5.0",
423+
)
424+
experiment.add_parameters_to_search_space(
425+
parameters=[new_param],
426+
status_quo_values={new_param.name: 0.0},
427+
parameter_constraints=[constraint],
428+
)
429+
# Verify parameter was added
430+
self.assertIn("new_param", experiment.search_space.parameters)
431+
# Verify constraint was added
432+
self.assertEqual(
433+
len(experiment.search_space.parameter_constraints),
434+
num_existing_constraints + 1,
435+
)
436+
added_constraint = experiment.search_space.parameter_constraints[-1]
437+
self.assertIn("new_param", added_constraint.constraint_dict)
438+
self.assertIn("w", added_constraint.constraint_dict)
439+
self.assertEqual(added_constraint.bound, 5.0)
440+
415441
def test_add_derived_parameter_to_search_space_with_trials(self) -> None:
416442
"""Test adding DerivedParameters to an experiment that has existing trials.
417443

ax/service/ax_client.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from ax.core.objective import MultiObjective, Objective
3131
from ax.core.observation import ObservationFeatures
3232
from ax.core.parameter import RangeParameter
33+
from ax.core.parameter_constraint import ParameterConstraint
3334
from ax.core.runner import Runner
3435
from ax.core.trial import Trial
3536
from ax.core.trial_status import TrialStatus
@@ -524,6 +525,7 @@ def add_parameters(
524525
parameters: Sequence[RangeParameterConfig | ChoiceParameterConfig],
525526
backfill_values: TParameterization,
526527
status_quo_values: TParameterization | None = None,
528+
parameter_constraints: list[str] | None = None,
527529
) -> None:
528530
"""
529531
Add new parameters to the experiment's search space. This allows extending
@@ -543,6 +545,10 @@ def add_parameters(
543545
status_quo_values: Optional parameter values for the new parameters to
544546
use in the status quo (baseline) arm, if one is defined. If None,
545547
the backfill values will be used for the status quo.
548+
parameter_constraints: Optional list of string representations of
549+
parameter constraints to add (e.g., ``"x1 + x2 <= 5.0"``
550+
or ``"x1 <= x2"``). May reference both existing and new
551+
parameters.
546552
"""
547553
parameters_to_add = [
548554
parameter_from_config(parameter_config) for parameter_config in parameters
@@ -563,9 +569,25 @@ def add_parameters(
563569
for parameter in parameters_to_add:
564570
if parameter.name in backfill_values:
565571
parameter._backfill_value = backfill_values[parameter.name]
572+
573+
# Convert string constraints to typed ParameterConstraint objects.
574+
typed_parameter_constraints: list[ParameterConstraint] = []
575+
if parameter_constraints:
576+
# Build a parameter map with both existing and new parameters so
577+
# constraints can reference either.
578+
parameter_map = {
579+
**self.experiment.search_space.parameters,
580+
**{p.name: p for p in parameters_to_add},
581+
}
582+
typed_parameter_constraints = [
583+
InstantiationBase.constraint_from_str(c, parameter_map)
584+
for c in parameter_constraints
585+
]
586+
566587
self.experiment.add_parameters_to_search_space(
567588
parameters=parameters_to_add,
568-
status_quo_values=status_quo_values,
589+
status_quo_values=status_quo_values or backfill_values,
590+
parameter_constraints=typed_parameter_constraints or None,
569591
)
570592
self._save_experiment_to_db_if_possible(experiment=self.experiment)
571593

ax/service/tests/test_ax_client.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1493,6 +1493,144 @@ def test_add_parameters(self) -> None:
14931493
assert isinstance(param_x3, ChoiceParameter)
14941494
self.assertEqual(param_x3.values, ["a", "b", "c"])
14951495

1496+
def test_add_parameters_backfill_values_used_for_status_quo(self) -> None:
1497+
"""Test that backfill_values are used for the status quo arm when
1498+
status_quo_values is not provided.
1499+
"""
1500+
ax_client = AxClient()
1501+
ax_client.create_experiment(
1502+
name="test_experiment",
1503+
parameters=[
1504+
{
1505+
"name": "x1",
1506+
"type": "range",
1507+
"bounds": [0.0, 1.0],
1508+
"value_type": "float",
1509+
},
1510+
],
1511+
status_quo={"x1": 0.5},
1512+
is_test=True,
1513+
immutable_search_space_and_opt_config=False,
1514+
)
1515+
1516+
ax_client.add_parameters(
1517+
parameters=[
1518+
RangeParameterConfig(
1519+
name="x2",
1520+
bounds=(0.0, 10.0),
1521+
parameter_type="float",
1522+
),
1523+
ChoiceParameterConfig(
1524+
name="x3",
1525+
values=["a", "b", "c"],
1526+
parameter_type="str",
1527+
),
1528+
],
1529+
backfill_values={"x2": 5.0, "x3": "a"},
1530+
)
1531+
1532+
# Verify the status quo arm was updated with backfill_values
1533+
status_quo = ax_client.experiment.status_quo
1534+
self.assertIsNotNone(status_quo)
1535+
assert status_quo is not None
1536+
self.assertEqual(status_quo.parameters["x1"], 0.5)
1537+
self.assertEqual(status_quo.parameters["x2"], 5.0)
1538+
self.assertEqual(status_quo.parameters["x3"], "a")
1539+
1540+
def test_add_parameters_with_constraints(self) -> None:
1541+
"""Test that add_parameters correctly adds parameter constraints."""
1542+
ax_client = AxClient()
1543+
ax_client.create_experiment(
1544+
name="test_experiment",
1545+
parameters=[
1546+
{
1547+
"name": "x1",
1548+
"type": "range",
1549+
"bounds": [0.0, 10.0],
1550+
"value_type": "float",
1551+
},
1552+
],
1553+
is_test=True,
1554+
immutable_search_space_and_opt_config=False,
1555+
)
1556+
1557+
with self.subTest("Sum constraint on new parameters"):
1558+
ax_client.add_parameters(
1559+
parameters=[
1560+
RangeParameterConfig(
1561+
name="x2",
1562+
bounds=(0.0, 10.0),
1563+
parameter_type="float",
1564+
),
1565+
],
1566+
backfill_values={"x2": 5.0},
1567+
parameter_constraints=["x1 + x2 <= 5.0"],
1568+
)
1569+
search_space = ax_client.experiment.search_space
1570+
self.assertIn("x2", search_space.parameters)
1571+
self.assertEqual(len(search_space.parameter_constraints), 1)
1572+
constraint = search_space.parameter_constraints[0]
1573+
self.assertIn("x1", constraint.constraint_dict)
1574+
self.assertIn("x2", constraint.constraint_dict)
1575+
self.assertEqual(constraint.bound, 5.0)
1576+
1577+
with self.subTest("Order constraint referencing existing and new parameter"):
1578+
ax_client_2 = AxClient()
1579+
ax_client_2.create_experiment(
1580+
name="test_experiment_2",
1581+
parameters=[
1582+
{
1583+
"name": "x1",
1584+
"type": "range",
1585+
"bounds": [0.0, 10.0],
1586+
"value_type": "float",
1587+
},
1588+
],
1589+
is_test=True,
1590+
immutable_search_space_and_opt_config=False,
1591+
)
1592+
ax_client_2.add_parameters(
1593+
parameters=[
1594+
RangeParameterConfig(
1595+
name="x2",
1596+
bounds=(0.0, 10.0),
1597+
parameter_type="float",
1598+
),
1599+
],
1600+
backfill_values={"x2": 5.0},
1601+
parameter_constraints=["x1 <= x2"],
1602+
)
1603+
search_space = ax_client_2.experiment.search_space
1604+
self.assertEqual(len(search_space.parameter_constraints), 1)
1605+
1606+
with self.subTest("Constraint referencing non-existent parameter"):
1607+
ax_client_3 = AxClient()
1608+
ax_client_3.create_experiment(
1609+
name="test_experiment_3",
1610+
parameters=[
1611+
{
1612+
"name": "x1",
1613+
"type": "range",
1614+
"bounds": [0.0, 10.0],
1615+
"value_type": "float",
1616+
},
1617+
],
1618+
is_test=True,
1619+
immutable_search_space_and_opt_config=False,
1620+
)
1621+
with self.assertRaises(ValueError):
1622+
ax_client_3.add_parameters(
1623+
parameters=[
1624+
RangeParameterConfig(
1625+
name="x2",
1626+
bounds=(0.0, 10.0),
1627+
parameter_type="float",
1628+
),
1629+
],
1630+
backfill_values={"x2": 5.0},
1631+
parameter_constraints=["x1 + nonexistent <= 5.0"],
1632+
)
1633+
14961634
def test_disable_parameters(self) -> None:
14971635
"""Test that disable_parameters correctly disables parameters in the search
14981636
space."""

0 commit comments

Comments
 (0)