Skip to content

Commit b668431

Browse files
sdaultonmeta-codesync[bot]
authored andcommitted
Thread equality constraints through adapter layer and TorchOptConfig (#5176)
Summary: Pull Request resolved: #5176 Add equality constraint extraction and propagation through the adapter layer to TorchOptConfig, enabling downstream generators to receive equality constraints. - Add `extract_equality_constraints` in `adapter_utils.py` (filters for `is_equality=True` constraints, returns `(A, b)` matrices). - Update `extract_parameter_constraints` to filter out equality constraints. - Add `equality_constraints` parameter to `validate_and_apply_final_transform` (now returns a 7-tuple). - Add `equality_constraints: tuple[Tensor, Tensor] | None` field to `TorchOptConfig`. - Update `TorchAdapter._get_transformed_model_gen_args` to extract equality constraints and pass them through to `TorchOptConfig`. Reviewed By: bletham Differential Revision: D100256480 fbshipit-source-id: c7a99278fc93fdc85df72e8840a788ab7aaeba90
1 parent cdc5060 commit b668431

5 files changed

Lines changed: 163 additions & 16 deletions

File tree

ax/adapter/adapter_utils.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,31 +76,67 @@
7676
from ax import adapter as adapter_module # noqa F401
7777

7878

79-
def extract_parameter_constraints(
80-
parameter_constraints: list[ParameterConstraint], param_names: list[str]
79+
def _extract_constraints(
80+
parameter_constraints: list[ParameterConstraint],
81+
param_names: list[str],
82+
is_equality: bool,
8183
) -> TBounds:
82-
"""Convert Ax parameter constraints into a tuple of NumPy arrays representing the
83-
system of linear inequality constraints.
84+
"""Extract linear constraints into a tuple of NumPy arrays.
85+
86+
Shared helper for extracting inequality (``A x <= b``) or equality
87+
(``A x = b``) constraints.
8488
8589
Args:
8690
parameter_constraints: A list of parameter constraint objects.
8791
param_names: A list of parameter names.
92+
is_equality: If True, extract equality constraints; otherwise
93+
extract inequality constraints.
8894
8995
Returns:
90-
An optional tuple of NumPy arrays (A, b) representing the system of linear
91-
inequality constraints A x < b.
96+
An optional tuple of NumPy arrays (A, b).
9297
"""
93-
if len(parameter_constraints) == 0:
98+
filtered = [c for c in parameter_constraints if c.is_equality == is_equality]
99+
if len(filtered) == 0:
94100
return None
95-
A = np.zeros((len(parameter_constraints), len(param_names)))
96-
b = np.zeros((len(parameter_constraints), 1))
97-
for i, c in enumerate(parameter_constraints):
101+
A = np.zeros((len(filtered), len(param_names)))
102+
b = np.zeros((len(filtered), 1))
103+
for i, c in enumerate(filtered):
98104
b[i, 0] = c.bound
99105
for name, val in c.constraint_dict.items():
100106
A[i, param_names.index(name)] = val
101107
return (A, b)
102108

103109

110+
def extract_inequality_constraints(
111+
parameter_constraints: list[ParameterConstraint], param_names: list[str]
112+
) -> TBounds:
113+
"""Convert Ax inequality parameter constraints into NumPy arrays.
114+
115+
Args:
116+
parameter_constraints: A list of parameter constraint objects.
117+
param_names: A list of parameter names.
118+
119+
Returns:
120+
An optional tuple of NumPy arrays (A, b) with ``A x <= b``.
121+
"""
122+
return _extract_constraints(parameter_constraints, param_names, is_equality=False)
123+
124+
125+
def extract_equality_constraints(
126+
parameter_constraints: list[ParameterConstraint], param_names: list[str]
127+
) -> TBounds:
128+
"""Convert Ax equality parameter constraints into NumPy arrays.
129+
130+
Args:
131+
parameter_constraints: A list of parameter constraint objects.
132+
param_names: A list of parameter names.
133+
134+
Returns:
135+
An optional tuple of NumPy arrays (A, b) with ``A x = b``.
136+
"""
137+
return _extract_constraints(parameter_constraints, param_names, is_equality=True)
138+
139+
104140
def extract_search_space_digest(
105141
search_space: SearchSpace, param_names: list[str]
106142
) -> SearchSpaceDigest:
@@ -402,6 +438,7 @@ def validate_and_apply_final_transform(
402438
pending_observations: list[npt.NDArray] | None,
403439
objective_thresholds: npt.NDArray | None = None,
404440
pruning_target_point: npt.NDArray | None = None,
441+
equality_constraints: tuple[npt.NDArray, npt.NDArray] | None = None,
405442
final_transform: Callable[[npt.NDArray], Tensor] = torch.tensor,
406443
) -> tuple[
407444
Tensor,
@@ -410,6 +447,7 @@ def validate_and_apply_final_transform(
410447
list[Tensor] | None,
411448
Tensor | None,
412449
Tensor | None,
450+
tuple[Tensor, Tensor] | None,
413451
]:
414452
# TODO: use some container down the road (similar to
415453
# SearchSpaceDigest) to limit the return arguments
@@ -437,13 +475,20 @@ def validate_and_apply_final_transform(
437475
pruning_target_tensor: Tensor | None = None
438476
if pruning_target_point is not None:
439477
pruning_target_tensor = final_transform(pruning_target_point)
478+
equality_constraints_tensors: tuple[Tensor, Tensor] | None = None
479+
if equality_constraints is not None:
480+
equality_constraints_tensors = (
481+
final_transform(equality_constraints[0]),
482+
final_transform(equality_constraints[1]),
483+
)
440484
return (
441485
obj_weights_tensor,
442486
outcome_constraints_tensors,
443487
linear_constraints_tensors,
444488
pending_obs_tensors,
445489
obj_thresholds_tensor,
446490
pruning_target_tensor,
491+
equality_constraints_tensors,
447492
)
448493

449494

@@ -700,7 +745,7 @@ def get_pareto_frontier_and_configs(
700745
if obj_t is not None:
701746
obj_t = array_to_tensor(obj_t)
702747
# Transform to tensors.
703-
obj_w, oc_c, _, _, _, _ = validate_and_apply_final_transform(
748+
obj_w, oc_c, _, _, _, _, _ = validate_and_apply_final_transform(
704749
objective_weights=objective_weights,
705750
outcome_constraints=outcome_constraints,
706751
linear_constraints=None,

ax/adapter/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import numpy as np
1313
from ax.adapter.adapter_utils import (
14-
extract_parameter_constraints,
14+
extract_inequality_constraints,
1515
extract_search_space_digest,
1616
get_fixed_features,
1717
parse_observation_features,
@@ -92,7 +92,7 @@ def _gen(
9292
# Get fixed features
9393
fixed_features_dict = get_fixed_features(fixed_features, self.parameters)
9494
# Extract param constraints
95-
linear_constraints = extract_parameter_constraints(
95+
linear_constraints = extract_inequality_constraints(
9696
search_space.parameter_constraints, self.parameters
9797
)
9898
# Extract generated points.

ax/adapter/tests/test_adapter_utils.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
_get_fresh_pairwise_trial_indices,
1616
arm_to_np_array,
1717
can_map_to_binary,
18+
extract_equality_constraints,
19+
extract_inequality_constraints,
1820
extract_objective_weight_matrix,
1921
extract_search_space_digest,
2022
feasible_hypervolume,
@@ -35,6 +37,7 @@
3537
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
3638
from ax.core.outcome_constraint import ObjectiveThreshold, OutcomeConstraint
3739
from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter
40+
from ax.core.parameter_constraint import ParameterConstraint
3841
from ax.core.search_space import SearchSpace
3942
from ax.core.types import ComparisonOp
4043
from ax.exceptions.core import UserInputError
@@ -377,6 +380,7 @@ def test_validate_and_apply_final_transform_with_target_point(self) -> None:
377380
_,
378381
_,
379382
target_p,
383+
_,
380384
) = validate_and_apply_final_transform(
381385
objective_weights=objective_weights,
382386
outcome_constraints=outcome_constraints,
@@ -412,6 +416,7 @@ def test_validate_and_apply_final_transform_none_target_point(self) -> None:
412416
_,
413417
_,
414418
target_p,
419+
_,
415420
) = validate_and_apply_final_transform(
416421
objective_weights=objective_weights,
417422
outcome_constraints=outcome_constraints,
@@ -652,3 +657,90 @@ def _attach(
652657
self.assertNotIn(0, result)
653658
self.assertNotIn(1, result)
654659
self.assertIn(2, result)
660+
661+
def test_extract_inequality_constraints(self) -> None:
662+
param_names = ["x", "y"]
663+
ineq = ParameterConstraint(inequality="x + y <= 1")
664+
eq = ParameterConstraint(equality="x + y == 1")
665+
666+
# Only inequality constraints are extracted
667+
result = extract_inequality_constraints([ineq, eq], param_names)
668+
self.assertIsNotNone(result)
669+
assert result is not None
670+
A, b = result
671+
self.assertEqual(A.shape, (1, 2))
672+
self.assertEqual(b.shape, (1, 1))
673+
np.testing.assert_array_equal(A[0], [1.0, 1.0])
674+
np.testing.assert_array_equal(b[0], [1.0])
675+
676+
# Returns None when no inequality constraints
677+
result = extract_inequality_constraints([eq], param_names)
678+
self.assertIsNone(result)
679+
680+
# Returns None for empty list
681+
result = extract_inequality_constraints([], param_names)
682+
self.assertIsNone(result)
683+
684+
def test_extract_equality_constraints(self) -> None:
685+
param_names = ["x", "y"]
686+
ineq = ParameterConstraint(inequality="x + y <= 1")
687+
eq = ParameterConstraint(equality="x + y == 1")
688+
689+
# Only equality constraints are extracted
690+
result = extract_equality_constraints([ineq, eq], param_names)
691+
self.assertIsNotNone(result)
692+
assert result is not None
693+
A, b = result
694+
self.assertEqual(A.shape, (1, 2))
695+
self.assertEqual(b.shape, (1, 1))
696+
np.testing.assert_array_equal(A[0], [1.0, 1.0])
697+
np.testing.assert_array_equal(b[0], [1.0])
698+
699+
# Returns None when no equality constraints
700+
result = extract_equality_constraints([ineq], param_names)
701+
self.assertIsNone(result)
702+
703+
def test_extract_constraints_mixed(self) -> None:
704+
"""Both functions correctly partition a mixed list."""
705+
param_names = ["x", "y"]
706+
ineq1 = ParameterConstraint(inequality="x <= 0.5")
707+
ineq2 = ParameterConstraint(inequality="y <= 0.8")
708+
eq1 = ParameterConstraint(equality="x + y == 1")
709+
710+
ineq_result = extract_inequality_constraints([ineq1, eq1, ineq2], param_names)
711+
eq_result = extract_equality_constraints([ineq1, eq1, ineq2], param_names)
712+
713+
assert ineq_result is not None
714+
assert eq_result is not None
715+
self.assertEqual(ineq_result[0].shape, (2, 2)) # 2 inequalities
716+
self.assertEqual(eq_result[0].shape, (1, 2)) # 1 equality
717+
718+
def test_validate_and_apply_final_transform_equality_constraints(self) -> None:
719+
"""equality_constraints are converted to tensors."""
720+
objective_weights = np.array([1.0, 0.0])
721+
A_eq = np.array([[1.0, 1.0]])
722+
b_eq = np.array([[1.0]])
723+
724+
_, _, _, _, _, _, eq_c = validate_and_apply_final_transform(
725+
objective_weights=objective_weights,
726+
outcome_constraints=None,
727+
linear_constraints=None,
728+
pending_observations=None,
729+
equality_constraints=(A_eq, b_eq),
730+
)
731+
self.assertIsNotNone(eq_c)
732+
assert eq_c is not None
733+
self.assertTrue(torch.equal(eq_c[0], torch.tensor(A_eq)))
734+
self.assertTrue(torch.equal(eq_c[1], torch.tensor(b_eq)))
735+
736+
def test_validate_and_apply_final_transform_no_equality_constraints(self) -> None:
737+
"""equality_constraints defaults to None."""
738+
objective_weights = np.array([1.0])
739+
740+
_, _, _, _, _, _, eq_c = validate_and_apply_final_transform(
741+
objective_weights=objective_weights,
742+
outcome_constraints=None,
743+
linear_constraints=None,
744+
pending_observations=None,
745+
)
746+
self.assertIsNone(eq_c)

ax/adapter/torch.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
_get_fresh_pairwise_trial_indices,
2121
arm_to_np_array,
2222
array_to_observation_data,
23+
extract_equality_constraints,
24+
extract_inequality_constraints,
2325
extract_objective_thresholds,
2426
extract_objective_weight_matrix,
2527
extract_outcome_constraints,
26-
extract_parameter_constraints,
2728
extract_search_space_digest,
2829
get_fixed_features,
2930
observation_data_to_array,
@@ -1044,7 +1045,10 @@ def _get_transformed_model_gen_args(
10441045
arm=optimization_config.pruning_target_parameterization,
10451046
parameters=self.parameters,
10461047
)
1047-
linear_constraints = extract_parameter_constraints(
1048+
linear_constraints = extract_inequality_constraints(
1049+
search_space.parameter_constraints, self.parameters
1050+
)
1051+
equality_constraints_np = extract_equality_constraints(
10481052
search_space.parameter_constraints, self.parameters
10491053
)
10501054
fixed_features_dict = get_fixed_features(fixed_features, self.parameters)
@@ -1065,14 +1069,15 @@ def _get_transformed_model_gen_args(
10651069
pending_array = pending_observations_as_array_list(
10661070
pending_observations, self.outcomes, self.parameters
10671071
)
1068-
obj_w, out_c, lin_c, pend_o, obj_t, pruning_target_p = (
1072+
obj_w, out_c, lin_c, pend_o, obj_t, pruning_target_p, eq_c = (
10691073
validate_and_apply_final_transform(
10701074
objective_weights=objective_weights,
10711075
outcome_constraints=outcome_constraints,
10721076
linear_constraints=linear_constraints,
10731077
pending_observations=pending_array,
10741078
objective_thresholds=objective_thresholds,
10751079
pruning_target_point=pruning_target_point,
1080+
equality_constraints=equality_constraints_np,
10761081
final_transform=self._array_to_tensor,
10771082
)
10781083
)
@@ -1089,6 +1094,7 @@ def _get_transformed_model_gen_args(
10891094
outcome_constraints=out_c,
10901095
objective_thresholds=obj_t,
10911096
linear_constraints=lin_c,
1097+
equality_constraints=eq_c,
10921098
fixed_features=fixed_features_dict,
10931099
pending_observations=pend_o,
10941100
model_gen_options=model_gen_options or {},

ax/generators/torch_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ class TorchOptConfig:
5050
linear_constraints: A tuple of (A, b). For k linear constraints on
5151
d-dimensional x, A is (k x d) and b is (k x 1) such that
5252
A x <= b for feasible x.
53+
equality_constraints: A tuple of (A, b). For k equality constraints on
54+
d-dimensional x, A is (k x d) and b is (k x 1) such that
55+
A x = b for feasible x.
5356
fixed_features: A map {feature_index: value} for features that
5457
should be fixed to a particular value during generation.
5558
pending_observations: A list of m (k_i x d) feature tensors X
@@ -91,6 +94,7 @@ class TorchOptConfig:
9194
outcome_constraints: tuple[Tensor, Tensor] | None = None
9295
objective_thresholds: Tensor | None = None
9396
linear_constraints: tuple[Tensor, Tensor] | None = None
97+
equality_constraints: tuple[Tensor, Tensor] | None = None
9498
fixed_features: dict[int, float] | None = None
9599
pending_observations: list[Tensor] | None = None
96100
model_gen_options: TConfig = field(default_factory=dict)

0 commit comments

Comments
 (0)