Skip to content

Commit c62f3fd

Browse files
mgarrardfacebook-github-bot
authored andcommitted
Implement Chebyshev center fallback for center node (#4477)
Summary: This diff implements Chebyshev center if naive centering fails due to violation of parameter constraints. Reviewed By: sdaulton Differential Revision: D85712042
1 parent 5dccf83 commit c62f3fd

3 files changed

Lines changed: 324 additions & 139 deletions

File tree

ax/core/search_space.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88

99
from __future__ import annotations
1010

11+
import math
12+
1113
import warnings
1214
from collections.abc import Mapping, Sequence
1315
from dataclasses import dataclass, field
1416
from logging import Logger
1517

18+
import numpy as np
19+
1620
import pandas as pd
1721
from ax import core
1822
from ax.core.arm import Arm
@@ -36,6 +40,9 @@
3640
from ax.utils.common.constants import Keys
3741
from ax.utils.common.logger import get_logger
3842
from pyre_extensions import none_throws
43+
from scipy.optimize import linprog
44+
45+
from scipy.special import expit, logit
3946

4047

4148
logger: Logger = get_logger(__name__)
@@ -572,6 +579,134 @@ def clone(self) -> SearchSpace:
572579
parameter_constraints=[pc.clone() for pc in self._parameter_constraints],
573580
)
574581

582+
def compute_naive_center(self) -> TParameterization:
583+
"""Compute the naive center of the search space.
584+
585+
For range parameters, the center is the midpoint of the range. If the
586+
parameter is log-scale, then the center point will correspond to the
587+
mid-point in log-scale. If the parameter is logit-scale, then the center
588+
point will correspond to the mid-point in logit-scale.
589+
For choice parameters, the center point is determined as the value
590+
that is at the middle of the values list.
591+
For both choice and integer range parameters, ties are broken in
592+
favor of the larger value / index. For example, a binary parameter with
593+
values [0, 1] will be sampled as 1.
594+
Fixed parameters are returned at their only allowed value.
595+
596+
Returns:
597+
A parameterization with the center values for each parameter.
598+
"""
599+
parameters = {}
600+
derived_params = []
601+
for name, p in self.parameters.items():
602+
if isinstance(p, RangeParameter):
603+
if p.logit_scale:
604+
center = expit((logit(p.lower) + logit(p.upper)) / 2.0)
605+
elif p.log_scale:
606+
center = 10 ** ((math.log10(p.lower) + math.log10(p.upper)) / 2.0)
607+
else:
608+
center = (float(p.lower) + float(p.upper)) / 2.0
609+
parameters[name] = p.cast(center)
610+
elif isinstance(p, ChoiceParameter):
611+
parameters[name] = p.values[int(len(p.values) / 2)]
612+
elif isinstance(p, FixedParameter):
613+
parameters[name] = p.value
614+
elif isinstance(p, DerivedParameter):
615+
derived_params.append(p)
616+
else:
617+
raise NotImplementedError(f"Parameter type {type(p)} is not supported.")
618+
for p in derived_params:
619+
parameters[p.name] = p.compute(parameters=parameters)
620+
if self.is_hierarchical:
621+
parameters = self._cast_parameterization(parameters=parameters)
622+
return parameters
623+
624+
def compute_chebyshev_center(self) -> dict[str, float] | None:
625+
"""Compute the Chebyshev center of the constraint polytope.
626+
627+
The Chebyshev center is the center of the largest inscribed ball in the
628+
feasible region defined by the parameter constraints. This is computed
629+
by solving a linear program. It is most limited by the tightest constraint.
630+
631+
For a polytope defined by a @ x <= b, the Chebyshev center (x_c, r) is
632+
the solution to:
633+
maximize r, where r is the radius of the inscribed ball
634+
subject to: a_i^T x + r ||a_i||_2 <= b_i for all i
635+
636+
Note: this only considers natural (non-log, non-logit) range parameters.
637+
Other parameter types are handled naively via compute_naive_center.
638+
639+
Returns:
640+
A dictionary mapping parameter names to values at the Chebyshev center,
641+
or None if the problem is infeasible.
642+
"""
643+
# Only consider non-log, non-logit range parameters
644+
natural_range_params = {
645+
name: param
646+
for name, param in self.range_parameters.items()
647+
if not param.log_scale and not param.logit_scale
648+
}
649+
650+
if not natural_range_params:
651+
return {}
652+
653+
constraint_matrix = []
654+
bound_vector = []
655+
param_names = list(natural_range_params.keys())
656+
num_params = len(natural_range_params)
657+
param_name_to_idx = {name: idx for idx, name in enumerate(param_names)}
658+
659+
# Add parameter constraints
660+
for constraint in self.parameter_constraints:
661+
row = np.zeros(num_params)
662+
for param_name, weight in constraint.constraint_dict.items():
663+
if param_name in param_name_to_idx:
664+
row[param_name_to_idx[param_name]] = weight
665+
666+
constraint_matrix.append(row)
667+
bound_vector.append(constraint.bound)
668+
669+
# Add parameter bounds
670+
for name, idx in param_name_to_idx.items():
671+
param = natural_range_params[name]
672+
# lower bound: -x_i <= -lower_i
673+
row_lower = np.zeros(num_params)
674+
row_lower[idx] = -1.0
675+
constraint_matrix.append(row_lower)
676+
bound_vector.append(-float(param.lower))
677+
678+
# upper bound: x_i <= upper_i
679+
row_upper = np.zeros(num_params)
680+
row_upper[idx] = 1.0
681+
constraint_matrix.append(row_upper)
682+
bound_vector.append(float(param.upper))
683+
684+
constraint_matrix = np.array(constraint_matrix)
685+
bound_vector = np.array(bound_vector)
686+
687+
# Compute norm for each vector in constraint matrix
688+
row_norms = np.linalg.norm(constraint_matrix, axis=1)
689+
augmented_constraint_matrix = np.column_stack([constraint_matrix, row_norms])
690+
691+
# Set objective vector which maximizes r (minimize -r == maximize r)
692+
radius_objective_vector = np.zeros(num_params + 1)
693+
radius_objective_vector[-1] = -1.0
694+
result = linprog(
695+
c=radius_objective_vector,
696+
A_ub=augmented_constraint_matrix,
697+
b_ub=bound_vector,
698+
bounds=[(None, None)] * num_params + [(0, None)], # no bounds except r >= 0
699+
)
700+
701+
if not result.success or result.x is None:
702+
return None
703+
704+
center_values = result.x[:num_params] # remove r
705+
center_dict = {
706+
name: float(center_values[param_name_to_idx[name]]) for name in param_names
707+
}
708+
return center_dict
709+
575710
def _validate_parameter_constraints(
576711
self, parameter_constraints: list[ParameterConstraint]
577712
) -> None:

ax/generation_strategy/center_generation_node.py

Lines changed: 56 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
import math
109
from dataclasses import dataclass
1110
from typing import Any
1211

@@ -16,13 +15,8 @@
1615
from ax.core.experiment import Experiment
1716
from ax.core.generator_run import GeneratorRun
1817
from ax.core.observation import ObservationFeatures
19-
from ax.core.parameter import (
20-
ChoiceParameter,
21-
DerivedParameter,
22-
FixedParameter,
23-
RangeParameter,
24-
)
25-
from ax.core.search_space import SearchSpace
18+
from ax.core.parameter import DerivedParameter
19+
from ax.core.search_space import HierarchicalSearchSpace, SearchSpace
2620
from ax.core.types import TParameterization
2721
from ax.exceptions.generation_strategy import AxGenerationException
2822
from ax.generation_strategy.external_generation_node import ExternalGenerationNode
@@ -31,7 +25,6 @@
3125
AutoTransitionAfterGenOrExhaustion,
3226
)
3327
from pyre_extensions import none_throws
34-
from scipy.special import expit, logit
3528

3629

3730
@dataclass(init=False)
@@ -88,17 +81,16 @@ def gen(
8881
"""
8982
# Check if center already exists or is infeasible
9083
self.search_space = experiment.search_space
91-
center_params = self._compute_center_params()
92-
search_space = none_throws(self.search_space)
84+
center_params = self.compute_center_params()
9385

94-
# Check if center already exists in experiment
95-
center_arm = Arm(parameters=center_params)
96-
if center_arm.signature in experiment.arms_by_signature:
86+
# Check if unable to find a suitable center
87+
if center_params is None:
9788
self._should_skip = True
9889
return None
9990

100-
# Check if center violates parameter constraints
101-
if not search_space.check_membership(parameterization=center_params):
91+
# Check if center already exists in experiment
92+
center_arm = Arm(parameters=center_params)
93+
if center_arm.signature in experiment.arms_by_signature:
10294
self._should_skip = True
10395
return None
10496

@@ -112,33 +104,43 @@ def gen(
112104
**gs_gen_kwargs,
113105
)
114106

115-
def _compute_center_params(self) -> TParameterization:
116-
"""Compute the center of the search space."""
107+
def compute_center_params(self) -> TParameterization | None:
108+
"""Compute the center of the search space.
109+
110+
Returns:
111+
The center parameters, or None if the center cannot be computed
112+
(e.g., due to infeasible constraints).
113+
"""
117114
search_space = none_throws(self.search_space)
118-
parameters = {}
119-
derived_params = []
120-
for name, p in search_space.parameters.items():
121-
if isinstance(p, RangeParameter):
122-
if p.logit_scale:
123-
# Leverage scipy's numerically stable logit and expit functions
124-
center = expit((logit(p.lower) + logit(p.upper)) / 2.0)
125-
elif p.log_scale:
126-
center = 10 ** ((math.log10(p.lower) + math.log10(p.upper)) / 2.0)
127-
else:
128-
center = (float(p.lower) + float(p.upper)) / 2.0
129-
parameters[name] = p.cast(center)
130-
elif isinstance(p, ChoiceParameter):
131-
parameters[name] = p.values[int(len(p.values) / 2)]
132-
elif isinstance(p, FixedParameter):
133-
parameters[name] = p.value
134-
elif isinstance(p, DerivedParameter):
135-
derived_params.append(p)
136-
else:
137-
raise NotImplementedError(f"Parameter type {type(p)} is not supported.")
138-
for p in derived_params:
139-
parameters[p.name] = p.compute(parameters=parameters)
140-
if search_space.is_hierarchical:
141-
parameters = search_space._cast_parameterization(parameters=parameters)
115+
parameters = search_space.compute_naive_center()
116+
117+
# Check for search space membership, which will check if the generated
118+
# point satisfies the parameter constraints. Fallback to Chebyshev center
119+
if not search_space.check_membership(parameterization=parameters):
120+
chebyshev_center = search_space.compute_chebyshev_center()
121+
if chebyshev_center is not None:
122+
for name, value in chebyshev_center.items():
123+
if name in parameters:
124+
parameters[name] = search_space[name].cast(value)
125+
126+
# recompute derived parameters using the updated parameter values
127+
derived_params = [
128+
p
129+
for p in search_space.parameters.values()
130+
if isinstance(p, DerivedParameter)
131+
]
132+
for p in derived_params:
133+
parameters[p.name] = p.compute(parameters=parameters)
134+
135+
if isinstance(search_space, HierarchicalSearchSpace):
136+
parameters = search_space._cast_parameterization(parameters=parameters)
137+
138+
# Return None if something goes wrong, or some non-range parameter
139+
# remains out of search space
140+
if chebyshev_center is None or not search_space.check_membership(
141+
parameterization=parameters
142+
):
143+
return None
142144
return parameters
143145

144146
def get_next_candidate(
@@ -156,18 +158,18 @@ def get_next_candidate(
156158
favor of the larger value / index. For example, a binary parameter with
157159
values [0, 1] will be sampled as 1.
158160
Fixed parameters are returned at their only allowed value.
159-
"""
160-
search_space = none_throws(self.search_space)
161-
parameters = self._compute_center_params()
162161
163-
# Check for search space membership, which will check if the generated
164-
# point satisfies the parameter constraints.
165-
if not search_space.check_membership(parameterization=parameters):
166-
# TODO: Improve this handling by instead choosing the point
167-
# in the center of the feasible set (e.g. by finding the)
168-
# Chebyshev center of the constraint polytope.
162+
Note: If range naive midpoint fails to remain within parameter constraints, we
163+
attempt to compute the Chebyshev center of the constraint polytope defined by
164+
parameter bounds and parameter constraints w.r.t non-log range parameters.
165+
This finds the center of the largest inscribed ball in the feasible region.
166+
"""
167+
center_params = self.compute_center_params()
168+
if center_params is None:
169+
# raising an exception here will cause fallback to sobol, currently
170+
# it should be very unlikely to hit this case
169171
raise AxGenerationException(
170-
"Center of the search space does not satisfy parameter constraints. "
171-
"The generation strategy will fallback to Sobol. "
172+
"Center of the search space does not satisfy parameter "
173+
"constraints. The generation strategy will fallback to Sobol. "
172174
)
173-
return parameters
175+
return center_params

0 commit comments

Comments
 (0)