Skip to content

Commit 8ff8cf2

Browse files
mpolson64meta-codesync[bot]
authored andcommitted
Refactor Objective and OutcomeConstraint to expression-based
Summary: Refactors the Objective class to hold an expression string instead of Metric objects. This is Step 1 of decoupling Objective from Metric, enabling objectives to be defined purely by string expressions like "accuracy", "-loss", "2*acc + recall", or "acc, -loss" (multi-objective). Key changes: - Objective.__init__ now takes an `expression: str` as its primary argument, eagerly parsed via SymPy into cached metric names, weights, and objective type flags. - The deprecated `metric`/`minimize` kwargs are preserved for backward compatibility and emit DeprecationWarning. - New properties: `expression`, `metric_weights`, `metric_weights_by_objective`, `is_single_objective`, `is_scalarized_objective`, `is_multi_objective`. - Old properties that returned Metric objects (`metric`, `metrics`, `metric_signatures`, `get_unconstrainable_metrics`) now emit DeprecationWarning and raise NotImplementedError. - New `get_unconstrainable_metric_names()` replaces `get_unconstrainable_metrics()`. - MultiObjective and ScalarizedObjective are preserved as deprecated subclasses that build expression strings internally, keeping isinstance checks working during transition. - Added three expression-parsing helpers to ax/utils/common/sympy.py: `parse_objective_expression`, `extract_metric_names_from_objective_expr`, `extract_metric_weights_from_objective_expr`. - All private cached attributes removed from Objective.__init__ except _expression_str. All derived values become computed properties. - Refactored OutcomeConstraint, ObjectiveThreshold, and ScalarizedOutcomeConstraint to hold expression strings instead of Metric objects. - OutcomeConstraint now takes an expression string (e.g. "qps >= 700", "loss <= 0.5%") as its primary constructor argument. Differential Revision: D93520819
1 parent 1054802 commit 8ff8cf2

128 files changed

Lines changed: 4650 additions & 2827 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

ax/adapter/adapter_utils.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from ax.core.arm import Arm
2525
from ax.core.experiment import Experiment
26-
from ax.core.objective import MultiObjective, Objective, ScalarizedObjective
26+
from ax.core.objective import Objective
2727
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2828
from ax.core.optimization_config import (
2929
MultiObjectiveOptimizationConfig,
@@ -204,6 +204,7 @@ def extract_objective_thresholds(
204204
objective_thresholds: TRefPoint,
205205
objective: Objective,
206206
outcomes: list[str],
207+
experiment: Experiment,
207208
) -> npt.NDArray | None:
208209
"""Extracts objective thresholds' values, in the order of `outcomes`.
209210
@@ -219,6 +220,7 @@ def extract_objective_thresholds(
219220
objective_thresholds: Objective thresholds to extract values from.
220221
objective: The corresponding Objective, for validation purposes.
221222
outcomes: n-length list of names of metrics.
223+
experiment: The experiment, used to map metric names to signatures.
222224
223225
Returns:
224226
(n,) array of thresholds
@@ -228,15 +230,19 @@ def extract_objective_thresholds(
228230

229231
objective_threshold_dict = {}
230232
for ot in objective_thresholds:
233+
ot_signature = experiment.get_metric(ot.metric_names[0]).signature
231234
if ot.relative:
232235
raise ValueError(
233-
f"Objective {ot.metric.signature} has a relative threshold that "
236+
f"Objective {ot_signature} has a relative threshold that "
234237
f"is not supported here."
235238
)
236-
objective_threshold_dict[ot.metric.signature] = ot.bound
239+
objective_threshold_dict[ot_signature] = ot.bound
237240

238241
# Check that all thresholds correspond to a metric.
239-
if set(objective_threshold_dict.keys()).difference(set(objective.metric_names)):
242+
obj_metric_signatures = [
243+
experiment.get_metric(name).signature for name in objective.metric_names
244+
]
245+
if set(objective_threshold_dict.keys()).difference(set(obj_metric_signatures)):
240246
raise ValueError(
241247
"Some objective thresholds do not have corresponding metrics. "
242248
f"Got {objective_thresholds=} and {objective=}."
@@ -250,7 +256,9 @@ def extract_objective_thresholds(
250256
return obj_t
251257

252258

253-
def extract_objective_weights(objective: Objective, outcomes: list[str]) -> npt.NDArray:
259+
def extract_objective_weights(
260+
objective: Objective, outcomes: list[str], experiment: Experiment
261+
) -> npt.NDArray:
254262
"""Extract a weights for objectives.
255263
256264
Weights are for a maximization problem.
@@ -266,29 +274,24 @@ def extract_objective_weights(objective: Objective, outcomes: list[str]) -> npt.
266274
267275
Args:
268276
objective: Objective to extract weights from.
269-
outcomes: n-length list of names of metrics.
277+
outcomes: n-length list of metric signatures.
278+
experiment: The experiment, used to map metric names to signatures.
270279
271280
Returns:
272281
n-length array of weights.
273282
274283
"""
275284
objective_weights = np.zeros(len(outcomes))
276-
if isinstance(objective, ScalarizedObjective):
277-
s = -1.0 if objective.minimize else 1.0
278-
for obj_metric, obj_weight in objective.metric_weights:
279-
objective_weights[outcomes.index(obj_metric.signature)] = obj_weight * s
280-
elif isinstance(objective, MultiObjective):
281-
for obj in objective.objectives:
282-
s = -1.0 if obj.minimize else 1.0
283-
objective_weights[outcomes.index(obj.metric.signature)] = s
284-
else:
285-
s = -1.0 if objective.minimize else 1.0
286-
objective_weights[outcomes.index(objective.metric.signature)] = s
285+
# metric_weights returns sign-encoded (name, weight) tuples for all
286+
# objective types (single, scalarized, multi).
287+
for obj_metric_name, obj_weight in objective.metric_weights:
288+
sig = experiment.get_metric(obj_metric_name).signature
289+
objective_weights[outcomes.index(sig)] = obj_weight
287290
return objective_weights
288291

289292

290293
def extract_objective_weight_matrix(
291-
objective: Objective, outcomes: list[str]
294+
objective: Objective, outcomes: list[str], experiment: Experiment
292295
) -> npt.NDArray:
293296
"""Extract a 2D weight matrix for objectives.
294297
@@ -302,23 +305,31 @@ def extract_objective_weight_matrix(
302305
303306
Args:
304307
objective: Objective to extract weights from.
305-
outcomes: n-length list of names of metrics.
308+
outcomes: n-length list of signatures of metrics.
306309
307310
Returns:
308311
``(n_objectives, n)`` array of weights.
309312
"""
310-
if isinstance(objective, MultiObjective):
313+
if objective.is_multi_objective:
311314
rows: list[npt.NDArray] = []
312-
for obj in objective.objectives:
313-
rows.append(extract_objective_weights(obj, outcomes))
315+
for name, weight in objective.metric_weights:
316+
rows.append(
317+
extract_objective_weights(
318+
objective=Objective(expression=f"{weight} * {name}"),
319+
outcomes=outcomes,
320+
experiment=experiment,
321+
)
322+
)
314323
return np.stack(rows, axis=0)
315324
else:
316325
# Single row – covers Objective and ScalarizedObjective
317-
return extract_objective_weights(objective, outcomes).reshape(1, -1)
326+
return extract_objective_weights(objective, outcomes, experiment).reshape(1, -1)
318327

319328

320329
def extract_outcome_constraints(
321-
outcome_constraints: list[OutcomeConstraint], outcomes: list[str]
330+
outcome_constraints: list[OutcomeConstraint],
331+
outcomes: list[str],
332+
experiment: Experiment,
322333
) -> TBounds:
323334
if len(outcome_constraints) == 0:
324335
return None
@@ -328,11 +339,11 @@ def extract_outcome_constraints(
328339
for i, c in enumerate(outcome_constraints):
329340
s = 1 if c.op == ComparisonOp.LEQ else -1
330341
if isinstance(c, ScalarizedOutcomeConstraint):
331-
for c_metric, c_weight in c.metric_weights:
332-
j = outcomes.index(c_metric.signature)
342+
for c_metric_name, c_weight in c.metric_weights:
343+
j = outcomes.index(experiment.get_metric(c_metric_name).signature)
333344
A[i, j] = s * c_weight
334345
else:
335-
j = outcomes.index(c.metric.signature)
346+
j = outcomes.index(experiment.get_metric(c.metric_names[0]).signature)
336347
A[i, j] = s
337348
b[i, 0] = s * c.bound
338349
return (A, b)
@@ -643,16 +654,20 @@ def get_pareto_frontier_and_configs(
643654
)
644655
# Extract weights, constraints, and objective_thresholds
645656
objective_weights = extract_objective_weight_matrix(
646-
objective=optimization_config.objective, outcomes=adapter.outcomes
657+
objective=optimization_config.objective,
658+
outcomes=adapter.outcomes,
659+
experiment=adapter._experiment,
647660
)
648661
outcome_constraints = extract_outcome_constraints(
649662
outcome_constraints=optimization_config.outcome_constraints,
650663
outcomes=adapter.outcomes,
664+
experiment=adapter._experiment,
651665
)
652666
obj_t = extract_objective_thresholds(
653667
objective_thresholds=optimization_config.objective_thresholds,
654668
objective=optimization_config.objective,
655669
outcomes=adapter.outcomes,
670+
experiment=adapter._experiment,
656671
)
657672
if obj_t is not None:
658673
obj_t = array_to_tensor(obj_t)
@@ -1111,6 +1126,7 @@ def observation_features_to_array(
11111126
def feasible_hypervolume(
11121127
optimization_config: MultiObjectiveOptimizationConfig,
11131128
values: dict[str, npt.NDArray],
1129+
experiment: Experiment,
11141130
) -> npt.NDArray:
11151131
"""Compute the feasible hypervolume each iteration.
11161132
@@ -1119,34 +1135,35 @@ def feasible_hypervolume(
11191135
values: Dictionary from metric name to array of value at each
11201136
iteration (each array is `n`-dim). If optimization config contains
11211137
outcome constraints, values for them must be present in `values`.
1138+
experiment: The experiment, used to map metric names to signatures.
11221139
11231140
Returns: Array of feasible hypervolumes.
11241141
"""
11251142
# Get objective at each iteration
11261143
obj_threshold_dict = {
1127-
ot.metric.signature: ot.bound for ot in optimization_config.objective_thresholds
1144+
experiment.get_metric(ot.metric_names[0]).signature: ot.bound
1145+
for ot in optimization_config.objective_thresholds
11281146
}
1129-
f_vals = np.hstack(
1130-
[
1131-
values[m.signature].reshape(-1, 1)
1132-
for m in optimization_config.objective.metrics
1133-
]
1134-
)
1135-
obj_thresholds = np.array(
1136-
[obj_threshold_dict[m.signature] for m in optimization_config.objective.metrics]
1137-
)
1147+
obj_metric_names = optimization_config.objective.metric_names
1148+
obj_metrics = [experiment.get_metric(name) for name in obj_metric_names]
1149+
f_vals = np.hstack([values[m.signature].reshape(-1, 1) for m in obj_metrics])
1150+
obj_thresholds = np.array([obj_threshold_dict[m.signature] for m in obj_metrics])
11381151
# Set infeasible points to be the objective threshold
11391152
for oc in optimization_config.outcome_constraints:
11401153
if oc.relative:
11411154
raise ValueError(
11421155
"Benchmark aggregation does not support relative constraints"
11431156
)
1144-
g = values[oc.metric.signature]
1157+
oc_sig = experiment.get_metric(oc.metric_names[0]).signature
1158+
g = values[oc_sig]
11451159
feas = g <= oc.bound if oc.op == ComparisonOp.LEQ else g >= oc.bound
11461160
f_vals[~feas] = obj_thresholds
11471161

1162+
# Derive objective directions from the objective's metric_weights.
1163+
# Positive weight = maximize, negative weight = minimize.
1164+
obj_weight_dict = dict(optimization_config.objective.metric_weights)
11481165
obj_weights = np.array(
1149-
[-1 if m.lower_is_better else 1 for m in optimization_config.objective.metrics]
1166+
[1 if obj_weight_dict[name] > 0 else -1 for name in obj_metric_names]
11501167
)
11511168
obj_thresholds = obj_thresholds * obj_weights
11521169
f_vals = f_vals * obj_weights

ax/adapter/base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def __init__(
201201
"Optimization config is required when "
202202
"`fit_tracking_metrics` is False."
203203
)
204-
self.outcomes = sorted(self._optimization_config.metrics.keys())
204+
self.outcomes = sorted(self._optimization_config.metric_names)
205205

206206
# Set training data (in the raw / untransformed space). This also omits
207207
# out-of-design and abandoned observations depending on the corresponding flags.
@@ -466,10 +466,14 @@ def _set_status_quo(self, experiment: Experiment) -> None:
466466
)
467467
return
468468

469-
if has_map_metrics(optimization_config=self._optimization_config):
469+
if has_map_metrics(
470+
metrics=experiment.get_metrics(
471+
metric_names=[*self._optimization_config.metric_names]
472+
)
473+
):
470474
self._status_quo = _combine_multiple_status_quo_observations(
471475
status_quo_observations=status_quo_observations,
472-
metrics=set(none_throws(self._optimization_config).metrics),
476+
metrics=none_throws(self._optimization_config).metric_names,
473477
)
474478
else:
475479
logger.warning(
@@ -689,7 +693,7 @@ def _get_transformed_gen_args(
689693
# Check that the optimization config has the same metrics as
690694
# the original one. Otherwise, we may attempt to optimize over
691695
# metrics that do not have a fitted model.
692-
outcomes = set(optimization_config.metrics.keys())
696+
outcomes = optimization_config.metric_names
693697
if not outcomes.issubset(self.outcomes):
694698
raise UnsupportedError(
695699
"When fit_tracking_metrics is False, the optimization config "

ax/adapter/cross_validation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ax.adapter.data_utils import ExperimentData
2222
from ax.adapter.observation_utils import unwrap_observation_data
2323
from ax.adapter.torch import TorchAdapter
24+
from ax.core.experiment import Experiment
2425
from ax.core.observation import Observation, ObservationData, ObservationFeatures
2526
from ax.core.optimization_config import OptimizationConfig
2627
from ax.exceptions.core import UnsupportedError
@@ -573,6 +574,7 @@ def assess_model_fit(
573574
def has_good_opt_config_model_fit(
574575
optimization_config: OptimizationConfig,
575576
assess_model_fit_result: AssessModelFitResult,
577+
experiment: Experiment,
576578
) -> bool:
577579
"""Assess model fit for given diagnostics results across the optimization
578580
config metrics
@@ -584,7 +586,8 @@ def has_good_opt_config_model_fit(
584586
585587
Args:
586588
optimization_config: Objective/Outcome constraint metrics to assess
587-
diagnostics: Output of compute_diagnostics
589+
assess_model_fit_result: Output of assess_model_fit
590+
experiment: The experiment, used to map metric names to signatures.
588591
589592
Returns:
590593
Two dictionaries, one for good metrics, one for bad metrics, each
@@ -594,8 +597,9 @@ def has_good_opt_config_model_fit(
594597
# Bad fit criteria: Any objective metrics are poorly fit
595598
# TODO[]: Incl. outcome constraints in assessment
596599
has_good_opt_config_fit = all(
597-
(m.signature in assess_model_fit_result.good_fit_metrics_to_fisher_score)
598-
for m in optimization_config.objective.metrics
600+
experiment.get_metric(name).signature
601+
in assess_model_fit_result.good_fit_metrics_to_fisher_score
602+
for name in optimization_config.objective.metric_names
599603
)
600604
return has_good_opt_config_fit
601605

ax/adapter/discrete.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,18 @@ def _gen(
163163
objective_weights = None
164164
outcome_constraints = None
165165
else:
166-
validate_transformed_optimization_config(optimization_config, self.outcomes)
166+
validate_transformed_optimization_config(
167+
optimization_config, self.outcomes, experiment=self._experiment
168+
)
167169
objective_weights = extract_objective_weights(
168-
objective=optimization_config.objective, outcomes=self.outcomes
170+
objective=optimization_config.objective,
171+
outcomes=self.outcomes,
172+
experiment=self._experiment,
169173
)
170174
outcome_constraints = extract_outcome_constraints(
171175
outcome_constraints=optimization_config.outcome_constraints,
172176
outcomes=self.outcomes,
177+
experiment=self._experiment,
173178
)
174179

175180
# Get fixed features

0 commit comments

Comments
 (0)