Skip to content

Commit bfcbff1

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Refactor Objective and OutcomeConstraint to expression-based
Summary: Refactors the Objective class to hold an expression string instead of Metric objects. This is Step 1 of decoupling Objective from Metric, enabling objectives to be defined purely by string expressions like "accuracy", "-loss", "2*acc + recall", or "acc, -loss" (multi-objective). Key changes: - Objective.__init__ now takes an `expression: str` as its primary argument, eagerly parsed via SymPy into cached metric names, weights, and objective type flags. - The deprecated `metric`/`minimize` kwargs are preserved for backward compatibility and emit DeprecationWarning. - New properties: `expression`, `metric_weights`, `metric_weights_by_objective`, `is_single_objective`, `is_scalarized_objective`, `is_multi_objective`. - Old properties that returned Metric objects (`metric`, `metrics`, `metric_signatures`, `get_unconstrainable_metrics`) now emit DeprecationWarning and raise NotImplementedError. - New `get_unconstrainable_metric_names()` replaces `get_unconstrainable_metrics()`. - MultiObjective and ScalarizedObjective are preserved as deprecated subclasses that build expression strings internally, keeping isinstance checks working during transition. - Added three expression-parsing helpers to ax/utils/common/sympy.py: `parse_objective_expression`, `extract_metric_names_from_objective_expr`, `extract_metric_weights_from_objective_expr`. - All private cached attributes removed from Objective.__init__ except _expression_str. All derived values become computed properties. - Refactored OutcomeConstraint, ObjectiveThreshold, and ScalarizedOutcomeConstraint to hold expression strings instead of Metric objects. - OutcomeConstraint now takes an expression string (e.g. "qps >= 700", "loss <= 0.5%") as its primary constructor argument. Differential Revision: D93520819
1 parent 05f34d5 commit bfcbff1

129 files changed

Lines changed: 4646 additions & 2800 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

ax/adapter/adapter_utils.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from ax.core.arm import Arm
2525
from ax.core.experiment import Experiment
26-
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
26+
from ax.core.objective import Objective
2727
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2828
from ax.core.optimization_config import (
2929
MultiObjectiveOptimizationConfig,
@@ -206,6 +206,7 @@ def extract_objective_thresholds(
206206
objective_thresholds: TRefPoint,
207207
objective: Objective,
208208
outcomes: list[str],
209+
experiment: Experiment,
209210
) -> npt.NDArray | None:
210211
"""Extracts objective thresholds' values, in the order of `outcomes`.
211212
@@ -221,6 +222,7 @@ def extract_objective_thresholds(
221222
objective_thresholds: Objective thresholds to extract values from.
222223
objective: The corresponding Objective, for validation purposes.
223224
outcomes: n-length list of names of metrics.
225+
experiment: The experiment, used to map metric names to signatures.
224226
225227
Returns:
226228
(n,) array of thresholds
@@ -230,15 +232,19 @@ def extract_objective_thresholds(
230232

231233
objective_threshold_dict = {}
232234
for ot in objective_thresholds:
235+
ot_signature = experiment.get_metric(ot.metric_names[0]).signature
233236
if ot.relative:
234237
raise ValueError(
235-
f"Objective {ot.metric.signature} has a relative threshold that "
238+
f"Objective {ot_signature} has a relative threshold that "
236239
f"is not supported here."
237240
)
238-
objective_threshold_dict[ot.metric.signature] = ot.bound
241+
objective_threshold_dict[ot_signature] = ot.bound
239242

240243
# Check that all thresholds correspond to a metric.
241-
if set(objective_threshold_dict.keys()).difference(set(objective.metric_names)):
244+
obj_metric_signatures = [
245+
experiment.get_metric(name).signature for name in objective.metric_names
246+
]
247+
if set(objective_threshold_dict.keys()).difference(set(obj_metric_signatures)):
242248
raise ValueError(
243249
"Some objective thresholds do not have corresponding metrics. "
244250
f"Got {objective_thresholds=} and {objective=}."
@@ -252,7 +258,9 @@ def extract_objective_thresholds(
252258
return obj_t
253259

254260

255-
def extract_objective_weights(objective: Objective, outcomes: list[str]) -> npt.NDArray:
261+
def extract_objective_weights(
262+
objective: Objective, outcomes: list[str], experiment: Experiment
263+
) -> npt.NDArray:
256264
"""Extract a weights for objectives.
257265
258266
Weights are for a maximization problem.
@@ -268,29 +276,24 @@ def extract_objective_weights(objective: Objective, outcomes: list[str]) -> npt.
268276
269277
Args:
270278
objective: Objective to extract weights from.
271-
outcomes: n-length list of names of metrics.
279+
outcomes: n-length list of metric signatures.
280+
experiment: The experiment, used to map metric names to signatures.
272281
273282
Returns:
274283
n-length array of weights.
275284
276285
"""
277286
objective_weights = np.zeros(len(outcomes))
278-
if isinstance(objective, ScalarizedObjective):
279-
s = -1.0 if objective.minimize else 1.0
280-
for obj_metric, obj_weight in objective.metric_weights:
281-
objective_weights[outcomes.index(obj_metric.signature)] = obj_weight * s
282-
elif isinstance(objective, MultiObjective):
283-
for obj in objective.objectives:
284-
s = -1.0 if obj.minimize else 1.0
285-
objective_weights[outcomes.index(obj.metric.signature)] = s
286-
else:
287-
s = -1.0 if objective.minimize else 1.0
288-
objective_weights[outcomes.index(objective.metric.signature)] = s
287+
# metric_weights returns sign-encoded (name, weight) tuples for all
288+
# objective types (single, scalarized, multi).
289+
for obj_metric_name, obj_weight in objective.metric_weights:
290+
sig = experiment.get_metric(obj_metric_name).signature
291+
objective_weights[outcomes.index(sig)] = obj_weight
289292
return objective_weights
290293

291294

292295
def extract_objective_weight_matrix(
293-
objective: Objective, outcomes: list[str]
296+
objective: Objective, outcomes: list[str], experiment: Experiment
294297
) -> npt.NDArray:
295298
"""Extract a 2D weight matrix for objectives.
296299
@@ -304,23 +307,31 @@ def extract_objective_weight_matrix(
304307
305308
Args:
306309
objective: Objective to extract weights from.
307-
outcomes: n-length list of names of metrics.
310+
outcomes: n-length list of signatures of metrics.
308311
309312
Returns:
310313
``(n_objectives, n)`` array of weights.
311314
"""
312-
if isinstance(objective, MultiObjective):
315+
if objective.is_multi_objective:
313316
rows: list[npt.NDArray] = []
314-
for obj in objective.objectives:
315-
rows.append(extract_objective_weights(obj, outcomes))
317+
for name, weight in objective.metric_weights:
318+
rows.append(
319+
extract_objective_weights(
320+
objective=Objective(expression=f"{weight} * {name}"),
321+
outcomes=outcomes,
322+
experiment=experiment,
323+
)
324+
)
316325
return np.stack(rows, axis=0)
317326
else:
318327
# Single row – covers Objective and ScalarizedObjective
319-
return extract_objective_weights(objective, outcomes).reshape(1, -1)
328+
return extract_objective_weights(objective, outcomes, experiment).reshape(1, -1)
320329

321330

322331
def extract_outcome_constraints(
323-
outcome_constraints: list[OutcomeConstraint], outcomes: list[str]
332+
outcome_constraints: list[OutcomeConstraint],
333+
outcomes: list[str],
334+
experiment: Experiment,
324335
) -> TBounds:
325336
if len(outcome_constraints) == 0:
326337
return None
@@ -330,11 +341,11 @@ def extract_outcome_constraints(
330341
for i, c in enumerate(outcome_constraints):
331342
s = 1 if c.op == ComparisonOp.LEQ else -1
332343
if isinstance(c, ScalarizedOutcomeConstraint):
333-
for c_metric, c_weight in c.metric_weights:
334-
j = outcomes.index(c_metric.signature)
344+
for c_metric_name, c_weight in c.metric_weights:
345+
j = outcomes.index(experiment.get_metric(c_metric_name).signature)
335346
A[i, j] = s * c_weight
336347
else:
337-
j = outcomes.index(c.metric.signature)
348+
j = outcomes.index(experiment.get_metric(c.metric_names[0]).signature)
338349
A[i, j] = s
339350
b[i, 0] = s * c.bound
340351
return (A, b)
@@ -645,16 +656,20 @@ def get_pareto_frontier_and_configs(
645656
)
646657
# Extract weights, constraints, and objective_thresholds
647658
objective_weights = extract_objective_weight_matrix(
648-
objective=optimization_config.objective, outcomes=adapter.outcomes
659+
objective=optimization_config.objective,
660+
outcomes=adapter.outcomes,
661+
experiment=adapter._experiment,
649662
)
650663
outcome_constraints = extract_outcome_constraints(
651664
outcome_constraints=optimization_config.outcome_constraints,
652665
outcomes=adapter.outcomes,
666+
experiment=adapter._experiment,
653667
)
654668
obj_t = extract_objective_thresholds(
655669
objective_thresholds=optimization_config.objective_thresholds,
656670
objective=optimization_config.objective,
657671
outcomes=adapter.outcomes,
672+
experiment=adapter._experiment,
658673
)
659674
if obj_t is not None:
660675
obj_t = array_to_tensor(obj_t)
@@ -1113,6 +1128,7 @@ def observation_features_to_array(
11131128
def feasible_hypervolume(
11141129
optimization_config: MultiObjectiveOptimizationConfig,
11151130
values: dict[str, npt.NDArray],
1131+
experiment: Experiment,
11161132
) -> npt.NDArray:
11171133
"""Compute the feasible hypervolume each iteration.
11181134
@@ -1121,34 +1137,35 @@ def feasible_hypervolume(
11211137
values: Dictionary from metric name to array of value at each
11221138
iteration (each array is `n`-dim). If optimization config contains
11231139
outcome constraints, values for them must be present in `values`.
1140+
experiment: The experiment, used to map metric names to signatures.
11241141
11251142
Returns: Array of feasible hypervolumes.
11261143
"""
11271144
# Get objective at each iteration
11281145
obj_threshold_dict = {
1129-
ot.metric.signature: ot.bound for ot in optimization_config.objective_thresholds
1146+
experiment.get_metric(ot.metric_names[0]).signature: ot.bound
1147+
for ot in optimization_config.objective_thresholds
11301148
}
1131-
f_vals = np.hstack(
1132-
[
1133-
values[m.signature].reshape(-1, 1)
1134-
for m in optimization_config.objective.metrics
1135-
]
1136-
)
1137-
obj_thresholds = np.array(
1138-
[obj_threshold_dict[m.signature] for m in optimization_config.objective.metrics]
1139-
)
1149+
obj_metric_names = optimization_config.objective.metric_names
1150+
obj_metrics = [experiment.get_metric(name) for name in obj_metric_names]
1151+
f_vals = np.hstack([values[m.signature].reshape(-1, 1) for m in obj_metrics])
1152+
obj_thresholds = np.array([obj_threshold_dict[m.signature] for m in obj_metrics])
11401153
# Set infeasible points to be the objective threshold
11411154
for oc in optimization_config.outcome_constraints:
11421155
if oc.relative:
11431156
raise ValueError(
11441157
"Benchmark aggregation does not support relative constraints"
11451158
)
1146-
g = values[oc.metric.signature]
1159+
oc_sig = experiment.get_metric(oc.metric_names[0]).signature
1160+
g = values[oc_sig]
11471161
feas = g <= oc.bound if oc.op == ComparisonOp.LEQ else g >= oc.bound
11481162
f_vals[~feas] = obj_thresholds
11491163

1164+
# Derive objective directions from the objective's metric_weights.
1165+
# Positive weight = maximize, negative weight = minimize.
1166+
obj_weight_dict = dict(optimization_config.objective.metric_weights)
11501167
obj_weights = np.array(
1151-
[-1 if m.lower_is_better else 1 for m in optimization_config.objective.metrics]
1168+
[1 if obj_weight_dict[name] > 0 else -1 for name in obj_metric_names]
11521169
)
11531170
obj_thresholds = obj_thresholds * obj_weights
11541171
f_vals = f_vals * obj_weights

ax/adapter/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __init__(
201201
"Optimization config is required when "
202202
"`fit_tracking_metrics` is False."
203203
)
204-
self.outcomes = sorted(self._optimization_config.metrics.keys())
204+
self.outcomes = sorted(self._optimization_config.metric_names)
205205

206206
# Set training data (in the raw / untransformed space). This also omits
207207
# out-of-design and abandoned observations depending on the corresponding flags.
@@ -466,10 +466,14 @@ def _set_status_quo(self, experiment: Experiment) -> None:
466466
)
467467
return
468468

469-
if has_map_metrics(optimization_config=self._optimization_config):
469+
if has_map_metrics(
470+
metrics=experiment.get_metrics(
471+
metric_names=[*self._optimization_config.metric_names]
472+
)
473+
):
470474
self._status_quo = _combine_multiple_status_quo_observations(
471475
status_quo_observations=status_quo_observations,
472-
metrics=set(none_throws(self._optimization_config).metrics),
476+
metrics=none_throws(self._optimization_config).metric_names,
473477
)
474478
else:
475479
logger.warning(
@@ -689,7 +693,7 @@ def _get_transformed_gen_args(
689693
# Check that the optimization config has the same metrics as
690694
# the original one. Otherwise, we may attempt to optimize over
691695
# metrics that do not have a fitted model.
692-
outcomes = set(optimization_config.metrics.keys())
696+
outcomes = optimization_config.metric_names
693697
if not outcomes.issubset(self.outcomes):
694698
raise UnsupportedError(
695699
"When fit_tracking_metrics is False, the optimization config "

ax/adapter/cross_validation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ax.adapter.data_utils import ExperimentData
2222
from ax.adapter.observation_utils import unwrap_observation_data
2323
from ax.adapter.torch import TorchAdapter
24+
from ax.core.experiment import Experiment
2425
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2526
from ax.core.optimization_config import OptimizationConfig
2627
from ax.exceptions.core import UnsupportedError
@@ -573,6 +574,7 @@ def assess_model_fit(
573574
def has_good_opt_config_model_fit(
574575
optimization_config: OptimizationConfig,
575576
assess_model_fit_result: AssessModelFitResult,
577+
experiment: Experiment,
576578
) -> bool:
577579
"""Assess model fit for given diagnostics results across the optimization
578580
config metrics
@@ -584,7 +586,8 @@ def has_good_opt_config_model_fit(
584586
585587
Args:
586588
optimization_config: Objective/Outcome constraint metrics to assess
587-
diagnostics: Output of compute_diagnostics
589+
assess_model_fit_result: Output of assess_model_fit
590+
experiment: The experiment, used to map metric names to signatures.
588591
589592
Returns:
590593
Two dictionaries, one for good metrics, one for bad metrics, each
@@ -594,8 +597,9 @@ def has_good_opt_config_model_fit(
594597
# Bad fit criteria: Any objective metrics are poorly fit
595598
# TODO[]: Incl. outcome constraints in assessment
596599
has_good_opt_config_fit = all(
597-
(m.signature in assess_model_fit_result.good_fit_metrics_to_fisher_score)
598-
for m in optimization_config.objective.metrics
600+
experiment.get_metric(name).signature
601+
in assess_model_fit_result.good_fit_metrics_to_fisher_score
602+
for name in optimization_config.objective.metric_names
599603
)
600604
return has_good_opt_config_fit
601605

ax/adapter/discrete.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,18 @@ def _gen(
163163
objective_weights = None
164164
outcome_constraints = None
165165
else:
166-
validate_transformed_optimization_config(optimization_config, self.outcomes)
166+
validate_transformed_optimization_config(
167+
optimization_config, self.outcomes, experiment=self._experiment
168+
)
167169
objective_weights = extract_objective_weights(
168-
objective=optimization_config.objective, outcomes=self.outcomes
170+
objective=optimization_config.objective,
171+
outcomes=self.outcomes,
172+
experiment=self._experiment,
169173
)
170174
outcome_constraints = extract_outcome_constraints(
171175
outcome_constraints=optimization_config.outcome_constraints,
172176
outcomes=self.outcomes,
177+
experiment=self._experiment,
173178
)
174179

175180
# Get fixed features

ax/adapter/tests/test_adapter_utils.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from ax.utils.common.hash_utils import compute_lilo_input_hash
4444
from ax.utils.common.testutils import TestCase
4545
from ax.utils.testing.core_stubs import (
46+
get_branin_experiment,
4647
get_experiment_with_observations,
4748
get_hierarchical_search_space,
4849
get_search_space_for_range_values,
@@ -61,7 +62,7 @@ def test_feasible_hypervolume(self) -> None:
6162
),
6263
outcome_constraints=[
6364
OutcomeConstraint(
64-
mc,
65+
metric=mc,
6566
op=ComparisonOp.GEQ,
6667
bound=0,
6768
relative=False,
@@ -78,6 +79,11 @@ def test_feasible_hypervolume(self) -> None:
7879
),
7980
],
8081
)
82+
experiment = Experiment(
83+
search_space=SearchSpace(parameters=[]),
84+
optimization_config=optimization_config,
85+
tracking_metrics=[mc],
86+
)
8187
feas_hv = feasible_hypervolume(
8288
optimization_config,
8389
values={
@@ -106,6 +112,7 @@ def test_feasible_hypervolume(self) -> None:
106112
]
107113
),
108114
},
115+
experiment=experiment,
109116
)
110117
self.assertEqual(list(feas_hv), [0.0, 0.0, 1.0, 1.0])
111118

@@ -537,20 +544,24 @@ def test_can_map_to_binary(self) -> None:
537544
def test_extract_objective_weight_matrix(self) -> None:
538545
m1, m2, m3 = Metric(name="m1"), Metric(name="m2"), Metric(name="m3")
539546
outcomes = ["m1", "m2", "m3"]
547+
experiment = get_branin_experiment()
548+
experiment.add_metric(m1)
549+
experiment.add_metric(m2)
550+
experiment.add_metric(m3)
540551

541552
# Single Objective: one row, nonzero only in matching column.
542553
obj = Objective(metric=m1, minimize=False)
543-
result = extract_objective_weight_matrix(obj, outcomes)
554+
result = extract_objective_weight_matrix(obj, outcomes, experiment)
544555
np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0]])
545556

546557
# Minimization flips the sign.
547558
obj_min = Objective(metric=m2, minimize=True)
548-
result = extract_objective_weight_matrix(obj_min, outcomes)
559+
result = extract_objective_weight_matrix(obj_min, outcomes, experiment)
549560
np.testing.assert_array_equal(result, [[0.0, -1.0, 0.0]])
550561

551562
# ScalarizedObjective: single row with multiple nonzero entries.
552563
scal = ScalarizedObjective(metrics=[m1, m3], weights=[0.3, 0.7], minimize=False)
553-
result = extract_objective_weight_matrix(scal, outcomes)
564+
result = extract_objective_weight_matrix(scal, outcomes, experiment)
554565
np.testing.assert_array_almost_equal(result, [[0.3, 0.0, 0.7]])
555566

556567
# MultiObjective: one row per sub-objective.
@@ -560,7 +571,7 @@ def test_extract_objective_weight_matrix(self) -> None:
560571
Objective(metric=m3, minimize=True),
561572
]
562573
)
563-
result = extract_objective_weight_matrix(multi, outcomes)
574+
result = extract_objective_weight_matrix(multi, outcomes, experiment)
564575
np.testing.assert_array_equal(result, [[1.0, 0.0, 0.0], [0.0, 0.0, -1.0]])
565576

566577
def test_get_fresh_pairwise_trial_indices(self) -> None:

0 commit comments

Comments
 (0)